latest fixes needed for GA in latest transformers

This commit is contained in:
Wing Lian
2025-01-13 13:36:47 -05:00
parent 49b5501fc2
commit 5b5ba49c46
5 changed files with 72 additions and 17 deletions

View File

@@ -49,12 +49,7 @@ class TestModelPatches(unittest.TestCase):
)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
model, _ = load_model(cfg, tokenizer, inference=False)
assert (
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
load_model(cfg, tokenizer, inference=False)
@with_temp_dir
def test_mistral_multipack(self, temp_dir):