prevent usage of low bit ao optimizers with configurations that use parameter groups (#3003)
* prevent usage of low bit ao optimizers with configurations that use parameter groups * use optimizer enum value * fix validation
This commit is contained in:
@@ -880,6 +880,23 @@ class OptimizationValidationMixin:
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def lr_groups_ao_optimizer(self):
|
||||||
|
if (
|
||||||
|
self.loraplus_lr_ratio is not None
|
||||||
|
or self.embedding_lr_scale is not None
|
||||||
|
or self.embedding_lr is not None
|
||||||
|
or self.lr_groups is not None
|
||||||
|
) and self.optimizer.value in ["adamw_torch_8bit", "adamw_torch_4bit"]:
|
||||||
|
# TODO(wing): remove this once ao>0.12.0
|
||||||
|
# requires https://github.com/pytorch/ao/pull/2606 in an ao release
|
||||||
|
raise ValueError(
|
||||||
|
"lr groups (`loraplus_lr_ratio`, `embedding_lr_scale`, `embedding_lr`, `lr_groups`) are not "
|
||||||
|
"supported with ao low-bit optimizers until ao>0.12.0. "
|
||||||
|
"Please refer to https://github.com/pytorch/ao/pull/2606."
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_tensor_parallel_size_update_ds_json(cls, data):
|
def check_tensor_parallel_size_update_ds_json(cls, data):
|
||||||
|
|||||||
Reference in New Issue
Block a user