feat: LoRA kernel support for bias, dropout, dora, embeddings (#3528) [skip ci]
* feat: LoRA kernel support for bias, dropout, dora, embeddings * chore: lint * chore: lint * address PR feedback, add regression tests, add fsdp2 tests for lora kernels * update tests for new sigs * update tests now that bias and dropout are supported
This commit is contained in:
@@ -102,7 +102,7 @@ def mock_proj():
|
||||
def test_get_lora_parameters(mock_proj):
|
||||
"""Tests get_lora_parameters function"""
|
||||
# Test with LoRA enabled
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
||||
|
||||
assert isinstance(W, torch.Tensor)
|
||||
assert W.shape == (128, 64)
|
||||
@@ -113,13 +113,13 @@ def test_get_lora_parameters(mock_proj):
|
||||
|
||||
# Test with LoRA disabled
|
||||
mock_proj.disable_adapters = True
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
# Test with merged state
|
||||
mock_proj.disable_adapters = False
|
||||
mock_proj.merged = True
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user