From 84dad0bd12974325be9390463084fc563a976eed Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 18 Sep 2024 10:42:26 -0700 Subject: [PATCH] ensure epsilon is cast to float --- src/axolotl/core/trainer_builder.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4b400e273..3e7255c89 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -506,6 +506,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer): optim_args.get("use_decoupled_weight_decay") ) + if "epsilon" in self.args.optim_shampoo_grafting_config_kwargs: + if isinstance( + self.args.optim_shampoo_grafting_config_kwargs["epsilon"], str + ): + self.args.optim_shampoo_grafting_config_kwargs[ + "epsilon" + ] = float( + self.args.optim_shampoo_grafting_config_kwargs["epsilon"] + ) if self.args.optim_shampoo_grafting_config_type == "adam": grafting_config = AdamGraftingConfig( **self.args.optim_shampoo_grafting_config_kwargs