Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
1a22d16842 handle empty offset for quant state 2025-05-01 13:01:00 -04:00
17 changed files with 23 additions and 484 deletions

View File

@@ -4,12 +4,6 @@ on:
pull_request:
types: [opened, synchronize, reopened]
# Run the workflow only when one of these files changes
paths:
- '**/*.md' # any Markdown file
- '**/*.qmd' # any Quarto file
- '_quarto.yaml'
permissions:
checks: write
contents: write

View File

@@ -15,7 +15,7 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate
from axolotl.utils import patch_optimized_env
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
@@ -32,7 +32,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
cli_args: CLI arguments.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
set_pytorch_cuda_alloc_conf()
# pylint: disable=duplicate-code
print_axolotl_text_art()

View File

@@ -29,7 +29,7 @@ from axolotl.cli.utils import (
filter_none_kwargs,
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import patch_optimized_env
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.schemas.config import AxolotlInputConfig
@@ -55,8 +55,6 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
patch_optimized_env()
if cloud:
from axolotl.cli.cloud import do_cli_preprocess
@@ -102,7 +100,7 @@ def train(
config options.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
set_pytorch_cuda_alloc_conf()
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False

View File

@@ -18,7 +18,7 @@ from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils import patch_optimized_env
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
@@ -36,7 +36,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
cli_args: Training-specific CLI arguments.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
set_pytorch_cuda_alloc_conf()
print_axolotl_text_art()
check_accelerate_default_config()

View File

@@ -610,15 +610,3 @@ class AxolotlTrainer(
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
def compute_loss_context_manager(self):
from contextlib import ExitStack
from torchtune.training import OffloadActivations
stack = ExitStack()
stack.enter_context(super().compute_loss_context_manager())
stack.enter_context(OffloadActivations())
return stack

View File

@@ -177,8 +177,12 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
res["chosen_labels"] = res["chosen_labels"][1:]
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
res["rejected_labels"] = res["rejected_labels"][1:]
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
return res

View File

@@ -151,30 +151,6 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3":
from axolotl.integrations.liger.models.qwen3 import (
apply_liger_kernel_to_qwen3,
)
apply_liger_kernel_to_qwen3(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_moe":
from axolotl.integrations.liger.models.qwen3_moe import (
apply_liger_kernel_to_qwen3_moe,
)
apply_liger_kernel_to_qwen3_moe(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
else:
logging.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."

View File

@@ -1,160 +0,0 @@
"""
Liger FLCE for Qwen3. Based on transformers v4.51.3.
"""
import sys
from typing import Optional, Tuple, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_liger_kernel_to_qwen3(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> None:
# pylint: disable=duplicate-code
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3.modeling_qwen3 # noqa: F401 # pylint: disable=unused-import
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
modeling_qwen3 = sys.modules["transformers.models.qwen3.modeling_qwen3"]
if rms_norm:
modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
if glu_activation:
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
if layer_norm:
modeling_qwen3.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_qwen3.Qwen3ForCausalLM.forward = lce_forward

View File

@@ -1,191 +0,0 @@
"""
Liger FLCE for Qwen3 MoE. Based on transformers v4.51.3.
"""
import sys
from copy import deepcopy
from typing import List, Optional, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
from transformers.models.qwen3_moe.modeling_qwen3_moe import load_balancing_loss_func
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> MoeCausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(
loss.device
) # make sure to reside in the same device
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_liger_kernel_to_qwen3_moe(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> None:
# pylint: disable=duplicate-code
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3_moe.modeling_qwen3_moe # noqa: F401 # pylint: disable=unused-import
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
modeling_qwen3_moe = sys.modules["transformers.models.qwen3_moe.modeling_qwen3_moe"]
if rms_norm:
modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
if glu_activation:
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"Accepts intermediate_size to pass to LigerSwiGLUMLP"
# clone config to avoid modifying the original
config = deepcopy(config)
if intermediate_size:
setattr(config, "intermediate_size", intermediate_size)
return LigerSwiGLUMLP(config, **kwargs)
modeling_qwen3_moe.Qwen3MoeMLP = _liger_swiglu_mlp_wrapper
if layer_norm:
modeling_qwen3_moe.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = lce_forward

View File

@@ -55,13 +55,16 @@ def dequantize(
target_device = W.device
# Extract quantization state
nested = False
if not isinstance(quant_state, list):
# New style quant_state class
absmax = quant_state.absmax.to(target_device)
shape = quant_state.shape
dtype = quant_state.dtype
blocksize = quant_state.blocksize
offset = quant_state.offset.to(target_device)
if quant_state.nested:
nested = True
offset = quant_state.offset.to(target_device)
state2 = quant_state.state2
absmax2 = state2.absmax.to(target_device)
code2 = state2.code.to(target_device)
@@ -115,7 +118,8 @@ def dequantize(
ctypes.c_int(n_elements_absmax),
)
out_absmax += offset
if nested:
out_absmax += offset
# Choose appropriate dequantization function
fx = (

View File

@@ -18,8 +18,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mixtral",
"qwen2",
"qwen2_moe",
"qwen3",
"qwen3_moe",
"falcon",
"phi",
"phi3",

View File

@@ -43,12 +43,3 @@ def set_pytorch_cuda_alloc_conf():
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
"expandable_segments:True,roundup_power2_divisions:16"
)
def patch_optimized_env():
"""
Patch environment variables to improve VRAM usage and increase download speed
"""
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
set_pytorch_cuda_alloc_conf()

View File

@@ -59,7 +59,7 @@ def choose_device(cfg):
def resolve_dtype(cfg):
if (
not cfg.fp16 and cfg.bf16 == "auto" and not cfg.use_ray
cfg.bf16 == "auto" and not cfg.use_ray
): # if we use ray we want to defer this check to the worker node
if is_torch_bf16_gpu_available():
LOG.debug("bf16 support detected, enabling for this configuration.")

View File

@@ -2,13 +2,6 @@
from functools import partial
import torch
from torch.utils.checkpoint import (
CheckpointPolicy,
checkpoint,
create_selective_checkpoint_contexts,
)
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
)
@@ -25,32 +18,3 @@ def hf_grad_checkpoint_offload_wrapper(
),
*args,
)
aten = torch.ops.aten
compute_intensive_ops = [
aten.mm.default,
aten.bmm.default,
aten.addmm.default,
]
def policy_fn(ctx, op, *args, **kwargs):
if op in compute_intensive_ops:
return CheckpointPolicy.MUST_SAVE
else:
return CheckpointPolicy.PREFER_RECOMPUTE
context_fn = partial(create_selective_checkpoint_contexts, policy_fn)
def checkpoint_w_policy(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return checkpoint(
decoder_layer,
*args,
use_reentrant=use_reentrant,
context_fn=context_fn,
)

View File

@@ -190,7 +190,7 @@ class MultipackBatchSampler(BatchSampler):
self.len_across_ranks = None
if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warning(
LOG.warn(
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
)

View File

@@ -512,17 +512,10 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def hint_sample_packing_padding(cls, data):
if data.get("sample_packing"):
pad_to_sequence_len = data.get("pad_to_sequence_len")
if pad_to_sequence_len is False:
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
elif pad_to_sequence_len is None:
LOG.info(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
)
data["pad_to_sequence_len"] = True
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
return data
@model_validator(mode="before")

View File

@@ -648,7 +648,7 @@ class TestValidation(BaseValidation):
DictDefault(
{
"sample_packing": True,
"pad_to_sequence_len": False,
"pad_to_sequence_len": None,
"flash_attention": True,
}
)
@@ -662,26 +662,6 @@ class TestValidation(BaseValidation):
for record in self._caplog.records
)
def test_packing_autoset(self, minimal_cfg):
cfg = (
DictDefault(
{
"sample_packing": True,
"pad_to_sequence_len": None,
"flash_attention": True,
}
)
| minimal_cfg
)
with self._caplog.at_level(logging.INFO):
cfg = validate_config(cfg)
assert any(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
in record.message
for record in self._caplog.records
)
assert cfg.pad_to_sequence_len is True
def test_merge_lora_no_bf16_fail(self, minimal_cfg):
"""
This is assumed to be run on a CPU machine, so bf16 is not supported.