Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
5bb4a782ce dataloader defaults 2023-12-12 17:33:31 -05:00
86 changed files with 3353 additions and 2502 deletions

View File

@@ -28,12 +28,7 @@ jobs:
- cuda: "118" - cuda: "118"
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.1 pytorch: 2.1.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -27,56 +27,38 @@ jobs:
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.1 pytorch: 2.1.0
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.1
axolotl_extras: axolotl_extras:
runs-on: [self-hosted, gpu, docker] runs-on: [self-hosted, gpu, docker]
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Docker metadata - name: Docker metadata
id: metadata id: metadata
uses: docker/metadata-action@v5 uses: docker/metadata-action@v3
with: with:
images: winglian/axolotl images: winglian/axolotl
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@v3 uses: docker/login-action@v2
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/ - name: Set up Docker Buildx
- name: Build and export to Docker uses: docker/setup-buildx-action@v2
uses: docker/build-push-action@v5 - name: Build
uses: docker/build-push-action@v4
with: with:
context: . context: .
load: true
build-args: | build-args: |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }} PYTORCH_VERSION=${{ matrix.pytorch }}
file: ./docker/Dockerfile file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: | tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
- name: Unit Tests
run: |
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
- name: Push to Docker Hub
if: github.event_name != 'pull_request'
run: |
docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
if [ -n "$latest_tag" ]; then
docker push "$latest_tag"
fi
build-axolotl-runpod: build-axolotl-runpod:
needs: build-axolotl needs: build-axolotl
if: github.repository_owner == 'OpenAccess-AI-Collective' if: github.repository_owner == 'OpenAccess-AI-Collective'
@@ -98,31 +80,26 @@ jobs:
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.1 pytorch: 2.1.0
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.1
axolotl_extras: axolotl_extras:
runs-on: [self-hosted, gpu, docker] runs-on: [self-hosted, gpu, docker]
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Docker metadata - name: Docker metadata
id: metadata id: metadata
uses: docker/metadata-action@v5 uses: docker/metadata-action@v3
with: with:
images: winglian/axolotl-runpod images: winglian/axolotl-runpod
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@v3 uses: docker/login-action@v2
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2 uses: docker/setup-buildx-action@v2
- name: Build - name: Build
uses: docker/build-push-action@v5 uses: docker/build-push-action@v4
with: with:
context: . context: .
build-args: | build-args: |

View File

@@ -1,46 +0,0 @@
name: e2e-docker-tests
on:
pull_request:
paths:
- '**.py'
- 'requirements.txt'
- '.github/workflows/*.yml'
workflow_dispatch:
jobs:
build-axolotl:
if: github.repository_owner == 'OpenAccess-AI-Collective'
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.1
runs-on: [self-hosted, gpu, docker]
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Build Docker image
run: |
# Set up build arguments
BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}"
CUDA="${{ matrix.cuda }}"
PYTORCH_VERSION="${{ matrix.pytorch }}"
# Build the Docker image
docker build . \
--file ./docker/Dockerfile \
--build-arg BASE_TAG=$BASE_TAG \
--build-arg CUDA=$CUDA \
--build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
--tag test-axolotl
- name: Unit Tests w docker image
run: |
docker run --rm test-axolotl pytest --ignore=tests/e2e/ /workspace/axolotl/tests/

103
README.md
View File

@@ -36,9 +36,7 @@ Features:
- [Train](#train) - [Train](#train)
- [Inference](#inference) - [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base) - [Merge LORA to Base](#merge-lora-to-base)
- [Special Tokens](#special-tokens)
- [Common Errors](#common-errors-) - [Common Errors](#common-errors-)
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
- [Need Help?](#need-help-) - [Need Help?](#need-help-)
- [Badge](#badge-) - [Badge](#badge-)
- [Community Showcase](#community-showcase) - [Community Showcase](#community-showcase)
@@ -253,13 +251,6 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json ```json
{"conversations": [{"from": "...", "value": "..."}]} {"conversations": [{"from": "...", "value": "..."}]}
``` ```
- `llama-2`: the json is the same format as `sharegpt` above, with the following config (see the [config section](#config) for more details)
```yml
datasets:
- path: <your-path>
type: sharegpt
conversation: llama-2
```
- `completion`: raw corpus - `completion`: raw corpus
```json ```json
{"text": "..."} {"text": "..."}
@@ -520,14 +511,6 @@ model_config:
type: # linear | dynamic type: # linear | dynamic
factor: # float factor: # float
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
# These are default values
llm_int8_has_fp16_weight: false
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# Whether you are training a 4-bit GPTQ quantized model # Whether you are training a 4-bit GPTQ quantized model
gptq: true gptq: true
@@ -550,11 +533,6 @@ tf32: true # require >=ampere
bfloat16: true # require >=ampere bfloat16: true # require >=ampere
float16: true float16: true
# Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset
gpu_memory_limit: 20GiB
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
lora_on_cpu: true
# A list of one or more datasets to finetune the model with # A list of one or more datasets to finetune the model with
datasets: datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
@@ -594,9 +572,6 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column # For `completion` datsets only, uses the provided field instead of `text` column
field: field:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
# Currently supports chatml and inst (mistral/mixtral)
chat_template: chatml
# Axolotl attempts to save the dataset as an arrow after packing the data together so # Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path # subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared dataset_prepared_path: data/last_run_prepared
@@ -648,8 +623,7 @@ max_memory:
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model # If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora adapter: lora
# If you already have a lora model trained that you want to load, put that here. # If you already have a lora model trained that you want to load, put that here.
# This means after training, if you want to test the model, you should set this to the value of `output_dir`. # This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
# Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`.
lora_model_dir: lora_model_dir:
# LoRA hyperparameters # LoRA hyperparameters
@@ -676,6 +650,10 @@ lora_modules_to_save:
# - embed_tokens # - embed_tokens
# - lm_head # - lm_head
# Once you complete training, the model will be saved to the following directory.
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
lora_out_dir:
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
# ReLoRA configuration # ReLoRA configuration
@@ -685,7 +663,6 @@ relora_warmup_steps: # Number of per-restart warmup steps
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it # wandb configuration if you're using it
# Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
wandb_project: # Your wandb project name wandb_project: # Your wandb project name
wandb_entity: # A wandb Team name if using a Team wandb_entity: # A wandb Team name if using a Team
@@ -714,11 +691,9 @@ warmup_ratio: 0.05 # cannot use with warmup_steps
learning_rate: 0.00003 learning_rate: 0.00003
lr_quadratic_warmup: lr_quadratic_warmup:
logging_steps: logging_steps:
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
save_strategy: # Set to `no` to skip checkpoint saves save_strategy: # Set to `no` to skip checkpoint saves
save_steps: # Leave empty to save at each epoch save_steps: # Leave empty to save at each epoch
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
save_total_limit: # Checkpoints saved at a time save_total_limit: # Checkpoints saved at a time
# Maximum number of iterations to train for. It precedes num_epochs which means that # Maximum number of iterations to train for. It precedes num_epochs which means that
# if both are set, num_epochs will not be guaranteed. # if both are set, num_epochs will not be guaranteed.
@@ -743,9 +718,6 @@ group_by_length: false
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false gradient_checkpointing: false
# additional kwargs to pass to the trainer for gradient checkpointing
# gradient_checkpointing_kwargs:
# use_reentrant: false
# Stop training after this many evaluation losses have increased in a row # Stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
@@ -800,7 +772,7 @@ max_grad_norm:
# Augmentation techniques # Augmentation techniques
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings # NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
# currently only supported on Llama and Mistral # currently only supported on Llama and Mistral
neftune_noise_alpha: noisy_embedding_alpha:
# Whether to bettertransformers # Whether to bettertransformers
flash_optimum: flash_optimum:
@@ -815,6 +787,11 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# Whether to use scaled-dot-product attention # Whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention: sdp_attention:
# Landmark attention (only llama)
landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# LLaMA only
xpos_rope:
# Resume from a specific checkpoint dir # Resume from a specific checkpoint dir
resume_from_checkpoint: resume_from_checkpoint:
@@ -937,9 +914,8 @@ accelerate launch -m axolotl.cli.train your_config.yml
You can optionally pre-tokenize dataset with the following before finetuning. You can optionally pre-tokenize dataset with the following before finetuning.
This is recommended for large datasets. This is recommended for large datasets.
- Set `dataset_prepared_path:` to a local folder for saving and loading pre-tokenized dataset. - Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
- (Optional): Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface. - Use `--debug` to see preprocessed examples.
- (Optional): Use `--debug` to see preprocessed examples.
```bash ```bash
python -m axolotl.cli.preprocess your_config.yml python -m axolotl.cli.preprocess your_config.yml
@@ -982,8 +958,6 @@ fsdp_config:
##### Weights & Biases Logging ##### Weights & Biases Logging
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
- wandb options - wandb options
```yaml ```yaml
wandb_mode: wandb_mode:
@@ -994,28 +968,9 @@ wandb_name:
wandb_log_model: wandb_log_model:
``` ```
##### Special Tokens ### Inference
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this: Pass the appropriate flag to the train command:
```yml
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
tokens: # these are delimiters
- "<|im_start|>"
- "<|im_end|>"
```
When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.
### Inference Playground
Axolotl allows you to load your model in an interactive terminal playground for quick experimentation.
The config file is the same config file used for training.
Pass the appropriate flag to the inference command, depending upon what kind of model was trained:
- Pretrained LORA: - Pretrained LORA:
```bash ```bash
@@ -1041,20 +996,18 @@ Please use `--sample_packing False` if you have it on and receive the error simi
### Merge LORA to base ### Merge LORA to base
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`. Add below flag to train command above
```bash ```bash
python3 -m axolotl.cli.merge_lora your_config.yml --lora_model_dir="./completed-model" python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
``` ```
You may need to use the `gpu_memory_limit` and/or `lora_on_cpu` config options to avoid running out of memory. If you still run out of CUDA memory, you can try to merge in system RAM with If you run out of CUDA memory, you can try to merge in system RAM with
```bash ```bash
CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ... CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
``` ```
although this will be very slow, and using the config options above are recommended instead.
## Common Errors 🧰 ## Common Errors 🧰
See also the [FAQ's](./docs/faq.md). See also the [FAQ's](./docs/faq.md).
@@ -1067,10 +1020,6 @@ Please reduce any below
- `gradient_accumulation_steps` - `gradient_accumulation_steps`
- `sequence_len` - `sequence_len`
If it does not help, try running without deepspeed and without accelerate (replace "accelerate launch" with "python") in the command.
Using adamw_bnb_8bit might also save you some memory.
> `failed (exitcode: -9)` > `failed (exitcode: -9)`
Usually means your system has run out of system memory. Usually means your system has run out of system memory.
@@ -1093,20 +1042,6 @@ It's safe to ignore it.
See the [NCCL](docs/nccl.md) guide. See the [NCCL](docs/nccl.md) guide.
### Tokenization Mismatch b/w Inference & Training
For many formats, Axolotl constructs prompts by concatenating token ids _after_ tokenizing strings. The reason for concatenating token ids rather than operating on strings is to maintain precise accounting for attention masks.
If you decode a prompt constructed by axolotl, you might see spaces between tokens (or lack thereof) that you do not expect, especially around delimiters and special tokens. When you are starting out with a new format, you should always do the following:
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same adjust your inference server accordingly.
4. As an additional troubleshooting step, you can look look at the token ids between 1 and 2 to make sure they are identical.
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example.
## Need help? 🙋♂️ ## Need help? 🙋♂️
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you

View File

@@ -1,39 +0,0 @@
{
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 0,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 0,
"stage3_max_reuse_distance": 0,
"stage3_gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -10,7 +10,7 @@ ARG PYTORCH_VERSION="2.0.1"
ENV PYTORCH_VERSION=$PYTORCH_VERSION ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev apt-get install -y vim curl
WORKDIR /workspace WORKDIR /workspace
@@ -19,15 +19,13 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
WORKDIR /workspace/axolotl WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \ pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
else \ else \
pip install -e .[deepspeed,flash-attn]; \ pip install -e .[deepspeed,flash-attn]; \
fi fi
# So we can test the Docker image
RUN pip install pytest
# fix so that git fetch/pull from remote works # fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch git config --get remote.origin.fetch

View File

@@ -1,35 +0,0 @@
# RLHF (Beta)
### Overview
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
feedback. Various methods include, but not limited to:
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
- Direct Preference Optimization (DPO)
- Identity Preference Optimization (IPO)
### RLHF using Axolotl
[!IMPORTANT]
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
#### DPO
```yaml
rl: true
datasets:
- path: Intel/orca_dpo_pairs
split: train
type: intel_apply_chatml
- path: argilla/ultrafeedback-binarized-preferences
split: train
type: argilla_apply_chatml
```
#### IPO
```yaml
rl: ipo
```

View File

@@ -72,8 +72,8 @@ gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 32 warmup_steps: 32
evals_per_epoch: 4 eval_steps:
saves_per_epoch: 1 save_steps:
save_total_limit: save_total_limit:
debug: debug:

View File

@@ -49,8 +49,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -54,8 +54,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -56,8 +56,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -54,8 +54,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -56,8 +56,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -54,8 +54,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -56,8 +56,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -51,8 +51,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 40 warmup_steps: 40
evals_per_epoch: 4 eval_steps: 5
saves_per_epoch: 1 save_steps: 43
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -80,8 +80,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 5
saves_per_epoch: 1 save_steps: 10
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.000001 weight_decay: 0.000001

View File

@@ -51,8 +51,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 40 warmup_steps: 40
evals_per_epoch: 4 eval_steps: 5
saves_per_epoch: 1 save_steps: 43
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -46,8 +46,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -42,8 +42,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20
evals_per_epoch: 4 eval_steps: 110
saves_per_epoch: 1 save_steps: 660
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -58,9 +58,9 @@ flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true flash_attn_fuse_mlp: true
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 4 eval_steps: 0.05
eval_table_size: eval_table_size:
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: #deepspeed/zero2.json # multi-gpu only deepspeed: #deepspeed/zero2.json # multi-gpu only
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -62,8 +62,8 @@ flash_attention:
sdp_attention: sdp_attention:
flash_optimum: flash_optimum:
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 4 eval_steps:
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -54,10 +54,10 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -56,9 +56,9 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
eval_table_size: eval_table_size:
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -60,8 +60,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps: 50
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -1,4 +1,5 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -16,7 +17,6 @@ output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true
pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
@@ -54,11 +54,15 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 eval_table_size:
save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0
fsdp: fsdp:
fsdp_config: fsdp_config:
special_tokens: special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -47,10 +47,10 @@ xformers_attention:
flash_attention: flash_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps:
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 save_steps: 0.25
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -17,7 +17,6 @@ output_dir: ./out
sequence_len: 8192 sequence_len: 8192
sample_packing: true sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
eval_sample_packing: false
wandb_project: wandb_project:
wandb_entity: wandb_entity:
@@ -47,10 +46,10 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -1,5 +1,5 @@
base_model: mistralai/Mixtral-8x7B-v0.1 base_model: DiscoResearch/mixtral-7b-8expert
model_type: AutoModelForCausalLM model_type: MixtralForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
trust_remote_code: true trust_remote_code: true
@@ -14,18 +14,6 @@ dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./qlora-out output_dir: ./qlora-out
## You can optionally freeze the entire model and unfreeze a subset of parameters
unfrozen_parameters:
# - lm_head.*
# - model.embed_tokens.*
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
model_config:
output_router_logits: true
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
@@ -79,10 +67,10 @@ loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3 loss_watchdog_patience: 3
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps:
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed/zero2.json deepspeed: deepspeed/zero2.json
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -11,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.1 val_set_size: 0.05
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
@@ -66,10 +66,10 @@ loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3 loss_watchdog_patience: 3
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -44,8 +44,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20
evals_per_epoch: 4 eval_steps: 110
saves_per_epoch: 1 save_steps: 660
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0001 weight_decay: 0.0001

View File

@@ -49,8 +49,8 @@ flash_attention: true
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -54,8 +54,8 @@ flash_attention: true
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -48,8 +48,8 @@ flash_attention: true
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -59,8 +59,8 @@ xformers_attention:
flash_attention: flash_attention:
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -59,8 +59,8 @@ xformers_attention:
flash_attention: flash_attention:
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 4 eval_steps: 0.05
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.1 weight_decay: 0.1

View File

@@ -33,5 +33,5 @@ early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
weight_decay: 0.1 weight_decay: 0.1
evals_per_epoch: 4 eval_steps: 0.05
logging_steps: 1 logging_steps: 1

View File

@@ -56,10 +56,10 @@ xformers_attention:
flash_attention: flash_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -56,10 +56,10 @@ xformers_attention:
flash_attention: flash_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 0.05
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -45,8 +45,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20
evals_per_epoch: 4 eval_steps: 110
saves_per_epoch: 1 save_steps: 660
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0001 weight_decay: 0.0001

View File

@@ -45,8 +45,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20
evals_per_epoch: 4 eval_steps: 50
saves_per_epoch: 1 save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0 weight_decay: 0

View File

@@ -1,17 +0,0 @@
# Overview
This is a simple example of how to finetune TinyLlama1.1B using either lora or qlora:
LoRa:
```
accelerate launch -m axolotl.cli.train examples/tiny-llama/lora.yml
```
qLoRa:
```
accelerate launch -m axolotl.cli.train examples/tiny-llama/qlora.yml
```
Both take about 10 minutes to complete on a 4090.

View File

@@ -1,58 +0,0 @@
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false
strict: false
max_steps: 200
pretraining_dataset:
path: c4
name: en
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./model-out
sequence_len: 2048
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch:
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,66 +0,0 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -78,8 +78,8 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 eval_steps: 50
saves_per_epoch: 1 save_steps: 50
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.0

View File

@@ -1,5 +0,0 @@
# Overview
This is an example of a Yi-34B-Chat configuration. It demonstrates that it is possible to finetune a 34B model on a GPU with 24GB of VRAM.
Tested on an RTX 4090 with `python -m axolotl.cli.train examples/mistral/qlora.yml`, a single epoch of finetuning on the alpaca dataset using qlora runs in 47 mins, using 97% of available memory.

View File

@@ -1,76 +0,0 @@
base_model: 01-ai/Yi-34B-Chat
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: false
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
sequence_len: 1024
bf16: true
fp16: false
tf32: false
flash_attention: true
special_tokens:
bos_token: "<|startoftext|>"
eos_token: "<|endoftext|>"
unk_token: "<unk>"
# Data
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
warmup_steps: 10
# Iterations
num_epochs: 1
# Evaluation
val_set_size: 0.1
evals_per_epoch: 5
eval_table_size:
eval_table_max_new_tokens: 128
eval_sample_packing: false
eval_batch_size: 1
# LoRA
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
# Sampling
sample_packing: false
pad_to_sequence_len: false
# Batching
gradient_accumulation_steps: 4
micro_batch_size: 1
gradient_checkpointing: true
# wandb
wandb_project:
# Optimizer
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
# Misc
train_on_inputs: false
group_by_length: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
debug:
deepspeed:
weight_decay: 0
fsdp:
fsdp_config:

View File

@@ -2,7 +2,7 @@
auto-gptq==0.5.1 auto-gptq==0.5.1
packaging packaging
peft==0.6.0 peft==0.6.0
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0 transformers @ git+https://github.com/huggingface/transformers.git@df5c5c62ae253055336f5bb0828ca8e3e15ab6bd
tokenizers==0.15.0 tokenizers==0.15.0
bitsandbytes>=0.41.1 bitsandbytes>=0.41.1
accelerate==0.24.1 accelerate==0.24.1
@@ -29,7 +29,7 @@ scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml pynvml
art art
fschat==0.2.34 fschat==0.2.29
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard
@@ -37,5 +37,3 @@ tensorboard
s3fs s3fs
gcsfs gcsfs
# adlfs # adlfs
trl @ git+https://github.com/huggingface/trl.git@main

View File

@@ -1,7 +1,5 @@
"""setup.py for axolotl""" """setup.py for axolotl"""
from importlib.metadata import PackageNotFoundError, version
from setuptools import find_packages, setup from setuptools import find_packages, setup
@@ -24,13 +22,12 @@ def parse_requirements():
# Handle standard packages # Handle standard packages
_install_requires.append(line) _install_requires.append(line)
try: # TODO(wing) remove once xformers release supports torch 2.1.0
torch_version = version("torch") if "torch==2.1.0" in _install_requires:
if torch_version.startswith("2.1.1"): _install_requires.pop(_install_requires.index("xformers>=0.0.22"))
_install_requires.pop(_install_requires.index("xformers==0.0.22")) _install_requires.append(
_install_requires.append("xformers==0.0.23") "xformers @ git+https://github.com/facebookresearch/xformers.git@main"
except PackageNotFoundError: )
pass
return _install_requires, _dependency_links return _install_requires, _dependency_links

View File

@@ -2,7 +2,6 @@
import importlib import importlib
import logging import logging
import math
import os import os
import random import random
import sys import sys
@@ -17,7 +16,6 @@ import yaml
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
from accelerate.commands.config import config_args from accelerate.commands.config import config_args
from art import text2art from art import text2art
from datasets import concatenate_datasets, load_dataset
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
@@ -25,7 +23,7 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta from axolotl.train import TrainDatasetMeta
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import add_defaults, normalize_config, validate_config
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process from axolotl.utils.distributed import is_main_process
@@ -73,7 +71,7 @@ def do_merge_lora(
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True
LOG.info("running merge of LoRA with base model") LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload(progressbar=True) model = model.merge_and_unload()
model.to(dtype=cfg.torch_dtype) model.to(dtype=cfg.torch_dtype)
if cfg.local_rank == 0: if cfg.local_rank == 0:
@@ -81,7 +79,6 @@ def do_merge_lora(
model.save_pretrained( model.save_pretrained(
str(Path(cfg.output_dir) / "merged"), str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
progressbar=True,
) )
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
@@ -106,7 +103,15 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter importlib.import_module("axolotl.prompters"), prompter
) )
model = model.to(cfg.device, dtype=cfg.torch_dtype) if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
model = model.to(cfg.device)
while True: while True:
print("=" * 80) print("=" * 80)
@@ -171,7 +176,15 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter importlib.import_module("axolotl.prompters"), prompter
) )
model = model.to(cfg.device, dtype=cfg.torch_dtype) if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
model = model.to(cfg.device)
def generate(instruction): def generate(instruction):
if not instruction: if not instruction:
@@ -288,6 +301,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
normalize_config(cfg) normalize_config(cfg)
add_defaults(cfg)
setup_wandb_env_vars(cfg) setup_wandb_env_vars(cfg)
return cfg return cfg
@@ -328,94 +343,6 @@ def load_datasets(
) )
def load_rl_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
) -> TrainDatasetMeta:
train_datasets: List[Any] = []
for i, ds_cfg in enumerate(cfg.datasets):
train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"]))
# eval_dataset = load_dataset(
# cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"]
# )
eval_dataset = None
def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
return sample
def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample
def apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample
def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample
for i, data_set in enumerate(train_datasets):
_type = cfg.datasets[i]["type"]
ds_type_fn = locals()[_type]
train_datasets[i] = data_set.map(ds_type_fn)
train_dataset = concatenate_datasets(train_datasets)
# eval_dataset = eval_dataset.map(intel_apply_chatml)
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def check_accelerate_default_config(): def check_accelerate_default_config():
if Path(config_args.default_yaml_config_file).exists(): if Path(config_args.default_yaml_config_file).exists():
LOG.warning( LOG.warning(

View File

@@ -18,22 +18,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
return_remaining_strings=True return_remaining_strings=True
) )
parsed_cli_args.merge_lora = True parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(config, merge_lora=True, **kwargs)
parsed_cfg = load_cfg(
config,
merge_lora=True,
load_in_8bit=False,
load_in_4bit=False,
flash_attention=False,
**kwargs,
)
if not parsed_cfg.lora_model_dir and parsed_cfg.output_dir:
parsed_cfg.lora_model_dir = parsed_cfg.output_dir
if not Path(parsed_cfg.lora_model_dir).exists():
raise ValueError(
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
)
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)

View File

@@ -12,7 +12,6 @@ from axolotl.cli import (
check_user_token, check_user_token,
load_cfg, load_cfg,
load_datasets, load_datasets,
load_rl_datasets,
print_axolotl_text_art, print_axolotl_text_art,
) )
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
@@ -23,18 +22,15 @@ LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Path = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
print_axolotl_text_art() print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs)) parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
) )
if parsed_cfg.rl: dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)

View File

@@ -9,7 +9,7 @@ import math
import sys import sys
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import wraps from functools import partial
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@@ -20,7 +20,6 @@ from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_utils import seed_worker from transformers.trainer_utils import seed_worker
from trl import DPOTrainer
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
@@ -60,12 +59,6 @@ class AxolotlTrainingArguments(TrainingArguments):
default=False, default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."}, metadata={"help": "Use quadratic warmup for cosine scheduling."},
) )
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field( sample_packing: bool = field(
default=False, default=False,
metadata={"help": "Use sample packing for efficient training."}, metadata={"help": "Use sample packing for efficient training."},
@@ -127,7 +120,6 @@ class AxolotlTrainer(Trainer):
""" """
args = None # type: AxolotlTrainingArguments args = None # type: AxolotlTrainingArguments
tag_names = ["axolotl"]
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs): def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
self.num_epochs = num_epochs self.num_epochs = num_epochs
@@ -163,7 +155,7 @@ class AxolotlTrainer(Trainer):
return self.lr_scheduler return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining: if self.args.sample_packing:
return MultipackBatchSampler( return MultipackBatchSampler(
RandomSampler(self.train_dataset), RandomSampler(self.train_dataset),
self.args.train_batch_size, self.args.train_batch_size,
@@ -199,7 +191,7 @@ class AxolotlTrainer(Trainer):
return super()._get_eval_sampler(eval_dataset) return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing and not self.args.pretraining: if self.args.sample_packing:
train_dataset = self.train_dataset train_dataset = self.train_dataset
train_dataset = train_dataset.remove_columns(["length"]) train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator data_collator = self.data_collator
@@ -298,41 +290,12 @@ class AxolotlTrainer(Trainer):
# return (loss, outputs) if return_outputs else loss # return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs) return super().compute_loss(model, inputs, return_outputs=return_outputs)
def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs
@wraps(Trainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = self._sanitize_kwargs_for_tagging(
tag_names=self.tag_names, kwargs=kwargs
)
return super().push_to_hub(*args, **kwargs)
class AxolotlMambaTrainer(AxolotlTrainer): class AxolotlMambaTrainer(AxolotlTrainer):
""" """
Mamba specific trainer to handle loss calculation Mamba specific trainer to handle loss calculation
""" """
tag_names = ["axolotl", "mamba"]
def compute_loss( def compute_loss(
self, self,
model, model,
@@ -359,8 +322,6 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
Trainer subclass that uses the OneCycleLR scheduler Trainer subclass that uses the OneCycleLR scheduler
""" """
tag_names = ["axolotl", "onecycle"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.lr_scheduler = None self.lr_scheduler = None
@@ -390,8 +351,6 @@ class ReLoRATrainer(AxolotlTrainer):
Trainer subclass that uses the OneCycleLR scheduler Trainer subclass that uses the OneCycleLR scheduler
""" """
tag_names = ["axolotl", "relora"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.lr_scheduler = None self.lr_scheduler = None
@@ -427,21 +386,12 @@ class TrainerBuilderBase(abc.ABC):
_train_dataset = None _train_dataset = None
_eval_dataset = None _eval_dataset = None
_model_ref = None
def __init__(self, cfg, model, tokenizer): def __init__(self, cfg, model, tokenizer):
self.cfg = cfg self.cfg = cfg
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
@property
def model_ref(self):
return self._model_ref
@model_ref.setter
def model_ref(self, model):
self._model_ref = model
@property @property
def train_dataset(self): def train_dataset(self):
return self._train_dataset return self._train_dataset
@@ -582,14 +532,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[ training_arguments_kwargs[
"gradient_checkpointing" "gradient_checkpointing"
] = self.cfg.gradient_checkpointing ] = self.cfg.gradient_checkpointing
if self.cfg.gradient_checkpointing_kwargs:
training_arguments_kwargs[
"gradient_checkpointing_kwargs"
] = self.cfg.gradient_checkpointing_kwargs
else:
training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
}
if self.cfg.fsdp: if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config: if self.cfg.fsdp_config:
@@ -617,7 +559,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_arguments_kwargs["push_to_hub"] = True training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True training_arguments_kwargs["hub_private_repo"] = True
training_arguments_kwargs["hub_always_push"] = True
if self.cfg.hub_strategy: if self.cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
@@ -751,9 +692,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep") and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine" else "cosine"
) )
training_arguments_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
training_arguments_kwargs["weight_decay"] = ( training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
) )
@@ -774,13 +712,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs training_arguments_kwargs
) )
training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[
"neftune_noise_alpha"
] = self.cfg.neftune_noise_alpha
training_args = ( training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs, **training_arguments_kwargs,
@@ -806,6 +737,26 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64 data_collator_kwargs["pad_to_multiple_of"] = 64
if self.cfg.is_llama_derived_model and self.cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import (
add_mem_tokens,
get_mem_id,
set_model_mem_id,
)
set_model_mem_id(self.model, self.tokenizer)
LOG.info("Adding landmark attention tokens to dataset")
for dataset in [self.train_dataset, self.eval_dataset]:
dataset = dataset.map(
partial(
add_mem_tokens, mem_freq=50, mem_id=get_mem_id(self.tokenizer)
),
batched=False,
num_proc=32,
)
trainer_cls = self._get_trainer_cls() trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls trainer_kwargs, trainer_cls
@@ -815,7 +766,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset, eval_dataset=self.eval_dataset,
args=training_args, args=training_args,
data_collator=self.build_collator(training_args, **data_collator_kwargs), data_collator=self.build_collator(**data_collator_kwargs),
bench_data_collator=transformers.DataCollatorForSeq2Seq( bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer, self.tokenizer,
return_tensors="pt", return_tensors="pt",
@@ -836,10 +787,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return trainer return trainer
def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs): def build_collator(self, **kwargs):
if training_args.pretraining:
return None
if self.cfg.model_config_type == "mamba": if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer) return MambaDataCollator(tokenizer=self.tokenizer)
@@ -848,96 +796,3 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return_tensors="pt", return_tensors="pt",
**kwargs, **kwargs,
) )
class HFDPOTrainerBuilder(TrainerBuilderBase):
"""
Trainer factory class for DPO Trainer
"""
def get_callbacks(self):
callbacks = []
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
return callbacks
def build_training_arguments(self, total_num_steps):
training_args_kwargs = {}
for arg in [
"adam_beta1",
"adam_beta2",
"adam_epsilon",
"dataloader_num_workers",
"dataloader_pin_memory",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
training_args = TrainingArguments(
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=total_num_steps,
remove_unused_columns=False,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
evaluation_strategy="no",
# eval_steps=self.cfg.eval_steps,
save_strategy="steps",
save_steps=self.cfg.save_steps,
output_dir=self.cfg.output_dir,
warmup_steps=self.cfg.warmup_steps,
bf16=True,
gradient_checkpointing=self.cfg.gradient_checkpointing,
gradient_checkpointing_kwargs={"use_reentrant": False},
logging_first_step=True,
logging_steps=1,
optim=self.cfg.optimizer,
save_total_limit=self.cfg.save_total_limit or 5,
**training_args_kwargs,
)
return training_args
def build(self, total_num_steps):
training_args = self.build_training_arguments(total_num_steps)
dpo_trainer_kwargs = {}
if self.cfg.rl == "ipo":
dpo_trainer_kwargs["loss_type"] = "ipo"
if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
dpo_trainer = DPOTrainer(
self.model,
self.model_ref,
args=training_args,
beta=self.cfg.dpo_beta or 0.1,
train_dataset=self.train_dataset,
# eval_dataset=self.eval_dataset,
eval_dataset=None,
tokenizer=self.tokenizer,
max_length=self.cfg.sequence_len,
max_target_length=None,
max_prompt_length=self.cfg.sequence_len,
generate_during_eval=True,
**dpo_trainer_kwargs,
)
return dpo_trainer
class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
HF Factory class for PPO Trainer
"""
def get_callbacks(self):
callbacks = []
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
return callbacks
def build(self, total_num_steps):
# build PPOConfig
pass

View File

@@ -1,66 +0,0 @@
"""
module for TRL PPO training
"""
import torch
from tqdm import tqdm
from trl import PPOTrainer
class TRLPPOTrainer(PPOTrainer):
"""
wrapper for ppo trainer to handle customizations
"""
def train(
self,
reward_pipe,
resume_from_checkpoint=None, # pylint: disable=unused-argument
):
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": 32,
}
sent_kwargs = {
"return_all_scores": True,
"function_to_apply": "none",
"batch_size": 16,
}
for epoch, batch in tqdm( # pylint: disable=unused-variable
enumerate(self.dataloader)
):
query_tensors = batch["input_ids"]
# generate model response
response_tensors, ref_response_tensors = self.generate(
query_tensors,
return_prompt=False,
generate_ref_response=True,
**generation_kwargs
)
batch["response"] = self.tokenizer.batch_decode(response_tensors)
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
# Compute sentiment score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = reward_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs)
ref_rewards = [
torch.tensor(output[1]["score"]) for output in ref_pipe_outputs
]
batch["ref_rewards"] = ref_rewards
# Run PPO step
stats = self.step(query_tensors, response_tensors, rewards)
self.log_stats(
stats,
batch,
rewards,
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
)

View File

@@ -0,0 +1,6 @@
"""
Custom modeling code for mixtral
"""
from .configuration_moe_mistral import MixtralConfig # noqa
from .modeling_moe_mistral import MixtralForCausalLM # noqa

View File

@@ -0,0 +1,154 @@
# coding=utf-8
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Mistral model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json",
"mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json",
}
class MixtralConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
[mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
[mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MistralModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 14336):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
allows sequence of up to 4096*32 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention window size. If not specified, will default to `4096`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import MistralModel, MistralConfig
>>> # Initializing a Mistral 7B style configuration
>>> configuration = MixtralConfig()
>>> # Initializing a model from the Mistral 7B style configuration
>>> model = MixtralModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "mistral"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
attention_dropout=0.0,
num_experts_per_token=2,
num_experts=8,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.num_experts = num_experts
self.num_experts_per_token = num_experts_per_token
# pylint: disable=duplicate-code
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

File diff suppressed because it is too large Load Diff

View File

@@ -82,44 +82,15 @@ def get_turns( # pylint: disable=too-many-return-statements
else: else:
yield role + ":", "" yield role + ":", ""
return return
if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral": if self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
if self.system_message: if self.system_message:
if self.messages:
# For llama, the system message is incorporated into the first human instruction
first_role, first_msg = self.messages[0]
if first_role == self.roles[0]:
system_prompt += first_msg
self.messages.pop(0)
yield "", system_prompt yield "", system_prompt
for i, (role, message) in enumerate(self.messages): else:
yield "", "[INST] "
for i, (role, message) in enumerate(self.messages[1:]):
if message: if message:
if (i % 2 == 0 and not self.system_message) or ( yield role + " ", message + seps[i % 2]
i % 2 != 0 and self.system_message
):
role = "<s> " + role
yield role + " ", message
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral":
contains_sys_msg = False
if self.system_message:
contains_sys_msg = True
if self.messages:
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline
first_role, first_msg = self.messages[0]
if first_role == self.roles[0]:
system_prompt = self.system_template.format(
system_message=" " + self.system_message
)
system_prompt += first_msg
self.messages.pop(0)
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message and i == 0 and not contains_sys_msg:
yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a `<s> [INST]` at the beginning of the first instruction.
elif message:
yield role + " ", message
else: else:
yield role, "" yield role, ""
return return
@@ -147,15 +118,6 @@ def get_turns( # pylint: disable=too-many-return-statements
else: else:
yield role + "\n", "" yield role + "\n", ""
return return
if self.sep_style == SeparatorStyle.CHATGLM3:
if self.system_message:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role + "\n", " " + message
else:
yield role
return
if self.sep_style == SeparatorStyle.CHATINTERN: if self.sep_style == SeparatorStyle.CHATINTERN:
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
seps = [self.sep, self.sep2] seps = [self.sep, self.sep2]

File diff suppressed because it is too large Load Diff

View File

@@ -1,22 +0,0 @@
"""
Patches to support multipack for mixtral
"""
import transformers
def replace_mixtral_attn_with_multipack_flash_attn():
from .modeling_mixtral import (
MixtralMultipackFlashAttention2,
mixtral_decoder_layer_forward,
mixtral_model_forward,
)
transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = (
mixtral_decoder_layer_forward
)
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
mixtral_model_forward
)
transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[
"flash_attention_2"
] = MixtralMultipackFlashAttention2

View File

@@ -1,383 +0,0 @@
"""
Mixtral modeling for multipack
"""
# pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code
import logging
import warnings
from typing import List, Optional, Tuple, Union
import torch
from einops import rearrange
from flash_attn import flash_attn_varlen_qkvpacked_func
from transformers import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import MoeModelOutputWithPast
from transformers.models.mixtral.modeling_mixtral import (
MixtralFlashAttention2,
apply_rotary_pos_emb,
repeat_kv,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
LOG = logging.getLogger("axolotl.monkeypatch.mixtral")
class MixtralMultipackFlashAttention2(MixtralFlashAttention2):
"""
Custom multipack implementation w flash attention 2
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._flash_attn_uses_top_left_mask = True
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
attn_output = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
dropout_p=self.attention_dropout,
softmax_scale=None,
causal=True,
)
attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def mixtral_decoder_layer_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_router_logits (`bool`, *optional*):
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
should not be returned during inference.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
if output_router_logits:
outputs += (router_logits,)
return outputs
def mixtral_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MoeModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
cu_seqlens = None
max_seqlen = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if (
attention_mask is not None
and self._attn_implementation == "flash_attention_2"
and use_cache
):
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = (
attention_mask
if (attention_mask is not None and 0 in attention_mask)
else None
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
LOG.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
output_router_logits,
use_cache,
cu_seqlens,
max_seqlen,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if output_router_logits:
all_router_logits += (layer_outputs[-1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if use_legacy_cache
else next_decoder_cache
)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
all_router_logits,
]
if v is not None
)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)

View File

@@ -0,0 +1,65 @@
"""
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
"""
import torch
from peft import PeftModel
from transformers import PreTrainedModel
def patch_neft(alpha, model):
embeddings = None
if isinstance(model, PreTrainedModel):
embeddings = model.get_input_embeddings()
if isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
if not embeddings:
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
embeddings.noisy_embedding_alpha = alpha
old_forward = embeddings.forward
# This hack seems to be needed to properly use a custom forward pass
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
embeddings, embeddings.__class__
)
setattr(embeddings, "forward", bound_method)
embeddings._old_forward = old_forward # pylint: disable=protected-access
return model
def unpatch_neft(model):
embeddings = None
if isinstance(model, PreTrainedModel):
embeddings = model.get_input_embeddings()
if isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
if not embeddings:
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
if hasattr(embeddings, "_old_forward"):
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
del embeddings._old_forward # pylint: disable=protected-access
del embeddings.noisy_embedding_alpha
def neft_forward(self, inputs: torch.Tensor):
embeddings = self._old_forward(inputs) # pylint: disable=protected-access
if self.training:
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
-mag_norm, mag_norm
)
return embeddings
def pretrain_hook(cfg, trainer):
if cfg.noisy_embedding_alpha:
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
def post_train_hook(cfg, trainer):
if cfg.noisy_embedding_alpha:
unpatch_neft(trainer.model)

View File

@@ -0,0 +1,94 @@
# pylint: skip-file
"""
Copied from https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
"""
import torch
import transformers
import transformers.models.llama.modeling_llama
from einops import rearrange
class XposRotaryEmbedding(torch.nn.Module):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scale_base=2048,
use_xpos=True,
):
super().__init__()
self.max_seq_len_cached = max_position_embeddings
self.scale_base = scale_base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(self.max_seq_len_cached, device=device).type_as(inv_freq)
freqs = torch.einsum("i , j -> i j", t, inv_freq)
freqs = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("freqs_cached", freqs, persistent=False)
if not use_xpos:
self.register_buffer("scale", None)
self.register_buffer("scale_cached", torch.ones(1))
return
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
power = (t - (self.max_seq_len_cached // 2)) / self.scale_base
scale_cached = scale ** rearrange(power, "n -> n 1")
scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
self.register_buffer("scale", scale, persistent=False)
self.register_buffer("scale_cached", scale_cached, persistent=False)
def forward(
self,
x,
seq_len,
):
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device).type_as(
self.inv_freq
)
freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim=-1).to(dtype=x.dtype)
self.register_buffer("freqs_cached", freqs)
if self.scale is None:
self.register_buffer(
"scale_cached", torch.ones(1, device=x.device).to(dtype=x.dtype)
)
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached
power = (t - (seq_len // 2)) / self.scale_base
scale = self.scale ** rearrange(power, "n -> n 1")
scale = torch.cat((scale, scale), dim=-1).to(dtype=x.dtype)
self.register_buffer("scale_cached", scale)
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached.to(dtype=x.dtype)
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, freqs, scale=1, position_ids=None):
freqs = freqs[position_ids, :]
if scale.shape[-1] != 1:
scale = scale[position_ids, :]
q_embed = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale)
k_embed = (k * freqs.cos() * 1 / scale) + (rotate_half(k) * freqs.sin() * 1 / scale)
return q_embed, k_embed
def replace_llama_rope_with_xpos_rope():
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = XposRotaryEmbedding
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb

View File

@@ -81,9 +81,8 @@ class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.tokenizer.add_special_tokens( self.sequence_len = 4096
{"pad_token": getattr(self.tokenizer, "pad_token", "<pad>")} self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
)
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json # https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):

View File

@@ -39,23 +39,6 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
return strategy return strategy
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
strategy = UltrachatShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_role(tokenizer, cfg): def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy( return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(), ShareGPTPrompterV2(),
@@ -126,17 +109,3 @@ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations {"from": role_map[t["role"]], "value": t["text"]} for t in conversations
] ]
return turns return turns
class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps ultrachat data to sharegpt format
"""
def get_conversation_thread(self, prompt):
conversations = prompt["messages"]
role_map = {"user": "human", "assistant": "gpt"}
turns = [
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
]
return turns

View File

@@ -33,8 +33,8 @@ class AlpacaPrompter(Prompter):
Base class for alpaca prompters Base class for alpaca prompters
""" """
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request." system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request." system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
system_format: str = "{system}" system_format: str = "{system}"
turn_format: str turn_format: str
turn_no_input_format: str turn_no_input_format: str

View File

@@ -12,13 +12,12 @@ import transformers.modelcard
from accelerate.logging import get_logger from accelerate.logging import get_logger
from datasets import Dataset from datasets import Dataset
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from pkg_resources import get_distribution # type: ignore
from transformers.deepspeed import is_deepspeed_zero3_enabled from transformers.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.monkeypatch import neft_embeddings
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_parameters_except
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
@@ -61,12 +60,6 @@ def train(
msg += " and peft_config..." msg += " and peft_config..."
LOG.debug(msg) LOG.debug(msg)
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model_ref = None
if cfg.rl:
# load the model again for model_ref/baseline
model_ref, _ = load_model(
cfg, tokenizer, inference=cli_args.inference, reference_model=True
)
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True
@@ -85,11 +78,8 @@ def train(
) )
resume_from_checkpoint = cfg.resume_from_checkpoint resume_from_checkpoint = cfg.resume_from_checkpoint
if cfg.unfrozen_parameters:
freeze_parameters_except(model, cfg.unfrozen_parameters)
trainer = setup_trainer( trainer = setup_trainer(
cfg, train_dataset, eval_dataset, (model, model_ref), tokenizer, total_num_steps cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
) )
if hasattr(model, "config"): if hasattr(model, "config"):
@@ -122,12 +112,6 @@ def train(
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)""" badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
if getattr(cfg, "axolotl_config_path"):
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
version = get_distribution("axolotl").version
if raw_axolotl_cfg.is_file():
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
LOG.info("Starting trainer...") LOG.info("Starting trainer...")
if cfg.group_by_length: if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length") LOG.info("hang tight... sorting dataset for group_by_length")
@@ -188,26 +172,25 @@ def train(
if not cfg.hub_model_id: if not cfg.hub_model_id:
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./")) trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated
trainer.push_to_hub()
return model, tokenizer return model, tokenizer
def pretrain_hooks(_cfg, _trainer): def pretrain_hooks(cfg, trainer):
""" """
Run hooks right before kicking off the training Run hooks right before kicking off the training
:param cfg: :param cfg:
:param trainer: :param trainer:
:return: :return:
""" """
neft_embeddings.pretrain_hook(cfg, trainer)
def post_train_hooks(_cfg, _trainer): def post_train_hooks(cfg, trainer):
""" """
Run hooks right after training completes Run hooks right after training completes
:param cfg: :param cfg:
:param trainer: :param trainer:
:return: :return:
""" """
neft_embeddings.post_train_hook(cfg, trainer)

View File

@@ -4,8 +4,6 @@ from __future__ import annotations
import logging import logging
import os import os
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Dict, List from typing import TYPE_CHECKING, Dict, List
import evaluate import evaluate
@@ -563,15 +561,10 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
): ):
if is_main_process(): if is_main_process():
try: try:
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later. artifact = wandb.Artifact(name="axolotl-config", type="config")
with NamedTemporaryFile( artifact.add_file(local_path=self.axolotl_config_path)
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" wandb.run.log_artifact(artifact)
) as temp_file: LOG.info("Axolotl config has been saved to WandB as an artifact.")
copyfile(self.axolotl_config_path, temp_file.name)
wandb.save(temp_file.name)
LOG.info(
"The Axolotl config has been saved to the WandB run under files."
)
except (FileNotFoundError, ConnectionError) as err: except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}") LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
return control return control

View File

@@ -1,29 +0,0 @@
"""
This module provides functionality for selecting chat templates based on user choices.
These templates are used for formatting messages in a conversation.
"""
def chat_templates(user_choice: str):
"""
Finds the correct chat_template for the tokenizer_config.
Args:
user_choice (str): The user's choice of template.
Returns:
str: The chosen template string.
Raises:
ValueError: If the user_choice is not found in the templates.
"""
templates = {
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
}
if user_choice in templates:
return templates[user_choice]
raise ValueError(f"Template '{user_choice}' not found.")

View File

@@ -178,24 +178,3 @@ class MambaDataCollator:
"input_ids": input_ids, "input_ids": input_ids,
"labels": labels, "labels": labels,
} }
@dataclass
class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
Collator for multipack specific to the using the BatchSampler
"""
def __call__(self, features, return_tensors=None):
chunked_data = {}
for feature in features.keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [(1) * np.array(item) for item in features[feature]]
chunked_data[feature] = np.concatenate(arrays)
else:
arrays = [np.array(item) for item in features[feature]]
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)

View File

@@ -41,6 +41,16 @@ def choose_device(cfg):
cfg.device_map = None cfg.device_map = None
def add_defaults(cfg):
# setup sane defaults if left unspecified
if cfg.dataloader_num_workers is None:
cfg.dataloader_num_workers = int(os.getenv("WORLD_SIZE", "1"))
if cfg.dataloader_prefetch_factor is None:
cfg.dataloader_prefetch_factor = cfg.batch_size * 2
if cfg.dataloader_pin_memory is None:
cfg.dataloader_pin_memory = True
def normalize_config(cfg): def normalize_config(cfg):
# setup some derived config / hyperparams # setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or ( cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
@@ -77,15 +87,6 @@ def normalize_config(cfg):
else: else:
cfg.torch_dtype = torch.float32 cfg.torch_dtype = torch.float32
if cfg.saves_per_epoch:
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
if save_steps < 1.0: # prevent saves on every step
cfg.save_steps = save_steps
if cfg.evals_per_epoch:
eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
if eval_steps < 1.0: # prevent evals on every step
cfg.eval_steps = eval_steps
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count() cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
if not cfg.base_model_config: if not cfg.base_model_config:
@@ -361,27 +362,6 @@ def validate_config(cfg):
cfg.datasets[idx].type = cfg.datasets[idx].type.replace( cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
"sharegpt_simple", "sharegpt" "sharegpt_simple", "sharegpt"
) )
if cfg.saves_per_epoch and cfg.save_steps:
raise ValueError(
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
)
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
)
if cfg.evals_per_epoch and cfg.eval_steps:
raise ValueError(
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
)
if (
cfg.evals_per_epoch
and cfg.evaluation_strategy
and cfg.evaluation_strategy != "steps"
):
raise ValueError(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps": if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError( raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
@@ -422,6 +402,11 @@ def validate_config(cfg):
if cfg.warmup_steps and cfg.warmup_ratio: if cfg.warmup_steps and cfg.warmup_ratio:
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive") raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
if cfg.is_qwen_derived_model and cfg.gradient_checkpointing:
LOG.warning(
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
)
if cfg.wandb_run_id and not cfg.wandb_name: if cfg.wandb_run_id and not cfg.wandb_name:
cfg.wandb_name = cfg.wandb_run_id cfg.wandb_name = cfg.wandb_run_id
@@ -429,39 +414,6 @@ def validate_config(cfg):
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead." "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
) )
if cfg.noisy_embedding_alpha is not None:
# Deprecated, use neftune_noise_alpha
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
if cfg.neftune_noise_alpha is None:
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
else:
# User is providing both; bail and have them sort out their settings
raise ValueError(
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
)
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
raise ValueError("neftune_noise_alpha must be > 0.0")
if (
cfg.adapter
and cfg.tokens
and (
not cfg.lora_modules_to_save
or not all(
x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"]
)
)
):
raise ValueError(
"lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
)
if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
raise ValueError(
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
)
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -2,7 +2,6 @@
import functools import functools
import hashlib import hashlib
import logging import logging
from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
@@ -15,7 +14,6 @@ from datasets import (
load_from_disk, load_from_disk,
) )
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
@@ -41,14 +39,11 @@ from axolotl.prompters import (
SummarizeTLDRPrompter, SummarizeTLDRPrompter,
UnsupportedPrompter, UnsupportedPrompter,
) )
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.samplers.multipack import MultipackBatchSampler
from axolotl.utils.trainer import ( from axolotl.utils.trainer import (
calculate_total_num_steps, calculate_total_num_steps,
process_datasets_for_packing, process_datasets_for_packing,
process_pretraining_datasets_for_packing,
) )
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -69,17 +64,9 @@ def prepare_dataset(cfg, tokenizer):
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
) )
else: else:
path = cfg.pretraining_dataset
name = None
if isinstance(cfg.pretraining_dataset, dict):
path = cfg.pretraining_dataset["path"]
name = cfg.pretraining_dataset["name"]
train_dataset = load_pretraining_dataset( train_dataset = load_pretraining_dataset(
path, cfg.pretraining_dataset,
tokenizer, tokenizer,
cfg,
name=name,
max_tokens=cfg.sequence_len, max_tokens=cfg.sequence_len,
seed=cfg.seed or 42, seed=cfg.seed or 42,
) )
@@ -819,27 +806,9 @@ def encode_pretraining(
return ret return ret
def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42): def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
if cfg.sample_packing: encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( dataset = load_dataset(path, streaming=True, split="train")
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens * cfg.micro_batch_size,
)
encode = functools.partial(
encode_packed_pretraining,
tokenizer,
collate_fn,
max_seq_length=max_tokens,
batch_size=cfg.micro_batch_size,
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train", name=name)
dataset = dataset.shuffle(seed=seed, buffer_size=10_000) dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
dataset = dataset.map( dataset = dataset.map(
encode, encode,
@@ -850,63 +819,3 @@ def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, s
remove_columns=dataset.features.keys(), remove_columns=dataset.features.keys(),
) )
return dataset return dataset
def encode_packed_pretraining(
tokenizer: PreTrainedTokenizerBase,
collate_fn,
examples: List[str],
max_seq_length: int = 2048,
batch_size: int = 4,
) -> Dict[str, List]:
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
res = tokenizer(
examples,
truncation=True,
max_length=max_seq_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
attention_mask = [seq + [1] for seq in res["attention_mask"]]
tokenized_examples = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
train_dataset = Dataset.from_dict(tokenized_examples)
train_dataset = process_pretraining_datasets_for_packing(
train_dataset, max_seq_length
)
sampler = MultipackBatchSampler(
RandomSampler(train_dataset),
batch_size=batch_size,
drop_last=True,
batch_max_len=batch_size * max_seq_length,
lengths=(
train_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
)
chunked_data = defaultdict(list)
for data in sampler:
features = train_dataset[data]
features["labels"] = features["input_ids"].copy()
collated_features = collate_fn(features)
for feature in features.keys():
if feature == "length":
continue
chunked_data[feature].append(collated_features[feature].squeeze(0))
return chunked_data

View File

@@ -1,38 +0,0 @@
"""
module to freeze/unfreeze parameters by name
"""
import logging
import re
from axolotl.utils.distributed import is_main_process
LOG = logging.getLogger("axolotl.utils.freeze")
def freeze_parameters_except(model, regex_patterns):
"""
Freezes all layers of the given model except for the layers that match given regex patterns.
Periods in the patterns are treated as literal periods, not as wildcard characters.
Parameters:
- model (nn.Module): The PyTorch model to be modified.
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
Returns:
None; the model is modified in place.
"""
# Escape periods and compile the regex patterns
compiled_patterns = [
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
]
# First, freeze all parameters in the model
for param in model.parameters():
param.requires_grad = False
# Unfreeze layers that match the regex patterns
for name, param in model.named_parameters():
if any(pattern.match(name) for pattern in compiled_patterns):
if is_main_process():
LOG.debug(f"unfreezing {name}")
param.requires_grad = True

View File

@@ -2,7 +2,7 @@
import logging import logging
import math import math
import os import os
from typing import Any, Optional, Tuple # noqa: F401 from typing import Optional, Tuple # noqa: F401
import addict import addict
import bitsandbytes as bnb import bitsandbytes as bnb
@@ -21,12 +21,10 @@ from transformers import ( # noqa: F401
PreTrainedModel, PreTrainedModel,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from transformers.deepspeed import is_deepspeed_zero3_enabled
from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -56,19 +54,25 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
def load_model_config(cfg): def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code = cfg.trust_remote_code is True trust_remote_code = cfg.trust_remote_code is True
model_type = cfg.model_type
try: if model_type == "MixtralForCausalLM":
model_config = AutoConfig.from_pretrained( from axolotl.models.mixtral.configuration_moe_mistral import MixtralConfig
model_config_name, trust_remote_code=trust_remote_code
) model_config = MixtralConfig.from_pretrained(model_config_name)
except ValueError as err: else:
if "mamba" in model_config_name: try:
return addict.Dict( model_config = AutoConfig.from_pretrained(
{ model_config_name, trust_remote_code=trust_remote_code
"model_type": "mamba",
}
) )
raise err except ValueError as err:
if "mamba" in model_config_name:
return addict.Dict(
{
"model_type": "mamba",
}
)
raise err
if cfg.model_config: if cfg.model_config:
for key, val in cfg.model_config.items(): for key, val in cfg.model_config.items():
@@ -137,23 +141,6 @@ def load_tokenizer(cfg):
if cfg.special_tokens: if cfg.special_tokens:
for k, val in cfg.special_tokens.items(): for k, val in cfg.special_tokens.items():
# check if new special token is not already in tokenizer and
# is adapter training to make sure lora_modules_to_save is set
if (
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
and cfg.adapter
and (
not cfg.lora_modules_to_save
or not all(
x in cfg.lora_modules_to_save
for x in ["embed_tokens", "lm_head"]
)
)
):
raise ValueError(
"Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens."
)
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)} {k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
) )
@@ -187,12 +174,6 @@ def load_tokenizer(cfg):
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if cfg.chat_template:
tokenizer.chat_template = chat_templates(cfg.chat_template)
else:
LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference."
)
return tokenizer return tokenizer
@@ -200,7 +181,6 @@ def load_model(
cfg: DictDefault, cfg: DictDefault,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
inference: bool = False, inference: bool = False,
reference_model: bool = False,
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: ) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
""" """
Load a model for a given configuration and tokenizer. Load a model for a given configuration and tokenizer.
@@ -255,6 +235,17 @@ def load_model(
LOG.info("patching with sdp attention") LOG.info("patching with sdp attention")
hijack_llama_sdp_attention() hijack_llama_sdp_attention()
elif cfg.is_llama_derived_model and cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import (
MEM_TOKEN,
patch_llama_with_landmark_attn,
)
LOG.info("patching with landmark attention")
patch_llama_with_landmark_attn()
# Note: This might overwrite previous additional_special_tokens
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing: if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.mistral_attn_hijack_flash import ( from axolotl.monkeypatch.mistral_attn_hijack_flash import (
@@ -264,17 +255,13 @@ def load_model(
LOG.info("patching with flash attention") LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if ( if cfg.is_llama_derived_model and cfg.xpos_rope:
cfg.model_config_type == "mixtral" from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
and cfg.flash_attention replace_llama_rope_with_xpos_rope,
and cfg.sample_packing
):
from axolotl.monkeypatch.mixtral import (
replace_mixtral_attn_with_multipack_flash_attn,
) )
LOG.info("patching with flash attention") LOG.info("patching with xpos rope")
replace_mixtral_attn_with_multipack_flash_attn() replace_llama_rope_with_xpos_rope()
if ( if (
cfg.is_llama_derived_model cfg.is_llama_derived_model
@@ -288,50 +275,9 @@ def load_model(
model_kwargs = {} model_kwargs = {}
max_memory = cfg.max_memory model_kwargs["device_map"] = cfg.device_map
device_map = cfg.device_map model_kwargs["max_memory"] = cfg.max_memory
if cfg.gpu_memory_limit:
gpu_memory_limit = (
str(cfg.gpu_memory_limit) + "GiB"
if isinstance(cfg.gpu_memory_limit, int)
else cfg.gpu_memory_limit
)
max_memory = {}
for i in range(torch.cuda.device_count()):
max_memory[i] = gpu_memory_limit
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
if max_memory is not None:
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
from accelerate import infer_auto_device_map, init_empty_weights
with init_empty_weights():
model_canvas = AutoModelForCausalLM.from_config(model_config)
model_canvas.tie_weights()
device_map = infer_auto_device_map(
model_canvas,
max_memory=max_memory,
dtype=cfg.torch_dtype,
)
# We can discard max_memory now as we have a device map set up for us
max_memory = None
model_kwargs["device_map"] = device_map
model_kwargs["torch_dtype"] = cfg.torch_dtype model_kwargs["torch_dtype"] = cfg.torch_dtype
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
# if cfg.rl:
# if torch.cuda.device_count() > 1:
# if reference_model:
# model_kwargs["device_map"] = "cuda:" + str(
# torch.cuda.current_device() + 1
# )
# else:
# model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device())
if is_deepspeed_zero3_enabled():
del model_kwargs["device_map"]
if cfg.model_revision: if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision model_kwargs["revision"] = cfg.model_revision
@@ -347,45 +293,24 @@ def load_model(
**model_config.quantization_config **model_config.quantization_config
) )
if cfg.adapter == "qlora" and cfg.load_in_4bit: if cfg.adapter == "qlora" and cfg.load_in_4bit:
bnb_config = {
"load_in_4bit": True,
"llm_int8_threshold": 6.0,
"llm_int8_has_fp16_weight": False,
"bnb_4bit_compute_dtype": cfg.torch_dtype,
"bnb_4bit_use_double_quant": True,
"bnb_4bit_quant_type": "nf4",
}
if cfg.bnb_config_kwargs:
bnb_config.update(cfg.bnb_config_kwargs)
model_kwargs["quantization_config"] = BitsAndBytesConfig( model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config, load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=cfg.torch_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
) )
# sample packing uses custom FA2 patch # sample packing uses custom FA2 patch
if cfg.flash_attention: if cfg.flash_attention and not cfg.sample_packing:
if not cfg.sample_packing: if (
if ( cfg.is_llama_derived_model
cfg.is_llama_derived_model or cfg.is_falcon_derived_model
or cfg.is_falcon_derived_model or cfg.is_mistral_derived_model
or cfg.is_mistral_derived_model ):
or model_config.model_type == "mixtral" # TODO enable once properly supported in transformers
): # model_kwargs["attn_implementation"] = "flash_attention_2"
model_kwargs["attn_implementation"] = "flash_attention_2" model_kwargs["use_flash_attention_2"] = True # legacy, to be deprecated
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
if model_config.model_type == "mixtral":
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
model_kwargs["attn_implementation"] = "eager"
model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
try: try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
@@ -447,6 +372,15 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs, **model_kwargs,
) )
elif model_type == "MixtralForCausalLM":
from axolotl.models.mixtral import MixtralForCausalLM
model = MixtralForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,
)
elif model_type == "MambaLMHeadModel": elif model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work # FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
@@ -455,6 +389,7 @@ def load_model(
model_kwargs["device"] = torch.cuda.current_device() model_kwargs["device"] = torch.cuda.current_device()
del model_kwargs["torch_dtype"] del model_kwargs["torch_dtype"]
del model_kwargs["device_map"] del model_kwargs["device_map"]
del model_kwargs["max_memory"]
model = MambaLMHeadModel.from_pretrained( model = MambaLMHeadModel.from_pretrained(
base_model, base_model,
@@ -598,11 +533,9 @@ def load_model(
if hasattr(module, "weight"): if hasattr(module, "weight"):
module.to(cfg.torch_dtype) module.to(cfg.torch_dtype)
lora_config = None model, lora_config = load_adapter(model, cfg, cfg.adapter)
if not reference_model or cfg.lora_model_dir:
model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit): if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}") model.to(f"cuda:{cfg.local_rank}")
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
@@ -711,15 +644,10 @@ def load_lora(model, cfg, inference=False):
if cfg.lora_model_dir: if cfg.lora_model_dir:
LOG.debug("Loading pretained PEFT - LoRA") LOG.debug("Loading pretained PEFT - LoRA")
model_kwargs: Any = {}
if cfg.lora_on_cpu:
model_kwargs["max_memory"] = {"cpu": "256GiB"}
model_kwargs["device_map"] = {"": "cpu"}
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
cfg.lora_model_dir, cfg.lora_model_dir,
is_trainable=(not inference), is_trainable=(not inference),
**model_kwargs,
) )
else: else:
model = get_peft_model(model, lora_config) model = get_peft_model(model, lora_config)

View File

@@ -31,8 +31,8 @@ def check_example_labels(example, tokenizer, text_only=False):
) )
colored_tokens.append(colored_token) colored_tokens.append(colored_token)
output = " ".join(colored_tokens) delimiter = "" if text_only else " "
LOG.info(output) LOG.info(delimiter.join(colored_tokens))
LOG.info("\n\n\n") LOG.info("\n\n\n")
return output return " ".join(colored_tokens)

View File

@@ -12,7 +12,7 @@ from accelerate.logging import get_logger
from datasets import set_caching_enabled from datasets import set_caching_enabled
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader, RandomSampler
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder from axolotl.core.trainer_builder import HFCausalTrainerBuilder
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
from axolotl.utils.samplers import MultipackBatchSampler from axolotl.utils.samplers import MultipackBatchSampler
@@ -143,16 +143,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
return train_dataset, eval_dataset return train_dataset, eval_dataset
def process_pretraining_datasets_for_packing(train_dataset, sequence_len):
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
train_dataset = train_dataset.filter(drop_long)
train_dataset = train_dataset.map(
add_position_ids,
)
return train_dataset
def calculate_total_num_steps(cfg, train_dataset, update=True): def calculate_total_num_steps(cfg, train_dataset, update=True):
if not cfg.total_num_tokens: if not cfg.total_num_tokens:
total_num_tokens = np.sum( total_num_tokens = np.sum(
@@ -286,16 +276,10 @@ def prepare_optim_env(cfg):
setup_fsdp_envs(cfg) setup_fsdp_envs(cfg)
elif cfg.deepspeed: elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl: trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1]
else:
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.train_dataset = train_dataset trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset trainer_builder.eval_dataset = eval_dataset

View File

@@ -1,59 +0,0 @@
"""
unit tests for axolotl.core.trainer_builder
"""
import pytest
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
@pytest.fixture(name="cfg")
def fixture_cfg():
return DictDefault(
{
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00005,
"save_steps": 100,
"output_dir": "./model-out",
"warmup_steps": 10,
"gradient_checkpointing": False,
"optimizer": "adamw_torch",
"sequence_len": 2048,
"rl": True,
"adam_beta1": 0.998,
"adam_beta2": 0.9,
"adam_epsilon": 0.00001,
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
}
)
@pytest.fixture(name="tokenizer")
def fixture_tokenizer(cfg):
return load_tokenizer(cfg)
@pytest.fixture(name="model")
def fixture_model(cfg, tokenizer):
return load_model(cfg, tokenizer)
class TestHFDPOTrainerBuilder:
"""
TestCase class for DPO trainer builder
"""
def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100)
assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9
assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True

View File

@@ -1,109 +0,0 @@
"""
E2E tests for mixtral
"""
import logging
import os
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestMixtral(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_qlora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True,
"sequence_len": 1024,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -1,123 +0,0 @@
"""
E2E tests for mixtral
"""
import logging
import os
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestMixtral(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_qlora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True,
"sequence_len": 2048,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"sample_packing": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True,
"sequence_len": 2048,
"val_set_size": 0.1,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"sample_packing": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (
"axolotl.monkeypatch.mixtral.modeling_mixtral"
in model.model.layers[0].self_attn.__class__.__module__
)
assert (
"MixtralMultipackFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -1,99 +0,0 @@
"""
E2E smoke tests to check that the monkeypatches are in place for certain configurations
"""
import unittest
from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from .utils import with_temp_dir
class TestModelPatches(unittest.TestCase):
"""
TestCases for the multipack monkey patches
"""
@with_temp_dir
def test_mixtral_multipack(self, temp_dir):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"val_set_size": 0.1,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
tokenizer = load_tokenizer(cfg)
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
assert (
"axolotl.monkeypatch.mixtral.modeling_mixtral"
in model.model.layers[0].self_attn.__class__.__module__
)
assert (
"MixtralMultipackFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
@with_temp_dir
def test_mistral_multipack(self, temp_dir):
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"val_set_size": 0.1,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
tokenizer = load_tokenizer(cfg)
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
assert (
"axolotl.monkeypatch.mistral_attn_hijack_flash"
in model.model.layers[0].self_attn.forward.__module__
)

File diff suppressed because one or more lines are too long

View File

@@ -1,82 +0,0 @@
"""Module for testing streaming dataset sequence packing"""
import unittest
from functools import partial
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data import encode_packed_pretraining
class TestPacking(unittest.TestCase):
"""
Test class for packing streaming dataset sequences
"""
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.pad_token = "</s>"
self.max_seq_length = 2048
self.batch_size = 2
def test_packing_stream_dataset(self):
# pylint: disable=duplicate-code
dataset = load_dataset(
"c4",
"en",
streaming=True,
)["train"]
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=self.max_seq_length,
)
encode = partial(
encode_packed_pretraining,
self.tokenizer,
collate_fn,
max_seq_length=self.max_seq_length,
batch_size=self.batch_size,
)
dataset = dataset.map(
encode,
batched=True,
input_columns="text",
remove_columns=dataset.features.keys(),
)
trainer_loader = DataLoader(
dataset,
batch_size=1,
collate_fn=None,
drop_last=True,
)
idx = 0
for data in trainer_loader:
if idx > 10:
break
assert data["input_ids"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
)
assert data["position_ids"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
)
assert data["labels"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
)
assert data["attention_mask"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
)
idx += 1
if __name__ == "__main__":
unittest.main()

View File

@@ -2,7 +2,6 @@
import json import json
import logging import logging
import unittest import unittest
from copy import deepcopy
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@@ -26,50 +25,6 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
test_data = {
"multi_turn_sys": {
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "123"},
{"from": "gpt", "value": "sit"},
]
},
"single_turn_sys": {
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
]
},
"single_turn_no_sys": {
"conversations": [
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
]
},
"multi_turn_no_sys": {
"conversations": [
{"from": "human", "value": "abc"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "123"},
{"from": "gpt", "value": "sit"},
]
},
}
def prompt_strat(conversation, tokenizer):
"Helper function to create a prompt strategy for testing."
prompter = ShareGPTPrompterV2(conversation=conversation)
return ShareGPTPromptTokenizingStrategy(
prompter,
tokenizer,
False,
2048,
)
class TestPromptTokenizationStrategies(unittest.TestCase): class TestPromptTokenizationStrategies(unittest.TestCase):
""" """
@@ -159,70 +114,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
in self._caplog.records[0].message in self._caplog.records[0].message
) )
def test_sharegpt_llama(self):
"Make sure the sharegpt/llama is tokenized and formatted correctly."
strat = prompt_strat("llama-2", self.tokenizer)
def tokenize(conv):
return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
def decode(ids):
return strat.tokenizer.decode(ids)
# fmt: off
# System message, multi-turn conversations
mt_ids = tokenize(test_data['multi_turn_sys'])
assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
# System message, single-turn conversations
st_ids = tokenize(test_data['single_turn_sys'])
assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
# No system message, single-turn
ns_ids = tokenize(test_data['single_turn_no_sys'])
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
# No system message, multi-turn
ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
# fmt: on
def test_sharegpt_mistral(self):
"Make sure the sharegpt/mistral is tokenized and formatted correctly."
strat = prompt_strat("mistral", self.tokenizer)
def tokenize(conv):
return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
def decode(ids):
return strat.tokenizer.decode(ids)
# fmt: off
# System message, multi-turn conversations
mt_ids = tokenize(test_data['multi_turn_sys'])
assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
# System message, single-turn conversations
st_ids = tokenize(test_data['single_turn_sys'])
assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
# No system message, single-turn
ns_ids = tokenize(test_data['single_turn_no_sys'])
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
# No system message, multi-turn
ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
# fmt: on
def test_sharegpt_changes_roles(self): def test_sharegpt_changes_roles(self):
conversation = { conversation = {
"roles": ["USER", "CHARACTER"], "roles": ["USER", "CHARACTER"],

View File

@@ -3,8 +3,6 @@ Test cases for the tokenizer loading
""" """
import unittest import unittest
import pytest
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer from axolotl.utils.models import load_tokenizer
@@ -33,40 +31,6 @@ class TestTokenizers(unittest.TestCase):
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
assert "Fast" not in tokenizer.__class__.__name__ assert "Fast" not in tokenizer.__class__.__name__
def test_special_tokens_modules_to_save(self):
# setting special_tokens to new token
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"adapter": "lora",
"special_tokens": {"bos_token": "[INST]"},
}
)
with pytest.raises(
ValueError,
match=r".*Please set lora_modules_to_save*",
):
load_tokenizer(cfg)
# setting special_tokens but not changing from default
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"adapter": "lora",
"special_tokens": {"bos_token": "<s>"},
}
)
load_tokenizer(cfg)
# non-adapter setting special_tokens
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"special_tokens": {"bos_token": "[INST]"},
}
)
load_tokenizer(cfg)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -682,43 +682,6 @@ class ValidationTest(unittest.TestCase):
validate_config(cfg) validate_config(cfg)
def test_add_tokens_adapter(self):
cfg = DictDefault(
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
)
with pytest.raises(
ValueError,
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
):
validate_config(cfg)
cfg = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": True,
"tokens": ["<|imstart|>"],
"lora_modules_to_save": ["embed_tokens"],
}
)
with pytest.raises(
ValueError,
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
):
validate_config(cfg)
cfg = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": True,
"tokens": ["<|imstart|>"],
"lora_modules_to_save": ["embed_tokens", "lm_head"],
}
)
validate_config(cfg)
class ValidationWandbTest(ValidationTest): class ValidationWandbTest(ValidationTest):
""" """