add e2e tests for Unsloth qlora and test the builds (#2093)

* see if unsloth installs cleanly in ci

* check unsloth install on regular tests, not sdist

* fix ampere check exception for ci

* use cached_property instead

* add an e2e test for unsloth qlora

* reduce seq len and mbsz to prevent oom in ci

* add checks for fp16 and sdp_attention

* pin unsloth to a specific release

* add unsloth to docker image too

* fix flash attn xentropy patch

* fix loss, add check for loss when using fa_xentropy

* fix special tokens for test

* typo

* test fa xentropy with and without gradient accum

* pr feedback changes
This commit is contained in:
Wing Lian
2024-11-29 20:38:49 -05:00
committed by GitHub
parent 1cf7075d18
commit 5f1d98e8fc
8 changed files with 275 additions and 50 deletions

View File

@@ -67,6 +67,7 @@ jobs:
run: | run: |
pip3 show torch pip3 show torch
pip3 install -U -e . pip3 install -U -e .
python scripts/unsloth_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests - name: Run tests

View File

@@ -37,6 +37,8 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi fi
RUN python scripts/unsloth_install.py | sh
# So we can test the Docker image # So we can test the Docker image
RUN pip install -r requirements-dev.txt -r requirements-tests.txt RUN pip install -r requirements-dev.txt -r requirements-tests.txt

View File

@@ -26,6 +26,8 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi fi
RUN python scripts/unsloth_install.py | sh
# So we can test the Docker image # So we can test the Docker image
RUN pip install pytest RUN pip install pytest

View File

@@ -8,7 +8,10 @@ from packaging.version import Version as V
v = V(torch.__version__) v = V(torch.__version__)
cuda = str(torch.version.cuda) cuda = str(torch.version.cuda)
is_ampere = torch.cuda.get_device_capability()[0] >= 8 try:
is_ampere = torch.cuda.get_device_capability()[0] >= 8
except RuntimeError:
is_ampere = False
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4": if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
raise RuntimeError(f"CUDA = {cuda} not supported!") raise RuntimeError(f"CUDA = {cuda} not supported!")
if v <= V("2.1.0"): if v <= V("2.1.0"):
@@ -29,5 +32,5 @@ else:
raise RuntimeError(f"Torch = {v} too new!") raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "") x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
print( print(
f'pip install unsloth-zoo && pip install --no-deps "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"' f'pip install unsloth-zoo==2024.11.7 && pip install --no-deps "unsloth[{x}]==2024.11.9"'
) )

View File

@@ -4,7 +4,6 @@
import logging import logging
import warnings import warnings
from functools import partial
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@@ -94,13 +93,32 @@ def replace_llama_qkv_with_fused(model):
set_module_name(model, name, qkv) set_module_name(model, name, qkv)
def patch_llama_cross_entropy(): def patch_fa_llama_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss LOG.info(
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
) )
from flash_attn.ops.triton.cross_entropy import (
cross_entropy_loss as flash_attn_cross_entropy_loss,
)
def fa2_fixed_cross_entropy(
source,
target,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
): # pylint: disable=unused-argument
reduction = "sum" if num_items_in_batch is not None else "mean"
loss, _ = flash_attn_cross_entropy_loss(
source, target, ignore_index=ignore_index
)
if reduction == "sum":
loss = loss.sum() / num_items_in_batch
else:
loss = loss.sum() / (target != ignore_index).sum()
return loss
transformers.loss.loss_utils.fixed_cross_entropy = fa2_fixed_cross_entropy
def patch_llama_rms_norm(): def patch_llama_rms_norm():
@@ -147,7 +165,7 @@ def replace_llama_attn_with_flash_attn(
# skip only if explicitly disabled # skip only if explicitly disabled
if cross_entropy: if cross_entropy:
patch_llama_cross_entropy() patch_fa_llama_cross_entropy()
# skip only if explicitly disabled # skip only if explicitly disabled
if rms_norm: if rms_norm:

View File

@@ -2,10 +2,12 @@
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
import gc import gc
import importlib
import logging import logging
import math import math
import os import os
import types import types
from functools import cached_property
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401 from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
import addict import addict
@@ -409,7 +411,7 @@ class ModelLoader:
) )
if self.cfg.is_llama_derived_model: if self.cfg.is_llama_derived_model:
self.patch_loss() self.patch_loss_llama()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
@@ -451,27 +453,34 @@ class ModelLoader:
replace_stablelm_attn_with_flash_attn(self.cfg.base_model) replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
def patch_loss(self) -> None: @cached_property
def has_flash_attn(self) -> bool:
"""Check if flash attention is installed"""
return importlib.util.find_spec("flash_attn") is not None
def patch_loss_llama(self) -> None:
""" """
Patch loss functions Patch loss functions
""" """
from axolotl.monkeypatch.llama_attn_hijack_flash import ( if self.has_flash_attn:
patch_llama_cross_entropy, from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_rms_norm, patch_fa_llama_cross_entropy,
) patch_llama_rms_norm,
)
if self.cfg.flash_attn_cross_entropy: if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
patch_llama_cross_entropy() patch_fa_llama_cross_entropy()
if self.cfg.flash_attn_rms_norm: elif self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
patch_llama_rms_norm() patch_llama_rms_norm()
elif self.cfg.unsloth_rms_norm: elif self.cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
patch_unsloth_layernorm() patch_unsloth_layernorm()
if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
@@ -481,6 +490,7 @@ class ModelLoader:
""" """
Modify all llama derived models in one block Modify all llama derived models in one block
""" """
self.patch_loss_llama()
if self.cfg.flash_attention: if self.cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import ( from axolotl.monkeypatch.llama_attn_hijack_flash import (
@@ -528,16 +538,6 @@ class ModelLoader:
"Shifted-sparse attention not currently implemented without flash attention." "Shifted-sparse attention not currently implemented without flash attention."
) )
if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
def set_auto_model_loader(self) -> None: def set_auto_model_loader(self) -> None:
"""set self.AutoModelLoader """set self.AutoModelLoader
- default value: AutoModelForCausalLM (set at __init__) - default value: AutoModelForCausalLM (set at __init__)

View File

@@ -4,11 +4,11 @@ E2E tests for lora llama
import logging import logging
import os import os
import unittest
from importlib import reload from importlib import reload
from pathlib import Path from pathlib import Path
import pytest import pytest
from tbparse import SummaryReader
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets from axolotl.cli import load_datasets
@@ -17,7 +17,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir from ..utils import most_recent_subdir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -31,18 +31,20 @@ def reload_transformers():
reload(transformers.models.llama.modeling_llama) reload(transformers.models.llama.modeling_llama)
class TestFAXentropyLlama(unittest.TestCase): class TestFAXentropyLlama:
""" """
Test case for Llama models using LoRA w multipack Test case for Llama models using LoRA w multipack
""" """
@with_temp_dir @pytest.mark.parametrize(
def test_lora_packing_fa_cross_entropy(self, temp_dir): "gradient_accumulation_steps",
[1, 4],
)
def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "JackFram/llama-68m", "base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024, "sequence_len": 1024,
"sample_packing": True, "sample_packing": True,
"flash_attention": True, "flash_attention": True,
@@ -55,25 +57,29 @@ class TestFAXentropyLlama(unittest.TestCase):
"lora_target_linear": True, "lora_target_linear": True,
"val_set_size": 0.2, "val_set_size": 0.2,
"special_tokens": { "special_tokens": {
"unk_token": "<unk>", "pad_token": "<|endoftext|>",
"bos_token": "<s>",
"eos_token": "</s>",
}, },
"chat_template": "chatml",
"datasets": [ "datasets": [
{ {
"path": "mhenrichsen/alpaca_2k_test", "path": "mlabonne/FineTome-100k",
"type": "alpaca", "field_messages": "conversations",
"message_field_content": "value",
"message_field_role": "from",
"type": "chat_template",
"split": "train[:2%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 10, "max_steps": 5,
"save_steps": 10, "save_steps": 5,
"micro_batch_size": 8, "micro_batch_size": 2,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"use_tensorboard": True,
} }
) )
if is_torch_bf16_gpu_available(): if is_torch_bf16_gpu_available():
@@ -87,3 +93,10 @@ class TestFAXentropyLlama(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 1.5, "Loss is too high"

View File

@@ -0,0 +1,186 @@
"""
e2e tests for unsloth qlora
"""
import logging
import os
from pathlib import Path
import pytest
from e2e.utils import most_recent_subdir
from tbparse import SummaryReader
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
# pylint: disable=duplicate-code
class TestUnslothQLoRA:
"""
Test class for Unsloth QLoRA Llama models
"""
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_unsloth_llama_qlora_fa2(self, temp_dir, sample_packing):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": sample_packing,
"flash_attention": True,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.2,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"save_steps": 10,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": False,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.2,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"save_steps": 10,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"
@pytest.mark.parametrize(
"sdp_attention",
[True, False],
)
def test_unsloth_llama_qlora_unpacked_no_fa2_fp16(self, temp_dir, sdp_attention):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": False,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.2,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"save_steps": 10,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"sdp_attention": sdp_attention,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"fp16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"