fix: refactor sft and rl trainer to set same base args
This commit is contained in:
@@ -230,6 +230,101 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
# TODO
|
# TODO
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
|
def _set_base_training_args(self, total_num_steps) -> dict[str, Any]:
|
||||||
|
training_args_kwargs = {}
|
||||||
|
|
||||||
|
warmup_steps = None
|
||||||
|
if self.cfg.warmup_steps is not None:
|
||||||
|
warmup_steps = self.cfg.warmup_steps
|
||||||
|
elif self.cfg.warmup_ratio is not None:
|
||||||
|
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||||
|
else:
|
||||||
|
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||||
|
if warmup_steps == 1:
|
||||||
|
warmup_steps = 2
|
||||||
|
|
||||||
|
logging_steps = (
|
||||||
|
self.cfg.logging_steps
|
||||||
|
if self.cfg.logging_steps is not None
|
||||||
|
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
training_args_kwargs["warmup_steps"] = warmup_steps
|
||||||
|
training_args_kwargs["logging_steps"] = logging_steps
|
||||||
|
|
||||||
|
if self.cfg.hub_model_id:
|
||||||
|
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
||||||
|
training_args_kwargs["push_to_hub"] = True
|
||||||
|
training_args_kwargs["hub_private_repo"] = True
|
||||||
|
training_args_kwargs["hub_always_push"] = True
|
||||||
|
|
||||||
|
if self.cfg.hub_strategy:
|
||||||
|
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||||
|
|
||||||
|
if self.cfg.save_safetensors is not None:
|
||||||
|
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||||
|
|
||||||
|
# set save_strategy and save_steps
|
||||||
|
if self.cfg.save_steps:
|
||||||
|
training_args_kwargs["save_strategy"] = "steps"
|
||||||
|
training_args_kwargs["save_steps"] = self.cfg.save_steps
|
||||||
|
elif self.cfg.save_strategy:
|
||||||
|
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
|
||||||
|
else:
|
||||||
|
# default to saving each epoch if not defined
|
||||||
|
training_args_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
|
training_args_kwargs["save_only_model"] = self.cfg.save_only_model
|
||||||
|
|
||||||
|
if self.cfg.gradient_checkpointing:
|
||||||
|
training_args_kwargs[
|
||||||
|
"gradient_checkpointing"
|
||||||
|
] = self.cfg.gradient_checkpointing
|
||||||
|
if self.cfg.gradient_checkpointing_kwargs is not None:
|
||||||
|
training_args_kwargs[
|
||||||
|
"gradient_checkpointing_kwargs"
|
||||||
|
] = self.cfg.gradient_checkpointing_kwargs
|
||||||
|
else:
|
||||||
|
training_args_kwargs["gradient_checkpointing_kwargs"] = {
|
||||||
|
"use_reentrant": False
|
||||||
|
}
|
||||||
|
|
||||||
|
for arg in [
|
||||||
|
"adam_beta1",
|
||||||
|
"adam_beta2",
|
||||||
|
"adam_epsilon",
|
||||||
|
"max_grad_norm",
|
||||||
|
"dataloader_num_workers",
|
||||||
|
"dataloader_pin_memory",
|
||||||
|
"dataloader_prefetch_factor",
|
||||||
|
"include_tokens_per_second",
|
||||||
|
]:
|
||||||
|
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||||
|
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||||
|
|
||||||
|
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
||||||
|
|
||||||
|
if self.cfg.eval_batch_size:
|
||||||
|
training_args_kwargs[
|
||||||
|
"per_device_eval_batch_size"
|
||||||
|
] = self.cfg.eval_batch_size
|
||||||
|
|
||||||
|
training_args_kwargs[
|
||||||
|
"gradient_accumulation_steps"
|
||||||
|
] = self.cfg.gradient_accumulation_steps
|
||||||
|
|
||||||
|
training_args_kwargs["learning_rate"] = self.cfg.learning_rate
|
||||||
|
training_args_kwargs["output_dir"] = self.cfg.output_dir
|
||||||
|
training_args_kwargs["save_total_limit"] = (
|
||||||
|
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
|
||||||
|
)
|
||||||
|
|
||||||
|
training_args_kwargs["max_steps"] = (
|
||||||
|
total_num_steps if self.cfg.max_steps else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
return training_args_kwargs
|
||||||
|
|
||||||
|
|
||||||
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||||
"""
|
"""
|
||||||
@@ -319,29 +414,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
warmup_steps = None
|
training_arguments_kwargs = self._set_base_training_args(total_num_steps)
|
||||||
if self.cfg.warmup_steps is not None:
|
|
||||||
warmup_steps = self.cfg.warmup_steps
|
|
||||||
elif self.cfg.warmup_ratio is not None:
|
|
||||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
|
||||||
else:
|
|
||||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
|
||||||
if warmup_steps == 1:
|
|
||||||
warmup_steps = 2
|
|
||||||
|
|
||||||
logging_steps = (
|
|
||||||
self.cfg.logging_steps
|
|
||||||
if self.cfg.logging_steps is not None
|
|
||||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
training_arguments_kwargs = {}
|
|
||||||
|
|
||||||
if self.cfg.include_tokens_per_second is not None:
|
|
||||||
training_arguments_kwargs["include_tokens_per_second"] = (
|
|
||||||
self.cfg.include_tokens_per_second
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.bf16 == "full":
|
if self.cfg.bf16 == "full":
|
||||||
training_arguments_kwargs["bf16_full_eval"] = True
|
training_arguments_kwargs["bf16_full_eval"] = True
|
||||||
else:
|
else:
|
||||||
@@ -350,20 +423,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.fp16 and not self.cfg.bf16
|
self.cfg.fp16 and not self.cfg.bf16
|
||||||
) or False
|
) or False
|
||||||
training_arguments_kwargs["tf32"] = self.cfg.tf32
|
training_arguments_kwargs["tf32"] = self.cfg.tf32
|
||||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
|
||||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
|
||||||
|
|
||||||
if self.cfg.seed is not None:
|
if self.cfg.seed is not None:
|
||||||
training_arguments_kwargs["seed"] = self.cfg.seed
|
training_arguments_kwargs["seed"] = self.cfg.seed
|
||||||
|
|
||||||
if self.cfg.gradient_checkpointing:
|
|
||||||
training_arguments_kwargs["gradient_checkpointing"] = (
|
|
||||||
self.cfg.gradient_checkpointing
|
|
||||||
)
|
|
||||||
if self.cfg.gradient_checkpointing_kwargs is not None:
|
|
||||||
training_arguments_kwargs["gradient_checkpointing_kwargs"] = (
|
|
||||||
self.cfg.gradient_checkpointing_kwargs
|
|
||||||
)
|
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||||
if self.cfg.fsdp_config:
|
if self.cfg.fsdp_config:
|
||||||
@@ -383,39 +446,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.lr_quadratic_warmup
|
self.cfg.lr_quadratic_warmup
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.adam_beta1:
|
|
||||||
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
|
||||||
if self.cfg.adam_beta2:
|
|
||||||
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
|
|
||||||
if self.cfg.adam_epsilon:
|
|
||||||
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
|
|
||||||
if self.cfg.max_grad_norm:
|
|
||||||
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
|
|
||||||
|
|
||||||
if self.cfg.hub_model_id:
|
|
||||||
training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
|
||||||
training_arguments_kwargs["push_to_hub"] = True
|
|
||||||
training_arguments_kwargs["hub_private_repo"] = True
|
|
||||||
training_arguments_kwargs["hub_always_push"] = True
|
|
||||||
|
|
||||||
if self.cfg.hub_strategy:
|
|
||||||
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
|
||||||
|
|
||||||
if self.cfg.save_safetensors is not None:
|
|
||||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
|
||||||
|
|
||||||
if self.cfg.dataloader_pin_memory is not None:
|
|
||||||
training_arguments_kwargs["dataloader_pin_memory"] = (
|
|
||||||
self.cfg.dataloader_pin_memory
|
|
||||||
)
|
|
||||||
if self.cfg.dataloader_num_workers is not None:
|
|
||||||
training_arguments_kwargs["dataloader_num_workers"] = (
|
|
||||||
self.cfg.dataloader_num_workers
|
|
||||||
)
|
|
||||||
if self.cfg.dataloader_prefetch_factor is not None:
|
|
||||||
training_arguments_kwargs["dataloader_prefetch_factor"] = (
|
|
||||||
self.cfg.dataloader_prefetch_factor
|
|
||||||
)
|
|
||||||
if self.cfg.dataloader_drop_last is not None:
|
if self.cfg.dataloader_drop_last is not None:
|
||||||
training_arguments_kwargs["dataloader_drop_last"] = (
|
training_arguments_kwargs["dataloader_drop_last"] = (
|
||||||
self.cfg.dataloader_drop_last
|
self.cfg.dataloader_drop_last
|
||||||
@@ -440,17 +470,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
# we have an eval set, but no steps defined, default to use epoch
|
# we have an eval set, but no steps defined, default to use epoch
|
||||||
training_arguments_kwargs["eval_strategy"] = "epoch"
|
training_arguments_kwargs["eval_strategy"] = "epoch"
|
||||||
|
|
||||||
if self.cfg.save_steps:
|
|
||||||
training_arguments_kwargs["save_strategy"] = "steps"
|
|
||||||
training_arguments_kwargs["save_steps"] = self.cfg.save_steps
|
|
||||||
elif self.cfg.save_strategy:
|
|
||||||
training_arguments_kwargs["save_strategy"] = self.cfg.save_strategy
|
|
||||||
else:
|
|
||||||
# default to saving each epoch if not defined
|
|
||||||
training_arguments_kwargs["save_strategy"] = "epoch"
|
|
||||||
|
|
||||||
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
|
|
||||||
|
|
||||||
if self.cfg.do_bench_eval:
|
if self.cfg.do_bench_eval:
|
||||||
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
||||||
if self.cfg.bench_dataset:
|
if self.cfg.bench_dataset:
|
||||||
@@ -493,33 +512,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# these are all the "standard" kwargs that are def used
|
# these are all the "standard" kwargs that are def used
|
||||||
training_arguments_kwargs["max_steps"] = (
|
|
||||||
self.cfg.max_steps if self.cfg.max_steps else -1
|
|
||||||
)
|
|
||||||
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
||||||
training_arguments_kwargs["per_device_train_batch_size"] = (
|
|
||||||
self.cfg.micro_batch_size
|
|
||||||
)
|
|
||||||
if self.cfg.eval_batch_size:
|
|
||||||
training_arguments_kwargs["per_device_eval_batch_size"] = (
|
|
||||||
self.cfg.eval_batch_size
|
|
||||||
)
|
|
||||||
if self.cfg.auto_find_batch_size is not None:
|
if self.cfg.auto_find_batch_size is not None:
|
||||||
training_arguments_kwargs["auto_find_batch_size"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.auto_find_batch_size
|
"auto_find_batch_size"
|
||||||
)
|
] = self.cfg.auto_find_batch_size
|
||||||
training_arguments_kwargs["gradient_accumulation_steps"] = (
|
|
||||||
self.cfg.gradient_accumulation_steps
|
training_arguments_kwargs[
|
||||||
)
|
"eval_accumulation_steps"
|
||||||
training_arguments_kwargs["eval_accumulation_steps"] = (
|
] = self.cfg.gradient_accumulation_steps
|
||||||
self.cfg.gradient_accumulation_steps
|
|
||||||
)
|
|
||||||
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||||
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
|
|
||||||
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
|
|
||||||
training_arguments_kwargs["save_total_limit"] = (
|
|
||||||
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
|
|
||||||
)
|
|
||||||
training_arguments_kwargs["load_best_model_at_end"] = (
|
training_arguments_kwargs["load_best_model_at_end"] = (
|
||||||
(
|
(
|
||||||
self.cfg.load_best_model_at_end is not False
|
self.cfg.load_best_model_at_end is not False
|
||||||
@@ -974,34 +978,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def build_training_arguments(self, total_num_steps):
|
def build_training_arguments(self, total_num_steps):
|
||||||
training_args_kwargs = {}
|
training_args_kwargs = self._set_base_training_args(
|
||||||
for arg in [
|
total_num_steps=total_num_steps
|
||||||
"adam_beta1",
|
)
|
||||||
"adam_beta2",
|
|
||||||
"adam_epsilon",
|
|
||||||
"dataloader_num_workers",
|
|
||||||
"dataloader_pin_memory",
|
|
||||||
]:
|
|
||||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
|
||||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
|
||||||
|
|
||||||
if self.cfg.hub_model_id:
|
if not self.eval_dataset:
|
||||||
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
training_args_kwargs["eval_strategy"] = "no"
|
||||||
training_args_kwargs["push_to_hub"] = True
|
elif self.cfg.eval_steps:
|
||||||
training_args_kwargs["hub_private_repo"] = True
|
|
||||||
training_args_kwargs["hub_always_push"] = True
|
|
||||||
|
|
||||||
if self.cfg.hub_strategy:
|
|
||||||
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
|
||||||
|
|
||||||
if self.cfg.save_safetensors is not None:
|
|
||||||
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
|
||||||
|
|
||||||
if self.eval_dataset:
|
|
||||||
training_args_kwargs["eval_strategy"] = "steps"
|
training_args_kwargs["eval_strategy"] = "steps"
|
||||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
else:
|
elif self.cfg.eval_strategy:
|
||||||
training_args_kwargs["eval_strategy"] = "no"
|
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||||
|
|
||||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||||
training_args_kwargs["bf16"] = True
|
training_args_kwargs["bf16"] = True
|
||||||
@@ -1014,6 +1001,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["lr_scheduler_kwargs"] = (
|
training_args_kwargs["lr_scheduler_kwargs"] = (
|
||||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.remove_unused_columns is not None:
|
if self.cfg.remove_unused_columns is not None:
|
||||||
training_args_kwargs["remove_unused_columns"] = (
|
training_args_kwargs["remove_unused_columns"] = (
|
||||||
self.cfg.remove_unused_columns
|
self.cfg.remove_unused_columns
|
||||||
@@ -1021,47 +1009,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
training_args_kwargs["remove_unused_columns"] = False
|
training_args_kwargs["remove_unused_columns"] = False
|
||||||
|
|
||||||
if self.cfg.dataloader_pin_memory is not None:
|
|
||||||
training_args_kwargs["dataloader_pin_memory"] = (
|
|
||||||
self.cfg.dataloader_pin_memory
|
|
||||||
)
|
|
||||||
if self.cfg.dataloader_num_workers is not None:
|
|
||||||
training_args_kwargs["dataloader_num_workers"] = (
|
|
||||||
self.cfg.dataloader_num_workers
|
|
||||||
)
|
|
||||||
if self.cfg.dataloader_prefetch_factor is not None:
|
|
||||||
training_args_kwargs["dataloader_prefetch_factor"] = (
|
|
||||||
self.cfg.dataloader_prefetch_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.seed is not None:
|
|
||||||
training_args_kwargs["seed"] = self.cfg.seed
|
|
||||||
|
|
||||||
if self.cfg.gradient_checkpointing:
|
|
||||||
training_args_kwargs["gradient_checkpointing"] = (
|
|
||||||
self.cfg.gradient_checkpointing
|
|
||||||
)
|
|
||||||
if self.cfg.gradient_checkpointing_kwargs is not None:
|
|
||||||
training_args_kwargs["gradient_checkpointing_kwargs"] = (
|
|
||||||
self.cfg.gradient_checkpointing_kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
training_args_kwargs["gradient_checkpointing_kwargs"] = {
|
|
||||||
"use_reentrant": False
|
|
||||||
}
|
|
||||||
|
|
||||||
# set save_strategy and save_steps
|
|
||||||
if self.cfg.save_steps:
|
|
||||||
training_args_kwargs["save_strategy"] = "steps"
|
|
||||||
training_args_kwargs["save_steps"] = self.cfg.save_steps
|
|
||||||
elif self.cfg.save_strategy:
|
|
||||||
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
|
|
||||||
else:
|
|
||||||
# default to saving each epoch if not defined
|
|
||||||
training_args_kwargs["save_strategy"] = "epoch"
|
|
||||||
|
|
||||||
training_args_kwargs["save_only_model"] = self.cfg.save_only_model
|
|
||||||
|
|
||||||
if self.cfg.dataset_processes:
|
if self.cfg.dataset_processes:
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
|
||||||
@@ -1137,19 +1084,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if blocklist_key in training_args_kwargs:
|
if blocklist_key in training_args_kwargs:
|
||||||
del training_args_kwargs[blocklist_key]
|
del training_args_kwargs[blocklist_key]
|
||||||
|
|
||||||
max_steps = self.cfg.max_steps or total_num_steps or -1
|
|
||||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||||
self.cfg.output_dir,
|
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
|
||||||
max_steps=max_steps,
|
|
||||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
|
||||||
learning_rate=self.cfg.learning_rate,
|
|
||||||
warmup_steps=self.cfg.warmup_steps,
|
|
||||||
logging_first_step=True,
|
logging_first_step=True,
|
||||||
logging_steps=1,
|
|
||||||
optim=self.cfg.optimizer,
|
optim=self.cfg.optimizer,
|
||||||
save_total_limit=self.cfg.save_total_limit or 5,
|
|
||||||
**training_args_kwargs,
|
**training_args_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user