diff --git a/requirements.txt b/requirements.txt index 565224e92..dc57ebfab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,21 +2,21 @@ # START section of dependencies that don't install on Darwin/MacOS bitsandbytes==0.49.1 -triton>=3.0.0 +triton>=3.4.0 mamba-ssm==1.2.0.post1 xformers>=0.0.23.post1 -liger-kernel==0.6.4 +liger-kernel==0.7.0 # END section packaging==26.0 huggingface_hub>=1.1.7 peft>=0.18.1 tokenizers>=0.22.1 -transformers==5.0.0 +transformers==5.2.0 accelerate==1.12.0 datasets==4.5.0 deepspeed>=0.18.3 -trl==0.27.1 +trl==0.28.0 hf_xet==1.2.0 kernels==0.11.5 @@ -63,7 +63,7 @@ langdetect==1.0.9 immutabledict==4.2.0 antlr4-python3-runtime==4.13.2 -torchao==0.13.0 +torchao==0.16.0 openenv-core==0.1.0 schedulefree==1.4.1 diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 09bcff450..7bfc5e874 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -246,7 +246,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ddp_find_unused_parameters ) - training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length + if self.cfg.group_by_length: + training_arguments_kwargs["train_sampling_strategy"] = "group_by_length" training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index 0bd2eedfc..5a7343ca7 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -11,7 +11,6 @@ from axolotl.core.trainers import ( ) from axolotl.core.trainers.dpo import DPOStrategy from axolotl.core.trainers.dpo.args import AxolotlDPOConfig -from axolotl.core.trainers.grpo import GRPOStrategy from axolotl.integrations.base import PluginManager from axolotl.loaders.utils import ensure_dtype from axolotl.utils.callbacks.qat import QATCallback @@ -53,6 +52,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): trainer_cls_args = [self.model] if self.cfg.rl in {RLType.GRPO, RLType.GDPO}: + from axolotl.core.trainers.grpo import GRPOStrategy + trainer_cls = GRPOStrategy.get_trainer_class( sequence_parallel=self.cfg.context_parallel_size > 1 ) @@ -133,21 +134,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.cpo_alpha is not None: training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha - # Handle when max_prompt_length == max_length from defaults - # CPOTrainer requires strictly less than - if ( - training_args_kwargs["max_prompt_length"] - == training_args_kwargs["max_length"] - ): - training_args_kwargs["max_prompt_length"] -= 1 + blocklist_args_kwargs.append("max_prompt_length") elif self.cfg.rl is RLType.ORPO: training_args_cls = AxolotlORPOConfig + blocklist_args_kwargs.append("max_prompt_length") + elif self.cfg.rl is RLType.KTO: training_args_cls = AxolotlKTOConfig # KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length - blocklist_args_kwargs = ["max_prompt_length"] + blocklist_args_kwargs.append("max_prompt_length") training_args_kwargs["desirable_weight"] = ( self.cfg.kto_desirable_weight or 1.0 @@ -157,6 +154,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): ) elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}: + from axolotl.core.trainers.grpo import GRPOStrategy + training_args_cls = GRPOStrategy.get_training_args_class() training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py index b04505d89..92307fe23 100644 --- a/src/axolotl/core/trainers/dpo/trainer.py +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -57,16 +57,18 @@ class AxolotlDPOTrainer( def tokenize_row( features, processing_class, - max_prompt_length, - max_completion_length, - add_special_tokens, + max_prompt_length: int | None = None, + max_completion_length: int | None = None, + add_special_tokens: bool = True, + is_chat: bool = False, ) -> Dict: res = DPOTrainer.tokenize_row( features, processing_class, - max_prompt_length, - max_completion_length, - add_special_tokens, + max_prompt_length=max_prompt_length, + max_completion_length=max_completion_length, + add_special_tokens=add_special_tokens, + is_chat=is_chat, ) # fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen if processing_class.bos_token is None and res["prompt_input_ids"][0] is None: diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 3cf8bbd20..222260020 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -10,6 +10,7 @@ from functools import cached_property import addict import transformers from transformers import PretrainedConfig, PreTrainedModel +from transformers.modeling_flash_attention_utils import is_flash_attn_available from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.multipack import ( @@ -500,6 +501,7 @@ class PatchManager: and not self.cfg.trust_remote_code and not self.cfg.gptq and self.cfg.flash_attention + and is_flash_attn_available() and not self.inference ): # TODO(MengqingCao): split these patches separately diff --git a/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py b/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py index 8d06f172d..886441196 100644 --- a/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py +++ b/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py @@ -59,7 +59,12 @@ class CPU_Offloaded_Gradient_Checkpointer(torch.autograd.Function): hidden_states = hidden_states.to("cuda", non_blocking=True).detach() hidden_states.requires_grad = True with torch.enable_grad(): - (output,) = ctx.forward_function(hidden_states, *ctx.args) + output = ctx.forward_function(hidden_states, *ctx.args) + # Newer HF models (e.g. Qwen3MoE) using GradientCheckpointingLayer + # return a plain tensor, not a tuple. Older models return tuples + # like (hidden_states, present_kv, ...). Unwrap if needed. + if isinstance(output, (tuple, list)): + (output,) = output torch.autograd.backward(output, dY) return ( None, diff --git a/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py index b8172bbe6..3a99d0115 100644 --- a/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py +++ b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py @@ -28,8 +28,12 @@ PATCHED_EVAL_CODE = { "array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()', } -ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()" -PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()" +ORIGINAL_MAYBE_CODE = ( + "tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item()" +) +PATCHED_MAYBE_CODE = ( + "tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).nanmean().item()" +) def check_evaluation_loop_is_patchable() -> bool: diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index 194950e15..fc16f723e 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -300,7 +300,6 @@ class TestHFRLTrainerBuilder: self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl) # ORPO specific assert training_arguments.beta == 0.1 # maps from orpo_alpha - assert training_arguments.max_prompt_length == 512 def test_kto_training_arguments(self, kto_cfg, model, tokenizer): builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer) diff --git a/tests/e2e/multigpu/test_fsdp1.py b/tests/e2e/multigpu/test_fsdp1.py index e50316287..5b6724791 100644 --- a/tests/e2e/multigpu/test_fsdp1.py +++ b/tests/e2e/multigpu/test_fsdp1.py @@ -186,6 +186,7 @@ class TestFSDP1: verify_training_success(temp_dir) + @pytest.mark.skip(reason="slow test, deprecate fsdp1 asap") def test_dpo_fft(self, temp_dir): cfg = DictDefault( { diff --git a/tests/e2e/multigpu/test_fsdp2.py b/tests/e2e/multigpu/test_fsdp2.py index 19239a3ec..e10456240 100644 --- a/tests/e2e/multigpu/test_fsdp2.py +++ b/tests/e2e/multigpu/test_fsdp2.py @@ -365,6 +365,7 @@ class TestFSDP2: verify_training_success(temp_dir) + @pytest.mark.skip(reason="slow test w cu129 + torch 2.9.1 + py3.12") @require_torch_2_7_0 def test_dpo_fft(self, temp_dir): cfg = DictDefault(