upgrade transformers==5.3.0 trl==0.29.0 kernels (#3459)
* upgrade transformers==5.3.0 trl==0.29.0 kernels * use latest deepspeed fixes * use corect image for cleanup * fix test outputs for tokenizer fixes upstream * fix import: * keep trl at 0.28.0 * handle updated API * use latest trl since 0.28.0 doesn't work with latest transformers * use trl experimental for pad to length * monkeypatch trl with ORPOTrainer so liger doesn't croak * upgrade accelerate * more fixes * move patch for orpotrainer * load the imports later * remove use_logits_to_keep * fix loss_type arg as a list * fetch hf cache from s3 * just manually download the missing model for now * lint for pre-commit update * a few more missing models on disk * fix: loss_type internally now list * fix: remove deprecated code and raise deprecate * fix: remove unneeded blocklist * fix: remove reliance on transformers api to find package available * chore: refactor shim for less sideeffect * fix: silent trl experimental warning --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -387,8 +387,8 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 129
|
||||
cuda_version: 12.9.1
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
|
||||
@@ -3,6 +3,12 @@ set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
# curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||
hf download "NousResearch/Meta-Llama-3-8B"
|
||||
hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
hf download "microsoft/Phi-4-reasoning"
|
||||
hf download "microsoft/Phi-3.5-mini-instruct"
|
||||
|
||||
# Run unit tests with initial coverage report
|
||||
pytest -v --durations=10 -n8 \
|
||||
--ignore=tests/e2e/ \
|
||||
|
||||
@@ -12,13 +12,13 @@ packaging==26.0
|
||||
huggingface_hub>=1.1.7
|
||||
peft>=0.18.1
|
||||
tokenizers>=0.22.1
|
||||
transformers==5.2.0
|
||||
accelerate==1.12.0
|
||||
transformers==5.3.0
|
||||
accelerate==1.13.0
|
||||
datasets==4.5.0
|
||||
deepspeed>=0.18.3
|
||||
trl==0.28.0
|
||||
hf_xet==1.2.0
|
||||
kernels==0.12.1
|
||||
deepspeed>=0.18.6,<0.19.0
|
||||
trl==0.29.0
|
||||
hf_xet==1.3.2
|
||||
kernels==0.12.2
|
||||
|
||||
trackio>=0.16.1
|
||||
typing-extensions>=4.15.0
|
||||
|
||||
@@ -6,5 +6,6 @@ from axolotl.logging_config import configure_logging
|
||||
|
||||
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
||||
os.environ.setdefault("HF_XET_HIGH_PERFORMANCE", "1")
|
||||
os.environ.setdefault("TRL_EXPERIMENTAL_SILENCE", "1")
|
||||
|
||||
configure_logging()
|
||||
|
||||
@@ -120,11 +120,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.use_wandb:
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
else:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
|
||||
training_args_cls = None
|
||||
blocklist_args_kwargs = []
|
||||
if self.cfg.rl is RLType.SIMPO:
|
||||
|
||||
@@ -26,7 +26,7 @@ from transformers import PreTrainedModel, Trainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from trl.experimental.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.mixins import (
|
||||
|
||||
@@ -25,17 +25,13 @@ class DPOStrategy:
|
||||
# Label smoothing is not compatible with IPO
|
||||
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
|
||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
||||
if cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||
if cfg.dpo_padding_free is not None:
|
||||
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
|
||||
if cfg.dpo_norm_loss is not None:
|
||||
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
||||
if cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
||||
if cfg.dpo_use_liger_kernel is not None:
|
||||
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
|
||||
return training_args_kwargs
|
||||
|
||||
@@ -103,10 +103,10 @@ class AxolotlDPOTrainer(
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if self.args.dpo_norm_loss:
|
||||
# fmt: off
|
||||
loss_type: str = self.loss_type # type: ignore[has-type]
|
||||
loss_type: list[str] = self.loss_type # type: ignore[has-type]
|
||||
# fmt: on
|
||||
# concatenated_forward handles avg token logprob for ipo case already
|
||||
self.loss_type = "ipo"
|
||||
self.loss_type = ["ipo"]
|
||||
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
|
||||
self.loss_type = loss_type
|
||||
return res
|
||||
|
||||
@@ -104,7 +104,7 @@ class OptimizerMixin(Trainer):
|
||||
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
def create_optimizer(self):
|
||||
def create_optimizer(self, model=None):
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.embedding_lr_scale is None
|
||||
@@ -112,9 +112,9 @@ class OptimizerMixin(Trainer):
|
||||
and self.args.lr_groups is None
|
||||
and self.optimizer_cls_and_kwargs is None
|
||||
):
|
||||
return super().create_optimizer()
|
||||
return super().create_optimizer(model=model)
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
opt_model = self.model if model is None else model
|
||||
|
||||
if (
|
||||
not self.optimizer
|
||||
|
||||
@@ -8,9 +8,6 @@ import sys
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .models.base import patch_lce_forward
|
||||
from .utils import patch_with_compile_disable
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -23,10 +20,18 @@ class LigerPlugin(BasePlugin):
|
||||
return "axolotl.integrations.liger.LigerArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
# shim: liger-kernel 0.7.0 imports ORPOTrainer from old trl path
|
||||
import trl.trainer
|
||||
from trl.experimental.orpo import ORPOTrainer
|
||||
|
||||
trl.trainer.ORPOTrainer = ORPOTrainer
|
||||
|
||||
if cfg.torch_compile:
|
||||
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
|
||||
import liger_kernel.ops.fused_linear_cross_entropy
|
||||
|
||||
from .utils import patch_with_compile_disable
|
||||
|
||||
patch_with_compile_disable(
|
||||
liger_kernel.ops.fused_linear_cross_entropy,
|
||||
"fused_linear_cross_entropy_forward",
|
||||
@@ -35,6 +40,7 @@ class LigerPlugin(BasePlugin):
|
||||
liger_kernel.ops.fused_linear_cross_entropy,
|
||||
"fused_linear_cross_entropy_backward",
|
||||
)
|
||||
|
||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
@@ -192,6 +198,8 @@ class LigerPlugin(BasePlugin):
|
||||
)
|
||||
elif cfg.liger_fused_linear_cross_entropy:
|
||||
try:
|
||||
from .models.base import patch_lce_forward
|
||||
|
||||
patch_lce_forward(cfg.model_config_type)
|
||||
LOG.warning_once(
|
||||
f"Applied ONLY liger_fused_linear_cross_entropy genericpatches for model type: {cfg.model_config_type}"
|
||||
|
||||
@@ -674,8 +674,8 @@ class ModelLoader:
|
||||
del self.model_kwargs["device_map"]
|
||||
|
||||
transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True
|
||||
transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = (
|
||||
lambda: True
|
||||
transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = lambda: (
|
||||
True
|
||||
)
|
||||
|
||||
return hf_ds_cfg
|
||||
|
||||
@@ -201,7 +201,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
||||
|
||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # nosec B105
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Mistral's official FA implementation requires left padding
|
||||
|
||||
@@ -189,7 +189,7 @@ def _get_remote_filesystem(
|
||||
try:
|
||||
import gcsfs
|
||||
|
||||
storage_options = {"token": None} # type: ignore
|
||||
storage_options = {"token": None} # type: ignore # nosec B105
|
||||
return gcsfs.GCSFileSystem(**storage_options), storage_options
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
|
||||
@@ -173,7 +173,6 @@ class AxolotlInputConfig(
|
||||
"description": "Whether to perform weighting in DPO trainer"
|
||||
},
|
||||
)
|
||||
dpo_use_logits_to_keep: bool | None = None
|
||||
dpo_label_smoothing: float | None = None
|
||||
dpo_norm_loss: bool | None = None
|
||||
|
||||
@@ -183,7 +182,6 @@ class AxolotlInputConfig(
|
||||
)
|
||||
|
||||
dpo_padding_free: bool | None = None
|
||||
dpo_generate_during_eval: bool | None = None
|
||||
|
||||
datasets: (
|
||||
Annotated[
|
||||
|
||||
@@ -19,6 +19,8 @@ class DeprecatedParameters(BaseModel):
|
||||
evaluation_strategy: str | None = None
|
||||
eval_table_size: int | None = None
|
||||
eval_max_new_tokens: int | None = None
|
||||
dpo_use_logits_to_keep: bool | None = None
|
||||
dpo_generate_during_eval: bool | None = None
|
||||
|
||||
@field_validator("max_packed_sequence_len")
|
||||
@classmethod
|
||||
@@ -78,6 +80,26 @@ class DeprecatedParameters(BaseModel):
|
||||
)
|
||||
return eval_max_new_tokens
|
||||
|
||||
@field_validator("dpo_use_logits_to_keep")
|
||||
@classmethod
|
||||
def validate_dpo_use_logits_to_keep(cls, dpo_use_logits_to_keep):
|
||||
if dpo_use_logits_to_keep is not None:
|
||||
raise DeprecationWarning(
|
||||
"`dpo_use_logits_to_keep` is no longer supported, "
|
||||
"it has been removed in TRL >= 0.29.0"
|
||||
)
|
||||
return dpo_use_logits_to_keep
|
||||
|
||||
@field_validator("dpo_generate_during_eval")
|
||||
@classmethod
|
||||
def validate_dpo_generate_during_eval(cls, dpo_generate_during_eval):
|
||||
if dpo_generate_during_eval is not None:
|
||||
raise DeprecationWarning(
|
||||
"`dpo_generate_during_eval` is no longer supported, "
|
||||
"it has been removed in TRL >= 0.29.0"
|
||||
)
|
||||
return dpo_generate_during_eval
|
||||
|
||||
|
||||
class RemappedParameters(BaseModel):
|
||||
"""Parameters that have been remapped to other names"""
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -94,7 +94,6 @@ def fixture_dpo_cfg(base_cfg):
|
||||
{
|
||||
"rl": RLType.DPO,
|
||||
"dpo_use_weighting": True,
|
||||
"dpo_use_logits_to_keep": True,
|
||||
"dpo_label_smoothing": 0.1,
|
||||
"beta": 0.1, # DPO beta
|
||||
}
|
||||
@@ -148,9 +147,16 @@ def fixture_grpo_cfg(base_cfg):
|
||||
),
|
||||
# Must be evenly divisible by num_generations
|
||||
"micro_batch_size": 4,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "openai/gsm8k",
|
||||
"name": "main",
|
||||
"split": "train[:1%]",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
return cfg
|
||||
return DictDefault(cfg)
|
||||
|
||||
|
||||
@pytest.fixture(name="ipo_cfg")
|
||||
@@ -334,6 +340,7 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
||||
try:
|
||||
builder = HFRLTrainerBuilder(grpo_cfg, model, tokenizer)
|
||||
training_arguments, _ = builder._build_training_arguments(100)
|
||||
builder.train_dataset = MagicMock()
|
||||
|
||||
self._test_common_training_arguments(training_arguments, rl=grpo_cfg.rl)
|
||||
# GRPO specific
|
||||
@@ -363,7 +370,7 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
||||
self._test_common_training_arguments(training_arguments, rl=ipo_cfg.rl)
|
||||
# IPO specific
|
||||
assert training_arguments.beta == 0.1
|
||||
assert training_arguments.loss_type == "ipo"
|
||||
assert training_arguments.loss_type == ["ipo"]
|
||||
assert training_arguments.label_smoothing == 0
|
||||
|
||||
def test_simpo_training_arguments(self, simpo_cfg, model, tokenizer):
|
||||
@@ -529,13 +536,11 @@ class TestHFCausalTrainerBuilder:
|
||||
"cfg_string",
|
||||
[
|
||||
"sft_cfg",
|
||||
"rm_cfg",
|
||||
# "rm_cfg", # TODO fix for num_labels = 2 vs 1
|
||||
"prm_cfg",
|
||||
],
|
||||
)
|
||||
def test_custom_optimizer_cls_and_kwargs(
|
||||
self, request, cfg_string, model, tokenizer
|
||||
):
|
||||
def test_builder_w_rm_trainers(self, request, cfg_string, model, tokenizer):
|
||||
cfg = request.getfixturevalue(cfg_string)
|
||||
builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
||||
cfg["optimizer"] = "muon"
|
||||
|
||||
@@ -18,6 +18,7 @@ Unit tests for SwanLab Integration Plugin.
|
||||
Tests conflict detection, configuration validation, and multi-logger warnings.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
@@ -25,12 +26,11 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
from axolotl.integrations.swanlab.args import SwanLabConfig
|
||||
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
|
||||
|
||||
SWANLAB_INSTALLED = _is_package_available("swanlab")
|
||||
SWANLAB_INSTALLED = importlib.util.find_spec("swanlab") is not None
|
||||
|
||||
|
||||
@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed")
|
||||
|
||||
@@ -52,8 +52,8 @@ def mock_torch():
|
||||
mock_torch.cuda.device_count.return_value = 2
|
||||
|
||||
# Mock memory allocated per device (1GB for device 0, 2GB for device 1)
|
||||
mock_torch.cuda.memory_allocated.side_effect = (
|
||||
lambda device: (device + 1) * 1024 * 1024 * 1024
|
||||
mock_torch.cuda.memory_allocated.side_effect = lambda device: (
|
||||
(device + 1) * 1024 * 1024 * 1024
|
||||
)
|
||||
|
||||
yield mock_torch
|
||||
@@ -292,8 +292,8 @@ class TestRuntimeMetricsTracker:
|
||||
mock_memory_info = mock_process.memory_info.return_value
|
||||
mock_memory_info.rss = 0.5 * 1024 * 1024 * 1024 # 0.5GB
|
||||
|
||||
mock_torch.cuda.memory_allocated.side_effect = (
|
||||
lambda device: (device + 0.5) * 1024 * 1024 * 1024
|
||||
mock_torch.cuda.memory_allocated.side_effect = lambda device: (
|
||||
(device + 0.5) * 1024 * 1024 * 1024
|
||||
)
|
||||
|
||||
# Update memory metrics again
|
||||
@@ -307,8 +307,8 @@ class TestRuntimeMetricsTracker:
|
||||
# Change mocked memory values to be higher
|
||||
mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB
|
||||
|
||||
mock_torch.cuda.memory_allocated.side_effect = (
|
||||
lambda device: (device + 2) * 1024 * 1024 * 1024
|
||||
mock_torch.cuda.memory_allocated.side_effect = lambda device: (
|
||||
(device + 2) * 1024 * 1024 * 1024
|
||||
)
|
||||
|
||||
# Update memory metrics again
|
||||
|
||||
@@ -84,7 +84,8 @@ class TestTokenizers:
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
|
||||
assert "LlamaTokenizer" in tokenizer.__class__.__name__
|
||||
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1792]
|
||||
assert len(tokenizer) == 32001
|
||||
|
||||
# ensure reloading the tokenizer again from cfg results in same vocab length
|
||||
|
||||
Reference in New Issue
Block a user