Compare commits

..

3 Commits

Author SHA1 Message Date
Dan Saunders
02efd7e83d quick formatting fix for LoRA optims doc 2025-02-19 14:17:20 +00:00
Tobias
8dfadc2b3c Fix sample packing producing longer sequences than specified by sequence_len (#2332)
* Extend MultiPackBatchSampler test to include shorter sequence length and drop long sequences filter

* Fix get_dataset_lengths for datasets that were previously filtered (e.g., with drop_long_seq_in_dataset)

* Update src/axolotl/utils/samplers/utils.py

Fix get_dataset_lengths for datasets that do not have position_ids or length attributes

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2025-02-19 12:02:35 +07:00
Wing Lian
23a9fcb0a7 make sure chatml dpo dataset loading works (#2333) 2025-02-18 16:08:40 -05:00
8 changed files with 211 additions and 124 deletions

View File

@@ -12,6 +12,7 @@ to leverage operator fusion and tensor re-use in order to improve speed and redu
memory usage during the forward and backward passes of these calculations. memory usage during the forward and backward passes of these calculations.
We currently support several common model architectures, including (but not limited to): We currently support several common model architectures, including (but not limited to):
- `llama` - `llama`
- `mistral` - `mistral`
- `qwen2` - `qwen2`

View File

@@ -4,12 +4,13 @@ import importlib
import inspect import inspect
import logging import logging
import types import types
from typing import Type
import torch import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
from peft import PeftModelForCausalLM from peft import PeftModelForCausalLM
from torch import nn from torch import nn
from transformers.modeling_utils import PreTrainedModel from transformers import AutoConfig
from axolotl.kernels.lora import ( from axolotl.kernels.lora import (
apply_lora_mlp_geglu, apply_lora_mlp_geglu,
@@ -95,90 +96,108 @@ def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tens
return attn_output return attn_output
# pylint: disable=protected-access def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
def patch_self_attn_lora(model: PreTrainedModel):
""" """
Patches the attention classes in a transformer model with optimized LoRA implementations. Get the appropriate attention class by inspecting the model config.
Uses dynamic import to support any model architecture that follows
the standard transformers naming convention.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
The appropriate attention class for the model.
Raises:
ValueError: If `base_model` not specified or attention class cannot be imported
ImportError: If the model module or attention class doesn't exist
"""
if "base_model" not in cfg:
raise ValueError("base_model must be specified in config")
# Get model config without loading the model
model_config = AutoConfig.from_pretrained(cfg["base_model"])
model_type = model_config.model_type
# Special case for model_type = "qwen2"
if model_type == "qwen2":
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
return Qwen2Attention
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = __import__(
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
)
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
return attention_cls
except (ImportError, AttributeError) as e:
raise ValueError(
f"Could not import attention class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e
# pylint: disable=protected-access
def patch_self_attn_lora(cfg: DictDefault):
"""
Given an `axolotl` config, this method patches the inferred attention class forward
pass with optimized LoRA implementations.
It modifies the attention class to use optimized QKV and output projections. The It modifies the attention class to use optimized QKV and output projections. The
original implementation is preserved and can be restored if needed. original implementation is preserved and can be restored if needed.
Args: Args:
model: A HuggingFace transformers model. cfg: Dictionary mapping `axolotl` config keys to values.
Raises: Raises:
AssertionError: If the required code blocks are not found in the attention AssertionError: If the required code blocks are not found in the attention
implementation. implementation.
""" """
# Find all attention modules in the model attention_cls = get_attention_cls_from_config(cfg)
attention_modules = [
module
for module in model.modules()
if "attention" in module.__class__.__name__.lower()
and hasattr(module, "forward")
]
if not attention_modules: # Check if already patched
LOG.warning("No attention modules found in model") if hasattr(attention_cls, "_original_forward"):
LOG.info(f"{attention_cls.__name__} already patched")
return return
attention_classes = {type(module) for module in attention_modules} self_attn_forward = inspect.getsource(attention_cls.forward)
LOG.info(f"Found attention classes: {[cls.__name__ for cls in attention_classes]}") attention_cls._original_forward = self_attn_forward
self_attn_forward, _ = detab_code(self_attn_forward)
for attention_cls in attention_classes: assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
# Skip if already patched assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
if hasattr(attention_cls, "_original_forward"):
LOG.info(f"{attention_cls.__name__} already patched")
continue
# Get and store original forward implementation self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
self_attn_forward = inspect.getsource(attention_cls.forward) self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
attention_cls._original_forward = self_attn_forward self_attn_forward = self_attn_forward.replace(
"def forward(",
"def axolotl_attn_forward(",
1,
)
# Remove indentation # Load necessary imports
self_attn_forward, _ = detab_code(self_attn_forward) module_name = attention_cls.__module__
module = importlib.import_module(module_name)
# Verify required code blocks exist items_to_import = []
assert ( for item in dir(module):
ORIGINAL_QKV_CODE in self_attn_forward if item in self_attn_forward:
), f"Original QKV code not found in {attention_cls.__name__}" items_to_import.append(item)
assert (
ORIGINAL_O_CODE in self_attn_forward
), f"Original O code not found in {attention_cls.__name__}"
# Replace code blocks exec( # pylint: disable=exec-used # nosec B102
self_attn_forward = self_attn_forward.replace( f"from {module_name} import ({', '.join(items_to_import)})",
ORIGINAL_QKV_CODE, PATCHED_QKV_CODE globals(),
) )
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE) exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
self_attn_forward = self_attn_forward.replace(
"def forward(",
"def axolotl_attn_forward(",
1,
)
# Import necessary symbols from the attention module LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
module_name = attention_cls.__module__ attention_cls.forward = (
module = importlib.import_module(module_name) axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
items_to_import = []
for item in dir(module):
if item in self_attn_forward:
items_to_import.append(item)
if items_to_import:
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
# Execute the new implementation
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
attention_cls.forward = (
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
def apply_lora_kernel_patches( def apply_lora_kernel_patches(

View File

@@ -439,6 +439,11 @@ class ModelLoader:
patch_mistral_cross_entropy() patch_mistral_cross_entropy()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.cfg)
def patch_attention(self) -> None: def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"): if hasattr(self.model_config, "model_type"):
if self.model_config.model_type == "mllama" and self.cfg.flash_attention: if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
@@ -1023,12 +1028,6 @@ class ModelLoader:
integrate_rope_embeddings() integrate_rope_embeddings()
def apply_lora_patch(self) -> None: def apply_lora_patch(self) -> None:
"""Applies patching relevant to LoRA Triton kernels if enabled."""
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.model)
if ( if (
self.cfg.lora_mlp_kernel self.cfg.lora_mlp_kernel
or self.cfg.lora_qkv_kernel or self.cfg.lora_qkv_kernel
@@ -1182,7 +1181,6 @@ class ModelLoader:
if self.cfg.adapter is not None: if self.cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", self.model.device) log_gpu_memory_usage(LOG, "after adapters", self.model.device)
# TODO: Deprecate this.
self.apply_unsloth_lora_patch() self.apply_unsloth_lora_patch()
self.apply_lora_patch() self.apply_lora_patch()
@@ -1203,7 +1201,9 @@ def load_model(
reference_model: bool = False, reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: ) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
"""Load a model for a given configuration and tokenizer.""" """
Load a model for a given configuration and tokenizer.
"""
loader = ModelLoader( loader = ModelLoader(
cfg, cfg,
tokenizer, tokenizer,

View File

@@ -5,12 +5,12 @@ import numpy as np
def get_dataset_lengths(dataset): def get_dataset_lengths(dataset):
if "length" in dataset.data.column_names: if "length" in dataset.column_names:
lengths = np.array(dataset.data.column("length")) lengths = np.array(dataset["length"])
elif "position_ids" in dataset.data.column_names: elif "position_ids" in dataset.column_names:
position_ids = dataset.data.column("position_ids") position_ids = dataset["position_ids"]
lengths = np.array([x[-1] + 1 for x in position_ids]) lengths = np.array([x[-1] + 1 for x in position_ids])
else: else:
input_ids = dataset.data.column("input_ids") input_ids = dataset["input_ids"]
lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) lengths = np.array([len(seq) for seq in input_ids])
return lengths return lengths

View File

@@ -9,14 +9,16 @@ from transformers import AutoModelForCausalLM, LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.llama.modeling_llama import LlamaAttention
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.kernels.lora import ( from axolotl.kernels.lora import (
apply_lora_mlp_geglu, apply_lora_mlp_geglu,
apply_lora_mlp_swiglu, apply_lora_mlp_swiglu,
apply_lora_o, apply_lora_o,
apply_lora_qkv, apply_lora_qkv,
) )
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches from axolotl.monkeypatch.lora_kernels import (
apply_lora_kernel_patches,
patch_self_attn_lora,
)
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
MODEL_CONFIGS = [ MODEL_CONFIGS = [
@@ -63,45 +65,15 @@ def small_llama_model():
return LlamaForCausalLM(LlamaConfig(**config)) return LlamaForCausalLM(LlamaConfig(**config))
# pylint: disable=duplicate-code def test_attention_patching_integration():
@pytest.fixture
def minimal_cfg():
"Config of real HuggingFace Hub model"
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)
return cfg
def test_attention_patching_integration(minimal_cfg):
"""Test attention patching in integration context.""" """Test attention patching in integration context."""
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
# Store the original implementation # Store the original implementation
original_forward = getattr(LlamaAttention, "forward") original_forward = getattr(LlamaAttention, "forward")
# Load model # Apply patch
_, _ = load_model_and_tokenizer(cfg=minimal_cfg) patch_self_attn_lora(cfg)
# Get the new forward method # Get the new forward method
patched_forward = LlamaAttention.forward patched_forward = LlamaAttention.forward
@@ -404,10 +376,38 @@ def test_model_architecture(model_config):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
def test_kernel_training_integration(minimal_cfg): def test_kernel_training_integration():
"""Test model loading with kernel patches enabled.""" """Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer
# Create minimal config
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)
# Load model # Load model
model, _ = load_model_and_tokenizer(cfg=minimal_cfg) model, _ = load_model_and_tokenizer(cfg=cfg)
# Verify correct activation function # Verify correct activation function
layer = model.model.model.layers[0] layer = model.model.model.layers[0]

View File

@@ -125,6 +125,12 @@ def fixture_llama3_tokenizer():
return tokenizer return tokenizer
@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True)
def fixture_smollm2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
return tokenizer
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True) @pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
def fixture_mistralv03_tokenizer(): def fixture_mistralv03_tokenizer():
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(

View File

@@ -0,0 +1,61 @@
"""
Tests for loading DPO preference datasets with chatml formatting
"""
import unittest
import pytest
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="minimal_dpo_cfg")
def fixture_cfg():
return DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"rl": "dpo",
"learning_rate": 0.000001,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"sequence_len": 2048,
}
)
class TestDPOChatml:
"""
Test loading DPO preference datasets with chatml formatting
"""
def test_default(self, minimal_dpo_cfg):
cfg = DictDefault(
{
"datasets": [
{
"path": "argilla/distilabel-intel-orca-dpo-pairs",
"type": "chatml",
"split": "train[:1%]",
}
]
}
| minimal_dpo_cfg
)
# test that dpo.load works
load_dpo("chatml", cfg)
# now actually load the datasets with the strategy
train_ds, _ = load_prepare_preference_datasets(cfg)
assert train_ds[0]["prompt"].startswith("<|im_start|>")
assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n")
assert "chosen" in train_ds[0]
assert "rejected" in train_ds[0]
if __name__ == "__main__":
unittest.main()

View File

@@ -7,6 +7,7 @@ from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.completion import load from axolotl.prompt_strategies.completion import load
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import drop_long_seq_in_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -18,11 +19,6 @@ def fixture_tokenizer():
return tokenizer return tokenizer
@pytest.fixture(name="max_seq_length")
def fixture_max_seq_length():
return 4096
class TestBatchedSamplerPacking: class TestBatchedSamplerPacking:
""" """
Test class for packing streaming dataset sequences Test class for packing streaming dataset sequences
@@ -37,6 +33,7 @@ class TestBatchedSamplerPacking:
(2, 2), (2, 2),
], ],
) )
@pytest.mark.parametrize("max_seq_length", [4096, 512])
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length): def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
@@ -62,6 +59,9 @@ class TestBatchedSamplerPacking:
dataset, dataset,
) )
train_dataset = concatenate_datasets([dataset_wrapper]) train_dataset = concatenate_datasets([dataset_wrapper])
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg)
lengths = get_dataset_lengths(train_dataset) lengths = get_dataset_lengths(train_dataset)
batch_sampler = MultipackBatchSampler( batch_sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset), sampler=RandomSampler(train_dataset),
@@ -90,7 +90,7 @@ class TestBatchedSamplerPacking:
batch_idxs.extend(pack) batch_idxs.extend(pack)
for batch in loader: for batch in loader:
assert len(batch["input_ids"]) <= batch_size * max_seq_length assert batch["input_ids"].numel() <= batch_size * max_seq_length
assert batch["input_ids"].shape[1] == max_seq_length assert batch["input_ids"].shape[1] == max_seq_length
original_idxs = set(range(len(train_dataset))) original_idxs = set(range(len(train_dataset)))