diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index aa249c6ce..02e80dd8e 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -880,6 +880,23 @@ class OptimizationValidationMixin: return self + @model_validator(mode="after") + def lr_groups_ao_optimizer(self): + if ( + self.loraplus_lr_ratio is not None + or self.embedding_lr_scale is not None + or self.embedding_lr is not None + or self.lr_groups is not None + ) and self.optimizer.value in ["adamw_torch_8bit", "adamw_torch_4bit"]: + # TODO(wing): remove this once ao>0.12.0 + # requires https://github.com/pytorch/ao/pull/2606 in an ao release + raise ValueError( + "lr groups (`loraplus_lr_ratio`, `embedding_lr_scale`, `embedding_lr`, `lr_groups`) are not " + "supported with ao low-bit optimizers until ao>0.12.0. " + "Please refer to https://github.com/pytorch/ao/pull/2606." + ) + return self + @model_validator(mode="before") @classmethod def check_tensor_parallel_size_update_ds_json(cls, data):