From a2da85257684261c26a545726c6b2cff0b7352e0 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 10 Feb 2026 17:58:40 +0700 Subject: [PATCH] fix: improve lora kernels failure message and handle trust_remote_code (#3378) [skip ci] * fix: improve lora kernels failure message and handle trust_remote_code * chore: re-order model guides --- _quarto.yml | 2 +- docs/lora_optims.qmd | 4 ++ src/axolotl/monkeypatch/lora_kernels.py | 3 +- src/axolotl/utils/schemas/config.py | 4 ++ src/axolotl/utils/schemas/validation.py | 15 +++++ .../utils/lora/test_config_validation_lora.py | 59 +++++++++++++++++++ 6 files changed, 85 insertions(+), 2 deletions(-) diff --git a/_quarto.yml b/_quarto.yml index 7de2be6a7..4534c0a0e 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -251,7 +251,6 @@ website: - docs/models/olmo3.qmd - docs/models/trinity.qmd - docs/models/arcee.qmd - - docs/models/mistral.qmd - section: "Ministral3" contents: - docs/models/ministral3.qmd @@ -266,6 +265,7 @@ website: - docs/models/mistral-small.qmd - docs/models/voxtral.qmd - docs/models/devstral.qmd + - docs/models/mistral.qmd - docs/models/llama-4.qmd - docs/models/llama-2.qmd - docs/models/qwen3-next.qmd diff --git a/docs/lora_optims.qmd b/docs/lora_optims.qmd index 40893387b..7006b5d19 100644 --- a/docs/lora_optims.qmd +++ b/docs/lora_optims.qmd @@ -89,6 +89,10 @@ lora_o_kernel: true Currently, LoRA kernels are not supported for RLHF training, only SFT. ::: +::: {.callout-warning} +LoRA kernels do not support remote modeling code. +::: + ## Requirements - One or more NVIDIA or AMD GPUs (in order to use the Triton kernels) diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 8e335fe4c..2972c6285 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -169,7 +169,8 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]: return attention_cls except (ImportError, AttributeError) as e: raise ValueError( - f"Could not import attention class for model_type: {model_type}. " + f"Axolotl could not import attention class for model_type: {model_type}. " + "Please raise an Issue and turn off lora kernels to continue training. " f"Error: {str(e)}" ) from e diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index d858fdbce..91bc221ac 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1271,6 +1271,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): ): return data + # Skip if trust_remote_code is enabled, as lora kernels are not compatible + if data.get("trust_remote_code"): + return data + # Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks if data.get("lora_dropout") != 0: return data diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index bde367e0e..783017405 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -690,6 +690,21 @@ class LoRAValidationMixin: ) return data + @model_validator(mode="before") + @classmethod + def check_lora_kernels_trust_remote_code(cls, data): + if ( + data.get("lora_mlp_kernel") + or data.get("lora_qkv_kernel") + or data.get("lora_o_kernel") + ) and data.get("trust_remote_code"): + raise ValueError( + "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not " + "compatible with trust_remote_code. Please disable trust_remote_code " + "or explicitly set lora_*_kernel to false." + ) + return data + class RLValidationMixin: """Validation methods related to RL training configuration.""" diff --git a/tests/utils/lora/test_config_validation_lora.py b/tests/utils/lora/test_config_validation_lora.py index a22e2a5b7..9d97288b6 100644 --- a/tests/utils/lora/test_config_validation_lora.py +++ b/tests/utils/lora/test_config_validation_lora.py @@ -90,3 +90,62 @@ class TestLoRAConfigValidation: } ) validate_config(invalid_config) + + @pytest.mark.parametrize( + "kernel_field", ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"] + ) + def test_lora_kernels_trust_remote_code_incompatible(self, kernel_field): + """Test that lora kernels are incompatible with trust_remote_code""" + with pytest.raises(ValueError, match="not compatible with trust_remote_code"): + invalid_config = DictDefault( + { + "adapter": "lora", + kernel_field: True, + "trust_remote_code": True, + "datasets": [{"path": "dummy_dataset", "type": "alpaca"}], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-5, + "base_model": "dummy_model", + } + ) + validate_config(invalid_config) + + def test_lora_kernels_trust_remote_code_false(self): + """Test that lora kernels work when trust_remote_code is false""" + # Test with trust_remote_code=False, lora kernels should be allowed + valid_config = DictDefault( + { + "adapter": "lora", + "lora_mlp_kernel": True, + "lora_qkv_kernel": True, + "lora_o_kernel": True, + "trust_remote_code": False, + "datasets": [{"path": "dummy_dataset", "type": "alpaca"}], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-5, + "base_model": "dummy_model", + } + ) + result = validate_config(valid_config) + assert result["lora_mlp_kernel"] is True + assert result["lora_qkv_kernel"] is True + assert result["lora_o_kernel"] is True + + # Test with trust_remote_code=None (unset), kernels should be allowed + valid_config = DictDefault( + { + "adapter": "lora", + "lora_qkv_kernel": True, + "trust_remote_code": None, + "datasets": [{"path": "dummy_dataset", "type": "alpaca"}], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-5, + "base_model": "dummy_model", + } + ) + result = validate_config(valid_config) + assert result["lora_qkv_kernel"] is True + assert result["trust_remote_code"] is None