diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4b400e273..3e7255c89 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -506,6 +506,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer): optim_args.get("use_decoupled_weight_decay") ) + if "epsilon" in self.args.optim_shampoo_grafting_config_kwargs: + if isinstance( + self.args.optim_shampoo_grafting_config_kwargs["epsilon"], str + ): + self.args.optim_shampoo_grafting_config_kwargs[ + "epsilon" + ] = float( + self.args.optim_shampoo_grafting_config_kwargs["epsilon"] + ) if self.args.optim_shampoo_grafting_config_type == "adam": grafting_config = AdamGraftingConfig( **self.args.optim_shampoo_grafting_config_kwargs