ensure epsilon is cast to float

This commit is contained in:
Wing Lian
2024-09-18 10:42:26 -07:00
parent 05f61a0ea5
commit 84dad0bd12

View File

@@ -506,6 +506,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
optim_args.get("use_decoupled_weight_decay")
)
if "epsilon" in self.args.optim_shampoo_grafting_config_kwargs:
if isinstance(
self.args.optim_shampoo_grafting_config_kwargs["epsilon"], str
):
self.args.optim_shampoo_grafting_config_kwargs[
"epsilon"
] = float(
self.args.optim_shampoo_grafting_config_kwargs["epsilon"]
)
if self.args.optim_shampoo_grafting_config_type == "adam":
grafting_config = AdamGraftingConfig(
**self.args.optim_shampoo_grafting_config_kwargs