misc fixes to add gptq tests (#621)

* misc fixes to add gptq tests

* set bf16 needed for fa2
This commit is contained in:
Wing Lian
2023-09-21 21:52:12 -04:00
committed by GitHub
parent 97d3776ce6
commit 03e59077a0
5 changed files with 93 additions and 21 deletions

View File

@@ -19,7 +19,11 @@ def check_cuda_device(default_value):
def wrapper(*args, **kwargs):
device = kwargs.get("device", args[0] if args else None)
if not torch.cuda.is_available() or device == "auto" or device == "cpu":
if (
not torch.cuda.is_available()
or device == "auto"
or torch.device(device).type == "cpu"
):
return default_value
return func(*args, **kwargs)

View File

@@ -10,6 +10,7 @@ import torch
import transformers
from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig, prepare_model_for_kbit_training
from peft.tuners.lora import QuantLinear
from transformers import ( # noqa: F401
AutoConfig,
AutoModelForCausalLM,
@@ -309,16 +310,26 @@ def load_model(
):
config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
except Exception as err: # pylint: disable=broad-exception-caught
LOG.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
@@ -466,10 +477,10 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
if isinstance(module, cls) or "Linear" in module.__class__.__name__:
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])

View File

@@ -676,6 +676,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
and cfg.val_set_size > 0
and cfg.save_steps
and cfg.eval_steps
and cfg.save_steps % cfg.eval_steps == 0
)
or False,