feat: LoRA kernel support for bias, dropout, dora, embeddings (#3528) [skip ci]

* feat: LoRA kernel support for bias, dropout, dora, embeddings

* chore: lint

* chore: lint

* address PR feedback, add regression tests, add fsdp2 tests for lora kernels

* update tests for new sigs

* update tests now that bias and dropout are supported
This commit is contained in:
Wing Lian
2026-03-22 13:53:19 -04:00
committed by GitHub
parent a67392c427
commit b3289fd190
13 changed files with 2847 additions and 448 deletions

View File

@@ -153,7 +153,7 @@ class TestLoraFP8Guard(unittest.TestCase):
proj.base_layer = base_layer
W, b, quant_state, A, B, s = get_lora_parameters(proj)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
# quant_state should be None since weight is bf16, not FP8
self.assertIsNone(quant_state)
@@ -174,7 +174,7 @@ class TestLoraFP8Guard(unittest.TestCase):
scale_inv = torch.ones(1)
base_layer.weight_scale_inv = scale_inv
W, b, quant_state, A, B, s = get_lora_parameters(proj)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
self.assertIs(quant_state, scale_inv)