diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 1467f9e29..e590dbdaa 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -42,6 +42,10 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False): if has_remote_code: patch_remote(model_name) elif hasattr(transformers, "modeling_flash_attention_utils"): + # sanity check in case upstream api changes on this + assert hasattr( + transformers.modeling_flash_attention_utils, "_get_unpad_data" + ), "transformers api changed for _get_unpad_data for flash attention" transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data )