add tests so CI can catch updates where patches will break with unsloth (#1737) [skip ci]
This commit is contained in:
@@ -80,8 +80,9 @@ def get_forward_code() -> str:
|
||||
return forward
|
||||
|
||||
|
||||
def test_cel_is_patchable() -> bool:
|
||||
def check_cel_is_patchable() -> bool:
|
||||
forward = get_forward_code()
|
||||
forward, _ = detab_code(forward)
|
||||
return ORIGINAL_CEL_CODE in forward
|
||||
|
||||
|
||||
@@ -90,9 +91,10 @@ def get_self_attn_code() -> str:
|
||||
return forward
|
||||
|
||||
|
||||
def test_self_attn_is_patchable() -> bool:
|
||||
def check_self_attn_is_patchable() -> bool:
|
||||
qkv = get_self_attn_code()
|
||||
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_QKV_CODE in qkv
|
||||
qkv, _ = detab_code(qkv)
|
||||
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
|
||||
|
||||
|
||||
def integrate_cross_entropy_loss_patch():
|
||||
|
||||
Reference in New Issue
Block a user