add assertion for packing patch to _get_unpad_data (#2840)
This commit is contained in:
@@ -42,6 +42,10 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
|||||||
if has_remote_code:
|
if has_remote_code:
|
||||||
patch_remote(model_name)
|
patch_remote(model_name)
|
||||||
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||||
|
# sanity check in case upstream api changes on this
|
||||||
|
assert hasattr(
|
||||||
|
transformers.modeling_flash_attention_utils, "_get_unpad_data"
|
||||||
|
), "transformers api changed for _get_unpad_data for flash attention"
|
||||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user