add validation to prevent 8bit lora finetuning on H100s (#1827)
This commit is contained in:
@@ -1267,6 +1267,19 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_hopper_8bit_lora(cls, data):
|
||||||
|
is_sm_90: bool = (
|
||||||
|
data["capabilities"]
|
||||||
|
and data["capabilities"].get("compute_capability") == "sm_90"
|
||||||
|
)
|
||||||
|
if data.get("adapter") and data.get("load_in_8bit") and is_sm_90:
|
||||||
|
# see https://github.com/bitsandbytes-foundation/bitsandbytes/issues/538#issuecomment-2262945464
|
||||||
|
raise ValueError("8-bit LoRA is not supported on Hopper GPUs")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_fsdp_deepspeed(cls, data):
|
def check_fsdp_deepspeed(cls, data):
|
||||||
|
|||||||
Reference in New Issue
Block a user