use mixins for orpo and kto configs so they work with axolotl customizations (#1674)
This commit is contained in:
@@ -91,11 +91,12 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(TrainingArguments):
|
||||
class AxolotlTrainingMixins:
|
||||
"""
|
||||
Extend the base TrainingArguments for axolotl helpers
|
||||
Mixin class for the Axolotl training args.
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
model_type: Optional[str] = field(
|
||||
default=None, metadata={"help": "HF model configuration model_type."}
|
||||
)
|
||||
@@ -227,6 +228,30 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
"""
|
||||
Training arguments for Causal trainer
|
||||
|
||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||
so it can't be used as a mixin.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||
"""
|
||||
ORPO config for ORPO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
|
||||
"""
|
||||
KTO config for KTO training
|
||||
"""
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
@@ -1583,14 +1608,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
training_args_cls = AxolotlTrainingArguments
|
||||
if self.cfg.rl == "orpo":
|
||||
training_args_cls = ORPOConfig
|
||||
training_args_cls = AxolotlORPOConfig
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
if self.cfg.rl == "kto":
|
||||
training_args_cls = KTOConfig
|
||||
training_args_cls = AxolotlKTOConfig
|
||||
|
||||
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
||||
training_args_kwargs["desirable_weight"] = (
|
||||
@@ -1605,12 +1630,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
training_args = training_args_cls(
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
output_dir=self.cfg.output_dir,
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
max_steps=self.cfg.max_steps or total_num_steps,
|
||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||
learning_rate=self.cfg.learning_rate,
|
||||
output_dir=self.cfg.output_dir,
|
||||
warmup_steps=self.cfg.warmup_steps,
|
||||
logging_first_step=True,
|
||||
logging_steps=1,
|
||||
|
||||
Reference in New Issue
Block a user