upgrade transformers==5.3.0 trl==0.29.0 kernels (#3459)

* upgrade transformers==5.3.0 trl==0.29.0 kernels

* use latest deepspeed fixes

* use corect image for cleanup

* fix test outputs for tokenizer fixes upstream

* fix import:

* keep trl at 0.28.0

* handle updated API

* use latest trl since 0.28.0 doesn't work with latest transformers

* use trl experimental for pad to length

* monkeypatch trl with ORPOTrainer so liger doesn't croak

* upgrade accelerate

* more fixes

* move patch for orpotrainer

* load the imports later

* remove use_logits_to_keep

* fix loss_type arg as a list

* fetch hf cache from s3

* just manually download the missing model for now

* lint for pre-commit update

* a few more missing models on disk

* fix: loss_type internally now list

* fix: remove deprecated code and raise deprecate

* fix: remove unneeded blocklist

* fix: remove reliance on transformers api to find package available

* chore: refactor shim for less sideeffect

* fix: silent trl experimental warning

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Wing Lian
2026-03-06 09:11:20 -05:00
committed by GitHub
parent 56162f71db
commit cada93cee5
19 changed files with 81 additions and 49 deletions

View File

@@ -6,5 +6,6 @@ from axolotl.logging_config import configure_logging
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("HF_XET_HIGH_PERFORMANCE", "1")
os.environ.setdefault("TRL_EXPERIMENTAL_SILENCE", "1")
configure_logging()

View File

@@ -120,11 +120,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
else:
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl is RLType.SIMPO:

View File

@@ -26,7 +26,7 @@ from transformers import PreTrainedModel, Trainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
from trl.trainer.utils import pad_to_length
from trl.experimental.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.mixins import (

View File

@@ -25,17 +25,13 @@ class DPOStrategy:
# Label smoothing is not compatible with IPO
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
if cfg.dpo_padding_free is not None:
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
if cfg.dpo_norm_loss is not None:
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
if cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
if cfg.dpo_use_liger_kernel is not None:
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
return training_args_kwargs

View File

@@ -103,10 +103,10 @@ class AxolotlDPOTrainer(
) -> dict[str, torch.Tensor]:
if self.args.dpo_norm_loss:
# fmt: off
loss_type: str = self.loss_type # type: ignore[has-type]
loss_type: list[str] = self.loss_type # type: ignore[has-type]
# fmt: on
# concatenated_forward handles avg token logprob for ipo case already
self.loss_type = "ipo"
self.loss_type = ["ipo"]
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
self.loss_type = loss_type
return res

View File

@@ -104,7 +104,7 @@ class OptimizerMixin(Trainer):
return optimizer_grouped_parameters
def create_optimizer(self):
def create_optimizer(self, model=None):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
@@ -112,9 +112,9 @@ class OptimizerMixin(Trainer):
and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None
):
return super().create_optimizer()
return super().create_optimizer(model=model)
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
opt_model = self.model if model is None else model
if (
not self.optimizer

View File

@@ -8,9 +8,6 @@ import sys
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
from .models.base import patch_lce_forward
from .utils import patch_with_compile_disable
LOG = get_logger(__name__)
@@ -23,10 +20,18 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
# shim: liger-kernel 0.7.0 imports ORPOTrainer from old trl path
import trl.trainer
from trl.experimental.orpo import ORPOTrainer
trl.trainer.ORPOTrainer = ORPOTrainer
if cfg.torch_compile:
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
import liger_kernel.ops.fused_linear_cross_entropy
from .utils import patch_with_compile_disable
patch_with_compile_disable(
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_forward",
@@ -35,6 +40,7 @@ class LigerPlugin(BasePlugin):
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_backward",
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
@@ -192,6 +198,8 @@ class LigerPlugin(BasePlugin):
)
elif cfg.liger_fused_linear_cross_entropy:
try:
from .models.base import patch_lce_forward
patch_lce_forward(cfg.model_config_type)
LOG.warning_once(
f"Applied ONLY liger_fused_linear_cross_entropy genericpatches for model type: {cfg.model_config_type}"

View File

@@ -674,8 +674,8 @@ class ModelLoader:
del self.model_kwargs["device_map"]
transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True
transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = (
lambda: True
transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = lambda: (
True
)
return hf_ds_cfg

View File

@@ -201,7 +201,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # nosec B105
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Mistral's official FA implementation requires left padding

View File

@@ -189,7 +189,7 @@ def _get_remote_filesystem(
try:
import gcsfs
storage_options = {"token": None} # type: ignore
storage_options = {"token": None} # type: ignore # nosec B105
return gcsfs.GCSFileSystem(**storage_options), storage_options
except ImportError as exc:
raise ImportError(

View File

@@ -173,7 +173,6 @@ class AxolotlInputConfig(
"description": "Whether to perform weighting in DPO trainer"
},
)
dpo_use_logits_to_keep: bool | None = None
dpo_label_smoothing: float | None = None
dpo_norm_loss: bool | None = None
@@ -183,7 +182,6 @@ class AxolotlInputConfig(
)
dpo_padding_free: bool | None = None
dpo_generate_during_eval: bool | None = None
datasets: (
Annotated[

View File

@@ -19,6 +19,8 @@ class DeprecatedParameters(BaseModel):
evaluation_strategy: str | None = None
eval_table_size: int | None = None
eval_max_new_tokens: int | None = None
dpo_use_logits_to_keep: bool | None = None
dpo_generate_during_eval: bool | None = None
@field_validator("max_packed_sequence_len")
@classmethod
@@ -78,6 +80,26 @@ class DeprecatedParameters(BaseModel):
)
return eval_max_new_tokens
@field_validator("dpo_use_logits_to_keep")
@classmethod
def validate_dpo_use_logits_to_keep(cls, dpo_use_logits_to_keep):
if dpo_use_logits_to_keep is not None:
raise DeprecationWarning(
"`dpo_use_logits_to_keep` is no longer supported, "
"it has been removed in TRL >= 0.29.0"
)
return dpo_use_logits_to_keep
@field_validator("dpo_generate_during_eval")
@classmethod
def validate_dpo_generate_during_eval(cls, dpo_generate_during_eval):
if dpo_generate_during_eval is not None:
raise DeprecationWarning(
"`dpo_generate_during_eval` is no longer supported, "
"it has been removed in TRL >= 0.29.0"
)
return dpo_generate_during_eval
class RemappedParameters(BaseModel):
"""Parameters that have been remapped to other names"""