add assertion for packing patch to _get_unpad_data (#2840)

This commit is contained in:
Wing Lian
2025-06-27 11:20:23 -04:00
committed by GitHub
parent ec15a7a691
commit a1a740608d

View File

@@ -42,6 +42,10 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
if has_remote_code:
patch_remote(model_name)
elif hasattr(transformers, "modeling_flash_attention_utils"):
# sanity check in case upstream api changes on this
assert hasattr(
transformers.modeling_flash_attention_utils, "_get_unpad_data"
), "transformers api changed for _get_unpad_data for flash attention"
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)