fix: improve lora kernels failure message and handle trust_remote_code (#3378) [skip ci]
* fix: improve lora kernels failure message and handle trust_remote_code * chore: re-order model guides
This commit is contained in:
@@ -169,7 +169,8 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
return attention_cls
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(
|
||||
f"Could not import attention class for model_type: {model_type}. "
|
||||
f"Axolotl could not import attention class for model_type: {model_type}. "
|
||||
"Please raise an Issue and turn off lora kernels to continue training. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@@ -1271,6 +1271,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
):
|
||||
return data
|
||||
|
||||
# Skip if trust_remote_code is enabled, as lora kernels are not compatible
|
||||
if data.get("trust_remote_code"):
|
||||
return data
|
||||
|
||||
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
|
||||
if data.get("lora_dropout") != 0:
|
||||
return data
|
||||
|
||||
@@ -690,6 +690,21 @@ class LoRAValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_kernels_trust_remote_code(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
) and data.get("trust_remote_code"):
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
||||
"compatible with trust_remote_code. Please disable trust_remote_code "
|
||||
"or explicitly set lora_*_kernel to false."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class RLValidationMixin:
|
||||
"""Validation methods related to RL training configuration."""
|
||||
|
||||
Reference in New Issue
Block a user