chore: consolidate eval_strat, loraplus, lr sched, max_length

This commit is contained in:
NanoCode012
2025-01-28 15:04:18 +07:00
parent fd271b2547
commit 053e5fd7d1
2 changed files with 42 additions and 65 deletions

View File

@@ -529,7 +529,7 @@ profiler_steps: # enable the pytorch profiler to capture the first N steps of tr
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
# Save model as safetensors (require safetensors package)
# Save model as safetensors (require safetensors package). Transformers default True
save_safetensors:
# Whether to mask out or include the human's prompt from the training labels

View File

@@ -252,14 +252,15 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["warmup_steps"] = warmup_steps
training_args_kwargs["logging_steps"] = logging_steps
# precision
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
training_args_kwargs["tf32"] = self.cfg.tf32
if self.cfg.bf16 == "full":
training_args_kwargs["bf16_full_eval"] = True
else:
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
# hub
if self.cfg.hub_model_id:
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_args_kwargs["push_to_hub"] = True
@@ -269,10 +270,7 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.save_safetensors is not None:
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
# set save_strategy and save_steps
# save_strategy and save_steps
if self.cfg.save_steps:
training_args_kwargs["save_strategy"] = "steps"
training_args_kwargs["save_steps"] = self.cfg.save_steps
@@ -282,7 +280,15 @@ class TrainerBuilderBase(abc.ABC):
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["save_only_model"] = self.cfg.save_only_model
# eval_strategy and eval_steps
if not self.eval_dataset or self.cfg.val_set_size == 0:
# do not eval if no eval_dataset or val_set_size=0
training_args_kwargs["eval_strategy"] = "no"
elif self.cfg.eval_steps:
training_args_kwargs["eval_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.eval_strategy:
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
if self.cfg.gradient_checkpointing:
training_args_kwargs[
@@ -297,6 +303,7 @@ class TrainerBuilderBase(abc.ABC):
"use_reentrant": False
}
# set arg into trainer_args_kwargs with same name if value not None
for arg in [
"adam_beta1",
"adam_beta2",
@@ -305,6 +312,11 @@ class TrainerBuilderBase(abc.ABC):
"dataloader_num_workers",
"dataloader_pin_memory",
"dataloader_prefetch_factor",
"gradient_accumulation_steps",
"learning_rate",
"output_dir",
"save_safetensors",
"save_only_model",
"include_tokens_per_second",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
@@ -317,12 +329,6 @@ class TrainerBuilderBase(abc.ABC):
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
training_args_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
training_args_kwargs["learning_rate"] = self.cfg.learning_rate
training_args_kwargs["output_dir"] = self.cfg.output_dir
training_args_kwargs["save_total_limit"] = (
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
)
@@ -331,6 +337,11 @@ class TrainerBuilderBase(abc.ABC):
total_num_steps if self.cfg.max_steps else -1
)
# max_length is not used in CausalTrainer
if self.cfg.reward_model or self.cfg.rl:
training_args_kwargs["max_length"] = self.cfg.sequence_len
# reporting
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
@@ -349,6 +360,24 @@ class TrainerBuilderBase(abc.ABC):
else:
training_args_kwargs["run_name"] = None
# optim/scheduler
training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_args_kwargs["lr_scheduler_type"] = "cosine"
training_args_kwargs["alternate_lr_scheduler_type"] = self.cfg.lr_scheduler
else:
training_args_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
)
training_args_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
training_args_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_args_kwargs[
"cosine_constant_lr_ratio"
] = self.cfg.cosine_constant_lr_ratio
return training_args_kwargs
@@ -476,18 +505,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.remove_unused_columns
)
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["eval_strategy"] = "no"
elif self.cfg.eval_steps:
training_arguments_kwargs["eval_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.eval_strategy:
training_arguments_kwargs["eval_strategy"] = self.cfg.eval_strategy
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["eval_strategy"] = "epoch"
if self.cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset:
@@ -582,30 +599,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs["alternate_lr_scheduler_type"] = (
self.cfg.lr_scheduler
)
else:
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
)
training_arguments_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_arguments_kwargs["cosine_constant_lr_ratio"] = (
self.cfg.cosine_constant_lr_ratio
)
training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
)
@@ -671,9 +668,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
if self.cfg.reward_model:
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
# Handle custom optimizer
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
if self.cfg.optimizer in custom_supported_optimizers:
@@ -1006,22 +1000,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
total_num_steps=total_num_steps
)
if not self.eval_dataset:
training_args_kwargs["eval_strategy"] = "no"
elif self.cfg.eval_steps:
training_args_kwargs["eval_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.eval_strategy:
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding
training_args_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
)
training_args_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
if self.cfg.remove_unused_columns is not None:
training_args_kwargs["remove_unused_columns"] = (
@@ -1056,14 +1037,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl is RLType.SIMPO:
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
@@ -1077,7 +1056,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.kto_undesirable_weight or 1.0
)
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
@@ -1090,7 +1068,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_cls = AxolotlDPOConfig
if self.cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb