use mixins for orpo and kto configs so they work with axolotl customizations (#1674)

This commit is contained in:
Wing Lian
2024-05-29 22:27:00 -04:00
committed by GitHub
parent 16d46b74e4
commit f7332ac449

View File

@@ -91,11 +91,12 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
@dataclass
class AxolotlTrainingArguments(TrainingArguments):
class AxolotlTrainingMixins:
"""
Extend the base TrainingArguments for axolotl helpers
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
@@ -227,6 +228,30 @@ class AxolotlTrainingArguments(TrainingArguments):
)
@dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
so it can't be used as a mixin.
"""
@dataclass
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
"""
ORPO config for ORPO training
"""
@dataclass
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
"""
KTO config for KTO training
"""
class AxolotlTrainer(Trainer):
"""
Extend the base Trainer for axolotl helpers
@@ -1583,14 +1608,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_cls = AxolotlTrainingArguments
if self.cfg.rl == "orpo":
training_args_cls = ORPOConfig
training_args_cls = AxolotlORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
if self.cfg.rl == "kto":
training_args_cls = KTOConfig
training_args_cls = AxolotlKTOConfig
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
training_args_kwargs["desirable_weight"] = (
@@ -1605,12 +1630,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
training_args = training_args_cls(
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
output_dir=self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=self.cfg.max_steps or total_num_steps,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
output_dir=self.cfg.output_dir,
warmup_steps=self.cfg.warmup_steps,
logging_first_step=True,
logging_steps=1,