Fix: lora kernel pre-patch applied despite post-patch not applied (#2772)
* fix: do not pre-patch self attention if lora dropout non-zero * fix: add test to check patch not applied * fix: test * fix: test config check * fix where we check so that tests don't break * fix: test --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -25,7 +25,9 @@ from axolotl.loaders.model import ModelLoader
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
from axolotl.monkeypatch.lora_kernels import (
|
||||
apply_lora_kernel_patches,
|
||||
find_self_attn_in_layer,
|
||||
get_attention_cls_from_config,
|
||||
get_layers,
|
||||
patch_self_attn_lora,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -501,3 +503,63 @@ def test_kernel_training_integration_auto_enable(temp_dir):
|
||||
break
|
||||
|
||||
assert found_patched_attn
|
||||
|
||||
|
||||
def test_kernel_training_integration_dropout_non_zero():
|
||||
"""Test model loading with dropout non-zero should not patch."""
|
||||
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
|
||||
# Create minimal config
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||
"learning_rate": 0.000001,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.1,
|
||||
"lora_target_linear": True,
|
||||
"sequence_len": 1024,
|
||||
}
|
||||
)
|
||||
|
||||
# Get original attention class
|
||||
attention_cls = get_attention_cls_from_config(cfg)
|
||||
|
||||
# Store original state before patching
|
||||
original_forward_method = attention_cls.forward
|
||||
|
||||
# Load model
|
||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
# We call modelloader as that's where the patches are applied
|
||||
# despite the fact that we're not using it to load the model
|
||||
model_loader = ModelLoader(cfg, tokenizer)
|
||||
|
||||
# Apply patch
|
||||
model_loader.patch_manager._apply_self_attention_lora_patch() # pylint: disable=protected-access
|
||||
|
||||
# Verify patch was not applied
|
||||
assert attention_cls.forward == original_forward_method
|
||||
|
||||
# Apply apply_lora_kernel_patches
|
||||
model_loader.patch_manager._apply_lora_kernel_patch( # pylint: disable=protected-access
|
||||
model
|
||||
)
|
||||
|
||||
# Verify patch was not applied
|
||||
layers = get_layers(model)
|
||||
for layer in layers:
|
||||
for self_attn in find_self_attn_in_layer(layer):
|
||||
assert not hasattr(self_attn, "apply_qkv")
|
||||
assert not hasattr(self_attn, "apply_o")
|
||||
|
||||
Reference in New Issue
Block a user