diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d85892b43..d04450428 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -9,6 +9,7 @@ on: - '.github/workflows/*.yml' - "*.[q]md" - "examples/**/*.y[a]?ml" + - ".pre-commit-config.yaml" workflow_dispatch: jobs: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index be78d0d3e..195746d2d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: hooks: - id: pylint - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.15.0 + rev: v1.16.0 hooks: - id: mypy additional_dependencies: diff --git a/src/axolotl/kernels/lora.py b/src/axolotl/kernels/lora.py index 03fca6df4..63c9e57bd 100644 --- a/src/axolotl/kernels/lora.py +++ b/src/axolotl/kernels/lora.py @@ -280,19 +280,19 @@ class LoRA_MLP(torch.autograd.Function): # Initialize and compute LoRA gradients d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None - if down_A is not None: + if down_A is not None and down_B is not None: d_down_A = h.t() @ (grad_output @ down_B.t()) d_down_B = (down_A.t() @ h.t()) @ grad_output d_down_A *= down_scale d_down_B *= down_scale - if up_A is not None: + if up_A is not None and up_B is not None: d_up_A = X.t() @ (grad_up @ up_B.t()) d_up_B = (up_A.t() @ X.t()) @ grad_up d_up_A *= up_scale d_up_B *= up_scale - if gate_A is not None: + if gate_A is not None and gate_B is not None: d_gate_A = X.t() @ (grad_gate @ gate_B.t()) d_gate_B = (gate_A.t() @ X.t()) @ grad_gate d_gate_A *= gate_scale @@ -311,7 +311,7 @@ class LoRA_MLP(torch.autograd.Function): del up_weight # Note the .to(dtype) only where mixing LoRA with base weights - if up_A is not None: + if up_A is not None and up_B is not None: dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t()) # Gate projection gradients @@ -319,7 +319,7 @@ class LoRA_MLP(torch.autograd.Function): dX += grad_gate @ gate_weight.t() del gate_weight - if gate_A is not None: + if gate_A is not None and gate_B is not None: dX += ( grad_gate @ gate_B.to(dtype).t()