fix casting of optim args
This commit is contained in:
@@ -498,13 +498,22 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
optim_args[key] = value
|
||||
|
||||
optim_args["betas"] = self.args.optim_shampoo_betas
|
||||
if "max_preconditioner_dim" in optim_args:
|
||||
optim_args["max_preconditioner_dim"] = int(
|
||||
optim_args["max_preconditioner_dim"]
|
||||
)
|
||||
if "precondition_frequency" in optim_args:
|
||||
optim_args["precondition_frequency"] = int(
|
||||
optim_args["precondition_frequency"]
|
||||
)
|
||||
if "use_decoupled_weight_decay" in optim_args:
|
||||
optim_args["use_decoupled_weight_decay"] = bool(
|
||||
optim_args["use_decoupled_weight_decay"]
|
||||
)
|
||||
if isinstance(optim_args["epsilon"], str):
|
||||
optim_args["epsilon"] = float(optim_args["epsilon"])
|
||||
optim_args["lr"] = self.args.learning_rate
|
||||
optim_args["weight_decay"] = self.args.weight_decay
|
||||
optim_args["use_decoupled_weight_decay"] = bool(
|
||||
optim_args.get("use_decoupled_weight_decay")
|
||||
)
|
||||
|
||||
if "epsilon" in self.args.optim_shampoo_grafting_config_kwargs:
|
||||
if isinstance(
|
||||
|
||||
Reference in New Issue
Block a user