4bit quantized support (wip)
This commit is contained in:
@@ -29,8 +29,8 @@ shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl
|
||||
```
|
||||
|
||||
- Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml]
|
||||
- Install python dependencies `pip3 install -r requirements.txt`
|
||||
- Configure accelerate `accelerate launch` or update `~/.cache/huggingface/accelerate/default_config.yaml`
|
||||
- Install python dependencies `pip3 install -e .[triton]` or `pip3 install -e .[cuda]`
|
||||
- Configure accelerate `accelerate config` or update `~/.cache/huggingface/accelerate/default_config.yaml`
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
type: gpteacher
|
||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
type: gpteacher
|
||||
dataset_prepared_path: data/last_run
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
sequence_len: 2048
|
||||
@@ -34,6 +34,7 @@ train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: True
|
||||
tf32: True
|
||||
gradient_checkpointing:
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
type: gpteacher
|
||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
type: gpteacher
|
||||
dataset_prepared_path: data/last_run
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
type: gpteacher
|
||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
type: gpteacher
|
||||
dataset_prepared_path: data/last_run
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.04
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
type: gpteacher
|
||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||
type: gpteacher
|
||||
dataset_prepared_path: data/last_run
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
[build-system]
|
||||
requires = ["setuptools", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
@@ -1,5 +1,4 @@
|
||||
git+https://github.com/huggingface/peft.git
|
||||
git+https://github.com/huggingface/transformers.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git
|
||||
attrdict
|
||||
fire
|
||||
PyYAML==6.0
|
||||
|
||||
@@ -13,12 +13,6 @@ import transformers
|
||||
import yaml
|
||||
from attrdict import AttrDefault
|
||||
from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
prepare_model_for_int8_training,
|
||||
PeftModel,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
@@ -45,7 +39,7 @@ from axolotl.prompt_tokenizers import (
|
||||
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
||||
|
||||
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
||||
DEFAULT_DATASET_PREPARED_PATH = "data/last_run"
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
|
||||
|
||||
def setup_wandb_env_vars(cfg):
|
||||
@@ -60,7 +54,11 @@ def setup_wandb_env_vars(cfg):
|
||||
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
||||
|
||||
|
||||
def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", inference: bool=False):
|
||||
def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, adapter="lora", inference: bool=False):
|
||||
# TODO refactor as a kwarg
|
||||
load_in_8bit = cfg.load_in_8bit
|
||||
tokenizer = None
|
||||
|
||||
if adapter != "lora":
|
||||
raise NotImplementedError(f"{adapter} peft adapter not available")
|
||||
if "llama" in base_model:
|
||||
@@ -70,7 +68,43 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
|
||||
|
||||
torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
|
||||
try:
|
||||
if "llama" in base_model:
|
||||
if cfg.load_4bit:
|
||||
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model
|
||||
replace_peft_model_with_int4_lora_model()
|
||||
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
prepare_model_for_int8_training,
|
||||
PeftModel,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
raise e
|
||||
|
||||
try:
|
||||
if cfg.load_4bit and "llama" in base_model:
|
||||
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
cache_model_path = Path(snapshot_download(base_model))
|
||||
# TODO search .glob for a .pt, .safetensor, or .bin
|
||||
cache_model_path.glob("*.pt")
|
||||
files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensor')) + list(cache_model_path.glob('*.bin'))
|
||||
if len(files) > 0:
|
||||
model_path = str(files[0])
|
||||
else:
|
||||
logging.warning("unable to find a cached model file, this will likely fail...")
|
||||
model_path = str(cache_model_path)
|
||||
model, tokenizer = load_llama_model_4bit_low_ram(
|
||||
base_model_config if base_model_config else base_model,
|
||||
model_path,
|
||||
device_map=cfg.device_map,
|
||||
groupsize=-1,
|
||||
is_v1_model=True,
|
||||
)
|
||||
load_in_8bit = False
|
||||
elif "llama" in base_model:
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit,
|
||||
@@ -92,13 +126,14 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
|
||||
device_map=cfg.device_map,
|
||||
)
|
||||
|
||||
try:
|
||||
if "llama" in base_model:
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model)
|
||||
else:
|
||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||||
if not tokenizer:
|
||||
try:
|
||||
if "llama" in base_model:
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model)
|
||||
else:
|
||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||||
|
||||
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
||||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
||||
@@ -107,7 +142,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
if cfg.load_in_8bit:
|
||||
if load_in_8bit:
|
||||
model = prepare_model_for_int8_training(model)
|
||||
|
||||
lora_config = LoraConfig(
|
||||
@@ -128,6 +163,16 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
|
||||
if cfg.ddp:
|
||||
model.to(f"cuda:{cfg.local_rank}")
|
||||
|
||||
if cfg.load_4bit:
|
||||
# Scales to half
|
||||
print('Fitting 4bit scales and zeros to half')
|
||||
for n, m in model.named_modules():
|
||||
if 'Autograd4bitQuantLinear' in str(type(m)) or 'Linear4bitLt' in str(type(m)):
|
||||
if hasattr(m, "is_v1_model") and m.is_v1_model:
|
||||
m.zeros = m.zeros.half()
|
||||
m.scales = m.scales.half()
|
||||
m.bias = m.bias.half()
|
||||
|
||||
# TODO resume_from_checkpoint handling
|
||||
model.print_trainable_parameters()
|
||||
return model, tokenizer, lora_config
|
||||
@@ -243,6 +288,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
training_arguments_kwargs["tf32"] = cfg.tf32
|
||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||
if cfg.gradient_checkpointing is not None:
|
||||
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
||||
|
||||
training_args = transformers.TrainingArguments(
|
||||
per_device_train_batch_size=cfg.micro_batch_size,
|
||||
@@ -260,7 +307,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
group_by_length=cfg.group_by_length,
|
||||
report_to="wandb" if cfg.use_wandb else None,
|
||||
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
||||
gradient_checkpointing=cfg.gradient_checkpointing,
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
|
||||
@@ -356,11 +402,13 @@ def train(
|
||||
cfg.bf16 = False
|
||||
|
||||
# Load the model and tokenizer
|
||||
logging.info("loading model, tokenizer, and lora_config...")
|
||||
model, tokenizer, lora_config = load_model(
|
||||
cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter, inference=("inference" in kwargs)
|
||||
cfg.base_model, cfg.base_model_config, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter, inference=("inference" in kwargs)
|
||||
)
|
||||
|
||||
if "inference" in kwargs:
|
||||
logging.info("calling do_inference function")
|
||||
do_inference(cfg, model, tokenizer)
|
||||
return
|
||||
|
||||
@@ -369,6 +417,7 @@ def train(
|
||||
dataset = load_from_disk(cfg.dataset_prepared_path)
|
||||
logging.info("Prepared dataset loaded from disk...")
|
||||
else:
|
||||
logging.info("Loading raw datasets...")
|
||||
datasets = []
|
||||
for d in cfg.datasets:
|
||||
if Path(d.path).exists():
|
||||
@@ -402,6 +451,7 @@ def train(
|
||||
constant_len_dataset = ConstantLengthDataset(
|
||||
tokenizer, datasets, seq_length=cfg.sequence_len
|
||||
)
|
||||
logging.info("merging, packing, shuffling, and splitting master dataset")
|
||||
dataset = Dataset.from_list(
|
||||
[_ for _ in constant_len_dataset]
|
||||
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
|
||||
|
||||
33
setup.cfg
33
setup.cfg
@@ -1,33 +0,0 @@
|
||||
[metadata]
|
||||
name = axolotl
|
||||
version = 0.1.0
|
||||
description = You know you're going to axolotl questions
|
||||
author = Wing Lian
|
||||
author_email = wing.lian@gmail.com
|
||||
license = MIT
|
||||
|
||||
[options]
|
||||
package_dir =
|
||||
=src
|
||||
packages = find:
|
||||
install_requires =
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@main
|
||||
peft @ git+https://github.com/huggingface/peft.git@main
|
||||
attrdict
|
||||
fire
|
||||
PyYAML == 6.0
|
||||
black
|
||||
bitsandbytes
|
||||
datasets
|
||||
accelerate
|
||||
sentencepiece
|
||||
wandb
|
||||
flash-attn
|
||||
einops
|
||||
|
||||
[options.packages.find]
|
||||
where = src
|
||||
|
||||
[options.extras_require]
|
||||
gptq_cuda = alpaca_lora_4bit[cuda] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[cuda]
|
||||
gptq_triton = alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[triton]
|
||||
30
setup.py
Normal file
30
setup.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import sys
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
install_requires = []
|
||||
with open("./requirements.txt", "r") as requirements_file:
|
||||
# don't include peft yet until we check the int4
|
||||
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
|
||||
reqs = [r for r in reqs if r[0] != "#"]
|
||||
for r in reqs:
|
||||
install_requires.append(r)
|
||||
|
||||
setup(
|
||||
name='axolotl',
|
||||
version='0.1',
|
||||
description="You know you're going to axolotl questions",
|
||||
package_dir={'': 'src'},
|
||||
packages=find_packages(),
|
||||
install_requires=install_requires,
|
||||
extras_require={
|
||||
None: [
|
||||
"peft @ git+https://github.com/huggingface/peft.git",
|
||||
],
|
||||
'int4_cuda': [
|
||||
"alpaca_lora_4bit[cuda] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[cuda]",
|
||||
],
|
||||
'int4_triton': [
|
||||
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[triton]",
|
||||
],
|
||||
},
|
||||
)
|
||||
@@ -31,6 +31,7 @@ class TokenizedPromptDataset(IterableDataset):
|
||||
pass
|
||||
|
||||
|
||||
# TODO this isn't the best since it can't interleave datasets
|
||||
class ConstantLengthDataset(IterableDataset):
|
||||
"""
|
||||
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
||||
|
||||
Reference in New Issue
Block a user