chore: refactor set_base_training_args into smaller modules
This commit is contained in:
@@ -178,8 +178,8 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
# TODO
|
# TODO
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
def _set_base_training_args(self, total_num_steps) -> dict[str, Any]:
|
def _configure_warmup_and_logging(self, total_num_steps):
|
||||||
training_args_kwargs: Dict[str, Any] = {}
|
training_args_kwargs = {}
|
||||||
|
|
||||||
warmup_steps = 0
|
warmup_steps = 0
|
||||||
warmup_ratio = 0.0
|
warmup_ratio = 0.0
|
||||||
@@ -212,7 +212,11 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
training_args_kwargs["warmup_steps"] = warmup_steps
|
training_args_kwargs["warmup_steps"] = warmup_steps
|
||||||
training_args_kwargs["logging_steps"] = logging_steps
|
training_args_kwargs["logging_steps"] = logging_steps
|
||||||
|
|
||||||
# precision
|
return training_args_kwargs
|
||||||
|
|
||||||
|
def _configure_precision_settings(self):
|
||||||
|
training_args_kwargs = {}
|
||||||
|
|
||||||
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
|
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
|
||||||
training_args_kwargs["tf32"] = self.cfg.tf32
|
training_args_kwargs["tf32"] = self.cfg.tf32
|
||||||
if self.cfg.bf16 == "full":
|
if self.cfg.bf16 == "full":
|
||||||
@@ -220,116 +224,11 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
else:
|
else:
|
||||||
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
|
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
|
||||||
|
|
||||||
# hub
|
return training_args_kwargs
|
||||||
if self.cfg.hub_model_id:
|
|
||||||
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
|
||||||
training_args_kwargs["push_to_hub"] = True
|
|
||||||
training_args_kwargs["hub_private_repo"] = True
|
|
||||||
training_args_kwargs["hub_always_push"] = True
|
|
||||||
|
|
||||||
if self.cfg.hub_strategy:
|
def _configure_optimizer_and_scheduler(self):
|
||||||
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
training_args_kwargs = {}
|
||||||
|
|
||||||
# save_strategy and save_steps
|
|
||||||
if self.cfg.save_steps:
|
|
||||||
training_args_kwargs["save_strategy"] = "steps"
|
|
||||||
training_args_kwargs["save_steps"] = self.cfg.save_steps
|
|
||||||
elif self.cfg.save_strategy:
|
|
||||||
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
|
|
||||||
else:
|
|
||||||
# default to saving each epoch if not defined
|
|
||||||
training_args_kwargs["save_strategy"] = "epoch"
|
|
||||||
|
|
||||||
# eval_strategy and eval_steps
|
|
||||||
if not self.eval_dataset or self.cfg.val_set_size == 0:
|
|
||||||
# do not eval if no eval_dataset or val_set_size=0
|
|
||||||
training_args_kwargs["eval_strategy"] = "no"
|
|
||||||
elif self.cfg.eval_steps:
|
|
||||||
training_args_kwargs["eval_strategy"] = "steps"
|
|
||||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
|
||||||
elif self.cfg.eval_strategy:
|
|
||||||
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
|
||||||
|
|
||||||
if self.cfg.gradient_checkpointing:
|
|
||||||
training_args_kwargs["gradient_checkpointing"] = (
|
|
||||||
self.cfg.gradient_checkpointing
|
|
||||||
)
|
|
||||||
if self.cfg.gradient_checkpointing_kwargs is not None:
|
|
||||||
training_args_kwargs["gradient_checkpointing_kwargs"] = (
|
|
||||||
self.cfg.gradient_checkpointing_kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
training_args_kwargs["gradient_checkpointing_kwargs"] = {
|
|
||||||
"use_reentrant": False
|
|
||||||
}
|
|
||||||
|
|
||||||
# set arg into trainer_args_kwargs with same name if value not None
|
|
||||||
for arg in [
|
|
||||||
"adam_beta1",
|
|
||||||
"adam_beta2",
|
|
||||||
"adam_epsilon",
|
|
||||||
"max_grad_norm",
|
|
||||||
"dataloader_num_workers",
|
|
||||||
"dataloader_pin_memory",
|
|
||||||
"dataloader_prefetch_factor",
|
|
||||||
"gradient_accumulation_steps",
|
|
||||||
"learning_rate",
|
|
||||||
"embedding_lr",
|
|
||||||
"embedding_lr_scale",
|
|
||||||
"lr_groups",
|
|
||||||
"loraplus_lr_ratio",
|
|
||||||
"loraplus_lr_embedding",
|
|
||||||
"output_dir",
|
|
||||||
"save_safetensors",
|
|
||||||
"save_only_model",
|
|
||||||
"include_tokens_per_second",
|
|
||||||
"weight_decay",
|
|
||||||
"sequence_parallel_degree",
|
|
||||||
"ring_attn_func",
|
|
||||||
"seed",
|
|
||||||
]:
|
|
||||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
|
||||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
|
||||||
|
|
||||||
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
|
||||||
|
|
||||||
if self.cfg.eval_batch_size:
|
|
||||||
training_args_kwargs["per_device_eval_batch_size"] = (
|
|
||||||
self.cfg.eval_batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
training_args_kwargs["save_total_limit"] = (
|
|
||||||
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
|
|
||||||
)
|
|
||||||
|
|
||||||
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
|
||||||
|
|
||||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
|
||||||
|
|
||||||
# max_length is not used in CausalTrainer
|
|
||||||
if self.cfg.reward_model or self.cfg.rl:
|
|
||||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
|
||||||
|
|
||||||
# reporting
|
|
||||||
report_to = []
|
|
||||||
if self.cfg.use_wandb:
|
|
||||||
report_to.append("wandb")
|
|
||||||
if self.cfg.use_mlflow:
|
|
||||||
report_to.append("mlflow")
|
|
||||||
if self.cfg.use_tensorboard:
|
|
||||||
report_to.append("tensorboard")
|
|
||||||
if self.cfg.use_comet:
|
|
||||||
report_to.append("comet_ml")
|
|
||||||
|
|
||||||
training_args_kwargs["report_to"] = report_to
|
|
||||||
if self.cfg.use_wandb:
|
|
||||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
|
||||||
elif self.cfg.use_mlflow:
|
|
||||||
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
|
|
||||||
else:
|
|
||||||
training_args_kwargs["run_name"] = None
|
|
||||||
|
|
||||||
# optim/scheduler
|
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep", "rex"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep", "rex"]:
|
||||||
training_args_kwargs["lr_scheduler_type"] = "cosine"
|
training_args_kwargs["lr_scheduler_type"] = "cosine"
|
||||||
training_args_kwargs["alternate_lr_scheduler_type"] = self.cfg.lr_scheduler
|
training_args_kwargs["alternate_lr_scheduler_type"] = self.cfg.lr_scheduler
|
||||||
@@ -462,7 +361,78 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if self.cfg.optim_target_modules:
|
if self.cfg.optim_target_modules:
|
||||||
training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules
|
training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules
|
||||||
|
|
||||||
# torch compile
|
return training_args_kwargs
|
||||||
|
|
||||||
|
def _configure_hub_parameters(self):
|
||||||
|
training_args_kwargs = {}
|
||||||
|
|
||||||
|
if self.cfg.hub_model_id:
|
||||||
|
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
||||||
|
training_args_kwargs["push_to_hub"] = True
|
||||||
|
training_args_kwargs["hub_private_repo"] = True
|
||||||
|
training_args_kwargs["hub_always_push"] = True
|
||||||
|
|
||||||
|
if self.cfg.hub_strategy:
|
||||||
|
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||||
|
|
||||||
|
return training_args_kwargs
|
||||||
|
|
||||||
|
def _configure_save_and_eval_strategy(self):
|
||||||
|
training_args_kwargs = {}
|
||||||
|
|
||||||
|
# save_strategy and save_steps
|
||||||
|
if self.cfg.save_steps:
|
||||||
|
training_args_kwargs["save_strategy"] = "steps"
|
||||||
|
training_args_kwargs["save_steps"] = self.cfg.save_steps
|
||||||
|
elif self.cfg.save_strategy:
|
||||||
|
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
|
||||||
|
else:
|
||||||
|
# default to saving each epoch if not defined
|
||||||
|
training_args_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
|
training_args_kwargs["save_total_limit"] = (
|
||||||
|
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
|
||||||
|
)
|
||||||
|
|
||||||
|
# eval_strategy and eval_steps
|
||||||
|
if not self.eval_dataset or self.cfg.val_set_size == 0:
|
||||||
|
# do not eval if no eval_dataset or val_set_size=0
|
||||||
|
training_args_kwargs["eval_strategy"] = "no"
|
||||||
|
elif self.cfg.eval_steps:
|
||||||
|
training_args_kwargs["eval_strategy"] = "steps"
|
||||||
|
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
|
elif self.cfg.eval_strategy:
|
||||||
|
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||||
|
|
||||||
|
return training_args_kwargs
|
||||||
|
|
||||||
|
def _configure_reporting(self):
|
||||||
|
training_args_kwargs = {}
|
||||||
|
|
||||||
|
report_to = []
|
||||||
|
if self.cfg.use_wandb:
|
||||||
|
report_to.append("wandb")
|
||||||
|
if self.cfg.use_mlflow:
|
||||||
|
report_to.append("mlflow")
|
||||||
|
if self.cfg.use_tensorboard:
|
||||||
|
report_to.append("tensorboard")
|
||||||
|
if self.cfg.use_comet:
|
||||||
|
report_to.append("comet_ml")
|
||||||
|
|
||||||
|
training_args_kwargs["report_to"] = report_to
|
||||||
|
|
||||||
|
if self.cfg.use_wandb:
|
||||||
|
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||||
|
elif self.cfg.use_mlflow:
|
||||||
|
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||||
|
else:
|
||||||
|
training_args_kwargs["run_name"] = None
|
||||||
|
|
||||||
|
return training_args_kwargs
|
||||||
|
|
||||||
|
def _configure_torch_compile(self):
|
||||||
|
training_args_kwargs = {}
|
||||||
|
|
||||||
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
|
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
|
||||||
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
||||||
True
|
True
|
||||||
@@ -476,3 +446,85 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||||
|
|
||||||
return training_args_kwargs
|
return training_args_kwargs
|
||||||
|
|
||||||
|
def _configure_gradient_checkpointing(self):
|
||||||
|
training_args_kwargs = {}
|
||||||
|
|
||||||
|
if self.cfg.gradient_checkpointing:
|
||||||
|
training_args_kwargs["gradient_checkpointing"] = (
|
||||||
|
self.cfg.gradient_checkpointing
|
||||||
|
)
|
||||||
|
if self.cfg.gradient_checkpointing_kwargs is not None:
|
||||||
|
training_args_kwargs["gradient_checkpointing_kwargs"] = (
|
||||||
|
self.cfg.gradient_checkpointing_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
training_args_kwargs["gradient_checkpointing_kwargs"] = {
|
||||||
|
"use_reentrant": False
|
||||||
|
}
|
||||||
|
|
||||||
|
return training_args_kwargs
|
||||||
|
|
||||||
|
def _set_base_training_args(self, total_num_steps) -> dict[str, Any]:
|
||||||
|
training_args_kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
training_args_kwargs.update(self._configure_warmup_and_logging(total_num_steps))
|
||||||
|
|
||||||
|
training_args_kwargs.update(self._configure_precision_settings())
|
||||||
|
|
||||||
|
training_args_kwargs.update(self._configure_save_and_eval_strategy())
|
||||||
|
|
||||||
|
training_args_kwargs.update(self._configure_gradient_checkpointing())
|
||||||
|
|
||||||
|
# set arg into trainer_args_kwargs with same name if value not None
|
||||||
|
for arg in [
|
||||||
|
"adam_beta1",
|
||||||
|
"adam_beta2",
|
||||||
|
"adam_epsilon",
|
||||||
|
"max_grad_norm",
|
||||||
|
"dataloader_num_workers",
|
||||||
|
"dataloader_pin_memory",
|
||||||
|
"dataloader_prefetch_factor",
|
||||||
|
"gradient_accumulation_steps",
|
||||||
|
"learning_rate",
|
||||||
|
"embedding_lr",
|
||||||
|
"embedding_lr_scale",
|
||||||
|
"lr_groups",
|
||||||
|
"loraplus_lr_ratio",
|
||||||
|
"loraplus_lr_embedding",
|
||||||
|
"output_dir",
|
||||||
|
"save_safetensors",
|
||||||
|
"save_only_model",
|
||||||
|
"include_tokens_per_second",
|
||||||
|
"weight_decay",
|
||||||
|
"sequence_parallel_degree",
|
||||||
|
"ring_attn_func",
|
||||||
|
"seed",
|
||||||
|
]:
|
||||||
|
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||||
|
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||||
|
|
||||||
|
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
||||||
|
|
||||||
|
if self.cfg.eval_batch_size:
|
||||||
|
training_args_kwargs["per_device_eval_batch_size"] = (
|
||||||
|
self.cfg.eval_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
||||||
|
|
||||||
|
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||||
|
|
||||||
|
# max_length is not used in CausalTrainer
|
||||||
|
if self.cfg.reward_model or self.cfg.rl:
|
||||||
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
|
training_args_kwargs.update(self._configure_reporting())
|
||||||
|
|
||||||
|
training_args_kwargs.update(self._configure_hub_parameters())
|
||||||
|
|
||||||
|
training_args_kwargs.update(self._configure_optimizer_and_scheduler())
|
||||||
|
|
||||||
|
training_args_kwargs.update(self._configure_torch_compile())
|
||||||
|
|
||||||
|
return training_args_kwargs
|
||||||
|
|||||||
Reference in New Issue
Block a user