use mixins for orpo and kto configs so they work with axolotl customizations (#1674)
This commit is contained in:
@@ -91,11 +91,12 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(TrainingArguments):
|
class AxolotlTrainingMixins:
|
||||||
"""
|
"""
|
||||||
Extend the base TrainingArguments for axolotl helpers
|
Mixin class for the Axolotl training args.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
default=None, metadata={"help": "HF model configuration model_type."}
|
default=None, metadata={"help": "HF model configuration model_type."}
|
||||||
)
|
)
|
||||||
@@ -227,6 +228,30 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||||
|
"""
|
||||||
|
Training arguments for Causal trainer
|
||||||
|
|
||||||
|
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||||
|
so it can't be used as a mixin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||||
|
"""
|
||||||
|
ORPO config for ORPO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
|
||||||
|
"""
|
||||||
|
KTO config for KTO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
class AxolotlTrainer(Trainer):
|
||||||
"""
|
"""
|
||||||
Extend the base Trainer for axolotl helpers
|
Extend the base Trainer for axolotl helpers
|
||||||
@@ -1583,14 +1608,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
training_args_cls = AxolotlTrainingArguments
|
training_args_cls = AxolotlTrainingArguments
|
||||||
if self.cfg.rl == "orpo":
|
if self.cfg.rl == "orpo":
|
||||||
training_args_cls = ORPOConfig
|
training_args_cls = AxolotlORPOConfig
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
if self.cfg.max_prompt_len:
|
if self.cfg.max_prompt_len:
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
|
||||||
if self.cfg.rl == "kto":
|
if self.cfg.rl == "kto":
|
||||||
training_args_cls = KTOConfig
|
training_args_cls = AxolotlKTOConfig
|
||||||
|
|
||||||
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
||||||
training_args_kwargs["desirable_weight"] = (
|
training_args_kwargs["desirable_weight"] = (
|
||||||
@@ -1605,12 +1630,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.max_prompt_len:
|
if self.cfg.max_prompt_len:
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
|
||||||
training_args = training_args_cls(
|
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||||
|
output_dir=self.cfg.output_dir,
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
max_steps=self.cfg.max_steps or total_num_steps,
|
max_steps=self.cfg.max_steps or total_num_steps,
|
||||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||||
learning_rate=self.cfg.learning_rate,
|
learning_rate=self.cfg.learning_rate,
|
||||||
output_dir=self.cfg.output_dir,
|
|
||||||
warmup_steps=self.cfg.warmup_steps,
|
warmup_steps=self.cfg.warmup_steps,
|
||||||
logging_first_step=True,
|
logging_first_step=True,
|
||||||
logging_steps=1,
|
logging_steps=1,
|
||||||
|
|||||||
Reference in New Issue
Block a user