From f7332ac449244643a93a19d99b4a8e28f59ef383 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 29 May 2024 22:27:00 -0400 Subject: [PATCH] use mixins for orpo and kto configs so they work with axolotl customizations (#1674) --- src/axolotl/core/trainer_builder.py | 37 ++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index a37652ade..c881e4cd3 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -91,11 +91,12 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): @dataclass -class AxolotlTrainingArguments(TrainingArguments): +class AxolotlTrainingMixins: """ - Extend the base TrainingArguments for axolotl helpers + Mixin class for the Axolotl training args. """ + # pylint: disable=duplicate-code model_type: Optional[str] = field( default=None, metadata={"help": "HF model configuration model_type."} ) @@ -227,6 +228,30 @@ class AxolotlTrainingArguments(TrainingArguments): ) +@dataclass +class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): + """ + Training arguments for Causal trainer + + This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value + so it can't be used as a mixin. + """ + + +@dataclass +class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig): + """ + ORPO config for ORPO training + """ + + +@dataclass +class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig): + """ + KTO config for KTO training + """ + + class AxolotlTrainer(Trainer): """ Extend the base Trainer for axolotl helpers @@ -1583,14 +1608,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase): training_args_cls = AxolotlTrainingArguments if self.cfg.rl == "orpo": - training_args_cls = ORPOConfig + training_args_cls = AxolotlORPOConfig training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["max_length"] = self.cfg.sequence_len if self.cfg.max_prompt_len: training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len if self.cfg.rl == "kto": - training_args_cls = KTOConfig + training_args_cls = AxolotlKTOConfig training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1 training_args_kwargs["desirable_weight"] = ( @@ -1605,12 +1630,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.max_prompt_len: training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len - training_args = training_args_cls( + training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg + output_dir=self.cfg.output_dir, per_device_train_batch_size=self.cfg.micro_batch_size, max_steps=self.cfg.max_steps or total_num_steps, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, learning_rate=self.cfg.learning_rate, - output_dir=self.cfg.output_dir, warmup_steps=self.cfg.warmup_steps, logging_first_step=True, logging_steps=1,