fixing transformers version

This commit is contained in:
salman
2025-04-08 11:28:52 +01:00
committed by Sung Ching Liu
parent 75c565d476
commit cdb16069af

View File

@@ -10,9 +10,9 @@ import transformers
def patch_flex_wrapper():
# TODO remove this patch when transformers#37285 is merged and in a release
is_torch_2_6 = torch.__version__.startswith("2.6")
is_transformers_below_4_52 = transformers.__version__ < "4.52.0"
is_transformers_below_4_51_1 = transformers.__version__ < "4.51.1"
if not (is_torch_2_6 and is_transformers_below_4_52):
if not (is_torch_2_6 and is_transformers_below_4_51_1):
return
from torch.nn.attention.flex_attention import flex_attention