add assertion for packing patch to _get_unpad_data (#2840)
This commit is contained in:
@@ -42,6 +42,10 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
||||
if has_remote_code:
|
||||
patch_remote(model_name)
|
||||
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||
# sanity check in case upstream api changes on this
|
||||
assert hasattr(
|
||||
transformers.modeling_flash_attention_utils, "_get_unpad_data"
|
||||
), "transformers api changed for _get_unpad_data for flash attention"
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user