Merge branch 'main' into patch-1
This commit is contained in:
37
README.md
37
README.md
@@ -22,7 +22,7 @@
|
|||||||
| Pythia | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
| Pythia | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| cerebras | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
| cerebras | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| falcon | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❓ |
|
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
|
|
||||||
|
|
||||||
## Quickstart ⚡
|
## Quickstart ⚡
|
||||||
@@ -33,6 +33,7 @@
|
|||||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||||
|
|
||||||
pip3 install -e .
|
pip3 install -e .
|
||||||
|
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||||
|
|
||||||
accelerate config
|
accelerate config
|
||||||
|
|
||||||
@@ -53,6 +54,7 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
|
|||||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.9-cu118-2.0.0
|
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.9-cu118-2.0.0
|
||||||
```
|
```
|
||||||
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0`: for runpod
|
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0`: for runpod
|
||||||
|
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0-gptq`: for gptq
|
||||||
- `winglian/axolotl:dev`: dev branch (not usually up to date)
|
- `winglian/axolotl:dev`: dev branch (not usually up to date)
|
||||||
|
|
||||||
Or run on the current files for development:
|
Or run on the current files for development:
|
||||||
@@ -67,9 +69,19 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
|
|||||||
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
||||||
|
|
||||||
3. Install python dependencies with ONE of the following:
|
3. Install python dependencies with ONE of the following:
|
||||||
- `pip3 install -e .` (recommended, supports QLoRA, no gptq/int4 support)
|
- Recommended, supports QLoRA, NO gptq/int4 support
|
||||||
- `pip3 install -e .[gptq]` (next best if you don't need QLoRA, but want to use gptq)
|
```bash
|
||||||
- `pip3 install -e .[gptq_triton]`
|
pip3 install -e .
|
||||||
|
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||||
|
```
|
||||||
|
- gptq/int4 support, NO QLoRA
|
||||||
|
```bash
|
||||||
|
pip3 install -e .[gptq]
|
||||||
|
```
|
||||||
|
- same as above but not recommended
|
||||||
|
```bash
|
||||||
|
pip3 install -e .[gptq_triton]
|
||||||
|
```
|
||||||
|
|
||||||
- LambdaLabs
|
- LambdaLabs
|
||||||
<details>
|
<details>
|
||||||
@@ -78,7 +90,8 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
|
|||||||
|
|
||||||
1. Install python
|
1. Install python
|
||||||
```bash
|
```bash
|
||||||
sudo apt install python3.9
|
sudo apt update
|
||||||
|
sudo apt install -y python3.9
|
||||||
|
|
||||||
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
|
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
|
||||||
sudo update-alternatives --config python # pick 3.9 if given option
|
sudo update-alternatives --config python # pick 3.9 if given option
|
||||||
@@ -205,14 +218,18 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"conversations": [{"role": "...", "value": "..."}]}
|
{"conversations": [{"role": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
- custom prompts structure:
|
|
||||||
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
|
||||||
2. Use your custom file name as the dataset type.
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
#### How to add custom prompts
|
||||||
|
|
||||||
|
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
||||||
|
2. Use your custom file name as the dataset type.
|
||||||
|
|
||||||
Optionally, download some datasets, see [data/README.md](data/README.md)
|
Optionally, download some datasets, see [data/README.md](data/README.md)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Config
|
### Config
|
||||||
|
|
||||||
See sample configs in [configs](configs) folder or [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
|
See sample configs in [configs](configs) folder or [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
|
||||||
@@ -370,7 +387,7 @@ train_on_inputs: false
|
|||||||
# don't use this, leads to wonky training (according to someone on the internet)
|
# don't use this, leads to wonky training (according to someone on the internet)
|
||||||
group_by_length: false
|
group_by_length: false
|
||||||
|
|
||||||
# does not work with current implementation of 4-bit LoRA
|
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
|
|
||||||
# stop training after this many evaluation losses have increased in a row
|
# stop training after this many evaluation losses have increased in a row
|
||||||
@@ -400,6 +417,8 @@ flash_attention: # require a100 for llama
|
|||||||
# whether to use scaled-dot-product attention
|
# whether to use scaled-dot-product attention
|
||||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||||
sdp_attention:
|
sdp_attention:
|
||||||
|
# Landmark attention (only llama)
|
||||||
|
landmark_attention:
|
||||||
|
|
||||||
# resume from a specific checkpoint dir
|
# resume from a specific checkpoint dir
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
|
|||||||
92
examples/falcon/config-7b-qlora.yml
Normal file
92
examples/falcon/config-7b-qlora.yml
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
# 1b: tiiuae/falcon-rw-1b
|
||||||
|
# 40b: tiiuae/falcon-40b
|
||||||
|
base_model: tiiuae/falcon-7b
|
||||||
|
base_model_config: tiiuae/falcon-7b
|
||||||
|
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
|
||||||
|
trust_remote_code: true
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
load_in_8bit: false
|
||||||
|
# enable 4bit for QLoRA
|
||||||
|
load_in_4bit: true
|
||||||
|
gptq: false
|
||||||
|
strict: false
|
||||||
|
push_dataset_to_hub:
|
||||||
|
datasets:
|
||||||
|
- path: QingyiSi/Alpaca-CoT
|
||||||
|
data_files:
|
||||||
|
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
||||||
|
type: "alpaca:chat"
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.01
|
||||||
|
# enable QLoRA
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
sequence_len: 2048
|
||||||
|
max_packed_sequence_len:
|
||||||
|
|
||||||
|
# hyperparameters from QLoRA paper Appendix B.2
|
||||||
|
# "We find hyperparameters to be largely robust across datasets"
|
||||||
|
lora_r: 64
|
||||||
|
lora_alpha: 16
|
||||||
|
# 0.1 for models up to 13B
|
||||||
|
# 0.05 for 33B and 65B models
|
||||||
|
lora_dropout: 0.05
|
||||||
|
# add LoRA modules on all linear layers of the base model
|
||||||
|
lora_target_modules:
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_run_id:
|
||||||
|
wandb_log_model:
|
||||||
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
# QLoRA paper Table 9
|
||||||
|
# - 16 for 7b & 13b
|
||||||
|
# - 32 for 33b, 64 for 64b
|
||||||
|
# Max size tested on A6000
|
||||||
|
# - 7b: 40
|
||||||
|
# - 40b: 4
|
||||||
|
# decrease if OOM, increase for max VRAM utilization
|
||||||
|
micro_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
num_epochs: 3
|
||||||
|
# Optimizer for QLoRA
|
||||||
|
optimizer: paged_adamw_32bit
|
||||||
|
torchdistx_path:
|
||||||
|
lr_scheduler: cosine
|
||||||
|
# QLoRA paper Table 9
|
||||||
|
# - 2e-4 for 7b & 13b
|
||||||
|
# - 1e-4 for 33b & 64b
|
||||||
|
learning_rate: 0.0002
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: true
|
||||||
|
gradient_checkpointing: true
|
||||||
|
# stop training after this many evaluation losses have increased in a row
|
||||||
|
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||||
|
early_stopping_patience: 3
|
||||||
|
resume_from_checkpoint:
|
||||||
|
auto_resume_from_checkpoints: true
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention: true
|
||||||
|
flash_attention:
|
||||||
|
gptq_groupsize:
|
||||||
|
gptq_model_v1:
|
||||||
|
warmup_steps: 10
|
||||||
|
eval_steps: 5
|
||||||
|
save_steps: 10
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.000001
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: "<|endoftext|>"
|
||||||
|
bos_token: ">>ABSTRACT<<"
|
||||||
|
eos_token: "<|endoftext|>"
|
||||||
@@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from transformers import GenerationConfig
|
from transformers import GenerationConfig, TextStreamer
|
||||||
|
|
||||||
from axolotl.utils.data import load_prepare_datasets
|
from axolotl.utils.data import load_prepare_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -64,13 +64,17 @@ def get_multi_line_input() -> Optional[str]:
|
|||||||
|
|
||||||
|
|
||||||
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
||||||
tokenizer.add_special_tokens({"unk_token": "<unk>"})
|
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||||
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
|
||||||
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
for token, symbol in default_tokens.items():
|
||||||
|
# If the token isn't already specified in the config, add it
|
||||||
|
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||||
|
tokenizer.add_special_tokens({token: symbol})
|
||||||
|
|
||||||
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
|
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
print("=" * 80)
|
||||||
# support for multiline inputs
|
# support for multiline inputs
|
||||||
instruction = get_multi_line_input()
|
instruction = get_multi_line_input()
|
||||||
if not instruction:
|
if not instruction:
|
||||||
@@ -79,7 +83,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|||||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||||
)
|
)
|
||||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||||
|
print("=" * 40)
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generation_config = GenerationConfig(
|
generation_config = GenerationConfig(
|
||||||
@@ -98,10 +102,13 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
)
|
)
|
||||||
|
streamer = TextStreamer(tokenizer)
|
||||||
generated = model.generate(
|
generated = model.generate(
|
||||||
inputs=batch["input_ids"].to(cfg.device),
|
inputs=batch["input_ids"].to(cfg.device),
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
|
streamer=streamer,
|
||||||
)
|
)
|
||||||
|
print("=" * 40)
|
||||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||||
|
|
||||||
|
|
||||||
@@ -183,6 +190,9 @@ def train(
|
|||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
cfg.bf16 = False
|
cfg.bf16 = False
|
||||||
|
|
||||||
|
if cfg.tf32:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||||
logging.info(f"loading tokenizer... {tokenizer_config}")
|
logging.info(f"loading tokenizer... {tokenizer_config}")
|
||||||
|
|||||||
1595
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
1595
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -78,6 +78,13 @@ def load_tokenized_prepared_datasets(
|
|||||||
else:
|
else:
|
||||||
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
||||||
logging.info("Loading raw datasets...")
|
logging.info("Loading raw datasets...")
|
||||||
|
|
||||||
|
if cfg.seed:
|
||||||
|
seed = cfg.seed
|
||||||
|
else:
|
||||||
|
logging.info("No seed provided, using default seed of 42")
|
||||||
|
seed = 42
|
||||||
|
|
||||||
datasets = []
|
datasets = []
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
for d in cfg.datasets:
|
for d in cfg.datasets:
|
||||||
@@ -127,11 +134,11 @@ def load_tokenized_prepared_datasets(
|
|||||||
# support for using a subset of the data
|
# support for using a subset of the data
|
||||||
if d.shards:
|
if d.shards:
|
||||||
if "train" in ds:
|
if "train" in ds:
|
||||||
ds = ds.shuffle(seed=42)["train"].shard(
|
ds = ds.shuffle(seed=seed)["train"].shard(
|
||||||
num_shards=d.shards, index=0
|
num_shards=d.shards, index=0
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ds = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
|
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
||||||
d_type = d.type
|
d_type = d.type
|
||||||
d_type_split = d_type.split(":")
|
d_type_split = d_type.split(":")
|
||||||
d_base_type = d_type_split[0]
|
d_base_type = d_type_split[0]
|
||||||
@@ -239,7 +246,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
samples: List[int] = []
|
samples: List[int] = []
|
||||||
for d in datasets:
|
for d in datasets:
|
||||||
samples = samples + list(d)
|
samples = samples + list(d)
|
||||||
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
dataset = Dataset.from_list(samples).shuffle(seed=seed)
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
|
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
|
||||||
|
|||||||
@@ -20,7 +20,9 @@ from transformers import ( # noqa: F401
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers import LlamaForCausalLM
|
from transformers import ( # pylint: disable=unused-import # noqa: F401
|
||||||
|
LlamaForCausalLM,
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"This version of transformers does not support Llama. Consider upgrading."
|
"This version of transformers does not support Llama. Consider upgrading."
|
||||||
@@ -82,37 +84,47 @@ def load_model(
|
|||||||
cfg,
|
cfg,
|
||||||
adapter="lora"
|
adapter="lora"
|
||||||
):
|
):
|
||||||
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
"""
|
"""
|
||||||
Load a model from a base model and a model type.
|
Load a model from a base model and a model type.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO refactor as a kwarg
|
# TODO refactor as a kwarg
|
||||||
load_in_8bit = cfg.load_in_8bit
|
load_in_8bit = cfg.load_in_8bit
|
||||||
is_llama_derived_model = "llama" in base_model or (
|
cfg.is_llama_derived_model = "llama" in base_model or (
|
||||||
cfg.model_type and "llama" in cfg.model_type.lower()
|
cfg.model_type and "llama" in cfg.model_type.lower()
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_llama_derived_model and cfg.flash_attention:
|
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||||
if cfg.device not in ["mps", "cpu"] and cfg.inference is False:
|
if cfg.device not in ["mps", "cpu"] and inference is False:
|
||||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
||||||
|
|
||||||
logging.info("patching with flash attention")
|
logging.info("patching with flash attention")
|
||||||
replace_llama_attn_with_flash_attn()
|
replace_llama_attn_with_flash_attn()
|
||||||
elif is_llama_derived_model and cfg.xformers_attention:
|
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||||
hijack_llama_attention,
|
hijack_llama_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("patching with xformers attention")
|
logging.info("patching with xformers attention")
|
||||||
hijack_llama_attention()
|
hijack_llama_attention()
|
||||||
elif is_llama_derived_model and cfg.sdp_attention:
|
elif cfg.is_llama_derived_model and cfg.sdp_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||||
hijack_llama_sdp_attention,
|
hijack_llama_sdp_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("patching with sdp attention")
|
logging.info("patching with sdp attention")
|
||||||
hijack_llama_sdp_attention()
|
hijack_llama_sdp_attention()
|
||||||
|
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||||
|
from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
|
||||||
|
MEM_TOKEN,
|
||||||
|
LlamaForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("patching with landmark attention")
|
||||||
|
|
||||||
|
# TODO: Check if this would overwrite previous additional_special_tokens
|
||||||
|
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
||||||
|
|
||||||
if cfg.bf16:
|
if cfg.bf16:
|
||||||
torch_dtype = torch.bfloat16
|
torch_dtype = torch.bfloat16
|
||||||
@@ -127,11 +139,18 @@ def load_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
replace_peft_model_with_int4_lora_model()
|
replace_peft_model_with_int4_lora_model()
|
||||||
from peft import prepare_model_for_int8_training
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logging.exception(err)
|
logging.exception(err)
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
|
try:
|
||||||
|
from peft import prepare_model_for_kbit_training
|
||||||
|
except ImportError:
|
||||||
|
# For backward compatibility
|
||||||
|
from peft import (
|
||||||
|
prepare_model_for_int8_training as prepare_model_for_kbit_training,
|
||||||
|
)
|
||||||
|
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
||||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
@@ -143,7 +162,7 @@ def load_model(
|
|||||||
bnb_4bit_quant_type="nf4",
|
bnb_4bit_quant_type="nf4",
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
if cfg.gptq and is_llama_derived_model:
|
if cfg.gptq and cfg.is_llama_derived_model:
|
||||||
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
@@ -181,7 +200,7 @@ def load_model(
|
|||||||
else True,
|
else True,
|
||||||
)
|
)
|
||||||
load_in_8bit = False
|
load_in_8bit = False
|
||||||
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
|
||||||
config = LlamaConfig.from_pretrained(base_model_config)
|
config = LlamaConfig.from_pretrained(base_model_config)
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
@@ -235,8 +254,15 @@ def load_model(
|
|||||||
)
|
)
|
||||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
||||||
# when training starts
|
# when training starts
|
||||||
if config.max_seq_len and cfg.sequence_len > config.max_seq_len:
|
if hasattr(config, "max_seq_len") and cfg.sequence_len > config.max_seq_len:
|
||||||
config.max_seq_len = cfg.sequence_len
|
config.max_seq_len = cfg.sequence_len
|
||||||
|
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
||||||
|
elif (
|
||||||
|
hasattr(config, "max_sequence_length")
|
||||||
|
and cfg.sequence_len > config.max_sequence_length
|
||||||
|
):
|
||||||
|
config.max_sequence_length = cfg.sequence_len
|
||||||
|
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=config,
|
config=config,
|
||||||
@@ -268,8 +294,8 @@ def load_model(
|
|||||||
(cfg.adapter == "lora" and load_in_8bit)
|
(cfg.adapter == "lora" and load_in_8bit)
|
||||||
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
||||||
):
|
):
|
||||||
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||||
model = prepare_model_for_int8_training(model)
|
model = prepare_model_for_kbit_training(model)
|
||||||
|
|
||||||
model, lora_config = load_adapter(model, cfg, adapter)
|
model, lora_config = load_adapter(model, cfg, adapter)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Module containing the Trainer class and related functions"""
|
"""Module containing the Trainer class and related functions"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -62,8 +63,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
if cfg.logging_steps is not None
|
if cfg.logging_steps is not None
|
||||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||||
)
|
)
|
||||||
save_steps = cfg.save_steps
|
|
||||||
eval_steps = cfg.eval_steps
|
|
||||||
|
|
||||||
training_arguments_kwargs = {}
|
training_arguments_kwargs = {}
|
||||||
if cfg.bf16 == "full":
|
if cfg.bf16 == "full":
|
||||||
@@ -74,6 +73,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
training_arguments_kwargs["tf32"] = cfg.tf32
|
training_arguments_kwargs["tf32"] = cfg.tf32
|
||||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||||
|
|
||||||
|
if cfg.seed:
|
||||||
|
training_arguments_kwargs["seed"] = cfg.seed
|
||||||
|
|
||||||
if cfg.gradient_checkpointing:
|
if cfg.gradient_checkpointing:
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
from alpaca_lora_4bit.gradient_checkpointing import (
|
from alpaca_lora_4bit.gradient_checkpointing import (
|
||||||
@@ -119,16 +122,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
num_train_epochs=cfg.num_epochs,
|
num_train_epochs=cfg.num_epochs,
|
||||||
learning_rate=cfg.learning_rate,
|
learning_rate=cfg.learning_rate,
|
||||||
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
|
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
|
||||||
save_strategy="steps" if save_steps else "epoch",
|
save_strategy="steps" if cfg.save_steps else "epoch",
|
||||||
eval_steps=eval_steps if cfg.val_set_size > 0 else None,
|
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
||||||
save_steps=save_steps,
|
save_steps=cfg.save_steps,
|
||||||
output_dir=cfg.output_dir,
|
output_dir=cfg.output_dir,
|
||||||
save_total_limit=3,
|
save_total_limit=3,
|
||||||
load_best_model_at_end=(
|
load_best_model_at_end=(
|
||||||
cfg.load_best_model_at_end is not False
|
cfg.load_best_model_at_end is not False
|
||||||
and cfg.val_set_size > 0
|
and cfg.val_set_size > 0
|
||||||
and save_steps
|
and cfg.save_steps
|
||||||
and save_steps % eval_steps == 0
|
and cfg.save_steps % cfg.eval_steps == 0
|
||||||
and cfg.load_in_8bit is not True
|
and cfg.load_in_8bit is not True
|
||||||
)
|
)
|
||||||
or False,
|
or False,
|
||||||
@@ -233,6 +236,23 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
else:
|
else:
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = 8
|
data_collator_kwargs["pad_to_multiple_of"] = 8
|
||||||
|
|
||||||
|
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens
|
||||||
|
|
||||||
|
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
|
||||||
|
model.set_mem_id(mem_id)
|
||||||
|
|
||||||
|
logging.info("Adding landmark attention tokens to dataset")
|
||||||
|
|
||||||
|
for dataset in [train_dataset, eval_dataset]:
|
||||||
|
dataset = dataset.map(
|
||||||
|
partial(add_mem_tokens, mem_freq=50, mem_id=mem_id),
|
||||||
|
batched=False,
|
||||||
|
num_proc=32,
|
||||||
|
)
|
||||||
|
|
||||||
trainer_cls = (
|
trainer_cls = (
|
||||||
OneCycleLRSchedulerTrainer
|
OneCycleLRSchedulerTrainer
|
||||||
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
|
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
|
||||||
|
|||||||
@@ -54,6 +54,9 @@ def validate_config(cfg):
|
|||||||
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
||||||
|
raise ValueError("FSDP is not supported for falcon models")
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -15,3 +15,5 @@ def setup_wandb_env_vars(cfg):
|
|||||||
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
||||||
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
||||||
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
||||||
|
else:
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|||||||
@@ -165,3 +165,36 @@ class ValidationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_falcon_fsdp(self):
|
||||||
|
regex_exp = r".*FSDP is not supported for falcon models.*"
|
||||||
|
|
||||||
|
# Check for lower-case
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "tiiuae/falcon-7b",
|
||||||
|
"fsdp": ["full_shard", "auto_wrap"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=regex_exp):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
# Check for upper-case
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "Falcon-7b",
|
||||||
|
"fsdp": ["full_shard", "auto_wrap"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=regex_exp):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "tiiuae/falcon-7b",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user