keep gate in fp32 for 16 bit loras (#1105)

* keep gate in fp32 for loras

* add e2e check for lora w/o flash attention for mixtral to check gate

* add checks for gate in fp32 for mixtral, add typehints to train outputs

* mixtral doesn't support basic lora 🤦

add lora tests @ 16bit and fix gate layer check
fix the parameter name, was using the old disco name
don't lora over the gate so we can check that is in fp32
fix dtype check

* ensure we're using fp16/bf16 for 16bit and qlora is always going to be in uint8
This commit is contained in:
Wing Lian
2024-01-12 14:58:21 -05:00
committed by GitHub
parent 2dc431078c
commit da97285e63
3 changed files with 191 additions and 8 deletions

View File

@@ -5,14 +5,16 @@ import signal
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from typing import Optional, Tuple, Union
import torch
import transformers.modelcard
from accelerate.logging import get_logger
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
from peft import PeftModel
from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs
@@ -43,7 +45,7 @@ class TrainDatasetMeta:
def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
):
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
# load the tokenizer first
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",

View File

@@ -590,7 +590,7 @@ def load_model(
# make sure these are fp32 per Ramesh et al. (2021)
embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
for name, module in model.named_modules():
if "norm" in name:
if any(m in name for m in ["norm", "gate"]):
module.to(torch.float32)
if model_config.model_type == "btlm":
# don't upcast lm_head for btlm

View File

@@ -7,6 +7,7 @@ import os
import unittest
from pathlib import Path
import torch
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
@@ -27,7 +28,7 @@ class TestMixtral(unittest.TestCase):
"""
@with_temp_dir
def test_qlora(self, temp_dir):
def test_qlora_w_fa2(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
@@ -37,10 +38,18 @@ class TestMixtral(unittest.TestCase):
"sequence_len": 1024,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,
"lora_r": 4,
"lora_alpha": 8,
"lora_dropout": 0.1,
"lora_target_linear": True,
"lora_target_modules": [
"o_proj",
"w3",
"k_proj",
"v_proj",
"w1",
"q_proj",
"w2",
],
"val_set_size": 0.1,
"special_tokens": {},
"datasets": [
@@ -65,7 +74,179 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.uint8
)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_qlora_wo_fa2(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": False,
"sequence_len": 1024,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 4,
"lora_alpha": 8,
"lora_dropout": 0.1,
"lora_target_modules": [
"o_proj",
"w3",
"k_proj",
"v_proj",
"w1",
"q_proj",
"w2",
],
"val_set_size": 0.1,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.uint8
)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_16bit_lora_w_fa2(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True,
"sequence_len": 1024,
"adapter": "lora",
"lora_r": 4,
"lora_alpha": 8,
"lora_dropout": 0.1,
"lora_target_modules": [
"o_proj",
"w3",
"k_proj",
"v_proj",
"w1",
"q_proj",
"w2",
],
"val_set_size": 0.1,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_16bit_lora_wo_fa2(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": False,
"sequence_len": 1024,
"adapter": "lora",
"lora_r": 4,
"lora_alpha": 8,
"lora_dropout": 0.1,
"lora_target_modules": [
"o_proj",
"w3",
"k_proj",
"v_proj",
"w1",
"q_proj",
"w2",
],
"val_set_size": 0.1,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
normalize_config(cfg)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir