prevent usage of low bit ao optimizers with configurations that use parameter groups (#3003)

* prevent usage of low bit ao optimizers with configurations that use parameter groups

* use optimizer enum value

* fix validation
This commit is contained in:
Wing Lian
2025-08-01 17:54:04 -04:00
committed by GitHub
parent cda3c82351
commit 5639552064

View File

@@ -880,6 +880,23 @@ class OptimizationValidationMixin:
return self
@model_validator(mode="after")
def lr_groups_ao_optimizer(self):
if (
self.loraplus_lr_ratio is not None
or self.embedding_lr_scale is not None
or self.embedding_lr is not None
or self.lr_groups is not None
) and self.optimizer.value in ["adamw_torch_8bit", "adamw_torch_4bit"]:
# TODO(wing): remove this once ao>0.12.0
# requires https://github.com/pytorch/ao/pull/2606 in an ao release
raise ValueError(
"lr groups (`loraplus_lr_ratio`, `embedding_lr_scale`, `embedding_lr`, `lr_groups`) are not "
"supported with ao low-bit optimizers until ao>0.12.0. "
"Please refer to https://github.com/pytorch/ao/pull/2606."
)
return self
@model_validator(mode="before")
@classmethod
def check_tensor_parallel_size_update_ds_json(cls, data):