simplifying trainer mixins and adding to rl trainers

This commit is contained in:
Dan Saunders
2025-04-01 17:53:12 +00:00
parent 7d0eb66b54
commit 3ce43b6db9
10 changed files with 43 additions and 71 deletions

View File

@@ -3,16 +3,16 @@
# pylint: disable=unused-import # pylint: disable=unused-import
# flake8: noqa # flake8: noqa
from .base import AxolotlTrainer from axolotl.core.trainers.base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer from axolotl.core.trainers.dpo import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOTrainer from axolotl.core.trainers.grpo import AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer from axolotl.core.trainers.mamba import AxolotlMambaTrainer
from .relora import ReLoRATrainer from axolotl.core.trainers.relora import ReLoRATrainer
from .trl import ( from axolotl.core.trainers.trl import (
AxolotlCPOTrainer, AxolotlCPOTrainer,
AxolotlKTOTrainer, AxolotlKTOTrainer,
AxolotlORPOTrainer, AxolotlORPOTrainer,
AxolotlPPOTrainer,
AxolotlPRMTrainer, AxolotlPRMTrainer,
AxolotlRewardTrainer, AxolotlRewardTrainer,
TRLPPOTrainer,
) )

View File

@@ -25,12 +25,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from trl.trainer.utils import pad_to_length from trl.trainer.utils import pad_to_length
from typing_extensions import override from typing_extensions import override
from axolotl.core.trainers.mixins import ( from axolotl.core.trainers.mixins import TrainerMixins
OptimizerMixin,
RngLoaderMixin,
SchedulerMixin,
SequenceParallelMixin,
)
from axolotl.core.trainers.utils import ( from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging, sanitize_kwargs_for_tagging,
@@ -40,9 +35,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class AxolotlTrainer( class AxolotlTrainer(TrainerMixins, Trainer):
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
):
"""Extend the base Trainer for axolotl helpers""" """Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]

View File

@@ -1,14 +1,10 @@
""" """DPO Specific Strategy for training"""
DPO Specific Strategy for training
"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
class DPOStrategy: class DPOStrategy:
""" """Strategy for DPO training"""
Strategy for DPO training
"""
@classmethod @classmethod
def get_trainer_class(cls): def get_trainer_class(cls):

View File

@@ -1,6 +1,4 @@
""" """Axolotl specific DPO args"""
Axolotl specific DPO args
"""
from dataclasses import dataclass from dataclasses import dataclass
@@ -11,6 +9,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass @dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
""" """DPO config for DPO training"""
DPO config for DPO training
"""

View File

@@ -13,7 +13,7 @@ from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer from trl import DPOTrainer
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.core.trainers.mixins import TrainerMixins
from axolotl.core.trainers.utils import ( from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging, sanitize_kwargs_for_tagging,
@@ -23,7 +23,7 @@ if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
""" """
Extend the base DPOTrainer for axolotl helpers Extend the base DPOTrainer for axolotl helpers
""" """

View File

@@ -1,6 +1,4 @@
""" """Axolotl GRPO trainer"""
Axolotl GRPO trainer
"""
from contextlib import nullcontext from contextlib import nullcontext
@@ -8,16 +6,14 @@ from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOTrainer from trl import GRPOTrainer
from trl.extras.profiling import profiling_decorator from trl.extras.profiling import profiling_decorator
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.core.trainers.mixins import TrainerMixins
if is_deepspeed_available(): if is_deepspeed_available():
import deepspeed import deepspeed
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): class AxolotlGRPOTrainer(TrainerMixins, GRPOTrainer):
""" """Extend the base GRPOTrainer for axolotl helpers"""
Extend the base GRPOTrainer for axolotl helpers
"""
_tag_names = ["trl", "grpo", "axolotl"] _tag_names = ["trl", "grpo", "axolotl"]

View File

@@ -3,7 +3,13 @@
# pylint: disable=unused-import # pylint: disable=unused-import
# flake8: noqa # flake8: noqa
from .optimizer import OptimizerMixin from axolotl.core.trainers.mixins.optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin from axolotl.core.trainers.mixins.rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin from axolotl.core.trainers.mixins.sequence_parallel import SequenceParallelMixin
class TrainerMixins(
OptimizerMixin, RngLoaderMixin, SchedulerMixin, SequenceParallelMixin
):
"""Stub class combining all mixins for Axolotl trainers."""

View File

@@ -21,9 +21,7 @@ LOG = logging.getLogger(__name__)
class RngLoaderMixin(Trainer): class RngLoaderMixin(Trainer):
""" """Mixin for method override to load RNG states from a checkpoint"""
mixin for method override to load RNG states from a checkpoint
"""
def _load_rng_state(self, checkpoint): def _load_rng_state(self, checkpoint):
# Load RNG states from `checkpoint` # Load RNG states from `checkpoint`

View File

@@ -13,11 +13,10 @@ from trl import (
RewardTrainer, RewardTrainer,
) )
from axolotl.core.trainers.mixins import RngLoaderMixin from axolotl.core.trainers.mixins import TrainerMixins
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
class TRLPPOTrainer(PPOTrainer): class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
"""Wrapper for TRL PPO trainer to handle customizations""" """Wrapper for TRL PPO trainer to handle customizations"""
tag_names = ["axolotl", "ppo"] tag_names = ["axolotl", "ppo"]
@@ -75,10 +74,8 @@ class TRLPPOTrainer(PPOTrainer):
) )
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer): class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
""" """Extend the base ORPOTrainer for axolotl helpers"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"] tag_names = ["axolotl", "orpo"]
@@ -155,18 +152,14 @@ class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
return loss, metrics return loss, metrics
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer): class AxolotlKTOTrainer(TrainerMixins, KTOTrainer):
""" """Extend the base KTOTrainer for axolotl helpers"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"] tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer): class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
""" """Extend the base CPOTrainer for axolotl helpers"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"] tag_names = ["axolotl", "cpo"]
@@ -245,17 +238,13 @@ class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
return loss, metrics return loss, metrics
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer): class AxolotlRewardTrainer(TrainerMixins, RewardTrainer):
""" """Extend the base RewardTrainer for axolotl helpers"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"] tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer): class AxolotlPRMTrainer(TrainerMixins, PRMTrainer):
""" """Extend the base trl.PRMTrainer for axolotl helpers"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
tag_names = ["axolotl", "prm"] tag_names = ["axolotl", "prm"]

View File

@@ -12,9 +12,7 @@ from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
@dataclass @dataclass
class AxolotlTrainingMixins: class AxolotlTrainingMixins:
""" """Mixin class for the Axolotl training args."""
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
model_type: Optional[str] = field( model_type: Optional[str] = field(