Merge branch 'main' into patch-1

This commit is contained in:
Angainor Development
2023-06-10 19:07:54 +02:00
committed by GitHub
10 changed files with 1844 additions and 37 deletions

View File

@@ -22,7 +22,7 @@
| Pythia | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ | | Pythia | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ | | cerebras | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ | | mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | | | ❌ | ❌ | ❌ | ❓ | | falcon | ✅ | | | ❌ | ❌ | ❌ | ❓ |
## Quickstart ⚡ ## Quickstart ⚡
@@ -33,6 +33,7 @@
git clone https://github.com/OpenAccess-AI-Collective/axolotl git clone https://github.com/OpenAccess-AI-Collective/axolotl
pip3 install -e . pip3 install -e .
pip3 install -U git+https://github.com/huggingface/peft.git
accelerate config accelerate config
@@ -53,6 +54,7 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.9-cu118-2.0.0 docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.9-cu118-2.0.0
``` ```
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0`: for runpod - `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0`: for runpod
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.0-gptq`: for gptq
- `winglian/axolotl:dev`: dev branch (not usually up to date) - `winglian/axolotl:dev`: dev branch (not usually up to date)
Or run on the current files for development: Or run on the current files for development:
@@ -67,9 +69,19 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
2. Install pytorch stable https://pytorch.org/get-started/locally/ 2. Install pytorch stable https://pytorch.org/get-started/locally/
3. Install python dependencies with ONE of the following: 3. Install python dependencies with ONE of the following:
- `pip3 install -e .` (recommended, supports QLoRA, no gptq/int4 support) - Recommended, supports QLoRA, NO gptq/int4 support
- `pip3 install -e .[gptq]` (next best if you don't need QLoRA, but want to use gptq) ```bash
- `pip3 install -e .[gptq_triton]` pip3 install -e .
pip3 install -U git+https://github.com/huggingface/peft.git
```
- gptq/int4 support, NO QLoRA
```bash
pip3 install -e .[gptq]
```
- same as above but not recommended
```bash
pip3 install -e .[gptq_triton]
```
- LambdaLabs - LambdaLabs
<details> <details>
@@ -78,7 +90,8 @@ accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
1. Install python 1. Install python
```bash ```bash
sudo apt install python3.9 sudo apt update
sudo apt install -y python3.9
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1 sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
sudo update-alternatives --config python # pick 3.9 if given option sudo update-alternatives --config python # pick 3.9 if given option
@@ -205,14 +218,18 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json ```json
{"conversations": [{"role": "...", "value": "..."}]} {"conversations": [{"role": "...", "value": "..."}]}
``` ```
- custom prompts structure:
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type.
</details> </details>
#### How to add custom prompts
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type.
Optionally, download some datasets, see [data/README.md](data/README.md) Optionally, download some datasets, see [data/README.md](data/README.md)
### Config ### Config
See sample configs in [configs](configs) folder or [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are: See sample configs in [configs](configs) folder or [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
@@ -370,7 +387,7 @@ train_on_inputs: false
# don't use this, leads to wonky training (according to someone on the internet) # don't use this, leads to wonky training (according to someone on the internet)
group_by_length: false group_by_length: false
# does not work with current implementation of 4-bit LoRA # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false gradient_checkpointing: false
# stop training after this many evaluation losses have increased in a row # stop training after this many evaluation losses have increased in a row
@@ -400,6 +417,8 @@ flash_attention: # require a100 for llama
# whether to use scaled-dot-product attention # whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention: sdp_attention:
# Landmark attention (only llama)
landmark_attention:
# resume from a specific checkpoint dir # resume from a specific checkpoint dir
resume_from_checkpoint: resume_from_checkpoint:

View File

@@ -0,0 +1,92 @@
# 1b: tiiuae/falcon-rw-1b
# 40b: tiiuae/falcon-40b
base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
# enable 4bit for QLoRA
load_in_4bit: true
gptq: false
strict: false
push_dataset_to_hub:
datasets:
- path: QingyiSi/Alpaca-CoT
data_files:
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
type: "alpaca:chat"
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
# enable QLoRA
adapter: qlora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
# hyperparameters from QLoRA paper Appendix B.2
# "We find hyperparameters to be largely robust across datasets"
lora_r: 64
lora_alpha: 16
# 0.1 for models up to 13B
# 0.05 for 33B and 65B models
lora_dropout: 0.05
# add LoRA modules on all linear layers of the base model
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
# QLoRA paper Table 9
# - 16 for 7b & 13b
# - 32 for 33b, 64 for 64b
# Max size tested on A6000
# - 7b: 40
# - 40b: 4
# decrease if OOM, increase for max VRAM utilization
micro_batch_size: 1
gradient_accumulation_steps: 2
num_epochs: 3
# Optimizer for QLoRA
optimizer: paged_adamw_32bit
torchdistx_path:
lr_scheduler: cosine
# QLoRA paper Table 9
# - 2e-4 for 7b & 13b
# - 1e-4 for 33b & 64b
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: true
gradient_checkpointing: true
# stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 5
save_steps: 10
debug:
deepspeed:
weight_decay: 0.000001
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<"
eos_token: "<|endoftext|>"

View File

@@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
import fire import fire
import torch import torch
import yaml import yaml
from transformers import GenerationConfig from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -64,13 +64,17 @@ def get_multi_line_input() -> Optional[str]:
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
tokenizer.add_special_tokens({"unk_token": "<unk>"}) default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
tokenizer.add_special_tokens({"bos_token": "<s>"})
tokenizer.add_special_tokens({"eos_token": "</s>"}) for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter) prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
while True: while True:
print("=" * 80)
# support for multiline inputs # support for multiline inputs
instruction = get_multi_line_input() instruction = get_multi_line_input()
if not instruction: if not instruction:
@@ -79,7 +83,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
prompter_module().build_prompt(instruction=instruction.strip("\n")) prompter_module().build_prompt(instruction=instruction.strip("\n"))
) )
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
generation_config = GenerationConfig( generation_config = GenerationConfig(
@@ -98,10 +102,13 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
output_hidden_states=False, output_hidden_states=False,
output_scores=False, output_scores=False,
) )
streamer = TextStreamer(tokenizer)
generated = model.generate( generated = model.generate(
inputs=batch["input_ids"].to(cfg.device), inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config, generation_config=generation_config,
streamer=streamer,
) )
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
@@ -183,6 +190,9 @@ def train(
cfg.fp16 = True cfg.fp16 = True
cfg.bf16 = False cfg.bf16 = False
if cfg.tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# load the tokenizer first # load the tokenizer first
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
logging.info(f"loading tokenizer... {tokenizer_config}") logging.info(f"loading tokenizer... {tokenizer_config}")

File diff suppressed because it is too large Load Diff

View File

@@ -78,6 +78,13 @@ def load_tokenized_prepared_datasets(
else: else:
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}") logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
logging.info("Loading raw datasets...") logging.info("Loading raw datasets...")
if cfg.seed:
seed = cfg.seed
else:
logging.info("No seed provided, using default seed of 42")
seed = 42
datasets = [] datasets = []
# pylint: disable=invalid-name # pylint: disable=invalid-name
for d in cfg.datasets: for d in cfg.datasets:
@@ -127,11 +134,11 @@ def load_tokenized_prepared_datasets(
# support for using a subset of the data # support for using a subset of the data
if d.shards: if d.shards:
if "train" in ds: if "train" in ds:
ds = ds.shuffle(seed=42)["train"].shard( ds = ds.shuffle(seed=seed)["train"].shard(
num_shards=d.shards, index=0 num_shards=d.shards, index=0
) )
else: else:
ds = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0) ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
d_type = d.type d_type = d.type
d_type_split = d_type.split(":") d_type_split = d_type.split(":")
d_base_type = d_type_split[0] d_base_type = d_type_split[0]
@@ -239,7 +246,7 @@ def load_tokenized_prepared_datasets(
samples: List[int] = [] samples: List[int] = []
for d in datasets: for d in datasets:
samples = samples + list(d) samples = samples + list(d)
dataset = Dataset.from_list(samples).shuffle(seed=42) dataset = Dataset.from_list(samples).shuffle(seed=seed)
if cfg.local_rank == 0: if cfg.local_rank == 0:
logging.info( logging.info(
f"Saving merged prepared dataset to disk... {prepared_ds_path}" f"Saving merged prepared dataset to disk... {prepared_ds_path}"

View File

@@ -20,7 +20,9 @@ from transformers import ( # noqa: F401
) )
try: try:
from transformers import LlamaForCausalLM from transformers import ( # pylint: disable=unused-import # noqa: F401
LlamaForCausalLM,
)
except ImportError: except ImportError:
logging.warning( logging.warning(
"This version of transformers does not support Llama. Consider upgrading." "This version of transformers does not support Llama. Consider upgrading."
@@ -82,37 +84,47 @@ def load_model(
cfg, cfg,
adapter="lora" adapter="lora"
): ):
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
""" """
Load a model from a base model and a model type. Load a model from a base model and a model type.
""" """
# TODO refactor as a kwarg # TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit load_in_8bit = cfg.load_in_8bit
is_llama_derived_model = "llama" in base_model or ( cfg.is_llama_derived_model = "llama" in base_model or (
cfg.model_type and "llama" in cfg.model_type.lower() cfg.model_type and "llama" in cfg.model_type.lower()
) )
if is_llama_derived_model and cfg.flash_attention: if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and cfg.inference is False: if cfg.device not in ["mps", "cpu"] and inference is False:
from axolotl.flash_attn import replace_llama_attn_with_flash_attn from axolotl.flash_attn import replace_llama_attn_with_flash_attn
logging.info("patching with flash attention") logging.info("patching with flash attention")
replace_llama_attn_with_flash_attn() replace_llama_attn_with_flash_attn()
elif is_llama_derived_model and cfg.xformers_attention: elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import ( from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention, hijack_llama_attention,
) )
logging.info("patching with xformers attention") logging.info("patching with xformers attention")
hijack_llama_attention() hijack_llama_attention()
elif is_llama_derived_model and cfg.sdp_attention: elif cfg.is_llama_derived_model and cfg.sdp_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import ( from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_sdp_attention, hijack_llama_sdp_attention,
) )
logging.info("patching with sdp attention") logging.info("patching with sdp attention")
hijack_llama_sdp_attention() hijack_llama_sdp_attention()
elif cfg.is_llama_derived_model and cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
MEM_TOKEN,
LlamaForCausalLM,
)
logging.info("patching with landmark attention")
# TODO: Check if this would overwrite previous additional_special_tokens
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
if cfg.bf16: if cfg.bf16:
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
@@ -127,11 +139,18 @@ def load_model(
) )
replace_peft_model_with_int4_lora_model() replace_peft_model_with_int4_lora_model()
from peft import prepare_model_for_int8_training
except Exception as err: except Exception as err:
logging.exception(err) logging.exception(err)
raise err raise err
try:
from peft import prepare_model_for_kbit_training
except ImportError:
# For backward compatibility
from peft import (
prepare_model_for_int8_training as prepare_model_for_kbit_training,
)
model_kwargs = {} model_kwargs = {}
if cfg.adapter == "qlora" and cfg.load_in_4bit: if cfg.adapter == "qlora" and cfg.load_in_4bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig( model_kwargs["quantization_config"] = BitsAndBytesConfig(
@@ -143,7 +162,7 @@ def load_model(
bnb_4bit_quant_type="nf4", bnb_4bit_quant_type="nf4",
) )
try: try:
if cfg.gptq and is_llama_derived_model: if cfg.gptq and cfg.is_llama_derived_model:
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
@@ -181,7 +200,7 @@ def load_model(
else True, else True,
) )
load_in_8bit = False load_in_8bit = False
elif is_llama_derived_model and "LlamaForCausalLM" in globals(): elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
config = LlamaConfig.from_pretrained(base_model_config) config = LlamaConfig.from_pretrained(base_model_config)
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
base_model, base_model,
@@ -235,8 +254,15 @@ def load_model(
) )
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts # when training starts
if config.max_seq_len and cfg.sequence_len > config.max_seq_len: if hasattr(config, "max_seq_len") and cfg.sequence_len > config.max_seq_len:
config.max_seq_len = cfg.sequence_len config.max_seq_len = cfg.sequence_len
logging.warning(f"increasing context length to {cfg.sequence_len}")
elif (
hasattr(config, "max_sequence_length")
and cfg.sequence_len > config.max_sequence_length
):
config.max_sequence_length = cfg.sequence_len
logging.warning(f"increasing context length to {cfg.sequence_len}")
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=config,
@@ -268,8 +294,8 @@ def load_model(
(cfg.adapter == "lora" and load_in_8bit) (cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit)
): ):
logging.info("converting PEFT model w/ prepare_model_for_int8_training") logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
model = prepare_model_for_int8_training(model) model = prepare_model_for_kbit_training(model)
model, lora_config = load_adapter(model, cfg, adapter) model, lora_config = load_adapter(model, cfg, adapter)

View File

@@ -1,6 +1,7 @@
"""Module containing the Trainer class and related functions""" """Module containing the Trainer class and related functions"""
import importlib import importlib
import logging
import math import math
import os import os
import sys import sys
@@ -62,8 +63,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.logging_steps is not None if cfg.logging_steps is not None
else max(min(int(0.005 * total_num_steps), 10), 1) else max(min(int(0.005 * total_num_steps), 10), 1)
) )
save_steps = cfg.save_steps
eval_steps = cfg.eval_steps
training_arguments_kwargs = {} training_arguments_kwargs = {}
if cfg.bf16 == "full": if cfg.bf16 == "full":
@@ -74,6 +73,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_arguments_kwargs["tf32"] = cfg.tf32 training_arguments_kwargs["tf32"] = cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps training_arguments_kwargs["logging_steps"] = logging_steps
if cfg.seed:
training_arguments_kwargs["seed"] = cfg.seed
if cfg.gradient_checkpointing: if cfg.gradient_checkpointing:
if cfg.gptq: if cfg.gptq:
from alpaca_lora_4bit.gradient_checkpointing import ( from alpaca_lora_4bit.gradient_checkpointing import (
@@ -119,16 +122,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
num_train_epochs=cfg.num_epochs, num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate, learning_rate=cfg.learning_rate,
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no", evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
save_strategy="steps" if save_steps else "epoch", save_strategy="steps" if cfg.save_steps else "epoch",
eval_steps=eval_steps if cfg.val_set_size > 0 else None, eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
save_steps=save_steps, save_steps=cfg.save_steps,
output_dir=cfg.output_dir, output_dir=cfg.output_dir,
save_total_limit=3, save_total_limit=3,
load_best_model_at_end=( load_best_model_at_end=(
cfg.load_best_model_at_end is not False cfg.load_best_model_at_end is not False
and cfg.val_set_size > 0 and cfg.val_set_size > 0
and save_steps and cfg.save_steps
and save_steps % eval_steps == 0 and cfg.save_steps % cfg.eval_steps == 0
and cfg.load_in_8bit is not True and cfg.load_in_8bit is not True
) )
or False, or False,
@@ -233,6 +236,23 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
else: else:
data_collator_kwargs["pad_to_multiple_of"] = 8 data_collator_kwargs["pad_to_multiple_of"] = 8
if cfg.is_llama_derived_model and cfg.landmark_attention:
from functools import partial
from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
model.set_mem_id(mem_id)
logging.info("Adding landmark attention tokens to dataset")
for dataset in [train_dataset, eval_dataset]:
dataset = dataset.map(
partial(add_mem_tokens, mem_freq=50, mem_id=mem_id),
batched=False,
num_proc=32,
)
trainer_cls = ( trainer_cls = (
OneCycleLRSchedulerTrainer OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora") if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")

View File

@@ -54,6 +54,9 @@ def validate_config(cfg):
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
) )
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
raise ValueError("FSDP is not supported for falcon models")
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -15,3 +15,5 @@ def setup_wandb_env_vars(cfg):
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0: if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
else:
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -165,3 +165,36 @@ class ValidationTest(unittest.TestCase):
) )
validate_config(cfg) validate_config(cfg)
def test_falcon_fsdp(self):
regex_exp = r".*FSDP is not supported for falcon models.*"
# Check for lower-case
cfg = DictDefault(
{
"base_model": "tiiuae/falcon-7b",
"fsdp": ["full_shard", "auto_wrap"],
}
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
# Check for upper-case
cfg = DictDefault(
{
"base_model": "Falcon-7b",
"fsdp": ["full_shard", "auto_wrap"],
}
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
cfg = DictDefault(
{
"base_model": "tiiuae/falcon-7b",
}
)
validate_config(cfg)