chore: consolidate eval_strat, loraplus, lr sched, max_length
This commit is contained in:
@@ -529,7 +529,7 @@ profiler_steps: # enable the pytorch profiler to capture the first N steps of tr
|
||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
||||
|
||||
# Save model as safetensors (require safetensors package)
|
||||
# Save model as safetensors (require safetensors package). Transformers default True
|
||||
save_safetensors:
|
||||
|
||||
# Whether to mask out or include the human's prompt from the training labels
|
||||
|
||||
@@ -252,14 +252,15 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["warmup_steps"] = warmup_steps
|
||||
training_args_kwargs["logging_steps"] = logging_steps
|
||||
|
||||
# precision
|
||||
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
|
||||
training_args_kwargs["tf32"] = self.cfg.tf32
|
||||
|
||||
if self.cfg.bf16 == "full":
|
||||
training_args_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
|
||||
|
||||
# hub
|
||||
if self.cfg.hub_model_id:
|
||||
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
||||
training_args_kwargs["push_to_hub"] = True
|
||||
@@ -269,10 +270,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if self.cfg.hub_strategy:
|
||||
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||
|
||||
if self.cfg.save_safetensors is not None:
|
||||
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||
|
||||
# set save_strategy and save_steps
|
||||
# save_strategy and save_steps
|
||||
if self.cfg.save_steps:
|
||||
training_args_kwargs["save_strategy"] = "steps"
|
||||
training_args_kwargs["save_steps"] = self.cfg.save_steps
|
||||
@@ -282,7 +280,15 @@ class TrainerBuilderBase(abc.ABC):
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
training_args_kwargs["save_only_model"] = self.cfg.save_only_model
|
||||
# eval_strategy and eval_steps
|
||||
if not self.eval_dataset or self.cfg.val_set_size == 0:
|
||||
# do not eval if no eval_dataset or val_set_size=0
|
||||
training_args_kwargs["eval_strategy"] = "no"
|
||||
elif self.cfg.eval_steps:
|
||||
training_args_kwargs["eval_strategy"] = "steps"
|
||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
elif self.cfg.eval_strategy:
|
||||
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_args_kwargs[
|
||||
@@ -297,6 +303,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
"use_reentrant": False
|
||||
}
|
||||
|
||||
# set arg into trainer_args_kwargs with same name if value not None
|
||||
for arg in [
|
||||
"adam_beta1",
|
||||
"adam_beta2",
|
||||
@@ -305,6 +312,11 @@ class TrainerBuilderBase(abc.ABC):
|
||||
"dataloader_num_workers",
|
||||
"dataloader_pin_memory",
|
||||
"dataloader_prefetch_factor",
|
||||
"gradient_accumulation_steps",
|
||||
"learning_rate",
|
||||
"output_dir",
|
||||
"save_safetensors",
|
||||
"save_only_model",
|
||||
"include_tokens_per_second",
|
||||
]:
|
||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||
@@ -317,12 +329,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
"per_device_eval_batch_size"
|
||||
] = self.cfg.eval_batch_size
|
||||
|
||||
training_args_kwargs[
|
||||
"gradient_accumulation_steps"
|
||||
] = self.cfg.gradient_accumulation_steps
|
||||
|
||||
training_args_kwargs["learning_rate"] = self.cfg.learning_rate
|
||||
training_args_kwargs["output_dir"] = self.cfg.output_dir
|
||||
training_args_kwargs["save_total_limit"] = (
|
||||
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
|
||||
)
|
||||
@@ -331,6 +337,11 @@ class TrainerBuilderBase(abc.ABC):
|
||||
total_num_steps if self.cfg.max_steps else -1
|
||||
)
|
||||
|
||||
# max_length is not used in CausalTrainer
|
||||
if self.cfg.reward_model or self.cfg.rl:
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
# reporting
|
||||
report_to = []
|
||||
if self.cfg.use_wandb:
|
||||
report_to.append("wandb")
|
||||
@@ -349,6 +360,24 @@ class TrainerBuilderBase(abc.ABC):
|
||||
else:
|
||||
training_args_kwargs["run_name"] = None
|
||||
|
||||
# optim/scheduler
|
||||
training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||
training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
||||
training_args_kwargs["lr_scheduler_type"] = "cosine"
|
||||
training_args_kwargs["alternate_lr_scheduler_type"] = self.cfg.lr_scheduler
|
||||
else:
|
||||
training_args_kwargs["lr_scheduler_type"] = (
|
||||
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
||||
)
|
||||
training_args_kwargs["lr_scheduler_kwargs"] = (
|
||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||
)
|
||||
training_args_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
|
||||
training_args_kwargs[
|
||||
"cosine_constant_lr_ratio"
|
||||
] = self.cfg.cosine_constant_lr_ratio
|
||||
|
||||
return training_args_kwargs
|
||||
|
||||
|
||||
@@ -476,18 +505,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.remove_unused_columns
|
||||
)
|
||||
|
||||
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
training_arguments_kwargs["eval_strategy"] = "no"
|
||||
elif self.cfg.eval_steps:
|
||||
training_arguments_kwargs["eval_strategy"] = "steps"
|
||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
elif self.cfg.eval_strategy:
|
||||
training_arguments_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||
else:
|
||||
# we have an eval set, but no steps defined, default to use epoch
|
||||
training_arguments_kwargs["eval_strategy"] = "epoch"
|
||||
|
||||
if self.cfg.do_bench_eval:
|
||||
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
||||
if self.cfg.bench_dataset:
|
||||
@@ -582,30 +599,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs[
|
||||
"optim_target_modules"
|
||||
] = self.cfg.optim_target_modules
|
||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||
training_arguments_kwargs[
|
||||
"loraplus_lr_embedding"
|
||||
] = self.cfg.loraplus_lr_embedding
|
||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||
training_arguments_kwargs["alternate_lr_scheduler_type"] = (
|
||||
self.cfg.lr_scheduler
|
||||
)
|
||||
else:
|
||||
training_arguments_kwargs["lr_scheduler_type"] = (
|
||||
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
||||
)
|
||||
training_arguments_kwargs["lr_scheduler_kwargs"] = (
|
||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||
)
|
||||
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
|
||||
training_arguments_kwargs["cosine_constant_lr_ratio"] = (
|
||||
self.cfg.cosine_constant_lr_ratio
|
||||
)
|
||||
training_arguments_kwargs["weight_decay"] = (
|
||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||
)
|
||||
@@ -671,9 +668,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
trainer_kwargs = {}
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
# Handle custom optimizer
|
||||
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
|
||||
if self.cfg.optimizer in custom_supported_optimizers:
|
||||
@@ -1006,22 +1000,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
total_num_steps=total_num_steps
|
||||
)
|
||||
|
||||
if not self.eval_dataset:
|
||||
training_args_kwargs["eval_strategy"] = "no"
|
||||
elif self.cfg.eval_steps:
|
||||
training_args_kwargs["eval_strategy"] = "steps"
|
||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
elif self.cfg.eval_strategy:
|
||||
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||
|
||||
training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||
training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding
|
||||
training_args_kwargs["lr_scheduler_type"] = (
|
||||
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
||||
)
|
||||
training_args_kwargs["lr_scheduler_kwargs"] = (
|
||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||
)
|
||||
|
||||
if self.cfg.remove_unused_columns is not None:
|
||||
training_args_kwargs["remove_unused_columns"] = (
|
||||
@@ -1056,14 +1037,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.rl is RLType.SIMPO:
|
||||
training_args_cls = AxolotlCPOConfig
|
||||
training_args_kwargs["loss_type"] = "simpo"
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
|
||||
if self.cfg.cpo_alpha is not None:
|
||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
training_args_cls = AxolotlORPOConfig
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
@@ -1077,7 +1056,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kto_undesirable_weight or 1.0
|
||||
)
|
||||
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
@@ -1090,7 +1068,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
if self.cfg.rl is RLType.IPO:
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
|
||||
Reference in New Issue
Block a user