fix: improve lora kernels failure message and handle trust_remote_code (#3378) [skip ci]
* fix: improve lora kernels failure message and handle trust_remote_code * chore: re-order model guides
This commit is contained in:
@@ -251,7 +251,6 @@ website:
|
||||
- docs/models/olmo3.qmd
|
||||
- docs/models/trinity.qmd
|
||||
- docs/models/arcee.qmd
|
||||
- docs/models/mistral.qmd
|
||||
- section: "Ministral3"
|
||||
contents:
|
||||
- docs/models/ministral3.qmd
|
||||
@@ -266,6 +265,7 @@ website:
|
||||
- docs/models/mistral-small.qmd
|
||||
- docs/models/voxtral.qmd
|
||||
- docs/models/devstral.qmd
|
||||
- docs/models/mistral.qmd
|
||||
- docs/models/llama-4.qmd
|
||||
- docs/models/llama-2.qmd
|
||||
- docs/models/qwen3-next.qmd
|
||||
|
||||
@@ -89,6 +89,10 @@ lora_o_kernel: true
|
||||
Currently, LoRA kernels are not supported for RLHF training, only SFT.
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
LoRA kernels do not support remote modeling code.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
||||
|
||||
@@ -169,7 +169,8 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
return attention_cls
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(
|
||||
f"Could not import attention class for model_type: {model_type}. "
|
||||
f"Axolotl could not import attention class for model_type: {model_type}. "
|
||||
"Please raise an Issue and turn off lora kernels to continue training. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
@@ -1271,6 +1271,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
):
|
||||
return data
|
||||
|
||||
# Skip if trust_remote_code is enabled, as lora kernels are not compatible
|
||||
if data.get("trust_remote_code"):
|
||||
return data
|
||||
|
||||
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
|
||||
if data.get("lora_dropout") != 0:
|
||||
return data
|
||||
|
||||
@@ -690,6 +690,21 @@ class LoRAValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_kernels_trust_remote_code(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
) and data.get("trust_remote_code"):
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
||||
"compatible with trust_remote_code. Please disable trust_remote_code "
|
||||
"or explicitly set lora_*_kernel to false."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class RLValidationMixin:
|
||||
"""Validation methods related to RL training configuration."""
|
||||
|
||||
@@ -90,3 +90,62 @@ class TestLoRAConfigValidation:
|
||||
}
|
||||
)
|
||||
validate_config(invalid_config)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kernel_field", ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
|
||||
)
|
||||
def test_lora_kernels_trust_remote_code_incompatible(self, kernel_field):
|
||||
"""Test that lora kernels are incompatible with trust_remote_code"""
|
||||
with pytest.raises(ValueError, match="not compatible with trust_remote_code"):
|
||||
invalid_config = DictDefault(
|
||||
{
|
||||
"adapter": "lora",
|
||||
kernel_field: True,
|
||||
"trust_remote_code": True,
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
validate_config(invalid_config)
|
||||
|
||||
def test_lora_kernels_trust_remote_code_false(self):
|
||||
"""Test that lora kernels work when trust_remote_code is false"""
|
||||
# Test with trust_remote_code=False, lora kernels should be allowed
|
||||
valid_config = DictDefault(
|
||||
{
|
||||
"adapter": "lora",
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
"lora_o_kernel": True,
|
||||
"trust_remote_code": False,
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
result = validate_config(valid_config)
|
||||
assert result["lora_mlp_kernel"] is True
|
||||
assert result["lora_qkv_kernel"] is True
|
||||
assert result["lora_o_kernel"] is True
|
||||
|
||||
# Test with trust_remote_code=None (unset), kernels should be allowed
|
||||
valid_config = DictDefault(
|
||||
{
|
||||
"adapter": "lora",
|
||||
"lora_qkv_kernel": True,
|
||||
"trust_remote_code": None,
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
result = validate_config(valid_config)
|
||||
assert result["lora_qkv_kernel"] is True
|
||||
assert result["trust_remote_code"] is None
|
||||
|
||||
Reference in New Issue
Block a user