Compare commits
38 Commits
v0.2.0
...
a6f5e5eaec
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a6f5e5eaec | ||
|
|
5a631b305b | ||
|
|
f94dd626f0 | ||
|
|
5079753b7a | ||
|
|
0136f510f2 | ||
|
|
9b8585dc70 | ||
|
|
8eb5811d4e | ||
|
|
e0011fdf55 | ||
|
|
6e9e98720e | ||
|
|
c2a0792680 | ||
|
|
b267d24a2b | ||
|
|
5c3f5db38b | ||
|
|
e3d03745ba | ||
|
|
fac46002d4 | ||
|
|
33d40179ba | ||
|
|
dcb03d6da4 | ||
|
|
0e4be625ae | ||
|
|
bdc4bd7d4e | ||
|
|
2d0ba3b818 | ||
|
|
c7021e191f | ||
|
|
c56818b119 | ||
|
|
2675fb756e | ||
|
|
1076bcbbca | ||
|
|
2daa6835f0 | ||
|
|
e3c494ca7b | ||
|
|
ad0ea6aaab | ||
|
|
876edd83d0 | ||
|
|
6cb2310592 | ||
|
|
6fa40bf8ad | ||
|
|
3aad5f3b3e | ||
|
|
39a208c2bc | ||
|
|
2520ecd6df | ||
|
|
c5b0af1a7e | ||
|
|
988aeb9c34 | ||
|
|
cf61f14bff | ||
|
|
0abcd71a85 | ||
|
|
c43c5c84ff | ||
|
|
36ec6e1a0e |
12
.github/workflows/base.yml
vendored
12
.github/workflows/base.yml
vendored
@@ -16,13 +16,22 @@ jobs:
|
||||
include:
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras:
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras:
|
||||
- cuda: "117"
|
||||
cuda_version: 11.7.0
|
||||
python_version: "3.9"
|
||||
pytorch: 1.13.1
|
||||
axolotl_extras:
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras: gptq
|
||||
steps:
|
||||
@@ -46,12 +55,13 @@ jobs:
|
||||
context: .
|
||||
file: ./docker/Dockerfile-base
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
build-args: |
|
||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTHON_VERSION=${{ matrix.python_version }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras }}
|
||||
|
||||
24
.github/workflows/main.yml
vendored
24
.github/workflows/main.yml
vendored
@@ -15,14 +15,22 @@ jobs:
|
||||
include:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras: gptq
|
||||
- cuda: cu117
|
||||
cuda_version: 11.7.0
|
||||
python_version: "3.9"
|
||||
pytorch: 1.13.1
|
||||
axolotl_extras:
|
||||
runs-on: self-hosted
|
||||
@@ -46,10 +54,10 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-base-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
@@ -62,14 +70,22 @@ jobs:
|
||||
include:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
axolotl_extras: gptq
|
||||
- cuda: cu117
|
||||
cuda_version: 11.7.0
|
||||
python_version: "3.9"
|
||||
pytorch: 1.13.1
|
||||
axolotl_extras:
|
||||
runs-on: self-hosted
|
||||
@@ -93,10 +109,10 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
file: ./docker/Dockerfile-runpod
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
@@ -5,6 +5,9 @@ exclude = venv
|
||||
[mypy-alpaca_lora_4bit.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-axolotl.monkeypatch.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-flash_attn.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
@@ -31,3 +34,6 @@ ignore_missing_imports = True
|
||||
|
||||
[mypy-addict]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-xformers.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
24
README.md
24
README.md
@@ -27,7 +27,7 @@
|
||||
|
||||
## Quickstart ⚡
|
||||
|
||||
**Requirements**: Python 3.9.
|
||||
**Requirements**: Python 3.9 and Pytorch 2.0.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
@@ -58,7 +58,9 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
|
||||
- Conda/Pip venv
|
||||
1. Install python **3.9**
|
||||
|
||||
2. Install python dependencies with ONE of the following:
|
||||
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
||||
|
||||
3. Install python dependencies with ONE of the following:
|
||||
- `pip3 install -e .` (recommended, supports QLoRA, no gptq/int4 support)
|
||||
- `pip3 install -e .[gptq]` (next best if you don't need QLoRA, but want to use gptq)
|
||||
- `pip3 install -e .[gptq_triton]`
|
||||
@@ -171,6 +173,9 @@ base_model_ignore_patterns:
|
||||
# if the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# you can set that here, or leave this empty to default to base_model
|
||||
base_model_config: ./llama-7b-hf
|
||||
# Optional tokenizer configuration override in case you want to use a different tokenizer
|
||||
# than the one defined in the base model
|
||||
tokenizer_config:
|
||||
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
|
||||
model_type: AutoModelForCausalLM
|
||||
# Corresponding tokenizer for the model AutoTokenizer is a good choice
|
||||
@@ -260,7 +265,7 @@ wandb_log_model: # 'checkpoint'
|
||||
output_dir: ./completed-model
|
||||
|
||||
# training hyperparameters
|
||||
batch_size: 8
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
eval_batch_size: 2
|
||||
num_epochs: 3
|
||||
@@ -300,6 +305,9 @@ weight_decay:
|
||||
xformers_attention:
|
||||
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
|
||||
flash_attention: # require a100 for llama
|
||||
# whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
|
||||
# resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
@@ -403,6 +411,16 @@ Try to turn off xformers.
|
||||
|
||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
||||
|
||||
## Badge ❤🏷️
|
||||
|
||||
Building something cool with Axolotl? Consider adding a badge to your model card.
|
||||
|
||||
```markdown
|
||||
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
|
||||
```
|
||||
|
||||
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
|
||||
|
||||
## Contributing 🤝
|
||||
|
||||
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
|
||||
|
||||
@@ -26,7 +26,7 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-alpaca
|
||||
batch_size: 32
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 4
|
||||
num_epochs: 5
|
||||
learning_rate: 0.0003
|
||||
|
||||
@@ -23,7 +23,7 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-llama-alpaca
|
||||
batch_size: 32
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 16
|
||||
num_epochs: 3
|
||||
learning_rate: 0.00003
|
||||
|
||||
@@ -25,7 +25,7 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./gpt4all-neox-20b
|
||||
batch_size: 48
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 4
|
||||
num_epochs: 5
|
||||
learning_rate: 0.00003
|
||||
|
||||
@@ -23,7 +23,7 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./llama-13b-sharegpt
|
||||
batch_size: 64
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
warmup_steps: 1000
|
||||
save_steps:
|
||||
|
||||
@@ -29,7 +29,7 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-llama-alpaca
|
||||
batch_size: 128
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 16
|
||||
warmup_steps: 1000
|
||||
save_steps:
|
||||
|
||||
@@ -26,7 +26,7 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-test
|
||||
batch_size: 8
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
|
||||
@@ -28,7 +28,7 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-llama-alpaca
|
||||
batch_size: 128
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 16
|
||||
num_epochs: 5
|
||||
learning_rate: 0.00003
|
||||
|
||||
@@ -24,7 +24,7 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./jeopardy-bot-7b
|
||||
batch_size: 4
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 2
|
||||
optimizer: adamw_bnb_8bit
|
||||
|
||||
@@ -28,7 +28,7 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-alpaca
|
||||
batch_size: 48
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 4
|
||||
num_epochs: 5
|
||||
learning_rate: 0.00001
|
||||
|
||||
@@ -26,7 +26,7 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-test
|
||||
batch_size: 4
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
|
||||
@@ -53,7 +53,8 @@ wandb_log_model:
|
||||
# where to save the finsihed model to
|
||||
output_dir: ./completed-model
|
||||
# training hyperparameters
|
||||
batch_size: 8
|
||||
gradient_accumulation_steps: 1
|
||||
batch_size:
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
|
||||
@@ -22,7 +22,7 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./stable-alpaca-3b
|
||||
batch_size: 2
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
|
||||
@@ -30,7 +30,7 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-reflect
|
||||
batch_size: 8
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
learning_rate: 0.00003
|
||||
|
||||
@@ -52,6 +52,8 @@ RUN git clone https://github.com/HazyResearch/flash-attention.git && \
|
||||
|
||||
FROM base-builder AS deepspeed-builder
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
|
||||
|
||||
@@ -26,7 +26,7 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./llama-7b-lora-int4
|
||||
batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
|
||||
@@ -24,7 +24,7 @@ wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./mpt-alpaca-7b
|
||||
batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
|
||||
BIN
image/axolotl-badge-web.png
Normal file
BIN
image/axolotl-badge-web.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 11 KiB |
@@ -1,6 +1,7 @@
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git
|
||||
bitsandbytes>=0.39.0
|
||||
accelerate
|
||||
addict
|
||||
fire
|
||||
PyYAML==6.0
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
import fire
|
||||
import torch
|
||||
import yaml
|
||||
from transformers import GenerationConfig
|
||||
|
||||
from axolotl.utils.data import load_prepare_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -73,26 +74,33 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
||||
instruction = get_multi_line_input()
|
||||
if not instruction:
|
||||
return
|
||||
prompt: str = next(prompter_module().build_prompt(instruction=instruction))
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# gc = GenerationConfig() # TODO swap out and use this
|
||||
generated = model.generate(
|
||||
inputs=batch["input_ids"].to(cfg.device),
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=100,
|
||||
max_new_tokens=1024,
|
||||
temperature=0.9,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
output_scores=False,
|
||||
)
|
||||
generated = model.generate(
|
||||
inputs=batch["input_ids"].to(cfg.device),
|
||||
generation_config=generation_config,
|
||||
)
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
@@ -149,17 +157,23 @@ def train(
|
||||
else:
|
||||
cfg[k] = kwargs[k]
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
# setup some derived config / hyperparams
|
||||
cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
|
||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||
cfg.batch_size // cfg.micro_batch_size
|
||||
)
|
||||
cfg.batch_size = (
|
||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||
)
|
||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
choose_device(cfg)
|
||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||
if cfg.ddp:
|
||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||
cfg.gradient_accumulation_steps = (
|
||||
cfg.gradient_accumulation_steps // cfg.world_size
|
||||
)
|
||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
if cfg.device == "mps":
|
||||
cfg.load_in_8bit = False
|
||||
@@ -168,11 +182,10 @@ def train(
|
||||
cfg.fp16 = True
|
||||
cfg.bf16 = False
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
# load the tokenizer first
|
||||
logging.info("loading tokenizer...")
|
||||
tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
|
||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||
logging.info(f"loading tokenizer... {tokenizer_config}")
|
||||
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
||||
|
||||
if check_not_in(
|
||||
["inference", "shard", "merge_lora"], kwargs
|
||||
|
||||
@@ -127,6 +127,11 @@ class ConstantLengthDataset(IterableDataset):
|
||||
input_ids = example["input_ids"]
|
||||
attention_mask = example["attention_mask"]
|
||||
labels = example["labels"]
|
||||
if (
|
||||
buffer["input_ids"]
|
||||
and input_ids[0] == self.tokenizer.bos_token_id
|
||||
):
|
||||
attention_mask[0] = 0
|
||||
|
||||
if add_concat_token:
|
||||
input_ids.append(self.concat_token_id)
|
||||
|
||||
@@ -25,6 +25,7 @@ def forward(
|
||||
|
||||
attention_mask: [bsz, q_len]
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
|
||||
233
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
Normal file
233
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import transformers.models.llama.modeling_llama
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
logging.error("xformers not found! Please install it before trying to use it.")
|
||||
|
||||
|
||||
def hijack_llama_attention():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
||||
|
||||
|
||||
def hijack_llama_sdp_attention():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||
sdp_attention_forward
|
||||
)
|
||||
|
||||
|
||||
def xformers_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
(
|
||||
query_states,
|
||||
key_states,
|
||||
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# We only apply xformers optimizations if we don't need to output the whole attention matrix
|
||||
if not output_attentions:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
||||
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
||||
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||
attn_output = xformers.ops.memory_efficient_attention(
|
||||
query_states, key_states, value_states, attn_bias=None
|
||||
)
|
||||
else:
|
||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||
attn_output = xformers.ops.memory_efficient_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_bias=xformers.ops.LowerTriangularMask(),
|
||||
)
|
||||
attn_weights = None
|
||||
else:
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(
|
||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||
)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def sdp_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
(
|
||||
query_states,
|
||||
key_states,
|
||||
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# We only apply sdp attention if we don't need to output the whole attention matrix
|
||||
if not output_attentions:
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
is_causal=False,
|
||||
)
|
||||
attn_weights = None
|
||||
else:
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(
|
||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||
)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
@@ -233,6 +233,7 @@ def load_tokenized_prepared_datasets(
|
||||
datasets.append(ds_wrapper)
|
||||
else:
|
||||
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
||||
raise ValueError(f"unhandled prompt tokenization strategy: {d.type}")
|
||||
logging.info("tokenizing, merging, and shuffling master dataset")
|
||||
|
||||
samples: List[int] = []
|
||||
|
||||
@@ -10,9 +10,14 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM # noqa: F401
|
||||
from transformers import PreTrainedModel # noqa: F401
|
||||
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
|
||||
from transformers import ( # noqa: F401
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
LlamaConfig,
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers import LlamaForCausalLM
|
||||
@@ -25,24 +30,23 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from peft import PeftConfig # noqa: F401
|
||||
from transformers import PreTrainedTokenizer # noqa: F401
|
||||
|
||||
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||
|
||||
|
||||
def load_tokenizer(
|
||||
base_model_config,
|
||||
tokenizer_config,
|
||||
tokenizer_type,
|
||||
cfg,
|
||||
):
|
||||
if tokenizer_type:
|
||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
||||
base_model_config,
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
base_model_config,
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
)
|
||||
|
||||
@@ -97,12 +101,19 @@ def load_model(
|
||||
logging.info("patching with flash attention")
|
||||
replace_llama_attn_with_flash_attn()
|
||||
elif is_llama_derived_model and cfg.xformers_attention:
|
||||
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import (
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_attention,
|
||||
)
|
||||
|
||||
logging.info("patching with xformers attention")
|
||||
hijack_llama_attention()
|
||||
elif is_llama_derived_model and cfg.sdp_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_sdp_attention,
|
||||
)
|
||||
|
||||
logging.info("patching with sdp attention")
|
||||
hijack_llama_sdp_attention()
|
||||
|
||||
if cfg.bf16:
|
||||
torch_dtype = torch.bfloat16
|
||||
@@ -172,8 +183,10 @@ def load_model(
|
||||
)
|
||||
load_in_8bit = False
|
||||
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
||||
config = LlamaConfig.from_pretrained(base_model_config)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
|
||||
@@ -4,6 +4,10 @@ import logging
|
||||
|
||||
|
||||
def validate_config(cfg):
|
||||
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
||||
raise ValueError(
|
||||
"please set only one of gradient_accumulation_steps or batch_size"
|
||||
)
|
||||
if cfg.load_4bit:
|
||||
raise ValueError(
|
||||
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
|
||||
|
||||
12
tests/fixtures/alpaca/alpaca.json
vendored
Normal file
12
tests/fixtures/alpaca/alpaca.json
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
[
|
||||
{
|
||||
"instruction": "You will be given a series of words. Output these words in reverse order, with each word on its own line.",
|
||||
"input": "Words: ['Hello', 'world'].",
|
||||
"output": "['world', 'Hello']"
|
||||
},
|
||||
{
|
||||
"instruction": "In this task, you're given a short description of an event. Your job is to order the steps involved in the event from first to last. Note that there may be multiple correct answers for each event.",
|
||||
"input": "Description: A man walks into a bar and orders a drink. He pays for his drink and leaves the bar.",
|
||||
"output": "1. The man walks into the bar.\n2. He orders a drink.\n3. He pays for his drink.\n4. He leaves the bar."
|
||||
}
|
||||
]
|
||||
65
tests/test_packed_dataset.py
Normal file
65
tests/test_packed_dataset.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Module for testing dataset sequence packing"""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter
|
||||
|
||||
|
||||
class TestPacking(unittest.TestCase):
|
||||
"""
|
||||
Test class for packing dataset sequences
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
)
|
||||
|
||||
def test_resets_attention(self):
|
||||
prompter = AlpacaPrompter("chat")
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
dateset = load_dataset(
|
||||
"json",
|
||||
data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
|
||||
)["train"]
|
||||
dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
|
||||
|
||||
constant_len_dataset = ConstantLengthDataset(
|
||||
self.tokenizer,
|
||||
[dataset],
|
||||
seq_length=2048,
|
||||
)
|
||||
packed_dataset = Dataset.from_list(list(constant_len_dataset))
|
||||
example = packed_dataset[0]
|
||||
next_bos_index = (
|
||||
example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
|
||||
) # add one since we sliced
|
||||
|
||||
# first example doesn't have mask reset
|
||||
assert example["input_ids"][0] == self.tokenizer.bos_token_id
|
||||
assert example["attention_mask"][0] == 1
|
||||
|
||||
# but subsequent one does
|
||||
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
|
||||
assert example["attention_mask"][next_bos_index] == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -18,6 +18,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(
|
||||
{
|
||||
|
||||
@@ -117,3 +117,32 @@ class ValidationTest(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
validate_config(cfg)
|
||||
|
||||
def test_gradient_accumulations_or_batch_size(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"gradient_accumulation_steps": 1,
|
||||
"batch_size": 1,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"batch_size": 1,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"gradient_accumulation_steps": 1,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
Reference in New Issue
Block a user