fix casting of optim args

This commit is contained in:
Wing Lian
2024-09-18 10:48:15 -07:00
parent 84dad0bd12
commit 69a29382e1

View File

@@ -498,13 +498,22 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
optim_args[key] = value
optim_args["betas"] = self.args.optim_shampoo_betas
if "max_preconditioner_dim" in optim_args:
optim_args["max_preconditioner_dim"] = int(
optim_args["max_preconditioner_dim"]
)
if "precondition_frequency" in optim_args:
optim_args["precondition_frequency"] = int(
optim_args["precondition_frequency"]
)
if "use_decoupled_weight_decay" in optim_args:
optim_args["use_decoupled_weight_decay"] = bool(
optim_args["use_decoupled_weight_decay"]
)
if isinstance(optim_args["epsilon"], str):
optim_args["epsilon"] = float(optim_args["epsilon"])
optim_args["lr"] = self.args.learning_rate
optim_args["weight_decay"] = self.args.weight_decay
optim_args["use_decoupled_weight_decay"] = bool(
optim_args.get("use_decoupled_weight_decay")
)
if "epsilon" in self.args.optim_shampoo_grafting_config_kwargs:
if isinstance(