Compare commits

..

23 Commits

Author SHA1 Message Date
Wing Lian
718a8f4153 update flash attention to 2.5.5 for gemma 2024-02-21 23:32:44 -05:00
NanoCode012
a359579371 deprecate: pytorch 2.0.1 image (#1315) [skip ci]
* deprecate: pytorch 2.0.1 image

* deprecate from main image

* Update main.yml

* Update tests.yml
2024-02-22 11:39:47 +09:00
Wing Lian
2752d5f958 multipack for gemma (#1313)
* multipack for gemma

* chore: lint

* handle cache_position kwarg in updated llama modeling

* add position_ids to rotary embed call for updated llama modeling
2024-02-21 19:24:21 -05:00
Monk
9e300aca0c Adding Google's gemma Model (#1312) 2024-02-21 12:56:47 -05:00
NanoCode012
3d2cd804ae fix(readme): update inference md link (#1311) [skip ci] 2024-02-22 02:48:06 +09:00
Jared Palmer
6ab69ec5f8 Add instructions for playing with qlora model to colab example (#1290)
* Add instructions for playing with qlora model to colab example

* Update examples/colab-notebooks/colab-axolotl-example.ipynb

Co-authored-by: JohanWork <39947546+JohanWork@users.noreply.github.com>

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: JohanWork <39947546+JohanWork@users.noreply.github.com>
2024-02-22 02:46:27 +09:00
David Meikle
3c00f406d6 Allow load_best_model_at_end to be configured for early stopping on custom evaluation datasets (#1291)
* Allow load_best_model_at_end when using test_datasets and val_set_size is zero for custom evaluation datasets

* Fixed formatting following failed Lint check
2024-02-22 00:57:18 +09:00
NanoCode012
a7a9a1433a fix(examples): remove is_*_derived as it's parsed automatically (#1297) 2024-02-22 00:52:46 +09:00
Leonardo Emili
e2786cce6a Validation always happens on first step (#1300) 2024-02-22 00:52:24 +09:00
Leonardo Emili
5a5d47458d Add seq2seq eval benchmark callback (#1274)
* Add CausalLMBenchEvalCallback for measuring seq2seq performance

* Fix code for pre-commit

* Fix typing and improve logging

* eval_sample_packing must be false with CausalLMBenchEvalCallback
2024-02-13 08:24:30 -08:00
김진원
8430db22e2 Scheduler implementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (#1273) 2024-02-12 21:23:28 -08:00
Wing Lian
4b997c3e1a allow the optimizer prune ratio for ReLoRA to be configurable (#1287)
* allow the optimizer prune ration for relora to be configurable

* update docs for relora

* prevent circular imports
2024-02-12 11:39:51 -08:00
Maxime
fac2d98c26 Add MPS support (#1264)
* add mps support

* linter stuff

* CI fixes

* install packaging for various tests

* Update setup.py

* Revert "install packaging for various tests"

This reverts commit 980e7aa44d.

* Revert "CI fixes"

This reverts commit 4609e3b166.

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-02-12 08:30:32 -05:00
Wing Lian
ea00dd0852 don't use load and push together (#1284) 2024-02-09 14:54:31 -05:00
Hamel Husain
b2a4cb4396 Update README.md (#1281) 2024-02-09 07:38:08 -08:00
Wing Lian
aaf54dc730 run the docker image builds and push on gh action gpu runners (#1218) 2024-02-09 10:32:54 -05:00
Hamel Husain
9bca7db133 add support for https remote yamls (#1277) 2024-02-08 20:02:17 -08:00
Hamel Husain
91cf4ee72c allow remote data paths (#1278)
* allow remote data paths

* add docs about public url

* only allow https

* better docs

* better docs
2024-02-08 15:02:35 -08:00
Wing Lian
1daecd161e copy edits (#1276) 2024-02-08 09:00:04 -05:00
Wing Lian
4a654b331e Add link to axolotl cloud image on latitude (#1275) 2024-02-08 08:50:11 -05:00
Wing Lian
5698943263 simplify haldning for newer multipack patches so they can be added in a single place (#1270) 2024-02-07 10:46:04 -05:00
Wing Lian
411293bdca contributor avatars (#1269) 2024-02-07 07:09:01 -08:00
Zac Brannelly
73f1bdaa15 Fix bug preventing model_kwargs being injected (#1262) 2024-02-07 09:38:35 -05:00
60 changed files with 769 additions and 292 deletions

View File

@@ -7,16 +7,11 @@ jobs:
build-base:
if: github.repository_owner == 'OpenAccess-AI-Collective'
# this job needs to be run on self-hosted GPU runners...
runs-on: self-hosted
runs-on: axolotl-gpu-runner
strategy:
fail-fast: false
matrix:
include:
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"

View File

@@ -9,16 +9,10 @@ on:
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
@@ -35,7 +29,7 @@ jobs:
python_version: "3.11"
pytorch: 2.1.2
axolotl_extras:
runs-on: [self-hosted, gpu, docker]
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -56,27 +50,16 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
load: true
build-args: |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
- name: Unit Tests
run: |
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
- name: Push to Docker Hub
if: github.event_name != 'pull_request'
run: |
docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
if [ -n "$latest_tag" ]; then
docker push "$latest_tag"
fi
build-axolotl-runpod:
needs: build-axolotl
@@ -85,11 +68,6 @@ jobs:
strategy:
matrix:
include:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
@@ -106,7 +84,7 @@ jobs:
python_version: "3.11"
pytorch: 2.1.2
axolotl_extras:
runs-on: [self-hosted, gpu, docker]
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -69,7 +69,7 @@ jobs:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
pytorch: 2.1.2
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"

View File

@@ -32,6 +32,9 @@ ignore_missing_imports = True
[mypy-bitsandbytes]
ignore_missing_imports = True
[mypy-requests]
ignore_missing_imports = True
[mypy-datasets]
ignore_missing_imports = True

View File

@@ -25,8 +25,8 @@ Features:
- [Installation](#installation)
- [Docker](#docker)
- [Conda/Pip venv](#condapip-venv)
- [Cloud GPU](#cloud-gpu) - Runpod, Latitude
- [LambdaLabs](#lambdalabs)
- [Cloud GPU](#cloud-gpu) - Latitude.sh, RunPod
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
- [Windows](#windows)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Dataset](#dataset)
@@ -34,7 +34,7 @@ Features:
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
- [Config](#config)
- [Train](#train)
- [Inference](#inference)
- [Inference](#inference-playground)
- [Merge LORA to Base](#merge-lora-to-base)
- [Special Tokens](#special-tokens)
- Advanced Topics
@@ -121,6 +121,10 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
# gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/openllama-3b/lora.yml
```
## Installation
@@ -182,9 +186,13 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
#### LambdaLabs
#### Bare Metal Cloud GPU
##### LambdaLabs
<details>
<summary>Click to Expand</summary>
@@ -464,6 +472,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
dataset:
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
...
# Loading Data From a Public URL
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
dataset:
- path: https://some.url.com/yourdata.jsonl # The URL should be a direct link to the file you wish to load. URLs must use HTTPS protocol, not HTTP.
ds_type: json # this is the default, see other options below.
```
- loading
@@ -720,6 +734,8 @@ peft:
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
relora_steps: # Number of steps per ReLoRA restart
relora_warmup_steps: # Number of per-restart warmup steps
relora_anneal_steps: # Number of anneal steps for each relora cycle
relora_prune_ratio: # threshold for optimizer magnitude when pruning
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it
@@ -768,7 +784,8 @@ save_total_limit: # Checkpoints saved at a time
max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
@@ -797,6 +814,7 @@ early_stopping_patience: 3
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
# For one_cycle optim
lr_div_factor: # Learning rate div factor
@@ -976,6 +994,9 @@ Run
accelerate launch -m axolotl.cli.train your_config.yml
```
> [!TIP]
> You can also reference a config file that is hosted on a public URL, for example `accelerate launch -m axolotl.cli.train https://yourdomain.com/your_config.yml`
#### Preprocess dataset
You can optionally pre-tokenize dataset with the following before finetuning.
@@ -1200,6 +1221,12 @@ pre-commit install
pytest tests/
```
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
<a href="https://github.com/openaccess-ai-collective/axolotl/graphs/contributors">
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
</a>
## Sponsors 🤝❤
OpenAccess AI Collective is run by volunteer contributors such as [winglian](https://github.com/winglian),

View File

@@ -2,7 +2,6 @@
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -1,7 +1,6 @@
base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -177,6 +177,24 @@
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Play with inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
" --qlora_model_dir=\"./qlora-out\" --gradio"
]
}
],
"metadata": {

View File

@@ -2,7 +2,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: true
load_in_4bit: false
gptq: false

View File

@@ -5,7 +5,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
# enable 4bit for QLoRA
load_in_4bit: true

View File

@@ -2,7 +2,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
load_in_4bit: false
gptq: false

65
examples/gemma/qlora.yml Normal file
View File

@@ -0,0 +1,65 @@
# use google/gemma-7b if you have access
base_model: mhenrichsen/gemma-7b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
# huggingface repo
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
val_set_size: 0.1
output_dir: ./out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 3
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,7 +1,6 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false

View File

@@ -1,5 +1,4 @@
base_model: TheBloke/Llama-2-7B-GPTQ
is_llama_derived_model: false
gptq: true
gptq_disable_exllama: true
model_type: AutoModelForCausalLM

View File

@@ -1,7 +1,6 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false
@@ -60,7 +59,7 @@ s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -1,7 +1,6 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
@@ -57,7 +56,7 @@ s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -1,7 +1,6 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -1,7 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -49,7 +49,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -2,7 +2,6 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: true
load_in_4bit: false
@@ -61,7 +60,7 @@ flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
#default deepspeed, can use more aggresive if needed like zero2, zero3

View File

@@ -1,7 +1,6 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false
load_in_4bit: false
@@ -49,7 +48,7 @@ flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -81,7 +81,7 @@ loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed: deepspeed_configs/zero2.json

View File

@@ -1,7 +1,6 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false
load_in_4bit: true
@@ -68,7 +67,7 @@ loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -2,7 +2,6 @@ base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true
load_in_8bit: true
@@ -58,7 +57,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -2,7 +2,6 @@ base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true
load_in_8bit: false
@@ -58,7 +57,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -0,0 +1,64 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
warmup_steps: 10
evals_per_epoch: 0
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,7 +1,6 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -2,7 +2,6 @@ base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false

View File

@@ -1,7 +1,6 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -1,8 +1,7 @@
base_model: 01-ai/Yi-34B-Chat
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: false
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
@@ -29,7 +28,7 @@ num_epochs: 1
val_set_size: 0.1
evals_per_epoch: 5
eval_table_size:
eval_table_max_new_tokens: 128
eval_max_new_tokens: 128
eval_sample_packing: false
eval_batch_size: 1

View File

@@ -1,3 +1,4 @@
pre-commit
black
mypy
types-requests

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc
transformers @ git+https://github.com/huggingface/transformers.git@ae49b218c3d718df90d8e4a109016450fb8f0632
tokenizers==0.15.0
bitsandbytes>=0.41.1
accelerate==0.26.1
@@ -9,8 +9,9 @@ deepspeed>=0.13.1
addict
fire
PyYAML>=6.0
requests
datasets>=2.15.0
flash-attn==2.3.3
flash-attn==2.5.5
sentencepiece
wandb
einops
@@ -22,7 +23,7 @@ numba
numpy>=1.24.4
mlflow
# qlora things
evaluate==0.4.0
evaluate==0.4.1
scipy
scikit-learn==1.2.2
pynvml

View File

@@ -1,5 +1,7 @@
"""setup.py for axolotl"""
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from setuptools import find_packages, setup
@@ -26,11 +28,25 @@ def parse_requirements():
_install_requires.append(line)
try:
torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}")
if torch_version.startswith("2.1."):
if "Darwin" in platform.system():
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers>=0.0.23")
else:
torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if version_match:
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = (
int(patch) if patch is not None else 0
) # Default patch to 0 if not present
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 1):
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers>=0.0.23")
except PackageNotFoundError:
pass
@@ -51,7 +67,7 @@ setup(
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn==2.5.0",
"flash-attn==2.5.5",
],
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",

View File

@@ -1,16 +1,20 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import json
import logging
import math
import os
import random
import sys
import tempfile
from pathlib import Path
from threading import Thread
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
import gradio as gr
import requests
import torch
import yaml
@@ -59,6 +63,52 @@ def print_axolotl_text_art(suffix=None):
print(ascii_art)
def check_remote_config(config: Union[str, Path]):
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
if not (isinstance(config, str) and config.startswith("https://")):
return config # Return the original value if it's not a valid URL
filename = os.path.basename(urlparse(config).path)
temp_dir = tempfile.mkdtemp()
try:
response = requests.get(config, timeout=30)
response.raise_for_status() # Check for HTTP errors
content = response.content
try:
# Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML
json.loads(content)
# Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link
LOG.warning(
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
)
except json.JSONDecodeError:
# If it's not valid JSON, verify it's valid YAML
try:
yaml.safe_load(content)
except yaml.YAMLError as err:
raise ValueError(
f"Failed to parse the content at {config} as YAML: {err}"
) from err
# Write the content to a file if it's valid YAML (or JSON treated as YAML)
output_path = Path(temp_dir) / filename
with open(output_path, "wb") as file:
file.write(content)
LOG.info(
f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n"
)
return output_path
except requests.RequestException as err:
# This catches all requests-related exceptions including HTTPError
raise RuntimeError(f"Failed to download {config}: {err}") from err
except Exception as err:
# Catch-all for any other exceptions
raise err
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to submit): ")
instruction = ""
@@ -270,9 +320,10 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
return not any(el in list2 for el in list1)
def load_cfg(config: Path = Path("examples/"), **kwargs):
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
config = check_remote_config(config)
if Path(config).is_dir():
config = choose_config(config)
config = choose_config(Path(config))
# load the config from the yaml file
with open(config, encoding="utf-8") as file:

View File

@@ -3,6 +3,7 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Union
import fire
import transformers
@@ -23,7 +24,7 @@ from axolotl.prompt_strategies.sharegpt import register_chatml_template
LOG = logging.getLogger("axolotl.cli.preprocess")
def do_cli(config: Path = Path("examples/"), **kwargs):
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)

View File

@@ -3,6 +3,7 @@ CLI to shard a trained model into 10GiB chunks
"""
import logging
from pathlib import Path
from typing import Union
import fire
import transformers
@@ -25,7 +26,7 @@ def shard(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_cli(config: Path = Path("examples/"), **kwargs):
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)

View File

@@ -3,7 +3,7 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Tuple
from typing import Tuple, Union
import fire
from transformers.hf_argparser import HfArgumentParser
@@ -25,7 +25,7 @@ from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Path = Path("examples/"), **kwargs):
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser((TrainerCliArgs))

View File

@@ -28,6 +28,7 @@ from transformers import (
from transformers.trainer_utils import seed_worker
from trl import DPOTrainer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
@@ -37,6 +38,7 @@ from axolotl.utils.callbacks import (
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import (
@@ -49,6 +51,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
)
try:
@@ -130,6 +133,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
@@ -142,6 +149,9 @@ class AxolotlTrainingArguments(TrainingArguments):
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
@@ -159,6 +169,12 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
class AxolotlTrainer(Trainer):
@@ -216,6 +232,16 @@ class AxolotlTrainer(Trainer):
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
@@ -642,6 +668,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.do_causal_lm_eval:
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
trainer, self.tokenizer
)
callbacks.append(CausalLMBenchEvalCallback(self.cfg))
if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
@@ -790,6 +821,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model:
training_arguments_kwargs[
"metric_for_best_model"
@@ -850,8 +883,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.load_best_model_at_end is not False
or self.cfg.early_stopping_patience
)
and not self.cfg.test_datasets
and self.cfg.val_set_size > 0
and (
(not self.cfg.test_datasets and self.cfg.val_set_size > 0)
or (self.cfg.test_datasets and self.cfg.val_set_size == 0)
)
and self.cfg.save_steps
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
@@ -882,6 +917,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_arguments_kwargs[
"cosine_constant_lr_ratio"
] = self.cfg.cosine_constant_lr_ratio
training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
)
@@ -899,9 +937,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"sample_packing_seq_len_multiplier"
] = self.cfg.micro_batch_size
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps
if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs[
"relora_warmup_steps"
] = self.cfg.relora_warmup_steps
if self.cfg.relora_anneal_steps:
training_arguments_kwargs[
"relora_anneal_steps"
] = self.cfg.relora_anneal_steps
if self.cfg.relora_prune_ratio:
training_arguments_kwargs[
"relora_prune_ratio"
] = self.cfg.relora_prune_ratio
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
@@ -994,7 +1043,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
]
]
if use_batch_sampler_collator:
if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
self.cfg.model_config_type in ["llama"]

View File

@@ -1,12 +0,0 @@
"""
Patches to support multipack for falcon
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_falcon_attn_with_multipack_flash_attn():
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -275,7 +275,9 @@ def flashattn_forward_with_s2attn(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = self.rotary_emb(
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
@@ -425,7 +427,9 @@ def flashattn_forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = self.rotary_emb(
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
@@ -688,6 +692,9 @@ def llama_model_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[ # pylint: disable=unused-argument
torch.LongTensor
] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions

View File

@@ -2,9 +2,6 @@
Patches to support multipack for mixtral
"""
import torch
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def patch_mixtral_moe_forward_zero3() -> None:
@@ -51,11 +48,3 @@ def patch_mixtral_moe_forward_zero3() -> None:
MixtralBLockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward
def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if for_zero3:
patch_mixtral_moe_forward_zero3()

View File

@@ -0,0 +1,34 @@
"""multipack patching for v2 of sample packing"""
import transformers
from transformers.integrations import is_deepspeed_zero3_enabled
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"]
def patch_for_multipack(model_type):
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -1,12 +0,0 @@
"""
Patches to support multipack for phi2
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_phi_attn_with_multipack_flash_attn():
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -1,12 +0,0 @@
"""
Patches to support multipack for qwen2
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_qwen2_attn_with_multipack_flash_attn():
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -46,8 +46,9 @@ def reset_optimizer(
*,
reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: list[str],
prune_ratio: float = 0.9,
):
pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9)
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
n_zeros = 0
n_total = 0
@@ -159,6 +160,7 @@ class ReLoRACallback(TrainerCallback):
optimizer,
reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys,
prune_ratio=args.relora_prune_ratio,
)
if self.quantized:

View File

@@ -186,8 +186,8 @@ def mask_2d_to_4d(
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1).to(dtype),
torch.tensor(0).to(dtype),
torch.tensor(1, device=mask.device).to(dtype),
torch.tensor(0, device=mask.device).to(dtype),
)
# Create a block-diagonal mask.

View File

@@ -208,7 +208,10 @@ def train(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id:
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
try:
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
except AttributeError:
pass
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated
trainer.push_to_hub()

View File

@@ -47,6 +47,12 @@ def gpu_memory_usage_all(device=0):
return usage, reserved - usage, max(0, smi - reserved)
def mps_memory_usage_all():
usage = torch.mps.current_allocated_memory() / 1024.0**3
reserved = torch.mps.driver_allocated_memory() / 1024.0**3
return usage, reserved - usage, 0
@check_cuda_device(0.0)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
@@ -63,7 +69,10 @@ def gpu_memory_usage_smi(device=0):
def log_gpu_memory_usage(log, msg, device):
usage, cache, misc = gpu_memory_usage_all(device)
if torch.backends.mps.is_available():
usage, cache, misc = mps_memory_usage_all()
else:
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
if cache > 0:
extras.append(f"+{cache:.03f}GB cache")

View File

@@ -62,7 +62,6 @@ class EvalFirstStepCallback(
):
if (
args.evaluation_strategy == IntervalStrategy.STEPS
and args.eval_steps < 1.0
and state.global_step == 1
):
control.should_evaluate = True
@@ -361,6 +360,187 @@ def bench_eval_callback_factory(trainer, tokenizer):
return BenchEvalCallback
def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
class CausalLMBenchEvalCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
self.metrics = self.__maybe_load_metrics()
def __maybe_load_metrics(self):
metrics = {}
for metric in self.cfg.eval_causal_lm_metrics:
try:
metrics[metric] = evaluate.load(metric)
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.warning(f"{metric}: {exc.args}")
return metrics
def on_evaluate(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
train_dataloader, # pylint: disable=unused-argument
eval_dataloader,
**kwargs, # pylint: disable=unused-argument
):
trainer.model.eval()
device = torch.device(self.cfg.device)
# pylint: disable=duplicate-code
generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
def find_ranges(lst):
ranges = []
start = 0
for i in range(1, len(lst)):
if lst[i] == 0:
ranges.append((start, i - 1))
start = i
end = len(lst) - 1
ranges.append((start, end))
return ranges
def compute(metric: evaluate.Metric, **kwargs):
# safely compute a metric and return the score if the format is correct
metric_score = None
try:
metric_score = metric.compute(**kwargs)
return (
metric_score["score"]
if "score" in metric_score
else metric_score["mean_score"]
)
except Exception: # pylint: disable=broad-exception-caught
LOG.debug(
f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
)
return metric_score
def evaluate_preds(sources, predictions, references):
scores = {}
for metric_name, metric in self.metrics.items():
score = compute(
metric,
references=references,
predictions=predictions,
sources=sources,
)
score = score or compute(
metric,
references=[[r] for r in references],
predictions=predictions,
)
scores[metric_name] = score
return scores
def predict_with_generate():
eval_src, eval_pred, eval_ref = [], [], []
for batch in tqdm(eval_dataloader):
batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)
if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
batch_pos_ids = [None] * len(batch["input_ids"])
prompt_token_ids_list = []
completion_token_ids_list = []
for input_ids_all, labels_all, pos_ids in zip(
batch_input_ids,
batch_labels,
batch_pos_ids,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
else:
pos_ranges = find_ranges(pos_ids)
for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue
input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]
tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)
prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)
completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)
prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)
with torch.no_grad():
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device)
predictions = trainer.model.generate(
**prompt_encoding, generation_config=generation_config
)
prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)
predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list, skip_special_tokens=True
)
eval_src.extend(prompt_texts)
eval_pred.extend(predicted_texts)
eval_ref.extend(completion_texts)
return eval_src, eval_pred, eval_ref
if is_main_process():
eval_preds = predict_with_generate()
trainer.log(evaluate_preds(*eval_preds))
return control
return CausalLMBenchEvalCallback
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""
@@ -388,7 +568,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
# pylint: disable=duplicate-code
generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_table_max_new_tokens,
max_new_tokens=self.cfg.eval_max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,

View File

@@ -56,7 +56,13 @@ def normalize_config(cfg):
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0
cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128
cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
"sacrebleu",
"comet",
"ter",
"chrf",
]
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
@@ -550,6 +556,21 @@ def validate_config(cfg):
if cfg.fsdp and "bnb" in cfg.optimizer:
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
raise ValueError(
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
)
if cfg.eval_causal_lm_metrics:
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
if not isinstance(cfg.eval_causal_lm_metrics, list):
raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(cfg.eval_causal_lm_metrics) - set(supported_metrics):
raise ValueError(
f"eval_causal_lm_metrics must be one of {supported_metrics}"
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -336,6 +336,16 @@ def load_tokenized_prepared_datasets(
split=None,
storage_options=storage_options,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
)
else:
if isinstance(config_dataset.data_files, str):
fp = hf_hub_download(

View File

@@ -29,6 +29,10 @@ from transformers import ( # noqa: F401
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates
@@ -299,8 +303,15 @@ def load_model(
shifted-sparse attention does not currently support sample packing."
)
# Modify all llama derived models in one block
if cfg.is_llama_derived_model:
if (
cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and cfg.flash_attention
and cfg.sample_packing
):
patch_for_multipack(cfg.model_config_type)
elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block
if cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
@@ -354,43 +365,6 @@ def load_model(
LOG.info("patching mistral with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if (
cfg.model_config_type == "mixtral"
and cfg.flash_attention
and cfg.sample_packing
):
from axolotl.monkeypatch.mixtral import (
replace_mixtral_attn_with_multipack_flash_attn,
)
LOG.info("patching mixtral with flash attention")
mixtral_patch_kwargs = {}
if is_deepspeed_zero3_enabled():
mixtral_patch_kwargs["for_zero3"] = True
replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.falcon import (
replace_falcon_attn_with_multipack_flash_attn,
)
LOG.info("patching falcon with flash attention")
replace_falcon_attn_with_multipack_flash_attn()
if cfg.model_config_type == "phi" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.phi import replace_phi_attn_with_multipack_flash_attn
LOG.info("patching phi with flash attention")
replace_phi_attn_with_multipack_flash_attn()
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.qwen2 import (
replace_qwen2_attn_with_multipack_flash_attn,
)
LOG.info("patching qwen2 with flash attention")
replace_qwen2_attn_with_multipack_flash_attn()
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
@@ -400,7 +374,7 @@ def load_model(
model_kwargs: Dict[str, Any] = {}
if cfg.model_kwargs:
for key, val in model_kwargs.items():
for key, val in cfg.model_kwargs.items():
model_kwargs[key] = val
max_memory = cfg.max_memory
@@ -435,6 +409,10 @@ def load_model(
model_kwargs["device_map"] = device_map
model_kwargs["torch_dtype"] = cfg.torch_dtype
if torch.backends.mps.is_available():
model_kwargs["device_map"] = "mps:0"
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
# if cfg.rl:
# if torch.cuda.device_count() > 1:
@@ -501,7 +479,7 @@ def load_model(
"flash_attention_2"
)
else:
if model_config.model_type in ["mixtral", "qwen2", "falcon", "phi"]:
if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
@@ -677,7 +655,7 @@ def load_model(
):
model.config.eos_token_id = tokenizer.eos_token_id
if hasattr(model, "device") and model.device.type == "cuda":
if hasattr(model, "device") and model.device.type in ("cuda", "mps"):
log_gpu_memory_usage(LOG, "after model load", model.device)
# make sure these are fp32 per Ramesh et al. (2021)

View File

@@ -52,7 +52,7 @@ def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
*,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float
num_cycles: float,
):
if current_step < num_warmup_steps:
return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
@@ -107,7 +107,7 @@ def _get_cosine_schedule_with_min_lr_lambda(
*,
num_warmup_steps: int,
num_training_steps: int,
min_lr_ratio: float
min_lr_ratio: float,
):
# Warm up
if current_step < num_warmup_steps:
@@ -140,3 +140,80 @@ def get_cosine_schedule_with_min_lr(
min_lr_ratio=min_lr_ratio,
)
return LambdaLR(optimizer, lr_lambda)
def _get_cosine_schedule_with_warmup_decay_constant_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
constant_lr_ratio: float,
min_lr_ratio: float,
num_cycles: float,
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
num_constant_steps = int(num_training_steps * constant_lr_ratio)
current_step = min(current_step, num_constant_steps)
progress = float(current_step - num_warmup_steps) / float(
max(1, num_constant_steps - num_warmup_steps)
)
return (
max(
0,
(1 - min_lr_ratio)
* 0.5
* (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
)
+ min_lr_ratio
)
def get_cosine_schedule_with_warmup_decay_constant(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
constant_lr_ratio: float,
min_lr_ratio: float,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Implementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (https://arxiv.org/pdf/2308.04014.pdf)
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to min_lr_ratio until num_training_steps * constant_lr_ratio, after constant_rate returns constant value of min_rate
, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
constant_lr_ratio: (`float`):
The ratio of num_training_steps to decrease by cosine function.
min_lr_ratio: (`float):
The ratio of maximum learning rate for cosine function to decay to minimum learning rate.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_cosine_schedule_with_warmup_decay_constant_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
constant_lr_ratio=constant_lr_ratio,
min_lr_ratio=min_lr_ratio,
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)

52
tests/test_schedulers.py Normal file
View File

@@ -0,0 +1,52 @@
"""
test module for the axolotl.utis.data module
"""
import unittest
import torch
from torch.optim import SGD
from axolotl.utils.schedulers import get_cosine_schedule_with_warmup_decay_constant
class TestCosineConstantLr(unittest.TestCase):
"""
test class for encode pretraining and md5 helper
"""
def setUp(self):
self.train_steps = 1000
self.warmup_steps = 10
self.min_lr_ratio = 0.1
self.constant_lr_ratio = 0.8
self._lr = 0.01
self.optimizer = SGD([torch.tensor(1)], lr=self._lr)
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
self.optimizer,
num_warmup_steps=self.warmup_steps,
num_training_steps=self.train_steps,
min_lr_ratio=self.min_lr_ratio,
constant_lr_ratio=self.constant_lr_ratio,
)
def test_schedulers(self):
self.assertEqual(self.lr_scheduler.get_last_lr()[0], 0)
for _ in range(self.warmup_steps):
self.lr_scheduler.step()
self.assertEqual(self.lr_scheduler.get_last_lr()[0], self._lr)
constant_step = int(self.train_steps * self.constant_lr_ratio)
remaining_step = self.train_steps - constant_step
for _ in range(constant_step):
self.lr_scheduler.step()
self.assertEqual(
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio
)
for _ in range(remaining_step):
self.lr_scheduler.step()
self.assertEqual(
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio
)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,98 +0,0 @@
"""
This module is used to launch Axolotl with user defined configurations.
"""
import gradio as gr
import yaml
def config(
base_model,
dataset,
dataset_type,
learn_rate,
gradient_accumulation_steps,
micro_batch_size,
seq_length,
num_epochs,
output_dir,
val_size,
):
"""
This function generates a configuration dictionary and saves it as a yaml file.
"""
config_dict = {
"base_model": base_model,
"datasets": [{"path": dataset, "type": dataset_type}],
"learning_rate": learn_rate,
"gradient_accumulation_steps": gradient_accumulation_steps,
"micro_batch_size": micro_batch_size,
"sequence_len": seq_length,
"num_epochs": num_epochs,
"output_dir": output_dir,
"val_set_size": val_size,
}
with open("config.yml", "w", encoding="utf-8") as file:
yaml.dump(config_dict, file)
print(config_dict)
return yaml.dump(config_dict)
with gr.Blocks(title="Axolotl Launcher") as demo:
gr.Markdown(
"""
# Axolotl Launcher
Fill out the required fields below to create a training run.
"""
)
with gr.Row():
base_model_name = gr.Textbox(
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", label="Base model"
)
mode = gr.Radio(
choices=["Full finetune", "QLoRA", "LoRA"],
label="Training mode",
info="FFT = 16 bit, Qlora = 4 bit, Lora = 8 bit",
)
with gr.Row():
dataset_path = gr.Textbox("mhenrichsen/alpaca_2k_test", label="Dataset")
dataset_type_name = gr.Dropdown(
choices=["alpaca", "sharegpt"], label="Dataset type", value="alpaca"
)
with gr.Accordion("Hyperparameters", open=False):
gr.Markdown("Choose hyperparameters")
with gr.Row():
learning_rate = gr.Number(0.000001, label="Learning rate")
gradient_accumulation_steps_count = gr.Number(
1, label="Gradient accumulation steps"
)
val_set_size_count = gr.Number(0, label="Validation size")
with gr.Row():
micro_batch_size_count = gr.Number(1, label="Micro batch size")
sequence_length = gr.Number(1024, label="Sequence length")
num_epochs_count = gr.Number(1, label="Epochs")
output_dir_path = gr.Textbox("./model-out", label="Output directory")
create_config = gr.Button("Create config")
output = gr.TextArea(label="Generated config")
create_config.click(
config,
inputs=[
base_model_name,
dataset_path,
dataset_type_name,
learning_rate,
gradient_accumulation_steps_count,
micro_batch_size_count,
sequence_length,
num_epochs_count,
output_dir_path,
val_set_size_count,
],
outputs=output,
)
demo.launch(debug=True, server_name="0.0.0.0", server_port=7860)