misc fixes to add gptq tests (#621)

* misc fixes to add gptq tests

* set bf16 needed for fa2
This commit is contained in:
Wing Lian
2023-09-21 21:52:12 -04:00
committed by GitHub
parent 97d3776ce6
commit 03e59077a0
5 changed files with 93 additions and 21 deletions

View File

@@ -19,7 +19,11 @@ def check_cuda_device(default_value):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
device = kwargs.get("device", args[0] if args else None) device = kwargs.get("device", args[0] if args else None)
if not torch.cuda.is_available() or device == "auto" or device == "cpu": if (
not torch.cuda.is_available()
or device == "auto"
or torch.device(device).type == "cpu"
):
return default_value return default_value
return func(*args, **kwargs) return func(*args, **kwargs)

View File

@@ -10,6 +10,7 @@ import torch
import transformers import transformers
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig, prepare_model_for_kbit_training from peft import PeftConfig, prepare_model_for_kbit_training
from peft.tuners.lora import QuantLinear
from transformers import ( # noqa: F401 from transformers import ( # noqa: F401
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
@@ -309,16 +310,26 @@ def load_model(
): ):
config.max_sequence_length = cfg.sequence_len config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}") LOG.warning(f"increasing context length to {cfg.sequence_len}")
model = AutoModelForCausalLM.from_pretrained( if cfg.gptq:
base_model, model = AutoModelForCausalLM.from_pretrained(
config=config, base_model,
device_map=cfg.device_map, config=config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, device_map=cfg.device_map,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=cfg.torch_dtype,
torch_dtype=cfg.torch_dtype, trust_remote_code=cfg.trust_remote_code or False,
trust_remote_code=cfg.trust_remote_code or False, **model_kwargs,
**model_kwargs, )
) else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
except Exception as err: # pylint: disable=broad-exception-caught except Exception as err: # pylint: disable=broad-exception-caught
LOG.error( LOG.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM" "Exception raised attempting to load model, retrying with AutoModelForCausalLM"
@@ -466,10 +477,10 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model): def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear) cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
lora_module_names = set() lora_module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, cls): if isinstance(module, cls) or "Linear" in module.__class__.__name__:
names = name.split(".") names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) lora_module_names.add(names[0] if len(names) == 1 else names[-1])

View File

@@ -676,6 +676,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience) (cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
and cfg.val_set_size > 0 and cfg.val_set_size > 0
and cfg.save_steps and cfg.save_steps
and cfg.eval_steps
and cfg.save_steps % cfg.eval_steps == 0 and cfg.save_steps % cfg.eval_steps == 0
) )
or False, or False,

View File

@@ -6,6 +6,7 @@ import logging
import os import os
import tempfile import tempfile
import unittest import unittest
from pathlib import Path
from axolotl.cli import load_datasets from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
@@ -24,6 +25,7 @@ class TestLoraLlama(unittest.TestCase):
def test_lora(self): def test_lora(self):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "JackFram/llama-68m", "base_model": "JackFram/llama-68m",
@@ -51,7 +53,7 @@ class TestLoraLlama(unittest.TestCase):
"num_epochs": 2, "num_epochs": 2,
"micro_batch_size": 8, "micro_batch_size": 8,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(), "output_dir": output_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -62,9 +64,11 @@ class TestLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_lora_packing(self): def test_lora_packing(self):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "JackFram/llama-68m", "base_model": "JackFram/llama-68m",
@@ -94,7 +98,7 @@ class TestLoraLlama(unittest.TestCase):
"num_epochs": 2, "num_epochs": 2,
"micro_batch_size": 8, "micro_batch_size": 8,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(), "output_dir": output_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -105,3 +109,53 @@ class TestLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_lora_gptq(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
"base_model_config": "TheBlokeAI/jackfram_llama-68m-GPTQ",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"gptq": True,
"gptq_disable_exllama": True,
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"save_steps": 0.5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()

View File

@@ -31,9 +31,9 @@ class TestPhi(unittest.TestCase):
"trust_remote_code": True, "trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM", "model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer", "tokenizer_type": "AutoTokenizer",
"sequence_len": 2048, "sequence_len": 512,
"sample_packing": False, "sample_packing": False,
"load_in_8bit": True, "load_in_8bit": False,
"adapter": None, "adapter": None,
"val_set_size": 0.1, "val_set_size": 0.1,
"special_tokens": { "special_tokens": {
@@ -55,8 +55,9 @@ class TestPhi(unittest.TestCase):
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(), "output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"bf16": True,
} }
) )
normalize_config(cfg) normalize_config(cfg)
@@ -74,9 +75,9 @@ class TestPhi(unittest.TestCase):
"trust_remote_code": True, "trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM", "model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer", "tokenizer_type": "AutoTokenizer",
"sequence_len": 2048, "sequence_len": 512,
"sample_packing": True, "sample_packing": True,
"load_in_8bit": True, "load_in_8bit": False,
"adapter": None, "adapter": None,
"val_set_size": 0.1, "val_set_size": 0.1,
"special_tokens": { "special_tokens": {
@@ -98,8 +99,9 @@ class TestPhi(unittest.TestCase):
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(), "output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"bf16": True,
} }
) )
normalize_config(cfg) normalize_config(cfg)