Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
7771498eae add guassian dropout support 2023-09-25 14:50:39 -04:00
65 changed files with 536 additions and 2416 deletions

View File

@@ -53,13 +53,6 @@ body:
validations:
required: true
- type: textarea
id: config
attributes:
label: Config yaml
description: |
Please attach the config yaml!
- type: textarea
id: possible-solution
attributes:

View File

@@ -6,11 +6,9 @@ on:
- "main"
paths:
- '**.py'
- 'requirements.txt'
pull_request:
paths:
- '**.py'
- 'requirements.txt'
workflow_dispatch:
jobs:
@@ -46,7 +44,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install -U -e .
pip3 install -e .
pip3 install -r requirements-tests.txt
- name: Run tests
@@ -71,8 +69,8 @@ jobs:
- name: Install dependencies
run: |
pip3 uninstall -y transformers accelerate
pip3 install -U -e .[flash-attn]
pip3 install -e .
pip3 install flash-attn
pip3 install -r requirements-tests.txt
- name: Run e2e tests

View File

@@ -12,4 +12,3 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-nested-blocks,

View File

@@ -124,11 +124,6 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Huggingface to use gated models/datasets.
```bash
huggingface-cli login
```
Get the token at huggingface.co/settings/tokens
- LambdaLabs
<details>
@@ -185,7 +180,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"instruction": "...", "input": "...", "output": "..."}
```
- `sharegpt`: conversations where `from` is `human`/`gpt`
- `sharegpt:chat`: conversations where `from` is `human`/`gpt`
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
@@ -250,10 +245,6 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"article": "...", "question": "...", "answer": "..."}
```
- `context_qa.load_v2`: in context question answering (alternate)
```json
{"context": "...", "question": "...", "answer": "..."}
```
- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
```json
{"article": "...", "unanswerable_question": "..."}
@@ -278,11 +269,11 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"prompt": "...", "generation": "..."}
```
- `sharegpt.load_role`: conversations where `role` is used instead of `from`
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- `sharegpt.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
@@ -317,7 +308,7 @@ Using file:
#### How to use your custom pretokenized dataset
- Do not pass a `type:`
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
- Dataset must contain `input_ids`, `attention_mask`, `labels` in columns
### Config
@@ -360,12 +351,6 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- path: data.jsonl # or json
ds_type: json # see other options below
type: alpaca
# dataset with splits, but no train split
dataset:
- path: knowrohit07/know_sql
type: context_qa.load_v2
train_on_split: validation
```
- loading
@@ -423,11 +408,6 @@ tokenizer_legacy:
# this is reported to improve training speed on some models
resize_token_embeddings_to_32x:
# used to identify which the model is based on
is_falcon_derived_model:
is_llama_derived_model:
is_mistral_derived_model:
# whether you are training a 4-bit GPTQ quantized model
gptq: true
gptq_groupsize: 128 # group size
@@ -459,7 +439,6 @@ datasets:
data_files: # Optional[str] path to source data files
shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
# custom user prompt
- path: repo
@@ -487,9 +466,6 @@ datasets:
dataset_prepared_path: data/last_run_prepared
# push prepared dataset to hub
push_dataset_to_hub: # repo path
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# if not set.
dataset_processes: # defaults to os.cpu_count() if not set
# push checkpoints to hub
hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub
@@ -571,7 +547,7 @@ torch_compile_backend: # Optional[str]
# training hyperparameters
gradient_accumulation_steps: 1
micro_batch_size: 2
eval_batch_size:
eval_batch_size: 2
num_epochs: 3
warmup_steps: 100
learning_rate: 0.00003
@@ -655,8 +631,6 @@ flash_optimum:
xformers_attention:
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
flash_attention:
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
# whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:

View File

@@ -12,18 +12,17 @@ RUN apt-get update && \
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
RUN cd axolotl && \
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
else \
pip install -e .[flash-attn]; \
fi
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
RUN cd axolotl && \
git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli

View File

@@ -7,7 +7,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
adapter: qlora
lora_model_dir:

View File

@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out

View File

@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out

View File

@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out

View File

@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out

View File

@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out

View File

@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out

View File

@@ -3,7 +3,6 @@ base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: true
load_in_4bit: false
gptq: false
@@ -12,7 +11,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca:chat
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
adapter: lora
lora_model_dir:

View File

@@ -6,7 +6,6 @@ base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
# enable 4bit for QLoRA
load_in_4bit: true
@@ -18,7 +17,7 @@ datasets:
data_files:
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
type: "alpaca:chat"
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
# enable QLoRA
adapter: qlora

View File

@@ -3,7 +3,6 @@ base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
load_in_4bit: false
gptq: false
@@ -12,7 +11,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca:chat
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
adapter:
lora_model_dir:

View File

@@ -7,7 +7,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
adapter: qlora
lora_model_dir:

View File

@@ -6,7 +6,7 @@ load_in_8bit: false
datasets:
- path: openaccess-ai-collective/jeopardy
type: jeopardy
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
adapter:
lora_model_dir:

View File

@@ -15,7 +15,7 @@ hf_use_auth_token: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
adapter: lora
lora_model_dir:

View File

@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out
@@ -56,7 +56,7 @@ flash_attention: true
warmup_steps: 10
eval_steps: 20
eval_table_size:
eval_table_size: 5
eval_table_max_new_tokens: 128
save_steps:
debug:

View File

@@ -11,7 +11,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out
@@ -58,7 +58,7 @@ flash_attention: true
warmup_steps: 10
eval_steps: 20
eval_table_size:
eval_table_size: 5
save_steps:
debug:
deepspeed:

View File

@@ -11,7 +11,7 @@ strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./relora-out

View File

@@ -12,7 +12,7 @@ strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out
@@ -56,7 +56,7 @@ flash_attention: true
warmup_steps: 10
eval_steps: 20
eval_table_size:
eval_table_size: 5
save_steps:
debug:
deepspeed:

View File

@@ -1,12 +0,0 @@
**Mistral 7B** is a language model with a total of 7.3 billion parameters, showcasing a notable performance across a variety of benchmarks.
Fine Tune:
```shell
accelerate launch -m axolotl.cli.train examples/mistral/config.yml
```
If you run into CUDA OOM, use deepspeed with config zero2.json:
```shell
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed/zero2.json
```

View File

@@ -1,62 +0,0 @@
base_model: mistralai/Mistral-7B-v0.1
base_model_config: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./out
sequence_len: 8192
sample_packing:
pad_to_sequence_len:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
eval_table_size: 5
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,79 +0,0 @@
base_model: mistralai/Mistral-7B-v0.1
base_model_config: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 8192
sample_packing: True
pad_to_sequence_len: True
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
eval_table_size: 5
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -6,7 +6,7 @@ load_in_8bit: false
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
adapter:
lora_model_dir:

View File

@@ -9,7 +9,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
adapter:
lora_model_dir:

View File

@@ -9,7 +9,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
adapter: lora
lora_model_dir:

View File

@@ -9,7 +9,7 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
adapter: qlora
lora_model_dir:

View File

@@ -13,7 +13,7 @@ datasets:
- path: garage-bAInd/Open-Platypus
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./phi-sft-out

View File

@@ -13,7 +13,7 @@ datasets:
- path: garage-bAInd/Open-Platypus
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./phi-sft-out

View File

@@ -10,7 +10,7 @@ device_map: auto
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter:
lora_model_dir:

View File

@@ -4,7 +4,7 @@ load_in_8bit: true
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter: lora
lora_model_dir:

View File

@@ -7,7 +7,7 @@ load_in_8bit: false
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
adapter:
lora_model_dir:

View File

@@ -5,7 +5,7 @@ load_in_8bit: false
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter: lora
lora_model_dir:

View File

@@ -16,7 +16,7 @@ datasets:
data_files:
- openassistant_best_replies_train.jsonl
type: "completion"
dataset_prepared_path:
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
# enable QLoRA
adapter: qlora

View File

@@ -4,15 +4,16 @@ torch==2.0.1
auto-gptq
packaging
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git@bd6205919aad4d3a2300a39a98a642f1cc3a5348
transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
accelerate @ git+https://github.com/huggingface/accelerate
deepspeed
addict
evaluate
fire
PyYAML>=6.0
datasets
flash-attn>=2.3.0
flash-attn>=2.2.1
sentencepiece
wandb
einops
@@ -30,4 +31,3 @@ scipy
scikit-learn==1.2.2
pynvml
art
fschat==0.2.29

View File

@@ -7,7 +7,6 @@ import transformers
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
do_inference,
do_merge_lora,
load_cfg,
@@ -32,7 +31,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
)
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True

View File

@@ -14,8 +14,6 @@ import yaml
# add src to the pythonpath so we don't need to pip install this
from accelerate.commands.config import config_args
from art import text2art
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextStreamer
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
@@ -51,7 +49,7 @@ def print_axolotl_text_art(suffix=None):
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to submit): ")
print("Give me an instruction (Ctrl + D to finish): ")
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
@@ -249,16 +247,3 @@ def check_accelerate_default_config():
LOG.warning(
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
)
def check_user_token():
# Verify if token is valid
api = HfApi()
try:
user_info = api.whoami()
return bool(user_info)
except LocalTokenNotFoundError:
LOG.warning(
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False

View File

@@ -8,7 +8,6 @@ import transformers
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
print_axolotl_text_art,
@@ -22,7 +21,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True

View File

@@ -5,7 +5,7 @@ import os
from typing import List
import torch
from datasets import Dataset, IterableDataset, Sequence, Value
from datasets import Dataset, IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy
@@ -22,7 +22,7 @@ class TokenizedPromptDataset(Dataset):
"""
Dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
"""
@@ -42,15 +42,11 @@ class TokenizedPromptDataset(Dataset):
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100
return (
dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
**map_kwargs,
)
.cast_column("input_ids", Sequence(feature=Value(dtype="int32", id=None)))
.cast_column("labels", Sequence(feature=Value(dtype="int32", id=None)))
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
**map_kwargs,
)
@@ -59,7 +55,7 @@ class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for processing the data.
tokenizer (Tokenizer): The processor used for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
seq_length (int): Length of token sequences to return.
"""

View File

@@ -711,8 +711,12 @@ class ParallelBlock(nn.Module):
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.block_idx = block_idx
self.mixer = MHA(config, layer_idx=block_idx)
self.mlp = MLP(config)
self.mixer = MHA(config=config, **mixer, layer_idx=block_idx)
mlp_cls = mlp.pop("mlp_cls")
if mlp_cls == "fused_mlp":
self.mlp = FusedMLP(config=config, **mlp)
else:
self.mlp = MLP(config=config, **mlp)
def forward(
self,

View File

@@ -7,7 +7,6 @@ import logging
from typing import Optional, Tuple
import torch
from accelerate import init_empty_weights
from flash_attn.flash_attn_interface import flash_attn_func
from transformers import AutoConfig, AutoModelForCausalLM
@@ -18,8 +17,7 @@ def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
# this is a wonky hack to get the remotely loaded module
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_btlm to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_btlm", ".modeling_btlm"
)

View File

@@ -0,0 +1,101 @@
"""
Flash Attention monkey patch for Falcon
copied from https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/falcon_flash_attn_monkey_patch.py
"""
from typing import Optional, Tuple
import torch
import transformers
from flash_attn import flash_attn_func
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor, # pylint: disable=unused-argument
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
use_cache: bool = False,
output_attentions: bool = False, # pylint: disable=unused-argument
):
fused_qkv = self.query_key_value(
hidden_states
) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = (
self.num_heads if self.new_decoder_architecture else self.num_kv_heads
)
# 3 x [batch_size, seq_length, num_heads, head_dim]
(
query_layer,
key_layer,
value_layer,
) = self._split_heads( # pylint: disable=protected-access
fused_qkv
)
batch_size, query_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(
batch_size * self.num_heads, query_length, self.head_dim
)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads, query_length, self.head_dim
)
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
# unused
# _, kv_length, _ = key_layer.shape
if use_cache:
present = (key_layer, value_layer)
else:
present = None
# unused
# attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
query_layer_ = (
query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
.transpose(1, 2)
.to(torch.bfloat16)
)
key_layer_ = (
key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
.transpose(1, 2)
.to(torch.bfloat16)
)
value_layer_ = (
value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
.transpose(1, 2)
.to(torch.bfloat16)
)
if alibi is not None:
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
# below output will have shape (batch_size, seqlen, nheads, headdim)
attn_output = flash_attn_func(query_layer_, key_layer_, value_layer_, causal=True)
attn_output = attn_output.reshape(
batch_size, query_length, self.num_heads * self.head_dim
)
output_tensor = self.dense(attn_output)
return output_tensor, present
def replace_falcon_attn_with_flash_attn():
transformers.models.falcon.modeling_falcon.FalconAttention.forward = forward

View File

@@ -1,174 +0,0 @@
"""
monkeypatch to add a get_turns method
"""
import logging
from typing import Generator, Tuple
from fastchat.conversation import SeparatorStyle
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
def get_prompt(self) -> str:
ret = ""
for role, msg in self.get_turns():
ret += role + msg
return ret
def get_turns( # pylint: disable=too-many-return-statements
self,
) -> Generator[Tuple[str, str], None, None]:
"""Get the prompt for generation."""
system_prompt = self.system_template.format(system_message=self.system_message)
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
seps = [self.sep, self.sep2]
yield "", system_prompt + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
yield role + ": ", message + seps[i % 2]
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ": ", "" # must be end with a space
return
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
yield "", "" if system_prompt == "" else system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + "\n", message + self.sep
else:
yield role + "\n", ""
return
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role, message + self.sep
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.NO_COLON_TWO:
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield role, message + seps[i % 2]
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.RWKV:
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield role + ": ", message.replace("\r\n", "\n").replace(
"\n\n", "\n"
) + "\n\n"
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
if self.system_message:
yield "", system_prompt
else:
yield "", "[INST] "
for i, (role, message) in enumerate(self.messages[1:]):
if message:
yield role + " ", message + seps[i % 2]
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
round_add_n = 1 if self.name == "chatglm2" else 0
if system_prompt:
yield "", system_prompt + self.sep
for i, (role, message) in enumerate(self.messages):
if i % 2 == 0:
yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
if message:
yield f"{role}", f"{message}{self.sep}"
else:
yield f"{role}", ""
return
if self.sep_style == SeparatorStyle.CHATML:
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
for role, message in self.messages:
if message:
yield role + "\n", message + self.sep + "\n"
else:
yield role + "\n", ""
return
if self.sep_style == SeparatorStyle.CHATINTERN:
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
prefix = "<s>" if i % 2 == 0 else ""
if message:
yield prefix + role + ":", message + seps[i % 2] + "\n"
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.DOLLY:
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
suffix = "\n\n" if i % 2 == 1 else ""
yield role + ":\n", message + seps[i % 2] + suffix
else:
yield role + ":\n", ""
return
if self.sep_style == SeparatorStyle.PHOENIX:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role + ": ", "<s>" + message + "</s>"
else:
yield role + ": " + "<s>", ""
return
if self.sep_style == SeparatorStyle.ROBIN:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ":\n", message + self.sep
else:
yield role + ":\n", ""
return
if self.sep_style == SeparatorStyle.FALCON_CHAT:
if self.system_message:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ":", ""
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def add_get_turns_to_conversation():
import fastchat.conversation
fastchat.conversation.Conversation.get_turns = get_turns
fastchat.conversation.Conversation.get_prompt = get_prompt

View File

@@ -12,6 +12,8 @@ import torch.nn.functional as F
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from torch import nn
from transformers import LlamaConfig
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
@@ -38,11 +40,7 @@ except ImportError:
LOG = logging.getLogger("axolotl")
def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
rms_norm: Optional[bool] = False,
):
def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
@@ -53,37 +51,46 @@ def replace_llama_attn_with_flash_attn(
llama_model_forward
)
# skip only if explicitly disabled
if cross_entropy:
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.info(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.info(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
# skip only if explicitly disabled
if rms_norm:
try:
from flash_attn.ops.rms_norm import RMSNorm
try:
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.info(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.info(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
class GaussianDropout(nn.Module):
def __init__(self, p=0.5):
super(GaussianDropout, self).__init__()
if p <= 0 or p >= 1:
raise Exception("p value should accomplish 0 < p < 1")
self.p = p
def forward(self, x):
stddev = (self.p / (1.0 - self.p)) ** 0.5
epsilon = torch.randn_like(x) * stddev
return x * epsilon
# Disable the transformation of the attention mask in LlamaModel as the flash attention
@@ -107,7 +114,6 @@ def flashattn_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
@@ -211,7 +217,7 @@ def flashattn_forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape:
@@ -485,13 +491,6 @@ def llama_model_forward(
dtype=torch.bool,
device=inputs_embeds.device,
)
padding_mask = None
else:
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
@@ -526,9 +525,7 @@ def llama_model_forward(
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(
*inputs,
)
return module(*inputs)
return custom_forward
@@ -537,10 +534,9 @@ def llama_model_forward(
hidden_states,
attention_mask,
position_ids,
past_key_value,
None,
output_attentions,
None,
padding_mask,
cu_seqlens,
max_seqlen,
)
@@ -552,7 +548,6 @@ def llama_model_forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
@@ -591,6 +586,15 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
"""
def __init__(self, config: LlamaConfig):
super(LlamaDecoderLayer, self).__init__(config)
self.attn_dropout = None
self.mlp_dropout = None
if config.dropout_attn:
self.attn_dropout = GaussianDropout(p=config.dropout_attn)
if config.dropout_mlp:
self.mlp_dropout = GaussianDropout(p=config.dropout_mlp)
def forward(
self,
hidden_states: torch.Tensor,
@@ -599,7 +603,6 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
@@ -632,16 +635,19 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
if self.training and self.attn_dropout:
hidden_states = self.attn_dropout(hidden_states)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.training and self.mlp_dropout:
hidden_states = self.mlp_dropout(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)

View File

@@ -1,541 +0,0 @@
"""Flash attention monkey patch for mistral model"""
# pylint: disable=duplicate-code
import logging
from typing import List, Optional, Tuple, Union
import torch
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer,
)
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
def replace_mistral_attn_with_flash_attn(
packed: Optional[bool] = False,
):
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
flashattn_forward
)
if packed:
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
MistralDecoderLayer
)
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
mistral_model_forward
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
sliding_window,
): # pylint: disable=unused-argument
# [bsz, seq_len]
return attention_mask
def flashattn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if self.training:
# during training q,k,v always have same seqlen
assert key_states.shape == query_states.shape
is_causal = True
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
query_states,
key_states,
value_states,
qkvpacked=True,
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
)
output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
cu_seqlens_q,
max_seqlen_q,
0.0,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if attention_mask is None or attention_mask.all().item():
output = flash_attn_kvpacked_func(
query_states,
torch.stack([key_states, value_states], 2),
causal=is_causal,
)
else:
( # pylint: disable=unbalanced-tuple-unpacking
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
_,
_,
output_pad_fn,
) = generate_qkv(
query_states,
key_states,
value_states,
kvpacked=True,
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
)
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
attn_output = output
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
def generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False,
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
q, query_padding_mask
)
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q_unpad.device,
)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=k_unpad.device,
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
if kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
return (
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
kv,
output_pad_fn,
)
return (
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
)
def mistral_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
cu_seqlens = None
max_seqlen = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
transformers.logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
cu_seqlens,
max_seqlen,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class MistralDecoderLayer(OriginalMistralDecoderLayer):
"""
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs

View File

@@ -1,415 +0,0 @@
# coding=utf-8
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based off the following work:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
""" PyTorch StableLM Epoch model. """
import importlib
import math
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from accelerate import init_empty_weights
from einops import rearrange
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_varlen_qkvpacked_func,
)
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
logger = logging.get_logger(__name__)
def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"):
# this is a wonky hack to get the remotely loaded module
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_stablelm_epoch to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_stablelm_epoch", ".modeling_stablelm_epoch"
)
modeling_stablelm = importlib.import_module(module_name)
modeling_stablelm.Attention.forward = ( # pylint: disable=protected-access
flashattn_attn
)
modeling_stablelm.StableLMEpochModel.forward = ( # pylint: disable=protected-access
stablelm_model_forward
)
modeling_stablelm.DecoderLayer.forward = ( # pylint: disable=protected-access
decoder_layer_forward
)
def rotate_half(x: torch.Tensor):
"""Rotates half the hidden dims of the input."""
# pylint: disable=invalid-name
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
# pylint: disable=invalid-name
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def flashattn_attn(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, # pylint: disable=unused-argument
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
query_rot = query_states[..., : self.rotary_ndims]
query_pass = query_states[..., self.rotary_ndims :]
key_rot = key_states[..., : self.rotary_ndims]
key_pass = key_states[..., self.rotary_ndims :]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_rot, key_rot, cos, sin, position_ids
)
# [batch_size, num_heads, seq_len, head_dim]
query_states = torch.cat((query_states, query_pass), dim=-1)
key_states = torch.cat((key_states, key_pass), dim=-1)
if past_key_value is not None:
# Reuse k, v, self_attention
key_states = torch.cat((past_key_value[0], key_states), dim=2)
value_states = torch.cat((past_key_value[1], value_states), dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# Repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
softmax_scale = None
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=softmax_scale, causal=True
)
attn_output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# Upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
# Merge heads
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
# Final linear projection
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def decoder_layer_forward(
self,
hidden_states: Optional[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Union[
Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]
]:
# pylint: disable=duplicate-code
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
def stablelm_model_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# Retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
cu_seqlens = None
max_seqlen = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
cu_seqlens,
max_seqlen,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# Add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

View File

@@ -24,15 +24,6 @@ def load(tokenizer, cfg):
)
def load_v2(tokenizer, cfg):
return ContextQaV2PromptTokenizingStrategy(
ContextV2Prompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class AlpacaContextPrompter(AlpacaPrompter):
"""
Customized system prompted for concise QA
@@ -59,38 +50,6 @@ class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
)
class ContextQaV2PromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenization Strategy to combine in-context article with a question and answer
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
"Context: "
+ prompt["context"]
+ "\nQuestion: "
+ prompt["question"]
+ "\n",
"",
"Answer: " + prompt["answer"],
)
class ContextV2Prompter(AlpacaPrompter):
"""
Customized system prompted for concise QA
"""
system_prompt = ""
system_no_input_prompt = ""
def match_prompt_style(self):
# pylint: disable=duplicate-code
self.turn_format = "{instruction}\n{input}"
self.turn_no_input_format = "{instruction}"
self.system_format = "{system}"
class AlpacaMissingInfoContextPromptTokenizingStrategy(
InstructionPromptTokenizingStrategy
):

View File

@@ -1,119 +0,0 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
from typing import Any, Dict, Optional
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
register_conv_template(
Conversation(
name="chatml",
system_template="<|im_start|>system\n{system_message}",
system_message="You are a helpful assistant.",
roles=["<|im_start|>user", "<|im_start|>assistant"],
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>\n",
)
)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
strat = ShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and ds_cfg["skip"]:
strat.skip_invalid = True
return strat
def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_guanaco(tokenizer, cfg):
return GuanacoShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_nous(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
return NousShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class NousShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy used by nous/ldj for input/output keyed data
"""
def get_conversation_thread(self):
return "conversation"
def map_conversation_thread(self, conversation):
turns = []
for turn in conversation:
turns.append({"from": "human", "value": turn["input"]})
turns.append({"from": "gpt", "value": turn["output"]})
return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
"""
def map_conversation_thread(self, conversation):
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
turns = [
{"from": turn["role"], "value": turn["value"]} for turn in conversation
]
return turns
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps oasst data to sharegpt format
"""
def map_conversation_thread(self, conversation):
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
role_map = {"prompter": "human", "assistant": "gpt"}
turns = [
{"from": role_map[turn["role"]], "value": turn["text"]}
for turn in conversation
]
return turns

View File

@@ -1,11 +1,11 @@
"""Module for Jokes prompts using sharegpt style """
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
from axolotl.prompters import PromptStyle, ShareGPTPrompter
def load(tokenizer, cfg):
return SimpleJokesShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,

View File

@@ -0,0 +1,67 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import PromptStyle, ShareGPTPrompter
def load(tokenizer, cfg):
return SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_guanaco(tokenizer, cfg):
return GuanacoShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
"""
def get_conversation_thread(self, prompt):
return prompt["conversations"]
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
"""
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
return turns
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps oasst data to sharegpt format
"""
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
role_map = {"prompter": "human", "assistant": "gpt"}
turns = [
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
]
return turns

View File

@@ -4,15 +4,10 @@ import abc
import copy
import functools
import logging
from collections import defaultdict
from typing import Dict, List, Tuple, Union
from fastchat.conversation import Conversation
from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.monkeypatch.fastchat_conversation_turns import (
add_get_turns_to_conversation,
)
from axolotl.prompters import IGNORE_TOKEN_ID
LOG = logging.getLogger("axolotl")
@@ -23,8 +18,6 @@ LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
add_get_turns_to_conversation()
class InvalidDataException(Exception):
"""
@@ -82,7 +75,7 @@ class PromptTokenizingStrategy(abc.ABC):
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:
result: BatchEncoding
if not prompt:
if not prompt.strip():
LOG.warning("Empty text requested for tokenization.")
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
else:
@@ -352,109 +345,72 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
Tokenizing strategy for ShareGPT prompts.
"""
_skip_invalid = False
@property
def supports_batched(self):
return True
@property
def skip_invalid(self):
return self._skip_invalid
@skip_invalid.setter
def skip_invalid(self, value):
self._skip_invalid = value
def get_conversation_thread(self):
return "conversations"
def map_conversation_thread(self, conversation):
return conversation
def get_conversation_thread(self, prompt):
return prompt["conversations"]
def tokenize_prompt(self, prompt):
tokenized_res = defaultdict(lambda: [])
conv_field = self.get_conversation_thread()
for prmpt in prompt[conv_field]:
result, current_len = tokenize_prompt_default()
user_token = self._get_user_token()
assistant_token = self._get_assistant_token()
conversation: Conversation = (
self.prompter._conversation # pylint: disable=protected-access
)
try:
for _, part in enumerate(
self.prompter.build_prompt(self.map_conversation_thread(prmpt))
):
if isinstance(part, tuple):
if conversation.roles[0] in part[0]:
turn = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should
if not part[1].strip():
err_msg = f"user turn has empty text: {prmpt}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
res = self._tokenize(
turn,
add_eos_token=False,
strip_bos_token=True,
)
if user_token:
res["input_ids"] = [user_token, *res["input_ids"]]
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif conversation.roles[1] in part[0]:
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
turn = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistant response, should end with an eos token
if not part[1].strip():
err_msg = f"assistant turn has empty text: {prmpt}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
res = self._tokenize(
turn,
add_eos_token=True,
strip_bos_token=True,
)
if assistant_token:
res["input_ids"] = [
assistant_token,
*res["input_ids"],
]
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
elif part[0] == "":
turn = part[1]
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
err_msg = f"unhandled role: {part[0]}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
continue
result, current_len = tokenize_prompt_default()
user_token = self._get_user_token()
assistant_token = self._get_assistant_token()
try:
for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if isinstance(part, tuple):
if part[0] == "USER:":
turn = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should
if not part[1].strip():
LOG.warning(f"user turn has empty text: {prompt}")
res = self._tokenize(
turn.strip(),
add_eos_token=False,
strip_bos_token=True,
)
if user_token:
res["input_ids"] = [user_token, *res["input_ids"]]
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif part[0] == "ASSISTANT:":
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
turn = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistant response, should end with an eos token
if not part[1].strip():
LOG.warning(f"assistant turn has empty text: {prompt}")
res = self._tokenize(
turn.strip(),
add_eos_token=True,
strip_bos_token=True,
)
if assistant_token:
res["input_ids"] = [
assistant_token,
*res["input_ids"],
]
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
elif part[0] == "SYSTEM:":
part = part[1] # Ignore the system role from preamble
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
part.strip(), add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {part[0]}")
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
for key, val in sorted(result.items(), key=lambda x: x[0]):
tokenized_res[key].append(val)
except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err
except ValueError as err:
LOG.warning("skipping prompt: %s", str(err))
return tokenized_res
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
return result
except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
if not prompt.strip():

View File

@@ -1,10 +1,9 @@
"""Module containing prompters"""
import dataclasses
import logging
from enum import Enum
from typing import Generator, Optional, Union
from fastchat.conversation import Conversation, get_conv_template
from enum import Enum, auto
from typing import Generator, List, Optional, Tuple, Union
LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100
@@ -215,6 +214,53 @@ class ReflectAlpacaPrompter:
yield res
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
DOLLY = auto()
# TODO clean this 💩 up
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: Optional[str] = None
def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
# seps = [self.sep, self.sep2]
preamble = self.system + self.sep
yield ("SYSTEM:", preamble)
for _, (role, message) in enumerate(self.messages):
if message:
yield (role + ":", " " + message)
else:
LOG.warning(f"role with empty message: {role}")
yield (role + ":", "")
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
)
def append_message(self, role, message):
self.messages.append([role, message])
SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data."
)
@@ -225,27 +271,28 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
A prompter that generates prompts for the ShareGPT
"""
role_key_human = "human"
role_key_model = "gpt"
def __init__(
self,
prompt_style=None, # pylint: disable=unused-argument
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
):
if conversation:
if isinstance(conversation, Conversation):
self._conversation = conversation
else:
self._conversation = get_conv_template(conversation)
else:
self._conversation = get_conv_template("vicuna_v1.1")
if role_key_human:
self.role_key_human = role_key_human
if role_key_model:
self.role_key_model = role_key_model
def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
if prompt_style != PromptStyle.CHAT.value:
raise ValueError(
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
)
system: str = (
system_prompt
if system_prompt
else (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
)
)
self._conversation = Conversation(
system=system,
roles=["USER", "ASSISTANT"],
messages=[],
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2=" ",
)
def build_prompt(self, source) -> Generator[str, None, None]:
if len(source) < 2:
@@ -259,14 +306,17 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
# Add the conversation system prompt if provided, otherwise use the default one
if source[0]["from"] == "system":
conv.set_system_message(source[0]["value"])
conv.system = source[0]["value"]
source.pop(0)
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
try:
# Apply prompt templates
if source[0]["from"] not in roles:
if (
source[0]["from"] not in roles
or roles[source[0]["from"]] != conv.roles[0]
):
# Skip the first one if it is not from human
source = source[1:]
except IndexError as err:
@@ -276,29 +326,8 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
if role != conv.roles[j % 2]:
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
conv.append_message(role, sentence["value"])
for part in conv.get_turns():
if part[0] and not part[1]:
LOG.warning(f"role with empty message: {part[0]}")
for part in conv.get_prompt():
yield part
class ShareGPTPrompterV2(ShareGPTPrompter):
"""
A V2 prompter that generates prompts for the ShareGPT
"""
def __init__(
self,
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
):
super().__init__(
conversation=conversation,
role_key_human=role_key_human,
role_key_model=role_key_model,
)

View File

@@ -58,9 +58,7 @@ def train(
safe_serialization = cfg.save_safetensors is True
if (
cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints
) or cfg.resume_from_checkpoint is True:
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
@@ -73,9 +71,7 @@ def train(
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = (
cfg.resume_from_checkpoint if cfg.resume_from_checkpoint is not True else None
)
resume_from_checkpoint = cfg.resume_from_checkpoint
trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps

View File

@@ -49,8 +49,6 @@ def normalize_config(cfg):
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
if cfg.eval_batch_size is None:
cfg.eval_batch_size = cfg.micro_batch_size
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0
@@ -77,8 +75,6 @@ def normalize_config(cfg):
else:
cfg.torch_dtype = torch.float32
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
model_config = load_model_config(cfg)
cfg.model_config_type = model_config.model_type
@@ -86,39 +82,10 @@ def normalize_config(cfg):
cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower()
or "llama" in cfg.base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
)
# figure out if the model is falcon
cfg.is_falcon_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type
in [
"falcon",
"RefinedWebModel",
"RefinedWeb",
]
)
or cfg.is_falcon_derived_model
or "falcon" in cfg.base_model.lower()
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
)
cfg.is_mistral_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type
in [
"mistral",
]
)
or cfg.is_mistral_derived_model
or "mistral" in cfg.base_model.lower()
or (cfg.model_type and "mistral" in cfg.model_type.lower())
)
log_gpu_memory_usage(LOG, "baseline", cfg.device)
@@ -159,11 +126,6 @@ def validate_config(cfg):
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
)
if cfg.eval_batch_size != cfg.micro_batch_size:
LOG.warning(
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
)
if cfg.load_4bit:
raise ValueError("cfg.load_4bit parameter has been deprecated")
@@ -300,45 +262,6 @@ def validate_config(cfg):
"`model_type: MixFormerSequentialForCausalLM` required for sample_packing"
)
if cfg.datasets:
for idx, ds_cfg in enumerate(cfg.datasets):
if not ds_cfg.type:
continue
if ds_cfg.type == "sharegpt:chat":
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
cfg.datasets[idx].type = "sharegpt"
if "sharegpt_simple" in ds_cfg.type:
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
"sharegpt_simple", "sharegpt"
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)
if (
cfg.evaluation_strategy
and cfg.eval_steps
and cfg.evaluation_strategy != "steps"
):
raise ValueError(
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
)
if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
raise ValueError(
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -25,6 +25,7 @@ from axolotl.prompt_tokenizers import (
GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
SummarizeTLDRPromptTokenizingStrategy,
)
from axolotl.prompters import (
@@ -34,6 +35,7 @@ from axolotl.prompters import (
MultipleChoiceConcisePrompter,
MultipleChoiceExplainPrompter,
ReflectAlpacaPrompter,
ShareGPTPrompter,
SummarizeTLDRPrompter,
)
from axolotl.utils.dict import DictDefault
@@ -74,7 +76,7 @@ def prepare_dataset(cfg, tokenizer):
with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset, tokenizer
cfg, train_dataset, eval_dataset
)
if cfg.max_steps:
total_num_steps = min(
@@ -114,7 +116,7 @@ def load_tokenized_prepared_datasets(
if cfg.push_dataset_to_hub:
dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
token=use_auth_token,
use_auth_token=use_auth_token,
)
dataset = dataset["train"]
except Exception: # pylint: disable=broad-except # nosec
@@ -122,7 +124,7 @@ def load_tokenized_prepared_datasets(
if dataset:
...
elif cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")):
elif any(prepared_ds_path.glob("*")):
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path))
LOG.info("Prepared dataset loaded from disk...")
@@ -155,26 +157,24 @@ def load_tokenized_prepared_datasets(
d.path,
name=d.name,
streaming=True,
token=use_auth_token,
use_auth_token=use_auth_token,
)
ds_from_hub = True
except (FileNotFoundError, ValueError):
except FileNotFoundError:
pass
# prefer local dataset, even if hub exists
local_path = Path(d.path)
if local_path.exists():
if local_path.is_dir():
if not d.type:
ds = load_from_disk(d.path)
else:
ds = load_dataset(
d.path,
name=d.name,
data_files=d.data_files,
streaming=False,
split=None,
)
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
ds = load_dataset(
d.path,
name=d.name,
data_files=d.data_files,
streaming=False,
split=None,
)
elif local_path.is_file():
ds_type = "json"
if d.ds_type:
@@ -204,29 +204,14 @@ def load_tokenized_prepared_datasets(
name=d.name,
streaming=False,
data_files=d.data_files,
token=use_auth_token,
use_auth_token=use_auth_token,
)
else:
if isinstance(d.data_files, str):
fp = hf_hub_download(
repo_id=d.path,
repo_type="dataset",
filename=d.data_files,
)
elif isinstance(d.data_files, list):
fp = []
for file in d.data_files:
fp.append(
hf_hub_download(
repo_id=d.path,
repo_type="dataset",
filename=file,
)
)
else:
raise ValueError(
"data_files must be either a string or list of strings"
)
fp = hf_hub_download(
repo_id=d.path,
repo_type="dataset",
filename=d.data_files,
)
ds = load_dataset(
"json", name=d.name, data_files=fp, streaming=False, split=None
)
@@ -249,16 +234,6 @@ def load_tokenized_prepared_datasets(
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if "train" in ds:
ds = ds["train"]
elif (
isinstance(ds, DatasetDict)
and d.train_on_split
and d.train_on_split in ds
):
ds = ds[d.train_on_split]
elif isinstance(ds, DatasetDict):
raise ValueError(
f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `"
)
if (
"input_ids" in ds.features
and "attention_mask" in ds.features
@@ -345,6 +320,15 @@ def load_tokenized_prepared_datasets(
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "sharegpt":
ds_strategy = ShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
else:
suffix = ""
if ":load_" in d.type:
@@ -359,7 +343,7 @@ def load_tokenized_prepared_datasets(
if len(datasets) > 1:
LOG.info("shuffle merged datasets")
dataset = dataset.shuffle(seed=seed)
if cfg.local_rank == 0 and cfg.dataset_prepared_path:
if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(prepared_ds_path)
if cfg.push_dataset_to_hub:
@@ -419,7 +403,7 @@ def load_prepare_datasets(
)
dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
token=use_auth_token,
use_auth_token=use_auth_token,
)
dataset = dataset["train"]
except Exception: # pylint: disable=broad-except # nosec
@@ -427,7 +411,7 @@ def load_prepare_datasets(
if dataset:
...
elif cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")):
elif any(prepared_ds_path.glob("*")):
LOG.info(
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
)

View File

@@ -1,4 +1,5 @@
"""Module for models and model loading"""
import importlib
import logging
import math
import os
@@ -11,7 +12,6 @@ from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig, prepare_model_for_kbit_training
from peft.tuners.lora import QuantLinear
from transformers import ( # noqa: F401
AddedToken,
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
@@ -81,22 +81,11 @@ def load_tokenizer(cfg):
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Mistral's official FA implementation requires left padding
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
tokenizer.padding_side = "left"
if cfg.special_tokens:
for k, val in cfg.special_tokens.items():
tokenizer.add_special_tokens(
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
)
tokenizer.add_special_tokens({k: val})
if cfg.tokens:
tokenizer.add_tokens(
[
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
for token in cfg.tokens
]
)
tokenizer.add_tokens(list(cfg.tokens))
return tokenizer
@@ -125,29 +114,26 @@ def load_model(
replace_btlm_attn_with_flash_attn(cfg.base_model)
if (
hasattr(model_config, "model_type")
and model_config.model_type == "stablelm_epoch"
):
if cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.stablelm_attn_hijack_flash import (
replace_stablelm_attn_with_flash_attn,
if hasattr(model_config, "model_type") and model_config.model_type in [
"falcon",
"RefinedWebModel",
"RefinedWeb",
]:
if cfg.flash_attention:
from axolotl.monkeypatch.falcon_attn_hijack_flash import (
replace_falcon_attn_with_flash_attn,
)
replace_stablelm_attn_with_flash_attn(cfg.base_model)
replace_falcon_attn_with_flash_attn()
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=cfg.sample_packing,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
)
LOG.info("patching with flash attention")
replace_llama_attn_with_flash_attn(packed=cfg.sample_packing)
elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
@@ -172,14 +158,6 @@ def load_model(
# Note: This might overwrite previous additional_special_tokens
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
replace_mistral_attn_with_flash_attn,
)
LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if cfg.is_llama_derived_model and cfg.xpos_rope:
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
replace_llama_rope_with_xpos_rope,
@@ -198,11 +176,21 @@ def load_model(
LOG.info("patching _expand_mask")
hijack_expand_mask()
# special handling b/c remote MixFormers code doesn't have _no_split_modules set
if (
"MixFormerSequentialConfig" in model_config.__class__.__name__
and cfg.model_type == "AutoModelForCausalLM"
):
module_name = model_config.__class__.__module__.replace(
".configuration_mixformer_sequential", ".modeling_mixformer_sequential"
)
modeling_phi = importlib.import_module(module_name)
# pylint:disable=protected-access
modeling_phi.MixFormerSequentialForCausalLM._no_split_modules = [
"ParallelBlock"
]
model_kwargs = {}
model_kwargs["device_map"] = cfg.device_map
model_kwargs["torch_dtype"] = cfg.torch_dtype
if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.gptq:
@@ -225,15 +213,6 @@ def load_model(
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
# sample packing uses custom FA2 patch
if cfg.flash_attention and not cfg.sample_packing:
if (
cfg.is_llama_derived_model
or cfg.is_falcon_derived_model
or cfg.is_mistral_derived_model
):
model_kwargs["use_flash_attention_2"] = True
try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
from transformers import LlamaForCausalLM
@@ -248,8 +227,10 @@ def load_model(
model = LlamaForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
**model_kwargs,
)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
@@ -283,22 +264,28 @@ def load_model(
model = MixFormerSequentialForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
**model_kwargs,
)
elif model_type and not cfg.trust_remote_code:
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = getattr(transformers, model_type).from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -327,6 +314,8 @@ def load_model(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -334,8 +323,10 @@ def load_model(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -346,8 +337,10 @@ def load_model(
LOG.exception(err)
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -382,7 +375,7 @@ def load_model(
if model_config.model_type == "btlm":
# don't upcast lm_head for btlm
continue
if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module.to(torch.float32)

View File

@@ -31,9 +31,7 @@ def check_example_labels(example, tokenizer, text_only=False):
)
colored_tokens.append(colored_token)
delimiter = "" if text_only else " "
LOG.info(delimiter.join(colored_tokens))
LOG.info(" ".join(colored_tokens))
LOG.info("\n\n\n")
print(" ".join(colored_tokens))
return " ".join(colored_tokens)

View File

@@ -397,36 +397,23 @@ def disable_datasets_caching():
set_caching_enabled(True)
def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
with zero_first(is_main_process()):
train_dataset = train_dataset.filter(drop_long, num_proc=cfg.dataset_processes)
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
if eval_dataset:
eval_dataset = eval_dataset.filter(
drop_long, num_proc=cfg.dataset_processes
)
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
if cfg.group_by_length:
train_dataset = train_dataset.map(
add_length, num_proc=cfg.dataset_processes
)
train_dataset = train_dataset.map(add_length, num_proc=os.cpu_count())
if cfg.sample_packing:
train_dataset = train_dataset.map(
add_position_ids, num_proc=cfg.dataset_processes
)
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids, num_proc=cfg.dataset_processes
add_position_ids, num_proc=os.cpu_count()
)
# Phi doesn't want the attention_mask feature when training
if "CodeGenTokenizer" in tokenizer.__class__.__name__:
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")
return train_dataset, eval_dataset
@@ -610,19 +597,26 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"sample_packing_efficiency"
] = cfg.sample_packing_eff_est
if cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
elif cfg.evaluation_strategy:
if cfg.eval_steps and cfg.evaluation_strategy:
# assume if the user set both, they know what they're doing
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
elif cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
elif cfg.evaluation_strategy and cfg.evaluation_strategy in ["epoch", "no"]:
# if explicitly set for epoch, just set, and eval steps don't matter
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
elif cfg.eval_steps:
# steps isn't used w/ epochs
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
if cfg.save_steps:
# save_steps implies save_strategy of steps
training_arguments_kwargs["save_strategy"] = "steps"
training_arguments_kwargs["save_steps"] = cfg.save_steps
elif cfg.save_strategy:
@@ -668,7 +662,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len,
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
if cfg.eval_batch_size is not None
else cfg.micro_batch_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs,

View File

@@ -1,116 +0,0 @@
"""
E2E tests for lora llama
"""
import logging
import os
import tempfile
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestMistral(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
def test_lora(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True,
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_ft(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "pytorch_model.bin").exists()

View File

@@ -1,118 +0,0 @@
"""
E2E tests for lora llama
"""
import logging
import os
import tempfile
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestMistral(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
def test_lora_packing(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_ft_packing(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "pytorch_model.bin").exists()

File diff suppressed because one or more lines are too long

View File

@@ -21,7 +21,7 @@ from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
LOG = logging.getLogger("axolotl")
@@ -60,7 +60,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
) as fin:
data = fin.read()
tokenized_conversation = json.loads(data)
prompter = ShareGPTPrompterV2()
prompter = ShareGPTPrompter("chat")
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
@@ -79,7 +79,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
) as fin:
data = fin.read()
conversation = json.loads(data)
prompter = ShareGPTPrompterV2()
prompter = ShareGPTPrompter("chat")
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,

View File

@@ -374,194 +374,3 @@ class ValidationTest(unittest.TestCase):
)
validate_config(cfg)
def test_sharegpt_deprecation(self):
cfg = DictDefault(
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"`type: sharegpt:chat` will soon be deprecated." in record.message
for record in self._caplog.records
)
assert cfg.datasets[0].type == "sharegpt"
cfg = DictDefault(
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}]}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"`type: sharegpt_simple` will soon be deprecated." in record.message
for record in self._caplog.records
)
assert cfg.datasets[0].type == "sharegpt:load_role"
def test_no_conflict_save_strategy(self):
cfg = DictDefault(
{
"save_strategy": "epoch",
"save_steps": 10,
}
)
with pytest.raises(
ValueError, match=r".*save_strategy and save_steps mismatch.*"
):
validate_config(cfg)
cfg = DictDefault(
{
"save_strategy": "no",
"save_steps": 10,
}
)
with pytest.raises(
ValueError, match=r".*save_strategy and save_steps mismatch.*"
):
validate_config(cfg)
cfg = DictDefault(
{
"save_strategy": "steps",
}
)
validate_config(cfg)
cfg = DictDefault(
{
"save_strategy": "steps",
"save_steps": 10,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"save_steps": 10,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"save_strategy": "no",
}
)
validate_config(cfg)
def test_no_conflict_eval_strategy(self):
cfg = DictDefault(
{
"evaluation_strategy": "epoch",
"eval_steps": 10,
}
)
with pytest.raises(
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
):
validate_config(cfg)
cfg = DictDefault(
{
"evaluation_strategy": "no",
"eval_steps": 10,
}
)
with pytest.raises(
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
):
validate_config(cfg)
cfg = DictDefault(
{
"evaluation_strategy": "steps",
}
)
validate_config(cfg)
cfg = DictDefault(
{
"evaluation_strategy": "steps",
"eval_steps": 10,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"eval_steps": 10,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"evaluation_strategy": "no",
}
)
validate_config(cfg)
cfg = DictDefault(
{
"evaluation_strategy": "epoch",
"val_set_size": 0,
}
)
with pytest.raises(
ValueError,
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
):
validate_config(cfg)
cfg = DictDefault(
{
"eval_steps": 10,
"val_set_size": 0,
}
)
with pytest.raises(
ValueError,
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
):
validate_config(cfg)
cfg = DictDefault(
{
"val_set_size": 0,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"eval_steps": 10,
"val_set_size": 0.01,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"evaluation_strategy": "epoch",
"val_set_size": 0.01,
}
)
validate_config(cfg)