fix: refactor sft and rl trainer to set same base args
This commit is contained in:
@@ -230,6 +230,101 @@ class TrainerBuilderBase(abc.ABC):
|
||||
# TODO
|
||||
return trainer
|
||||
|
||||
def _set_base_training_args(self, total_num_steps) -> dict[str, Any]:
|
||||
training_args_kwargs = {}
|
||||
|
||||
warmup_steps = None
|
||||
if self.cfg.warmup_steps is not None:
|
||||
warmup_steps = self.cfg.warmup_steps
|
||||
elif self.cfg.warmup_ratio is not None:
|
||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||
else:
|
||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||
if warmup_steps == 1:
|
||||
warmup_steps = 2
|
||||
|
||||
logging_steps = (
|
||||
self.cfg.logging_steps
|
||||
if self.cfg.logging_steps is not None
|
||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||
)
|
||||
|
||||
training_args_kwargs["warmup_steps"] = warmup_steps
|
||||
training_args_kwargs["logging_steps"] = logging_steps
|
||||
|
||||
if self.cfg.hub_model_id:
|
||||
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
||||
training_args_kwargs["push_to_hub"] = True
|
||||
training_args_kwargs["hub_private_repo"] = True
|
||||
training_args_kwargs["hub_always_push"] = True
|
||||
|
||||
if self.cfg.hub_strategy:
|
||||
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||
|
||||
if self.cfg.save_safetensors is not None:
|
||||
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||
|
||||
# set save_strategy and save_steps
|
||||
if self.cfg.save_steps:
|
||||
training_args_kwargs["save_strategy"] = "steps"
|
||||
training_args_kwargs["save_steps"] = self.cfg.save_steps
|
||||
elif self.cfg.save_strategy:
|
||||
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
|
||||
else:
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
training_args_kwargs["save_only_model"] = self.cfg.save_only_model
|
||||
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_args_kwargs[
|
||||
"gradient_checkpointing"
|
||||
] = self.cfg.gradient_checkpointing
|
||||
if self.cfg.gradient_checkpointing_kwargs is not None:
|
||||
training_args_kwargs[
|
||||
"gradient_checkpointing_kwargs"
|
||||
] = self.cfg.gradient_checkpointing_kwargs
|
||||
else:
|
||||
training_args_kwargs["gradient_checkpointing_kwargs"] = {
|
||||
"use_reentrant": False
|
||||
}
|
||||
|
||||
for arg in [
|
||||
"adam_beta1",
|
||||
"adam_beta2",
|
||||
"adam_epsilon",
|
||||
"max_grad_norm",
|
||||
"dataloader_num_workers",
|
||||
"dataloader_pin_memory",
|
||||
"dataloader_prefetch_factor",
|
||||
"include_tokens_per_second",
|
||||
]:
|
||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||
|
||||
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
||||
|
||||
if self.cfg.eval_batch_size:
|
||||
training_args_kwargs[
|
||||
"per_device_eval_batch_size"
|
||||
] = self.cfg.eval_batch_size
|
||||
|
||||
training_args_kwargs[
|
||||
"gradient_accumulation_steps"
|
||||
] = self.cfg.gradient_accumulation_steps
|
||||
|
||||
training_args_kwargs["learning_rate"] = self.cfg.learning_rate
|
||||
training_args_kwargs["output_dir"] = self.cfg.output_dir
|
||||
training_args_kwargs["save_total_limit"] = (
|
||||
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
|
||||
)
|
||||
|
||||
training_args_kwargs["max_steps"] = (
|
||||
total_num_steps if self.cfg.max_steps else -1
|
||||
)
|
||||
|
||||
return training_args_kwargs
|
||||
|
||||
|
||||
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
@@ -319,29 +414,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
warmup_steps = None
|
||||
if self.cfg.warmup_steps is not None:
|
||||
warmup_steps = self.cfg.warmup_steps
|
||||
elif self.cfg.warmup_ratio is not None:
|
||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||
else:
|
||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||
if warmup_steps == 1:
|
||||
warmup_steps = 2
|
||||
|
||||
logging_steps = (
|
||||
self.cfg.logging_steps
|
||||
if self.cfg.logging_steps is not None
|
||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||
)
|
||||
|
||||
training_arguments_kwargs = {}
|
||||
|
||||
if self.cfg.include_tokens_per_second is not None:
|
||||
training_arguments_kwargs["include_tokens_per_second"] = (
|
||||
self.cfg.include_tokens_per_second
|
||||
)
|
||||
|
||||
training_arguments_kwargs = self._set_base_training_args(total_num_steps)
|
||||
if self.cfg.bf16 == "full":
|
||||
training_arguments_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
@@ -350,20 +423,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.fp16 and not self.cfg.bf16
|
||||
) or False
|
||||
training_arguments_kwargs["tf32"] = self.cfg.tf32
|
||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||
|
||||
if self.cfg.seed is not None:
|
||||
training_arguments_kwargs["seed"] = self.cfg.seed
|
||||
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_arguments_kwargs["gradient_checkpointing"] = (
|
||||
self.cfg.gradient_checkpointing
|
||||
)
|
||||
if self.cfg.gradient_checkpointing_kwargs is not None:
|
||||
training_arguments_kwargs["gradient_checkpointing_kwargs"] = (
|
||||
self.cfg.gradient_checkpointing_kwargs
|
||||
)
|
||||
if self.cfg.fsdp:
|
||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||
if self.cfg.fsdp_config:
|
||||
@@ -383,39 +446,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.lr_quadratic_warmup
|
||||
)
|
||||
|
||||
if self.cfg.adam_beta1:
|
||||
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
||||
if self.cfg.adam_beta2:
|
||||
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
|
||||
if self.cfg.adam_epsilon:
|
||||
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
|
||||
if self.cfg.max_grad_norm:
|
||||
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
|
||||
|
||||
if self.cfg.hub_model_id:
|
||||
training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
||||
training_arguments_kwargs["push_to_hub"] = True
|
||||
training_arguments_kwargs["hub_private_repo"] = True
|
||||
training_arguments_kwargs["hub_always_push"] = True
|
||||
|
||||
if self.cfg.hub_strategy:
|
||||
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||
|
||||
if self.cfg.save_safetensors is not None:
|
||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||
|
||||
if self.cfg.dataloader_pin_memory is not None:
|
||||
training_arguments_kwargs["dataloader_pin_memory"] = (
|
||||
self.cfg.dataloader_pin_memory
|
||||
)
|
||||
if self.cfg.dataloader_num_workers is not None:
|
||||
training_arguments_kwargs["dataloader_num_workers"] = (
|
||||
self.cfg.dataloader_num_workers
|
||||
)
|
||||
if self.cfg.dataloader_prefetch_factor is not None:
|
||||
training_arguments_kwargs["dataloader_prefetch_factor"] = (
|
||||
self.cfg.dataloader_prefetch_factor
|
||||
)
|
||||
if self.cfg.dataloader_drop_last is not None:
|
||||
training_arguments_kwargs["dataloader_drop_last"] = (
|
||||
self.cfg.dataloader_drop_last
|
||||
@@ -440,17 +470,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
# we have an eval set, but no steps defined, default to use epoch
|
||||
training_arguments_kwargs["eval_strategy"] = "epoch"
|
||||
|
||||
if self.cfg.save_steps:
|
||||
training_arguments_kwargs["save_strategy"] = "steps"
|
||||
training_arguments_kwargs["save_steps"] = self.cfg.save_steps
|
||||
elif self.cfg.save_strategy:
|
||||
training_arguments_kwargs["save_strategy"] = self.cfg.save_strategy
|
||||
else:
|
||||
# default to saving each epoch if not defined
|
||||
training_arguments_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
|
||||
|
||||
if self.cfg.do_bench_eval:
|
||||
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
||||
if self.cfg.bench_dataset:
|
||||
@@ -493,33 +512,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
|
||||
# these are all the "standard" kwargs that are def used
|
||||
training_arguments_kwargs["max_steps"] = (
|
||||
self.cfg.max_steps if self.cfg.max_steps else -1
|
||||
)
|
||||
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
||||
training_arguments_kwargs["per_device_train_batch_size"] = (
|
||||
self.cfg.micro_batch_size
|
||||
)
|
||||
if self.cfg.eval_batch_size:
|
||||
training_arguments_kwargs["per_device_eval_batch_size"] = (
|
||||
self.cfg.eval_batch_size
|
||||
)
|
||||
|
||||
if self.cfg.auto_find_batch_size is not None:
|
||||
training_arguments_kwargs["auto_find_batch_size"] = (
|
||||
self.cfg.auto_find_batch_size
|
||||
)
|
||||
training_arguments_kwargs["gradient_accumulation_steps"] = (
|
||||
self.cfg.gradient_accumulation_steps
|
||||
)
|
||||
training_arguments_kwargs["eval_accumulation_steps"] = (
|
||||
self.cfg.gradient_accumulation_steps
|
||||
)
|
||||
training_arguments_kwargs[
|
||||
"auto_find_batch_size"
|
||||
] = self.cfg.auto_find_batch_size
|
||||
|
||||
training_arguments_kwargs[
|
||||
"eval_accumulation_steps"
|
||||
] = self.cfg.gradient_accumulation_steps
|
||||
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
|
||||
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
|
||||
training_arguments_kwargs["save_total_limit"] = (
|
||||
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
|
||||
)
|
||||
|
||||
training_arguments_kwargs["load_best_model_at_end"] = (
|
||||
(
|
||||
self.cfg.load_best_model_at_end is not False
|
||||
@@ -974,34 +978,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
return callbacks
|
||||
|
||||
def build_training_arguments(self, total_num_steps):
|
||||
training_args_kwargs = {}
|
||||
for arg in [
|
||||
"adam_beta1",
|
||||
"adam_beta2",
|
||||
"adam_epsilon",
|
||||
"dataloader_num_workers",
|
||||
"dataloader_pin_memory",
|
||||
]:
|
||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||
training_args_kwargs = self._set_base_training_args(
|
||||
total_num_steps=total_num_steps
|
||||
)
|
||||
|
||||
if self.cfg.hub_model_id:
|
||||
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
||||
training_args_kwargs["push_to_hub"] = True
|
||||
training_args_kwargs["hub_private_repo"] = True
|
||||
training_args_kwargs["hub_always_push"] = True
|
||||
|
||||
if self.cfg.hub_strategy:
|
||||
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||
|
||||
if self.cfg.save_safetensors is not None:
|
||||
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||
|
||||
if self.eval_dataset:
|
||||
if not self.eval_dataset:
|
||||
training_args_kwargs["eval_strategy"] = "no"
|
||||
elif self.cfg.eval_steps:
|
||||
training_args_kwargs["eval_strategy"] = "steps"
|
||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
else:
|
||||
training_args_kwargs["eval_strategy"] = "no"
|
||||
elif self.cfg.eval_strategy:
|
||||
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||
|
||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||
training_args_kwargs["bf16"] = True
|
||||
@@ -1014,6 +1001,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_kwargs["lr_scheduler_kwargs"] = (
|
||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||
)
|
||||
|
||||
if self.cfg.remove_unused_columns is not None:
|
||||
training_args_kwargs["remove_unused_columns"] = (
|
||||
self.cfg.remove_unused_columns
|
||||
@@ -1021,47 +1009,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
training_args_kwargs["remove_unused_columns"] = False
|
||||
|
||||
if self.cfg.dataloader_pin_memory is not None:
|
||||
training_args_kwargs["dataloader_pin_memory"] = (
|
||||
self.cfg.dataloader_pin_memory
|
||||
)
|
||||
if self.cfg.dataloader_num_workers is not None:
|
||||
training_args_kwargs["dataloader_num_workers"] = (
|
||||
self.cfg.dataloader_num_workers
|
||||
)
|
||||
if self.cfg.dataloader_prefetch_factor is not None:
|
||||
training_args_kwargs["dataloader_prefetch_factor"] = (
|
||||
self.cfg.dataloader_prefetch_factor
|
||||
)
|
||||
|
||||
if self.cfg.seed is not None:
|
||||
training_args_kwargs["seed"] = self.cfg.seed
|
||||
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_args_kwargs["gradient_checkpointing"] = (
|
||||
self.cfg.gradient_checkpointing
|
||||
)
|
||||
if self.cfg.gradient_checkpointing_kwargs is not None:
|
||||
training_args_kwargs["gradient_checkpointing_kwargs"] = (
|
||||
self.cfg.gradient_checkpointing_kwargs
|
||||
)
|
||||
else:
|
||||
training_args_kwargs["gradient_checkpointing_kwargs"] = {
|
||||
"use_reentrant": False
|
||||
}
|
||||
|
||||
# set save_strategy and save_steps
|
||||
if self.cfg.save_steps:
|
||||
training_args_kwargs["save_strategy"] = "steps"
|
||||
training_args_kwargs["save_steps"] = self.cfg.save_steps
|
||||
elif self.cfg.save_strategy:
|
||||
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
|
||||
else:
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
training_args_kwargs["save_only_model"] = self.cfg.save_only_model
|
||||
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
@@ -1137,19 +1084,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if blocklist_key in training_args_kwargs:
|
||||
del training_args_kwargs[blocklist_key]
|
||||
|
||||
max_steps = self.cfg.max_steps or total_num_steps or -1
|
||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
self.cfg.output_dir,
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
max_steps=max_steps,
|
||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||
learning_rate=self.cfg.learning_rate,
|
||||
warmup_steps=self.cfg.warmup_steps,
|
||||
logging_first_step=True,
|
||||
logging_steps=1,
|
||||
optim=self.cfg.optimizer,
|
||||
save_total_limit=self.cfg.save_total_limit or 5,
|
||||
**training_args_kwargs,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user