add validation to prevent 8bit lora finetuning on H100s (#1827)

This commit is contained in:
Wing Lian
2024-08-16 21:32:00 -04:00
committed by GitHub
parent 803fed3e90
commit b1d2921222

View File

@@ -1267,6 +1267,19 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
return data
@model_validator(mode="before")
@classmethod
def check_hopper_8bit_lora(cls, data):
is_sm_90: bool = (
data["capabilities"]
and data["capabilities"].get("compute_capability") == "sm_90"
)
if data.get("adapter") and data.get("load_in_8bit") and is_sm_90:
# see https://github.com/bitsandbytes-foundation/bitsandbytes/issues/538#issuecomment-2262945464
raise ValueError("8-bit LoRA is not supported on Hopper GPUs")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_deepspeed(cls, data):