fix: change to pass dict via arg instead of updating dict
This commit is contained in:
@@ -178,9 +178,7 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
# TODO
|
# TODO
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
def _configure_warmup_and_logging(self, total_num_steps):
|
def _configure_warmup_and_logging(self, total_num_steps, training_args_kwargs):
|
||||||
training_args_kwargs = {}
|
|
||||||
|
|
||||||
warmup_steps = 0
|
warmup_steps = 0
|
||||||
warmup_ratio = 0.0
|
warmup_ratio = 0.0
|
||||||
if self.cfg.warmup_steps:
|
if self.cfg.warmup_steps:
|
||||||
@@ -198,25 +196,19 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if warmup_steps == 1:
|
if warmup_steps == 1:
|
||||||
warmup_steps = 2
|
warmup_steps = 2
|
||||||
|
|
||||||
logging_steps = (
|
if self.cfg.logging_steps is not None:
|
||||||
self.cfg.logging_steps
|
training_args_kwargs["logging_steps"] = self.cfg.logging_steps
|
||||||
if self.cfg.logging_steps is not None
|
else:
|
||||||
else (
|
training_args_kwargs["logging_steps"] = (
|
||||||
500 # transformers defaults to 500
|
500 # transformers defaults to 500
|
||||||
if not total_num_steps
|
if not total_num_steps
|
||||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
training_args_kwargs["warmup_ratio"] = warmup_ratio
|
training_args_kwargs["warmup_ratio"] = warmup_ratio
|
||||||
training_args_kwargs["warmup_steps"] = warmup_steps
|
training_args_kwargs["warmup_steps"] = warmup_steps
|
||||||
training_args_kwargs["logging_steps"] = logging_steps
|
|
||||||
|
|
||||||
return training_args_kwargs
|
|
||||||
|
|
||||||
def _configure_precision_settings(self):
|
|
||||||
training_args_kwargs = {}
|
|
||||||
|
|
||||||
|
def _configure_precision_settings(self, training_args_kwargs):
|
||||||
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
|
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
|
||||||
training_args_kwargs["tf32"] = self.cfg.tf32
|
training_args_kwargs["tf32"] = self.cfg.tf32
|
||||||
if self.cfg.bf16 == "full":
|
if self.cfg.bf16 == "full":
|
||||||
@@ -224,11 +216,7 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
else:
|
else:
|
||||||
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
|
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
|
||||||
|
|
||||||
return training_args_kwargs
|
def _configure_optimizer_and_scheduler(self, training_args_kwargs):
|
||||||
|
|
||||||
def _configure_optimizer_and_scheduler(self):
|
|
||||||
training_args_kwargs = {}
|
|
||||||
|
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep", "rex"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep", "rex"]:
|
||||||
training_args_kwargs["lr_scheduler_type"] = "cosine"
|
training_args_kwargs["lr_scheduler_type"] = "cosine"
|
||||||
training_args_kwargs["alternate_lr_scheduler_type"] = self.cfg.lr_scheduler
|
training_args_kwargs["alternate_lr_scheduler_type"] = self.cfg.lr_scheduler
|
||||||
@@ -361,11 +349,7 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if self.cfg.optim_target_modules:
|
if self.cfg.optim_target_modules:
|
||||||
training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules
|
training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules
|
||||||
|
|
||||||
return training_args_kwargs
|
def _configure_hub_parameters(self, training_args_kwargs):
|
||||||
|
|
||||||
def _configure_hub_parameters(self):
|
|
||||||
training_args_kwargs = {}
|
|
||||||
|
|
||||||
if self.cfg.hub_model_id:
|
if self.cfg.hub_model_id:
|
||||||
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
||||||
training_args_kwargs["push_to_hub"] = True
|
training_args_kwargs["push_to_hub"] = True
|
||||||
@@ -375,11 +359,7 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if self.cfg.hub_strategy:
|
if self.cfg.hub_strategy:
|
||||||
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||||
|
|
||||||
return training_args_kwargs
|
def _configure_save_and_eval_strategy(self, training_args_kwargs):
|
||||||
|
|
||||||
def _configure_save_and_eval_strategy(self):
|
|
||||||
training_args_kwargs = {}
|
|
||||||
|
|
||||||
# save_strategy and save_steps
|
# save_strategy and save_steps
|
||||||
if self.cfg.save_steps:
|
if self.cfg.save_steps:
|
||||||
training_args_kwargs["save_strategy"] = "steps"
|
training_args_kwargs["save_strategy"] = "steps"
|
||||||
@@ -404,11 +384,7 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
elif self.cfg.eval_strategy:
|
elif self.cfg.eval_strategy:
|
||||||
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||||
|
|
||||||
return training_args_kwargs
|
def _configure_reporting(self, training_args_kwargs):
|
||||||
|
|
||||||
def _configure_reporting(self):
|
|
||||||
training_args_kwargs = {}
|
|
||||||
|
|
||||||
report_to = []
|
report_to = []
|
||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
report_to.append("wandb")
|
report_to.append("wandb")
|
||||||
@@ -428,11 +404,7 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
else:
|
else:
|
||||||
training_args_kwargs["run_name"] = None
|
training_args_kwargs["run_name"] = None
|
||||||
|
|
||||||
return training_args_kwargs
|
def _configure_torch_compile(self, training_args_kwargs):
|
||||||
|
|
||||||
def _configure_torch_compile(self):
|
|
||||||
training_args_kwargs = {}
|
|
||||||
|
|
||||||
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
|
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
|
||||||
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
||||||
True
|
True
|
||||||
@@ -445,11 +417,7 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if self.cfg.torch_compile_mode:
|
if self.cfg.torch_compile_mode:
|
||||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||||
|
|
||||||
return training_args_kwargs
|
def _configure_gradient_checkpointing(self, training_args_kwargs):
|
||||||
|
|
||||||
def _configure_gradient_checkpointing(self):
|
|
||||||
training_args_kwargs = {}
|
|
||||||
|
|
||||||
if self.cfg.gradient_checkpointing:
|
if self.cfg.gradient_checkpointing:
|
||||||
training_args_kwargs["gradient_checkpointing"] = (
|
training_args_kwargs["gradient_checkpointing"] = (
|
||||||
self.cfg.gradient_checkpointing
|
self.cfg.gradient_checkpointing
|
||||||
@@ -463,18 +431,16 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
"use_reentrant": False
|
"use_reentrant": False
|
||||||
}
|
}
|
||||||
|
|
||||||
return training_args_kwargs
|
|
||||||
|
|
||||||
def _set_base_training_args(self, total_num_steps) -> dict[str, Any]:
|
def _set_base_training_args(self, total_num_steps) -> dict[str, Any]:
|
||||||
training_args_kwargs: Dict[str, Any] = {}
|
training_args_kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
training_args_kwargs.update(self._configure_warmup_and_logging(total_num_steps))
|
self._configure_warmup_and_logging(total_num_steps, training_args_kwargs)
|
||||||
|
|
||||||
training_args_kwargs.update(self._configure_precision_settings())
|
self._configure_precision_settings(training_args_kwargs)
|
||||||
|
|
||||||
training_args_kwargs.update(self._configure_save_and_eval_strategy())
|
self._configure_save_and_eval_strategy(training_args_kwargs)
|
||||||
|
|
||||||
training_args_kwargs.update(self._configure_gradient_checkpointing())
|
self._configure_gradient_checkpointing(training_args_kwargs)
|
||||||
|
|
||||||
# set arg into trainer_args_kwargs with same name if value not None
|
# set arg into trainer_args_kwargs with same name if value not None
|
||||||
for arg in [
|
for arg in [
|
||||||
@@ -521,12 +487,12 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if self.cfg.reward_model or self.cfg.rl:
|
if self.cfg.reward_model or self.cfg.rl:
|
||||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
training_args_kwargs.update(self._configure_reporting())
|
self._configure_reporting(training_args_kwargs)
|
||||||
|
|
||||||
training_args_kwargs.update(self._configure_hub_parameters())
|
self._configure_hub_parameters(training_args_kwargs)
|
||||||
|
|
||||||
training_args_kwargs.update(self._configure_optimizer_and_scheduler())
|
self._configure_optimizer_and_scheduler(training_args_kwargs)
|
||||||
|
|
||||||
training_args_kwargs.update(self._configure_torch_compile())
|
self._configure_torch_compile(training_args_kwargs)
|
||||||
|
|
||||||
return training_args_kwargs
|
return training_args_kwargs
|
||||||
|
|||||||
Reference in New Issue
Block a user