add e2e tests for Unsloth qlora and test the builds (#2093)

* see if unsloth installs cleanly in ci

* check unsloth install on regular tests, not sdist

* fix ampere check exception for ci

* use cached_property instead

* add an e2e test for unsloth qlora

* reduce seq len and mbsz to prevent oom in ci

* add checks for fp16 and sdp_attention

* pin unsloth to a specific release

* add unsloth to docker image too

* fix flash attn xentropy patch

* fix loss, add check for loss when using fa_xentropy

* fix special tokens for test

* typo

* test fa xentropy with and without gradient accum

* pr feedback changes
This commit is contained in:
Wing Lian
2024-11-29 20:38:49 -05:00
committed by GitHub
parent 1cf7075d18
commit 5f1d98e8fc
8 changed files with 275 additions and 50 deletions

View File

@@ -4,7 +4,6 @@
import logging
import warnings
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
@@ -94,13 +93,32 @@ def replace_llama_qkv_with_fused(model):
set_module_name(model, name, qkv)
def patch_llama_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
def patch_fa_llama_cross_entropy():
LOG.info(
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
)
from flash_attn.ops.triton.cross_entropy import (
cross_entropy_loss as flash_attn_cross_entropy_loss,
)
def fa2_fixed_cross_entropy(
source,
target,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
): # pylint: disable=unused-argument
reduction = "sum" if num_items_in_batch is not None else "mean"
loss, _ = flash_attn_cross_entropy_loss(
source, target, ignore_index=ignore_index
)
if reduction == "sum":
loss = loss.sum() / num_items_in_batch
else:
loss = loss.sum() / (target != ignore_index).sum()
return loss
transformers.loss.loss_utils.fixed_cross_entropy = fa2_fixed_cross_entropy
def patch_llama_rms_norm():
@@ -147,7 +165,7 @@ def replace_llama_attn_with_flash_attn(
# skip only if explicitly disabled
if cross_entropy:
patch_llama_cross_entropy()
patch_fa_llama_cross_entropy()
# skip only if explicitly disabled
if rms_norm:

View File

@@ -2,10 +2,12 @@
# pylint: disable=too-many-lines
import gc
import importlib
import logging
import math
import os
import types
from functools import cached_property
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
import addict
@@ -409,7 +411,7 @@ class ModelLoader:
)
if self.cfg.is_llama_derived_model:
self.patch_loss()
self.patch_loss_llama()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
@@ -451,27 +453,34 @@ class ModelLoader:
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
def patch_loss(self) -> None:
@cached_property
def has_flash_attn(self) -> bool:
"""Check if flash attention is installed"""
return importlib.util.find_spec("flash_attn") is not None
def patch_loss_llama(self) -> None:
"""
Patch loss functions
"""
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)
if self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_fa_llama_cross_entropy,
patch_llama_rms_norm,
)
if self.cfg.flash_attn_cross_entropy:
patch_llama_cross_entropy()
if self.cfg.flash_attn_rms_norm:
if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
patch_fa_llama_cross_entropy()
elif self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
patch_llama_rms_norm()
elif self.cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
patch_unsloth_layernorm()
if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
@@ -481,6 +490,7 @@ class ModelLoader:
"""
Modify all llama derived models in one block
"""
self.patch_loss_llama()
if self.cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
@@ -528,16 +538,6 @@ class ModelLoader:
"Shifted-sparse attention not currently implemented without flash attention."
)
if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
def set_auto_model_loader(self) -> None:
"""set self.AutoModelLoader
- default value: AutoModelForCausalLM (set at __init__)