Compare commits
9 Commits
lora-quant
...
activation
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7610a02881 | ||
|
|
b0cd54bcb9 | ||
|
|
54960d4de0 | ||
|
|
ed922796b7 | ||
|
|
3dd9c3bf3f | ||
|
|
0ba7d362fa | ||
|
|
e4f73bc98e | ||
|
|
bcb59c70e2 | ||
|
|
6a3e6f8c53 |
6
.github/workflows/preview-docs.yml
vendored
6
.github/workflows/preview-docs.yml
vendored
@@ -4,6 +4,12 @@ on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
|
||||
# Run the workflow only when one of these files changes
|
||||
paths:
|
||||
- '**/*.md' # any Markdown file
|
||||
- '**/*.qmd' # any Quarto file
|
||||
- '_quarto.yaml'
|
||||
|
||||
permissions:
|
||||
checks: write
|
||||
contents: write
|
||||
|
||||
@@ -15,7 +15,7 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.evaluate import evaluate
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@@ -32,7 +32,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
cli_args: CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
patch_optimized_env()
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
|
||||
@@ -29,7 +29,7 @@ from axolotl.cli.utils import (
|
||||
filter_none_kwargs,
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
|
||||
@@ -55,6 +55,8 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
patch_optimized_env()
|
||||
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_preprocess
|
||||
|
||||
@@ -100,7 +102,7 @@ def train(
|
||||
config options.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
patch_optimized_env()
|
||||
|
||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||
accelerate = False
|
||||
|
||||
@@ -18,7 +18,7 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -36,7 +36,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
cli_args: Training-specific CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
patch_optimized_env()
|
||||
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
|
||||
@@ -610,3 +610,15 @@ class AxolotlTrainer(
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
def compute_loss_context_manager(self):
|
||||
from contextlib import ExitStack
|
||||
|
||||
from torchtune.training import OffloadActivations
|
||||
|
||||
stack = ExitStack()
|
||||
|
||||
stack.enter_context(super().compute_loss_context_manager())
|
||||
stack.enter_context(OffloadActivations())
|
||||
|
||||
return stack
|
||||
|
||||
@@ -177,12 +177,8 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@@ -151,6 +151,30 @@ class LigerPlugin(BasePlugin):
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "qwen3":
|
||||
from axolotl.integrations.liger.models.qwen3 import (
|
||||
apply_liger_kernel_to_qwen3,
|
||||
)
|
||||
|
||||
apply_liger_kernel_to_qwen3(
|
||||
cross_entropy=cfg.liger_cross_entropy,
|
||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||
glu_activation=cfg.liger_glu_activation,
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "qwen3_moe":
|
||||
from axolotl.integrations.liger.models.qwen3_moe import (
|
||||
apply_liger_kernel_to_qwen3_moe,
|
||||
)
|
||||
|
||||
apply_liger_kernel_to_qwen3_moe(
|
||||
cross_entropy=cfg.liger_cross_entropy,
|
||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||
glu_activation=cfg.liger_glu_activation,
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
||||
|
||||
160
src/axolotl/integrations/liger/models/qwen3.py
Normal file
160
src/axolotl/integrations/liger/models/qwen3.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
Liger FLCE for Qwen3. Based on transformers v4.51.3.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
def lce_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = None
|
||||
loss = None
|
||||
# if in training mode, don't materialize logits
|
||||
if self.training and (labels is not None):
|
||||
loss = LigerForCausalLMLoss(
|
||||
hidden_states=hidden_states,
|
||||
lm_head_weight=self.lm_head.weight,
|
||||
labels=labels,
|
||||
hidden_size=self.config.hidden_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else: # if in inference mode materialize logits
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def apply_liger_kernel_to_qwen3(
|
||||
cross_entropy: bool = False,
|
||||
fused_linear_cross_entropy: bool = False,
|
||||
rms_norm: bool = False,
|
||||
glu_activation: bool = False,
|
||||
layer_norm: bool = False,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
"""
|
||||
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
|
||||
|
||||
Args:
|
||||
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
||||
fused_linear_cross_entropy (bool):
|
||||
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
||||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
|
||||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
||||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
||||
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
||||
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
||||
"""
|
||||
|
||||
import transformers.models.qwen3.modeling_qwen3 # noqa: F401 # pylint: disable=unused-import
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
assert not (
|
||||
cross_entropy and fused_linear_cross_entropy
|
||||
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
||||
|
||||
modeling_qwen3 = sys.modules["transformers.models.qwen3.modeling_qwen3"]
|
||||
|
||||
if rms_norm:
|
||||
modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
|
||||
|
||||
if glu_activation:
|
||||
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
|
||||
|
||||
if layer_norm:
|
||||
modeling_qwen3.nn.LayerNorm = LigerLayerNorm
|
||||
|
||||
if cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
|
||||
if fused_linear_cross_entropy:
|
||||
modeling_qwen3.Qwen3ForCausalLM.forward = lce_forward
|
||||
191
src/axolotl/integrations/liger/models/qwen3_moe.py
Normal file
191
src/axolotl/integrations/liger/models/qwen3_moe.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
Liger FLCE for Qwen3 MoE. Based on transformers v4.51.3.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
||||
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import load_balancing_loss_func
|
||||
|
||||
|
||||
def lce_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
) -> MoeCausalLMOutputWithPast:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_router_logits = (
|
||||
output_router_logits
|
||||
if output_router_logits is not None
|
||||
else self.config.output_router_logits
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_router_logits=output_router_logits,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = None
|
||||
loss = None
|
||||
# if in training mode, don't materialize logits
|
||||
if self.training and (labels is not None):
|
||||
loss = LigerForCausalLMLoss(
|
||||
hidden_states=hidden_states,
|
||||
lm_head_weight=self.lm_head.weight,
|
||||
labels=labels,
|
||||
hidden_size=self.config.hidden_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else: # if in inference mode materialize logits
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(
|
||||
outputs.router_logits,
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
attention_mask,
|
||||
)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss.to(
|
||||
loss.device
|
||||
) # make sure to reside in the same device
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def apply_liger_kernel_to_qwen3_moe(
|
||||
cross_entropy: bool = False,
|
||||
fused_linear_cross_entropy: bool = False,
|
||||
rms_norm: bool = False,
|
||||
glu_activation: bool = False,
|
||||
layer_norm: bool = False,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
"""
|
||||
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
|
||||
|
||||
Args:
|
||||
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
||||
fused_linear_cross_entropy (bool):
|
||||
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
||||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
|
||||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
||||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
||||
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
||||
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
||||
"""
|
||||
|
||||
import transformers.models.qwen3_moe.modeling_qwen3_moe # noqa: F401 # pylint: disable=unused-import
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
assert not (
|
||||
cross_entropy and fused_linear_cross_entropy
|
||||
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
||||
|
||||
modeling_qwen3_moe = sys.modules["transformers.models.qwen3_moe.modeling_qwen3_moe"]
|
||||
|
||||
if rms_norm:
|
||||
modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
|
||||
|
||||
if glu_activation:
|
||||
|
||||
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
||||
"Accepts intermediate_size to pass to LigerSwiGLUMLP"
|
||||
# clone config to avoid modifying the original
|
||||
config = deepcopy(config)
|
||||
if intermediate_size:
|
||||
setattr(config, "intermediate_size", intermediate_size)
|
||||
return LigerSwiGLUMLP(config, **kwargs)
|
||||
|
||||
modeling_qwen3_moe.Qwen3MoeMLP = _liger_swiglu_mlp_wrapper
|
||||
|
||||
if layer_norm:
|
||||
modeling_qwen3_moe.nn.LayerNorm = LigerLayerNorm
|
||||
|
||||
if cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
|
||||
if fused_linear_cross_entropy:
|
||||
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = lce_forward
|
||||
@@ -55,16 +55,13 @@ def dequantize(
|
||||
target_device = W.device
|
||||
|
||||
# Extract quantization state
|
||||
nested = False
|
||||
if not isinstance(quant_state, list):
|
||||
# New style quant_state class
|
||||
absmax = quant_state.absmax.to(target_device)
|
||||
shape = quant_state.shape
|
||||
dtype = quant_state.dtype
|
||||
blocksize = quant_state.blocksize
|
||||
if quant_state.nested:
|
||||
nested = True
|
||||
offset = quant_state.offset.to(target_device)
|
||||
offset = quant_state.offset.to(target_device)
|
||||
state2 = quant_state.state2
|
||||
absmax2 = state2.absmax.to(target_device)
|
||||
code2 = state2.code.to(target_device)
|
||||
@@ -118,8 +115,7 @@ def dequantize(
|
||||
ctypes.c_int(n_elements_absmax),
|
||||
)
|
||||
|
||||
if nested:
|
||||
out_absmax += offset
|
||||
out_absmax += offset
|
||||
|
||||
# Choose appropriate dequantization function
|
||||
fx = (
|
||||
|
||||
@@ -18,6 +18,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"mixtral",
|
||||
"qwen2",
|
||||
"qwen2_moe",
|
||||
"qwen3",
|
||||
"qwen3_moe",
|
||||
"falcon",
|
||||
"phi",
|
||||
"phi3",
|
||||
|
||||
@@ -43,3 +43,12 @@ def set_pytorch_cuda_alloc_conf():
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
|
||||
"expandable_segments:True,roundup_power2_divisions:16"
|
||||
)
|
||||
|
||||
|
||||
def patch_optimized_env():
|
||||
"""
|
||||
Patch environment variables to improve VRAM usage and increase download speed
|
||||
"""
|
||||
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
@@ -59,7 +59,7 @@ def choose_device(cfg):
|
||||
|
||||
def resolve_dtype(cfg):
|
||||
if (
|
||||
cfg.bf16 == "auto" and not cfg.use_ray
|
||||
not cfg.fp16 and cfg.bf16 == "auto" and not cfg.use_ray
|
||||
): # if we use ray we want to defer this check to the worker node
|
||||
if is_torch_bf16_gpu_available():
|
||||
LOG.debug("bf16 support detected, enabling for this configuration.")
|
||||
|
||||
@@ -2,6 +2,13 @@
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.utils.checkpoint import (
|
||||
CheckpointPolicy,
|
||||
checkpoint,
|
||||
create_selective_checkpoint_contexts,
|
||||
)
|
||||
|
||||
from axolotl.utils.gradient_checkpointing.unsloth import (
|
||||
Unsloth_Offloaded_Gradient_Checkpointer,
|
||||
)
|
||||
@@ -18,3 +25,32 @@ def hf_grad_checkpoint_offload_wrapper(
|
||||
),
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
compute_intensive_ops = [
|
||||
aten.mm.default,
|
||||
aten.bmm.default,
|
||||
aten.addmm.default,
|
||||
]
|
||||
|
||||
|
||||
def policy_fn(ctx, op, *args, **kwargs):
|
||||
if op in compute_intensive_ops:
|
||||
return CheckpointPolicy.MUST_SAVE
|
||||
else:
|
||||
return CheckpointPolicy.PREFER_RECOMPUTE
|
||||
|
||||
|
||||
context_fn = partial(create_selective_checkpoint_contexts, policy_fn)
|
||||
|
||||
|
||||
def checkpoint_w_policy(
|
||||
decoder_layer, *args, use_reentrant=None
|
||||
): # pylint: disable=unused-argument
|
||||
return checkpoint(
|
||||
decoder_layer,
|
||||
*args,
|
||||
use_reentrant=use_reentrant,
|
||||
context_fn=context_fn,
|
||||
)
|
||||
|
||||
@@ -190,7 +190,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
self.len_across_ranks = None
|
||||
|
||||
if self.sequential and not isinstance(sampler, SequentialSampler):
|
||||
LOG.warn(
|
||||
LOG.warning(
|
||||
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
|
||||
)
|
||||
|
||||
|
||||
@@ -512,10 +512,17 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def hint_sample_packing_padding(cls, data):
|
||||
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
|
||||
LOG.warning(
|
||||
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||
)
|
||||
if data.get("sample_packing"):
|
||||
pad_to_sequence_len = data.get("pad_to_sequence_len")
|
||||
if pad_to_sequence_len is False:
|
||||
LOG.warning(
|
||||
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||
)
|
||||
elif pad_to_sequence_len is None:
|
||||
LOG.info(
|
||||
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
|
||||
)
|
||||
data["pad_to_sequence_len"] = True
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@@ -648,7 +648,7 @@ class TestValidation(BaseValidation):
|
||||
DictDefault(
|
||||
{
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": None,
|
||||
"pad_to_sequence_len": False,
|
||||
"flash_attention": True,
|
||||
}
|
||||
)
|
||||
@@ -662,6 +662,26 @@ class TestValidation(BaseValidation):
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
def test_packing_autoset(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": None,
|
||||
"flash_attention": True,
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
with self._caplog.at_level(logging.INFO):
|
||||
cfg = validate_config(cfg)
|
||||
assert any(
|
||||
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
|
||||
in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
assert cfg.pad_to_sequence_len is True
|
||||
|
||||
def test_merge_lora_no_bf16_fail(self, minimal_cfg):
|
||||
"""
|
||||
This is assumed to be run on a CPU machine, so bf16 is not supported.
|
||||
|
||||
Reference in New Issue
Block a user