From 3ce43b6db99946a10ea7d3cf5d532a161832f4cb Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 1 Apr 2025 17:53:12 +0000 Subject: [PATCH] simplifying trainer mixins and adding to rl trainers --- src/axolotl/core/trainers/__init__.py | 14 ++++---- src/axolotl/core/trainers/base.py | 11 ++---- src/axolotl/core/trainers/dpo/__init__.py | 8 ++--- src/axolotl/core/trainers/dpo/args.py | 8 ++--- src/axolotl/core/trainers/dpo/trainer.py | 4 +-- src/axolotl/core/trainers/grpo/trainer.py | 12 +++---- src/axolotl/core/trainers/mixins/__init__.py | 14 +++++--- .../core/trainers/mixins/rng_state_loader.py | 4 +-- src/axolotl/core/trainers/trl.py | 35 +++++++------------ src/axolotl/core/training_args.py | 4 +-- 10 files changed, 43 insertions(+), 71 deletions(-) diff --git a/src/axolotl/core/trainers/__init__.py b/src/axolotl/core/trainers/__init__.py index 32a889af9..1fa6cb4f5 100644 --- a/src/axolotl/core/trainers/__init__.py +++ b/src/axolotl/core/trainers/__init__.py @@ -3,16 +3,16 @@ # pylint: disable=unused-import # flake8: noqa -from .base import AxolotlTrainer -from .dpo.trainer import AxolotlDPOTrainer -from .grpo.trainer import AxolotlGRPOTrainer -from .mamba import AxolotlMambaTrainer -from .relora import ReLoRATrainer -from .trl import ( +from axolotl.core.trainers.base import AxolotlTrainer +from axolotl.core.trainers.dpo import AxolotlDPOTrainer +from axolotl.core.trainers.grpo import AxolotlGRPOTrainer +from axolotl.core.trainers.mamba import AxolotlMambaTrainer +from axolotl.core.trainers.relora import ReLoRATrainer +from axolotl.core.trainers.trl import ( AxolotlCPOTrainer, AxolotlKTOTrainer, AxolotlORPOTrainer, + AxolotlPPOTrainer, AxolotlPRMTrainer, AxolotlRewardTrainer, - TRLPPOTrainer, ) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 9fed78eb7..a1cb819b6 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -25,12 +25,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker from trl.trainer.utils import pad_to_length from typing_extensions import override -from axolotl.core.trainers.mixins import ( - OptimizerMixin, - RngLoaderMixin, - SchedulerMixin, - SequenceParallelMixin, -) +from axolotl.core.trainers.mixins import TrainerMixins from axolotl.core.trainers.utils import ( sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_tagging, @@ -40,9 +35,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = logging.getLogger(__name__) -class AxolotlTrainer( - SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer -): +class AxolotlTrainer(TrainerMixins, Trainer): """Extend the base Trainer for axolotl helpers""" args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py index 2d6835cf7..8b1a08b2b 100644 --- a/src/axolotl/core/trainers/dpo/__init__.py +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -1,14 +1,10 @@ -""" -DPO Specific Strategy for training -""" +"""DPO Specific Strategy for training""" from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer class DPOStrategy: - """ - Strategy for DPO training - """ + """Strategy for DPO training""" @classmethod def get_trainer_class(cls): diff --git a/src/axolotl/core/trainers/dpo/args.py b/src/axolotl/core/trainers/dpo/args.py index de1758ed0..8b6d9a950 100644 --- a/src/axolotl/core/trainers/dpo/args.py +++ b/src/axolotl/core/trainers/dpo/args.py @@ -1,6 +1,4 @@ -""" -Axolotl specific DPO args -""" +"""Axolotl specific DPO args""" from dataclasses import dataclass @@ -11,6 +9,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins @dataclass class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): - """ - DPO config for DPO training - """ + """DPO config for DPO training""" diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py index 89c77dca4..a6b8f56ba 100644 --- a/src/axolotl/core/trainers/dpo/trainer.py +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -13,7 +13,7 @@ from transformers import Trainer from transformers.utils import is_sagemaker_mp_enabled from trl import DPOTrainer -from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin +from axolotl.core.trainers.mixins import TrainerMixins from axolotl.core.trainers.utils import ( sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_tagging, @@ -23,7 +23,7 @@ if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp -class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): +class AxolotlDPOTrainer(TrainerMixins, DPOTrainer): """ Extend the base DPOTrainer for axolotl helpers """ diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 25aafa6a7..459d89fa3 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -1,6 +1,4 @@ -""" -Axolotl GRPO trainer -""" +"""Axolotl GRPO trainer""" from contextlib import nullcontext @@ -8,16 +6,14 @@ from accelerate.utils import is_deepspeed_available, is_peft_model from trl import GRPOTrainer from trl.extras.profiling import profiling_decorator -from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin +from axolotl.core.trainers.mixins import TrainerMixins if is_deepspeed_available(): import deepspeed -class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): - """ - Extend the base GRPOTrainer for axolotl helpers - """ +class AxolotlGRPOTrainer(TrainerMixins, GRPOTrainer): + """Extend the base GRPOTrainer for axolotl helpers""" _tag_names = ["trl", "grpo", "axolotl"] diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py index 44751b465..052754a2f 100644 --- a/src/axolotl/core/trainers/mixins/__init__.py +++ b/src/axolotl/core/trainers/mixins/__init__.py @@ -3,7 +3,13 @@ # pylint: disable=unused-import # flake8: noqa -from .optimizer import OptimizerMixin -from .rng_state_loader import RngLoaderMixin -from .scheduler import SchedulerMixin -from .sequence_parallel import SequenceParallelMixin +from axolotl.core.trainers.mixins.optimizer import OptimizerMixin +from axolotl.core.trainers.mixins.rng_state_loader import RngLoaderMixin +from axolotl.core.trainers.mixins.scheduler import SchedulerMixin +from axolotl.core.trainers.mixins.sequence_parallel import SequenceParallelMixin + + +class TrainerMixins( + OptimizerMixin, RngLoaderMixin, SchedulerMixin, SequenceParallelMixin +): + """Stub class combining all mixins for Axolotl trainers.""" diff --git a/src/axolotl/core/trainers/mixins/rng_state_loader.py b/src/axolotl/core/trainers/mixins/rng_state_loader.py index 0e101dabb..e7ce052dc 100644 --- a/src/axolotl/core/trainers/mixins/rng_state_loader.py +++ b/src/axolotl/core/trainers/mixins/rng_state_loader.py @@ -21,9 +21,7 @@ LOG = logging.getLogger(__name__) class RngLoaderMixin(Trainer): - """ - mixin for method override to load RNG states from a checkpoint - """ + """Mixin for method override to load RNG states from a checkpoint""" def _load_rng_state(self, checkpoint): # Load RNG states from `checkpoint` diff --git a/src/axolotl/core/trainers/trl.py b/src/axolotl/core/trainers/trl.py index b2c5c54ca..fadceef14 100644 --- a/src/axolotl/core/trainers/trl.py +++ b/src/axolotl/core/trainers/trl.py @@ -13,11 +13,10 @@ from trl import ( RewardTrainer, ) -from axolotl.core.trainers.mixins import RngLoaderMixin -from axolotl.core.trainers.mixins.scheduler import SchedulerMixin +from axolotl.core.trainers.mixins import TrainerMixins -class TRLPPOTrainer(PPOTrainer): +class AxolotlPPOTrainer(TrainerMixins, PPOTrainer): """Wrapper for TRL PPO trainer to handle customizations""" tag_names = ["axolotl", "ppo"] @@ -75,10 +74,8 @@ class TRLPPOTrainer(PPOTrainer): ) -class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer): - """ - Extend the base ORPOTrainer for axolotl helpers - """ +class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer): + """Extend the base ORPOTrainer for axolotl helpers""" tag_names = ["axolotl", "orpo"] @@ -155,18 +152,14 @@ class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer): return loss, metrics -class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer): - """ - Extend the base KTOTrainer for axolotl helpers - """ +class AxolotlKTOTrainer(TrainerMixins, KTOTrainer): + """Extend the base KTOTrainer for axolotl helpers""" tag_names = ["axolotl", "kto"] -class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer): - """ - Extend the base CPOTrainer for axolotl helpers - """ +class AxolotlCPOTrainer(TrainerMixins, CPOTrainer): + """Extend the base CPOTrainer for axolotl helpers""" tag_names = ["axolotl", "cpo"] @@ -245,17 +238,13 @@ class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer): return loss, metrics -class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer): - """ - Extend the base RewardTrainer for axolotl helpers - """ +class AxolotlRewardTrainer(TrainerMixins, RewardTrainer): + """Extend the base RewardTrainer for axolotl helpers""" tag_names = ["axolotl", "reward"] -class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer): - """ - Extend the base trl.PRMTrainer for axolotl helpers - """ +class AxolotlPRMTrainer(TrainerMixins, PRMTrainer): + """Extend the base trl.PRMTrainer for axolotl helpers""" tag_names = ["axolotl", "prm"] diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 18843abb4..fd8e2c7d0 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -12,9 +12,7 @@ from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig @dataclass class AxolotlTrainingMixins: - """ - Mixin class for the Axolotl training args. - """ + """Mixin class for the Axolotl training args.""" # pylint: disable=duplicate-code model_type: Optional[str] = field(