diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 3e7255c89..02b68c7d7 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -498,13 +498,22 @@ class AxolotlTrainer(SchedulerMixin, Trainer): optim_args[key] = value optim_args["betas"] = self.args.optim_shampoo_betas + if "max_preconditioner_dim" in optim_args: + optim_args["max_preconditioner_dim"] = int( + optim_args["max_preconditioner_dim"] + ) + if "precondition_frequency" in optim_args: + optim_args["precondition_frequency"] = int( + optim_args["precondition_frequency"] + ) + if "use_decoupled_weight_decay" in optim_args: + optim_args["use_decoupled_weight_decay"] = bool( + optim_args["use_decoupled_weight_decay"] + ) if isinstance(optim_args["epsilon"], str): optim_args["epsilon"] = float(optim_args["epsilon"]) optim_args["lr"] = self.args.learning_rate optim_args["weight_decay"] = self.args.weight_decay - optim_args["use_decoupled_weight_decay"] = bool( - optim_args.get("use_decoupled_weight_decay") - ) if "epsilon" in self.args.optim_shampoo_grafting_config_kwargs: if isinstance(