diff --git a/docs/optimizers.qmd b/docs/optimizers.qmd index d59ed43ab..fdcaac7cb 100644 --- a/docs/optimizers.qmd +++ b/docs/optimizers.qmd @@ -10,8 +10,8 @@ optim_args: max_preconditioner_dim: 8192 precondition_frequency: 100 use_decoupled_weight_decay: true - optim_shampoo_grafting_config_type: adam - optim_shampoo_grafting_config_kwargs: - beta2: 0.999 - epsilon: 1e-12 +optim_shampoo_grafting_config_type: adam +optim_shampoo_grafting_config_kwargs: + beta2: 0.999 + epsilon: 1e-12 ``` diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 43f0256fc..4b400e273 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -498,6 +498,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer): optim_args[key] = value optim_args["betas"] = self.args.optim_shampoo_betas + if isinstance(optim_args["epsilon"], str): + optim_args["epsilon"] = float(optim_args["epsilon"]) optim_args["lr"] = self.args.learning_rate optim_args["weight_decay"] = self.args.weight_decay optim_args["use_decoupled_weight_decay"] = bool( @@ -506,15 +508,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer): if self.args.optim_shampoo_grafting_config_type == "adam": grafting_config = AdamGraftingConfig( - self.args.optim_shampoo_grafting_config_kwargs + **self.args.optim_shampoo_grafting_config_kwargs ) elif self.args.optim_shampoo_grafting_config_type == "sgd": grafting_config = SGDGraftingConfig( - self.args.optim_shampoo_grafting_config_kwargs + **self.args.optim_shampoo_grafting_config_kwargs ) elif self.args.optim_shampoo_grafting_config_type == "adagrad": grafting_config = AdaGradGraftingConfig( - self.args.optim_shampoo_grafting_config_kwargs + **self.args.optim_shampoo_grafting_config_kwargs ) distributed_config = None diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 9154b3230..ada198f1b 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -776,6 +776,7 @@ class AxolotlInputConfig( data["accelerator_config"]["dispatch_batches"] = False return data + @model_validator(mode="before") @model_validator(mode="before") @classmethod def check_gptq_w_revision(cls, data):