swaps to use newer sample packing for mistral (#1773)
* swaps to use newer sample packing for mistral * fix multipack patch test * patch the common fa utils * update for refactor of flash attn unpad * remove un-needed drop attn mask for mistral * bump transformers to main to pick up latest mistral fix for 12b and refactor of fa2 * update test
This commit is contained in:
@@ -4,6 +4,8 @@ E2E smoke tests to check that the monkeypatches are in place for certain configu
|
||||
|
||||
import unittest
|
||||
|
||||
import transformers
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -87,9 +89,9 @@ class TestModelPatches(unittest.TestCase):
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||
load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||
|
||||
assert (
|
||||
"axolotl.monkeypatch.mistral_attn_hijack_flash"
|
||||
in model.model.layers[0].self_attn.forward.__module__
|
||||
"torch.jit"
|
||||
in transformers.modeling_flash_attention_utils._get_unpad_data.__module__ # pylint: disable=protected-access
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user