misc fixes to add gptq tests (#621)
* misc fixes to add gptq tests * set bf16 needed for fa2
This commit is contained in:
@@ -19,7 +19,11 @@ def check_cuda_device(default_value):
|
|||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
device = kwargs.get("device", args[0] if args else None)
|
device = kwargs.get("device", args[0] if args else None)
|
||||||
|
|
||||||
if not torch.cuda.is_available() or device == "auto" or device == "cpu":
|
if (
|
||||||
|
not torch.cuda.is_available()
|
||||||
|
or device == "auto"
|
||||||
|
or torch.device(device).type == "cpu"
|
||||||
|
):
|
||||||
return default_value
|
return default_value
|
||||||
|
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import torch
|
|||||||
import transformers
|
import transformers
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from peft import PeftConfig, prepare_model_for_kbit_training
|
from peft import PeftConfig, prepare_model_for_kbit_training
|
||||||
|
from peft.tuners.lora import QuantLinear
|
||||||
from transformers import ( # noqa: F401
|
from transformers import ( # noqa: F401
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
@@ -309,16 +310,26 @@ def load_model(
|
|||||||
):
|
):
|
||||||
config.max_sequence_length = cfg.sequence_len
|
config.max_sequence_length = cfg.sequence_len
|
||||||
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
if cfg.gptq:
|
||||||
base_model,
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
config=config,
|
base_model,
|
||||||
device_map=cfg.device_map,
|
config=config,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
device_map=cfg.device_map,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
torch_dtype=cfg.torch_dtype,
|
||||||
torch_dtype=cfg.torch_dtype,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
**model_kwargs,
|
||||||
**model_kwargs,
|
)
|
||||||
)
|
else:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
config=config,
|
||||||
|
device_map=cfg.device_map,
|
||||||
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
|
torch_dtype=cfg.torch_dtype,
|
||||||
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
except Exception as err: # pylint: disable=broad-exception-caught
|
except Exception as err: # pylint: disable=broad-exception-caught
|
||||||
LOG.error(
|
LOG.error(
|
||||||
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
||||||
@@ -466,10 +477,10 @@ def load_llama_adapter(model, cfg):
|
|||||||
|
|
||||||
|
|
||||||
def find_all_linear_names(model):
|
def find_all_linear_names(model):
|
||||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
||||||
lora_module_names = set()
|
lora_module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, cls):
|
if isinstance(module, cls) or "Linear" in module.__class__.__name__:
|
||||||
names = name.split(".")
|
names = name.split(".")
|
||||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||||
|
|
||||||
|
|||||||
@@ -676,6 +676,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
|
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
|
||||||
and cfg.val_set_size > 0
|
and cfg.val_set_size > 0
|
||||||
and cfg.save_steps
|
and cfg.save_steps
|
||||||
|
and cfg.eval_steps
|
||||||
and cfg.save_steps % cfg.eval_steps == 0
|
and cfg.save_steps % cfg.eval_steps == 0
|
||||||
)
|
)
|
||||||
or False,
|
or False,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
@@ -24,6 +25,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
|
|
||||||
def test_lora(self):
|
def test_lora(self):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
output_dir = tempfile.mkdtemp()
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "JackFram/llama-68m",
|
||||||
@@ -51,7 +53,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"num_epochs": 2,
|
"num_epochs": 2,
|
||||||
"micro_batch_size": 8,
|
"micro_batch_size": 8,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": tempfile.mkdtemp(),
|
"output_dir": output_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -62,9 +64,11 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
def test_lora_packing(self):
|
def test_lora_packing(self):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
output_dir = tempfile.mkdtemp()
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "JackFram/llama-68m",
|
||||||
@@ -94,7 +98,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"num_epochs": 2,
|
"num_epochs": 2,
|
||||||
"micro_batch_size": 8,
|
"micro_batch_size": 8,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": tempfile.mkdtemp(),
|
"output_dir": output_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -105,3 +109,53 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
def test_lora_gptq(self):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
output_dir = tempfile.mkdtemp()
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
||||||
|
"base_model_config": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
||||||
|
"model_type": "AutoModelForCausalLM",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"sample_packing": True,
|
||||||
|
"flash_attention": True,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"gptq": True,
|
||||||
|
"gptq_disable_exllama": True,
|
||||||
|
"lora_r": 32,
|
||||||
|
"lora_alpha": 64,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"save_steps": 0.5,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": output_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||||
|
|||||||
@@ -31,9 +31,9 @@ class TestPhi(unittest.TestCase):
|
|||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"model_type": "MixFormerSequentialForCausalLM",
|
"model_type": "MixFormerSequentialForCausalLM",
|
||||||
"tokenizer_type": "AutoTokenizer",
|
"tokenizer_type": "AutoTokenizer",
|
||||||
"sequence_len": 2048,
|
"sequence_len": 512,
|
||||||
"sample_packing": False,
|
"sample_packing": False,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": False,
|
||||||
"adapter": None,
|
"adapter": None,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.1,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
@@ -55,8 +55,9 @@ class TestPhi(unittest.TestCase):
|
|||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": tempfile.mkdtemp(),
|
"output_dir": tempfile.mkdtemp(),
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_bnb_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
"bf16": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
@@ -74,9 +75,9 @@ class TestPhi(unittest.TestCase):
|
|||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"model_type": "MixFormerSequentialForCausalLM",
|
"model_type": "MixFormerSequentialForCausalLM",
|
||||||
"tokenizer_type": "AutoTokenizer",
|
"tokenizer_type": "AutoTokenizer",
|
||||||
"sequence_len": 2048,
|
"sequence_len": 512,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": False,
|
||||||
"adapter": None,
|
"adapter": None,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.1,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
@@ -98,8 +99,9 @@ class TestPhi(unittest.TestCase):
|
|||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": tempfile.mkdtemp(),
|
"output_dir": tempfile.mkdtemp(),
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_bnb_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
"bf16": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user