Small updates
This commit is contained in:
@@ -141,6 +141,9 @@ def patch_self_attn_lora(model: PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
Patches the attention classes in a transformer model with optimized LoRA implementations.
|
Patches the attention classes in a transformer model with optimized LoRA implementations.
|
||||||
|
|
||||||
|
It modifies the attention class to use optimized QKV and output projections. The
|
||||||
|
original implementation is preserved and can be restored if needed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: A HuggingFace transformers model.
|
model: A HuggingFace transformers model.
|
||||||
|
|
||||||
|
|||||||
@@ -1023,6 +1023,12 @@ class ModelLoader:
|
|||||||
integrate_rope_embeddings()
|
integrate_rope_embeddings()
|
||||||
|
|
||||||
def apply_lora_patch(self) -> None:
|
def apply_lora_patch(self) -> None:
|
||||||
|
"""Applies patching relevant to LoRA Triton kernels if enabled."""
|
||||||
|
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
|
||||||
|
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
||||||
|
|
||||||
|
patch_self_attn_lora(self.model)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.cfg.lora_mlp_kernel
|
self.cfg.lora_mlp_kernel
|
||||||
or self.cfg.lora_qkv_kernel
|
or self.cfg.lora_qkv_kernel
|
||||||
@@ -1176,11 +1182,7 @@ class ModelLoader:
|
|||||||
if self.cfg.adapter is not None:
|
if self.cfg.adapter is not None:
|
||||||
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
||||||
|
|
||||||
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
|
# TODO: Deprecate this.
|
||||||
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
|
||||||
|
|
||||||
patch_self_attn_lora(self.model)
|
|
||||||
|
|
||||||
self.apply_unsloth_lora_patch()
|
self.apply_unsloth_lora_patch()
|
||||||
self.apply_lora_patch()
|
self.apply_lora_patch()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user