Compare commits

..

5 Commits

Author SHA1 Message Date
Dan Saunders
822a8a6931 pylint 2025-02-18 19:59:17 +00:00
Dan Saunders
1a51180637 removing unused function 2025-02-18 19:36:03 +00:00
Dan Saunders
7562aadf89 fix 2025-02-18 19:13:09 +00:00
Dan Saunders
479f5e18dd Small updates 2025-02-18 19:08:27 +00:00
Dan Saunders
945dcc5020 move patching to post-model load to improve applicability 2025-02-18 19:00:12 +00:00
7 changed files with 123 additions and 209 deletions

View File

@@ -4,13 +4,12 @@ import importlib
import inspect
import logging
import types
from typing import Type
import torch
from accelerate.logging import get_logger
from peft import PeftModelForCausalLM
from torch import nn
from transformers import AutoConfig
from transformers.modeling_utils import PreTrainedModel
from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
@@ -96,108 +95,90 @@ def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tens
return attn_output
def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
"""
Get the appropriate attention class by inspecting the model config.
Uses dynamic import to support any model architecture that follows
the standard transformers naming convention.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
The appropriate attention class for the model.
Raises:
ValueError: If `base_model` not specified or attention class cannot be imported
ImportError: If the model module or attention class doesn't exist
"""
if "base_model" not in cfg:
raise ValueError("base_model must be specified in config")
# Get model config without loading the model
model_config = AutoConfig.from_pretrained(cfg["base_model"])
model_type = model_config.model_type
# Special case for model_type = "qwen2"
if model_type == "qwen2":
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
return Qwen2Attention
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = __import__(
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
)
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
return attention_cls
except (ImportError, AttributeError) as e:
raise ValueError(
f"Could not import attention class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e
# pylint: disable=protected-access
def patch_self_attn_lora(cfg: DictDefault):
def patch_self_attn_lora(model: PreTrainedModel):
"""
Given an `axolotl` config, this method patches the inferred attention class forward
pass with optimized LoRA implementations.
Patches the attention classes in a transformer model with optimized LoRA implementations.
It modifies the attention class to use optimized QKV and output projections. The
original implementation is preserved and can be restored if needed.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
model: A HuggingFace transformers model.
Raises:
AssertionError: If the required code blocks are not found in the attention
implementation.
"""
attention_cls = get_attention_cls_from_config(cfg)
# Find all attention modules in the model
attention_modules = [
module
for module in model.modules()
if "attention" in module.__class__.__name__.lower()
and hasattr(module, "forward")
]
# Check if already patched
if hasattr(attention_cls, "_original_forward"):
LOG.info(f"{attention_cls.__name__} already patched")
if not attention_modules:
LOG.warning("No attention modules found in model")
return
self_attn_forward = inspect.getsource(attention_cls.forward)
attention_cls._original_forward = self_attn_forward
self_attn_forward, _ = detab_code(self_attn_forward)
attention_classes = {type(module) for module in attention_modules}
LOG.info(f"Found attention classes: {[cls.__name__ for cls in attention_classes]}")
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
for attention_cls in attention_classes:
# Skip if already patched
if hasattr(attention_cls, "_original_forward"):
LOG.info(f"{attention_cls.__name__} already patched")
continue
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",
"def axolotl_attn_forward(",
1,
)
# Get and store original forward implementation
self_attn_forward = inspect.getsource(attention_cls.forward)
attention_cls._original_forward = self_attn_forward
# Load necessary imports
module_name = attention_cls.__module__
module = importlib.import_module(module_name)
# Remove indentation
self_attn_forward, _ = detab_code(self_attn_forward)
items_to_import = []
for item in dir(module):
if item in self_attn_forward:
items_to_import.append(item)
# Verify required code blocks exist
assert (
ORIGINAL_QKV_CODE in self_attn_forward
), f"Original QKV code not found in {attention_cls.__name__}"
assert (
ORIGINAL_O_CODE in self_attn_forward
), f"Original O code not found in {attention_cls.__name__}"
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
# Replace code blocks
self_attn_forward = self_attn_forward.replace(
ORIGINAL_QKV_CODE, PATCHED_QKV_CODE
)
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",
"def axolotl_attn_forward(",
1,
)
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
attention_cls.forward = (
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
# Import necessary symbols from the attention module
module_name = attention_cls.__module__
module = importlib.import_module(module_name)
items_to_import = []
for item in dir(module):
if item in self_attn_forward:
items_to_import.append(item)
if items_to_import:
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
# Execute the new implementation
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
attention_cls.forward = (
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
def apply_lora_kernel_patches(

View File

@@ -439,11 +439,6 @@ class ModelLoader:
patch_mistral_cross_entropy()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.cfg)
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
@@ -1028,6 +1023,12 @@ class ModelLoader:
integrate_rope_embeddings()
def apply_lora_patch(self) -> None:
"""Applies patching relevant to LoRA Triton kernels if enabled."""
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.model)
if (
self.cfg.lora_mlp_kernel
or self.cfg.lora_qkv_kernel
@@ -1181,6 +1182,7 @@ class ModelLoader:
if self.cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
# TODO: Deprecate this.
self.apply_unsloth_lora_patch()
self.apply_lora_patch()
@@ -1201,9 +1203,7 @@ def load_model(
reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
"""
Load a model for a given configuration and tokenizer.
"""
"""Load a model for a given configuration and tokenizer."""
loader = ModelLoader(
cfg,
tokenizer,

View File

@@ -5,12 +5,12 @@ import numpy as np
def get_dataset_lengths(dataset):
if "length" in dataset.column_names:
lengths = np.array(dataset["length"])
elif "position_ids" in dataset.column_names:
position_ids = dataset["position_ids"]
if "length" in dataset.data.column_names:
lengths = np.array(dataset.data.column("length"))
elif "position_ids" in dataset.data.column_names:
position_ids = dataset.data.column("position_ids")
lengths = np.array([x[-1] + 1 for x in position_ids])
else:
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
return lengths

View File

@@ -9,16 +9,14 @@ from transformers import AutoModelForCausalLM, LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
apply_lora_qkv,
)
from axolotl.monkeypatch.lora_kernels import (
apply_lora_kernel_patches,
patch_self_attn_lora,
)
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
from axolotl.utils.dict import DictDefault
MODEL_CONFIGS = [
@@ -65,15 +63,45 @@ def small_llama_model():
return LlamaForCausalLM(LlamaConfig(**config))
def test_attention_patching_integration():
"""Test attention patching in integration context."""
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
# pylint: disable=duplicate-code
@pytest.fixture
def minimal_cfg():
"Config of real HuggingFace Hub model"
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)
return cfg
def test_attention_patching_integration(minimal_cfg):
"""Test attention patching in integration context."""
# Store the original implementation
original_forward = getattr(LlamaAttention, "forward")
# Apply patch
patch_self_attn_lora(cfg)
# Load model
_, _ = load_model_and_tokenizer(cfg=minimal_cfg)
# Get the new forward method
patched_forward = LlamaAttention.forward
@@ -376,38 +404,10 @@ def test_model_architecture(model_config):
# pylint: disable=duplicate-code
def test_kernel_training_integration():
def test_kernel_training_integration(minimal_cfg):
"""Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer
# Create minimal config
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)
# Load model
model, _ = load_model_and_tokenizer(cfg=cfg)
model, _ = load_model_and_tokenizer(cfg=minimal_cfg)
# Verify correct activation function
layer = model.model.model.layers[0]

View File

@@ -125,12 +125,6 @@ def fixture_llama3_tokenizer():
return tokenizer
@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True)
def fixture_smollm2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
return tokenizer
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
def fixture_mistralv03_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(

View File

@@ -1,61 +0,0 @@
"""
Tests for loading DPO preference datasets with chatml formatting
"""
import unittest
import pytest
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="minimal_dpo_cfg")
def fixture_cfg():
return DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"rl": "dpo",
"learning_rate": 0.000001,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"sequence_len": 2048,
}
)
class TestDPOChatml:
"""
Test loading DPO preference datasets with chatml formatting
"""
def test_default(self, minimal_dpo_cfg):
cfg = DictDefault(
{
"datasets": [
{
"path": "argilla/distilabel-intel-orca-dpo-pairs",
"type": "chatml",
"split": "train[:1%]",
}
]
}
| minimal_dpo_cfg
)
# test that dpo.load works
load_dpo("chatml", cfg)
# now actually load the datasets with the strategy
train_ds, _ = load_prepare_preference_datasets(cfg)
assert train_ds[0]["prompt"].startswith("<|im_start|>")
assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n")
assert "chosen" in train_ds[0]
assert "rejected" in train_ds[0]
if __name__ == "__main__":
unittest.main()

View File

@@ -7,7 +7,6 @@ from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.completion import load
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import drop_long_seq_in_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -19,6 +18,11 @@ def fixture_tokenizer():
return tokenizer
@pytest.fixture(name="max_seq_length")
def fixture_max_seq_length():
return 4096
class TestBatchedSamplerPacking:
"""
Test class for packing streaming dataset sequences
@@ -33,7 +37,6 @@ class TestBatchedSamplerPacking:
(2, 2),
],
)
@pytest.mark.parametrize("max_seq_length", [4096, 512])
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
@@ -59,9 +62,6 @@ class TestBatchedSamplerPacking:
dataset,
)
train_dataset = concatenate_datasets([dataset_wrapper])
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg)
lengths = get_dataset_lengths(train_dataset)
batch_sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
@@ -90,7 +90,7 @@ class TestBatchedSamplerPacking:
batch_idxs.extend(pack)
for batch in loader:
assert batch["input_ids"].numel() <= batch_size * max_seq_length
assert len(batch["input_ids"]) <= batch_size * max_seq_length
assert batch["input_ids"].shape[1] == max_seq_length
original_idxs = set(range(len(train_dataset)))