From 98333e639a35bd36a108786a6daaa42f03488aca Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 29 Oct 2025 18:02:16 -0400 Subject: [PATCH] upgrade trl to 0.24.0 and liger to 0.6.3 (#3230) * upgrade trl to 0.24.0 * fix reward collator init * use newer DataCollatorForPreference instead * DataCollatorForPreference doesn't use padding kwarg * fix input id labels * fix fbgemm-gpu version for pytorch versions * tweak pinned deps * transformers doesn't support hub 1.0 yet * upgrade liger dep to 0.6.3 * set TORCH_CUDA_ARCH_LIST correctly --- cicd/Dockerfile.jinja | 2 +- requirements.txt | 12 ++++++------ setup.py | 8 ++++++-- src/axolotl/core/builders/causal.py | 9 ++++++--- .../prompt_strategies/bradley_terry/chat_template.py | 4 ++-- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 6a1ddb66d..c3a613ecc 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -1,6 +1,6 @@ FROM axolotlai/axolotl-base:{{ BASE_TAG }} -ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" +ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}" ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}" ENV CUDA="{{ CUDA }}" diff --git a/requirements.txt b/requirements.txt index e1f1b10a5..5621d94b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,27 +5,27 @@ bitsandbytes==0.47.0 triton>=3.0.0 mamba-ssm==1.2.0.post1 xformers>=0.0.23.post1 -liger-kernel==0.6.1 +liger-kernel==0.6.3 # END section packaging==23.2 -huggingface_hub>=0.33.0 +huggingface_hub>=0.36.0 peft>=0.17.1 tokenizers>=0.21.1 transformers==4.57.1 accelerate==1.10.1 datasets==4.0.0 deepspeed>=0.17.0 -trl==0.23.1 -hf_xet==1.1.5 -kernels==0.9.0 +trl==0.24.0 +hf_xet==1.2.0 +kernels>=0.9.0 trackio optimum==1.16.2 hf_transfer sentencepiece -gradio==5.41.1 +gradio==5.49.1 modal==1.0.2 pydantic==2.10.6 diff --git a/setup.py b/setup.py index 9e3de48b5..2845bb151 100644 --- a/setup.py +++ b/setup.py @@ -62,8 +62,12 @@ def parse_requirements(extras_require_map): else: raise ValueError("Invalid version format") - if (major, minor) >= (2, 8): - pass + if (major, minor) >= (2, 9): + extras_require_map.pop("fbgemm-gpu") + extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"] + elif (major, minor) >= (2, 8): + extras_require_map.pop("fbgemm-gpu") + extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"] elif (major, minor) >= (2, 7): _install_requires.pop(_install_requires.index(xformers_version)) if patch == 0: diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 820304230..7a06431dc 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -12,7 +12,7 @@ from transformers import ( EarlyStoppingCallback, Trainer, ) -from trl.trainer.utils import RewardDataCollatorWithPadding +from trl.trainer.reward_trainer import DataCollatorForPreference from axolotl.core.builders.base import TrainerBuilderBase from axolotl.core.trainers import ( @@ -453,7 +453,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, DataCollatorWithFlattening, - RewardDataCollatorWithPadding, + DataCollatorForPreference, ] ] collator_args = [self.tokenizer] @@ -470,7 +470,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if kwargs and isinstance(kwargs, dict): kwargs.update(collator_cls_and_kwargs[1]) elif self.cfg.reward_model: - collator = RewardDataCollatorWithPadding + collator = DataCollatorForPreference + tokenizer = collator_args.pop(0) + kwargs["pad_token_id"] = tokenizer.pad_token_id + kwargs.pop("padding") elif use_batch_sampler_collator: # Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention, # supported multipack models, or non-flash-attention llama diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py index fd0d76f51..03336b3ef 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -71,10 +71,10 @@ class BTChatTemplateStrategy(ChatTemplateStrategy): ] return { - "input_ids_chosen": chosen_tokenized["input_ids"], + "chosen_input_ids": chosen_tokenized["input_ids"], "attention_mask_chosen": chosen_tokenized["attention_mask"], "labels_chosen": 1.0, - "input_ids_rejected": rejected_tokenized["input_ids"], + "rejected_input_ids": rejected_tokenized["input_ids"], "attention_mask_rejected": rejected_tokenized["attention_mask"], "labels_rejected": 0.0, }