diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index de8260414..6af3046e1 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -80,8 +80,9 @@ def get_forward_code() -> str: return forward -def test_cel_is_patchable() -> bool: +def check_cel_is_patchable() -> bool: forward = get_forward_code() + forward, _ = detab_code(forward) return ORIGINAL_CEL_CODE in forward @@ -90,9 +91,10 @@ def get_self_attn_code() -> str: return forward -def test_self_attn_is_patchable() -> bool: +def check_self_attn_is_patchable() -> bool: qkv = get_self_attn_code() - return ORIGINAL_QKV_CODE in qkv and ORIGINAL_QKV_CODE in qkv + qkv, _ = detab_code(qkv) + return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv def integrate_cross_entropy_loss_patch(): diff --git a/tests/e2e/patched/test_unsloth_integration.py b/tests/e2e/patched/test_unsloth_integration.py new file mode 100644 index 000000000..39c7abb1c --- /dev/null +++ b/tests/e2e/patched/test_unsloth_integration.py @@ -0,0 +1,25 @@ +"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected.""" +import unittest + +from axolotl.monkeypatch.unsloth_ import ( + check_cel_is_patchable, + check_self_attn_is_patchable, +) + + +class TestUnslothIntegration(unittest.TestCase): + """Unsloth monkeypatch integration tests.""" + + def test_is_cel_patchable(self): + # ensures the current version of transformers has loss code that matches our patching code + self.assertTrue( + check_cel_is_patchable(), + "HF transformers loss code has changed and isn't patchable", + ) + + def test_is_self_attn_patchable(self): + # ensures the current version of transformers has loss code that matches our patching code + self.assertTrue( + check_self_attn_is_patchable(), + "HF transformers self attention code has changed and isn't patchable", + )