Release update 20250331 (#2460) [skip ci]

* make torch 2.6.0 the default image

* fix tests against upstream main

* fix attribute access

* use fixture dataset

* fix dataset load

* correct the fixtures + tests

* more fixtures

* add accidentally removed shakespeare fixture

* fix conversion from unittest to pytest class

* nightly main ci caches

* build 12.6.3 cuda base image

* override for fix from huggingface/transformers#37162

* address PR feedback
This commit is contained in:
Wing Lian
2025-04-01 08:47:50 -04:00
committed by GitHub
parent 328d598114
commit e0aba74dd0
17 changed files with 347 additions and 169 deletions

View File

@@ -28,6 +28,7 @@ from typing_extensions import override
from axolotl.core.trainers.mixins import (
OptimizerMixin,
RngLoaderMixin,
SchedulerMixin,
SequenceParallelMixin,
)
@@ -40,7 +41,9 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__)
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
class AxolotlTrainer(
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
):
"""Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]

View File

@@ -13,7 +13,7 @@ from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.trainers.mixins import SchedulerMixin
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
@@ -23,7 +23,7 @@ if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""

View File

@@ -8,13 +8,13 @@ from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOTrainer
from trl.extras.profiling import profiling_decorator
from axolotl.core.trainers.base import SchedulerMixin
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
if is_deepspeed_available():
import deepspeed
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""
Extend the base GRPOTrainer for axolotl helpers
"""

View File

@@ -4,5 +4,6 @@
# flake8: noqa
from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin

View File

@@ -0,0 +1,67 @@
"""
Temporary fix/override for bug in resume from checkpoint
See https://github.com/huggingface/transformers/pull/37162
TODO: Remove when upstream added PR to release
"""
import logging
import os
import random
import numpy as np
import torch
from transformers import Trainer, is_torch_npu_available
from transformers.trainer import safe_globals
from transformers.trainer_pt_utils import set_rng_state_for_device
from transformers.training_args import ParallelMode
LOG = logging.getLogger(__name__)
class RngLoaderMixin(Trainer):
"""
mixin for method override to load RNG states from a checkpoint
"""
def _load_rng_state(self, checkpoint):
# Load RNG states from `checkpoint`
if checkpoint is None:
return
if self.args.world_size > 1:
process_index = self.args.process_index
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
if not os.path.isfile(rng_file):
LOG.info(
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
)
return
else:
rng_file = os.path.join(checkpoint, "rng_state.pth")
if not os.path.isfile(rng_file):
LOG.info(
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
"fashion, reproducibility is not guaranteed."
)
return
# Use safe_globals to ensure numpy RNG states can be deserialized safely under PyTorch 2.6+,
# which requires allowlisted classes when loading with weights_only=True.
with safe_globals():
checkpoint_rng_state = torch.load(rng_file) # nosec B614
random.setstate(checkpoint_rng_state["python"])
np.random.set_state(checkpoint_rng_state["numpy"])
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
if torch.cuda.is_available():
set_rng_state_for_device(
"CUDA", torch.cuda, checkpoint_rng_state, is_distributed
)
if is_torch_npu_available():
set_rng_state_for_device(
"NPU", torch.npu, checkpoint_rng_state, is_distributed
)

View File

@@ -13,6 +13,7 @@ from trl import (
RewardTrainer,
)
from axolotl.core.trainers.mixins import RngLoaderMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
@@ -74,7 +75,7 @@ class TRLPPOTrainer(PPOTrainer):
)
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
@@ -154,7 +155,7 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
return loss, metrics
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
@@ -162,7 +163,7 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
@@ -244,7 +245,7 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
return loss, metrics
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
@@ -252,7 +253,7 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""

View File

@@ -1270,3 +1270,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data["beta"] != data["trl"]["beta"]:
raise ValueError("beta and trl.beta must match or one must be removed")
return data
@model_validator(mode="after")
def check_min_torch_version(self):
if self.env_capabilities and self.env_capabilities.torch_version:
torch_version = self.env_capabilities.torch_version
if version.parse(torch_version) < version.parse("2.5.1"):
LOG.warning(
f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1."
)