4bit quantized support (wip)
This commit is contained in:
@@ -29,8 +29,8 @@ shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl
|
|||||||
```
|
```
|
||||||
|
|
||||||
- Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml]
|
- Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml]
|
||||||
- Install python dependencies `pip3 install -r requirements.txt`
|
- Install python dependencies `pip3 install -e .[triton]` or `pip3 install -e .[cuda]`
|
||||||
- Configure accelerate `accelerate launch` or update `~/.cache/huggingface/accelerate/default_config.yaml`
|
- Configure accelerate `accelerate config` or update `~/.cache/huggingface/accelerate/default_config.yaml`
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
compute_environment: LOCAL_MACHINE
|
compute_environment: LOCAL_MACHINE
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: gpteacher
|
type: gpteacher
|
||||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||||
type: gpteacher
|
type: gpteacher
|
||||||
dataset_prepared_path: data/last_run
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
adapter: lora
|
adapter: lora
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
@@ -34,6 +34,7 @@ train_on_inputs: false
|
|||||||
group_by_length: false
|
group_by_length: false
|
||||||
bf16: True
|
bf16: True
|
||||||
tf32: True
|
tf32: True
|
||||||
|
gradient_checkpointing:
|
||||||
early_stopping_patience:
|
early_stopping_patience:
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: gpteacher
|
type: gpteacher
|
||||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||||
type: gpteacher
|
type: gpteacher
|
||||||
dataset_prepared_path: data/last_run
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.04
|
val_set_size: 0.04
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: gpteacher
|
type: gpteacher
|
||||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||||
type: gpteacher
|
type: gpteacher
|
||||||
dataset_prepared_path: data/last_run
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.04
|
val_set_size: 0.04
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: gpteacher
|
type: gpteacher
|
||||||
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
||||||
type: gpteacher
|
type: gpteacher
|
||||||
dataset_prepared_path: data/last_run
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
[build-system]
|
|
||||||
requires = ["setuptools", "wheel"]
|
|
||||||
build-backend = "setuptools.build_meta"
|
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
git+https://github.com/huggingface/peft.git
|
transformers @ git+https://github.com/huggingface/transformers.git
|
||||||
git+https://github.com/huggingface/transformers.git
|
|
||||||
attrdict
|
attrdict
|
||||||
fire
|
fire
|
||||||
PyYAML==6.0
|
PyYAML==6.0
|
||||||
|
|||||||
@@ -13,12 +13,6 @@ import transformers
|
|||||||
import yaml
|
import yaml
|
||||||
from attrdict import AttrDefault
|
from attrdict import AttrDefault
|
||||||
from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
|
from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
|
||||||
from peft import (
|
|
||||||
LoraConfig,
|
|
||||||
get_peft_model,
|
|
||||||
prepare_model_for_int8_training,
|
|
||||||
PeftModel,
|
|
||||||
)
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
@@ -45,7 +39,7 @@ from axolotl.prompt_tokenizers import (
|
|||||||
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
||||||
|
|
||||||
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
||||||
DEFAULT_DATASET_PREPARED_PATH = "data/last_run"
|
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||||
|
|
||||||
|
|
||||||
def setup_wandb_env_vars(cfg):
|
def setup_wandb_env_vars(cfg):
|
||||||
@@ -60,7 +54,11 @@ def setup_wandb_env_vars(cfg):
|
|||||||
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
||||||
|
|
||||||
|
|
||||||
def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", inference: bool=False):
|
def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, adapter="lora", inference: bool=False):
|
||||||
|
# TODO refactor as a kwarg
|
||||||
|
load_in_8bit = cfg.load_in_8bit
|
||||||
|
tokenizer = None
|
||||||
|
|
||||||
if adapter != "lora":
|
if adapter != "lora":
|
||||||
raise NotImplementedError(f"{adapter} peft adapter not available")
|
raise NotImplementedError(f"{adapter} peft adapter not available")
|
||||||
if "llama" in base_model:
|
if "llama" in base_model:
|
||||||
@@ -70,7 +68,43 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
|
|||||||
|
|
||||||
torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
|
torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
|
||||||
try:
|
try:
|
||||||
if "llama" in base_model:
|
if cfg.load_4bit:
|
||||||
|
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model
|
||||||
|
replace_peft_model_with_int4_lora_model()
|
||||||
|
|
||||||
|
from peft import (
|
||||||
|
LoraConfig,
|
||||||
|
get_peft_model,
|
||||||
|
prepare_model_for_int8_training,
|
||||||
|
PeftModel,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
try:
|
||||||
|
if cfg.load_4bit and "llama" in base_model:
|
||||||
|
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
cache_model_path = Path(snapshot_download(base_model))
|
||||||
|
# TODO search .glob for a .pt, .safetensor, or .bin
|
||||||
|
cache_model_path.glob("*.pt")
|
||||||
|
files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensor')) + list(cache_model_path.glob('*.bin'))
|
||||||
|
if len(files) > 0:
|
||||||
|
model_path = str(files[0])
|
||||||
|
else:
|
||||||
|
logging.warning("unable to find a cached model file, this will likely fail...")
|
||||||
|
model_path = str(cache_model_path)
|
||||||
|
model, tokenizer = load_llama_model_4bit_low_ram(
|
||||||
|
base_model_config if base_model_config else base_model,
|
||||||
|
model_path,
|
||||||
|
device_map=cfg.device_map,
|
||||||
|
groupsize=-1,
|
||||||
|
is_v1_model=True,
|
||||||
|
)
|
||||||
|
load_in_8bit = False
|
||||||
|
elif "llama" in base_model:
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
load_in_8bit=cfg.load_in_8bit,
|
load_in_8bit=cfg.load_in_8bit,
|
||||||
@@ -92,13 +126,14 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
|
|||||||
device_map=cfg.device_map,
|
device_map=cfg.device_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
if not tokenizer:
|
||||||
if "llama" in base_model:
|
try:
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(model)
|
if "llama" in base_model:
|
||||||
else:
|
tokenizer = LlamaTokenizer.from_pretrained(model)
|
||||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
|
else:
|
||||||
except:
|
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
except:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||||||
|
|
||||||
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
||||||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
||||||
@@ -107,7 +142,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
|
|||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
if cfg.load_in_8bit:
|
if load_in_8bit:
|
||||||
model = prepare_model_for_int8_training(model)
|
model = prepare_model_for_int8_training(model)
|
||||||
|
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
@@ -128,6 +163,16 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
|
|||||||
if cfg.ddp:
|
if cfg.ddp:
|
||||||
model.to(f"cuda:{cfg.local_rank}")
|
model.to(f"cuda:{cfg.local_rank}")
|
||||||
|
|
||||||
|
if cfg.load_4bit:
|
||||||
|
# Scales to half
|
||||||
|
print('Fitting 4bit scales and zeros to half')
|
||||||
|
for n, m in model.named_modules():
|
||||||
|
if 'Autograd4bitQuantLinear' in str(type(m)) or 'Linear4bitLt' in str(type(m)):
|
||||||
|
if hasattr(m, "is_v1_model") and m.is_v1_model:
|
||||||
|
m.zeros = m.zeros.half()
|
||||||
|
m.scales = m.scales.half()
|
||||||
|
m.bias = m.bias.half()
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
model.print_trainable_parameters()
|
model.print_trainable_parameters()
|
||||||
return model, tokenizer, lora_config
|
return model, tokenizer, lora_config
|
||||||
@@ -243,6 +288,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
training_arguments_kwargs["tf32"] = cfg.tf32
|
training_arguments_kwargs["tf32"] = cfg.tf32
|
||||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||||
|
if cfg.gradient_checkpointing is not None:
|
||||||
|
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
||||||
|
|
||||||
training_args = transformers.TrainingArguments(
|
training_args = transformers.TrainingArguments(
|
||||||
per_device_train_batch_size=cfg.micro_batch_size,
|
per_device_train_batch_size=cfg.micro_batch_size,
|
||||||
@@ -260,7 +307,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
group_by_length=cfg.group_by_length,
|
group_by_length=cfg.group_by_length,
|
||||||
report_to="wandb" if cfg.use_wandb else None,
|
report_to="wandb" if cfg.use_wandb else None,
|
||||||
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
||||||
gradient_checkpointing=cfg.gradient_checkpointing,
|
|
||||||
**training_arguments_kwargs,
|
**training_arguments_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -356,11 +402,13 @@ def train(
|
|||||||
cfg.bf16 = False
|
cfg.bf16 = False
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
|
logging.info("loading model, tokenizer, and lora_config...")
|
||||||
model, tokenizer, lora_config = load_model(
|
model, tokenizer, lora_config = load_model(
|
||||||
cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter, inference=("inference" in kwargs)
|
cfg.base_model, cfg.base_model_config, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter, inference=("inference" in kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
if "inference" in kwargs:
|
if "inference" in kwargs:
|
||||||
|
logging.info("calling do_inference function")
|
||||||
do_inference(cfg, model, tokenizer)
|
do_inference(cfg, model, tokenizer)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -369,6 +417,7 @@ def train(
|
|||||||
dataset = load_from_disk(cfg.dataset_prepared_path)
|
dataset = load_from_disk(cfg.dataset_prepared_path)
|
||||||
logging.info("Prepared dataset loaded from disk...")
|
logging.info("Prepared dataset loaded from disk...")
|
||||||
else:
|
else:
|
||||||
|
logging.info("Loading raw datasets...")
|
||||||
datasets = []
|
datasets = []
|
||||||
for d in cfg.datasets:
|
for d in cfg.datasets:
|
||||||
if Path(d.path).exists():
|
if Path(d.path).exists():
|
||||||
@@ -402,6 +451,7 @@ def train(
|
|||||||
constant_len_dataset = ConstantLengthDataset(
|
constant_len_dataset = ConstantLengthDataset(
|
||||||
tokenizer, datasets, seq_length=cfg.sequence_len
|
tokenizer, datasets, seq_length=cfg.sequence_len
|
||||||
)
|
)
|
||||||
|
logging.info("merging, packing, shuffling, and splitting master dataset")
|
||||||
dataset = Dataset.from_list(
|
dataset = Dataset.from_list(
|
||||||
[_ for _ in constant_len_dataset]
|
[_ for _ in constant_len_dataset]
|
||||||
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
|
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
|
||||||
|
|||||||
33
setup.cfg
33
setup.cfg
@@ -1,33 +0,0 @@
|
|||||||
[metadata]
|
|
||||||
name = axolotl
|
|
||||||
version = 0.1.0
|
|
||||||
description = You know you're going to axolotl questions
|
|
||||||
author = Wing Lian
|
|
||||||
author_email = wing.lian@gmail.com
|
|
||||||
license = MIT
|
|
||||||
|
|
||||||
[options]
|
|
||||||
package_dir =
|
|
||||||
=src
|
|
||||||
packages = find:
|
|
||||||
install_requires =
|
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git@main
|
|
||||||
peft @ git+https://github.com/huggingface/peft.git@main
|
|
||||||
attrdict
|
|
||||||
fire
|
|
||||||
PyYAML == 6.0
|
|
||||||
black
|
|
||||||
bitsandbytes
|
|
||||||
datasets
|
|
||||||
accelerate
|
|
||||||
sentencepiece
|
|
||||||
wandb
|
|
||||||
flash-attn
|
|
||||||
einops
|
|
||||||
|
|
||||||
[options.packages.find]
|
|
||||||
where = src
|
|
||||||
|
|
||||||
[options.extras_require]
|
|
||||||
gptq_cuda = alpaca_lora_4bit[cuda] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[cuda]
|
|
||||||
gptq_triton = alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[triton]
|
|
||||||
30
setup.py
Normal file
30
setup.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import sys
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
install_requires = []
|
||||||
|
with open("./requirements.txt", "r") as requirements_file:
|
||||||
|
# don't include peft yet until we check the int4
|
||||||
|
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
|
||||||
|
reqs = [r for r in reqs if r[0] != "#"]
|
||||||
|
for r in reqs:
|
||||||
|
install_requires.append(r)
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='axolotl',
|
||||||
|
version='0.1',
|
||||||
|
description="You know you're going to axolotl questions",
|
||||||
|
package_dir={'': 'src'},
|
||||||
|
packages=find_packages(),
|
||||||
|
install_requires=install_requires,
|
||||||
|
extras_require={
|
||||||
|
None: [
|
||||||
|
"peft @ git+https://github.com/huggingface/peft.git",
|
||||||
|
],
|
||||||
|
'int4_cuda': [
|
||||||
|
"alpaca_lora_4bit[cuda] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[cuda]",
|
||||||
|
],
|
||||||
|
'int4_triton': [
|
||||||
|
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[triton]",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
@@ -31,6 +31,7 @@ class TokenizedPromptDataset(IterableDataset):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# TODO this isn't the best since it can't interleave datasets
|
||||||
class ConstantLengthDataset(IterableDataset):
|
class ConstantLengthDataset(IterableDataset):
|
||||||
"""
|
"""
|
||||||
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
||||||
|
|||||||
Reference in New Issue
Block a user