fix: improve lora kernels failure message and handle trust_remote_code (#3378) [skip ci]
* fix: improve lora kernels failure message and handle trust_remote_code * chore: re-order model guides
This commit is contained in:
@@ -90,3 +90,62 @@ class TestLoRAConfigValidation:
|
||||
}
|
||||
)
|
||||
validate_config(invalid_config)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kernel_field", ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
|
||||
)
|
||||
def test_lora_kernels_trust_remote_code_incompatible(self, kernel_field):
|
||||
"""Test that lora kernels are incompatible with trust_remote_code"""
|
||||
with pytest.raises(ValueError, match="not compatible with trust_remote_code"):
|
||||
invalid_config = DictDefault(
|
||||
{
|
||||
"adapter": "lora",
|
||||
kernel_field: True,
|
||||
"trust_remote_code": True,
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
validate_config(invalid_config)
|
||||
|
||||
def test_lora_kernels_trust_remote_code_false(self):
|
||||
"""Test that lora kernels work when trust_remote_code is false"""
|
||||
# Test with trust_remote_code=False, lora kernels should be allowed
|
||||
valid_config = DictDefault(
|
||||
{
|
||||
"adapter": "lora",
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
"lora_o_kernel": True,
|
||||
"trust_remote_code": False,
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
result = validate_config(valid_config)
|
||||
assert result["lora_mlp_kernel"] is True
|
||||
assert result["lora_qkv_kernel"] is True
|
||||
assert result["lora_o_kernel"] is True
|
||||
|
||||
# Test with trust_remote_code=None (unset), kernels should be allowed
|
||||
valid_config = DictDefault(
|
||||
{
|
||||
"adapter": "lora",
|
||||
"lora_qkv_kernel": True,
|
||||
"trust_remote_code": None,
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
result = validate_config(valid_config)
|
||||
assert result["lora_qkv_kernel"] is True
|
||||
assert result["trust_remote_code"] is None
|
||||
|
||||
Reference in New Issue
Block a user