prevent usage of low bit ao optimizers with configurations that use parameter groups (#3003)
* prevent usage of low bit ao optimizers with configurations that use parameter groups * use optimizer enum value * fix validation
This commit is contained in:
@@ -880,6 +880,23 @@ class OptimizationValidationMixin:
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def lr_groups_ao_optimizer(self):
|
||||
if (
|
||||
self.loraplus_lr_ratio is not None
|
||||
or self.embedding_lr_scale is not None
|
||||
or self.embedding_lr is not None
|
||||
or self.lr_groups is not None
|
||||
) and self.optimizer.value in ["adamw_torch_8bit", "adamw_torch_4bit"]:
|
||||
# TODO(wing): remove this once ao>0.12.0
|
||||
# requires https://github.com/pytorch/ao/pull/2606 in an ao release
|
||||
raise ValueError(
|
||||
"lr groups (`loraplus_lr_ratio`, `embedding_lr_scale`, `embedding_lr`, `lr_groups`) are not "
|
||||
"supported with ao low-bit optimizers until ao>0.12.0. "
|
||||
"Please refer to https://github.com/pytorch/ao/pull/2606."
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tensor_parallel_size_update_ds_json(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user