Compare commits

...

51 Commits

Author SHA1 Message Date
Wing Lian
d7ec10e337 add support for MoRA 2024-06-01 16:14:56 -04:00
Wing Lian
05b0bd08d2 need to add back drop_last for sampler (#1676) 2024-05-31 13:13:13 -04:00
Wing Lian
d4f6c65e4c cleanup the deepspeed proxy model at the end of training (#1675) 2024-05-30 13:40:35 -04:00
Wing Lian
a944f7b32b load explicit splits on datasets (#1652) 2024-05-29 22:27:59 -04:00
Wing Lian
9d4225a058 set chat_template in datasets config automatically (#1664)
* set chat_template in datasets config automatically

* dynamic chat_template, not jsut chatml
2024-05-29 22:27:26 -04:00
Wing Lian
f7332ac449 use mixins for orpo and kto configs so they work with axolotl customizations (#1674) 2024-05-29 22:27:00 -04:00
Wing Lian
16d46b74e4 re-enable phi for tests in modal ci (#1373) 2024-05-29 15:41:46 -04:00
Wing Lian
a6b37bdeb4 revert multipack batch sampler changes (#1672)
* revert multipack batch sampler changes

* fix default val for drop_last
2024-05-29 11:51:18 -04:00
Wing Lian
b7520801a3 handle the system role too for chat templates (#1671) 2024-05-29 10:21:11 -04:00
Wing Lian
fe650dd326 make sure the CI fails when pytest script fails (#1669)
* make sure the pytest script fails

* make sure the defaults come through for tests

* make sure tensorboard is loaded for test assertion
2024-05-29 10:12:11 -04:00
Abe Voelker
49b967b62f Fix README quick start example usage model dirs (#1668) 2024-05-28 18:10:40 -04:00
Seungduk Kim
65db903714 Correct name of MixtralBlockSparseTop2MLP (L -> l) (#1667) 2024-05-28 18:10:29 -04:00
Davide Caroselli
6a5a725f10 Fix: ensure correct handling of val_set_size as float or int (#1655)
* Fix: ensure correct handling of val_set_size as float or int

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-05-28 12:00:32 -04:00
Wing Lian
f5febc729a fix lint issue that snuck through (#1665) 2024-05-28 11:36:50 -04:00
Faria Huq
230e0ac363 Fix Lora config error for Llama3 (#1659)
The current yml code throws an error: ValueError: Please set lora_modules_to_save to [`embed_tokens`, `lm_head`] when using an adapter and changing the special tokens.

I added the required changes to resolve it
2024-05-28 11:25:08 -04:00
Keith Stevens
cc11c6bce2 Generalizing the chat_template prompt strategy (#1660) [skip ci]
The strategy now supports configuring several fields: * The data field holding message arrays * the role and
content fields for each message * role mapping from source to target types

additionally this adds a sample llama3-8b instruct template using the chat template
2024-05-28 11:24:13 -04:00
Maciek
5f91064040 Fix Google Colab notebook 2024-05 (#1662) [skip ci]
* include mlflow installation in the colab notebook

Without explicitly installing mlflow the `accelerate launch` command fails.

* update the colab noteboko to use the latest tinyllama config
2024-05-28 11:23:52 -04:00
Wing Lian
ef223519c9 update deps (#1663) [skip ci]
* update deps and tweak logic so axolotl is pip installable

* use vcs url format

* using dependency_links isn't supported per docs)
2024-05-28 11:23:34 -04:00
Charles Frye
8a20a7b711 document how to use share_strategy="no" (#1653) [skip ci]
The literal value `no` is parsed in some YAML parsers to the boolean `False`, which fails Pydantic validation. To be sure that the value is parsed to the string `"no"`, the value should be enclosed in quotes. [Discussion on StackOverflow](https://stackoverflow.com/questions/53648244/specifying-the-string-value-yes-in-yaml).
2024-05-24 14:15:44 -04:00
Wing Lian
367b2e879b Switch to parallel FFD bin packing algorithm. (#1619)
* Switch to parallel FFD bin packing algorithm.

Add support for packing in a distributed context.
Add packing efficiency estimate back.

* revert changes to distributed code

* chore: lint

* fix config w new params for packing test

* add sample_packing_group_size and sample_packing_bin_size to cfg schema

* fix lamdbda function

* fix sampler/dataloader calculations for packing

---------

Co-authored-by: dsesclei <dave@sescleifer.com>
2024-05-23 17:32:14 -04:00
Wing Lian
bbfed318bc support for custom messages field in sharegpt (#1651) 2024-05-23 13:03:22 -04:00
Jaydeep Thik
84bb8061ba Update tiny-llama qlora.yml addressing eval packing error (#1638) 2024-05-22 08:34:06 -04:00
George Grigorev
a27d5e1f4e enable loraplus setting for dpo trainer (#1646) 2024-05-22 08:29:06 -04:00
Wing Lian
6299eb5919 allow report_to for multiple providers (#1647) 2024-05-22 08:27:44 -04:00
Leonard
7c2bf3091f Fix llama3 chat_template (extra <|eot_id|> on last turn) (#1635)
* Fix llama3 chat_template (the {{eos_token}} leads to an extra <|eot_id|> being added in the last turn). Output now matches official Llama 3 Instruct model

* add tests

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-05-21 09:08:53 -04:00
Ben Redmond
22ae21a6c2 Add KTO support (#1640)
* add kto support

* test cleanup

* fix outdated comment

* fix llama3 ultra

* chore: lint

* update to use rl_beta instead of dpo_beta

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-05-20 16:05:16 -04:00
Wing Lian
ba45531802 fixes to save on fractional save_steps (#1643) 2024-05-20 14:24:45 -04:00
Wing Lian
8a1572a831 Unsloth optims for Llama (#1609)
* WIP for unsloth integrations

* import the unsloth code in the right context

* add unsloth mlp, qkv, o lora optimizations

* apply unsloth mlp and qkv kernels
2024-05-20 09:55:06 -04:00
Jeffrey Quesnelle
702a669cad add save_only_model option (#1634) 2024-05-17 00:23:18 -04:00
Wing Lian
891ae8aa13 fix ray install (#1630) 2024-05-16 01:25:42 -04:00
Wing Lian
0c49ecc429 more fixes to work with runpod + skypilot (#1629) 2024-05-16 00:05:56 -04:00
Wing Lian
60113437e4 cloud image w/o tmux (#1628) 2024-05-15 22:27:40 -04:00
Wing Lian
419b2a6a98 install rsync too (#1627) 2024-05-15 21:36:00 -04:00
Wing Lian
2501a371c6 fix setting the authorized keys when there are more than one in the env var (#1626) 2024-05-15 20:48:56 -04:00
Wing Lian
e6937e884b fix symlinks for axolotl outputs (#1625) 2024-05-15 19:41:45 -04:00
Wing Lian
039e2a0370 bump versions of deps (#1621)
* bump versions of deps

* bump transformers too

* fix xformers deps and include s3fs install
2024-05-15 13:27:44 -04:00
Wing Lian
4fde300e5f update outputs path so that we can mount workspace to /workspace/data (#1623)
* update outputs path so that we can mount workspace to /workspace/data

* fix ln order
2024-05-15 12:44:13 -04:00
Wing Lian
3319780300 update torch 2.2.1 -> 2.2.2 (#1622) 2024-05-15 09:45:27 -04:00
bofeng huang
81da7d2531 Fix total_num_steps (#1566)
* Fix `total_num_steps`

* Fix total_num_steps

* lint
2024-05-14 20:10:37 -04:00
Ali Mosavian
1e1921b794 FIX: max_length and max_prompt_length was not being sent to ORPOTrainer (#1584)
* FIX: TRL trainer preprocessing step was running in one process

* FIX: max_length and max_prompt_length was not being sent to ORPOTrainer

* FIX: Change ORPO max prompt length to 1/4 of max length, otherwise we get strange behaviour

* FIX: Removed change from a different PR

* FIX: Black fix

* explicitly set max prompt len for orpo config

---------

Co-authored-by: Ali Mosavian <ali.mosavian@kry.se>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-05-14 08:51:17 -04:00
Wing Lian
1634ac82e0 make sure to save on the last step (#1615) 2024-05-14 08:48:39 -04:00
Wing Lian
02982733ec fix attention mask collation (#1603) 2024-05-14 08:17:30 -04:00
Chansung Park
5d97e65f95 add dstack section (#1612) [skip ci]
* add dstack section

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-05-14 08:13:45 -04:00
Wing Lian
2147cf6837 Llama3 dpo (#1610)
* add dpo llama3

* fix dpo bos and eos

* bos token gets added automatically by the tokenizer

* explicit <|end_of_text|> not needed, as eot_id is sufficient

---------

Co-authored-by: Nero10578 <owenarliawan@gmail.com>
2024-05-11 18:29:03 -04:00
Ram
50421c8b1d feat: Add LLaMA-3 instruct prompt strategies for fine-tuning (#1553)
* Add prompt strategies

* Update modified URL

* Update modified URL

* Update fastchat_conversation_turns.py

* Update register function

* Remove extra /n for system prompt

* Fix return

* Fix BOS

* Update requirements, pylint

* Linting

* Linting

* fix tuples, make sure to set system message in template

* tests for llama3 tokenization

* fix conditionals for loading chat template

---------

Co-authored-by: Ram <ram@Rams-MacBook-Pro.local>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-05-11 00:08:04 -04:00
Antoni-Joan Solergibert
b32c08f8cc adding llama3 fastchat conversation monkeypatch (#1539)
* adding llama3 fastchat conversation monkeypatch

* Updated conversation turns to work with PR3259 of FastChat

* fixed bos token

* bump fastchat version

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-05-10 10:40:05 -04:00
Wing Lian
fff06af8d0 ignore the fsdp_config section too (#1606) [skip ci] 2024-05-09 13:30:39 -04:00
Wing Lian
796a085b2f make sure to save the lora adapter at the end of RL/dpo training (#1573) 2024-05-08 10:39:33 -04:00
Wing Lian
cb78a36374 improve tool handling roles (#1587) 2024-05-07 11:30:40 -04:00
NanoCode012
8b9c15b17f feat: exclude mamba blocks for jamba (#1578) 2024-05-07 22:52:57 +09:00
Chirag Jain
9e1480e9ca Pass deepspeed and fsdp as None explicitly when merging adapters to allow custom device_map (#1575) 2024-05-07 22:47:55 +09:00
117 changed files with 2190 additions and 535 deletions

View File

@@ -30,7 +30,7 @@ jobs:
- cuda: "121" - cuda: "121"
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.2.1 pytorch: 2.2.2
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121" - cuda: "121"
cuda_version: 12.1.0 cuda_version: 12.1.0

View File

@@ -28,7 +28,7 @@ jobs:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.2.1 pytorch: 2.2.2
axolotl_extras: axolotl_extras:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
@@ -89,7 +89,7 @@ jobs:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.2.1 pytorch: 2.2.2
axolotl_extras: axolotl_extras:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
@@ -125,3 +125,45 @@ jobs:
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-no-tmux:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
include:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: winglian/axolotl-cloud-term
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Build
uses: docker/build-push-action@v5
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud-no-tmux
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -27,7 +27,7 @@ jobs:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.2.1 pytorch: 2.2.2
axolotl_extras: axolotl_extras:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
@@ -89,7 +89,7 @@ jobs:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.2.1 pytorch: 2.2.2
axolotl_extras: axolotl_extras:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0

View File

@@ -82,7 +82,12 @@ jobs:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.2.1 pytorch: 2.2.2
num_gpus: 1
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
num_gpus: 1 num_gpus: 1
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -34,6 +34,7 @@ Features:
- [Mac](#mac) - [Mac](#mac)
- [Google Colab](#google-colab) - [Google Colab](#google-colab)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot) - [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
- [Dataset](#dataset) - [Dataset](#dataset)
- [Config](#config) - [Config](#config)
- [Train](#train) - [Train](#train)
@@ -123,11 +124,11 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference # inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --lora_model_dir="./outputs/lora-out"
# gradio # gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --gradio --lora_model_dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL # remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml # Note: the yaml config must directly link to the **raw** yaml
@@ -292,6 +293,42 @@ HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
``` ```
#### Launching on public clouds via dstack
To launch on GPU instance (both on-demand and spot instances) on public clouds (GCP, AWS, Azure, Lambda Labs, TensorDock, Vast.ai, and CUDO), you can use [dstack](https://dstack.ai/).
Write a job description in YAML as below:
```yaml
# dstack.yaml
type: task
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.2
env:
- HUGGING_FACE_HUB_TOKEN
- WANDB_API_KEY
commands:
- accelerate launch -m axolotl.cli.train config.yaml
ports:
- 6006
resources:
gpu:
memory: 24GB..
count: 2
```
then, simply run the job with `dstack run` command. Append `--spot` option if you want spot instance. `dstack run` command will show you the instance with cheapest price across multi cloud services:
```bash
pip install dstack
HUGGING_FACE_HUB_TOKEN=xxx WANDB_API_KEY=xxx dstack run . -f dstack.yaml # --spot
```
For further and fine-grained use cases, please refer to the official [dstack documents](https://dstack.ai/docs/) and the detailed description of [axolotl example](https://github.com/dstackai/dstack/tree/master/examples/fine-tuning/axolotl) on the official repository.
### Dataset ### Dataset
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field. Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.

View File

@@ -1,4 +1,5 @@
#!/bin/bash #!/bin/bash
set -e
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest /workspace/axolotl/tests/e2e/patched/ pytest /workspace/axolotl/tests/e2e/patched/

View File

@@ -11,7 +11,7 @@ ARG PYTORCH_VERSION="2.1.2"
ENV PYTORCH_VERSION=$PYTORCH_VERSION ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs
WORKDIR /workspace WORKDIR /workspace

View File

@@ -0,0 +1,27 @@
ARG BASE_TAG=main
FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
EXPOSE 8888
EXPOSE 22
COPY scripts/cloud-entrypoint-term.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux sudo && \
pip3 install -U --no-cache-dir grpcio ray[default]==2.9.3 && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"]

View File

@@ -186,6 +186,11 @@ eval_sample_packing:
# The trainer will provide recommended values for these values. # The trainer will provide recommended values for these values.
sample_packing_eff_est: sample_packing_eff_est:
total_num_tokens: total_num_tokens:
# Increasing the following values helps with packing, but usually only slightly (<%1.)
# The number of samples packed at a time.
sample_packing_group_size: 100000
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
sample_packing_bin_size: 200
# Passed through to transformers when loading the model when launched without accelerate # Passed through to transformers when loading the model when launched without accelerate
# Use `sequential` when training w/ model parallelism to limit memory # Use `sequential` when training w/ model parallelism to limit memory
@@ -285,7 +290,7 @@ lr_quadratic_warmup:
logging_steps: logging_steps:
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
save_strategy: # Set to `no` to skip checkpoint saves save_strategy: # Set to `"no"` to skip checkpoint saves
save_steps: # Leave empty to save at each epoch save_steps: # Leave empty to save at each epoch
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
save_total_limit: # Checkpoints saved at a time save_total_limit: # Checkpoints saved at a time

View File

@@ -38,7 +38,7 @@ wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: btlm-out output_dir: ./outputs/btlm-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 1 num_epochs: 1

View File

@@ -25,7 +25,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
batch_size: 4 batch_size: 4
micro_batch_size: 4 micro_batch_size: 4
num_epochs: 2 num_epochs: 2

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./outputs/lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./outputs/lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./outputs/lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -1,216 +1,223 @@
{ {
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "AKjdG7tbTb-n" "id": "AKjdG7tbTb-n"
}, },
"source": [ "source": [
"# Example notebook for running Axolotl on google colab" "# Example notebook for running Axolotl on google colab"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RcbNpOgWRcii"
},
"outputs": [],
"source": [
"import torch\n",
"# Check so there is a gpu available, a T4(free tier) is enough to run this notebook\n",
"assert (torch.cuda.is_available()==True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h3nLav8oTRA5"
},
"source": [
"## Install Axolotl and dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3c3yGAwnOIdi",
"outputId": "e3777b5a-40ef-424f-e181-62dfecd1dd01"
},
"outputs": [],
"source": [
"!pip install torch==\"2.1.2\"\n",
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.5.0\"\n",
"!pip install deepspeed==\"0.13.1\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BW2MFr7HTjub"
},
"source": [
"## Create an yaml config file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pkF2dSoQEUN"
},
"outputs": [],
"source": [
"import yaml\n",
"\n",
"# Your YAML string\n",
"yaml_string = \"\"\"\n",
"base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n",
"model_type: LlamaForCausalLM\n",
"tokenizer_type: LlamaTokenizer\n",
"is_llama_derived_model: true\n",
"\n",
"load_in_8bit: false\n",
"load_in_4bit: true\n",
"strict: false\n",
"\n",
"datasets:\n",
" - path: mhenrichsen/alpaca_2k_test\n",
" type: alpaca\n",
"dataset_prepared_path:\n",
"val_set_size: 0.05\n",
"output_dir: ./qlora-out\n",
"\n",
"adapter: qlora\n",
"lora_model_dir:\n",
"\n",
"sequence_len: 1096\n",
"sample_packing: true\n",
"pad_to_sequence_len: true\n",
"\n",
"lora_r: 32\n",
"lora_alpha: 16\n",
"lora_dropout: 0.05\n",
"lora_target_modules:\n",
"lora_target_linear: true\n",
"lora_fan_in_fan_out:\n",
"\n",
"wandb_project:\n",
"wandb_entity:\n",
"wandb_watch:\n",
"wandb_name:\n",
"wandb_log_model:\n",
"\n",
"mlflow_experiment_name: colab-example\n",
"\n",
"gradient_accumulation_steps: 1\n",
"micro_batch_size: 1\n",
"num_epochs: 4\n",
"max_steps: 20\n",
"optimizer: paged_adamw_32bit\n",
"lr_scheduler: cosine\n",
"learning_rate: 0.0002\n",
"\n",
"train_on_inputs: false\n",
"group_by_length: false\n",
"bf16: false\n",
"fp16: true\n",
"tf32: false\n",
"\n",
"gradient_checkpointing: true\n",
"early_stopping_patience:\n",
"resume_from_checkpoint:\n",
"local_rank:\n",
"logging_steps: 1\n",
"xformers_attention:\n",
"flash_attention: false\n",
"\n",
"warmup_steps: 10\n",
"evals_per_epoch:\n",
"saves_per_epoch:\n",
"debug:\n",
"deepspeed:\n",
"weight_decay: 0.0\n",
"fsdp:\n",
"fsdp_config:\n",
"special_tokens:\n",
"\n",
"\"\"\"\n",
"\n",
"# Convert the YAML string to a Python dictionary\n",
"yaml_dict = yaml.safe_load(yaml_string)\n",
"\n",
"# Specify your file path\n",
"file_path = 'test_axolotl.yaml'\n",
"\n",
"# Write the YAML file\n",
"with open(file_path, 'w') as file:\n",
" yaml.dump(yaml_dict, file)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bidoj8YLTusD"
},
"source": [
"## Launch the training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ydTI2Jk2RStU",
"outputId": "d6d0df17-4b53-439c-c802-22c0456d301b"
},
"outputs": [],
"source": [
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Play with inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
" --qlora_model_dir=\"./qlora-out\" --gradio"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
}, },
"nbformat": 4, {
"nbformat_minor": 0 "cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RcbNpOgWRcii"
},
"outputs": [],
"source": [
"import torch\n",
"# Check so there is a gpu available, a T4(free tier) is enough to run this notebook\n",
"assert (torch.cuda.is_available()==True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h3nLav8oTRA5"
},
"source": [
"## Install Axolotl and dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3c3yGAwnOIdi",
"outputId": "e3777b5a-40ef-424f-e181-62dfecd1dd01"
},
"outputs": [],
"source": [
"!pip install torch==\"2.1.2\"\n",
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.5.0\"\n",
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BW2MFr7HTjub"
},
"source": [
"## Create an yaml config file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pkF2dSoQEUN"
},
"outputs": [],
"source": [
"import yaml\n",
"\n",
"# Your YAML string\n",
"yaml_string = \"\"\"\n",
"base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n",
"model_type: LlamaForCausalLM\n",
"tokenizer_type: LlamaTokenizer\n",
"\n",
"load_in_8bit: false\n",
"load_in_4bit: true\n",
"strict: false\n",
"\n",
"datasets:\n",
" - path: mhenrichsen/alpaca_2k_test\n",
" type: alpaca\n",
"dataset_prepared_path:\n",
"val_set_size: 0.05\n",
"output_dir: ./outputs/qlora-out\n",
"\n",
"adapter: qlora\n",
"lora_model_dir:\n",
"\n",
"sequence_len: 4096\n",
"sample_packing: true\n",
"eval_sample_packing: false\n",
"pad_to_sequence_len: true\n",
"\n",
"lora_r: 32\n",
"lora_alpha: 16\n",
"lora_dropout: 0.05\n",
"lora_target_modules:\n",
"lora_target_linear: true\n",
"lora_fan_in_fan_out:\n",
"\n",
"wandb_project:\n",
"wandb_entity:\n",
"wandb_watch:\n",
"wandb_name:\n",
"wandb_log_model:\n",
"\n",
"gradient_accumulation_steps: 4\n",
"micro_batch_size: 2\n",
"num_epochs: 4\n",
"optimizer: paged_adamw_32bit\n",
"lr_scheduler: cosine\n",
"learning_rate: 0.0002\n",
"\n",
"train_on_inputs: false\n",
"group_by_length: false\n",
"bf16: auto\n",
"fp16:\n",
"tf32: false\n",
"\n",
"gradient_checkpointing: true\n",
"early_stopping_patience:\n",
"resume_from_checkpoint:\n",
"local_rank:\n",
"logging_steps: 1\n",
"xformers_attention:\n",
"flash_attention: true\n",
"\n",
"warmup_steps: 10\n",
"evals_per_epoch: 4\n",
"saves_per_epoch: 1\n",
"debug:\n",
"deepspeed:\n",
"weight_decay: 0.0\n",
"fsdp:\n",
"fsdp_config:\n",
"special_tokens:\n",
"\n",
"\"\"\"\n",
"\n",
"# Convert the YAML string to a Python dictionary\n",
"yaml_dict = yaml.safe_load(yaml_string)\n",
"\n",
"# Specify your file path\n",
"file_path = 'test_axolotl.yaml'\n",
"\n",
"# Write the YAML file\n",
"with open(file_path, 'w') as file:\n",
" yaml.dump(yaml_dict, file)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bidoj8YLTusD"
},
"source": [
"## Launch the training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ydTI2Jk2RStU",
"outputId": "d6d0df17-4b53-439c-c802-22c0456d301b"
},
"outputs": [],
"source": [
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Play with inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
" --qlora_model_dir=\"./qlora-out\" --gradio"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
} }

View File

@@ -10,7 +10,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./out output_dir: ./outputs/out
sequence_len: 512 sequence_len: 512
sample_packing: false sample_packing: false

View File

@@ -10,7 +10,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./out output_dir: ./outputs/out
sequence_len: 512 sequence_len: 512
sample_packing: false sample_packing: false

View File

@@ -10,7 +10,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./out output_dir: ./outputs/out
sequence_len: 512 sequence_len: 512
sample_packing: false sample_packing: false

View File

@@ -28,7 +28,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./falcon-7b output_dir: ./outputs/falcon-7b
batch_size: 2 batch_size: 2
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4

View File

@@ -42,7 +42,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
# QLoRA paper Table 9 # QLoRA paper Table 9
# - 16 for 7b & 13b # - 16 for 7b & 13b

View File

@@ -28,7 +28,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./falcon-7b output_dir: ./outputs/falcon-7b
batch_size: 2 batch_size: 2
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4

View File

@@ -12,7 +12,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
val_set_size: 0.1 val_set_size: 0.1
output_dir: ./out output_dir: ./outputs/out
adapter: qlora adapter: qlora
lora_r: 32 lora_r: 32

View File

@@ -23,7 +23,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 2 num_epochs: 2

View File

@@ -10,7 +10,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./out output_dir: ./outputs/out
sequence_len: 4096 sequence_len: 4096
sample_packing: false sample_packing: false

View File

@@ -10,7 +10,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./out output_dir: ./outputs/out
sequence_len: 4096 sequence_len: 4096
sample_packing: false sample_packing: false

View File

@@ -21,7 +21,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./jeopardy-bot-7b output_dir: ./outputs/jeopardy-bot-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./out output_dir: ./outputs/out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -33,7 +33,7 @@ wandb_project:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./model-out output_dir: ./outputs/model-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./lisa-out output_dir: ./outputs/lisa-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./outputs/lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./outputs/lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -12,7 +12,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./relora-out output_dir: ./outputs/relora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./out output_dir: ./outputs/out
sequence_len: 8192 sequence_len: 8192
sample_packing: true sample_packing: true

View File

@@ -0,0 +1,76 @@
base_model: meta-llama/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
chat_template: llama3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
chat_template: llama3
field_messages: messages
message_field_role: role
message_field_content: content
roles:
user:
- user
assistant:
- assistant
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./outputs/lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true
@@ -24,6 +24,9 @@ lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: lora_fan_in_fan_out:
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project: wandb_project:
wandb_entity: wandb_entity:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./out/qlora-llama3-70b output_dir: ./outputs/out/qlora-llama3-70b
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0 val_set_size: 0
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -12,7 +12,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./out output_dir: ./outputs/out
sequence_len: 2048 sequence_len: 2048
sample_packing: false sample_packing: false

View File

@@ -23,7 +23,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./out output_dir: ./outputs/out
sequence_len: 2048 sequence_len: 2048
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./out output_dir: ./outputs/out
sequence_len: 8192 sequence_len: 8192
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0 val_set_size: 0
output_dir: ./lora-out output_dir: ./outputs/lora-out
eval_sample_packing: false eval_sample_packing: false
adapter: lora adapter: lora

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.1 val_set_size: 0.1
output_dir: ./lora-out output_dir: ./outputs/lora-out
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:

View File

@@ -12,7 +12,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.02 val_set_size: 0.02
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
model_config: model_config:
output_router_logits: true output_router_logits: true

View File

@@ -16,7 +16,7 @@ datasets:
type: chat_template.argilla type: chat_template.argilla
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.1 val_set_size: 0.1
output_dir: ./mistral-qlora-orpo-out output_dir: ./outputs/mistral-qlora-orpo-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.02 val_set_size: 0.02
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
model_config: model_config:
output_router_logits: true output_router_logits: true

View File

@@ -12,7 +12,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.02 val_set_size: 0.02
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
model_config: model_config:
output_router_logits: true output_router_logits: true

View File

@@ -12,7 +12,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
## You can optionally freeze the entire model and unfreeze a subset of parameters ## You can optionally freeze the entire model and unfreeze a subset of parameters
unfrozen_parameters: unfrozen_parameters:

View File

@@ -21,7 +21,7 @@ model_config:
datasets: datasets:
- path: yahma/alpaca-cleaned - path: yahma/alpaca-cleaned
type: alpaca type: alpaca
output_dir: ./out output_dir: ./outputs/out
sequence_len: 8000 sequence_len: 8000
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.1 val_set_size: 0.1
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -23,7 +23,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./mpt-alpaca-7b output_dir: ./outputs/mpt-alpaca-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4

View File

@@ -25,7 +25,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./openllama-out output_dir: ./outputs/openllama-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4

View File

@@ -31,7 +31,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./lora-out output_dir: ./outputs/lora-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 4

View File

@@ -25,7 +25,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 4

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./phi-sft-out output_dir: ./outputs/phi-sft-out
sequence_len: 2048 sequence_len: 2048
sample_packing: true sample_packing: true

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./phi-sft-out output_dir: ./outputs/phi-sft-out
sequence_len: 2048 sequence_len: 2048
sample_packing: true sample_packing: true

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./phi-sft-out output_dir: ./outputs/phi-sft-out
sequence_len: 2048 sequence_len: 2048
sample_packing: true sample_packing: true

View File

@@ -26,7 +26,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./pythia-12b output_dir: ./outputs/pythia-12b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 5 num_epochs: 5

View File

@@ -20,7 +20,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./lora-alpaca-pythia output_dir: ./outputs/lora-alpaca-pythia
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 4 micro_batch_size: 4
num_epochs: 4 num_epochs: 4

View File

@@ -13,7 +13,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./outputs/lora-out
sequence_len: 2048 # supports up to 8192 sequence_len: 2048 # supports up to 8192
sample_packing: false sample_packing: false

View File

@@ -13,7 +13,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./outputs/lora-out
sequence_len: 2048 # supports up to 8192 sequence_len: 2048 # supports up to 8192
sample_packing: false sample_packing: false

View File

@@ -10,7 +10,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./out output_dir: ./outputs/out
sequence_len: 1024 # supports up to 32k sequence_len: 1024 # supports up to 32k
sample_packing: false sample_packing: false

View File

@@ -10,7 +10,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./out output_dir: ./outputs/out
sequence_len: 1024 # supports up to 32k sequence_len: 1024 # supports up to 32k
sample_packing: false sample_packing: false

View File

@@ -24,7 +24,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./redpajama-alpaca-3b output_dir: ./outputs/redpajama-alpaca-3b
batch_size: 4 batch_size: 4
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4

View File

@@ -23,7 +23,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./lora-replit output_dir: ./outputs/lora-replit
batch_size: 8 batch_size: 8
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4

View File

@@ -12,7 +12,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./out output_dir: ./outputs/out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -12,7 +12,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./outputs/lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.2 val_set_size: 0.2
output_dir: ./qlora output_dir: ./outputs/qlora
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0 val_set_size: 0
output_dir: ./lora-out output_dir: ./outputs/lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -11,7 +11,7 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./outputs/lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -14,7 +14,7 @@ pretraining_dataset:
type: pretrain type: pretrain
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./model-out output_dir: ./outputs/model-out
sequence_len: 2048 sequence_len: 2048
sample_packing: true sample_packing: true

View File

@@ -11,13 +11,14 @@ datasets:
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true pad_to_sequence_len: true
lora_r: 32 lora_r: 32

View File

@@ -40,7 +40,7 @@ wandb_entity:
wandb_watch: wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
# QLoRA paper Table 9 # QLoRA paper Table 9
# - 16 for 7b & 13b # - 16 for 7b & 13b

View File

@@ -33,7 +33,7 @@ eval_sample_packing: false
eval_batch_size: 1 eval_batch_size: 1
# LoRA # LoRA
output_dir: ./qlora-out output_dir: ./outputs/qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
lora_r: 32 lora_r: 32

View File

@@ -1,22 +1,22 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.10.0 peft==0.11.1
transformers @ git+https://github.com/huggingface/transformers.git@43d17c18360ac9c3d3491389328e2fe55fe8f9ce transformers==4.41.1
tokenizers==0.15.0 tokenizers==0.19.1
bitsandbytes==0.43.0 bitsandbytes==0.43.1
accelerate==0.28.0 accelerate==0.30.1
deepspeed==0.13.1 deepspeed==0.14.2
pydantic==2.6.3 pydantic==2.6.3
addict addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
requests requests
datasets==2.15.0 datasets==2.19.1
flash-attn==2.5.5 flash-attn==2.5.8
sentencepiece sentencepiece
wandb wandb
einops einops
xformers==0.0.22 xformers==0.0.26.post1
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
colorama colorama
@@ -28,7 +28,7 @@ scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml pynvml
art art
fschat @ git+https://github.com/lm-sys/FastChat.git@5095615810cf613dba7f27dd155f571fcff976d8 fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard
@@ -39,6 +39,6 @@ s3fs
gcsfs gcsfs
# adlfs # adlfs
trl==0.8.5 trl==0.8.6
zstandard==0.22.0 zstandard==0.22.0
fastcore fastcore

View File

@@ -0,0 +1,82 @@
#!/bin/bash
# Export specific ENV variables to /etc/rp_environment
echo "Exporting environment variables..."
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
conda init
# this needs to come after conda init
echo 'source /etc/rp_environment' >> ~/.bashrc
add_keys_to_authorized() {
local key_value=$1
# Create the ~/.ssh directory and set permissions
mkdir -p ~/.ssh
chmod 700 ~/.ssh
# Create the authorized_keys file if it doesn't exist
touch ~/.ssh/authorized_keys
# Initialize an empty key variable
local key=""
# Read the key variable word by word
for word in $key_value; do
# Check if the word looks like the start of a key
if [[ $word == ssh-* ]]; then
# If there's a key being built, add it to the authorized_keys file
if [[ -n $key ]]; then
echo $key >> ~/.ssh/authorized_keys
fi
# Start a new key
key=$word
else
# Append the word to the current key
key="$key $word"
fi
done
# Add the last key to the authorized_keys file
if [[ -n $key ]]; then
echo $key >> ~/.ssh/authorized_keys
fi
# Set the correct permissions
chmod 600 ~/.ssh/authorized_keys
chmod 700 -R ~/.ssh
}
if [[ $PUBLIC_KEY ]]; then
# runpod
add_keys_to_authorized "$PUBLIC_KEY"
# Start the SSH service in the background
service ssh start
elif [[ $SSH_KEY ]]; then
# latitude.sh
add_keys_to_authorized "$SSH_KEY"
# Start the SSH service in the background
service ssh start
else
echo "No PUBLIC_KEY or SSH_KEY environment variable provided, not starting openSSH daemon"
fi
# Check if JUPYTER_PASSWORD is set and not empty
if [ -n "$JUPYTER_PASSWORD" ]; then
# Set JUPYTER_TOKEN to the value of JUPYTER_PASSWORD
export JUPYTER_TOKEN="$JUPYTER_PASSWORD"
fi
if [ "$JUPYTER_DISABLE" != "1" ]; then
# Run Jupyter Lab in the background
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
fi
if [ ! -d "/workspace/data/axolotl-artifacts" ]; then
mkdir -p /workspace/data/axolotl-artifacts
fi
if [ ! -L "/workspace/axolotl/outputs" ]; then
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
fi
# Execute the passed arguments (CMD)
exec "$@"

View File

@@ -5,20 +5,53 @@ echo "Exporting environment variables..."
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
echo 'source /etc/rp_environment' >> ~/.bashrc echo 'source /etc/rp_environment' >> ~/.bashrc
add_keys_to_authorized() {
local key_value=$1
# Create the ~/.ssh directory and set permissions
mkdir -p ~/.ssh
chmod 700 ~/.ssh
# Create the authorized_keys file if it doesn't exist
touch ~/.ssh/authorized_keys
# Initialize an empty key variable
local key=""
# Read the key variable word by word
for word in $key_value; do
# Check if the word looks like the start of a key
if [[ $word == ssh-* ]]; then
# If there's a key being built, add it to the authorized_keys file
if [[ -n $key ]]; then
echo $key >> ~/.ssh/authorized_keys
fi
# Start a new key
key=$word
else
# Append the word to the current key
key="$key $word"
fi
done
# Add the last key to the authorized_keys file
if [[ -n $key ]]; then
echo $key >> ~/.ssh/authorized_keys
fi
# Set the correct permissions
chmod 600 ~/.ssh/authorized_keys
chmod 700 -R ~/.ssh
}
if [[ $PUBLIC_KEY ]]; then if [[ $PUBLIC_KEY ]]; then
# runpod # runpod
mkdir -p ~/.ssh add_keys_to_authorized "$PUBLIC_KEY"
chmod 700 ~/.ssh
echo $PUBLIC_KEY >> ~/.ssh/authorized_keys
chmod 700 -R ~/.ssh
# Start the SSH service in the background # Start the SSH service in the background
service ssh start service ssh start
elif [ -n "$SSH_KEY" ]; then elif [[ $SSH_KEY ]]; then
# latitude.sh # latitude.sh
mkdir -p ~/.ssh add_keys_to_authorized "$SSH_KEY"
chmod 700 ~/.ssh
echo $SSH_KEY >> ~/.ssh/authorized_keys
chmod 700 -R ~/.ssh
# Start the SSH service in the background # Start the SSH service in the background
service ssh start service ssh start
else else
@@ -36,5 +69,12 @@ if [ "$JUPYTER_DISABLE" != "1" ]; then
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* & jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
fi fi
if [ ! -d "/workspace/data/axolotl-artifacts" ]; then
mkdir -p /workspace/data/axolotl-artifacts
fi
if [ ! -L "/workspace/axolotl/outputs" ]; then
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
fi
# Execute the passed arguments (CMD) # Execute the passed arguments (CMD)
exec "$@" exec "$@"

View File

@@ -30,8 +30,11 @@ def parse_requirements():
try: try:
if "Darwin" in platform.system(): if "Darwin" in platform.system():
_install_requires.pop(_install_requires.index("xformers==0.0.22")) # don't install xformers on MacOS
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
else: else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
torch_version = version("torch") torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}") _install_requires.append(f"torch=={torch_version}")
@@ -45,9 +48,15 @@ def parse_requirements():
else: else:
raise ValueError("Invalid version format") raise ValueError("Invalid version format")
if (major, minor) >= (2, 1): if (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index("xformers==0.0.22")) pass
_install_requires.append("xformers>=0.0.23") elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
_install_requires.append("xformers>=0.0.25.post1")
else:
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
_install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError: except PackageNotFoundError:
pass pass
@@ -59,7 +68,7 @@ install_requires, dependency_links = parse_requirements()
setup( setup(
name="axolotl", name="axolotl",
version="0.4.0", version="0.4.1",
description="LLM Trainer", description="LLM Trainer",
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.", long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
package_dir={"": "src"}, package_dir={"": "src"},
@@ -68,13 +77,13 @@ setup(
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"flash-attn": [ "flash-attn": [
"flash-attn==2.5.5", "flash-attn==2.5.8",
], ],
"fused-dense-lib": [ "fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib", "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed==0.13.1", "deepspeed==0.14.2",
"deepspeed-kernels", "deepspeed-kernels",
], ],
"mamba-ssm": [ "mamba-ssm": [

View File

@@ -25,6 +25,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
load_in_8bit=False, load_in_8bit=False,
load_in_4bit=False, load_in_4bit=False,
flash_attention=False, flash_attention=False,
deepspeed=None,
fsdp=None,
**kwargs, **kwargs,
) )
@@ -40,6 +42,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cfg.flash_attention = False parsed_cfg.flash_attention = False
parsed_cfg.deepspeed = None parsed_cfg.deepspeed = None
parsed_cfg.fsdp = None parsed_cfg.fsdp = None
parsed_cfg.fsdp_config = None
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)

View File

@@ -19,7 +19,10 @@ from axolotl.cli import (
) )
from axolotl.common.cli import PreprocessCliArgs from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.sharegpt import register_chatml_template from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
)
LOG = logging.getLogger("axolotl.cli.preprocess") LOG = logging.getLogger("axolotl.cli.preprocess")
@@ -36,13 +39,22 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
return_remaining_strings=True return_remaining_strings=True
) )
if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message: if parsed_cfg.chat_template == "chatml":
LOG.info( if parsed_cfg.default_system_message:
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}" LOG.info(
) f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
register_chatml_template(parsed_cfg.default_system_message) )
else: register_chatml_template(parsed_cfg.default_system_message)
register_chatml_template() else:
register_chatml_template()
elif parsed_cfg.chat_template == "llama3":
if parsed_cfg.default_system_message:
LOG.info(
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_llama3_template(parsed_cfg.default_system_message)
else:
register_llama3_template()
if not parsed_cfg.dataset_prepared_path: if not parsed_cfg.dataset_prepared_path:
msg = ( msg = (

View File

@@ -19,7 +19,10 @@ from axolotl.cli import (
print_axolotl_text_art, print_axolotl_text_art,
) )
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.prompt_strategies.sharegpt import register_chatml_template from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
)
from axolotl.train import train from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train") LOG = logging.getLogger("axolotl.cli.train")
@@ -47,6 +50,14 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
else: else:
register_chatml_template() register_chatml_template()
if cfg.chat_template == "llama3" and cfg.default_system_message:
LOG.info(
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
)
register_llama3_template(cfg.default_system_message)
else:
register_llama3_template()
if cfg.rl: # and cfg.rl != "orpo": if cfg.rl: # and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else: else:

196
src/axolotl/core/trainer_builder.py Normal file → Executable file
View File

@@ -30,7 +30,7 @@ from transformers import (
) )
from transformers.trainer_utils import seed_worker from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer, ORPOConfig, ORPOTrainer from trl import DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
from trl.trainer.utils import pad_to_length from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer from axolotl.loraplus import create_loraplus_optimizer
@@ -43,7 +43,7 @@ from axolotl.utils.callbacks import (
LossWatchDogCallback, LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback, SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
SaveModelOnTrainEndCallback, SaveModelCallback,
bench_eval_callback_factory, bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory, causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory, log_prediction_callback_factory,
@@ -91,11 +91,12 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
@dataclass @dataclass
class AxolotlTrainingArguments(TrainingArguments): class AxolotlTrainingMixins:
""" """
Extend the base TrainingArguments for axolotl helpers Mixin class for the Axolotl training args.
""" """
# pylint: disable=duplicate-code
model_type: Optional[str] = field( model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."} default=None, metadata={"help": "HF model configuration model_type."}
) )
@@ -125,14 +126,22 @@ class AxolotlTrainingArguments(TrainingArguments):
default=1.0, default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."}, metadata={"help": "Sample packing efficiency for calculating batch length."},
) )
sample_packing_bin_size: int = field(
default=200,
metadata={
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
},
)
sample_packing_group_size: int = field(
default=100000,
metadata={
"help": "The number of samples to group together for packing. Increase for better packing."
},
)
max_seq_length: int = field( max_seq_length: int = field(
default=2048, default=2048,
metadata={"help": "The maximum sequence length the model can handle"}, metadata={"help": "The maximum sequence length the model can handle"},
) )
sample_packing_seq_len_multiplier: int = field(
default=1,
metadata={"help": "the multiplier for the max len for packed sequences"},
)
relora_steps: Optional[int] = field( relora_steps: Optional[int] = field(
default=None, default=None,
metadata={"help": "how often to reset for ReLoRA"}, metadata={"help": "how often to reset for ReLoRA"},
@@ -219,6 +228,30 @@ class AxolotlTrainingArguments(TrainingArguments):
) )
@dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
so it can't be used as a mixin.
"""
@dataclass
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
"""
ORPO config for ORPO training
"""
@dataclass
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
"""
KTO config for KTO training
"""
class AxolotlTrainer(Trainer): class AxolotlTrainer(Trainer):
""" """
Extend the base Trainer for axolotl helpers Extend the base Trainer for axolotl helpers
@@ -346,11 +379,12 @@ class AxolotlTrainer(Trainer):
) )
return MultipackBatchSampler( return MultipackBatchSampler(
RandomSampler(self.train_dataset), RandomSampler(self.train_dataset),
batch_size=batch_size,
drop_last=True,
batch_max_len=batch_max_len,
lengths=get_dataset_lengths(self.train_dataset), lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency, batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
) )
if self.args.curriculum_sampling: if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset) return SequentialSampler(self.train_dataset)
@@ -370,11 +404,12 @@ class AxolotlTrainer(Trainer):
) )
return MultipackBatchSampler( return MultipackBatchSampler(
SequentialSampler(eval_dataset), SequentialSampler(eval_dataset),
batch_size=batch_size, lengths=get_dataset_lengths(self.eval_dataset),
drop_last=True,
batch_max_len=batch_max_len, batch_max_len=batch_max_len,
lengths=get_dataset_lengths(eval_dataset), batch_size=batch_size,
packing_efficiency_estimate=self.args.sample_packing_efficiency, group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
) )
return super()._get_eval_sampler(eval_dataset) return super()._get_eval_sampler(eval_dataset)
@@ -798,6 +833,40 @@ class AxolotlDPOTrainer(DPOTrainer):
tag_names = ["axolotl", "dpo"] tag_names = ["axolotl", "dpo"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.optimizer = None
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
if loraplus_lr_ratio:
print("Using lora+")
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
@wraps(DPOTrainer.push_to_hub) @wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str: def push_to_hub(self, *args, **kwargs) -> str:
""" """
@@ -826,6 +895,14 @@ class AxolotlORPOTrainer(ORPOTrainer):
tag_names = ["axolotl", "orpo"] tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
""" """
Base class for trainer builder Base class for trainer builder
@@ -945,7 +1022,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.loss_watchdog_threshold is not None: if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg)) callbacks.append(LossWatchDogCallback(self.cfg))
callbacks.append(SaveModelOnTrainEndCallback()) callbacks.append(SaveModelCallback())
return callbacks return callbacks
@@ -1071,11 +1148,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.save_safetensors is not None: if self.cfg.save_safetensors is not None:
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = self.cfg.sample_packing_eff_est
if self.cfg.dataloader_pin_memory is not None: if self.cfg.dataloader_pin_memory is not None:
training_arguments_kwargs[ training_arguments_kwargs[
"dataloader_pin_memory" "dataloader_pin_memory"
@@ -1123,6 +1195,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined # default to saving each epoch if not defined
training_arguments_kwargs["save_strategy"] = "epoch" training_arguments_kwargs["save_strategy"] = "epoch"
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
if self.cfg.do_bench_eval: if self.cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset: if self.cfg.bench_dataset:
@@ -1202,11 +1276,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
report_to = None report_to = []
if self.cfg.use_wandb: if self.cfg.use_wandb:
report_to = "wandb" report_to.append("wandb")
if self.cfg.use_mlflow: if self.cfg.use_mlflow:
report_to = "mlflow" report_to.append("mlflow")
if self.cfg.use_tensorboard:
report_to.append("tensorboard")
training_arguments_kwargs["report_to"] = report_to training_arguments_kwargs["report_to"] = report_to
training_arguments_kwargs["run_name"] = ( training_arguments_kwargs["run_name"] = (
self.cfg.wandb_name if self.cfg.use_wandb else None self.cfg.wandb_name if self.cfg.use_wandb else None
@@ -1246,20 +1323,27 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["weight_decay"] = ( training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
) )
training_arguments_kwargs["sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
)
training_arguments_kwargs["multipack_real_batches"] = (
self.cfg.flash_attention is not True
)
training_arguments_kwargs["eval_sample_packing"] = (
self.cfg.sample_packing
if self.cfg.eval_sample_packing is not False
else False
)
training_arguments_kwargs[ training_arguments_kwargs[
"sample_packing_seq_len_multiplier" "multipack_real_batches"
] = self.cfg.micro_batch_size ] = not self.cfg.flash_attention
training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing
)
if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs[
"sample_packing_bin_size"
] = self.cfg.sample_packing_bin_size
if self.cfg.sample_packing_group_size is not None:
training_arguments_kwargs[
"sample_packing_group_size"
] = self.cfg.sample_packing_group_size
if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = self.cfg.sample_packing_eff_est
if self.cfg.relora_steps: if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs[ training_arguments_kwargs[
@@ -1429,7 +1513,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self): def get_callbacks(self):
callbacks = super().get_callbacks() callbacks = super().get_callbacks()
callbacks.append(SaveModelOnTrainEndCallback()) callbacks.append(SaveModelCallback())
return callbacks return callbacks
@@ -1470,6 +1554,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.bf16 or self.cfg.bfloat16: if self.cfg.bf16 or self.cfg.bfloat16:
training_args_kwargs["bf16"] = True training_args_kwargs["bf16"] = True
training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding
training_args_kwargs["lr_scheduler_type"] = ( training_args_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
) )
@@ -1522,17 +1608,36 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# trl does some odd mapping of alpha to beta to reuse the beta parameter ??? # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha training_args_kwargs["beta"] = self.cfg.orpo_alpha
training_args_cls = TrainingArguments training_args_cls = AxolotlTrainingArguments
if self.cfg.rl == "orpo": if self.cfg.rl == "orpo":
training_args_cls = ORPOConfig training_args_cls = AxolotlORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
training_args = training_args_cls( if self.cfg.rl == "kto":
training_args_cls = AxolotlKTOConfig
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
)
training_args_kwargs["undesirable_weight"] = (
self.cfg.kto_undesirable_weight or 1.0
)
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
output_dir=self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size, per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=self.cfg.max_steps or total_num_steps, max_steps=self.cfg.max_steps or total_num_steps,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate, learning_rate=self.cfg.learning_rate,
output_dir=self.cfg.output_dir,
warmup_steps=self.cfg.warmup_steps, warmup_steps=self.cfg.warmup_steps,
logging_first_step=True, logging_first_step=True,
logging_steps=1, logging_steps=1,
@@ -1562,7 +1667,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
] = self.cfg.precompute_ref_log_probs ] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]: if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
trainer_cls = AxolotlDPOTrainer trainer_cls = AxolotlDPOTrainer
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1 dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
trainer_cls_args = [self.model, self.model_ref] trainer_cls_args = [self.model, self.model_ref]
# these aren't used for the ORPO trainer # these aren't used for the ORPO trainer
@@ -1575,6 +1680,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
elif self.cfg.rl == "orpo": elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
elif self.cfg.rl == "kto":
trainer_cls = AxolotlKTOTrainer
trainer_cls_args = [self.model]
else: else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}") raise ValueError(f"Unsupported RL: {self.cfg.rl}")
dpo_trainer = trainer_cls( dpo_trainer = trainer_cls(

View File

@@ -123,6 +123,17 @@ def get_turns( # pylint: disable=too-many-return-statements
else: else:
yield role, "" yield role, ""
return return
if self.sep_style == SeparatorStyle.LLAMA3:
if self.system_message:
# For llama3, the system message is NOT incorporated into the first human instruction
# All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|>
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>"
else:
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", ""
return
if self.sep_style == SeparatorStyle.GEMMA: if self.sep_style == SeparatorStyle.GEMMA:
if self.system_message: if self.system_message:
raise ValueError("Gemma chat template does not support system messages") raise ValueError("Gemma chat template does not support system messages")

View File

@@ -42,9 +42,9 @@ def patch_mixtral_moe_forward_zero3() -> None:
return final_hidden_states, router_logits return final_hidden_states, router_logits
from transformers.models.mixtral.modeling_mixtral import ( from transformers.models.mixtral.modeling_mixtral import (
MixtralBLockSparseTop2MLP, MixtralBlockSparseTop2MLP,
MixtralSparseMoeBlock, MixtralSparseMoeBlock,
) )
MixtralBLockSparseTop2MLP.forward = mlp_forward MixtralBlockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward MixtralSparseMoeBlock.forward = moe_forward

View File

@@ -0,0 +1,267 @@
"""module for patching with unsloth optimizations"""
import inspect
import logging
import re
import types
from typing import Tuple
from peft import PeftModelForCausalLM
from transformers.models.llama.modeling_llama import (
LlamaFlashAttention2,
LlamaForCausalLM,
)
LOG = logging.getLogger("axolotl.monkeypatch.unsloth")
ORIGINAL_CEL_CODE = """ if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
"""
PATCHED_CEL_CODE = """ if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
)
"""
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
""".lstrip(
"\n"
)
PATCHED_QKV_CODE = """
query_states, key_states, value_states = self.apply_qkv(self, hidden_states)
""".lstrip(
"\n"
)
ORIGINAL_O_CODE = """
attn_output = self.o_proj(attn_output)
""".lstrip(
"\n"
)
PATCHED_O_CODE = """
attn_output = self.apply_o(self, attn_output)
""".lstrip(
"\n"
)
def original_apply_qkv(self, hidden_states):
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
return query_states, key_states, value_states
def original_apply_o(self, hidden_states):
attn_output = self.o_proj(hidden_states)
return attn_output
def get_forward_code() -> str:
forward = inspect.getsource(LlamaForCausalLM.forward)
return forward
def test_cel_is_patchable() -> bool:
forward = get_forward_code()
return ORIGINAL_CEL_CODE in forward
def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward
def test_self_attn_is_patchable() -> bool:
qkv = get_self_attn_code()
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_QKV_CODE in qkv
def integrate_cross_entropy_loss_patch():
forward = get_forward_code()
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
forward = forward.replace(
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
)
forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"",
)
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace(
"def forward(",
"def fast_cross_entropy_loss_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
print("patching unsloth fast_cross_entropy_loss")
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
def detab_code(code: str) -> Tuple[str, str]:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
return code, spaces
def patch_self_attn_lora():
self_attn_forward = get_self_attn_code()
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
self_attn_forward
)
self_attn_forward, _ = detab_code(self_attn_forward)
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original qkv code not found"
assert ORIGINAL_O_CODE in self_attn_forward, "Original o code not found"
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",
"def unsloth_attn_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in self_attn_forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
print("patching unsloth attn lora")
LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
from unsloth.kernels import apply_lora_mlp_swiglu
apply_lora_mlp = apply_lora_mlp_swiglu
elif peft_model.base_model.config.model_type == "gemma":
from unsloth.kernels import apply_lora_mlp_geglu_approx
apply_lora_mlp = apply_lora_mlp_geglu_approx
else:
raise NotImplementedError(
f"Model type {peft_model.base_model.config.model_type} not supported"
)
for idx, layer in enumerate(peft_model.model.model.layers):
layer_modules = [
getattr(layer.mlp, linear_proj)
for linear_proj in ["gate_proj", "up_proj", "down_proj"]
]
is_mlp_lora = all(hasattr(module, "lora_A") for module in layer_modules)
mlp_no_bias = all(
getattr(module, "base_layer", module).bias is None
for module in layer_modules
)
mlp_not_dora = all(
getattr(module, "lora_magnitude_vector", None) is None
for module in layer_modules
)
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
else:
logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
from unsloth.kernels import apply_lora_o, apply_lora_qkv
for idx, layer in enumerate(peft_model.model.model.layers):
if cfg.unsloth_lora_qkv:
layer_modules = [
getattr(layer.self_attn, linear_proj)
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
is_qkv_lora = all(hasattr(module, "lora_A") for module in layer_modules)
qkv_no_bias = all(
getattr(module, "base_layer", module).bias is None
for module in layer_modules
)
qkv_not_dora = all(
getattr(module, "lora_magnitude_vector", None) is None
for module in layer_modules
)
if is_qkv_lora and qkv_no_bias and qkv_not_dora:
layer.self_attn.apply_qkv = apply_lora_qkv
else:
layer.self_attn.apply_qkv = original_apply_qkv
logging.warning(
"unable to apply unsloth lora qkv patch to layer %d", idx
)
if cfg.unsloth_lora_o:
layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
]
is_o_lora = all(hasattr(module, "lora_A") for module in layer_modules)
o_no_bias = all(
getattr(module, "base_layer", module).bias is None
for module in layer_modules
)
o_not_dora = all(
getattr(module, "lora_magnitude_vector", None) is None
for module in layer_modules
)
if is_o_lora and o_no_bias and o_not_dora:
layer.self_attn.apply_o = apply_lora_o
else:
layer.self_attn.apply_o = original_apply_o
logging.warning(
"unable to apply unsloth lora o_proj patch to layer %d", idx
)

View File

@@ -1,24 +1,56 @@
""" """
HF Chat Templates prompt strategy HF Chat Templates prompt strategy
""" """
from typing import Any, Dict, Optional
import logging
from typing import Any, Dict, List, Optional
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import Prompter from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import chat_templates
LOG = logging.getLogger("axolotl")
class ChatTemplatePrompter(Prompter): class ChatTemplatePrompter(Prompter):
"""prompter for HF chat templates""" """prompter for HF chat templates"""
def __init__(self, tokenizer, chat_template=None, max_length=2048): def __init__(
self,
tokenizer,
chat_template=None,
max_length=2048,
message_field_role: str = "from",
message_field_content: str = "value",
roles: Optional[Dict[str, List[str]]] = None,
):
if roles:
self.roles = {s: t for t, sources in roles.items() for s in sources}
else:
self.roles = {
"human": "user",
"user": "user",
"assistant": "assistant",
"gpt": "assistant",
"system": "system",
}
self.message_field_role = message_field_role
self.message_field_content = message_field_content
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.chat_template = chat_template self.chat_template = chat_template
self.max_length = max_length self.max_length = max_length
def build_prompt(self, conversation, add_generation_prompt=False): def build_prompt(self, conversation, add_generation_prompt=False):
turns = [
{
"role": self.roles[t[self.message_field_role]],
"content": t[self.message_field_content],
}
for t in conversation
]
return self.tokenizer.apply_chat_template( return self.tokenizer.apply_chat_template(
conversation, turns,
truncation=True, truncation=True,
max_length=self.max_length, max_length=self.max_length,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
@@ -31,9 +63,19 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
Tokenizing strategy for instruction-based prompts. Tokenizing strategy for instruction-based prompts.
""" """
_messages = "conversations"
@property
def messages(self):
return self._messages
@messages.setter
def messages(self, messages):
self._messages = messages
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
turns = self.get_conversation_thread(prompt) turns = self.get_conversation_thread(prompt)
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True) prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True)
input_ids = self.prompter.build_prompt(turns) input_ids = self.prompter.build_prompt(turns)
if not self.train_on_inputs: if not self.train_on_inputs:
@@ -51,28 +93,37 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return tokenized_prompt return tokenized_prompt
def get_conversation_thread(self, prompt): def get_conversation_thread(self, prompt):
conversations = prompt["conversations"] return prompt[self.messages]
# remap roles - allow for assistant turn
role_map = {
"human": "user",
"user": "user",
"assistant": "assistant",
"gpt": "assistant",
}
turns = [
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
]
return turns
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
chat_template = ( chat_template = (
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml" ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
) )
message_field_role = (
ds_cfg["message_field_role"]
if ds_cfg and "message_field_role" in ds_cfg
else "from"
)
message_field_content = (
ds_cfg["message_field_content"]
if ds_cfg and "message_field_content" in ds_cfg
else "value"
)
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter(tokenizer, chat_templates(chat_template)), ChatTemplatePrompter(
tokenizer,
chat_templates(chat_template),
message_field_role=message_field_role,
message_field_content=message_field_content,
roles=roles,
),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy return strategy

View File

@@ -0,0 +1,133 @@
"""
DPO strategies for llama-3 chat template
"""
def argilla(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
return sample
return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/dpo-mix-7k conversations
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample
return transform_fn
def icr(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
return transform_fn
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca DPO Pairs
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
return transform_fn
def prompt_pairs(
cfg, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
return transform_fn
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample
return transform_fn

View File

@@ -0,0 +1,9 @@
"""
module for KTO style dataset transform strategies
"""
from functools import partial
from ..base import load as load_base
load = partial(load_base, module_base="axolotl.prompt_strategies.kto")

View File

@@ -0,0 +1,105 @@
"""
KTO strategies for chatml
"""
# pylint: disable=duplicate-code
def argilla(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample
return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/kto-mix-15k conversations
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>"
return sample
return transform_fn
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca KTO
ex: argilla/distilabel-intel-orca-kto
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample
return transform_fn
def prompt_pairs(
cfg, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample
return transform_fn
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample
return transform_fn

View File

@@ -0,0 +1,105 @@
"""
KTO strategies for llama-3 chat template
"""
# pylint: disable=duplicate-code
def argilla(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample
return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/kto-mix-15k conversations
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>"
return sample
return transform_fn
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca KTO
ex: argilla/distilabel-intel-orca-kto
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample
return transform_fn
def prompt_pairs(
cfg, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample
return transform_fn
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample
return transform_fn

View File

@@ -0,0 +1,39 @@
"""
User-defined KTO strategies
"""
# pylint: disable=duplicate-code
def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
ds_cfg = cfg["datasets"][dataset_idx]["type"]
if not isinstance(ds_cfg, dict):
raise ValueError(
f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
)
field_prompt = ds_cfg.get("field_prompt", "prompt")
field_system = ds_cfg.get("field_system", "system")
field_completion = ds_cfg.get("field_completion", "completion")
field_label = ds_cfg.get("field_label", "label")
prompt_format = ds_cfg.get("prompt_format")
if not prompt_format:
prompt_format = "{" + field_prompt + "}"
completion_format = ds_cfg.get("completion_format")
if not completion_format:
chosen_format = "{" + field_completion + "}"
def transform_fn(sample):
if (
"{" + field_system + "}" in prompt_format
and field_system in sample
and sample[field_system]
):
sample["prompt"] = prompt_format.format(
system=sample[field_system], prompt=sample[field_prompt]
)
else:
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
sample["completion"] = chosen_format.format(chosen=sample[field_completion])
sample["label"] = sample[field_label]
return sample
return transform_fn

View File

@@ -1,7 +1,7 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class""" """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Type
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
@@ -22,7 +22,7 @@ def register_chatml_template(system_message=None):
name="chatml", name="chatml",
system_template="<|im_start|>system\n{system_message}", system_template="<|im_start|>system\n{system_message}",
system_message=system_message, system_message=system_message,
roles=["<|im_start|>user", "<|im_start|>assistant"], roles=("<|im_start|>user", "<|im_start|>assistant"),
sep_style=SeparatorStyle.CHATML, sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>", sep="<|im_end|>",
) )
@@ -32,83 +32,65 @@ def register_chatml_template(system_message=None):
name="chatml_glaive", name="chatml_glaive",
system_template="<|im_start|>system\n{system_message}", system_template="<|im_start|>system\n{system_message}",
system_message=system_message, system_message=system_message,
roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"], roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"),
sep_style=SeparatorStyle.CHATML, sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>", sep="<|im_end|>",
) )
) )
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def register_llama3_template(system_message=None):
conversation = ( system_message = system_message or "You are a helpful assistant."
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None register_conv_template(
) Conversation(
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None name="llama3",
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None system_message=system_message,
strategy = SimpleShareGPTPromptTokenizingStrategy( roles=("user", "assistant"),
ShareGPTPrompterV2( sep_style=SeparatorStyle.LLAMA3,
conversation=conversation, sep="",
role_key_model=field_model, stop_str="<|eot_id|>",
role_key_human=field_human, stop_token_ids=[128001, 128009],
roles=roles, )
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
strategy = UltrachatShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )
def load_guanaco(tokenizer, cfg): def build_loader(
return GuanacoShareGPTPromptTokenizingStrategy( tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
ShareGPTPrompterV2(), prompter_cls: Type["ShareGPTPrompterV2"],
tokenizer, default_conversation: Optional[str] = None,
cfg.train_on_inputs, ):
cfg.sequence_len, def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
) conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else default_conversation
)
field_human = (
ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
)
field_model = (
ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
)
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
strategy = tokenization_strategy_cls(
prompter_cls(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
strategy.strict = ds_cfg["strict"]
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy
return _load
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else "chatml_glaive"
)
return GlaiveShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(conversation=conversation),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
@@ -117,6 +99,7 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
""" """
_strict = False _strict = False
_messages = "conversations"
@property @property
def strict(self): def strict(self):
@@ -126,8 +109,16 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
def strict(self, strict): def strict(self, strict):
self._strict = strict self._strict = strict
@property
def messages(self):
return self._messages
@messages.setter
def messages(self, messages):
self._messages = messages
def get_conversation_thread(self, prompt): def get_conversation_thread(self, prompt):
conversations = prompt["conversations"] conversations = prompt[self.messages]
if self.strict: if self.strict:
return conversations return conversations
role_key = "from" role_key = "from"
@@ -158,7 +149,9 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
return turns return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleRoleShareGPTPromptTokenizingStrategy(
SimpleShareGPTPromptTokenizingStrategy
):
""" """
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
""" """
@@ -209,3 +202,16 @@ class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrat
conversation = merge_consecutive_messages(conversation) conversation = merge_consecutive_messages(conversation)
return conversation return conversation
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_ultrachat = build_loader(
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
)
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_glaive = build_loader(
GlaiveShareGPTPromptTokenizingStrategy,
ShareGPTPrompterV2,
default_conversation="chatml_glaive",
)

View File

@@ -263,6 +263,7 @@ CONVERSATION_ROLE_FORMAT = {
"chatml": "<|im_start|>{ROLE}", "chatml": "<|im_start|>{ROLE}",
"zephyr": "<|{ROLE}|>", "zephyr": "<|{ROLE}|>",
"vicuna_v1.1": "{ROLE}", "vicuna_v1.1": "{ROLE}",
"llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
} }
@@ -348,7 +349,10 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
) )
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") if (
role != "assistant"
): # back to back assistant calls may be okay for tool calls
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])

View File

@@ -197,6 +197,13 @@ def train(
trainer.accelerator.wait_for_everyone() trainer.accelerator.wait_for_everyone()
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped) unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
# the trainer saved a model.safetensors file in the output directory,
# but it is a proxy model and should be deleted
if os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")):
LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}")
LOG.info("This is a proxy model and should be deleted")
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if # Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or # `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
# `zero3_save_16bit_model` is True in DeepSpeed Plugin. # `zero3_save_16bit_model` is True in DeepSpeed Plugin.
@@ -212,6 +219,10 @@ def train(
if cfg.flash_optimum and BetterTransformer: if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id: if not cfg.hub_model_id:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import math
import os import os
from shutil import copyfile from shutil import copyfile
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
@@ -775,9 +776,27 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
return control return control
class SaveModelOnTrainEndCallback(TrainerCallback): class SaveModelCallback(TrainerCallback):
"""Callback to save model on train end""" """Callback to save model on train end"""
def on_step_end( # pylint: disable=unused-argument
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# Save
if state.global_step >= state.max_steps:
control.should_save = True
elif (
args.save_strategy == IntervalStrategy.STEPS
and state.save_steps < 1.0
and state.global_step % math.ceil(state.save_steps * state.max_steps) == 0
):
# workaround to save model on fractional save_steps
control.should_save = True
def on_train_end( # pylint: disable=unused-argument def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs self, args, state, control, **kwargs
): ):

View File

@@ -24,6 +24,7 @@ def chat_templates(user_choice: str):
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}", "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
} }
if user_choice in templates: if user_choice in templates:

Some files were not shown because too many files have changed in this diff Show More