From 86d6ee7c0551393dd537a2c1c5e5c6362e1b3e41 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 16 Sep 2025 14:53:01 -0400 Subject: [PATCH] upgrade trl and accelerate (#3161) * upgrade trl==0.23.0 * upgrade accelerate patch fix * add hints when using gradient_checkpointing with DPO * set gradient-checpointing properly --- requirements.txt | 4 ++-- src/axolotl/core/builders/base.py | 2 +- src/axolotl/utils/schemas/validation.py | 15 +++++++++++++++ tests/e2e/multigpu/test_llama.py | 4 ++-- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6138707af..44a3c0277 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,10 +15,10 @@ huggingface_hub>=0.33.0 peft>=0.17.0 transformers==4.56.1 tokenizers>=0.21.1 -accelerate==1.10.0 +accelerate==1.10.1 datasets==4.0.0 deepspeed>=0.17.0 -trl==0.21.0 +trl==0.23.0 hf_xet==1.1.5 kernels==0.9.0 trackio diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 1ec818004..3ad8012f9 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -435,7 +435,7 @@ class TrainerBuilderBase(abc.ABC): # don't use the HF gradient checkpointing, manually wrap training_args_kwargs["gradient_checkpointing"] = False training_args_kwargs["activation_offloading"] = True - elif self.cfg.gradient_checkpointing: + elif self.cfg.gradient_checkpointing is not None: training_args_kwargs["gradient_checkpointing"] = ( self.cfg.gradient_checkpointing ) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 64018ca48..9671b10ae 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1378,6 +1378,21 @@ class ComplexValidationMixin: return self + def hint_gradient_checkpointing_dpo_lora_ddp(self): + if ( + (self.gradient_checkpointing is True or self.gradient_checkpointing is None) + and self.capabilities + and self.capabilities.get("n_gpu", 1) > 1 + and self.adapter in ("lora", "qlora") + and self.rl == RLType.DPO + and not self.fsdp + and not self.deepspeed + ): + LOG.warning( + "gradient_checkpointing with DPO + DDP + LoRA is not recommended." + ) + return self + class DistributedValidationMixin: """validation for distributed training.""" diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index ad15d628b..c16ef0c60 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -199,7 +199,7 @@ class TestMultiGPULlama: "max_steps": 2, "micro_batch_size": 2, "gradient_accumulation_steps": 2, - # "gradient_checkpointing": True, + "gradient_checkpointing": False, "output_dir": temp_dir, "dataset_prepared_path": temp_dir + "/last_run_prepared", "warmup_steps": 0, @@ -278,7 +278,7 @@ class TestMultiGPULlama: "max_steps": 2, "micro_batch_size": 2, "gradient_accumulation_steps": 2, - # "gradient_checkpointing": True, + "gradient_checkpointing": False, "output_dir": temp_dir, "dataset_prepared_path": temp_dir + "/last_run_prepared", "warmup_steps": 0,