simplifying trainer mixins and adding to rl trainers
This commit is contained in:
@@ -3,16 +3,16 @@
|
|||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
|
|
||||||
from .base import AxolotlTrainer
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
from .dpo.trainer import AxolotlDPOTrainer
|
from axolotl.core.trainers.dpo import AxolotlDPOTrainer
|
||||||
from .grpo.trainer import AxolotlGRPOTrainer
|
from axolotl.core.trainers.grpo import AxolotlGRPOTrainer
|
||||||
from .mamba import AxolotlMambaTrainer
|
from axolotl.core.trainers.mamba import AxolotlMambaTrainer
|
||||||
from .relora import ReLoRATrainer
|
from axolotl.core.trainers.relora import ReLoRATrainer
|
||||||
from .trl import (
|
from axolotl.core.trainers.trl import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
AxolotlKTOTrainer,
|
AxolotlKTOTrainer,
|
||||||
AxolotlORPOTrainer,
|
AxolotlORPOTrainer,
|
||||||
|
AxolotlPPOTrainer,
|
||||||
AxolotlPRMTrainer,
|
AxolotlPRMTrainer,
|
||||||
AxolotlRewardTrainer,
|
AxolotlRewardTrainer,
|
||||||
TRLPPOTrainer,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -25,12 +25,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
|||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import (
|
from axolotl.core.trainers.mixins import TrainerMixins
|
||||||
OptimizerMixin,
|
|
||||||
RngLoaderMixin,
|
|
||||||
SchedulerMixin,
|
|
||||||
SequenceParallelMixin,
|
|
||||||
)
|
|
||||||
from axolotl.core.trainers.utils import (
|
from axolotl.core.trainers.utils import (
|
||||||
sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
@@ -40,9 +35,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(
|
class AxolotlTrainer(TrainerMixins, Trainer):
|
||||||
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
|
|
||||||
):
|
|
||||||
"""Extend the base Trainer for axolotl helpers"""
|
"""Extend the base Trainer for axolotl helpers"""
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|||||||
@@ -1,14 +1,10 @@
|
|||||||
"""
|
"""DPO Specific Strategy for training"""
|
||||||
DPO Specific Strategy for training
|
|
||||||
"""
|
|
||||||
|
|
||||||
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||||
|
|
||||||
|
|
||||||
class DPOStrategy:
|
class DPOStrategy:
|
||||||
"""
|
"""Strategy for DPO training"""
|
||||||
Strategy for DPO training
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_trainer_class(cls):
|
def get_trainer_class(cls):
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
"""
|
"""Axolotl specific DPO args"""
|
||||||
Axolotl specific DPO args
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@@ -11,6 +9,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||||
"""
|
"""DPO config for DPO training"""
|
||||||
DPO config for DPO training
|
|
||||||
"""
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from transformers import Trainer
|
|||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
from axolotl.core.trainers.mixins import TrainerMixins
|
||||||
from axolotl.core.trainers.utils import (
|
from axolotl.core.trainers.utils import (
|
||||||
sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
@@ -23,7 +23,7 @@ if is_sagemaker_mp_enabled():
|
|||||||
import smdistributed.modelparallel.torch as smp
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
|
||||||
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base DPOTrainer for axolotl helpers
|
Extend the base DPOTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
"""
|
"""Axolotl GRPO trainer"""
|
||||||
Axolotl GRPO trainer
|
|
||||||
"""
|
|
||||||
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
@@ -8,16 +6,14 @@ from accelerate.utils import is_deepspeed_available, is_peft_model
|
|||||||
from trl import GRPOTrainer
|
from trl import GRPOTrainer
|
||||||
from trl.extras.profiling import profiling_decorator
|
from trl.extras.profiling import profiling_decorator
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
from axolotl.core.trainers.mixins import TrainerMixins
|
||||||
|
|
||||||
if is_deepspeed_available():
|
if is_deepspeed_available():
|
||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
|
|
||||||
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
class AxolotlGRPOTrainer(TrainerMixins, GRPOTrainer):
|
||||||
"""
|
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||||
Extend the base GRPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
_tag_names = ["trl", "grpo", "axolotl"]
|
_tag_names = ["trl", "grpo", "axolotl"]
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,13 @@
|
|||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
|
|
||||||
from .optimizer import OptimizerMixin
|
from axolotl.core.trainers.mixins.optimizer import OptimizerMixin
|
||||||
from .rng_state_loader import RngLoaderMixin
|
from axolotl.core.trainers.mixins.rng_state_loader import RngLoaderMixin
|
||||||
from .scheduler import SchedulerMixin
|
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||||
from .sequence_parallel import SequenceParallelMixin
|
from axolotl.core.trainers.mixins.sequence_parallel import SequenceParallelMixin
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerMixins(
|
||||||
|
OptimizerMixin, RngLoaderMixin, SchedulerMixin, SequenceParallelMixin
|
||||||
|
):
|
||||||
|
"""Stub class combining all mixins for Axolotl trainers."""
|
||||||
|
|||||||
@@ -21,9 +21,7 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class RngLoaderMixin(Trainer):
|
class RngLoaderMixin(Trainer):
|
||||||
"""
|
"""Mixin for method override to load RNG states from a checkpoint"""
|
||||||
mixin for method override to load RNG states from a checkpoint
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _load_rng_state(self, checkpoint):
|
def _load_rng_state(self, checkpoint):
|
||||||
# Load RNG states from `checkpoint`
|
# Load RNG states from `checkpoint`
|
||||||
|
|||||||
@@ -13,11 +13,10 @@ from trl import (
|
|||||||
RewardTrainer,
|
RewardTrainer,
|
||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import RngLoaderMixin
|
from axolotl.core.trainers.mixins import TrainerMixins
|
||||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
|
||||||
|
|
||||||
|
|
||||||
class TRLPPOTrainer(PPOTrainer):
|
class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
|
||||||
"""Wrapper for TRL PPO trainer to handle customizations"""
|
"""Wrapper for TRL PPO trainer to handle customizations"""
|
||||||
|
|
||||||
tag_names = ["axolotl", "ppo"]
|
tag_names = ["axolotl", "ppo"]
|
||||||
@@ -75,10 +74,8 @@ class TRLPPOTrainer(PPOTrainer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
|
class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
|
||||||
"""
|
"""Extend the base ORPOTrainer for axolotl helpers"""
|
||||||
Extend the base ORPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "orpo"]
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
@@ -155,18 +152,14 @@ class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
|
|||||||
return loss, metrics
|
return loss, metrics
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
|
class AxolotlKTOTrainer(TrainerMixins, KTOTrainer):
|
||||||
"""
|
"""Extend the base KTOTrainer for axolotl helpers"""
|
||||||
Extend the base KTOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
|
class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
|
||||||
"""
|
"""Extend the base CPOTrainer for axolotl helpers"""
|
||||||
Extend the base CPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "cpo"]
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
@@ -245,17 +238,13 @@ class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
|
|||||||
return loss, metrics
|
return loss, metrics
|
||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
|
class AxolotlRewardTrainer(TrainerMixins, RewardTrainer):
|
||||||
"""
|
"""Extend the base RewardTrainer for axolotl helpers"""
|
||||||
Extend the base RewardTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "reward"]
|
tag_names = ["axolotl", "reward"]
|
||||||
|
|
||||||
|
|
||||||
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
|
class AxolotlPRMTrainer(TrainerMixins, PRMTrainer):
|
||||||
"""
|
"""Extend the base trl.PRMTrainer for axolotl helpers"""
|
||||||
Extend the base trl.PRMTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "prm"]
|
tag_names = ["axolotl", "prm"]
|
||||||
|
|||||||
@@ -12,9 +12,7 @@ from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingMixins:
|
class AxolotlTrainingMixins:
|
||||||
"""
|
"""Mixin class for the Axolotl training args."""
|
||||||
Mixin class for the Axolotl training args.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
|
|||||||
Reference in New Issue
Block a user