Compare commits
36 Commits
llama-drop
...
sharegpt-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4d84d56d5 | ||
|
|
669f1d052c | ||
|
|
d4a88e4eca | ||
|
|
2d60ba3a6e | ||
|
|
eb480dfd68 | ||
|
|
133e676bcc | ||
|
|
69fac9a020 | ||
|
|
e0b7eeabfd | ||
|
|
43856c0a39 | ||
|
|
e62d5901b5 | ||
|
|
697c50d408 | ||
|
|
90e0d673f7 | ||
|
|
2642caedf2 | ||
|
|
f34648c8b9 | ||
|
|
e50a64e85e | ||
|
|
f4868d733c | ||
|
|
a7e56d83c2 | ||
|
|
5b0bc48fbc | ||
|
|
9ec20777ba | ||
|
|
590d6032fd | ||
|
|
409ca0f21c | ||
|
|
8662e8ffe8 | ||
|
|
b2edaaeff6 | ||
|
|
b88f51512a | ||
|
|
eb41f76f92 | ||
|
|
383f88d7a7 | ||
|
|
b6ab8aad62 | ||
|
|
85b0be2ba7 | ||
|
|
8fe0e633d2 | ||
|
|
d1236f2c41 | ||
|
|
895f0a0723 | ||
|
|
e7d3e2dbb6 | ||
|
|
60c7c48c97 | ||
|
|
e8cbf50be6 | ||
|
|
d887ad86c3 | ||
|
|
19a600a8b8 |
7
.github/ISSUE_TEMPLATE/bug-report.yaml
vendored
7
.github/ISSUE_TEMPLATE/bug-report.yaml
vendored
@@ -53,6 +53,13 @@ body:
|
|||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: config
|
||||||
|
attributes:
|
||||||
|
label: Config yaml
|
||||||
|
description: |
|
||||||
|
Please attach the config yaml!
|
||||||
|
|
||||||
- type: textarea
|
- type: textarea
|
||||||
id: possible-solution
|
id: possible-solution
|
||||||
attributes:
|
attributes:
|
||||||
|
|||||||
8
.github/workflows/tests.yml
vendored
8
.github/workflows/tests.yml
vendored
@@ -6,9 +6,11 @@ on:
|
|||||||
- "main"
|
- "main"
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- '**.py'
|
||||||
|
- 'requirements.txt'
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- '**.py'
|
||||||
|
- 'requirements.txt'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
@@ -44,7 +46,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install -e .
|
pip3 install -U -e .
|
||||||
pip3 install -r requirements-tests.txt
|
pip3 install -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
@@ -69,8 +71,8 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install -e .
|
pip3 uninstall -y transformers accelerate
|
||||||
pip3 install flash-attn
|
pip3 install -U -e .[flash-attn]
|
||||||
pip3 install -r requirements-tests.txt
|
pip3 install -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Run e2e tests
|
- name: Run e2e tests
|
||||||
|
|||||||
@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
|
|||||||
disable=missing-function-docstring, line-too-long, import-error,
|
disable=missing-function-docstring, line-too-long, import-error,
|
||||||
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
||||||
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
||||||
|
too-many-nested-blocks,
|
||||||
|
|||||||
36
README.md
36
README.md
@@ -124,6 +124,11 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
pip3 install packaging
|
pip3 install packaging
|
||||||
pip3 install -e '.[flash-attn,deepspeed]'
|
pip3 install -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
|
4. (Optional) Login to Huggingface to use gated models/datasets.
|
||||||
|
```bash
|
||||||
|
huggingface-cli login
|
||||||
|
```
|
||||||
|
Get the token at huggingface.co/settings/tokens
|
||||||
|
|
||||||
- LambdaLabs
|
- LambdaLabs
|
||||||
<details>
|
<details>
|
||||||
@@ -180,7 +185,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"instruction": "...", "input": "...", "output": "..."}
|
{"instruction": "...", "input": "...", "output": "..."}
|
||||||
```
|
```
|
||||||
- `sharegpt:chat`: conversations where `from` is `human`/`gpt`
|
- `sharegpt`: conversations where `from` is `human`/`gpt`
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
{"conversations": [{"from": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
@@ -245,6 +250,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"article": "...", "question": "...", "answer": "..."}
|
{"article": "...", "question": "...", "answer": "..."}
|
||||||
```
|
```
|
||||||
|
- `context_qa.load_v2`: in context question answering (alternate)
|
||||||
|
```json
|
||||||
|
{"context": "...", "question": "...", "answer": "..."}
|
||||||
|
```
|
||||||
- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
|
- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
|
||||||
```json
|
```json
|
||||||
{"article": "...", "unanswerable_question": "..."}
|
{"article": "...", "unanswerable_question": "..."}
|
||||||
@@ -269,11 +278,11 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"prompt": "...", "generation": "..."}
|
{"prompt": "...", "generation": "..."}
|
||||||
```
|
```
|
||||||
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
|
- `sharegpt.load_role`: conversations where `role` is used instead of `from`
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"role": "...", "value": "..."}]}
|
{"conversations": [{"role": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
|
- `sharegpt.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
{"conversations": [{"from": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
@@ -308,7 +317,7 @@ Using file:
|
|||||||
#### How to use your custom pretokenized dataset
|
#### How to use your custom pretokenized dataset
|
||||||
|
|
||||||
- Do not pass a `type:`
|
- Do not pass a `type:`
|
||||||
- Dataset must contain `input_ids`, `attention_mask`, `labels` in columns
|
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
|
||||||
|
|
||||||
|
|
||||||
### Config
|
### Config
|
||||||
@@ -351,6 +360,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
- path: data.jsonl # or json
|
- path: data.jsonl # or json
|
||||||
ds_type: json # see other options below
|
ds_type: json # see other options below
|
||||||
type: alpaca
|
type: alpaca
|
||||||
|
|
||||||
|
# dataset with splits, but no train split
|
||||||
|
dataset:
|
||||||
|
- path: knowrohit07/know_sql
|
||||||
|
type: context_qa.load_v2
|
||||||
|
train_on_split: validation
|
||||||
```
|
```
|
||||||
|
|
||||||
- loading
|
- loading
|
||||||
@@ -408,6 +423,11 @@ tokenizer_legacy:
|
|||||||
# this is reported to improve training speed on some models
|
# this is reported to improve training speed on some models
|
||||||
resize_token_embeddings_to_32x:
|
resize_token_embeddings_to_32x:
|
||||||
|
|
||||||
|
# used to identify which the model is based on
|
||||||
|
is_falcon_derived_model:
|
||||||
|
is_llama_derived_model:
|
||||||
|
is_mistral_derived_model:
|
||||||
|
|
||||||
# whether you are training a 4-bit GPTQ quantized model
|
# whether you are training a 4-bit GPTQ quantized model
|
||||||
gptq: true
|
gptq: true
|
||||||
gptq_groupsize: 128 # group size
|
gptq_groupsize: 128 # group size
|
||||||
@@ -439,6 +459,7 @@ datasets:
|
|||||||
data_files: # Optional[str] path to source data files
|
data_files: # Optional[str] path to source data files
|
||||||
shards: # Optional[int] number of shards to split data into
|
shards: # Optional[int] number of shards to split data into
|
||||||
name: # Optional[str] name of dataset configuration to load
|
name: # Optional[str] name of dataset configuration to load
|
||||||
|
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||||
|
|
||||||
# custom user prompt
|
# custom user prompt
|
||||||
- path: repo
|
- path: repo
|
||||||
@@ -466,6 +487,9 @@ datasets:
|
|||||||
dataset_prepared_path: data/last_run_prepared
|
dataset_prepared_path: data/last_run_prepared
|
||||||
# push prepared dataset to hub
|
# push prepared dataset to hub
|
||||||
push_dataset_to_hub: # repo path
|
push_dataset_to_hub: # repo path
|
||||||
|
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||||
|
# if not set.
|
||||||
|
dataset_processes: # defaults to os.cpu_count() if not set
|
||||||
# push checkpoints to hub
|
# push checkpoints to hub
|
||||||
hub_model_id: # repo path to push finetuned model
|
hub_model_id: # repo path to push finetuned model
|
||||||
# how to push checkpoints to hub
|
# how to push checkpoints to hub
|
||||||
@@ -547,7 +571,7 @@ torch_compile_backend: # Optional[str]
|
|||||||
# training hyperparameters
|
# training hyperparameters
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
eval_batch_size: 2
|
eval_batch_size:
|
||||||
num_epochs: 3
|
num_epochs: 3
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
learning_rate: 0.00003
|
learning_rate: 0.00003
|
||||||
@@ -631,6 +655,8 @@ flash_optimum:
|
|||||||
xformers_attention:
|
xformers_attention:
|
||||||
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||||
flash_attention:
|
flash_attention:
|
||||||
|
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||||
|
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||||
# whether to use scaled-dot-product attention
|
# whether to use scaled-dot-product attention
|
||||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||||
sdp_attention:
|
sdp_attention:
|
||||||
|
|||||||
@@ -12,17 +12,18 @@ RUN apt-get update && \
|
|||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||||
|
|
||||||
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN cd axolotl && \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
|
||||||
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[flash-attn]; \
|
pip install -e .[flash-attn]; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works
|
# fix so that git fetch/pull from remote works
|
||||||
RUN cd axolotl && \
|
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||||
git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
|
||||||
git config --get remote.origin.fetch
|
git config --get remote.origin.fetch
|
||||||
|
|
||||||
# helper for huggingface-login cli
|
# helper for huggingface-login cli
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ push_dataset_to_hub:
|
|||||||
datasets:
|
datasets:
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ strict: false
|
|||||||
datasets:
|
datasets:
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ strict: false
|
|||||||
datasets:
|
datasets:
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ strict: false
|
|||||||
datasets:
|
datasets:
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ strict: false
|
|||||||
datasets:
|
datasets:
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ strict: false
|
|||||||
datasets:
|
datasets:
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ strict: false
|
|||||||
datasets:
|
datasets:
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b
|
|||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
model_type: AutoModelForCausalLM
|
model_type: AutoModelForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
|
is_falcon_derived_model: true
|
||||||
load_in_8bit: true
|
load_in_8bit: true
|
||||||
load_in_4bit: false
|
load_in_4bit: false
|
||||||
gptq: false
|
gptq: false
|
||||||
@@ -11,7 +12,7 @@ push_dataset_to_hub:
|
|||||||
datasets:
|
datasets:
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca:chat
|
type: alpaca:chat
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ base_model_config: tiiuae/falcon-7b
|
|||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
model_type: AutoModelForCausalLM
|
model_type: AutoModelForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
|
is_falcon_derived_model: true
|
||||||
load_in_8bit: false
|
load_in_8bit: false
|
||||||
# enable 4bit for QLoRA
|
# enable 4bit for QLoRA
|
||||||
load_in_4bit: true
|
load_in_4bit: true
|
||||||
@@ -17,7 +18,7 @@ datasets:
|
|||||||
data_files:
|
data_files:
|
||||||
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
||||||
type: "alpaca:chat"
|
type: "alpaca:chat"
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
# enable QLoRA
|
# enable QLoRA
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b
|
|||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
model_type: AutoModelForCausalLM
|
model_type: AutoModelForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
|
is_falcon_derived_model: true
|
||||||
load_in_8bit: false
|
load_in_8bit: false
|
||||||
load_in_4bit: false
|
load_in_4bit: false
|
||||||
gptq: false
|
gptq: false
|
||||||
@@ -11,7 +12,7 @@ push_dataset_to_hub:
|
|||||||
datasets:
|
datasets:
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca:chat
|
type: alpaca:chat
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter:
|
adapter:
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ push_dataset_to_hub:
|
|||||||
datasets:
|
datasets:
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ load_in_8bit: false
|
|||||||
datasets:
|
datasets:
|
||||||
- path: openaccess-ai-collective/jeopardy
|
- path: openaccess-ai-collective/jeopardy
|
||||||
type: jeopardy
|
type: jeopardy
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
adapter:
|
adapter:
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ hf_use_auth_token: true
|
|||||||
datasets:
|
datasets:
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ strict: false
|
|||||||
datasets:
|
datasets:
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
@@ -56,7 +56,7 @@ flash_attention: true
|
|||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
eval_steps: 20
|
||||||
eval_table_size: 5
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_table_max_new_tokens: 128
|
||||||
save_steps:
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ strict: false
|
|||||||
datasets:
|
datasets:
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
@@ -58,7 +58,7 @@ flash_attention: true
|
|||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
eval_steps: 20
|
||||||
eval_table_size: 5
|
eval_table_size:
|
||||||
save_steps:
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ strict: false
|
|||||||
datasets:
|
datasets:
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
output_dir: ./relora-out
|
output_dir: ./relora-out
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ strict: false
|
|||||||
datasets:
|
datasets:
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
@@ -56,7 +56,7 @@ flash_attention: true
|
|||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
eval_steps: 20
|
||||||
eval_table_size: 5
|
eval_table_size:
|
||||||
save_steps:
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
12
examples/mistral/README.md
Normal file
12
examples/mistral/README.md
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
**Mistral 7B** is a language model with a total of 7.3 billion parameters, showcasing a notable performance across a variety of benchmarks.
|
||||||
|
|
||||||
|
Fine Tune:
|
||||||
|
```shell
|
||||||
|
accelerate launch -m axolotl.cli.train examples/mistral/config.yml
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
If you run into CUDA OOM, use deepspeed with config zero2.json:
|
||||||
|
```shell
|
||||||
|
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed/zero2.json
|
||||||
|
```
|
||||||
62
examples/mistral/config.yml
Normal file
62
examples/mistral/config.yml
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
base_model: mistralai/Mistral-7B-v0.1
|
||||||
|
base_model_config: mistralai/Mistral-7B-v0.1
|
||||||
|
model_type: MistralForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
is_mistral_derived_model: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.01
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
sample_packing:
|
||||||
|
pad_to_sequence_len:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_run_id:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 3
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
eval_steps: 20
|
||||||
|
eval_table_size: 5
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
|
save_steps:
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
bos_token: "<s>"
|
||||||
|
eos_token: "</s>"
|
||||||
|
unk_token: "<unk>"
|
||||||
79
examples/mistral/qlora.yml
Normal file
79
examples/mistral/qlora.yml
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
base_model: mistralai/Mistral-7B-v0.1
|
||||||
|
base_model_config: mistralai/Mistral-7B-v0.1
|
||||||
|
model_type: MistralForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
is_mistral_derived_model: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.01
|
||||||
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
sample_packing: True
|
||||||
|
pad_to_sequence_len: True
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
lora_target_modules:
|
||||||
|
- gate_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_run_id:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 4
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
eval_steps: 20
|
||||||
|
eval_table_size: 5
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
|
save_steps:
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
bos_token: "<s>"
|
||||||
|
eos_token: "</s>"
|
||||||
|
unk_token: "<unk>"
|
||||||
@@ -6,7 +6,7 @@ load_in_8bit: false
|
|||||||
datasets:
|
datasets:
|
||||||
- path: vicgalle/alpaca-gpt4
|
- path: vicgalle/alpaca-gpt4
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
adapter:
|
adapter:
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ push_dataset_to_hub:
|
|||||||
datasets:
|
datasets:
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
adapter:
|
adapter:
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ push_dataset_to_hub:
|
|||||||
datasets:
|
datasets:
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ push_dataset_to_hub:
|
|||||||
datasets:
|
datasets:
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ datasets:
|
|||||||
- path: garage-bAInd/Open-Platypus
|
- path: garage-bAInd/Open-Platypus
|
||||||
type: alpaca
|
type: alpaca
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./phi-sft-out
|
output_dir: ./phi-sft-out
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ datasets:
|
|||||||
- path: garage-bAInd/Open-Platypus
|
- path: garage-bAInd/Open-Platypus
|
||||||
type: alpaca
|
type: alpaca
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./phi-sft-out
|
output_dir: ./phi-sft-out
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ device_map: auto
|
|||||||
datasets:
|
datasets:
|
||||||
- path: vicgalle/alpaca-gpt4
|
- path: vicgalle/alpaca-gpt4
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
adapter:
|
adapter:
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ load_in_8bit: true
|
|||||||
datasets:
|
datasets:
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ load_in_8bit: false
|
|||||||
datasets:
|
datasets:
|
||||||
- path: vicgalle/alpaca-gpt4
|
- path: vicgalle/alpaca-gpt4
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
adapter:
|
adapter:
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ load_in_8bit: false
|
|||||||
datasets:
|
datasets:
|
||||||
- path: vicgalle/alpaca-gpt4
|
- path: vicgalle/alpaca-gpt4
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ datasets:
|
|||||||
data_files:
|
data_files:
|
||||||
- openassistant_best_replies_train.jsonl
|
- openassistant_best_replies_train.jsonl
|
||||||
type: "completion"
|
type: "completion"
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
# enable QLoRA
|
# enable QLoRA
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
|
|||||||
@@ -4,16 +4,15 @@ torch==2.0.1
|
|||||||
auto-gptq
|
auto-gptq
|
||||||
packaging
|
packaging
|
||||||
peft @ git+https://github.com/huggingface/peft.git
|
peft @ git+https://github.com/huggingface/peft.git
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git
|
transformers @ git+https://github.com/huggingface/transformers.git@bd6205919aad4d3a2300a39a98a642f1cc3a5348
|
||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.41.1
|
||||||
accelerate @ git+https://github.com/huggingface/accelerate
|
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
|
||||||
deepspeed
|
deepspeed
|
||||||
addict
|
addict
|
||||||
evaluate
|
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
datasets
|
datasets
|
||||||
flash-attn>=2.2.1
|
flash-attn>=2.3.0
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
@@ -31,3 +30,4 @@ scipy
|
|||||||
scikit-learn==1.2.2
|
scikit-learn==1.2.2
|
||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
|
fschat==0.2.29
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import transformers
|
|||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import (
|
||||||
check_accelerate_default_config,
|
check_accelerate_default_config,
|
||||||
|
check_user_token,
|
||||||
do_inference,
|
do_inference,
|
||||||
do_merge_lora,
|
do_merge_lora,
|
||||||
load_cfg,
|
load_cfg,
|
||||||
@@ -31,6 +32,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
)
|
)
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
|
check_user_token()
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ import yaml
|
|||||||
# add src to the pythonpath so we don't need to pip install this
|
# add src to the pythonpath so we don't need to pip install this
|
||||||
from accelerate.commands.config import config_args
|
from accelerate.commands.config import config_args
|
||||||
from art import text2art
|
from art import text2art
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
from transformers import GenerationConfig, TextStreamer
|
from transformers import GenerationConfig, TextStreamer
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||||
@@ -49,7 +51,7 @@ def print_axolotl_text_art(suffix=None):
|
|||||||
|
|
||||||
|
|
||||||
def get_multi_line_input() -> Optional[str]:
|
def get_multi_line_input() -> Optional[str]:
|
||||||
print("Give me an instruction (Ctrl + D to finish): ")
|
print("Give me an instruction (Ctrl + D to submit): ")
|
||||||
instruction = ""
|
instruction = ""
|
||||||
for line in sys.stdin:
|
for line in sys.stdin:
|
||||||
instruction += line # pylint: disable=consider-using-join
|
instruction += line # pylint: disable=consider-using-join
|
||||||
@@ -247,3 +249,16 @@ def check_accelerate_default_config():
|
|||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
|
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_user_token():
|
||||||
|
# Verify if token is valid
|
||||||
|
api = HfApi()
|
||||||
|
try:
|
||||||
|
user_info = api.whoami()
|
||||||
|
return bool(user_info)
|
||||||
|
except LocalTokenNotFoundError:
|
||||||
|
LOG.warning(
|
||||||
|
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import transformers
|
|||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import (
|
||||||
check_accelerate_default_config,
|
check_accelerate_default_config,
|
||||||
|
check_user_token,
|
||||||
load_cfg,
|
load_cfg,
|
||||||
load_datasets,
|
load_datasets,
|
||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
@@ -21,6 +22,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
|
check_user_token()
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import os
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset, Sequence, Value
|
||||||
|
|
||||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
"""
|
"""
|
||||||
Dataset that returns tokenized prompts from a stream of text files.
|
Dataset that returns tokenized prompts from a stream of text files.
|
||||||
Args:
|
Args:
|
||||||
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
|
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
|
||||||
dataset (dataset.Dataset): Dataset with text files.
|
dataset (dataset.Dataset): Dataset with text files.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -42,11 +42,15 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
map_kwargs["batched"] = True
|
map_kwargs["batched"] = True
|
||||||
map_kwargs["batch_size"] = 100
|
map_kwargs["batch_size"] = 100
|
||||||
return dataset.map(
|
return (
|
||||||
self.prompt_tokenizer.tokenize_prompt,
|
dataset.map(
|
||||||
num_proc=num_proc,
|
self.prompt_tokenizer.tokenize_prompt,
|
||||||
remove_columns=features,
|
num_proc=num_proc,
|
||||||
**map_kwargs,
|
remove_columns=features,
|
||||||
|
**map_kwargs,
|
||||||
|
)
|
||||||
|
.cast_column("input_ids", Sequence(feature=Value(dtype="int32", id=None)))
|
||||||
|
.cast_column("labels", Sequence(feature=Value(dtype="int32", id=None)))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -55,7 +59,7 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
"""
|
"""
|
||||||
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
||||||
Args:
|
Args:
|
||||||
tokenizer (Tokenizer): The processor used for proccessing the data.
|
tokenizer (Tokenizer): The processor used for processing the data.
|
||||||
dataset (dataset.Dataset): Dataset with text files.
|
dataset (dataset.Dataset): Dataset with text files.
|
||||||
seq_length (int): Length of token sequences to return.
|
seq_length (int): Length of token sequences to return.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -711,12 +711,8 @@ class ParallelBlock(nn.Module):
|
|||||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||||
self.block_idx = block_idx
|
self.block_idx = block_idx
|
||||||
|
|
||||||
self.mixer = MHA(config=config, **mixer, layer_idx=block_idx)
|
self.mixer = MHA(config, layer_idx=block_idx)
|
||||||
mlp_cls = mlp.pop("mlp_cls")
|
self.mlp = MLP(config)
|
||||||
if mlp_cls == "fused_mlp":
|
|
||||||
self.mlp = FusedMLP(config=config, **mlp)
|
|
||||||
else:
|
|
||||||
self.mlp = MLP(config=config, **mlp)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import logging
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import init_empty_weights
|
||||||
from flash_attn.flash_attn_interface import flash_attn_func
|
from flash_attn.flash_attn_interface import flash_attn_func
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
|
||||||
@@ -17,7 +18,8 @@ def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
|
|||||||
# this is a wonky hack to get the remotely loaded module
|
# this is a wonky hack to get the remotely loaded module
|
||||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||||
# we need to load the model here in order for modeling_btlm to be available
|
# we need to load the model here in order for modeling_btlm to be available
|
||||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
with init_empty_weights():
|
||||||
|
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||||
module_name = model_config.__class__.__module__.replace(
|
module_name = model_config.__class__.__module__.replace(
|
||||||
".configuration_btlm", ".modeling_btlm"
|
".configuration_btlm", ".modeling_btlm"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,101 +0,0 @@
|
|||||||
"""
|
|
||||||
Flash Attention monkey patch for Falcon
|
|
||||||
|
|
||||||
copied from https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/falcon_flash_attn_monkey_patch.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from flash_attn import flash_attn_func
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
alibi: Optional[torch.Tensor],
|
|
||||||
attention_mask: torch.Tensor, # pylint: disable=unused-argument
|
|
||||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
head_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
|
||||||
use_cache: bool = False,
|
|
||||||
output_attentions: bool = False, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
fused_qkv = self.query_key_value(
|
|
||||||
hidden_states
|
|
||||||
) # [batch_size, seq_length, 3 x hidden_size]
|
|
||||||
num_kv_heads = (
|
|
||||||
self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
|
||||||
)
|
|
||||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
|
||||||
(
|
|
||||||
query_layer,
|
|
||||||
key_layer,
|
|
||||||
value_layer,
|
|
||||||
) = self._split_heads( # pylint: disable=protected-access
|
|
||||||
fused_qkv
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size, query_length, _, _ = query_layer.shape
|
|
||||||
|
|
||||||
query_layer = query_layer.transpose(1, 2).reshape(
|
|
||||||
batch_size * self.num_heads, query_length, self.head_dim
|
|
||||||
)
|
|
||||||
key_layer = key_layer.transpose(1, 2).reshape(
|
|
||||||
batch_size * num_kv_heads,
|
|
||||||
query_length,
|
|
||||||
self.head_dim,
|
|
||||||
)
|
|
||||||
value_layer = value_layer.transpose(1, 2).reshape(
|
|
||||||
batch_size * num_kv_heads, query_length, self.head_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
|
||||||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
|
||||||
|
|
||||||
if layer_past is not None:
|
|
||||||
past_key, past_value = layer_past
|
|
||||||
# concatenate along seq_length dimension:
|
|
||||||
# - key: [batch_size * self.num_heads, kv_length, head_dim]
|
|
||||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
|
||||||
key_layer = torch.cat((past_key, key_layer), dim=1)
|
|
||||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
|
||||||
|
|
||||||
# unused
|
|
||||||
# _, kv_length, _ = key_layer.shape
|
|
||||||
if use_cache:
|
|
||||||
present = (key_layer, value_layer)
|
|
||||||
else:
|
|
||||||
present = None
|
|
||||||
# unused
|
|
||||||
# attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
|
|
||||||
query_layer_ = (
|
|
||||||
query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.to(torch.bfloat16)
|
|
||||||
)
|
|
||||||
key_layer_ = (
|
|
||||||
key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.to(torch.bfloat16)
|
|
||||||
)
|
|
||||||
value_layer_ = (
|
|
||||||
value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.to(torch.bfloat16)
|
|
||||||
)
|
|
||||||
|
|
||||||
if alibi is not None:
|
|
||||||
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
|
|
||||||
|
|
||||||
# below output will have shape (batch_size, seqlen, nheads, headdim)
|
|
||||||
attn_output = flash_attn_func(query_layer_, key_layer_, value_layer_, causal=True)
|
|
||||||
attn_output = attn_output.reshape(
|
|
||||||
batch_size, query_length, self.num_heads * self.head_dim
|
|
||||||
)
|
|
||||||
output_tensor = self.dense(attn_output)
|
|
||||||
return output_tensor, present
|
|
||||||
|
|
||||||
|
|
||||||
def replace_falcon_attn_with_flash_attn():
|
|
||||||
transformers.models.falcon.modeling_falcon.FalconAttention.forward = forward
|
|
||||||
174
src/axolotl/monkeypatch/fastchat_conversation_turns.py
Normal file
174
src/axolotl/monkeypatch/fastchat_conversation_turns.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
"""
|
||||||
|
monkeypatch to add a get_turns method
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Generator, Tuple
|
||||||
|
|
||||||
|
from fastchat.conversation import SeparatorStyle
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt(self) -> str:
|
||||||
|
ret = ""
|
||||||
|
for role, msg in self.get_turns():
|
||||||
|
ret += role + msg
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_turns( # pylint: disable=too-many-return-statements
|
||||||
|
self,
|
||||||
|
) -> Generator[Tuple[str, str], None, None]:
|
||||||
|
"""Get the prompt for generation."""
|
||||||
|
system_prompt = self.system_template.format(system_message=self.system_message)
|
||||||
|
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
|
||||||
|
yield "", system_prompt + self.sep
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
yield role + ": ", message + self.sep
|
||||||
|
else:
|
||||||
|
yield role + ":", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
|
yield "", system_prompt + seps[0]
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message:
|
||||||
|
yield role + ": ", message + seps[i % 2]
|
||||||
|
else:
|
||||||
|
yield role + ":", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
|
||||||
|
yield "", system_prompt + self.sep
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
yield role + ": ", message + self.sep
|
||||||
|
else:
|
||||||
|
yield role + ": ", "" # must be end with a space
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
|
||||||
|
yield "", "" if system_prompt == "" else system_prompt + self.sep
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
yield role + "\n", message + self.sep
|
||||||
|
else:
|
||||||
|
yield role + "\n", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
||||||
|
yield "", system_prompt
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
yield role, message + self.sep
|
||||||
|
else:
|
||||||
|
yield role, ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.NO_COLON_TWO:
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
|
yield "", system_prompt
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message:
|
||||||
|
yield role, message + seps[i % 2]
|
||||||
|
else:
|
||||||
|
yield role, ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.RWKV:
|
||||||
|
yield "", system_prompt
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message:
|
||||||
|
yield role + ": ", message.replace("\r\n", "\n").replace(
|
||||||
|
"\n\n", "\n"
|
||||||
|
) + "\n\n"
|
||||||
|
else:
|
||||||
|
yield role + ":", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.LLAMA2:
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
|
if self.system_message:
|
||||||
|
yield "", system_prompt
|
||||||
|
else:
|
||||||
|
yield "", "[INST] "
|
||||||
|
for i, (role, message) in enumerate(self.messages[1:]):
|
||||||
|
if message:
|
||||||
|
yield role + " ", message + seps[i % 2]
|
||||||
|
else:
|
||||||
|
yield role, ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.CHATGLM:
|
||||||
|
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
||||||
|
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||||
|
round_add_n = 1 if self.name == "chatglm2" else 0
|
||||||
|
if system_prompt:
|
||||||
|
yield "", system_prompt + self.sep
|
||||||
|
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if i % 2 == 0:
|
||||||
|
yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
|
||||||
|
|
||||||
|
if message:
|
||||||
|
yield f"{role}:", f"{message}{self.sep}"
|
||||||
|
else:
|
||||||
|
yield f"{role}:", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.CHATML:
|
||||||
|
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
yield role + "\n", message + self.sep + "\n"
|
||||||
|
else:
|
||||||
|
yield role + "\n", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.CHATINTERN:
|
||||||
|
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
|
yield "", system_prompt
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
prefix = "<s>" if i % 2 == 0 else ""
|
||||||
|
if message:
|
||||||
|
yield prefix + role + ":", message + seps[i % 2] + "\n"
|
||||||
|
else:
|
||||||
|
yield role + ":", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.DOLLY:
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
|
yield "", system_prompt
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message:
|
||||||
|
suffix = "\n\n" if i % 2 == 1 else ""
|
||||||
|
yield role + ":\n", message + seps[i % 2] + suffix
|
||||||
|
else:
|
||||||
|
yield role + ":\n", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.PHOENIX:
|
||||||
|
yield "", system_prompt
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
yield role + ": ", "<s>" + message + "</s>"
|
||||||
|
else:
|
||||||
|
yield role + ": " + "<s>", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.ROBIN:
|
||||||
|
yield "", system_prompt + self.sep
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
yield role + ":\n", message + self.sep
|
||||||
|
else:
|
||||||
|
yield role + ":\n", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.FALCON_CHAT:
|
||||||
|
if self.system_message:
|
||||||
|
yield "", system_prompt + self.sep
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
yield role + ": ", message + self.sep
|
||||||
|
else:
|
||||||
|
yield role + ":", ""
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||||
|
|
||||||
|
|
||||||
|
def add_get_turns_to_conversation():
|
||||||
|
import fastchat.conversation
|
||||||
|
|
||||||
|
fastchat.conversation.Conversation.get_turns = get_turns
|
||||||
|
fastchat.conversation.Conversation.get_prompt = get_prompt
|
||||||
@@ -38,7 +38,11 @@ except ImportError:
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
def replace_llama_attn_with_flash_attn(
|
||||||
|
packed: Optional[bool] = False,
|
||||||
|
cross_entropy: Optional[bool] = False,
|
||||||
|
rms_norm: Optional[bool] = False,
|
||||||
|
):
|
||||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||||
_prepare_decoder_attention_mask
|
_prepare_decoder_attention_mask
|
||||||
)
|
)
|
||||||
@@ -49,33 +53,37 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
|||||||
llama_model_forward
|
llama_model_forward
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
# skip only if explicitly disabled
|
||||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
if cross_entropy:
|
||||||
|
try:
|
||||||
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||||
CrossEntropyLoss, inplace_backward=True
|
CrossEntropyLoss, inplace_backward=True
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
# skip only if explicitly disabled
|
||||||
from flash_attn.ops.rms_norm import RMSNorm
|
if rms_norm:
|
||||||
|
try:
|
||||||
|
from flash_attn.ops.rms_norm import RMSNorm
|
||||||
|
|
||||||
class LlamaRMSNorm(RMSNorm):
|
class LlamaRMSNorm(RMSNorm):
|
||||||
"""Patched LLamaRMSNorm"""
|
"""Patched LLamaRMSNorm"""
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
super().__init__(hidden_size, eps=eps)
|
super().__init__(hidden_size, eps=eps)
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||||
except ImportError:
|
except ImportError:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
@@ -99,6 +107,7 @@ def flashattn_forward(
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
@@ -476,6 +485,13 @@ def llama_model_forward(
|
|||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
device=inputs_embeds.device,
|
device=inputs_embeds.device,
|
||||||
)
|
)
|
||||||
|
padding_mask = None
|
||||||
|
else:
|
||||||
|
if 0 in attention_mask:
|
||||||
|
padding_mask = attention_mask
|
||||||
|
else:
|
||||||
|
padding_mask = None
|
||||||
|
|
||||||
attention_mask = (
|
attention_mask = (
|
||||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -510,7 +526,9 @@ def llama_model_forward(
|
|||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
# None for past_key_value
|
# None for past_key_value
|
||||||
return module(*inputs)
|
return module(
|
||||||
|
*inputs,
|
||||||
|
)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
@@ -519,9 +537,10 @@ def llama_model_forward(
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
None,
|
past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
None,
|
None,
|
||||||
|
padding_mask,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_seqlen,
|
max_seqlen,
|
||||||
)
|
)
|
||||||
@@ -533,6 +552,7 @@ def llama_model_forward(
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
padding_mask=padding_mask,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
)
|
)
|
||||||
@@ -579,6 +599,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None,
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
@@ -611,6 +632,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
padding_mask=padding_mask,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
)
|
)
|
||||||
|
|||||||
541
src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
Normal file
541
src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
Normal file
@@ -0,0 +1,541 @@
|
|||||||
|
"""Flash attention monkey patch for mistral model"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from einops import rearrange
|
||||||
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
|
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||||
|
flash_attn_kvpacked_func,
|
||||||
|
flash_attn_varlen_kvpacked_func,
|
||||||
|
flash_attn_varlen_qkvpacked_func,
|
||||||
|
)
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
|
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||||
|
)
|
||||||
|
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
||||||
|
|
||||||
|
|
||||||
|
def replace_mistral_attn_with_flash_attn(
|
||||||
|
packed: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||||
|
_prepare_decoder_attention_mask
|
||||||
|
)
|
||||||
|
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
||||||
|
flashattn_forward
|
||||||
|
)
|
||||||
|
if packed:
|
||||||
|
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
||||||
|
MistralDecoderLayer
|
||||||
|
)
|
||||||
|
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
|
||||||
|
mistral_model_forward
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
|
# requires the attention mask to be the same as the key_padding_mask
|
||||||
|
def _prepare_decoder_attention_mask(
|
||||||
|
self,
|
||||||
|
attention_mask,
|
||||||
|
input_shape,
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
sliding_window,
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
# [bsz, seq_len]
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def flashattn_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
# during training q,k,v always have same seqlen
|
||||||
|
assert key_states.shape == query_states.shape
|
||||||
|
is_causal = True
|
||||||
|
else:
|
||||||
|
# turn off FA causal mask after first inference autoregressive iteration
|
||||||
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
|
is_causal = key_states.shape == query_states.shape
|
||||||
|
|
||||||
|
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||||
|
# special handling using sample packing
|
||||||
|
qkv = torch.stack(
|
||||||
|
[query_states, key_states, value_states], dim=2
|
||||||
|
) # [bsz, nh, 3, q_len, hd]
|
||||||
|
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||||
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||||
|
)
|
||||||
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
|
elif query_states.shape == key_states.shape:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
qkvpacked=True,
|
||||||
|
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||||
|
# the attention_mask should be the same as the key_padding_mask
|
||||||
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
max_seqlen_q,
|
||||||
|
0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
else:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
if attention_mask is None or attention_mask.all().item():
|
||||||
|
output = flash_attn_kvpacked_func(
|
||||||
|
query_states,
|
||||||
|
torch.stack([key_states, value_states], 2),
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
( # pylint: disable=unbalanced-tuple-unpacking
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
output_pad_fn,
|
||||||
|
) = generate_qkv(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
kvpacked=True,
|
||||||
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
if q_unpad.dtype != kv_unpad.dtype:
|
||||||
|
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||||
|
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
|
||||||
|
attn_output = output
|
||||||
|
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
||||||
|
def generate_qkv(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
query_padding_mask=None,
|
||||||
|
key_padding_mask=None,
|
||||||
|
kvpacked=False,
|
||||||
|
qkvpacked=False,
|
||||||
|
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
q: (batch_size, seqlen_q, nheads, d)
|
||||||
|
k: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
v: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
query_padding_mask: (batch_size, seqlen), bool
|
||||||
|
key_padding_mask: (batch_size, seqlen), bool
|
||||||
|
"""
|
||||||
|
assert not (kvpacked and qkvpacked)
|
||||||
|
batch_size, seqlen_q, nheads, d = q.shape
|
||||||
|
_, seqlen_k, nheads_k, _ = k.shape
|
||||||
|
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
|
||||||
|
if query_padding_mask is not None:
|
||||||
|
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
||||||
|
q, query_padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
||||||
|
output_unpad, indices_q, batch_size, seqlen_q
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_q,
|
||||||
|
step=seqlen_q,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=q_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_q = seqlen_q
|
||||||
|
|
||||||
|
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
||||||
|
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||||
|
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
||||||
|
else:
|
||||||
|
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||||
|
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_k = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_k,
|
||||||
|
step=seqlen_k,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=k_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_k = seqlen_k
|
||||||
|
|
||||||
|
if qkvpacked:
|
||||||
|
assert nheads == nheads_k
|
||||||
|
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||||
|
qkv = torch.stack([q, k, v], dim=2)
|
||||||
|
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
||||||
|
|
||||||
|
if kvpacked:
|
||||||
|
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||||
|
kv = torch.stack([k, v], dim=2)
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
output_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
k_unpad,
|
||||||
|
v_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
output_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mistral_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||||
|
)
|
||||||
|
if input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
if past_key_values is not None:
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2]
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
|
cu_seqlens = None
|
||||||
|
max_seqlen = None
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length,
|
||||||
|
seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
|
cu_seqlens = cu_seqlens.squeeze()
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
# embed positions
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(
|
||||||
|
(batch_size, seq_length_with_past),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
)
|
||||||
|
attention_mask = (
|
||||||
|
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
transformers.logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(decoder_layer),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
None,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MistralDecoderLayer(OriginalMistralDecoderLayer):
|
||||||
|
"""
|
||||||
|
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[
|
||||||
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||||
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
|
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
|
||||||
|
"""
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
415
src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py
Normal file
415
src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py
Normal file
@@ -0,0 +1,415 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# This code is based off the following work:
|
||||||
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||||
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||||
|
""" PyTorch StableLM Epoch model. """
|
||||||
|
import importlib
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from einops import rearrange
|
||||||
|
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||||
|
flash_attn_varlen_qkvpacked_func,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"):
|
||||||
|
# this is a wonky hack to get the remotely loaded module
|
||||||
|
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
# we need to load the model here in order for modeling_stablelm_epoch to be available
|
||||||
|
with init_empty_weights():
|
||||||
|
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
module_name = model_config.__class__.__module__.replace(
|
||||||
|
".configuration_stablelm_epoch", ".modeling_stablelm_epoch"
|
||||||
|
)
|
||||||
|
modeling_stablelm = importlib.import_module(module_name)
|
||||||
|
modeling_stablelm.Attention.forward = ( # pylint: disable=protected-access
|
||||||
|
flashattn_attn
|
||||||
|
)
|
||||||
|
modeling_stablelm.StableLMEpochModel.forward = ( # pylint: disable=protected-access
|
||||||
|
stablelm_model_forward
|
||||||
|
)
|
||||||
|
modeling_stablelm.DecoderLayer.forward = ( # pylint: disable=protected-access
|
||||||
|
decoder_layer_forward
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x: torch.Tensor):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||||
|
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||||
|
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||||
|
cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
|
||||||
|
sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||||
|
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||||
|
)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def flashattn_attn(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
attention_mask: torch.FloatTensor,
|
||||||
|
position_ids: torch.LongTensor,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False, # pylint: disable=unused-argument
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
query_rot = query_states[..., : self.rotary_ndims]
|
||||||
|
query_pass = query_states[..., self.rotary_ndims :]
|
||||||
|
key_rot = key_states[..., : self.rotary_ndims]
|
||||||
|
key_pass = key_states[..., self.rotary_ndims :]
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_rot, key_rot, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# [batch_size, num_heads, seq_len, head_dim]
|
||||||
|
query_states = torch.cat((query_states, query_pass), dim=-1)
|
||||||
|
key_states = torch.cat((key_states, key_pass), dim=-1)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# Reuse k, v, self_attention
|
||||||
|
key_states = torch.cat((past_key_value[0], key_states), dim=2)
|
||||||
|
value_states = torch.cat((past_key_value[1], value_states), dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
# Repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||||
|
# special handling using sample packing
|
||||||
|
qkv = torch.stack(
|
||||||
|
[query_states, key_states, value_states], dim=2
|
||||||
|
) # [bsz, nh, 3, q_len, hd]
|
||||||
|
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||||
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
softmax_scale = None
|
||||||
|
|
||||||
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=softmax_scale, causal=True
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
|
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(
|
||||||
|
query_states, key_states.transpose(2, 3)
|
||||||
|
) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# Upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(
|
||||||
|
attn_weights, dim=-1, dtype=torch.float32
|
||||||
|
).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge heads
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
# Final linear projection
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def decoder_layer_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: Optional[torch.FloatTensor],
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[
|
||||||
|
Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]
|
||||||
|
]:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def stablelm_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# Retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||||
|
)
|
||||||
|
if input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
if past_key_values is not None:
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2]
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
|
cu_seqlens = None
|
||||||
|
max_seqlen = None
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length,
|
||||||
|
seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
|
cu_seqlens = cu_seqlens.squeeze()
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
# Embed positions
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(
|
||||||
|
(batch_size, seq_length_with_past),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
)
|
||||||
|
attention_mask = (
|
||||||
|
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# Decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(decoder_layer),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
None,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# Add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
@@ -24,6 +24,15 @@ def load(tokenizer, cfg):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_v2(tokenizer, cfg):
|
||||||
|
return ContextQaV2PromptTokenizingStrategy(
|
||||||
|
ContextV2Prompter(),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AlpacaContextPrompter(AlpacaPrompter):
|
class AlpacaContextPrompter(AlpacaPrompter):
|
||||||
"""
|
"""
|
||||||
Customized system prompted for concise QA
|
Customized system prompted for concise QA
|
||||||
@@ -50,6 +59,38 @@ class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextQaV2PromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||||
|
"""
|
||||||
|
Tokenization Strategy to combine in-context article with a question and answer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||||
|
return (
|
||||||
|
"Context: "
|
||||||
|
+ prompt["context"]
|
||||||
|
+ "\nQuestion: "
|
||||||
|
+ prompt["question"]
|
||||||
|
+ "\n",
|
||||||
|
"",
|
||||||
|
"Answer: " + prompt["answer"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextV2Prompter(AlpacaPrompter):
|
||||||
|
"""
|
||||||
|
Customized system prompted for concise QA
|
||||||
|
"""
|
||||||
|
|
||||||
|
system_prompt = ""
|
||||||
|
system_no_input_prompt = ""
|
||||||
|
|
||||||
|
def match_prompt_style(self):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
self.turn_format = "{instruction}\n{input}"
|
||||||
|
self.turn_no_input_format = "{instruction}"
|
||||||
|
self.system_format = "{system}"
|
||||||
|
|
||||||
|
|
||||||
class AlpacaMissingInfoContextPromptTokenizingStrategy(
|
class AlpacaMissingInfoContextPromptTokenizingStrategy(
|
||||||
InstructionPromptTokenizingStrategy
|
InstructionPromptTokenizingStrategy
|
||||||
):
|
):
|
||||||
|
|||||||
119
src/axolotl/prompt_strategies/sharegpt.py
Normal file
119
src/axolotl/prompt_strategies/sharegpt.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
||||||
|
|
||||||
|
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||||
|
from axolotl.prompters import ShareGPTPrompterV2
|
||||||
|
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="chatml",
|
||||||
|
system_template="<|im_start|>system\n{system_message}",
|
||||||
|
system_message="You are a helpful assistant.",
|
||||||
|
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
||||||
|
sep_style=SeparatorStyle.CHATML,
|
||||||
|
sep="<|im_end|>\n",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
|
conversation = (
|
||||||
|
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||||
|
)
|
||||||
|
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||||
|
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||||
|
strat = ShareGPTPromptTokenizingStrategy(
|
||||||
|
ShareGPTPrompterV2(
|
||||||
|
conversation=conversation,
|
||||||
|
role_key_model=field_model,
|
||||||
|
role_key_human=field_human,
|
||||||
|
),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
if ds_cfg and ds_cfg["skip"]:
|
||||||
|
strat.skip_invalid = True
|
||||||
|
return strat
|
||||||
|
|
||||||
|
|
||||||
|
def load_role(tokenizer, cfg):
|
||||||
|
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||||
|
ShareGPTPrompterV2(),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_guanaco(tokenizer, cfg):
|
||||||
|
return GuanacoShareGPTPromptTokenizingStrategy(
|
||||||
|
ShareGPTPrompterV2(),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_nous(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
|
conversation = (
|
||||||
|
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||||
|
)
|
||||||
|
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||||
|
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||||
|
return NousShareGPTPromptTokenizingStrategy(
|
||||||
|
ShareGPTPrompterV2(
|
||||||
|
conversation=conversation,
|
||||||
|
role_key_model=field_model,
|
||||||
|
role_key_human=field_human,
|
||||||
|
),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NousShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||||
|
"""
|
||||||
|
basic sharegpt strategy used by nous/ldj for input/output keyed data
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_conversation_thread(self):
|
||||||
|
return "conversation"
|
||||||
|
|
||||||
|
def map_conversation_thread(self, conversation):
|
||||||
|
turns = []
|
||||||
|
for turn in conversation:
|
||||||
|
turns.append({"from": "human", "value": turn["input"]})
|
||||||
|
turns.append({"from": "gpt", "value": turn["output"]})
|
||||||
|
return turns
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||||
|
"""
|
||||||
|
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
||||||
|
"""
|
||||||
|
|
||||||
|
def map_conversation_thread(self, conversation):
|
||||||
|
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||||
|
turns = [
|
||||||
|
{"from": turn["role"], "value": turn["value"]} for turn in conversation
|
||||||
|
]
|
||||||
|
return turns
|
||||||
|
|
||||||
|
|
||||||
|
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||||
|
"""
|
||||||
|
sharegpt strategy that remaps oasst data to sharegpt format
|
||||||
|
"""
|
||||||
|
|
||||||
|
def map_conversation_thread(self, conversation):
|
||||||
|
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||||
|
role_map = {"prompter": "human", "assistant": "gpt"}
|
||||||
|
turns = [
|
||||||
|
{"from": role_map[turn["role"]], "value": turn["text"]}
|
||||||
|
for turn in conversation
|
||||||
|
]
|
||||||
|
return turns
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
"""Module for Jokes prompts using sharegpt style """
|
"""Module for Jokes prompts using sharegpt style """
|
||||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||||
from axolotl.prompters import PromptStyle, ShareGPTPrompter
|
from axolotl.prompters import ShareGPTPrompterV2
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg):
|
def load(tokenizer, cfg):
|
||||||
return SimpleJokesShareGPTPromptTokenizingStrategy(
|
return SimpleJokesShareGPTPromptTokenizingStrategy(
|
||||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
ShareGPTPrompterV2(),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
|
|||||||
@@ -1,67 +0,0 @@
|
|||||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
|
||||||
from axolotl.prompters import PromptStyle, ShareGPTPrompter
|
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg):
|
|
||||||
return SimpleShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_role(tokenizer, cfg):
|
|
||||||
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_guanaco(tokenizer, cfg):
|
|
||||||
return GuanacoShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|
||||||
"""
|
|
||||||
basic sharegpt strategy to grab conversations from the sample row
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
|
||||||
return prompt["conversations"]
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|
||||||
"""
|
|
||||||
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
|
||||||
conversations = prompt["conversations"]
|
|
||||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
|
||||||
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
|
|
||||||
return turns
|
|
||||||
|
|
||||||
|
|
||||||
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|
||||||
"""
|
|
||||||
sharegpt strategy that remaps oasst data to sharegpt format
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
|
||||||
conversations = prompt["conversations"]
|
|
||||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
|
||||||
role_map = {"prompter": "human", "assistant": "gpt"}
|
|
||||||
turns = [
|
|
||||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
|
||||||
]
|
|
||||||
return turns
|
|
||||||
@@ -4,10 +4,15 @@ import abc
|
|||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
from fastchat.conversation import Conversation
|
||||||
from transformers import BatchEncoding, PreTrainedTokenizer
|
from transformers import BatchEncoding, PreTrainedTokenizer
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.fastchat_conversation_turns import (
|
||||||
|
add_get_turns_to_conversation,
|
||||||
|
)
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
@@ -18,6 +23,8 @@ LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
|
|||||||
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
|
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
|
||||||
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
|
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
|
||||||
|
|
||||||
|
add_get_turns_to_conversation()
|
||||||
|
|
||||||
|
|
||||||
class InvalidDataException(Exception):
|
class InvalidDataException(Exception):
|
||||||
"""
|
"""
|
||||||
@@ -75,7 +82,7 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
result: BatchEncoding
|
result: BatchEncoding
|
||||||
if not prompt.strip():
|
if not prompt:
|
||||||
LOG.warning("Empty text requested for tokenization.")
|
LOG.warning("Empty text requested for tokenization.")
|
||||||
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
||||||
else:
|
else:
|
||||||
@@ -345,72 +352,109 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
Tokenizing strategy for ShareGPT prompts.
|
Tokenizing strategy for ShareGPT prompts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
_skip_invalid = False
|
||||||
return prompt["conversations"]
|
|
||||||
|
@property
|
||||||
|
def supports_batched(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def skip_invalid(self):
|
||||||
|
return self._skip_invalid
|
||||||
|
|
||||||
|
@skip_invalid.setter
|
||||||
|
def skip_invalid(self, value):
|
||||||
|
self._skip_invalid = value
|
||||||
|
|
||||||
|
def get_conversation_thread(self):
|
||||||
|
return "conversations"
|
||||||
|
|
||||||
|
def map_conversation_thread(self, conversation):
|
||||||
|
return conversation
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
result, current_len = tokenize_prompt_default()
|
tokenized_res = defaultdict(lambda: [])
|
||||||
user_token = self._get_user_token()
|
conv_field = self.get_conversation_thread()
|
||||||
assistant_token = self._get_assistant_token()
|
for prmpt in prompt[conv_field]:
|
||||||
try:
|
result, current_len = tokenize_prompt_default()
|
||||||
for _, part in enumerate(
|
user_token = self._get_user_token()
|
||||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
assistant_token = self._get_assistant_token()
|
||||||
):
|
conversation: Conversation = (
|
||||||
if isinstance(part, tuple):
|
self.prompter._conversation # pylint: disable=protected-access
|
||||||
if part[0] == "USER:":
|
)
|
||||||
turn = part[0] + part[1] if not user_token else part[1]
|
try:
|
||||||
# this is still the user query, we should
|
for _, part in enumerate(
|
||||||
if not part[1].strip():
|
self.prompter.build_prompt(self.map_conversation_thread(prmpt))
|
||||||
LOG.warning(f"user turn has empty text: {prompt}")
|
):
|
||||||
res = self._tokenize(
|
if isinstance(part, tuple):
|
||||||
turn.strip(),
|
if conversation.roles[0] in part[0]:
|
||||||
add_eos_token=False,
|
turn = part[0] + part[1] if not user_token else part[1]
|
||||||
strip_bos_token=True,
|
# this is still the user query, we should
|
||||||
)
|
if not part[1].strip():
|
||||||
if user_token:
|
err_msg = f"user turn has empty text: {prmpt}"
|
||||||
res["input_ids"] = [user_token, *res["input_ids"]]
|
if self.skip_invalid:
|
||||||
# everything from this is masked out from the labels
|
raise ValueError(err_msg)
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
LOG.warning(err_msg)
|
||||||
elif part[0] == "ASSISTANT:":
|
res = self._tokenize(
|
||||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
turn,
|
||||||
turn = part[0] + part[1] if not assistant_token else part[1]
|
add_eos_token=False,
|
||||||
# this should be the assistant response, should end with an eos token
|
strip_bos_token=True,
|
||||||
if not part[1].strip():
|
)
|
||||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
if user_token:
|
||||||
res = self._tokenize(
|
res["input_ids"] = [user_token, *res["input_ids"]]
|
||||||
turn.strip(),
|
# everything from this is masked out from the labels
|
||||||
add_eos_token=True,
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
strip_bos_token=True,
|
elif conversation.roles[1] in part[0]:
|
||||||
)
|
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||||
if assistant_token:
|
turn = part[0] + part[1] if not assistant_token else part[1]
|
||||||
res["input_ids"] = [
|
# this should be the assistant response, should end with an eos token
|
||||||
assistant_token,
|
if not part[1].strip():
|
||||||
*res["input_ids"],
|
err_msg = f"assistant turn has empty text: {prmpt}"
|
||||||
]
|
if self.skip_invalid:
|
||||||
# not masked out from labels
|
raise ValueError(err_msg)
|
||||||
labels = copy.deepcopy(res["input_ids"])
|
LOG.warning(err_msg)
|
||||||
elif part[0] == "SYSTEM:":
|
res = self._tokenize(
|
||||||
part = part[1] # Ignore the system role from preamble
|
turn,
|
||||||
# this is only ever the first part, should include the bos token and the user query
|
add_eos_token=True,
|
||||||
res = self._tokenize(
|
strip_bos_token=True,
|
||||||
part.strip(), add_eos_token=False, strip_bos_token=False
|
)
|
||||||
)
|
if assistant_token:
|
||||||
# everything from this is masked out from the labels
|
res["input_ids"] = [
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
assistant_token,
|
||||||
else:
|
*res["input_ids"],
|
||||||
LOG.warning(f"unhandled role: {part[0]}")
|
]
|
||||||
|
# not masked out from labels
|
||||||
|
labels = copy.deepcopy(res["input_ids"])
|
||||||
|
elif part[0] == "":
|
||||||
|
turn = part[1]
|
||||||
|
# this is only ever the first part, should include the bos token and the user query
|
||||||
|
res = self._tokenize(
|
||||||
|
turn, add_eos_token=False, strip_bos_token=False
|
||||||
|
)
|
||||||
|
# everything from this is masked out from the labels
|
||||||
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
|
else:
|
||||||
|
err_msg = f"unhandled role: {part[0]}"
|
||||||
|
if self.skip_invalid:
|
||||||
|
raise ValueError(err_msg)
|
||||||
|
LOG.warning(err_msg)
|
||||||
|
continue
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
result, current_len = parse_tokenized_to_result(
|
result, current_len = parse_tokenized_to_result(
|
||||||
result,
|
result,
|
||||||
current_len,
|
current_len,
|
||||||
res,
|
res,
|
||||||
labels,
|
labels,
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
)
|
)
|
||||||
return result
|
for key, val in sorted(result.items(), key=lambda x: x[0]):
|
||||||
except (KeyError, AssertionError, IndexError) as err:
|
tokenized_res[key].append(val)
|
||||||
raise InvalidDataException(str(err)) from err
|
except (KeyError, AssertionError, IndexError) as err:
|
||||||
|
raise InvalidDataException(str(err)) from err
|
||||||
|
except ValueError as err:
|
||||||
|
LOG.warning("skipping prompt: %s", str(err))
|
||||||
|
return tokenized_res
|
||||||
|
|
||||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||||
if not prompt.strip():
|
if not prompt.strip():
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""Module containing prompters"""
|
"""Module containing prompters"""
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum, auto
|
from enum import Enum
|
||||||
from typing import Generator, List, Optional, Tuple, Union
|
from typing import Generator, Optional, Union
|
||||||
|
|
||||||
|
from fastchat.conversation import Conversation, get_conv_template
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
IGNORE_TOKEN_ID = -100
|
IGNORE_TOKEN_ID = -100
|
||||||
@@ -214,53 +215,6 @@ class ReflectAlpacaPrompter:
|
|||||||
yield res
|
yield res
|
||||||
|
|
||||||
|
|
||||||
class SeparatorStyle(Enum):
|
|
||||||
"""Different separator style."""
|
|
||||||
|
|
||||||
SINGLE = auto()
|
|
||||||
TWO = auto()
|
|
||||||
DOLLY = auto()
|
|
||||||
|
|
||||||
|
|
||||||
# TODO clean this 💩 up
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class Conversation:
|
|
||||||
"""A class that keeps all conversation history."""
|
|
||||||
|
|
||||||
system: str
|
|
||||||
roles: List[str]
|
|
||||||
messages: List[List[str]]
|
|
||||||
offset: int
|
|
||||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
|
||||||
sep: str = "###"
|
|
||||||
sep2: Optional[str] = None
|
|
||||||
|
|
||||||
def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
|
|
||||||
# seps = [self.sep, self.sep2]
|
|
||||||
preamble = self.system + self.sep
|
|
||||||
yield ("SYSTEM:", preamble)
|
|
||||||
for _, (role, message) in enumerate(self.messages):
|
|
||||||
if message:
|
|
||||||
yield (role + ":", " " + message)
|
|
||||||
else:
|
|
||||||
LOG.warning(f"role with empty message: {role}")
|
|
||||||
yield (role + ":", "")
|
|
||||||
|
|
||||||
def copy(self):
|
|
||||||
return Conversation(
|
|
||||||
system=self.system,
|
|
||||||
roles=self.roles,
|
|
||||||
messages=[[x, y] for x, y in self.messages],
|
|
||||||
offset=self.offset,
|
|
||||||
sep_style=self.sep_style,
|
|
||||||
sep=self.sep,
|
|
||||||
sep2=self.sep2,
|
|
||||||
)
|
|
||||||
|
|
||||||
def append_message(self, role, message):
|
|
||||||
self.messages.append([role, message])
|
|
||||||
|
|
||||||
|
|
||||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||||
"Role did not alternate between turns (gpt and human). Please check your data."
|
"Role did not alternate between turns (gpt and human). Please check your data."
|
||||||
)
|
)
|
||||||
@@ -271,28 +225,27 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
A prompter that generates prompts for the ShareGPT
|
A prompter that generates prompts for the ShareGPT
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
|
role_key_human = "human"
|
||||||
if prompt_style != PromptStyle.CHAT.value:
|
role_key_model = "gpt"
|
||||||
raise ValueError(
|
|
||||||
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
|
def __init__(
|
||||||
)
|
self,
|
||||||
system: str = (
|
prompt_style=None, # pylint: disable=unused-argument
|
||||||
system_prompt
|
conversation: Optional[Union[str, Conversation]] = None,
|
||||||
if system_prompt
|
role_key_human: Optional[str] = None,
|
||||||
else (
|
role_key_model: Optional[str] = None,
|
||||||
"A chat between a curious user and an artificial intelligence assistant. "
|
):
|
||||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
if conversation:
|
||||||
)
|
if isinstance(conversation, Conversation):
|
||||||
)
|
self._conversation = conversation
|
||||||
self._conversation = Conversation(
|
else:
|
||||||
system=system,
|
self._conversation = get_conv_template(conversation)
|
||||||
roles=["USER", "ASSISTANT"],
|
else:
|
||||||
messages=[],
|
self._conversation = get_conv_template("vicuna_v1.1")
|
||||||
offset=0,
|
if role_key_human:
|
||||||
sep_style=SeparatorStyle.TWO,
|
self.role_key_human = role_key_human
|
||||||
sep=" ",
|
if role_key_model:
|
||||||
sep2=" ",
|
self.role_key_model = role_key_model
|
||||||
)
|
|
||||||
|
|
||||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||||
if len(source) < 2:
|
if len(source) < 2:
|
||||||
@@ -306,17 +259,14 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
|
|
||||||
# Add the conversation system prompt if provided, otherwise use the default one
|
# Add the conversation system prompt if provided, otherwise use the default one
|
||||||
if source[0]["from"] == "system":
|
if source[0]["from"] == "system":
|
||||||
conv.system = source[0]["value"]
|
conv.set_system_message(source[0]["value"])
|
||||||
source.pop(0)
|
source.pop(0)
|
||||||
|
|
||||||
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Apply prompt templates
|
# Apply prompt templates
|
||||||
if (
|
if source[0]["from"] not in roles:
|
||||||
source[0]["from"] not in roles
|
|
||||||
or roles[source[0]["from"]] != conv.roles[0]
|
|
||||||
):
|
|
||||||
# Skip the first one if it is not from human
|
# Skip the first one if it is not from human
|
||||||
source = source[1:]
|
source = source[1:]
|
||||||
except IndexError as err:
|
except IndexError as err:
|
||||||
@@ -326,8 +276,29 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
conv.messages = []
|
conv.messages = []
|
||||||
for j, sentence in enumerate(source):
|
for j, sentence in enumerate(source):
|
||||||
role = roles[sentence["from"]]
|
role = roles[sentence["from"]]
|
||||||
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
if role != conv.roles[j % 2]:
|
||||||
|
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||||
conv.append_message(role, sentence["value"])
|
conv.append_message(role, sentence["value"])
|
||||||
|
|
||||||
for part in conv.get_prompt():
|
for part in conv.get_turns():
|
||||||
|
if part[0] and not part[1]:
|
||||||
|
LOG.warning(f"role with empty message: {part[0]}")
|
||||||
yield part
|
yield part
|
||||||
|
|
||||||
|
|
||||||
|
class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||||
|
"""
|
||||||
|
A V2 prompter that generates prompts for the ShareGPT
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
conversation: Optional[Union[str, Conversation]] = None,
|
||||||
|
role_key_human: Optional[str] = None,
|
||||||
|
role_key_model: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
conversation=conversation,
|
||||||
|
role_key_human=role_key_human,
|
||||||
|
role_key_model=role_key_model,
|
||||||
|
)
|
||||||
|
|||||||
@@ -58,7 +58,9 @@ def train(
|
|||||||
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
if (
|
||||||
|
cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints
|
||||||
|
) or cfg.resume_from_checkpoint is True:
|
||||||
possible_checkpoints = [
|
possible_checkpoints = [
|
||||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||||
]
|
]
|
||||||
@@ -71,7 +73,9 @@ def train(
|
|||||||
LOG.info(
|
LOG.info(
|
||||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||||
)
|
)
|
||||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
resume_from_checkpoint = (
|
||||||
|
cfg.resume_from_checkpoint if cfg.resume_from_checkpoint is not True else None
|
||||||
|
)
|
||||||
|
|
||||||
trainer = setup_trainer(
|
trainer = setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ def normalize_config(cfg):
|
|||||||
cfg.batch_size = (
|
cfg.batch_size = (
|
||||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||||
)
|
)
|
||||||
|
if cfg.eval_batch_size is None:
|
||||||
|
cfg.eval_batch_size = cfg.micro_batch_size
|
||||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
cfg.eval_table_size = cfg.eval_table_size or 0
|
cfg.eval_table_size = cfg.eval_table_size or 0
|
||||||
@@ -75,6 +77,8 @@ def normalize_config(cfg):
|
|||||||
else:
|
else:
|
||||||
cfg.torch_dtype = torch.float32
|
cfg.torch_dtype = torch.float32
|
||||||
|
|
||||||
|
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
||||||
|
|
||||||
model_config = load_model_config(cfg)
|
model_config = load_model_config(cfg)
|
||||||
cfg.model_config_type = model_config.model_type
|
cfg.model_config_type = model_config.model_type
|
||||||
|
|
||||||
@@ -82,10 +86,39 @@ def normalize_config(cfg):
|
|||||||
cfg.is_llama_derived_model = (
|
cfg.is_llama_derived_model = (
|
||||||
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
||||||
or cfg.is_llama_derived_model
|
or cfg.is_llama_derived_model
|
||||||
or "llama" in cfg.base_model
|
or "llama" in cfg.base_model.lower()
|
||||||
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# figure out if the model is falcon
|
||||||
|
cfg.is_falcon_derived_model = (
|
||||||
|
(
|
||||||
|
hasattr(model_config, "model_type")
|
||||||
|
and model_config.model_type
|
||||||
|
in [
|
||||||
|
"falcon",
|
||||||
|
"RefinedWebModel",
|
||||||
|
"RefinedWeb",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
or cfg.is_falcon_derived_model
|
||||||
|
or "falcon" in cfg.base_model.lower()
|
||||||
|
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg.is_mistral_derived_model = (
|
||||||
|
(
|
||||||
|
hasattr(model_config, "model_type")
|
||||||
|
and model_config.model_type
|
||||||
|
in [
|
||||||
|
"mistral",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
or cfg.is_mistral_derived_model
|
||||||
|
or "mistral" in cfg.base_model.lower()
|
||||||
|
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
||||||
|
)
|
||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
|
|
||||||
@@ -126,6 +159,11 @@ def validate_config(cfg):
|
|||||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||||
)
|
)
|
||||||
|
if cfg.eval_batch_size != cfg.micro_batch_size:
|
||||||
|
LOG.warning(
|
||||||
|
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.load_4bit:
|
if cfg.load_4bit:
|
||||||
raise ValueError("cfg.load_4bit parameter has been deprecated")
|
raise ValueError("cfg.load_4bit parameter has been deprecated")
|
||||||
|
|
||||||
@@ -262,6 +300,45 @@ def validate_config(cfg):
|
|||||||
"`model_type: MixFormerSequentialForCausalLM` required for sample_packing"
|
"`model_type: MixFormerSequentialForCausalLM` required for sample_packing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.datasets:
|
||||||
|
for idx, ds_cfg in enumerate(cfg.datasets):
|
||||||
|
if not ds_cfg.type:
|
||||||
|
continue
|
||||||
|
if ds_cfg.type == "sharegpt:chat":
|
||||||
|
LOG.warning(
|
||||||
|
PendingDeprecationWarning(
|
||||||
|
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cfg.datasets[idx].type = "sharegpt"
|
||||||
|
if "sharegpt_simple" in ds_cfg.type:
|
||||||
|
LOG.warning(
|
||||||
|
PendingDeprecationWarning(
|
||||||
|
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
||||||
|
"sharegpt_simple", "sharegpt"
|
||||||
|
)
|
||||||
|
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||||
|
raise ValueError(
|
||||||
|
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
cfg.evaluation_strategy
|
||||||
|
and cfg.eval_steps
|
||||||
|
and cfg.evaluation_strategy != "steps"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
|
||||||
|
raise ValueError(
|
||||||
|
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||||
|
)
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from axolotl.prompt_tokenizers import (
|
|||||||
GPTeacherPromptTokenizingStrategy,
|
GPTeacherPromptTokenizingStrategy,
|
||||||
JeopardyPromptTokenizingStrategy,
|
JeopardyPromptTokenizingStrategy,
|
||||||
OpenAssistantPromptTokenizingStrategy,
|
OpenAssistantPromptTokenizingStrategy,
|
||||||
ShareGPTPromptTokenizingStrategy,
|
|
||||||
SummarizeTLDRPromptTokenizingStrategy,
|
SummarizeTLDRPromptTokenizingStrategy,
|
||||||
)
|
)
|
||||||
from axolotl.prompters import (
|
from axolotl.prompters import (
|
||||||
@@ -35,7 +34,6 @@ from axolotl.prompters import (
|
|||||||
MultipleChoiceConcisePrompter,
|
MultipleChoiceConcisePrompter,
|
||||||
MultipleChoiceExplainPrompter,
|
MultipleChoiceExplainPrompter,
|
||||||
ReflectAlpacaPrompter,
|
ReflectAlpacaPrompter,
|
||||||
ShareGPTPrompter,
|
|
||||||
SummarizeTLDRPrompter,
|
SummarizeTLDRPrompter,
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -76,7 +74,7 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
|
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||||
cfg, train_dataset, eval_dataset
|
cfg, train_dataset, eval_dataset, tokenizer
|
||||||
)
|
)
|
||||||
if cfg.max_steps:
|
if cfg.max_steps:
|
||||||
total_num_steps = min(
|
total_num_steps = min(
|
||||||
@@ -116,7 +114,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
||||||
use_auth_token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
dataset = dataset["train"]
|
dataset = dataset["train"]
|
||||||
except Exception: # pylint: disable=broad-except # nosec
|
except Exception: # pylint: disable=broad-except # nosec
|
||||||
@@ -124,7 +122,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
|
|
||||||
if dataset:
|
if dataset:
|
||||||
...
|
...
|
||||||
elif any(prepared_ds_path.glob("*")):
|
elif cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")):
|
||||||
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
||||||
dataset = load_from_disk(str(prepared_ds_path))
|
dataset = load_from_disk(str(prepared_ds_path))
|
||||||
LOG.info("Prepared dataset loaded from disk...")
|
LOG.info("Prepared dataset loaded from disk...")
|
||||||
@@ -157,24 +155,26 @@ def load_tokenized_prepared_datasets(
|
|||||||
d.path,
|
d.path,
|
||||||
name=d.name,
|
name=d.name,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
use_auth_token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except FileNotFoundError:
|
except (FileNotFoundError, ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# prefer local dataset, even if hub exists
|
# prefer local dataset, even if hub exists
|
||||||
local_path = Path(d.path)
|
local_path = Path(d.path)
|
||||||
if local_path.exists():
|
if local_path.exists():
|
||||||
if local_path.is_dir():
|
if local_path.is_dir():
|
||||||
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
if not d.type:
|
||||||
ds = load_dataset(
|
ds = load_from_disk(d.path)
|
||||||
d.path,
|
else:
|
||||||
name=d.name,
|
ds = load_dataset(
|
||||||
data_files=d.data_files,
|
d.path,
|
||||||
streaming=False,
|
name=d.name,
|
||||||
split=None,
|
data_files=d.data_files,
|
||||||
)
|
streaming=False,
|
||||||
|
split=None,
|
||||||
|
)
|
||||||
elif local_path.is_file():
|
elif local_path.is_file():
|
||||||
ds_type = "json"
|
ds_type = "json"
|
||||||
if d.ds_type:
|
if d.ds_type:
|
||||||
@@ -204,14 +204,29 @@ def load_tokenized_prepared_datasets(
|
|||||||
name=d.name,
|
name=d.name,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
data_files=d.data_files,
|
data_files=d.data_files,
|
||||||
use_auth_token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
fp = hf_hub_download(
|
if isinstance(d.data_files, str):
|
||||||
repo_id=d.path,
|
fp = hf_hub_download(
|
||||||
repo_type="dataset",
|
repo_id=d.path,
|
||||||
filename=d.data_files,
|
repo_type="dataset",
|
||||||
)
|
filename=d.data_files,
|
||||||
|
)
|
||||||
|
elif isinstance(d.data_files, list):
|
||||||
|
fp = []
|
||||||
|
for file in d.data_files:
|
||||||
|
fp.append(
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id=d.path,
|
||||||
|
repo_type="dataset",
|
||||||
|
filename=file,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"data_files must be either a string or list of strings"
|
||||||
|
)
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
"json", name=d.name, data_files=fp, streaming=False, split=None
|
"json", name=d.name, data_files=fp, streaming=False, split=None
|
||||||
)
|
)
|
||||||
@@ -234,6 +249,16 @@ def load_tokenized_prepared_datasets(
|
|||||||
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
||||||
if "train" in ds:
|
if "train" in ds:
|
||||||
ds = ds["train"]
|
ds = ds["train"]
|
||||||
|
elif (
|
||||||
|
isinstance(ds, DatasetDict)
|
||||||
|
and d.train_on_split
|
||||||
|
and d.train_on_split in ds
|
||||||
|
):
|
||||||
|
ds = ds[d.train_on_split]
|
||||||
|
elif isinstance(ds, DatasetDict):
|
||||||
|
raise ValueError(
|
||||||
|
f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `"
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
"input_ids" in ds.features
|
"input_ids" in ds.features
|
||||||
and "attention_mask" in ds.features
|
and "attention_mask" in ds.features
|
||||||
@@ -320,15 +345,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
elif d_base_type == "sharegpt":
|
|
||||||
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
else:
|
else:
|
||||||
suffix = ""
|
suffix = ""
|
||||||
if ":load_" in d.type:
|
if ":load_" in d.type:
|
||||||
@@ -343,7 +359,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
if len(datasets) > 1:
|
if len(datasets) > 1:
|
||||||
LOG.info("shuffle merged datasets")
|
LOG.info("shuffle merged datasets")
|
||||||
dataset = dataset.shuffle(seed=seed)
|
dataset = dataset.shuffle(seed=seed)
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0 and cfg.dataset_prepared_path:
|
||||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||||
dataset.save_to_disk(prepared_ds_path)
|
dataset.save_to_disk(prepared_ds_path)
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
@@ -403,7 +419,7 @@ def load_prepare_datasets(
|
|||||||
)
|
)
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
||||||
use_auth_token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
dataset = dataset["train"]
|
dataset = dataset["train"]
|
||||||
except Exception: # pylint: disable=broad-except # nosec
|
except Exception: # pylint: disable=broad-except # nosec
|
||||||
@@ -411,7 +427,7 @@ def load_prepare_datasets(
|
|||||||
|
|
||||||
if dataset:
|
if dataset:
|
||||||
...
|
...
|
||||||
elif any(prepared_ds_path.glob("*")):
|
elif cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")):
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
|
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Module for models and model loading"""
|
"""Module for models and model loading"""
|
||||||
import importlib
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -12,6 +11,7 @@ from optimum.bettertransformer import BetterTransformer
|
|||||||
from peft import PeftConfig, prepare_model_for_kbit_training
|
from peft import PeftConfig, prepare_model_for_kbit_training
|
||||||
from peft.tuners.lora import QuantLinear
|
from peft.tuners.lora import QuantLinear
|
||||||
from transformers import ( # noqa: F401
|
from transformers import ( # noqa: F401
|
||||||
|
AddedToken,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@@ -81,11 +81,22 @@ def load_tokenizer(cfg):
|
|||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
# Mistral's official FA implementation requires left padding
|
||||||
|
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
if cfg.special_tokens:
|
if cfg.special_tokens:
|
||||||
for k, val in cfg.special_tokens.items():
|
for k, val in cfg.special_tokens.items():
|
||||||
tokenizer.add_special_tokens({k: val})
|
tokenizer.add_special_tokens(
|
||||||
|
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
||||||
|
)
|
||||||
if cfg.tokens:
|
if cfg.tokens:
|
||||||
tokenizer.add_tokens(list(cfg.tokens))
|
tokenizer.add_tokens(
|
||||||
|
[
|
||||||
|
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
|
||||||
|
for token in cfg.tokens
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
@@ -114,26 +125,29 @@ def load_model(
|
|||||||
|
|
||||||
replace_btlm_attn_with_flash_attn(cfg.base_model)
|
replace_btlm_attn_with_flash_attn(cfg.base_model)
|
||||||
|
|
||||||
if hasattr(model_config, "model_type") and model_config.model_type in [
|
if (
|
||||||
"falcon",
|
hasattr(model_config, "model_type")
|
||||||
"RefinedWebModel",
|
and model_config.model_type == "stablelm_epoch"
|
||||||
"RefinedWeb",
|
):
|
||||||
]:
|
if cfg.flash_attention and cfg.sample_packing:
|
||||||
if cfg.flash_attention:
|
from axolotl.monkeypatch.stablelm_attn_hijack_flash import (
|
||||||
from axolotl.monkeypatch.falcon_attn_hijack_flash import (
|
replace_stablelm_attn_with_flash_attn,
|
||||||
replace_falcon_attn_with_flash_attn,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
replace_falcon_attn_with_flash_attn()
|
replace_stablelm_attn_with_flash_attn(cfg.base_model)
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
|
||||||
if cfg.device not in ["mps", "cpu"] and not inference:
|
if cfg.device not in ["mps", "cpu"] and not inference:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
replace_llama_attn_with_flash_attn,
|
replace_llama_attn_with_flash_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.info("patching with flash attention")
|
LOG.info("patching with flash attention for sample packing")
|
||||||
replace_llama_attn_with_flash_attn(packed=cfg.sample_packing)
|
replace_llama_attn_with_flash_attn(
|
||||||
|
packed=cfg.sample_packing,
|
||||||
|
cross_entropy=cfg.flash_attn_cross_entropy,
|
||||||
|
rms_norm=cfg.flash_attn_rms_norm,
|
||||||
|
)
|
||||||
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||||
hijack_llama_attention,
|
hijack_llama_attention,
|
||||||
@@ -158,6 +172,14 @@ def load_model(
|
|||||||
# Note: This might overwrite previous additional_special_tokens
|
# Note: This might overwrite previous additional_special_tokens
|
||||||
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
||||||
|
|
||||||
|
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
|
||||||
|
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
||||||
|
replace_mistral_attn_with_flash_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("patching with flash attention")
|
||||||
|
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||||
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
||||||
replace_llama_rope_with_xpos_rope,
|
replace_llama_rope_with_xpos_rope,
|
||||||
@@ -176,21 +198,11 @@ def load_model(
|
|||||||
LOG.info("patching _expand_mask")
|
LOG.info("patching _expand_mask")
|
||||||
hijack_expand_mask()
|
hijack_expand_mask()
|
||||||
|
|
||||||
# special handling b/c remote MixFormers code doesn't have _no_split_modules set
|
|
||||||
if (
|
|
||||||
"MixFormerSequentialConfig" in model_config.__class__.__name__
|
|
||||||
and cfg.model_type == "AutoModelForCausalLM"
|
|
||||||
):
|
|
||||||
module_name = model_config.__class__.__module__.replace(
|
|
||||||
".configuration_mixformer_sequential", ".modeling_mixformer_sequential"
|
|
||||||
)
|
|
||||||
modeling_phi = importlib.import_module(module_name)
|
|
||||||
# pylint:disable=protected-access
|
|
||||||
modeling_phi.MixFormerSequentialForCausalLM._no_split_modules = [
|
|
||||||
"ParallelBlock"
|
|
||||||
]
|
|
||||||
|
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
|
|
||||||
|
model_kwargs["device_map"] = cfg.device_map
|
||||||
|
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
||||||
|
|
||||||
if cfg.model_revision:
|
if cfg.model_revision:
|
||||||
model_kwargs["revision"] = cfg.model_revision
|
model_kwargs["revision"] = cfg.model_revision
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
@@ -213,6 +225,15 @@ def load_model(
|
|||||||
bnb_4bit_use_double_quant=True,
|
bnb_4bit_use_double_quant=True,
|
||||||
bnb_4bit_quant_type="nf4",
|
bnb_4bit_quant_type="nf4",
|
||||||
)
|
)
|
||||||
|
# sample packing uses custom FA2 patch
|
||||||
|
if cfg.flash_attention and not cfg.sample_packing:
|
||||||
|
if (
|
||||||
|
cfg.is_llama_derived_model
|
||||||
|
or cfg.is_falcon_derived_model
|
||||||
|
or cfg.is_mistral_derived_model
|
||||||
|
):
|
||||||
|
model_kwargs["use_flash_attention_2"] = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||||
from transformers import LlamaForCausalLM
|
from transformers import LlamaForCausalLM
|
||||||
@@ -227,10 +248,8 @@ def load_model(
|
|||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=config,
|
config=config,
|
||||||
device_map=cfg.device_map,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=cfg.torch_dtype,
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||||
@@ -264,28 +283,22 @@ def load_model(
|
|||||||
|
|
||||||
model = MixFormerSequentialForCausalLM.from_pretrained(
|
model = MixFormerSequentialForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
device_map=cfg.device_map,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=cfg.torch_dtype,
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
elif model_type and not cfg.trust_remote_code:
|
elif model_type and not cfg.trust_remote_code:
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
device_map=cfg.device_map,
|
|
||||||
torch_dtype=cfg.torch_dtype,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = getattr(transformers, model_type).from_pretrained(
|
model = getattr(transformers, model_type).from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
device_map=cfg.device_map,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=cfg.torch_dtype,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -314,8 +327,6 @@ def load_model(
|
|||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=config,
|
config=config,
|
||||||
device_map=cfg.device_map,
|
|
||||||
torch_dtype=cfg.torch_dtype,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -323,10 +334,8 @@ def load_model(
|
|||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=config,
|
config=config,
|
||||||
device_map=cfg.device_map,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=cfg.torch_dtype,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -337,10 +346,8 @@ def load_model(
|
|||||||
LOG.exception(err)
|
LOG.exception(err)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
device_map=cfg.device_map,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=cfg.torch_dtype,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -375,7 +382,7 @@ def load_model(
|
|||||||
if model_config.model_type == "btlm":
|
if model_config.model_type == "btlm":
|
||||||
# don't upcast lm_head for btlm
|
# don't upcast lm_head for btlm
|
||||||
continue
|
continue
|
||||||
if "lm_head" in name or "embed_tokens" in name:
|
if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
|
||||||
if hasattr(module, "weight"):
|
if hasattr(module, "weight"):
|
||||||
module.to(torch.float32)
|
module.to(torch.float32)
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,9 @@ def check_example_labels(example, tokenizer, text_only=False):
|
|||||||
)
|
)
|
||||||
colored_tokens.append(colored_token)
|
colored_tokens.append(colored_token)
|
||||||
|
|
||||||
LOG.info(" ".join(colored_tokens))
|
delimiter = "" if text_only else " "
|
||||||
|
LOG.info(delimiter.join(colored_tokens))
|
||||||
LOG.info("\n\n\n")
|
LOG.info("\n\n\n")
|
||||||
|
print(" ".join(colored_tokens))
|
||||||
|
|
||||||
return " ".join(colored_tokens)
|
return " ".join(colored_tokens)
|
||||||
|
|||||||
@@ -397,23 +397,36 @@ def disable_datasets_caching():
|
|||||||
set_caching_enabled(True)
|
set_caching_enabled(True)
|
||||||
|
|
||||||
|
|
||||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
||||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
|
train_dataset = train_dataset.filter(drop_long, num_proc=cfg.dataset_processes)
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
|
eval_dataset = eval_dataset.filter(
|
||||||
|
drop_long, num_proc=cfg.dataset_processes
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.group_by_length:
|
if cfg.group_by_length:
|
||||||
train_dataset = train_dataset.map(add_length, num_proc=os.cpu_count())
|
train_dataset = train_dataset.map(
|
||||||
|
add_length, num_proc=cfg.dataset_processes
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
|
train_dataset = train_dataset.map(
|
||||||
|
add_position_ids, num_proc=cfg.dataset_processes
|
||||||
|
)
|
||||||
if cfg.eval_sample_packing is not False:
|
if cfg.eval_sample_packing is not False:
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
add_position_ids, num_proc=os.cpu_count()
|
add_position_ids, num_proc=cfg.dataset_processes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Phi doesn't want the attention_mask feature when training
|
||||||
|
if "CodeGenTokenizer" in tokenizer.__class__.__name__:
|
||||||
|
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||||
|
if eval_dataset:
|
||||||
|
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
||||||
|
|
||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
|
|
||||||
@@ -597,26 +610,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
"sample_packing_efficiency"
|
"sample_packing_efficiency"
|
||||||
] = cfg.sample_packing_eff_est
|
] = cfg.sample_packing_eff_est
|
||||||
|
|
||||||
if cfg.eval_steps and cfg.evaluation_strategy:
|
if cfg.eval_steps:
|
||||||
# assume if the user set both, they know what they're doing
|
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||||
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
|
|
||||||
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
|
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
|
||||||
|
elif cfg.evaluation_strategy:
|
||||||
|
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
|
||||||
elif cfg.val_set_size == 0:
|
elif cfg.val_set_size == 0:
|
||||||
# no eval set, so don't eval
|
# no eval set, so don't eval
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||||
elif cfg.evaluation_strategy and cfg.evaluation_strategy in ["epoch", "no"]:
|
|
||||||
# if explicitly set for epoch, just set, and eval steps don't matter
|
|
||||||
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
|
|
||||||
elif cfg.eval_steps:
|
|
||||||
# steps isn't used w/ epochs
|
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
|
||||||
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
|
|
||||||
else:
|
else:
|
||||||
# we have an eval set, but no steps defined, default to use epoch
|
# we have an eval set, but no steps defined, default to use epoch
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||||
|
|
||||||
if cfg.save_steps:
|
if cfg.save_steps:
|
||||||
# save_steps implies save_strategy of steps
|
|
||||||
training_arguments_kwargs["save_strategy"] = "steps"
|
training_arguments_kwargs["save_strategy"] = "steps"
|
||||||
training_arguments_kwargs["save_steps"] = cfg.save_steps
|
training_arguments_kwargs["save_steps"] = cfg.save_steps
|
||||||
elif cfg.save_strategy:
|
elif cfg.save_strategy:
|
||||||
@@ -662,9 +668,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
max_steps=total_num_steps if cfg.max_steps else -1,
|
max_steps=total_num_steps if cfg.max_steps else -1,
|
||||||
max_seq_length=cfg.sequence_len,
|
max_seq_length=cfg.sequence_len,
|
||||||
per_device_train_batch_size=cfg.micro_batch_size,
|
per_device_train_batch_size=cfg.micro_batch_size,
|
||||||
per_device_eval_batch_size=cfg.eval_batch_size
|
per_device_eval_batch_size=cfg.eval_batch_size,
|
||||||
if cfg.eval_batch_size is not None
|
|
||||||
else cfg.micro_batch_size,
|
|
||||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||||
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||||
num_train_epochs=cfg.num_epochs,
|
num_train_epochs=cfg.num_epochs,
|
||||||
|
|||||||
116
tests/e2e/test_mistral.py
Normal file
116
tests/e2e/test_mistral.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for lora llama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMistral(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models using LoRA
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_lora(self):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
output_dir = tempfile.mkdtemp()
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||||
|
"base_model_config": "openaccess-ai-collective/tiny-mistral",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 32,
|
||||||
|
"lora_alpha": 64,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": output_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
def test_ft(self):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
output_dir = tempfile.mkdtemp()
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||||
|
"base_model_config": "openaccess-ai-collective/tiny-mistral",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": output_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if is_torch_bf16_gpu_available():
|
||||||
|
cfg.bf16 = True
|
||||||
|
else:
|
||||||
|
cfg.fp16 = True
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
||||||
118
tests/e2e/test_mistral_samplepack.py
Normal file
118
tests/e2e/test_mistral_samplepack.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for lora llama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMistral(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models using LoRA
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_lora_packing(self):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
output_dir = tempfile.mkdtemp()
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||||
|
"base_model_config": "openaccess-ai-collective/tiny-mistral",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sample_packing": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 32,
|
||||||
|
"lora_alpha": 64,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": output_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
def test_ft_packing(self):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
output_dir = tempfile.mkdtemp()
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||||
|
"base_model_config": "openaccess-ai-collective/tiny-mistral",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sample_packing": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": output_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if is_torch_bf16_gpu_available():
|
||||||
|
cfg.bf16 = True
|
||||||
|
else:
|
||||||
|
cfg.fp16 = True
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
||||||
2
tests/fixtures/conversation.tokenized.json
vendored
2
tests/fixtures/conversation.tokenized.json
vendored
File diff suppressed because one or more lines are too long
@@ -21,7 +21,7 @@ from axolotl.prompt_tokenizers import (
|
|||||||
AlpacaPromptTokenizingStrategy,
|
AlpacaPromptTokenizingStrategy,
|
||||||
ShareGPTPromptTokenizingStrategy,
|
ShareGPTPromptTokenizingStrategy,
|
||||||
)
|
)
|
||||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
|
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -60,7 +60,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
) as fin:
|
) as fin:
|
||||||
data = fin.read()
|
data = fin.read()
|
||||||
tokenized_conversation = json.loads(data)
|
tokenized_conversation = json.loads(data)
|
||||||
prompter = ShareGPTPrompter("chat")
|
prompter = ShareGPTPrompterV2()
|
||||||
strat = ShareGPTPromptTokenizingStrategy(
|
strat = ShareGPTPromptTokenizingStrategy(
|
||||||
prompter,
|
prompter,
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
@@ -79,7 +79,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
) as fin:
|
) as fin:
|
||||||
data = fin.read()
|
data = fin.read()
|
||||||
conversation = json.loads(data)
|
conversation = json.loads(data)
|
||||||
prompter = ShareGPTPrompter("chat")
|
prompter = ShareGPTPrompterV2()
|
||||||
strat = ShareGPTPromptTokenizingStrategy(
|
strat = ShareGPTPromptTokenizingStrategy(
|
||||||
prompter,
|
prompter,
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
|
|||||||
@@ -374,3 +374,194 @@ class ValidationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_sharegpt_deprecation(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]}
|
||||||
|
)
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert any(
|
||||||
|
"`type: sharegpt:chat` will soon be deprecated." in record.message
|
||||||
|
for record in self._caplog.records
|
||||||
|
)
|
||||||
|
assert cfg.datasets[0].type == "sharegpt"
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}]}
|
||||||
|
)
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert any(
|
||||||
|
"`type: sharegpt_simple` will soon be deprecated." in record.message
|
||||||
|
for record in self._caplog.records
|
||||||
|
)
|
||||||
|
assert cfg.datasets[0].type == "sharegpt:load_role"
|
||||||
|
|
||||||
|
def test_no_conflict_save_strategy(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"save_strategy": "epoch",
|
||||||
|
"save_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match=r".*save_strategy and save_steps mismatch.*"
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"save_strategy": "no",
|
||||||
|
"save_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match=r".*save_strategy and save_steps mismatch.*"
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"save_strategy": "steps",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"save_strategy": "steps",
|
||||||
|
"save_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"save_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"save_strategy": "no",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_no_conflict_eval_strategy(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"evaluation_strategy": "epoch",
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"evaluation_strategy": "no",
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"evaluation_strategy": "steps",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"evaluation_strategy": "steps",
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"evaluation_strategy": "no",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"evaluation_strategy": "epoch",
|
||||||
|
"val_set_size": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"eval_steps": 10,
|
||||||
|
"val_set_size": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"val_set_size": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"eval_steps": 10,
|
||||||
|
"val_set_size": 0.01,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"evaluation_strategy": "epoch",
|
||||||
|
"val_set_size": 0.01,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user