* keep gate in fp32 for loras
* add e2e check for lora w/o flash attention for mixtral to check gate
* add checks for gate in fp32 for mixtral, add typehints to train outputs
* mixtral doesn't support basic lora 🤦
add lora tests @ 16bit and fix gate layer check
fix the parameter name, was using the old disco name
don't lora over the gate so we can check that is in fp32
fix dtype check
* ensure we're using fp16/bf16 for 16bit and qlora is always going to be in uint8
291 lines
9.5 KiB
Python
291 lines
9.5 KiB
Python
"""
|
|
E2E tests for mixtral
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from transformers.utils import is_torch_bf16_gpu_available
|
|
|
|
from axolotl.cli import load_datasets
|
|
from axolotl.common.cli import TrainerCliArgs
|
|
from axolotl.train import train
|
|
from axolotl.utils.config import normalize_config
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
from .utils import with_temp_dir
|
|
|
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
|
os.environ["WANDB_DISABLED"] = "true"
|
|
|
|
|
|
class TestMixtral(unittest.TestCase):
|
|
"""
|
|
Test case for Llama models using LoRA
|
|
"""
|
|
|
|
@with_temp_dir
|
|
def test_qlora_w_fa2(self, temp_dir):
|
|
# pylint: disable=duplicate-code
|
|
cfg = DictDefault(
|
|
{
|
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
|
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
|
"flash_attention": True,
|
|
"sequence_len": 1024,
|
|
"load_in_4bit": True,
|
|
"adapter": "qlora",
|
|
"lora_r": 4,
|
|
"lora_alpha": 8,
|
|
"lora_dropout": 0.1,
|
|
"lora_target_modules": [
|
|
"o_proj",
|
|
"w3",
|
|
"k_proj",
|
|
"v_proj",
|
|
"w1",
|
|
"q_proj",
|
|
"w2",
|
|
],
|
|
"val_set_size": 0.1,
|
|
"special_tokens": {},
|
|
"datasets": [
|
|
{
|
|
"path": "mhenrichsen/alpaca_2k_test",
|
|
"type": "alpaca",
|
|
},
|
|
],
|
|
"num_epochs": 2,
|
|
"micro_batch_size": 2,
|
|
"gradient_accumulation_steps": 1,
|
|
"output_dir": temp_dir,
|
|
"learning_rate": 0.00001,
|
|
"optimizer": "adamw_bnb_8bit",
|
|
"lr_scheduler": "cosine",
|
|
"max_steps": 20,
|
|
"save_steps": 10,
|
|
"eval_steps": 10,
|
|
}
|
|
)
|
|
normalize_config(cfg)
|
|
cli_args = TrainerCliArgs()
|
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
|
|
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
assert (
|
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
|
== torch.uint8
|
|
)
|
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
|
|
|
@with_temp_dir
|
|
def test_qlora_wo_fa2(self, temp_dir):
|
|
# pylint: disable=duplicate-code
|
|
cfg = DictDefault(
|
|
{
|
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
|
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
|
"flash_attention": False,
|
|
"sequence_len": 1024,
|
|
"load_in_4bit": True,
|
|
"adapter": "qlora",
|
|
"lora_r": 4,
|
|
"lora_alpha": 8,
|
|
"lora_dropout": 0.1,
|
|
"lora_target_modules": [
|
|
"o_proj",
|
|
"w3",
|
|
"k_proj",
|
|
"v_proj",
|
|
"w1",
|
|
"q_proj",
|
|
"w2",
|
|
],
|
|
"val_set_size": 0.1,
|
|
"special_tokens": {},
|
|
"datasets": [
|
|
{
|
|
"path": "mhenrichsen/alpaca_2k_test",
|
|
"type": "alpaca",
|
|
},
|
|
],
|
|
"num_epochs": 2,
|
|
"micro_batch_size": 2,
|
|
"gradient_accumulation_steps": 1,
|
|
"output_dir": temp_dir,
|
|
"learning_rate": 0.00001,
|
|
"optimizer": "adamw_bnb_8bit",
|
|
"lr_scheduler": "cosine",
|
|
"max_steps": 20,
|
|
"save_steps": 10,
|
|
"eval_steps": 10,
|
|
}
|
|
)
|
|
normalize_config(cfg)
|
|
cli_args = TrainerCliArgs()
|
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
|
|
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
assert (
|
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
|
== torch.uint8
|
|
)
|
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
|
|
|
@with_temp_dir
|
|
def test_16bit_lora_w_fa2(self, temp_dir):
|
|
# pylint: disable=duplicate-code
|
|
cfg = DictDefault(
|
|
{
|
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
|
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
|
"flash_attention": True,
|
|
"sequence_len": 1024,
|
|
"adapter": "lora",
|
|
"lora_r": 4,
|
|
"lora_alpha": 8,
|
|
"lora_dropout": 0.1,
|
|
"lora_target_modules": [
|
|
"o_proj",
|
|
"w3",
|
|
"k_proj",
|
|
"v_proj",
|
|
"w1",
|
|
"q_proj",
|
|
"w2",
|
|
],
|
|
"val_set_size": 0.1,
|
|
"special_tokens": {},
|
|
"datasets": [
|
|
{
|
|
"path": "mhenrichsen/alpaca_2k_test",
|
|
"type": "alpaca",
|
|
},
|
|
],
|
|
"num_epochs": 2,
|
|
"micro_batch_size": 2,
|
|
"gradient_accumulation_steps": 1,
|
|
"output_dir": temp_dir,
|
|
"learning_rate": 0.00001,
|
|
"optimizer": "adamw_bnb_8bit",
|
|
"lr_scheduler": "cosine",
|
|
"max_steps": 20,
|
|
"save_steps": 10,
|
|
"eval_steps": 10,
|
|
}
|
|
)
|
|
if is_torch_bf16_gpu_available():
|
|
cfg.bf16 = True
|
|
else:
|
|
cfg.fp16 = True
|
|
normalize_config(cfg)
|
|
cli_args = TrainerCliArgs()
|
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
|
|
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
assert (
|
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
|
== torch.float32
|
|
)
|
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
|
|
|
@with_temp_dir
|
|
def test_16bit_lora_wo_fa2(self, temp_dir):
|
|
# pylint: disable=duplicate-code
|
|
cfg = DictDefault(
|
|
{
|
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
|
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
|
"flash_attention": False,
|
|
"sequence_len": 1024,
|
|
"adapter": "lora",
|
|
"lora_r": 4,
|
|
"lora_alpha": 8,
|
|
"lora_dropout": 0.1,
|
|
"lora_target_modules": [
|
|
"o_proj",
|
|
"w3",
|
|
"k_proj",
|
|
"v_proj",
|
|
"w1",
|
|
"q_proj",
|
|
"w2",
|
|
],
|
|
"val_set_size": 0.1,
|
|
"special_tokens": {},
|
|
"datasets": [
|
|
{
|
|
"path": "mhenrichsen/alpaca_2k_test",
|
|
"type": "alpaca",
|
|
},
|
|
],
|
|
"num_epochs": 2,
|
|
"micro_batch_size": 2,
|
|
"gradient_accumulation_steps": 1,
|
|
"output_dir": temp_dir,
|
|
"learning_rate": 0.00001,
|
|
"optimizer": "adamw_bnb_8bit",
|
|
"lr_scheduler": "cosine",
|
|
"max_steps": 20,
|
|
"save_steps": 10,
|
|
"eval_steps": 10,
|
|
}
|
|
)
|
|
normalize_config(cfg)
|
|
if is_torch_bf16_gpu_available():
|
|
cfg.bf16 = True
|
|
else:
|
|
cfg.fp16 = True
|
|
cli_args = TrainerCliArgs()
|
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
|
|
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
assert (
|
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
|
== torch.float32
|
|
)
|
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
|
|
|
@with_temp_dir
|
|
def test_ft(self, temp_dir):
|
|
# pylint: disable=duplicate-code
|
|
cfg = DictDefault(
|
|
{
|
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
|
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
|
"flash_attention": True,
|
|
"sequence_len": 1024,
|
|
"val_set_size": 0.1,
|
|
"special_tokens": {},
|
|
"datasets": [
|
|
{
|
|
"path": "mhenrichsen/alpaca_2k_test",
|
|
"type": "alpaca",
|
|
},
|
|
],
|
|
"num_epochs": 2,
|
|
"micro_batch_size": 2,
|
|
"gradient_accumulation_steps": 1,
|
|
"output_dir": temp_dir,
|
|
"learning_rate": 0.00001,
|
|
"optimizer": "adamw_bnb_8bit",
|
|
"lr_scheduler": "cosine",
|
|
"max_steps": 20,
|
|
"save_steps": 10,
|
|
"eval_steps": 10,
|
|
}
|
|
)
|
|
if is_torch_bf16_gpu_available():
|
|
cfg.bf16 = True
|
|
else:
|
|
cfg.fp16 = True
|
|
normalize_config(cfg)
|
|
cli_args = TrainerCliArgs()
|
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
|
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|