Compare commits
5 Commits
fix_kto
...
822a8a6931
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
822a8a6931 | ||
|
|
1a51180637 | ||
|
|
7562aadf89 | ||
|
|
479f5e18dd | ||
|
|
945dcc5020 |
@@ -4,13 +4,12 @@ import importlib
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import types
|
import types
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from peft import PeftModelForCausalLM
|
from peft import PeftModelForCausalLM
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoConfig
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
from axolotl.kernels.lora import (
|
from axolotl.kernels.lora import (
|
||||||
apply_lora_mlp_geglu,
|
apply_lora_mlp_geglu,
|
||||||
@@ -96,108 +95,90 @@ def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tens
|
|||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
|
||||||
"""
|
|
||||||
Get the appropriate attention class by inspecting the model config.
|
|
||||||
Uses dynamic import to support any model architecture that follows
|
|
||||||
the standard transformers naming convention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The appropriate attention class for the model.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If `base_model` not specified or attention class cannot be imported
|
|
||||||
ImportError: If the model module or attention class doesn't exist
|
|
||||||
"""
|
|
||||||
if "base_model" not in cfg:
|
|
||||||
raise ValueError("base_model must be specified in config")
|
|
||||||
|
|
||||||
# Get model config without loading the model
|
|
||||||
model_config = AutoConfig.from_pretrained(cfg["base_model"])
|
|
||||||
model_type = model_config.model_type
|
|
||||||
|
|
||||||
# Special case for model_type = "qwen2"
|
|
||||||
if model_type == "qwen2":
|
|
||||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
|
|
||||||
|
|
||||||
return Qwen2Attention
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Dynamically import the module and attention class
|
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
|
||||||
module = __import__(
|
|
||||||
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
|
|
||||||
)
|
|
||||||
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
|
|
||||||
|
|
||||||
return attention_cls
|
|
||||||
except (ImportError, AttributeError) as e:
|
|
||||||
raise ValueError(
|
|
||||||
f"Could not import attention class for model_type: {model_type}. "
|
|
||||||
f"Error: {str(e)}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
def patch_self_attn_lora(cfg: DictDefault):
|
def patch_self_attn_lora(model: PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
Given an `axolotl` config, this method patches the inferred attention class forward
|
Patches the attention classes in a transformer model with optimized LoRA implementations.
|
||||||
pass with optimized LoRA implementations.
|
|
||||||
|
|
||||||
It modifies the attention class to use optimized QKV and output projections. The
|
It modifies the attention class to use optimized QKV and output projections. The
|
||||||
original implementation is preserved and can be restored if needed.
|
original implementation is preserved and can be restored if needed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
model: A HuggingFace transformers model.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AssertionError: If the required code blocks are not found in the attention
|
AssertionError: If the required code blocks are not found in the attention
|
||||||
implementation.
|
implementation.
|
||||||
"""
|
"""
|
||||||
attention_cls = get_attention_cls_from_config(cfg)
|
# Find all attention modules in the model
|
||||||
|
attention_modules = [
|
||||||
|
module
|
||||||
|
for module in model.modules()
|
||||||
|
if "attention" in module.__class__.__name__.lower()
|
||||||
|
and hasattr(module, "forward")
|
||||||
|
]
|
||||||
|
|
||||||
# Check if already patched
|
if not attention_modules:
|
||||||
if hasattr(attention_cls, "_original_forward"):
|
LOG.warning("No attention modules found in model")
|
||||||
LOG.info(f"{attention_cls.__name__} already patched")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self_attn_forward = inspect.getsource(attention_cls.forward)
|
attention_classes = {type(module) for module in attention_modules}
|
||||||
attention_cls._original_forward = self_attn_forward
|
LOG.info(f"Found attention classes: {[cls.__name__ for cls in attention_classes]}")
|
||||||
self_attn_forward, _ = detab_code(self_attn_forward)
|
|
||||||
|
|
||||||
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
|
for attention_cls in attention_classes:
|
||||||
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
|
# Skip if already patched
|
||||||
|
if hasattr(attention_cls, "_original_forward"):
|
||||||
|
LOG.info(f"{attention_cls.__name__} already patched")
|
||||||
|
continue
|
||||||
|
|
||||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
|
# Get and store original forward implementation
|
||||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
self_attn_forward = inspect.getsource(attention_cls.forward)
|
||||||
self_attn_forward = self_attn_forward.replace(
|
attention_cls._original_forward = self_attn_forward
|
||||||
"def forward(",
|
|
||||||
"def axolotl_attn_forward(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load necessary imports
|
# Remove indentation
|
||||||
module_name = attention_cls.__module__
|
self_attn_forward, _ = detab_code(self_attn_forward)
|
||||||
module = importlib.import_module(module_name)
|
|
||||||
|
|
||||||
items_to_import = []
|
# Verify required code blocks exist
|
||||||
for item in dir(module):
|
assert (
|
||||||
if item in self_attn_forward:
|
ORIGINAL_QKV_CODE in self_attn_forward
|
||||||
items_to_import.append(item)
|
), f"Original QKV code not found in {attention_cls.__name__}"
|
||||||
|
assert (
|
||||||
|
ORIGINAL_O_CODE in self_attn_forward
|
||||||
|
), f"Original O code not found in {attention_cls.__name__}"
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
# Replace code blocks
|
||||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
self_attn_forward = self_attn_forward.replace(
|
||||||
globals(),
|
ORIGINAL_QKV_CODE, PATCHED_QKV_CODE
|
||||||
)
|
)
|
||||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
||||||
|
self_attn_forward = self_attn_forward.replace(
|
||||||
|
"def forward(",
|
||||||
|
"def axolotl_attn_forward(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
|
# Import necessary symbols from the attention module
|
||||||
attention_cls.forward = (
|
module_name = attention_cls.__module__
|
||||||
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
module = importlib.import_module(module_name)
|
||||||
)
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(module):
|
||||||
|
if item in self_attn_forward:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
if items_to_import:
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the new implementation
|
||||||
|
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
|
||||||
|
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
|
||||||
|
attention_cls.forward = (
|
||||||
|
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def apply_lora_kernel_patches(
|
def apply_lora_kernel_patches(
|
||||||
|
|||||||
@@ -439,11 +439,6 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_mistral_cross_entropy()
|
patch_mistral_cross_entropy()
|
||||||
|
|
||||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
|
||||||
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
|
||||||
|
|
||||||
patch_self_attn_lora(self.cfg)
|
|
||||||
|
|
||||||
def patch_attention(self) -> None:
|
def patch_attention(self) -> None:
|
||||||
if hasattr(self.model_config, "model_type"):
|
if hasattr(self.model_config, "model_type"):
|
||||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||||
@@ -1028,6 +1023,12 @@ class ModelLoader:
|
|||||||
integrate_rope_embeddings()
|
integrate_rope_embeddings()
|
||||||
|
|
||||||
def apply_lora_patch(self) -> None:
|
def apply_lora_patch(self) -> None:
|
||||||
|
"""Applies patching relevant to LoRA Triton kernels if enabled."""
|
||||||
|
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
|
||||||
|
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
||||||
|
|
||||||
|
patch_self_attn_lora(self.model)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.cfg.lora_mlp_kernel
|
self.cfg.lora_mlp_kernel
|
||||||
or self.cfg.lora_qkv_kernel
|
or self.cfg.lora_qkv_kernel
|
||||||
@@ -1181,6 +1182,7 @@ class ModelLoader:
|
|||||||
if self.cfg.adapter is not None:
|
if self.cfg.adapter is not None:
|
||||||
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
||||||
|
|
||||||
|
# TODO: Deprecate this.
|
||||||
self.apply_unsloth_lora_patch()
|
self.apply_unsloth_lora_patch()
|
||||||
self.apply_lora_patch()
|
self.apply_lora_patch()
|
||||||
|
|
||||||
@@ -1201,9 +1203,7 @@ def load_model(
|
|||||||
reference_model: bool = False,
|
reference_model: bool = False,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||||
"""
|
"""Load a model for a given configuration and tokenizer."""
|
||||||
Load a model for a given configuration and tokenizer.
|
|
||||||
"""
|
|
||||||
loader = ModelLoader(
|
loader = ModelLoader(
|
||||||
cfg,
|
cfg,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
|||||||
@@ -9,16 +9,14 @@ from transformers import AutoModelForCausalLM, LlamaForCausalLM
|
|||||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||||
|
|
||||||
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.kernels.lora import (
|
from axolotl.kernels.lora import (
|
||||||
apply_lora_mlp_geglu,
|
apply_lora_mlp_geglu,
|
||||||
apply_lora_mlp_swiglu,
|
apply_lora_mlp_swiglu,
|
||||||
apply_lora_o,
|
apply_lora_o,
|
||||||
apply_lora_qkv,
|
apply_lora_qkv,
|
||||||
)
|
)
|
||||||
from axolotl.monkeypatch.lora_kernels import (
|
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
|
||||||
apply_lora_kernel_patches,
|
|
||||||
patch_self_attn_lora,
|
|
||||||
)
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
MODEL_CONFIGS = [
|
MODEL_CONFIGS = [
|
||||||
@@ -65,15 +63,45 @@ def small_llama_model():
|
|||||||
return LlamaForCausalLM(LlamaConfig(**config))
|
return LlamaForCausalLM(LlamaConfig(**config))
|
||||||
|
|
||||||
|
|
||||||
def test_attention_patching_integration():
|
# pylint: disable=duplicate-code
|
||||||
"""Test attention patching in integration context."""
|
@pytest.fixture
|
||||||
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
|
def minimal_cfg():
|
||||||
|
"Config of real HuggingFace Hub model"
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"learning_rate": 0.000001,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.0,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"lora_mlp_kernel": True,
|
||||||
|
"lora_qkv_kernel": True,
|
||||||
|
"lora_o_kernel": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def test_attention_patching_integration(minimal_cfg):
|
||||||
|
"""Test attention patching in integration context."""
|
||||||
# Store the original implementation
|
# Store the original implementation
|
||||||
original_forward = getattr(LlamaAttention, "forward")
|
original_forward = getattr(LlamaAttention, "forward")
|
||||||
|
|
||||||
# Apply patch
|
# Load model
|
||||||
patch_self_attn_lora(cfg)
|
_, _ = load_model_and_tokenizer(cfg=minimal_cfg)
|
||||||
|
|
||||||
# Get the new forward method
|
# Get the new forward method
|
||||||
patched_forward = LlamaAttention.forward
|
patched_forward = LlamaAttention.forward
|
||||||
@@ -376,38 +404,10 @@ def test_model_architecture(model_config):
|
|||||||
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
def test_kernel_training_integration():
|
def test_kernel_training_integration(minimal_cfg):
|
||||||
"""Test model loading with kernel patches enabled."""
|
"""Test model loading with kernel patches enabled."""
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
|
||||||
|
|
||||||
# Create minimal config
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"learning_rate": 0.000001,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"lora_mlp_kernel": True,
|
|
||||||
"lora_qkv_kernel": True,
|
|
||||||
"lora_o_kernel": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model, _ = load_model_and_tokenizer(cfg=cfg)
|
model, _ = load_model_and_tokenizer(cfg=minimal_cfg)
|
||||||
|
|
||||||
# Verify correct activation function
|
# Verify correct activation function
|
||||||
layer = model.model.model.layers[0]
|
layer = model.model.model.layers[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user