4bit quantized support (wip)

This commit is contained in:
Wing Lian
2023-04-17 11:37:39 -04:00
parent 12de7b7cf7
commit 77fca25f1b
11 changed files with 108 additions and 63 deletions

View File

@@ -13,12 +13,6 @@ import transformers
import yaml
from attrdict import AttrDefault
from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_int8_training,
PeftModel,
)
from torch import nn
from transformers import (
AutoModelForCausalLM,
@@ -45,7 +39,7 @@ from axolotl.prompt_tokenizers import (
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "data/last_run"
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def setup_wandb_env_vars(cfg):
@@ -60,7 +54,11 @@ def setup_wandb_env_vars(cfg):
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", inference: bool=False):
def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, adapter="lora", inference: bool=False):
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
tokenizer = None
if adapter != "lora":
raise NotImplementedError(f"{adapter} peft adapter not available")
if "llama" in base_model:
@@ -70,7 +68,43 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
try:
if "llama" in base_model:
if cfg.load_4bit:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model
replace_peft_model_with_int4_lora_model()
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_int8_training,
PeftModel,
)
except Exception as e:
logging.exception(e)
raise e
try:
if cfg.load_4bit and "llama" in base_model:
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
from huggingface_hub import snapshot_download
cache_model_path = Path(snapshot_download(base_model))
# TODO search .glob for a .pt, .safetensor, or .bin
cache_model_path.glob("*.pt")
files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensor')) + list(cache_model_path.glob('*.bin'))
if len(files) > 0:
model_path = str(files[0])
else:
logging.warning("unable to find a cached model file, this will likely fail...")
model_path = str(cache_model_path)
model, tokenizer = load_llama_model_4bit_low_ram(
base_model_config if base_model_config else base_model,
model_path,
device_map=cfg.device_map,
groupsize=-1,
is_v1_model=True,
)
load_in_8bit = False
elif "llama" in base_model:
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
@@ -92,13 +126,14 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
device_map=cfg.device_map,
)
try:
if "llama" in base_model:
tokenizer = LlamaTokenizer.from_pretrained(model)
else:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
except:
tokenizer = AutoTokenizer.from_pretrained(base_model)
if not tokenizer:
try:
if "llama" in base_model:
tokenizer = LlamaTokenizer.from_pretrained(model)
else:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
except:
tokenizer = AutoTokenizer.from_pretrained(base_model)
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
@@ -107,7 +142,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.load_in_8bit:
if load_in_8bit:
model = prepare_model_for_int8_training(model)
lora_config = LoraConfig(
@@ -128,6 +163,16 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
if cfg.ddp:
model.to(f"cuda:{cfg.local_rank}")
if cfg.load_4bit:
# Scales to half
print('Fitting 4bit scales and zeros to half')
for n, m in model.named_modules():
if 'Autograd4bitQuantLinear' in str(type(m)) or 'Linear4bitLt' in str(type(m)):
if hasattr(m, "is_v1_model") and m.is_v1_model:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
# TODO resume_from_checkpoint handling
model.print_trainable_parameters()
return model, tokenizer, lora_config
@@ -243,6 +288,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_arguments_kwargs["tf32"] = cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
if cfg.gradient_checkpointing is not None:
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
training_args = transformers.TrainingArguments(
per_device_train_batch_size=cfg.micro_batch_size,
@@ -260,7 +307,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
group_by_length=cfg.group_by_length,
report_to="wandb" if cfg.use_wandb else None,
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
gradient_checkpointing=cfg.gradient_checkpointing,
**training_arguments_kwargs,
)
@@ -356,11 +402,13 @@ def train(
cfg.bf16 = False
# Load the model and tokenizer
logging.info("loading model, tokenizer, and lora_config...")
model, tokenizer, lora_config = load_model(
cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter, inference=("inference" in kwargs)
cfg.base_model, cfg.base_model_config, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter, inference=("inference" in kwargs)
)
if "inference" in kwargs:
logging.info("calling do_inference function")
do_inference(cfg, model, tokenizer)
return
@@ -369,6 +417,7 @@ def train(
dataset = load_from_disk(cfg.dataset_prepared_path)
logging.info("Prepared dataset loaded from disk...")
else:
logging.info("Loading raw datasets...")
datasets = []
for d in cfg.datasets:
if Path(d.path).exists():
@@ -402,6 +451,7 @@ def train(
constant_len_dataset = ConstantLengthDataset(
tokenizer, datasets, seq_length=cfg.sequence_len
)
logging.info("merging, packing, shuffling, and splitting master dataset")
dataset = Dataset.from_list(
[_ for _ in constant_len_dataset]
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)