simplifying trainer mixins and adding to rl trainers
This commit is contained in:
@@ -3,16 +3,16 @@
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .relora import ReLoRATrainer
|
||||
from .trl import (
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.core.trainers.dpo import AxolotlDPOTrainer
|
||||
from axolotl.core.trainers.grpo import AxolotlGRPOTrainer
|
||||
from axolotl.core.trainers.mamba import AxolotlMambaTrainer
|
||||
from axolotl.core.trainers.relora import ReLoRATrainer
|
||||
from axolotl.core.trainers.trl import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
AxolotlPPOTrainer,
|
||||
AxolotlPRMTrainer,
|
||||
AxolotlRewardTrainer,
|
||||
TRLPPOTrainer,
|
||||
)
|
||||
|
||||
@@ -25,12 +25,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.mixins import (
|
||||
OptimizerMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
SequenceParallelMixin,
|
||||
)
|
||||
from axolotl.core.trainers.mixins import TrainerMixins
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
@@ -40,9 +35,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AxolotlTrainer(
|
||||
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
|
||||
):
|
||||
class AxolotlTrainer(TrainerMixins, Trainer):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
"""
|
||||
DPO Specific Strategy for training
|
||||
"""
|
||||
"""DPO Specific Strategy for training"""
|
||||
|
||||
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||
|
||||
|
||||
class DPOStrategy:
|
||||
"""
|
||||
Strategy for DPO training
|
||||
"""
|
||||
"""Strategy for DPO training"""
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(cls):
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Axolotl specific DPO args
|
||||
"""
|
||||
"""Axolotl specific DPO args"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -11,6 +9,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
@dataclass
|
||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
"""DPO config for DPO training"""
|
||||
|
||||
@@ -13,7 +13,7 @@ from transformers import Trainer
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.mixins import TrainerMixins
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
@@ -23,7 +23,7 @@ if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
||||
class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
|
||||
"""
|
||||
Extend the base DPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Axolotl GRPO trainer
|
||||
"""
|
||||
"""Axolotl GRPO trainer"""
|
||||
|
||||
from contextlib import nullcontext
|
||||
|
||||
@@ -8,16 +6,14 @@ from accelerate.utils import is_deepspeed_available, is_peft_model
|
||||
from trl import GRPOTrainer
|
||||
from trl.extras.profiling import profiling_decorator
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.mixins import TrainerMixins
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
|
||||
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
"""
|
||||
Extend the base GRPOTrainer for axolotl helpers
|
||||
"""
|
||||
class AxolotlGRPOTrainer(TrainerMixins, GRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
|
||||
@@ -3,7 +3,13 @@
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .optimizer import OptimizerMixin
|
||||
from .rng_state_loader import RngLoaderMixin
|
||||
from .scheduler import SchedulerMixin
|
||||
from .sequence_parallel import SequenceParallelMixin
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerMixin
|
||||
from axolotl.core.trainers.mixins.rng_state_loader import RngLoaderMixin
|
||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||
from axolotl.core.trainers.mixins.sequence_parallel import SequenceParallelMixin
|
||||
|
||||
|
||||
class TrainerMixins(
|
||||
OptimizerMixin, RngLoaderMixin, SchedulerMixin, SequenceParallelMixin
|
||||
):
|
||||
"""Stub class combining all mixins for Axolotl trainers."""
|
||||
|
||||
@@ -21,9 +21,7 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RngLoaderMixin(Trainer):
|
||||
"""
|
||||
mixin for method override to load RNG states from a checkpoint
|
||||
"""
|
||||
"""Mixin for method override to load RNG states from a checkpoint"""
|
||||
|
||||
def _load_rng_state(self, checkpoint):
|
||||
# Load RNG states from `checkpoint`
|
||||
|
||||
@@ -13,11 +13,10 @@ from trl import (
|
||||
RewardTrainer,
|
||||
)
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin
|
||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||
from axolotl.core.trainers.mixins import TrainerMixins
|
||||
|
||||
|
||||
class TRLPPOTrainer(PPOTrainer):
|
||||
class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
|
||||
"""Wrapper for TRL PPO trainer to handle customizations"""
|
||||
|
||||
tag_names = ["axolotl", "ppo"]
|
||||
@@ -75,10 +74,8 @@ class TRLPPOTrainer(PPOTrainer):
|
||||
)
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
"""
|
||||
class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
|
||||
"""Extend the base ORPOTrainer for axolotl helpers"""
|
||||
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
@@ -155,18 +152,14 @@ class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
|
||||
return loss, metrics
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
class AxolotlKTOTrainer(TrainerMixins, KTOTrainer):
|
||||
"""Extend the base KTOTrainer for axolotl helpers"""
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
"""
|
||||
class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
|
||||
"""Extend the base CPOTrainer for axolotl helpers"""
|
||||
|
||||
tag_names = ["axolotl", "cpo"]
|
||||
|
||||
@@ -245,17 +238,13 @@ class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
|
||||
return loss, metrics
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
"""
|
||||
class AxolotlRewardTrainer(TrainerMixins, RewardTrainer):
|
||||
"""Extend the base RewardTrainer for axolotl helpers"""
|
||||
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
|
||||
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
|
||||
"""
|
||||
Extend the base trl.PRMTrainer for axolotl helpers
|
||||
"""
|
||||
class AxolotlPRMTrainer(TrainerMixins, PRMTrainer):
|
||||
"""Extend the base trl.PRMTrainer for axolotl helpers"""
|
||||
|
||||
tag_names = ["axolotl", "prm"]
|
||||
|
||||
@@ -12,9 +12,7 @@ from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
"""
|
||||
Mixin class for the Axolotl training args.
|
||||
"""
|
||||
"""Mixin class for the Axolotl training args."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
model_type: Optional[str] = field(
|
||||
|
||||
Reference in New Issue
Block a user