upgrade transformers==5.3.0 trl==0.29.0 kernels (#3459)

* upgrade transformers==5.3.0 trl==0.29.0 kernels

* use latest deepspeed fixes

* use corect image for cleanup

* fix test outputs for tokenizer fixes upstream

* fix import:

* keep trl at 0.28.0

* handle updated API

* use latest trl since 0.28.0 doesn't work with latest transformers

* use trl experimental for pad to length

* monkeypatch trl with ORPOTrainer so liger doesn't croak

* upgrade accelerate

* more fixes

* move patch for orpotrainer

* load the imports later

* remove use_logits_to_keep

* fix loss_type arg as a list

* fetch hf cache from s3

* just manually download the missing model for now

* lint for pre-commit update

* a few more missing models on disk

* fix: loss_type internally now list

* fix: remove deprecated code and raise deprecate

* fix: remove unneeded blocklist

* fix: remove reliance on transformers api to find package available

* chore: refactor shim for less sideeffect

* fix: silent trl experimental warning

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Wing Lian
2026-03-06 09:11:20 -05:00
committed by GitHub
parent 56162f71db
commit cada93cee5
19 changed files with 81 additions and 49 deletions

View File

@@ -387,8 +387,8 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 129
cuda_version: 12.9.1
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
num_gpus: 1

View File

@@ -3,6 +3,12 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
# curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
hf download "NousResearch/Meta-Llama-3-8B"
hf download "NousResearch/Meta-Llama-3-8B-Instruct"
hf download "microsoft/Phi-4-reasoning"
hf download "microsoft/Phi-3.5-mini-instruct"
# Run unit tests with initial coverage report
pytest -v --durations=10 -n8 \
--ignore=tests/e2e/ \

View File

@@ -12,13 +12,13 @@ packaging==26.0
huggingface_hub>=1.1.7
peft>=0.18.1
tokenizers>=0.22.1
transformers==5.2.0
accelerate==1.12.0
transformers==5.3.0
accelerate==1.13.0
datasets==4.5.0
deepspeed>=0.18.3
trl==0.28.0
hf_xet==1.2.0
kernels==0.12.1
deepspeed>=0.18.6,<0.19.0
trl==0.29.0
hf_xet==1.3.2
kernels==0.12.2
trackio>=0.16.1
typing-extensions>=4.15.0

View File

@@ -6,5 +6,6 @@ from axolotl.logging_config import configure_logging
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("HF_XET_HIGH_PERFORMANCE", "1")
os.environ.setdefault("TRL_EXPERIMENTAL_SILENCE", "1")
configure_logging()

View File

@@ -120,11 +120,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
else:
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl is RLType.SIMPO:

View File

@@ -26,7 +26,7 @@ from transformers import PreTrainedModel, Trainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
from trl.trainer.utils import pad_to_length
from trl.experimental.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.mixins import (

View File

@@ -25,17 +25,13 @@ class DPOStrategy:
# Label smoothing is not compatible with IPO
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
if cfg.dpo_padding_free is not None:
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
if cfg.dpo_norm_loss is not None:
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
if cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
if cfg.dpo_use_liger_kernel is not None:
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
return training_args_kwargs

View File

@@ -103,10 +103,10 @@ class AxolotlDPOTrainer(
) -> dict[str, torch.Tensor]:
if self.args.dpo_norm_loss:
# fmt: off
loss_type: str = self.loss_type # type: ignore[has-type]
loss_type: list[str] = self.loss_type # type: ignore[has-type]
# fmt: on
# concatenated_forward handles avg token logprob for ipo case already
self.loss_type = "ipo"
self.loss_type = ["ipo"]
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
self.loss_type = loss_type
return res

View File

@@ -104,7 +104,7 @@ class OptimizerMixin(Trainer):
return optimizer_grouped_parameters
def create_optimizer(self):
def create_optimizer(self, model=None):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
@@ -112,9 +112,9 @@ class OptimizerMixin(Trainer):
and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None
):
return super().create_optimizer()
return super().create_optimizer(model=model)
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
opt_model = self.model if model is None else model
if (
not self.optimizer

View File

@@ -8,9 +8,6 @@ import sys
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
from .models.base import patch_lce_forward
from .utils import patch_with_compile_disable
LOG = get_logger(__name__)
@@ -23,10 +20,18 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
# shim: liger-kernel 0.7.0 imports ORPOTrainer from old trl path
import trl.trainer
from trl.experimental.orpo import ORPOTrainer
trl.trainer.ORPOTrainer = ORPOTrainer
if cfg.torch_compile:
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
import liger_kernel.ops.fused_linear_cross_entropy
from .utils import patch_with_compile_disable
patch_with_compile_disable(
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_forward",
@@ -35,6 +40,7 @@ class LigerPlugin(BasePlugin):
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_backward",
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
@@ -192,6 +198,8 @@ class LigerPlugin(BasePlugin):
)
elif cfg.liger_fused_linear_cross_entropy:
try:
from .models.base import patch_lce_forward
patch_lce_forward(cfg.model_config_type)
LOG.warning_once(
f"Applied ONLY liger_fused_linear_cross_entropy genericpatches for model type: {cfg.model_config_type}"

View File

@@ -674,8 +674,8 @@ class ModelLoader:
del self.model_kwargs["device_map"]
transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True
transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = (
lambda: True
transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = lambda: (
True
)
return hf_ds_cfg

View File

@@ -201,7 +201,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # nosec B105
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Mistral's official FA implementation requires left padding

View File

@@ -189,7 +189,7 @@ def _get_remote_filesystem(
try:
import gcsfs
storage_options = {"token": None} # type: ignore
storage_options = {"token": None} # type: ignore # nosec B105
return gcsfs.GCSFileSystem(**storage_options), storage_options
except ImportError as exc:
raise ImportError(

View File

@@ -173,7 +173,6 @@ class AxolotlInputConfig(
"description": "Whether to perform weighting in DPO trainer"
},
)
dpo_use_logits_to_keep: bool | None = None
dpo_label_smoothing: float | None = None
dpo_norm_loss: bool | None = None
@@ -183,7 +182,6 @@ class AxolotlInputConfig(
)
dpo_padding_free: bool | None = None
dpo_generate_during_eval: bool | None = None
datasets: (
Annotated[

View File

@@ -19,6 +19,8 @@ class DeprecatedParameters(BaseModel):
evaluation_strategy: str | None = None
eval_table_size: int | None = None
eval_max_new_tokens: int | None = None
dpo_use_logits_to_keep: bool | None = None
dpo_generate_during_eval: bool | None = None
@field_validator("max_packed_sequence_len")
@classmethod
@@ -78,6 +80,26 @@ class DeprecatedParameters(BaseModel):
)
return eval_max_new_tokens
@field_validator("dpo_use_logits_to_keep")
@classmethod
def validate_dpo_use_logits_to_keep(cls, dpo_use_logits_to_keep):
if dpo_use_logits_to_keep is not None:
raise DeprecationWarning(
"`dpo_use_logits_to_keep` is no longer supported, "
"it has been removed in TRL >= 0.29.0"
)
return dpo_use_logits_to_keep
@field_validator("dpo_generate_during_eval")
@classmethod
def validate_dpo_generate_during_eval(cls, dpo_generate_during_eval):
if dpo_generate_during_eval is not None:
raise DeprecationWarning(
"`dpo_generate_during_eval` is no longer supported, "
"it has been removed in TRL >= 0.29.0"
)
return dpo_generate_during_eval
class RemappedParameters(BaseModel):
"""Parameters that have been remapped to other names"""

View File

@@ -2,7 +2,7 @@
import sys
from pathlib import Path
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
@@ -94,7 +94,6 @@ def fixture_dpo_cfg(base_cfg):
{
"rl": RLType.DPO,
"dpo_use_weighting": True,
"dpo_use_logits_to_keep": True,
"dpo_label_smoothing": 0.1,
"beta": 0.1, # DPO beta
}
@@ -148,9 +147,16 @@ def fixture_grpo_cfg(base_cfg):
),
# Must be evenly divisible by num_generations
"micro_batch_size": 4,
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"split": "train[:1%]",
}
],
}
)
return cfg
return DictDefault(cfg)
@pytest.fixture(name="ipo_cfg")
@@ -334,6 +340,7 @@ def rand_reward_func(prompts, completions) -> list[float]:
try:
builder = HFRLTrainerBuilder(grpo_cfg, model, tokenizer)
training_arguments, _ = builder._build_training_arguments(100)
builder.train_dataset = MagicMock()
self._test_common_training_arguments(training_arguments, rl=grpo_cfg.rl)
# GRPO specific
@@ -363,7 +370,7 @@ def rand_reward_func(prompts, completions) -> list[float]:
self._test_common_training_arguments(training_arguments, rl=ipo_cfg.rl)
# IPO specific
assert training_arguments.beta == 0.1
assert training_arguments.loss_type == "ipo"
assert training_arguments.loss_type == ["ipo"]
assert training_arguments.label_smoothing == 0
def test_simpo_training_arguments(self, simpo_cfg, model, tokenizer):
@@ -529,13 +536,11 @@ class TestHFCausalTrainerBuilder:
"cfg_string",
[
"sft_cfg",
"rm_cfg",
# "rm_cfg", # TODO fix for num_labels = 2 vs 1
"prm_cfg",
],
)
def test_custom_optimizer_cls_and_kwargs(
self, request, cfg_string, model, tokenizer
):
def test_builder_w_rm_trainers(self, request, cfg_string, model, tokenizer):
cfg = request.getfixturevalue(cfg_string)
builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
cfg["optimizer"] = "muon"

View File

@@ -18,6 +18,7 @@ Unit tests for SwanLab Integration Plugin.
Tests conflict detection, configuration validation, and multi-logger warnings.
"""
import importlib.util
import logging
import os
import time
@@ -25,12 +26,11 @@ from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from transformers.utils.import_utils import _is_package_available
from axolotl.integrations.swanlab.args import SwanLabConfig
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
SWANLAB_INSTALLED = _is_package_available("swanlab")
SWANLAB_INSTALLED = importlib.util.find_spec("swanlab") is not None
@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed")

View File

@@ -52,8 +52,8 @@ def mock_torch():
mock_torch.cuda.device_count.return_value = 2
# Mock memory allocated per device (1GB for device 0, 2GB for device 1)
mock_torch.cuda.memory_allocated.side_effect = (
lambda device: (device + 1) * 1024 * 1024 * 1024
mock_torch.cuda.memory_allocated.side_effect = lambda device: (
(device + 1) * 1024 * 1024 * 1024
)
yield mock_torch
@@ -292,8 +292,8 @@ class TestRuntimeMetricsTracker:
mock_memory_info = mock_process.memory_info.return_value
mock_memory_info.rss = 0.5 * 1024 * 1024 * 1024 # 0.5GB
mock_torch.cuda.memory_allocated.side_effect = (
lambda device: (device + 0.5) * 1024 * 1024 * 1024
mock_torch.cuda.memory_allocated.side_effect = lambda device: (
(device + 0.5) * 1024 * 1024 * 1024
)
# Update memory metrics again
@@ -307,8 +307,8 @@ class TestRuntimeMetricsTracker:
# Change mocked memory values to be higher
mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB
mock_torch.cuda.memory_allocated.side_effect = (
lambda device: (device + 2) * 1024 * 1024 * 1024
mock_torch.cuda.memory_allocated.side_effect = lambda device: (
(device + 2) * 1024 * 1024 * 1024
)
# Update memory metrics again

View File

@@ -84,7 +84,8 @@ class TestTokenizers:
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
assert "LlamaTokenizer" in tokenizer.__class__.__name__
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1792]
assert len(tokenizer) == 32001
# ensure reloading the tokenizer again from cfg results in same vocab length