ensure epsilon is cast to float
This commit is contained in:
@@ -506,6 +506,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
optim_args.get("use_decoupled_weight_decay")
|
optim_args.get("use_decoupled_weight_decay")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "epsilon" in self.args.optim_shampoo_grafting_config_kwargs:
|
||||||
|
if isinstance(
|
||||||
|
self.args.optim_shampoo_grafting_config_kwargs["epsilon"], str
|
||||||
|
):
|
||||||
|
self.args.optim_shampoo_grafting_config_kwargs[
|
||||||
|
"epsilon"
|
||||||
|
] = float(
|
||||||
|
self.args.optim_shampoo_grafting_config_kwargs["epsilon"]
|
||||||
|
)
|
||||||
if self.args.optim_shampoo_grafting_config_type == "adam":
|
if self.args.optim_shampoo_grafting_config_type == "adam":
|
||||||
grafting_config = AdamGraftingConfig(
|
grafting_config = AdamGraftingConfig(
|
||||||
**self.args.optim_shampoo_grafting_config_kwargs
|
**self.args.optim_shampoo_grafting_config_kwargs
|
||||||
|
|||||||
Reference in New Issue
Block a user