Compare commits
1 Commits
kernelize-
...
lora_kerne
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ede973b76c |
@@ -9,7 +9,6 @@ liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
|
||||
chat_template: llama3
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
|
||||
@@ -15,7 +15,6 @@ lora_model_dir:
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
# Currently, we don't support dropout with our custom Triton kernels
|
||||
|
||||
@@ -102,8 +102,8 @@ def matmul_lora(
|
||||
del W
|
||||
|
||||
if A is not None:
|
||||
A, B = A.t(), B.t()
|
||||
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
|
||||
A, B = A.t().to(dtype), B.t().to(dtype)
|
||||
out += (X @ A) @ (s * B)
|
||||
|
||||
return out.view(batch, seq_len, -1) if reshape else out
|
||||
|
||||
|
||||
@@ -221,7 +221,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
transformer_auto_wrap_policy,
|
||||
)
|
||||
|
||||
# We need the `auto_wrap_policy` original type to create a custom poilicy function for sharding
|
||||
# We need the `auto_wrap_policy` original type to create a custom policy function for sharding
|
||||
# This is because `fully_shard` doesn't support old auto wrap policies, rather we have to imitate the behaviour
|
||||
if fsdp2_plugin.auto_wrap_policy is transformer_auto_wrap_policy:
|
||||
pass # auto_wrap_policy_type = "transformer"
|
||||
|
||||
Reference in New Issue
Block a user