4bit quantized support (wip)

This commit is contained in:
Wing Lian
2023-04-17 11:37:39 -04:00
parent 12de7b7cf7
commit 77fca25f1b
11 changed files with 108 additions and 63 deletions

View File

@@ -29,8 +29,8 @@ shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl
```
- Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml]
- Install python dependencies `pip3 install -r requirements.txt`
- Configure accelerate `accelerate launch` or update `~/.cache/huggingface/accelerate/default_config.yaml`
- Install python dependencies `pip3 install -e .[triton]` or `pip3 install -e .[cuda]`
- Configure accelerate `accelerate config` or update `~/.cache/huggingface/accelerate/default_config.yaml`
```yaml
compute_environment: LOCAL_MACHINE

View File

@@ -11,7 +11,7 @@ datasets:
type: gpteacher
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
type: gpteacher
dataset_prepared_path: data/last_run
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter: lora
sequence_len: 2048
@@ -34,6 +34,7 @@ train_on_inputs: false
group_by_length: false
bf16: True
tf32: True
gradient_checkpointing:
early_stopping_patience:
resume_from_checkpoint:
local_rank:

View File

@@ -11,7 +11,7 @@ datasets:
type: gpteacher
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
type: gpteacher
dataset_prepared_path: data/last_run
dataset_prepared_path: last_run_prepared
val_set_size: 0.04
adapter: lora
lora_model_dir:

View File

@@ -11,7 +11,7 @@ datasets:
type: gpteacher
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
type: gpteacher
dataset_prepared_path: data/last_run
dataset_prepared_path: last_run_prepared
val_set_size: 0.04
adapter: lora
lora_model_dir:

View File

@@ -11,7 +11,7 @@ datasets:
type: gpteacher
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
type: gpteacher
dataset_prepared_path: data/last_run
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter: lora
lora_model_dir:

View File

@@ -1,3 +0,0 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"

View File

@@ -1,5 +1,4 @@
git+https://github.com/huggingface/peft.git
git+https://github.com/huggingface/transformers.git
transformers @ git+https://github.com/huggingface/transformers.git
attrdict
fire
PyYAML==6.0

View File

@@ -13,12 +13,6 @@ import transformers
import yaml
from attrdict import AttrDefault
from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_int8_training,
PeftModel,
)
from torch import nn
from transformers import (
AutoModelForCausalLM,
@@ -45,7 +39,7 @@ from axolotl.prompt_tokenizers import (
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "data/last_run"
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def setup_wandb_env_vars(cfg):
@@ -60,7 +54,11 @@ def setup_wandb_env_vars(cfg):
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", inference: bool=False):
def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, adapter="lora", inference: bool=False):
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
tokenizer = None
if adapter != "lora":
raise NotImplementedError(f"{adapter} peft adapter not available")
if "llama" in base_model:
@@ -70,7 +68,43 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
try:
if "llama" in base_model:
if cfg.load_4bit:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model
replace_peft_model_with_int4_lora_model()
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_int8_training,
PeftModel,
)
except Exception as e:
logging.exception(e)
raise e
try:
if cfg.load_4bit and "llama" in base_model:
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
from huggingface_hub import snapshot_download
cache_model_path = Path(snapshot_download(base_model))
# TODO search .glob for a .pt, .safetensor, or .bin
cache_model_path.glob("*.pt")
files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensor')) + list(cache_model_path.glob('*.bin'))
if len(files) > 0:
model_path = str(files[0])
else:
logging.warning("unable to find a cached model file, this will likely fail...")
model_path = str(cache_model_path)
model, tokenizer = load_llama_model_4bit_low_ram(
base_model_config if base_model_config else base_model,
model_path,
device_map=cfg.device_map,
groupsize=-1,
is_v1_model=True,
)
load_in_8bit = False
elif "llama" in base_model:
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
@@ -92,13 +126,14 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
device_map=cfg.device_map,
)
try:
if "llama" in base_model:
tokenizer = LlamaTokenizer.from_pretrained(model)
else:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
except:
tokenizer = AutoTokenizer.from_pretrained(base_model)
if not tokenizer:
try:
if "llama" in base_model:
tokenizer = LlamaTokenizer.from_pretrained(model)
else:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
except:
tokenizer = AutoTokenizer.from_pretrained(base_model)
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
@@ -107,7 +142,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.load_in_8bit:
if load_in_8bit:
model = prepare_model_for_int8_training(model)
lora_config = LoraConfig(
@@ -128,6 +163,16 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
if cfg.ddp:
model.to(f"cuda:{cfg.local_rank}")
if cfg.load_4bit:
# Scales to half
print('Fitting 4bit scales and zeros to half')
for n, m in model.named_modules():
if 'Autograd4bitQuantLinear' in str(type(m)) or 'Linear4bitLt' in str(type(m)):
if hasattr(m, "is_v1_model") and m.is_v1_model:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
m.bias = m.bias.half()
# TODO resume_from_checkpoint handling
model.print_trainable_parameters()
return model, tokenizer, lora_config
@@ -243,6 +288,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_arguments_kwargs["tf32"] = cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
if cfg.gradient_checkpointing is not None:
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
training_args = transformers.TrainingArguments(
per_device_train_batch_size=cfg.micro_batch_size,
@@ -260,7 +307,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
group_by_length=cfg.group_by_length,
report_to="wandb" if cfg.use_wandb else None,
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
gradient_checkpointing=cfg.gradient_checkpointing,
**training_arguments_kwargs,
)
@@ -356,11 +402,13 @@ def train(
cfg.bf16 = False
# Load the model and tokenizer
logging.info("loading model, tokenizer, and lora_config...")
model, tokenizer, lora_config = load_model(
cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter, inference=("inference" in kwargs)
cfg.base_model, cfg.base_model_config, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter, inference=("inference" in kwargs)
)
if "inference" in kwargs:
logging.info("calling do_inference function")
do_inference(cfg, model, tokenizer)
return
@@ -369,6 +417,7 @@ def train(
dataset = load_from_disk(cfg.dataset_prepared_path)
logging.info("Prepared dataset loaded from disk...")
else:
logging.info("Loading raw datasets...")
datasets = []
for d in cfg.datasets:
if Path(d.path).exists():
@@ -402,6 +451,7 @@ def train(
constant_len_dataset = ConstantLengthDataset(
tokenizer, datasets, seq_length=cfg.sequence_len
)
logging.info("merging, packing, shuffling, and splitting master dataset")
dataset = Dataset.from_list(
[_ for _ in constant_len_dataset]
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)

View File

@@ -1,33 +0,0 @@
[metadata]
name = axolotl
version = 0.1.0
description = You know you're going to axolotl questions
author = Wing Lian
author_email = wing.lian@gmail.com
license = MIT
[options]
package_dir =
=src
packages = find:
install_requires =
transformers @ git+https://github.com/huggingface/transformers.git@main
peft @ git+https://github.com/huggingface/peft.git@main
attrdict
fire
PyYAML == 6.0
black
bitsandbytes
datasets
accelerate
sentencepiece
wandb
flash-attn
einops
[options.packages.find]
where = src
[options.extras_require]
gptq_cuda = alpaca_lora_4bit[cuda] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[cuda]
gptq_triton = alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[triton]

30
setup.py Normal file
View File

@@ -0,0 +1,30 @@
import sys
from setuptools import setup, find_packages
install_requires = []
with open("./requirements.txt", "r") as requirements_file:
# don't include peft yet until we check the int4
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
reqs = [r for r in reqs if r[0] != "#"]
for r in reqs:
install_requires.append(r)
setup(
name='axolotl',
version='0.1',
description="You know you're going to axolotl questions",
package_dir={'': 'src'},
packages=find_packages(),
install_requires=install_requires,
extras_require={
None: [
"peft @ git+https://github.com/huggingface/peft.git",
],
'int4_cuda': [
"alpaca_lora_4bit[cuda] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[cuda]",
],
'int4_triton': [
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[triton]",
],
},
)

View File

@@ -31,6 +31,7 @@ class TokenizedPromptDataset(IterableDataset):
pass
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.