chore: refactor set_base_training_args into smaller modules

This commit is contained in:
NanoCode012
2025-05-22 18:39:33 +07:00
parent 58842ded9c
commit 79472241e8

View File

@@ -178,8 +178,8 @@ class TrainerBuilderBase(abc.ABC):
# TODO
return trainer
def _set_base_training_args(self, total_num_steps) -> dict[str, Any]:
training_args_kwargs: Dict[str, Any] = {}
def _configure_warmup_and_logging(self, total_num_steps):
training_args_kwargs = {}
warmup_steps = 0
warmup_ratio = 0.0
@@ -212,7 +212,11 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["warmup_steps"] = warmup_steps
training_args_kwargs["logging_steps"] = logging_steps
# precision
return training_args_kwargs
def _configure_precision_settings(self):
training_args_kwargs = {}
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
training_args_kwargs["tf32"] = self.cfg.tf32
if self.cfg.bf16 == "full":
@@ -220,116 +224,11 @@ class TrainerBuilderBase(abc.ABC):
else:
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
# hub
if self.cfg.hub_model_id:
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_args_kwargs["push_to_hub"] = True
training_args_kwargs["hub_private_repo"] = True
training_args_kwargs["hub_always_push"] = True
return training_args_kwargs
if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
def _configure_optimizer_and_scheduler(self):
training_args_kwargs = {}
# save_strategy and save_steps
if self.cfg.save_steps:
training_args_kwargs["save_strategy"] = "steps"
training_args_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
# eval_strategy and eval_steps
if not self.eval_dataset or self.cfg.val_set_size == 0:
# do not eval if no eval_dataset or val_set_size=0
training_args_kwargs["eval_strategy"] = "no"
elif self.cfg.eval_steps:
training_args_kwargs["eval_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.eval_strategy:
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs["gradient_checkpointing_kwargs"] = (
self.cfg.gradient_checkpointing_kwargs
)
else:
training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
}
# set arg into trainer_args_kwargs with same name if value not None
for arg in [
"adam_beta1",
"adam_beta2",
"adam_epsilon",
"max_grad_norm",
"dataloader_num_workers",
"dataloader_pin_memory",
"dataloader_prefetch_factor",
"gradient_accumulation_steps",
"learning_rate",
"embedding_lr",
"embedding_lr_scale",
"lr_groups",
"loraplus_lr_ratio",
"loraplus_lr_embedding",
"output_dir",
"save_safetensors",
"save_only_model",
"include_tokens_per_second",
"weight_decay",
"sequence_parallel_degree",
"ring_attn_func",
"seed",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
if self.cfg.eval_batch_size:
training_args_kwargs["per_device_eval_batch_size"] = (
self.cfg.eval_batch_size
)
training_args_kwargs["save_total_limit"] = (
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
)
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
# max_length is not used in CausalTrainer
if self.cfg.reward_model or self.cfg.rl:
training_args_kwargs["max_length"] = self.cfg.sequence_len
# reporting
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.use_mlflow:
report_to.append("mlflow")
if self.cfg.use_tensorboard:
report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
training_args_kwargs["report_to"] = report_to
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
elif self.cfg.use_mlflow:
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_args_kwargs["run_name"] = None
# optim/scheduler
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep", "rex"]:
training_args_kwargs["lr_scheduler_type"] = "cosine"
training_args_kwargs["alternate_lr_scheduler_type"] = self.cfg.lr_scheduler
@@ -462,7 +361,78 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.optim_target_modules:
training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules
# torch compile
return training_args_kwargs
def _configure_hub_parameters(self):
training_args_kwargs = {}
if self.cfg.hub_model_id:
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_args_kwargs["push_to_hub"] = True
training_args_kwargs["hub_private_repo"] = True
training_args_kwargs["hub_always_push"] = True
if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
return training_args_kwargs
def _configure_save_and_eval_strategy(self):
training_args_kwargs = {}
# save_strategy and save_steps
if self.cfg.save_steps:
training_args_kwargs["save_strategy"] = "steps"
training_args_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["save_total_limit"] = (
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
)
# eval_strategy and eval_steps
if not self.eval_dataset or self.cfg.val_set_size == 0:
# do not eval if no eval_dataset or val_set_size=0
training_args_kwargs["eval_strategy"] = "no"
elif self.cfg.eval_steps:
training_args_kwargs["eval_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.eval_strategy:
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
return training_args_kwargs
def _configure_reporting(self):
training_args_kwargs = {}
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.use_mlflow:
report_to.append("mlflow")
if self.cfg.use_tensorboard:
report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
training_args_kwargs["report_to"] = report_to
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
elif self.cfg.use_mlflow:
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_args_kwargs["run_name"] = None
return training_args_kwargs
def _configure_torch_compile(self):
training_args_kwargs = {}
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
@@ -476,3 +446,85 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
return training_args_kwargs
def _configure_gradient_checkpointing(self):
training_args_kwargs = {}
if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs["gradient_checkpointing_kwargs"] = (
self.cfg.gradient_checkpointing_kwargs
)
else:
training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
}
return training_args_kwargs
def _set_base_training_args(self, total_num_steps) -> dict[str, Any]:
training_args_kwargs: Dict[str, Any] = {}
training_args_kwargs.update(self._configure_warmup_and_logging(total_num_steps))
training_args_kwargs.update(self._configure_precision_settings())
training_args_kwargs.update(self._configure_save_and_eval_strategy())
training_args_kwargs.update(self._configure_gradient_checkpointing())
# set arg into trainer_args_kwargs with same name if value not None
for arg in [
"adam_beta1",
"adam_beta2",
"adam_epsilon",
"max_grad_norm",
"dataloader_num_workers",
"dataloader_pin_memory",
"dataloader_prefetch_factor",
"gradient_accumulation_steps",
"learning_rate",
"embedding_lr",
"embedding_lr_scale",
"lr_groups",
"loraplus_lr_ratio",
"loraplus_lr_embedding",
"output_dir",
"save_safetensors",
"save_only_model",
"include_tokens_per_second",
"weight_decay",
"sequence_parallel_degree",
"ring_attn_func",
"seed",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
if self.cfg.eval_batch_size:
training_args_kwargs["per_device_eval_batch_size"] = (
self.cfg.eval_batch_size
)
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
# max_length is not used in CausalTrainer
if self.cfg.reward_model or self.cfg.rl:
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs.update(self._configure_reporting())
training_args_kwargs.update(self._configure_hub_parameters())
training_args_kwargs.update(self._configure_optimizer_and_scheduler())
training_args_kwargs.update(self._configure_torch_compile())
return training_args_kwargs