removing deepspeed guard for LoRA Triton kernels
This commit is contained in:
@@ -1224,17 +1224,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
):
|
):
|
||||||
capabilities = data.get("capabilities")
|
capabilities = data.get("capabilities")
|
||||||
is_fsdp = data.get("fsdp") is not None
|
is_fsdp = data.get("fsdp") is not None
|
||||||
is_deepspeed = data.get("deepspeed") is not None
|
|
||||||
|
|
||||||
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
||||||
if is_fsdp:
|
if is_fsdp:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP."
|
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP."
|
||||||
)
|
)
|
||||||
if is_deepspeed:
|
|
||||||
raise ValueError(
|
|
||||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with DeepSpeed."
|
|
||||||
)
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
Reference in New Issue
Block a user