simplifying trainer mixins and adding to rl trainers

This commit is contained in:
Dan Saunders
2025-04-01 17:53:12 +00:00
parent 7d0eb66b54
commit 3ce43b6db9
10 changed files with 43 additions and 71 deletions

View File

@@ -3,16 +3,16 @@
# pylint: disable=unused-import
# flake8: noqa
from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer
from .relora import ReLoRATrainer
from .trl import (
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.core.trainers.dpo import AxolotlDPOTrainer
from axolotl.core.trainers.grpo import AxolotlGRPOTrainer
from axolotl.core.trainers.mamba import AxolotlMambaTrainer
from axolotl.core.trainers.relora import ReLoRATrainer
from axolotl.core.trainers.trl import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlORPOTrainer,
AxolotlPPOTrainer,
AxolotlPRMTrainer,
AxolotlRewardTrainer,
TRLPPOTrainer,
)

View File

@@ -25,12 +25,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.mixins import (
OptimizerMixin,
RngLoaderMixin,
SchedulerMixin,
SequenceParallelMixin,
)
from axolotl.core.trainers.mixins import TrainerMixins
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
@@ -40,9 +35,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__)
class AxolotlTrainer(
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
):
class AxolotlTrainer(TrainerMixins, Trainer):
"""Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]

View File

@@ -1,14 +1,10 @@
"""
DPO Specific Strategy for training
"""
"""DPO Specific Strategy for training"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
class DPOStrategy:
"""
Strategy for DPO training
"""
"""Strategy for DPO training"""
@classmethod
def get_trainer_class(cls):

View File

@@ -1,6 +1,4 @@
"""
Axolotl specific DPO args
"""
"""Axolotl specific DPO args"""
from dataclasses import dataclass
@@ -11,6 +9,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""
"""DPO config for DPO training"""

View File

@@ -13,7 +13,7 @@ from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins import TrainerMixins
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
@@ -23,7 +23,7 @@ if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""

View File

@@ -1,6 +1,4 @@
"""
Axolotl GRPO trainer
"""
"""Axolotl GRPO trainer"""
from contextlib import nullcontext
@@ -8,16 +6,14 @@ from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOTrainer
from trl.extras.profiling import profiling_decorator
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins import TrainerMixins
if is_deepspeed_available():
import deepspeed
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""
Extend the base GRPOTrainer for axolotl helpers
"""
class AxolotlGRPOTrainer(TrainerMixins, GRPOTrainer):
"""Extend the base GRPOTrainer for axolotl helpers"""
_tag_names = ["trl", "grpo", "axolotl"]

View File

@@ -3,7 +3,13 @@
# pylint: disable=unused-import
# flake8: noqa
from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerMixin
from axolotl.core.trainers.mixins.rng_state_loader import RngLoaderMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
from axolotl.core.trainers.mixins.sequence_parallel import SequenceParallelMixin
class TrainerMixins(
OptimizerMixin, RngLoaderMixin, SchedulerMixin, SequenceParallelMixin
):
"""Stub class combining all mixins for Axolotl trainers."""

View File

@@ -21,9 +21,7 @@ LOG = logging.getLogger(__name__)
class RngLoaderMixin(Trainer):
"""
mixin for method override to load RNG states from a checkpoint
"""
"""Mixin for method override to load RNG states from a checkpoint"""
def _load_rng_state(self, checkpoint):
# Load RNG states from `checkpoint`

View File

@@ -13,11 +13,10 @@ from trl import (
RewardTrainer,
)
from axolotl.core.trainers.mixins import RngLoaderMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
from axolotl.core.trainers.mixins import TrainerMixins
class TRLPPOTrainer(PPOTrainer):
class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
"""Wrapper for TRL PPO trainer to handle customizations"""
tag_names = ["axolotl", "ppo"]
@@ -75,10 +74,8 @@ class TRLPPOTrainer(PPOTrainer):
)
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
"""Extend the base ORPOTrainer for axolotl helpers"""
tag_names = ["axolotl", "orpo"]
@@ -155,18 +152,14 @@ class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
return loss, metrics
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
class AxolotlKTOTrainer(TrainerMixins, KTOTrainer):
"""Extend the base KTOTrainer for axolotl helpers"""
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
"""Extend the base CPOTrainer for axolotl helpers"""
tag_names = ["axolotl", "cpo"]
@@ -245,17 +238,13 @@ class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
return loss, metrics
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
class AxolotlRewardTrainer(TrainerMixins, RewardTrainer):
"""Extend the base RewardTrainer for axolotl helpers"""
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
class AxolotlPRMTrainer(TrainerMixins, PRMTrainer):
"""Extend the base trl.PRMTrainer for axolotl helpers"""
tag_names = ["axolotl", "prm"]

View File

@@ -12,9 +12,7 @@ from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
@dataclass
class AxolotlTrainingMixins:
"""
Mixin class for the Axolotl training args.
"""
"""Mixin class for the Axolotl training args."""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(