Merge branch 'main' into patch-1
This commit is contained in:
37
README.md
37
README.md
@@ -22,7 +22,7 @@
|
||||
| Pythia | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
|
||||
|
||||
## Quickstart ⚡
|
||||
@@ -33,6 +33,7 @@
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
|
||||
pip3 install -e .
|
||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||
|
||||
accelerate config
|
||||
|
||||
@@ -53,6 +54,7 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.9-cu118-2.0.0
|
||||
```
|
||||
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0`: for runpod
|
||||
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0-gptq`: for gptq
|
||||
- `winglian/axolotl:dev`: dev branch (not usually up to date)
|
||||
|
||||
Or run on the current files for development:
|
||||
@@ -67,9 +69,19 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
|
||||
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
||||
|
||||
3. Install python dependencies with ONE of the following:
|
||||
- `pip3 install -e .` (recommended, supports QLoRA, no gptq/int4 support)
|
||||
- `pip3 install -e .[gptq]` (next best if you don't need QLoRA, but want to use gptq)
|
||||
- `pip3 install -e .[gptq_triton]`
|
||||
- Recommended, supports QLoRA, NO gptq/int4 support
|
||||
```bash
|
||||
pip3 install -e .
|
||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||
```
|
||||
- gptq/int4 support, NO QLoRA
|
||||
```bash
|
||||
pip3 install -e .[gptq]
|
||||
```
|
||||
- same as above but not recommended
|
||||
```bash
|
||||
pip3 install -e .[gptq_triton]
|
||||
```
|
||||
|
||||
- LambdaLabs
|
||||
<details>
|
||||
@@ -78,7 +90,8 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
|
||||
|
||||
1. Install python
|
||||
```bash
|
||||
sudo apt install python3.9
|
||||
sudo apt update
|
||||
sudo apt install -y python3.9
|
||||
|
||||
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
|
||||
sudo update-alternatives --config python # pick 3.9 if given option
|
||||
@@ -205,14 +218,18 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
- custom prompts structure:
|
||||
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
||||
2. Use your custom file name as the dataset type.
|
||||
|
||||
</details>
|
||||
|
||||
#### How to add custom prompts
|
||||
|
||||
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
||||
2. Use your custom file name as the dataset type.
|
||||
|
||||
Optionally, download some datasets, see [data/README.md](data/README.md)
|
||||
|
||||
|
||||
|
||||
### Config
|
||||
|
||||
See sample configs in [configs](configs) folder or [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
|
||||
@@ -370,7 +387,7 @@ train_on_inputs: false
|
||||
# don't use this, leads to wonky training (according to someone on the internet)
|
||||
group_by_length: false
|
||||
|
||||
# does not work with current implementation of 4-bit LoRA
|
||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
gradient_checkpointing: false
|
||||
|
||||
# stop training after this many evaluation losses have increased in a row
|
||||
@@ -400,6 +417,8 @@ flash_attention: # require a100 for llama
|
||||
# whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
# Landmark attention (only llama)
|
||||
landmark_attention:
|
||||
|
||||
# resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
|
||||
92
examples/falcon/config-7b-qlora.yml
Normal file
92
examples/falcon/config-7b-qlora.yml
Normal file
@@ -0,0 +1,92 @@
|
||||
# 1b: tiiuae/falcon-rw-1b
|
||||
# 40b: tiiuae/falcon-40b
|
||||
base_model: tiiuae/falcon-7b
|
||||
base_model_config: tiiuae/falcon-7b
|
||||
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
load_in_8bit: false
|
||||
# enable 4bit for QLoRA
|
||||
load_in_4bit: true
|
||||
gptq: false
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: QingyiSi/Alpaca-CoT
|
||||
data_files:
|
||||
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
||||
type: "alpaca:chat"
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
|
||||
# hyperparameters from QLoRA paper Appendix B.2
|
||||
# "We find hyperparameters to be largely robust across datasets"
|
||||
lora_r: 64
|
||||
lora_alpha: 16
|
||||
# 0.1 for models up to 13B
|
||||
# 0.05 for 33B and 65B models
|
||||
lora_dropout: 0.05
|
||||
# add LoRA modules on all linear layers of the base model
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
|
||||
# QLoRA paper Table 9
|
||||
# - 16 for 7b & 13b
|
||||
# - 32 for 33b, 64 for 64b
|
||||
# Max size tested on A6000
|
||||
# - 7b: 40
|
||||
# - 40b: 4
|
||||
# decrease if OOM, increase for max VRAM utilization
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 3
|
||||
# Optimizer for QLoRA
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
# QLoRA paper Table 9
|
||||
# - 2e-4 for 7b & 13b
|
||||
# - 1e-4 for 33b & 64b
|
||||
learning_rate: 0.0002
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
gradient_checkpointing: true
|
||||
# stop training after this many evaluation losses have increased in a row
|
||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
auto_resume_from_checkpoints: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 5
|
||||
save_steps: 10
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.000001
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
bos_token: ">>ABSTRACT<<"
|
||||
eos_token: "<|endoftext|>"
|
||||
@@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
import fire
|
||||
import torch
|
||||
import yaml
|
||||
from transformers import GenerationConfig
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
|
||||
from axolotl.utils.data import load_prepare_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -64,13 +64,17 @@ def get_multi_line_input() -> Optional[str]:
|
||||
|
||||
|
||||
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
||||
tokenizer.add_special_tokens({"unk_token": "<unk>"})
|
||||
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
||||
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
|
||||
|
||||
while True:
|
||||
print("=" * 80)
|
||||
# support for multiline inputs
|
||||
instruction = get_multi_line_input()
|
||||
if not instruction:
|
||||
@@ -79,7 +83,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
print("=" * 40)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
generation_config = GenerationConfig(
|
||||
@@ -98,10 +102,13 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
||||
output_hidden_states=False,
|
||||
output_scores=False,
|
||||
)
|
||||
streamer = TextStreamer(tokenizer)
|
||||
generated = model.generate(
|
||||
inputs=batch["input_ids"].to(cfg.device),
|
||||
generation_config=generation_config,
|
||||
streamer=streamer,
|
||||
)
|
||||
print("=" * 40)
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
@@ -183,6 +190,9 @@ def train(
|
||||
cfg.fp16 = True
|
||||
cfg.bf16 = False
|
||||
|
||||
if cfg.tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# load the tokenizer first
|
||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||
logging.info(f"loading tokenizer... {tokenizer_config}")
|
||||
|
||||
1595
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
1595
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -78,6 +78,13 @@ def load_tokenized_prepared_datasets(
|
||||
else:
|
||||
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
||||
logging.info("Loading raw datasets...")
|
||||
|
||||
if cfg.seed:
|
||||
seed = cfg.seed
|
||||
else:
|
||||
logging.info("No seed provided, using default seed of 42")
|
||||
seed = 42
|
||||
|
||||
datasets = []
|
||||
# pylint: disable=invalid-name
|
||||
for d in cfg.datasets:
|
||||
@@ -127,11 +134,11 @@ def load_tokenized_prepared_datasets(
|
||||
# support for using a subset of the data
|
||||
if d.shards:
|
||||
if "train" in ds:
|
||||
ds = ds.shuffle(seed=42)["train"].shard(
|
||||
ds = ds.shuffle(seed=seed)["train"].shard(
|
||||
num_shards=d.shards, index=0
|
||||
)
|
||||
else:
|
||||
ds = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
|
||||
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
||||
d_type = d.type
|
||||
d_type_split = d_type.split(":")
|
||||
d_base_type = d_type_split[0]
|
||||
@@ -239,7 +246,7 @@ def load_tokenized_prepared_datasets(
|
||||
samples: List[int] = []
|
||||
for d in datasets:
|
||||
samples = samples + list(d)
|
||||
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
||||
dataset = Dataset.from_list(samples).shuffle(seed=seed)
|
||||
if cfg.local_rank == 0:
|
||||
logging.info(
|
||||
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
|
||||
|
||||
@@ -20,7 +20,9 @@ from transformers import ( # noqa: F401
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers import ( # pylint: disable=unused-import # noqa: F401
|
||||
LlamaForCausalLM,
|
||||
)
|
||||
except ImportError:
|
||||
logging.warning(
|
||||
"This version of transformers does not support Llama. Consider upgrading."
|
||||
@@ -82,37 +84,47 @@ def load_model(
|
||||
cfg,
|
||||
adapter="lora"
|
||||
):
|
||||
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
"""
|
||||
Load a model from a base model and a model type.
|
||||
"""
|
||||
|
||||
# TODO refactor as a kwarg
|
||||
load_in_8bit = cfg.load_in_8bit
|
||||
is_llama_derived_model = "llama" in base_model or (
|
||||
cfg.is_llama_derived_model = "llama" in base_model or (
|
||||
cfg.model_type and "llama" in cfg.model_type.lower()
|
||||
)
|
||||
|
||||
if is_llama_derived_model and cfg.flash_attention:
|
||||
if cfg.device not in ["mps", "cpu"] and cfg.inference is False:
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||
if cfg.device not in ["mps", "cpu"] and inference is False:
|
||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
||||
|
||||
logging.info("patching with flash attention")
|
||||
replace_llama_attn_with_flash_attn()
|
||||
elif is_llama_derived_model and cfg.xformers_attention:
|
||||
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_attention,
|
||||
)
|
||||
|
||||
logging.info("patching with xformers attention")
|
||||
hijack_llama_attention()
|
||||
elif is_llama_derived_model and cfg.sdp_attention:
|
||||
elif cfg.is_llama_derived_model and cfg.sdp_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_sdp_attention,
|
||||
)
|
||||
|
||||
logging.info("patching with sdp attention")
|
||||
hijack_llama_sdp_attention()
|
||||
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
|
||||
MEM_TOKEN,
|
||||
LlamaForCausalLM,
|
||||
)
|
||||
|
||||
logging.info("patching with landmark attention")
|
||||
|
||||
# TODO: Check if this would overwrite previous additional_special_tokens
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
||||
|
||||
if cfg.bf16:
|
||||
torch_dtype = torch.bfloat16
|
||||
@@ -127,11 +139,18 @@ def load_model(
|
||||
)
|
||||
|
||||
replace_peft_model_with_int4_lora_model()
|
||||
from peft import prepare_model_for_int8_training
|
||||
except Exception as err:
|
||||
logging.exception(err)
|
||||
raise err
|
||||
|
||||
try:
|
||||
from peft import prepare_model_for_kbit_training
|
||||
except ImportError:
|
||||
# For backward compatibility
|
||||
from peft import (
|
||||
prepare_model_for_int8_training as prepare_model_for_kbit_training,
|
||||
)
|
||||
|
||||
model_kwargs = {}
|
||||
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
@@ -143,7 +162,7 @@ def load_model(
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
try:
|
||||
if cfg.gptq and is_llama_derived_model:
|
||||
if cfg.gptq and cfg.is_llama_derived_model:
|
||||
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
@@ -181,7 +200,7 @@ def load_model(
|
||||
else True,
|
||||
)
|
||||
load_in_8bit = False
|
||||
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
||||
elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
|
||||
config = LlamaConfig.from_pretrained(base_model_config)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
@@ -235,8 +254,15 @@ def load_model(
|
||||
)
|
||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
||||
# when training starts
|
||||
if config.max_seq_len and cfg.sequence_len > config.max_seq_len:
|
||||
if hasattr(config, "max_seq_len") and cfg.sequence_len > config.max_seq_len:
|
||||
config.max_seq_len = cfg.sequence_len
|
||||
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
||||
elif (
|
||||
hasattr(config, "max_sequence_length")
|
||||
and cfg.sequence_len > config.max_sequence_length
|
||||
):
|
||||
config.max_sequence_length = cfg.sequence_len
|
||||
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
@@ -268,8 +294,8 @@ def load_model(
|
||||
(cfg.adapter == "lora" and load_in_8bit)
|
||||
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
||||
):
|
||||
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
||||
model = prepare_model_for_int8_training(model)
|
||||
logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
model, lora_config = load_adapter(model, cfg, adapter)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Module containing the Trainer class and related functions"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
@@ -62,8 +63,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
if cfg.logging_steps is not None
|
||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||
)
|
||||
save_steps = cfg.save_steps
|
||||
eval_steps = cfg.eval_steps
|
||||
|
||||
training_arguments_kwargs = {}
|
||||
if cfg.bf16 == "full":
|
||||
@@ -74,6 +73,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
training_arguments_kwargs["tf32"] = cfg.tf32
|
||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||
|
||||
if cfg.seed:
|
||||
training_arguments_kwargs["seed"] = cfg.seed
|
||||
|
||||
if cfg.gradient_checkpointing:
|
||||
if cfg.gptq:
|
||||
from alpaca_lora_4bit.gradient_checkpointing import (
|
||||
@@ -119,16 +122,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
num_train_epochs=cfg.num_epochs,
|
||||
learning_rate=cfg.learning_rate,
|
||||
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
|
||||
save_strategy="steps" if save_steps else "epoch",
|
||||
eval_steps=eval_steps if cfg.val_set_size > 0 else None,
|
||||
save_steps=save_steps,
|
||||
save_strategy="steps" if cfg.save_steps else "epoch",
|
||||
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
||||
save_steps=cfg.save_steps,
|
||||
output_dir=cfg.output_dir,
|
||||
save_total_limit=3,
|
||||
load_best_model_at_end=(
|
||||
cfg.load_best_model_at_end is not False
|
||||
and cfg.val_set_size > 0
|
||||
and save_steps
|
||||
and save_steps % eval_steps == 0
|
||||
and cfg.save_steps
|
||||
and cfg.save_steps % cfg.eval_steps == 0
|
||||
and cfg.load_in_8bit is not True
|
||||
)
|
||||
or False,
|
||||
@@ -233,6 +236,23 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
else:
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 8
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||
from functools import partial
|
||||
|
||||
from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens
|
||||
|
||||
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
|
||||
model.set_mem_id(mem_id)
|
||||
|
||||
logging.info("Adding landmark attention tokens to dataset")
|
||||
|
||||
for dataset in [train_dataset, eval_dataset]:
|
||||
dataset = dataset.map(
|
||||
partial(add_mem_tokens, mem_freq=50, mem_id=mem_id),
|
||||
batched=False,
|
||||
num_proc=32,
|
||||
)
|
||||
|
||||
trainer_cls = (
|
||||
OneCycleLRSchedulerTrainer
|
||||
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
|
||||
|
||||
@@ -54,6 +54,9 @@ def validate_config(cfg):
|
||||
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
||||
)
|
||||
|
||||
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
||||
raise ValueError("FSDP is not supported for falcon models")
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
|
||||
@@ -15,3 +15,5 @@ def setup_wandb_env_vars(cfg):
|
||||
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
||||
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
||||
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
||||
else:
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -165,3 +165,36 @@ class ValidationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_falcon_fsdp(self):
|
||||
regex_exp = r".*FSDP is not supported for falcon models.*"
|
||||
|
||||
# Check for lower-case
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "tiiuae/falcon-7b",
|
||||
"fsdp": ["full_shard", "auto_wrap"],
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=regex_exp):
|
||||
validate_config(cfg)
|
||||
|
||||
# Check for upper-case
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Falcon-7b",
|
||||
"fsdp": ["full_shard", "auto_wrap"],
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=regex_exp):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "tiiuae/falcon-7b",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
Reference in New Issue
Block a user