update upstream HF deps (#2239)
* bump axolotl contribs for upstream main conflicts: * bump datasets, tokenizer, trl * remove log workarounds in trl * bump lm-eval * remove unsloth_ import from critical path * remove llama fa2 from conftest * unsloth breaks with latest upstream
This commit is contained in:
@@ -120,13 +120,12 @@ def temp_dir():
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def cleanup_monkeypatches():
|
||||
from transformers import Trainer
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
|
||||
LlamaAttention,
|
||||
LlamaFlashAttention2,
|
||||
LlamaForCausalLM,
|
||||
)
|
||||
|
||||
original_fa2_forward = LlamaFlashAttention2.forward
|
||||
# original_fa2_forward = LlamaFlashAttention2.forward
|
||||
original_llama_attn_forward = LlamaAttention.forward
|
||||
original_llama_forward = LlamaForCausalLM.forward
|
||||
original_trainer_inner_training_loop = (
|
||||
@@ -136,7 +135,7 @@ def cleanup_monkeypatches():
|
||||
# monkey patches can happen inside the tests
|
||||
yield
|
||||
# Reset LlamaFlashAttention2 forward
|
||||
LlamaFlashAttention2.forward = original_fa2_forward
|
||||
# LlamaFlashAttention2.forward = original_fa2_forward
|
||||
LlamaAttention.forward = original_llama_attn_forward
|
||||
LlamaForCausalLM.forward = original_llama_forward
|
||||
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||
@@ -149,7 +148,10 @@ def cleanup_monkeypatches():
|
||||
("transformers.models.llama",),
|
||||
(
|
||||
"transformers.models.llama.modeling_llama",
|
||||
["LlamaFlashAttention2", "LlamaAttention"],
|
||||
[
|
||||
# "LlamaFlashAttention2",
|
||||
"LlamaAttention",
|
||||
],
|
||||
),
|
||||
("transformers.trainer",),
|
||||
("transformers", ["Trainer"]),
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Unsloth integration will be broken going into latest transformers"
|
||||
)
|
||||
class TestUnslothIntegration(unittest.TestCase):
|
||||
"""Unsloth monkeypatch integration tests."""
|
||||
|
||||
|
||||
@@ -20,6 +20,9 @@ os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
@pytest.mark.skip(
|
||||
reason="Unsloth integration will be broken going into latest transformers"
|
||||
)
|
||||
class TestUnslothQLoRA:
|
||||
"""
|
||||
Test class for Unsloth QLoRA Llama models
|
||||
|
||||
Reference in New Issue
Block a user