add tests so CI can catch updates where patches will break with unsloth (#1737) [skip ci]

This commit is contained in:
Wing Lian
2024-07-11 16:43:19 -04:00
committed by GitHub
parent 1194c2e0b1
commit 47e1916484
2 changed files with 30 additions and 3 deletions

View File

@@ -80,8 +80,9 @@ def get_forward_code() -> str:
return forward
def test_cel_is_patchable() -> bool:
def check_cel_is_patchable() -> bool:
forward = get_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_CEL_CODE in forward
@@ -90,9 +91,10 @@ def get_self_attn_code() -> str:
return forward
def test_self_attn_is_patchable() -> bool:
def check_self_attn_is_patchable() -> bool:
qkv = get_self_attn_code()
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_QKV_CODE in qkv
qkv, _ = detab_code(qkv)
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
def integrate_cross_entropy_loss_patch():