add e2e tests for Unsloth qlora and test the builds (#2093)
* see if unsloth installs cleanly in ci * check unsloth install on regular tests, not sdist * fix ampere check exception for ci * use cached_property instead * add an e2e test for unsloth qlora * reduce seq len and mbsz to prevent oom in ci * add checks for fp16 and sdp_attention * pin unsloth to a specific release * add unsloth to docker image too * fix flash attn xentropy patch * fix loss, add check for loss when using fa_xentropy * fix special tokens for test * typo * test fa xentropy with and without gradient accum * pr feedback changes
This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -94,13 +93,32 @@ def replace_llama_qkv_with_fused(model):
|
||||
set_module_name(model, name, qkv)
|
||||
|
||||
|
||||
def patch_llama_cross_entropy():
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
def patch_fa_llama_cross_entropy():
|
||||
LOG.info(
|
||||
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
|
||||
)
|
||||
from flash_attn.ops.triton.cross_entropy import (
|
||||
cross_entropy_loss as flash_attn_cross_entropy_loss,
|
||||
)
|
||||
|
||||
def fa2_fixed_cross_entropy(
|
||||
source,
|
||||
target,
|
||||
num_items_in_batch: int = None,
|
||||
ignore_index: int = -100,
|
||||
**kwargs,
|
||||
): # pylint: disable=unused-argument
|
||||
reduction = "sum" if num_items_in_batch is not None else "mean"
|
||||
loss, _ = flash_attn_cross_entropy_loss(
|
||||
source, target, ignore_index=ignore_index
|
||||
)
|
||||
if reduction == "sum":
|
||||
loss = loss.sum() / num_items_in_batch
|
||||
else:
|
||||
loss = loss.sum() / (target != ignore_index).sum()
|
||||
return loss
|
||||
|
||||
transformers.loss.loss_utils.fixed_cross_entropy = fa2_fixed_cross_entropy
|
||||
|
||||
|
||||
def patch_llama_rms_norm():
|
||||
@@ -147,7 +165,7 @@ def replace_llama_attn_with_flash_attn(
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if cross_entropy:
|
||||
patch_llama_cross_entropy()
|
||||
patch_fa_llama_cross_entropy()
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if rms_norm:
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
import gc
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import types
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||
|
||||
import addict
|
||||
@@ -409,7 +411,7 @@ class ModelLoader:
|
||||
)
|
||||
|
||||
if self.cfg.is_llama_derived_model:
|
||||
self.patch_loss()
|
||||
self.patch_loss_llama()
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
@@ -451,27 +453,34 @@ class ModelLoader:
|
||||
|
||||
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
|
||||
|
||||
def patch_loss(self) -> None:
|
||||
@cached_property
|
||||
def has_flash_attn(self) -> bool:
|
||||
"""Check if flash attention is installed"""
|
||||
return importlib.util.find_spec("flash_attn") is not None
|
||||
|
||||
def patch_loss_llama(self) -> None:
|
||||
"""
|
||||
Patch loss functions
|
||||
"""
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
patch_llama_cross_entropy,
|
||||
patch_llama_rms_norm,
|
||||
)
|
||||
if self.has_flash_attn:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
patch_fa_llama_cross_entropy,
|
||||
patch_llama_rms_norm,
|
||||
)
|
||||
|
||||
if self.cfg.flash_attn_cross_entropy:
|
||||
patch_llama_cross_entropy()
|
||||
if self.cfg.flash_attn_rms_norm:
|
||||
if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
|
||||
patch_fa_llama_cross_entropy()
|
||||
elif self.cfg.unsloth_cross_entropy_loss:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||
|
||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||
|
||||
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
|
||||
patch_llama_rms_norm()
|
||||
elif self.cfg.unsloth_rms_norm:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
|
||||
|
||||
patch_unsloth_layernorm()
|
||||
if self.cfg.unsloth_cross_entropy_loss:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||
|
||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
@@ -481,6 +490,7 @@ class ModelLoader:
|
||||
"""
|
||||
Modify all llama derived models in one block
|
||||
"""
|
||||
self.patch_loss_llama()
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
@@ -528,16 +538,6 @@ class ModelLoader:
|
||||
"Shifted-sparse attention not currently implemented without flash attention."
|
||||
)
|
||||
|
||||
if self.cfg.unsloth_cross_entropy_loss:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||
|
||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora()
|
||||
|
||||
def set_auto_model_loader(self) -> None:
|
||||
"""set self.AutoModelLoader
|
||||
- default value: AutoModelForCausalLM (set at __init__)
|
||||
|
||||
Reference in New Issue
Block a user