Multipack simplify for Mixtral (#1142)
This commit is contained in:
@@ -7,8 +7,6 @@ import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
@@ -60,12 +58,9 @@ class TestMixtral(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"sample_packing": True,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -101,23 +96,16 @@ class TestMixtral(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"sample_packing": True,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
"axolotl.monkeypatch.mixtral.modeling_mixtral"
|
||||
in model.model.layers[0].self_attn.__class__.__module__
|
||||
)
|
||||
assert (
|
||||
"MixtralMultipackFlashAttention2"
|
||||
"MixtralFlashAttention2"
|
||||
in model.model.layers[0].self_attn.__class__.__name__
|
||||
)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
|
||||
@@ -52,11 +52,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||
|
||||
assert (
|
||||
"axolotl.monkeypatch.mixtral.modeling_mixtral"
|
||||
in model.model.layers[0].self_attn.__class__.__module__
|
||||
)
|
||||
assert (
|
||||
"MixtralMultipackFlashAttention2"
|
||||
"MixtralFlashAttention2"
|
||||
in model.model.layers[0].self_attn.__class__.__name__
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user