Compare commits

...

6 Commits

Author SHA1 Message Date
Wing Lian
3b5a9d1d88 update create_optimizer for updated api 2026-02-19 23:49:32 -05:00
Wing Lian
eb59070040 fix labels 2026-02-19 23:44:46 -05:00
Wing Lian
9722aaf7d8 fix for tokenizers change 2026-02-19 21:52:44 -05:00
Wing Lian
c5d20bbd79 integration branch for transformers#44041 2026-02-19 18:34:13 -05:00
NanoCode012
7fbedbd300 fix(doc): add limitation for unfrozen_parameters (#3416) 2026-02-19 18:32:26 -05:00
Wing Lian
145ffc9be1 upgrade transformers to 5.2.0 and torchao to 0.16.0 (#3407)
* upgrade transformers to 5.1.0 and torchao to 0.16.0

* upgrade trl for parity

* handle trl api changes

* orpo doesn't have max_prompt_len to check anymore

* cpoconfig doesn't take max_prompt_length and fix cpu offload

* slow fsdp1 test

* triton min 3.4.0 and liger to 0.7.0

* use transformers main for now for zero3 fix

* handle group_by_length change

* fix changes upstream

* mark skip flaky test

* use transformers latest release 5.2.0
2026-02-19 18:27:27 -05:00
14 changed files with 62 additions and 35 deletions

View File

@@ -2,21 +2,21 @@
# START section of dependencies that don't install on Darwin/MacOS # START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.49.1 bitsandbytes==0.49.1
triton>=3.0.0 triton>=3.4.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
liger-kernel==0.6.4 liger-kernel==0.7.0
# END section # END section
packaging==26.0 packaging==26.0
huggingface_hub>=1.1.7 huggingface_hub>=1.1.7
peft>=0.18.1 peft>=0.18.1
tokenizers>=0.22.1 tokenizers>=0.22.1
transformers==5.0.0 transformers @ git+https://github.com/winglian/transformers.git@refactor-inner-training-loop-reorder-only
accelerate==1.12.0 accelerate==1.12.0
datasets==4.5.0 datasets==4.5.0
deepspeed>=0.18.3 deepspeed>=0.18.3
trl==0.27.1 trl==0.28.0
hf_xet==1.2.0 hf_xet==1.2.0
kernels==0.11.5 kernels==0.11.5
@@ -63,7 +63,7 @@ langdetect==1.0.9
immutabledict==4.2.0 immutabledict==4.2.0
antlr4-python3-runtime==4.13.2 antlr4-python3-runtime==4.13.2
torchao==0.13.0 torchao==0.16.0
openenv-core==0.1.0 openenv-core==0.1.0
schedulefree==1.4.1 schedulefree==1.4.1

View File

@@ -246,7 +246,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
ddp_find_unused_parameters ddp_find_unused_parameters
) )
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length if self.cfg.group_by_length:
training_arguments_kwargs["train_sampling_strategy"] = "group_by_length"
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)

View File

@@ -11,7 +11,6 @@ from axolotl.core.trainers import (
) )
from axolotl.core.trainers.dpo import DPOStrategy from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback from axolotl.utils.callbacks.qat import QATCallback
@@ -53,6 +52,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}: if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
trainer_cls = GRPOStrategy.get_trainer_class( trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1 sequence_parallel=self.cfg.context_parallel_size > 1
) )
@@ -133,21 +134,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None: if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
# Handle when max_prompt_length == max_length from defaults blocklist_args_kwargs.append("max_prompt_length")
# CPOTrainer requires strictly less than
if (
training_args_kwargs["max_prompt_length"]
== training_args_kwargs["max_length"]
):
training_args_kwargs["max_prompt_length"] -= 1
elif self.cfg.rl is RLType.ORPO: elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig training_args_cls = AxolotlORPOConfig
blocklist_args_kwargs.append("max_prompt_length")
elif self.cfg.rl is RLType.KTO: elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig training_args_cls = AxolotlKTOConfig
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length # KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
blocklist_args_kwargs = ["max_prompt_length"] blocklist_args_kwargs.append("max_prompt_length")
training_args_kwargs["desirable_weight"] = ( training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0 self.cfg.kto_desirable_weight or 1.0
@@ -157,6 +154,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
) )
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}: elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
training_args_cls = GRPOStrategy.get_training_args_class() training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()

View File

@@ -57,16 +57,18 @@ class AxolotlDPOTrainer(
def tokenize_row( def tokenize_row(
features, features,
processing_class, processing_class,
max_prompt_length, max_prompt_length: int | None = None,
max_completion_length, max_completion_length: int | None = None,
add_special_tokens, add_special_tokens: bool = True,
is_chat: bool = False,
) -> Dict: ) -> Dict:
res = DPOTrainer.tokenize_row( res = DPOTrainer.tokenize_row(
features, features,
processing_class, processing_class,
max_prompt_length, max_prompt_length=max_prompt_length,
max_completion_length, max_completion_length=max_completion_length,
add_special_tokens, add_special_tokens=add_special_tokens,
is_chat=is_chat,
) )
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen # fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None: if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:

View File

@@ -104,7 +104,7 @@ class OptimizerMixin(Trainer):
return optimizer_grouped_parameters return optimizer_grouped_parameters
def create_optimizer(self): def create_optimizer(self, model=None):
if ( if (
self.args.loraplus_lr_ratio is None self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None and self.args.embedding_lr_scale is None
@@ -112,9 +112,9 @@ class OptimizerMixin(Trainer):
and self.args.lr_groups is None and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None and self.optimizer_cls_and_kwargs is None
): ):
return super().create_optimizer() return super().create_optimizer(model=model)
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model opt_model = self.model if model is None else model
if ( if (
not self.optimizer not self.optimizer

View File

@@ -10,6 +10,7 @@ from functools import cached_property
import addict import addict
import transformers import transformers
from transformers import PretrainedConfig, PreTrainedModel from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_flash_attention_utils import is_flash_attn_available
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import ( from axolotl.monkeypatch.multipack import (
@@ -500,6 +501,7 @@ class PatchManager:
and not self.cfg.trust_remote_code and not self.cfg.trust_remote_code
and not self.cfg.gptq and not self.cfg.gptq
and self.cfg.flash_attention and self.cfg.flash_attention
and is_flash_attn_available()
and not self.inference and not self.inference
): ):
# TODO(MengqingCao): split these patches separately # TODO(MengqingCao): split these patches separately

View File

@@ -59,7 +59,12 @@ class CPU_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
hidden_states = hidden_states.to("cuda", non_blocking=True).detach() hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True hidden_states.requires_grad = True
with torch.enable_grad(): with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args) output = ctx.forward_function(hidden_states, *ctx.args)
# Newer HF models (e.g. Qwen3MoE) using GradientCheckpointingLayer
# return a plain tensor, not a tuple. Older models return tuples
# like (hidden_states, present_kv, ...). Unwrap if needed.
if isinstance(output, (tuple, list)):
(output,) = output
torch.autograd.backward(output, dY) torch.autograd.backward(output, dY)
return ( return (
None, None,

View File

@@ -28,8 +28,12 @@ PATCHED_EVAL_CODE = {
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()', "array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
} }
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()" ORIGINAL_MAYBE_CODE = (
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()" "tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item()"
)
PATCHED_MAYBE_CODE = (
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).nanmean().item()"
)
def check_evaluation_loop_is_patchable() -> bool: def check_evaluation_loop_is_patchable() -> bool:

View File

@@ -446,7 +446,16 @@ class AxolotlInputConfig(
}, },
) )
unfrozen_parameters: list[str] | None = None unfrozen_parameters: list[str] | None = Field(
default=None,
json_schema_extra={
"description": "List of regex patterns for parameter names to keep unfrozen. "
"All other parameters will be frozen via requires_grad=False. "
"Note: range-based patterns (e.g. embed_tokens.weight$[:32000]) use gradient "
"zeroing rather than a true freeze, so weight decay will still apply to the "
"frozen portion and optimizer states are allocated for the full parameter."
},
)
sequence_len: int = Field( sequence_len: int = Field(
default=512, default=512,

View File

@@ -300,7 +300,6 @@ class TestHFRLTrainerBuilder:
self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl) self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl)
# ORPO specific # ORPO specific
assert training_arguments.beta == 0.1 # maps from orpo_alpha assert training_arguments.beta == 0.1 # maps from orpo_alpha
assert training_arguments.max_prompt_length == 512
def test_kto_training_arguments(self, kto_cfg, model, tokenizer): def test_kto_training_arguments(self, kto_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer) builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)

View File

@@ -186,6 +186,7 @@ class TestFSDP1:
verify_training_success(temp_dir) verify_training_success(temp_dir)
@pytest.mark.skip(reason="slow test, deprecate fsdp1 asap")
def test_dpo_fft(self, temp_dir): def test_dpo_fft(self, temp_dir):
cfg = DictDefault( cfg = DictDefault(
{ {

View File

@@ -365,6 +365,7 @@ class TestFSDP2:
verify_training_success(temp_dir) verify_training_success(temp_dir)
@pytest.mark.skip(reason="slow test w cu129 + torch 2.9.1 + py3.12")
@require_torch_2_7_0 @require_torch_2_7_0
def test_dpo_fft(self, temp_dir): def test_dpo_fft(self, temp_dir):
cfg = DictDefault( cfg = DictDefault(

View File

@@ -115,6 +115,9 @@ class TestAssistantChatTemplateLlama3:
def test_phi35(self, phi35_tokenizer, assistant_dataset): def test_phi35(self, phi35_tokenizer, assistant_dataset):
LOG.info("Testing phi-3.5 with assistant dataset") LOG.info("Testing phi-3.5 with assistant dataset")
assert "LlamaTokenizer" in phi35_tokenizer.__class__.__name__, (
"phi35 tokenizer should be a LlamaTokenizer"
)
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
phi35_tokenizer, phi35_tokenizer,
@@ -140,13 +143,13 @@ class TestAssistantChatTemplateLlama3:
# fmt: off # fmt: off
expected_input_ids = [ expected_input_ids = [
32010, # user 32010, # user
22172, 32007, # user eot 12199, 32007, # user eot
32001, # assistant 32001, # assistant
22172, 32007, # assistant eot 12199, 32007, # assistant eot
32010, # user 32010, # user
1781, 26966, 32007, # user eot 16773, 26966, 32007, # user eot
32001, # assistant 32001, # assistant
1781, 26966, 32007, # assistant eot 16773, 26966, 32007, # assistant eot
] ]
expected_labels = [ expected_labels = [
-100, # user -100, # user
@@ -156,7 +159,7 @@ class TestAssistantChatTemplateLlama3:
-100, # user -100, # user
-100, -100, -100, # user eot -100, -100, -100, # user eot
-100, # assistant -100, # assistant
1781, 26966, 32007, # assistant eot 16773, 26966, 32007, # assistant eot
] ]
# fmt: on # fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}") LOG.debug(f"Expected input_ids: {expected_input_ids}")

View File

@@ -84,7 +84,8 @@ class TestTokenizers:
} }
) )
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404] assert "LlamaTokenizer" in tokenizer.__class__.__name__
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1792]
assert len(tokenizer) == 32001 assert len(tokenizer) == 32001
# ensure reloading the tokenizer again from cfg results in same vocab length # ensure reloading the tokenizer again from cfg results in same vocab length