add tests so CI can catch updates where patches will break with unsloth (#1737) [skip ci]

This commit is contained in:
Wing Lian
2024-07-11 16:43:19 -04:00
committed by GitHub
parent 1194c2e0b1
commit 47e1916484
2 changed files with 30 additions and 3 deletions

View File

@@ -80,8 +80,9 @@ def get_forward_code() -> str:
return forward
def test_cel_is_patchable() -> bool:
def check_cel_is_patchable() -> bool:
forward = get_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_CEL_CODE in forward
@@ -90,9 +91,10 @@ def get_self_attn_code() -> str:
return forward
def test_self_attn_is_patchable() -> bool:
def check_self_attn_is_patchable() -> bool:
qkv = get_self_attn_code()
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_QKV_CODE in qkv
qkv, _ = detab_code(qkv)
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
def integrate_cross_entropy_loss_patch():

View File

@@ -0,0 +1,25 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest
from axolotl.monkeypatch.unsloth_ import (
check_cel_is_patchable,
check_self_attn_is_patchable,
)
class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests."""
def test_is_cel_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_cel_is_patchable(),
"HF transformers loss code has changed and isn't patchable",
)
def test_is_self_attn_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_self_attn_is_patchable(),
"HF transformers self attention code has changed and isn't patchable",
)