Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
d465b9fd98 wip, jagged restarts 2024-02-16 14:34:08 -05:00
112 changed files with 809 additions and 5081 deletions

View File

@@ -59,7 +59,6 @@ body:
label: Config yaml
description: |
Please attach the config yaml!
render: yaml
- type: textarea
id: possible-solution

View File

@@ -12,6 +12,11 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"

View File

@@ -17,6 +17,6 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.9"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0

View File

@@ -13,12 +13,16 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
is_latest: true
- cuda: 121
cuda_version: 12.1.0
@@ -55,7 +59,6 @@ jobs:
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: |
@@ -70,6 +73,11 @@ jobs:
strategy:
matrix:
include:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"

View File

@@ -23,7 +23,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.9"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0
@@ -33,7 +33,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
python_version: ["3.9", "3.10", "3.11"]
timeout-minutes: 10
steps:
@@ -58,8 +58,8 @@ jobs:
docker-e2e-tests:
if: github.repository_owner == 'OpenAccess-AI-Collective'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 60
runs-on: [self-hosted, gpu, docker]
timeout-minutes: 30
needs: [pre-commit, pytest]
strategy:
@@ -69,32 +69,44 @@ jobs:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.2
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
num_gpus: 1
pytorch: 2.0.1
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.2
num_gpus: 1
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
python-version: "3.10"
- name: Install Modal
images: winglian/axolotl-tests
- name: Build Docker image
run: |
python -m pip install --upgrade pip
pip install modal jinja2
- name: Update env vars
# Set up build arguments
BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}"
CUDA="${{ matrix.cuda }}"
PYTORCH_VERSION="${{ matrix.pytorch }}"
# Build the Docker image
docker build . \
--file ./docker/Dockerfile-tests \
--build-arg BASE_TAG=$BASE_TAG \
--build-arg CUDA=$CUDA \
--build-arg GITHUB_REF=$GITHUB_REF \
--build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
--tag ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} \
--no-cache
- name: Unit Tests w docker image
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
- name: GPU Unit Tests w docker image
run: |
modal run cicd.tests
docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
- name: GPU Unit Tests monkeypatched w docker image
run: |
docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest /workspace/axolotl/tests/e2e/patched/
- name: Prune image from docker
if: github.ref != 'refs/heads/main'
run: |
docker rmi -f ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}

5
.gitignore vendored
View File

@@ -167,8 +167,3 @@ cython_debug/
# WandB
# wandb creates a folder to store logs for training runs
wandb
# Runs
lora-out/*
qlora-out/*
mlruns/*

View File

@@ -1,5 +1,5 @@
[mypy]
plugins = pydantic.mypy
exclude = venv
[mypy-alpaca_lora_4bit.*]

View File

@@ -31,7 +31,6 @@ repos:
additional_dependencies:
[
'types-PyYAML',
'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.7.5

206
README.md
View File

@@ -22,10 +22,10 @@ Features:
- [Introduction](#axolotl)
- [Supported Features](#axolotl-supports)
- [Quickstart](#quickstart-)
- [Environment](#environment)
- [Installation](#installation)
- [Docker](#docker)
- [Conda/Pip venv](#condapip-venv)
- [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod
- [Cloud GPU](#cloud-gpu) - Latitude.sh, RunPod
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
- [Windows](#windows)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
@@ -34,7 +34,7 @@ Features:
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
- [Config](#config)
- [Train](#train)
- [Inference](#inference-playground)
- [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base)
- [Special Tokens](#special-tokens)
- Advanced Topics
@@ -87,17 +87,15 @@ Features:
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
✅: supported
❌: not supported
❓: untested
## Quickstart ⚡
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
**Requirements**: Python >=3.9 and Pytorch >=2.0.
`pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"`
### For developers
```bash
@@ -105,18 +103,9 @@ git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install packaging
```
General case:
```
pip3 install -e '.[flash-attn,deepspeed]'
```
Mac: see https://github.com/OpenAccess-AI-Collective/axolotl/blob/13199f678b9aab39e92961323bdbce3234ee4b2b/docs/mac.md
```
pip3 install -e '.'
```
### Usage
```bash
# preprocess datasets - optional but recommended
@@ -138,14 +127,13 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/openllama-3b/lora.yml
```
## Advanced Setup
## Installation
### Environment
#### Docker
```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main-latest
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
```
Or run on the current files for development:
@@ -164,7 +152,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAcc
A more powerful Docker command to run would be this:
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-latest
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
```
It additionally:
@@ -179,7 +167,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
</details>
#### Conda/Pip venv
1. Install python >=**3.10**
1. Install python >=**3.9**
2. Install pytorch stable https://pytorch.org/get-started/locally/
@@ -199,7 +187,6 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
#### Bare Metal Cloud GPU
@@ -213,11 +200,11 @@ For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud
1. Install python
```bash
sudo apt update
sudo apt install -y python3.10
sudo apt install -y python3.9
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1
sudo update-alternatives --config python # pick 3.10 if given option
python -V # should be 3.10
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
sudo update-alternatives --config python # pick 3.9 if given option
python -V # should be 3.9
```
@@ -255,18 +242,15 @@ Please use WSL or Docker!
#### Launching on public clouds via SkyPilot
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
```bash
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
sky check
```
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
```
git clone https://github.com/skypilot-org/skypilot.git
cd skypilot/llm/axolotl
```
Use one command to launch:
```bash
# On-demand
@@ -276,32 +260,31 @@ HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
```
### Dataset
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
Have dataset(s) in one of the following format (JSONL recommended):
#### Pretraining
- `completion`: raw corpus
```json
{"text": "..."}
```
Note: Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:
```yaml
pretraining_dataset: # hf path only
```
#### Supervised finetuning
##### Instruction
- `alpaca`: instruction; input(optional)
```json
{"instruction": "...", "input": "...", "output": "..."}
```
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
- `llama-2`: the json is the same format as `sharegpt` above, with the following config (see the [config section](#config) for more details)
```yml
datasets:
- path: <your-path>
type: sharegpt
conversation: llama-2
```
- `completion`: raw corpus
```json
{"text": "..."}
```
<details>
@@ -379,37 +362,14 @@ pretraining_dataset: # hf path only
```json
{"scores": "...", "critiques": "...", "instruction": "...", "answer": "...", "revision": "..."}
```
- `metharme`: instruction, adds additional eos tokens
```json
{"prompt": "...", "generation": "..."}
```
</details>
##### Template-Free
- `input_output`: template-free prompt construction
```json
{"segments": [{"label": true|false, "text": "..."}]}
```
This is a special format that allows you to construct prompts without using templates. This is for advanced users who want more freedom with prompt construction. See [these docs](docs/input_output.md) for more details.
##### Conversation
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
<details>
<summary>See other formats</summary>
- `pygmalion`: pygmalion
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- `metharme`: instruction, adds additional eos tokens
```json
{"prompt": "...", "generation": "..."}
```
- `sharegpt.load_role`: conversations where `role` is used instead of `from`
```json
{"conversations": [{"role": "...", "value": "..."}]}
@@ -425,8 +385,6 @@ This is a special format that allows you to construct prompts without using temp
</details>
Note: `type: sharegpt` opens a special config `conversation:` that enables conversions to many Conversation types. See dataset section under [all yaml options](#all-yaml-options).
#### How to add custom prompts
For a dataset that is preprocessed for instruction purposes:
@@ -448,16 +406,12 @@ datasets:
format: "[INST] {instruction} [/INST]"
no_input_format: "[INST] {instruction} [/INST]"
```
See full config options under [all yaml options](#all-yaml-options).
#### How to use your custom pretokenized dataset
- Do not pass a `type:`
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
```yaml
- path: ...
```
### Config
@@ -471,18 +425,22 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- dataset
```yaml
datasets:
# huggingface repo
- path: vicgalle/alpaca-gpt4
type: alpaca
sequence_len: 2048 # max token length for prompt
# huggingface repo with specific configuration/subset
# huggingface repo
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca # format from earlier
# huggingface repo with specific configuration/subset
datasets:
- path: EleutherAI/pile
name: enron_emails
type: completion # format from earlier
field: text # Optional[str] default: text, field to use for completion data
# huggingface repo with multiple named configurations/subsets
# huggingface repo with multiple named configurations/subsets
datasets:
- path: bigcode/commitpackft
name:
- ruby
@@ -490,29 +448,34 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript
type: ... # unimplemented custom format
# fastchat conversation
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# fastchat conversation
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
datasets:
- path: ...
type: sharegpt
conversation: chatml # default: vicuna_v1.1
conversation: chatml
# local
# local
datasets:
- path: data.jsonl # or json
ds_type: json # see other options below
type: alpaca
# dataset with splits, but no train split
# dataset with splits, but no train split
dataset:
- path: knowrohit07/know_sql
type: context_qa.load_v2
train_on_split: validation
# loading from s3 or gcs
# s3 creds will be loaded from the system default and gcs only supports public access
# loading from s3 or gcs
# s3 creds will be loaded from the system default and gcs only supports public access
dataset:
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
...
# Loading Data From a Public URL
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
# Loading Data From a Public URL
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
dataset:
- path: https://some.url.com/yourdata.jsonl # The URL should be a direct link to the file you wish to load. URLs must use HTTPS protocol, not HTTP.
ds_type: json # this is the default, see other options below.
```
@@ -521,11 +484,9 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
```yaml
load_in_4bit: true
load_in_8bit: true
bf16: auto # require >=ampere, auto will detect if your GPU supports this and choose automatically.
fp16: # leave empty to use fp16 when bf16 is 'auto'. set to false if you want to fallback to fp32
tf32: true # require >=ampere
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
float16: true # use instead of fp16 when you don't want AMP
```
@@ -533,7 +494,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- lora
```yaml
adapter: lora # 'qlora' or leave blank for full finetune
adapter: lora # qlora or leave blank for full finetune
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
@@ -542,9 +503,9 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- v_proj
```
<details id="all-yaml-options">
<details>
<summary>All yaml options (click to expand)</summary>
<summary>All yaml options (click me)</summary>
```yaml
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
@@ -556,8 +517,8 @@ base_model_ignore_patterns:
# You can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf
# You can specify to choose a specific model revision from huggingface hub
revision_of_model:
# Optional tokenizer configuration path in case you want to use a different tokenizer
model_revision:
# Optional tokenizer configuration override in case you want to use a different tokenizer
# than the one defined in the base model
tokenizer_config:
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
@@ -574,16 +535,15 @@ tokenizer_legacy:
# This is reported to improve training speed on some models
resize_token_embeddings_to_32x:
# (Internal use only)
# Used to identify which the model is based on
is_falcon_derived_model:
is_llama_derived_model:
is_qwen_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model:
is_qwen_derived_model:
# optional overrides to the base model configuration
overrides_of_model_config:
model_config:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
@@ -600,6 +560,8 @@ bnb_config_kwargs:
# Whether you are training a 4-bit GPTQ quantized model
gptq: true
gptq_groupsize: 128 # group size
gptq_model_v1: false # v1 or v2
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true
@@ -673,7 +635,7 @@ test_datasets:
data_files:
- /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto_pair'
# use RL training: dpo, ipo, kto_pair
rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
@@ -693,7 +655,7 @@ dataset_processes: # defaults to os.cpu_count() if not set
# Only needed if cached dataset is taking too much storage
dataset_keep_in_memory:
# push checkpoints to hub
hub_model_id: # private repo path to push finetuned model
hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy:
@@ -772,8 +734,6 @@ peft:
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
relora_steps: # Number of steps per ReLoRA restart
relora_warmup_steps: # Number of per-restart warmup steps
relora_anneal_steps: # Number of anneal steps for each relora cycle
relora_prune_ratio: # threshold for optimizer magnitude when pruning
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it
@@ -789,7 +749,6 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# mlflow configuration if you're using it
mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
# Where to save the full-finetuned model to
output_dir: ./completed-model
@@ -823,8 +782,7 @@ save_total_limit: # Checkpoints saved at a time
max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
@@ -853,11 +811,14 @@ early_stopping_patience: 3
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
# For one_cycle optim
lr_div_factor: # Learning rate div factor
# For log_sweep optim
log_sweep_min_lr:
log_sweep_max_lr:
# Specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
@@ -1080,10 +1041,6 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
##### FSDP + QLoRA
Axolotl supports training with FSDP and QLoRA, see [these docs](docs/fsdp_qlora.md) for more information.
##### Weights & Biases Logging
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
@@ -1145,7 +1102,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi
### Merge LORA to base
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
```bash
python3 -m axolotl.cli.merge_lora your_config.yml --lora_model_dir="./completed-model"
@@ -1206,7 +1163,7 @@ If you decode a prompt constructed by axolotl, you might see spaces between toke
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same, adjust your inference server accordingly.
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same adjust your inference server accordingly.
4. As an additional troubleshooting step, you can look at the token ids between 1 and 2 to make sure they are identical.
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example.
@@ -1253,20 +1210,11 @@ PRs are **greatly welcome**!
Please run below to setup env
```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]'
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
# test
pytest tests/
# optional: run against all files
pre-commit run --all-files
```
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
@@ -1303,6 +1251,4 @@ consider sponsoring the project via [GitHub Sponsors](https://github.com/sponsor
#### 🥉 Bronze Sponsors - $500/mo
- [JarvisLabs.ai](https://jarvislabs.ai)
---

View File

@@ -1,39 +0,0 @@
FROM winglian/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
ENV BNB_CUDA_VERSION="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
WORKDIR /workspace/axolotl
RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image
RUN pip install pytest
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli
RUN git config --global credential.helper store

View File

@@ -1,5 +0,0 @@
#!/bin/bash
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest /workspace/axolotl/tests/e2e/patched/
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/

View File

@@ -1,75 +0,0 @@
"""
modal application to run axolotl gpu tests in Modal
"""
import os
import pathlib
import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import Image, Stub
cicd_path = pathlib.Path(__file__).parent.resolve()
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.0.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.10-cu118-2.0.1"),
"CUDA": os.environ.get("CUDA", "118"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
}
dockerfile_contents = df_template.render(**df_args)
temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
force_build=True,
gpu="A10G",
)
.env(df_args)
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
stub = Stub("Axolotl CI/CD", secrets=[])
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=45 * 60,
cpu=8.0,
memory=131072,
)
def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
@stub.local_entrypoint()
def main():
cicd_pytest.remote()

View File

@@ -16,7 +16,6 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

View File

@@ -20,7 +20,6 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

View File

@@ -24,7 +24,6 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

View File

@@ -24,7 +24,6 @@
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

View File

@@ -2,6 +2,7 @@
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -3,10 +3,9 @@ FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.1.2"
ARG PYTORCH_VERSION="2.0.1"
ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -21,9 +20,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
fi
# So we can test the Docker image

View File

@@ -7,8 +7,8 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION a
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.10"
ARG PYTORCH_VERSION="2.1.2"
ARG PYTHON_VERSION="3.9"
ARG PYTORCH_VERSION="2.0.1"
ARG CUDA="118"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"

View File

@@ -3,10 +3,9 @@ FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.1.2"
ARG PYTORCH_VERSION="2.0.1"
ARG GITHUB_REF="main"
ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -25,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
fi
# So we can test the Docker image

View File

@@ -74,6 +74,7 @@ pip3 install -e '.[flash-attn,deepspeed]'
If you developing on a remote host, you can easily use VSCode to debug remotely. To do so, you will need to follow this [remote - SSH guide](https://code.visualstudio.com/docs/remote/ssh). You can also see the video below on [Docker and Remote SSH debugging](#video---attaching-to-docker-on-remote-host).
```bash
### Configuration

View File

@@ -1,37 +0,0 @@
# FDSP + QLoRA
## Background
Using FSDP with QLoRA is essential for **fine-tuning larger (70b+ parameter) LLMs on consumer GPUs.** For example, you can use FSDP + QLoRA to train a 70b model on two 24GB GPUs[^1].
Below, we describe how to use this feature in Axolotl.
## Usage
To enable `QLoRA` with `FSDP`, you need to perform the following steps:
> ![Tip]
> See the [example config](#example-config) file in addition to reading these instructions.
1. Set `adapter: qlora` in your axolotl config file.
2. Enable FSDP in your axolotl config, as [described here](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#fsdp).
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
## Example Config
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.
## References
- [PR #1378](https://github.com/OpenAccess-AI-Collective/axolotl/pull/1378) enabling QLoRA in FSDP in Axolotl.
- [Blog Post](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the [Answer.AI](https://www.answer.ai/) team describing the work that enabled QLoRA in FSDP.
- Related HuggingFace PRs Enabling FDSP + QLoRA:
- Accelerate [PR#2544](https://github.com/huggingface/accelerate/pull/2544 )
- Transformers [PR#29587](https://github.com/huggingface/transformers/pull/29587)
- TRL [PR#1416](https://github.com/huggingface/trl/pull/1416)
- PEFT [PR#1550](https://github.com/huggingface/peft/pull/1550)
[^1]: This was enabled by [this work](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the Answer.AI team.

View File

@@ -1,260 +0,0 @@
# Template-free prompt construction with the `input_output` format
<!-- TOC -->
- [Background](#background)
- [Masking Inputs](#masking-inputs)
- [You may not want prompt templates](#you-may-not-want-prompt-templates)
- [The `input_output` format](#the-input_output-format)
- [Usage](#usage)
- [1. Prepare Data](#1-prepare-data)
- [2. Use `type: input_output`](#2-use-type-input_output)
- [3. Check the prompts](#3-check-the-prompts)
<!-- /TOC -->
<a id="markdown-background" name="background"></a>
## Background
<a id="markdown-masking-inputs" name="masking-inputs"></a>
### Masking Inputs
One of the most popular features of
[axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) is
setting the following configuration value:
```yaml
train_on_inputs: false
```
If you declare a [dataset formats](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#dataset)
such as `alpaca` or `chatml`, axolotl knows what is an input
(i.e. human) vs. an output (i.e. the assistant) and masks the input
labels so that your model can focus on predicting the outputs only.
<a id="markdown-you-may-not-want-prompt-templates" name="you-may-not-want-prompt-templates"></a>
### You may not want prompt templates
However, there are many situations where you don't want to use one of
these formats or templates (I usually don't!). This is because they can:
- Add unnecessary boilerplate to your prompts.
- Create artifacts like special delimiters `<|im_start|>` that can
quickly become footguns if you don't include them correctly at
inference time.
- Enforce a *chat* interface when you do not want one. Sometimes you
just want to fine-tune a model to a very specific task and do NOT
want multi-turn conversations, roles, etc.
- Limit you to only certain roles that the template allows.
<a id="markdown-the-inputoutput-format" name="the-inputoutput-format"></a>
### The `input_output` format
You can construct your prompts without a template by using the
`input_output` format, by setting `type: input_output` in your
configuration file like this:
**config.yml**
```yaml
train_on_inputs: false # Mask segments of your data
datasets:
- path: output.jsonl
type: input_output # use template free prompt construction
```
Unlike `type: completion`, which is also template-free,
`type: input_output` allows you to mask segments of your text. More
details on how this works are described below.
<a id="markdown-usage" name="usage"></a>
## Usage
This is how you can use the `input_output` format:
<a id="markdown-1-prepare-data" name="1-prepare-data"></a>
### 1. Prepare Data
To use the `input_output` format, collect your data in the following
format into a jsonl file (below is the first row from the file
`output`.jsonl` pretty printed):
```bash
$ head -n1 output.jsonl | python -m json.tool
{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
```
Set `label:false` when you want to mask a segment of text so that the
model isn't trained on it. Some things to keep in mind:
> [!IMPORTANT]
> 1. **EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl
concatenates all the segments as-is.** The tokenizer doesn't add
anything additional. Notice how I added spaces, newlines, `<s>`
(BOS), and `</s>` (EOS) myself.
> 2. Make sure you check the materialized output to validate that the
prompt is getting assembled how you like.
<a id="markdown-2-use-type-inputoutput" name="2-use-type-inputoutput"></a>
### 2. Use `type: input_output`
Let's materialize data with our `output.jsonl` file by setting
`type: input_output` in our axolotl config:
```yaml
# training_config.yaml
base_model: mistralai/Mistral-7B-v0.1
data_seed: 49
seed: 49
datasets:
- path: output.jsonl
type: input_output
val_set_size: 0.1
sequence_len: 896
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 3
eval_batch_size: 2
num_epochs: 1
learning_rate: 0.0002
train_on_inputs: false
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
```
You can use the following command to materialize your data. The
`--debug` flag will print the tokens, along with the labels so you can
verify that the correct items are being ignored:
```bash
$ python -m axolotl.cli.preprocess training_config.yaml --debug
...
[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] <s>(1, 1) Hello(22557, 22557)
(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)
```
The format is `decoded_token`(`label`, `token_id`), for example,
`<s>(1, 1)` means that the token is `<s>`, the label is `1` and the
token_id is `1`. When the label is `-100` then that token is ignored for
training.
<a id="markdown-3-check-the-prompts" name="3-check-the-prompts"></a>
### 3. Check the prompts
Here is another way to check the materialized output:
```python
from transformers import AutoTokenizer
from datasets import load_from_disk
import yaml
directory = !ls last_run_prepared/
with open('training_config.yaml', 'r') as f:
cfg = yaml.safe_load(f)
model_id = cfg['base_model']
tok = AutoTokenizer.from_pretrained(model_id)
ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
```
```python
>>> row = ds[0]
>>> print(tok.decode(row['input_ids']))
<s> Hello
hi there!. goodbye farewell</s>
```
We can check that the right tokens are ingored by comparing the labels
to each token:
```python
import pandas as pd
pd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in
zip(row['input_ids'], row['labels'])])
```
| token | label | id |
|-------|-------|-------|
| 0 | \<s\> | 1 |
| 1 | Hello | 22557 |
| 2 | \\n | 13 |
| 3 | hi | 12014 |
| 4 | there | 736 |
| 5 | ! | 28808 |
| 6 | . | 28723 |
| 7 | | 28705 |
| 8 | good | -100 |
| 9 | bye | -100 |
| 10 | | -100 |
| 11 | fare | 19111 |
| 12 | well | 5458 |
| 13 | \</s\>| 2 |
If we look at the input data, the above table seems correct! (The jsonl
version is repeated below for reference):
```bash
$ head -n1 output.jsonl | python -m json.tool
{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
```

View File

@@ -1,18 +0,0 @@
# Mac M series support
Currently Axolotl on Mac is partially usable, many of the dependencies of Axolotl including Pytorch do not support MPS or have incomplete support.
Current support:
- [x] Support for all models
- [x] Full training of models
- [x] LoRA training
- [x] Sample packing
- [ ] FP16 and BF16 (awaiting AMP support for MPS in Pytorch)
- [ ] Tri-dao's flash-attn (until it is supported use spd_attention as an alternative)
- [ ] xformers
- [ ] bitsandbytes (meaning no 4/8 bits loading and bnb optimizers)
- [ ] qlora
- [ ] DeepSpeed
Untested:
- FSDP

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -177,24 +177,6 @@
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Play with inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
" --qlora_model_dir=\"./qlora-out\" --gradio"
]
}
],
"metadata": {

View File

@@ -2,7 +2,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: true
load_in_4bit: false
gptq: false

View File

@@ -5,7 +5,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
# enable 4bit for QLoRA
load_in_4bit: true

View File

@@ -2,7 +2,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
load_in_4bit: false
gptq: false

View File

@@ -1,65 +0,0 @@
# use google/gemma-7b if you have access
base_model: mhenrichsen/gemma-7b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
# huggingface repo
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
val_set_size: 0.1
output_dir: ./out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 3
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,6 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false

View File

@@ -1,4 +1,5 @@
base_model: TheBloke/Llama-2-7B-GPTQ
is_llama_derived_model: false
gptq: true
gptq_disable_exllama: true
model_type: AutoModelForCausalLM

View File

@@ -1,6 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false
@@ -59,7 +60,7 @@ s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -1,6 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
@@ -56,7 +57,7 @@ s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -1,70 +0,0 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 512
sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 4
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
special_tokens:

View File

@@ -1,6 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -1,7 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -49,7 +49,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -2,6 +2,7 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: true
load_in_4bit: false
@@ -60,7 +61,7 @@ flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
#default deepspeed, can use more aggresive if needed like zero2, zero3

View File

@@ -1,6 +1,7 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false
load_in_4bit: false
@@ -48,7 +49,7 @@ flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -1,79 +0,0 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./lora-out
eval_sample_packing: false
adapter: lora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
sdp_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,74 +0,0 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./qlora-out
model_config:
output_router_logits: true
adapter: qlora
lora_model_dir:
sequence_len: 1024
sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
special_tokens:

View File

@@ -16,12 +16,12 @@ output_dir: ./qlora-out
## You can optionally freeze the entire model and unfreeze a subset of parameters
unfrozen_parameters:
# - ^lm_head.weight$
# - ^model.embed_tokens.weight$[:32000]
# - model.layers.2[0-9]+.block_sparse_moe.gate
# - model.layers.2[0-9]+.block_sparse_moe.experts
# - model.layers.3[0-9]+.block_sparse_moe.gate
# - model.layers.3[0-9]+.block_sparse_moe.experts
# - lm_head.*
# - model.embed_tokens.*
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
model_config:
output_router_logits: true
@@ -81,7 +81,7 @@ loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed: deepspeed_configs/zero2.json

View File

@@ -1,6 +1,7 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false
load_in_4bit: true
@@ -67,7 +68,7 @@ loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -2,6 +2,7 @@ base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true
load_in_8bit: true
@@ -57,7 +58,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -2,6 +2,7 @@ base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true
load_in_8bit: false
@@ -57,7 +58,7 @@ flash_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:

View File

@@ -1,69 +0,0 @@
base_model: stabilityai/stablelm-2-1_6b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_steps: 100
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,66 +0,0 @@
base_model: stabilityai/stablelm-2-1_6b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,36 +0,0 @@
# StableLM 2
This repository contains examples for training and processing using StableLM-2. It also includes a section to help you estimate the GPU requirements for your specific use case.
## Estimating GPU Requirements
| type | deepspeed | batch size | context length | vRAM GPU (GBs) |
|---------------|-----------|------------|----------------|----------------|
| full finetune | N/A | 1 | 4096 | ~21.5GBs |
| full finetune | zero2 | 1 | 4096 | ~20GBs |
| lora | N/A | 1 | 4096 | ~16.6GBs |
The above are estimates and might differ slight depending on the setup for example whether you pack your sequence lengths or not (the above assumes you do to length 4096).
This blog post from Hamel Husain was a great resource for estimating these numbers: https://hamel.dev/notes/llm/03_estimating_vram.html
## Training
We have example scripts here for both full finetuning and lora using the popular alpaca dataset:
```shell
# preprocess the dataset
CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess examples/stablelm-2/1.6b/lora.yml
```
Single GPU Training:
```shell
python -m axolotl.cli.train examples/stablelm-2/fft.yml --deepspeed deepspeed_configs/zero2.json
# OR
python -m axolotl.cli.train examples/stablelm-2/1.6b/lora.yml
```
Multinode GPU Training with `accelerate`:
```shell
# make sure you've configured accelerate properly
accelerate launch -m axolotl.cli.train examples/stablelm-2/1.6b/fft.yml --deepspeed deepspeed_configs/zero2.json
```

View File

@@ -1,69 +0,0 @@
base_model: bigcode/starcoder2-3b
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.2
output_dir: ./qlora
adapter: qlora
lora_model_dir:
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 20
evals_per_epoch: 4
eval_steps:
eval_table_size:
saves_per_epoch: 4
save_steps:
save_total_limit: 2
debug:
deepspeed:
weight_decay:
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,6 +1,7 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
@@ -15,7 +16,6 @@ output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora

View File

@@ -2,6 +2,7 @@ base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true

View File

@@ -1,7 +1,8 @@
base_model: 01-ai/Yi-34B-Chat
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: false
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
@@ -28,7 +29,7 @@ num_epochs: 1
val_set_size: 0.1
evals_per_epoch: 5
eval_table_size:
eval_max_new_tokens: 128
eval_table_max_new_tokens: 128
eval_sample_packing: false
eval_batch_size: 1

View File

@@ -1,18 +1,17 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.9.0
transformers==4.38.2
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc
tokenizers==0.15.0
bitsandbytes>=0.43.0
bitsandbytes>=0.41.1
accelerate==0.26.1
deepspeed==0.13.1
pydantic==2.6.3
deepspeed>=0.13.1
addict
fire
PyYAML>=6.0
requests
datasets>=2.15.0
flash-attn==2.5.5
flash-attn==2.3.3
sentencepiece
wandb
einops
@@ -22,13 +21,14 @@ hf_transfer
colorama
numba
numpy>=1.24.4
mlflow
# qlora things
evaluate==0.4.1
evaluate==0.4.0
scipy
scikit-learn==1.2.2
pynvml
art
fschat==0.2.36
fschat==0.2.34
gradio==3.50.2
tensorboard
@@ -40,4 +40,3 @@ gcsfs
# adlfs
trl>=0.7.9
fastcore>=1.5.29

View File

@@ -18,7 +18,6 @@ def parse_requirements():
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
or "lion-pytorch" in line
)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
@@ -68,13 +67,13 @@ setup(
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn==2.5.5",
"flash-attn==2.5.0",
],
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
],
"deepspeed": [
"deepspeed==0.13.1",
"deepspeed>=0.13.1",
"deepspeed-kernels",
],
"mamba-ssm": [
@@ -83,11 +82,5 @@ setup(
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
"lion-pytorch": [
"lion-pytorch==0.1.2",
],
},
)

View File

@@ -13,6 +13,7 @@ from threading import Thread
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
import gradio as gr
import requests
import torch
import yaml
@@ -23,7 +24,6 @@ from art import text2art
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
@@ -214,8 +214,6 @@ def do_inference_gradio(
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
import gradio as gr
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
@@ -330,6 +328,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
@@ -342,22 +341,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
else:
cfg[k] = kwargs[k]
cfg.axolotl_config_path = config
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
cfg = validate_config(
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"n_gpu": os.environ.get("WORLD_SIZE", 1),
"compute_capability": gpu_version,
},
)
validate_config(cfg)
prepare_optim_env(cfg)

View File

@@ -1,55 +0,0 @@
"""module for building the auto wrap policy for FSDP"""
import functools
from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
from torch.distributed.fsdp.wrap import (
_or_policy,
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
"llama",
"mistral",
"mixtral",
]
def get_wrapping_policy_factory(model_type):
if model_type == "llama":
layer_to_wrap = LlamaDecoderLayer
elif model_type == "mistral":
layer_to_wrap = MistralDecoderLayer
elif model_type == "mixtral":
layer_to_wrap = MixtralDecoderLayer
def get_wrapping_policy():
"""This checks for lora layers (has weight and requires_grad)"""
def lambda_policy_fn(module):
return (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
)
lambda_policy = functools.partial(
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
)
transformer_layer_name = layer_to_wrap
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(
PrefixEncoder,
PromptEncoder,
PromptEmbedding,
transformer_layer_name,
),
)
policies = [lambda_policy, transformer_wrap_policy]
return functools.partial(_or_policy, policies=policies)
return get_wrapping_policy

View File

@@ -5,10 +5,8 @@ Builder for the training args and trainer
import abc
import importlib
import importlib.util
import logging
import math
import os
import sys
from abc import abstractmethod
from dataclasses import dataclass, field
@@ -18,10 +16,7 @@ from typing import List, Optional, Type, Union
import torch
import transformers
from accelerate import FullyShardedDataParallelPlugin
from accelerate.utils import str_to_bool
from datasets import Dataset
from torch.distributed.fsdp import MixedPrecision
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
@@ -31,21 +26,18 @@ from transformers import (
TrainingArguments,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoMlflowCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import (
@@ -57,13 +49,9 @@ from axolotl.utils.collators import (
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
get_cosine_schedule_with_quadratic_warmup, JaggedLRRestartScheduler,
)
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
try:
import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError:
@@ -72,10 +60,6 @@ except ImportError:
LOG = logging.getLogger("axolotl.core.trainer_builder")
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
@@ -145,11 +129,19 @@ class AxolotlTrainingArguments(TrainingArguments):
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
metadata={"help": "how many anneal steps to take before reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
jagged_restart_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for jagged restarts"},
)
jagged_restarts_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for jagged restarts"},
)
jagged_restarts_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many anneal steps to take before reset for jagged restarts"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
@@ -163,9 +155,6 @@ class AxolotlTrainingArguments(TrainingArguments):
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
@@ -183,23 +172,6 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
class AxolotlTrainer(Trainer):
@@ -224,33 +196,6 @@ class AxolotlTrainer(Trainer):
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
@@ -284,16 +229,6 @@ class AxolotlTrainer(Trainer):
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
@@ -303,7 +238,7 @@ class AxolotlTrainer(Trainer):
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
super().create_scheduler(num_training_steps, optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
@@ -311,6 +246,21 @@ class AxolotlTrainer(Trainer):
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
if self.args.jagged_restart_steps:
warmup_steps = (
self.args.jagged_restarts_warmup_steps or 10
)
anneal_steps = (
self.args.jagged_restarts_anneal_steps or 1
)
self.lr_scheduler = JaggedLRRestartScheduler(
optimizer,
self.lr_scheduler,
self.args.jagged_restart_steps,
warmup_steps,
anneal_steps,
)
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
@@ -477,56 +427,6 @@ class AxolotlTrainer(Trainer):
return super().push_to_hub(*args, **kwargs)
@wraps(Trainer.create_accelerator_and_postprocess)
def create_accelerator_and_postprocess(self):
rank = int(os.environ.get("LOCAL_RANK", 0))
res = super().create_accelerator_and_postprocess()
if self.args.qlora is False:
return res
# the rest of this method override is specific to fsdp + qlora (for now)
sync_module_states = (
str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
)
mp_policy = None
amp = os.environ["ACCELERATE_MIXED_PRECISION"]
if amp == "fp16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
elif amp == "bf16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
# If somehow we figure out how we want to parameterize we want to autocast buffers...
# mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
# load_param_skip_names = ['inv_freq']
if self.is_fsdp_enabled:
wrapping_policy = get_wrapping_policy_factory(self.args.model_type)
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrapping_policy(),
cpu_offload=False,
use_orig_params=False,
limit_all_gathers=True,
param_init_fn=lambda module: module.to_empty(
device=torch.device("cuda"), recurse=False
)
if (rank != 0 and sync_module_states)
else None,
mixed_precision_policy=mp_policy,
)
self.accelerator.state.fsdp_plugin = fsdp_plugin
return res
class AxolotlMambaTrainer(AxolotlTrainer):
"""
@@ -750,11 +650,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
if self.cfg.use_mlflow:
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
@@ -774,11 +670,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.do_causal_lm_eval:
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
trainer, self.tokenizer
)
callbacks.append(CausalLMBenchEvalCallback(self.cfg))
if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
@@ -846,9 +737,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True
# deepspeed
if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
@@ -930,8 +818,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model:
training_arguments_kwargs[
"metric_for_best_model"
@@ -992,10 +878,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.load_best_model_at_end is not False
or self.cfg.early_stopping_patience
)
and (
(not self.cfg.test_datasets and self.cfg.val_set_size > 0)
or (self.cfg.test_datasets and self.cfg.val_set_size == 0)
)
and not self.cfg.test_datasets
and self.cfg.val_set_size > 0
and self.cfg.save_steps
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
@@ -1016,10 +900,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
if self.cfg.save_only_model:
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler
if self.cfg.lr_scheduler
@@ -1030,9 +912,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_arguments_kwargs[
"cosine_constant_lr_ratio"
] = self.cfg.cosine_constant_lr_ratio
training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
)
@@ -1050,20 +929,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"sample_packing_seq_len_multiplier"
] = self.cfg.micro_batch_size
if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs[
"relora_warmup_steps"
] = self.cfg.relora_warmup_steps
if self.cfg.relora_anneal_steps:
training_arguments_kwargs[
"relora_anneal_steps"
] = self.cfg.relora_anneal_steps
if self.cfg.relora_prune_ratio:
training_arguments_kwargs[
"relora_prune_ratio"
] = self.cfg.relora_prune_ratio
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
@@ -1075,42 +943,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
"neftune_noise_alpha"
] = self.cfg.neftune_noise_alpha
trainer_kwargs = {}
if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
if "weight_decay" in training_arguments_kwargs:
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
if (
"adam_beta1" in training_arguments_kwargs
and "adam_beta2" in training_arguments_kwargs
):
lion_kwargs["betas"] = (
training_arguments_kwargs["adam_beta1"],
training_arguments_kwargs["adam_beta2"],
)
trainer_kwargs["optimizers"] = (
Lion(params=self.model.parameters(), **lion_kwargs),
None,
)
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
)
training_args = self.hook_post_create_training_args(training_args)
trainer_kwargs = {}
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
data_collator_kwargs = {
"padding": True, # True/"longest" is the default

View File

@@ -30,7 +30,6 @@ class ColorfulFormatter(Formatter):
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"simple": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",

View File

@@ -1,133 +0,0 @@
"""Module for LoRA+"""
# MIT License
#
# Copyright (c) 2024 nikhil-ghosh-berkeley
# https://github.com/nikhil-ghosh-berkeley/loraplus
import logging
from functools import reduce
from peft.tuners import lora
from torch import nn
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
LOG = logging.getLogger("axolotl.loraplus")
def get_module(name, opt_model):
"""
Retrieve a module from a model using its parameter name.
Args:
name (str): Full name of the parameter, typically including module path.
opt_model (torch.nn.Module): The model from which to retrieve the module.
Returns:
Module corresponding to the given name.
"""
parent_idx = 2 if "lora" in name else 1
module_names = name.split(sep=".")[:-parent_idx]
module = reduce(getattr, module_names, opt_model)
return module
def create_loraplus_optimizer(
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding=None,
):
"""
Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.
Args:
opt_model (torch.nn.Module): The model for which the optimizer is being created.
optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.
Returns:
An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
"""
assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."
if loraplus_lr_embedding is None:
loraplus_lr_embedding = 1e-6
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
param_groups = {
"groupA": {},
"groupB": {},
"groupB_no_decay": {},
"embedding": {},
}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
module = get_module(name, opt_model)
if isinstance(module, lora.Embedding):
param_groups["embedding"][name] = param
elif "lora_B" in name or param.ndim == 1:
if name in decay_parameters:
param_groups["groupB"][name] = param
else:
param_groups["groupB_no_decay"][name] = param
else:
param_groups["groupA"][name] = param
assigned_param_groups = ""
for group, group_params in param_groups.items():
assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n"
LOG.info(assigned_param_groups)
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
optimizer_grouped_parameters = [
{
"params": list(param_groups["groupA"].values()),
"weight_decay": weight_decay,
"lr": lr,
},
{
"params": list(param_groups["embedding"].values()),
"weight_decay": weight_decay,
"lr": loraplus_lr_embedding,
},
{
"params": list(param_groups["groupB"].values()),
"weight_decay": weight_decay,
"lr": lr * loraplus_lr_ratio,
},
{
"params": list(param_groups["groupB_no_decay"].values()),
"weight_decay": 0.0,
"lr": lr * loraplus_lr_ratio,
},
]
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{p.data_ptr(): p.numel() for p in module.parameters()}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
return optimizer

View File

@@ -106,7 +106,7 @@ def get_turns( # pylint: disable=too-many-return-statements
if self.system_message:
contains_sys_msg = True
if self.messages:
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction separated by a newline
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline
first_role, first_msg = self.messages[0]
if first_role == self.roles[0]:
system_prompt = self.system_template.format(

View File

@@ -44,18 +44,6 @@ except ImportError:
LOG = logging.getLogger("axolotl")
def is_xformers_swiglu_available() -> bool:
from xformers.ops.common import get_xformers_operator
try:
get_xformers_operator("swiglu_packedw")()
return True
except RuntimeError as exc:
if "No such operator xformers::swiglu_packedw " in str(exc):
return False
return True
def replace_llama_mlp_with_swiglu(model):
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):
@@ -245,6 +233,7 @@ def flashattn_forward_with_s2attn(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
cu_seqlens: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
max_seqlen: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
@@ -286,9 +275,7 @@ def flashattn_forward_with_s2attn(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
@@ -373,6 +360,7 @@ def flashattn_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
@@ -437,9 +425,7 @@ def flashattn_forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
@@ -702,9 +688,6 @@ def llama_model_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[ # pylint: disable=unused-argument
torch.LongTensor
] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
@@ -768,6 +751,12 @@ def llama_model_forward(
dtype=torch.bool,
device=inputs_embeds.device,
)
padding_mask = None
else:
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
@@ -817,6 +806,7 @@ def llama_model_forward(
past_key_value,
output_attentions,
None,
padding_mask,
cu_seqlens,
max_seqlen,
)
@@ -828,6 +818,7 @@ def llama_model_forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
@@ -874,6 +865,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
@@ -906,6 +898,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)

View File

@@ -1,26 +1,15 @@
"""multipack patching for v2 of sample packing"""
import importlib
import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.integrations import is_deepspeed_zero3_enabled
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mixtral",
"qwen2",
"falcon",
"phi",
"gemma",
"gemmoe",
"starcoder2",
]
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
def patch_for_multipack(model_type, model_name=None):
def patch_for_multipack(model_type):
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
@@ -39,23 +28,3 @@ def patch_for_multipack(model_type, model_name=None):
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemmoe":
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_gemmoe to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_gemmoe", ".modeling_gemmoe"
)
modeling_gemmoe = importlib.import_module(module_name)
modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -46,9 +46,8 @@ def reset_optimizer(
*,
reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: list[str],
prune_ratio: float = 0.9,
):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9)
n_zeros = 0
n_total = 0
@@ -160,7 +159,6 @@ class ReLoRACallback(TrainerCallback):
optimizer,
reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys,
prune_ratio=args.relora_prune_ratio,
)
if self.quantized:
@@ -267,7 +265,7 @@ class ReLoRAScheduler(LRScheduler):
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.relora_steps - self.warmup_steps:
if step < self.relora_steps:
scale = 1
else:
per_relora_progress = step % self.relora_steps

View File

@@ -1,7 +1,4 @@
"""
HF Chat Templates prompt strategy
"""
from typing import Any, Dict, Optional
from typing import Optional, Dict, Any
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import Prompter
@@ -9,8 +6,6 @@ from axolotl.utils.chat_templates import chat_templates
class ChatTemplatePrompter(Prompter):
"""prompter for HF chat templates"""
def __init__(self, tokenizer, chat_template=None, max_length=2048):
self.tokenizer = tokenizer
self.chat_template = chat_template
@@ -18,9 +13,7 @@ class ChatTemplatePrompter(Prompter):
def build_prompt(self, conversation, add_generation_prompt=False):
return self.tokenizer.apply_chat_template(
conversation,
truncation=True,
max_length=self.max_length,
conversation, truncation=True, max_length=self.max_length,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
)
@@ -42,10 +35,11 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
else:
labels = input_ids
tokenized_prompt = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
"attention_mask": [1] * len(input_ids)
}
return tokenized_prompt
@@ -53,12 +47,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap roles - allow for assistant turn
role_map = {
"human": "user",
"user": "user",
"assistant": "assistant",
"gpt": "assistant",
}
role_map = {"human": "user", "user": "user", "assistant": "assistant", "gpt": "assistant"}
turns = [
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
]
@@ -66,11 +55,11 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
chat_template = (
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(tokenizer, chat_templates(chat_template)),
ChatTemplatePrompter(
tokenizer,
chat_templates(ds_cfg["conversation"]),
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,

View File

@@ -8,13 +8,14 @@ import logging
LOG = logging.getLogger("axolotl")
def load(strategy, cfg, **kwargs):
def load(strategy, cfg):
try:
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
func = getattr(mod, load_fn)
return func(cfg, **kwargs)
load_kwargs = {}
return func(cfg, **load_kwargs)
except Exception: # pylint: disable=broad-exception-caught
LOG.warning(f"unable to load strategy {strategy}")
return None

View File

@@ -5,7 +5,6 @@ DPO strategies for chatml
def argilla(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
@@ -24,28 +23,8 @@ def argilla(
return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/dpo-mix-7k conversations
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample
return transform_fn
def icr(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
@@ -69,7 +48,7 @@ def icr(
return transform_fn
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca DPO Pairs
"""
@@ -91,9 +70,7 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
return transform_fn
def prompt_pairs(
cfg, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
@@ -111,7 +88,7 @@ def prompt_pairs(
return transform_fn
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
def ultra(cfg): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
"""

View File

@@ -1,41 +0,0 @@
"""
User-defined DPO strategies
"""
def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
ds_cfg = cfg["datasets"][dataset_idx]["type"]
if not isinstance(ds_cfg, dict):
raise ValueError(
f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
)
field_prompt = ds_cfg.get("field_prompt", "prompt")
field_system = ds_cfg.get("field_system", "system")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
prompt_format = ds_cfg.get("prompt_format")
if not prompt_format:
prompt_format = "{" + field_prompt + "}"
chosen_format = ds_cfg.get("chosen_format")
if not chosen_format:
chosen_format = "{" + field_chosen + "}"
rejected_format = ds_cfg.get("rejected_format")
if not rejected_format:
rejected_format = "{" + field_rejected + "}"
def transform_fn(sample):
if (
"{" + field_system + "}" in prompt_format
and field_system in sample
and sample[field_system]
):
sample["prompt"] = prompt_format.format(
system=sample[field_system], prompt=sample[field_prompt]
)
else:
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
return sample
return transform_fn

View File

@@ -3,7 +3,7 @@ DPO strategies for zephyr
"""
def nectar(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
def nectar(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
data = {}
data["prompt"] = (

View File

@@ -1,54 +0,0 @@
"""Module for plain input/output prompt pairs"""
from typing import Generator, Tuple
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
class RawInputOutputStrategy(PromptTokenizingStrategy):
"""Prompt Strategy class for input/output pairs"""
def __init__(self, *args, eos_token=None, **kwargs):
super().__init__(*args, **kwargs)
self.eos_token = eos_token
if not eos_token:
self.eos_token = self.tokenizer.eos_token
def tokenize_prompt(self, prompt):
# pylint: disable=duplicate-code
input_ids = []
labels = []
for label, text in self.prompter.build_prompt(prompt["segments"]):
tokenized_output = self.tokenizer(
text, add_special_tokens=False, return_tensors=None
)["input_ids"]
input_ids += tokenized_output
if label or self.train_on_inputs:
labels += tokenized_output
else:
labels += [IGNORE_TOKEN_ID] * len(tokenized_output)
tokenized_prompt = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
}
return tokenized_prompt
class RawInputOutputPrompter(Prompter):
"""prompter for raw i/o data"""
def build_prompt(self, source) -> Generator[Tuple[bool, str], None, None]:
for segment in source:
yield segment["label"], segment["text"]
def load(tokenizer, cfg):
return RawInputOutputStrategy(
RawInputOutputPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -1,15 +1,10 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
from typing import Any, Dict, Optional
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
from axolotl.utils.tokenization import (
chatml_to_conversation,
merge_consecutive_messages,
)
def register_chatml_template(system_message=None):
@@ -24,16 +19,6 @@ def register_chatml_template(system_message=None):
sep="<|im_end|>",
)
)
register_conv_template(
Conversation(
name="chatml_glaive",
system_template="<|im_start|>system\n{system_message}",
system_message=system_message,
roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"],
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
@@ -92,26 +77,12 @@ def load_guanaco(tokenizer, cfg):
)
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else "chatml_glaive"
)
return GlaiveShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(conversation=conversation),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
"""
_strict = False
_strict = True
@property
def strict(self):
@@ -125,25 +96,10 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
conversations = prompt["conversations"]
if self.strict:
return conversations
role_key = "from"
if "role" in conversations[0].keys():
role_key = "role"
value_key = "value"
if "text" in conversations[0].keys():
value_key = "text"
elif "content" in conversations[0].keys():
value_key = "content"
# remap roles - allow for assistant turn"
role_map = {
"user": "human",
"human": "human",
"assistant": "gpt",
"gpt": "gpt",
"system": "system",
}
# remap roles - allow for assistant turn
role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
turns = [
{"from": role_map[t[role_key]], "value": t[value_key]}
for t in conversations
{"from": role_map[t["from"]], "value": t["value"]} for t in conversations
]
return turns
@@ -187,15 +143,3 @@ class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingSt
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
]
return turns
class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps glaive data to sharegpt format
"""
def get_conversation_thread(self, prompt):
conversation = chatml_to_conversation(prompt)
conversation = merge_consecutive_messages(conversation)
return conversation

View File

@@ -360,19 +360,11 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
LOG.warning(f"expected tuple, got {part}")
continue
tool_role_label = None
if len(conversation.roles) == 3:
(
user_role_label,
assistant_role_label,
tool_role_label,
) = conversation.roles
else:
user_role_label, assistant_role_label = conversation.roles
user, assistant = conversation.roles
role, content = part
# Uses "in" because role contains extra characters
if user_role_label in role:
if user in role:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
@@ -392,7 +384,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif assistant_role_label in role:
elif assistant in role:
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
@@ -434,8 +426,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif tool_role_label and tool_role_label in role:
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue

View File

@@ -267,8 +267,6 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
role_key_human = "human"
role_key_model = "gpt"
# Optional, only used for tool usage datasets.
role_key_tool = None
def __init__(
self,
@@ -276,7 +274,6 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
role_key_tool: Optional[str] = None,
):
if conversation:
if isinstance(conversation, Conversation):
@@ -289,8 +286,6 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
self.role_key_human = role_key_human
if role_key_model:
self.role_key_model = role_key_model
if role_key_tool:
self.role_key_tool = role_key_tool
def _build_result(self, source):
if len(source) < 2:
@@ -308,8 +303,6 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
source.pop(0)
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
if self.role_key_tool:
roles[self.role_key_tool] = conv.roles[2]
try:
# Apply prompt templates

View File

@@ -19,7 +19,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.freeze import freeze_parameters_except
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
@@ -99,7 +99,7 @@ def train(
safe_serialization = cfg.save_safetensors is True
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
freeze_parameters_except(model, cfg.unfrozen_parameters)
trainer = setup_trainer(
cfg,

View File

@@ -24,9 +24,9 @@ def check_cuda_device(default_value):
or not torch.cuda.is_available()
or device == "auto"
or torch.device(device).type == "cpu"
or torch.device(device).type == "meta"
):
return default_value
return func(*args, **kwargs)
return wrapper

View File

@@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Dict, List
import evaluate
import mlflow
import numpy as np
import pandas as pd
import torch
@@ -41,8 +42,8 @@ from axolotl.utils.distributed import (
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
IGNORE_INDEX = -100
LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100
class EvalFirstStepCallback(
@@ -61,6 +62,7 @@ class EvalFirstStepCallback(
):
if (
args.evaluation_strategy == IntervalStrategy.STEPS
and (args.eval_steps < 1.0 or args.eval_steps > 1)
and state.global_step == 1
):
control.should_evaluate = True
@@ -359,187 +361,6 @@ def bench_eval_callback_factory(trainer, tokenizer):
return BenchEvalCallback
def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
class CausalLMBenchEvalCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
self.metrics = self.__maybe_load_metrics()
def __maybe_load_metrics(self):
metrics = {}
for metric in self.cfg.eval_causal_lm_metrics:
try:
metrics[metric] = evaluate.load(metric)
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.warning(f"{metric}: {exc.args}")
return metrics
def on_evaluate(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
train_dataloader, # pylint: disable=unused-argument
eval_dataloader,
**kwargs, # pylint: disable=unused-argument
):
trainer.model.eval()
device = torch.device(self.cfg.device)
# pylint: disable=duplicate-code
generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
def find_ranges(lst):
ranges = []
start = 0
for i in range(1, len(lst)):
if lst[i] == 0:
ranges.append((start, i - 1))
start = i
end = len(lst) - 1
ranges.append((start, end))
return ranges
def compute(metric: evaluate.Metric, **kwargs):
# safely compute a metric and return the score if the format is correct
metric_score = None
try:
metric_score = metric.compute(**kwargs)
return (
metric_score["score"]
if "score" in metric_score
else metric_score["mean_score"]
)
except Exception: # pylint: disable=broad-exception-caught
LOG.debug(
f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
)
return metric_score
def evaluate_preds(sources, predictions, references):
scores = {}
for metric_name, metric in self.metrics.items():
score = compute(
metric,
references=references,
predictions=predictions,
sources=sources,
)
score = score or compute(
metric,
references=[[r] for r in references],
predictions=predictions,
)
scores[metric_name] = score
return scores
def predict_with_generate():
eval_src, eval_pred, eval_ref = [], [], []
for batch in tqdm(eval_dataloader):
batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)
if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
batch_pos_ids = [None] * len(batch["input_ids"])
prompt_token_ids_list = []
completion_token_ids_list = []
for input_ids_all, labels_all, pos_ids in zip(
batch_input_ids,
batch_labels,
batch_pos_ids,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
else:
pos_ranges = find_ranges(pos_ids)
for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue
input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]
tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)
prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)
completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)
prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)
with torch.no_grad():
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device)
predictions = trainer.model.generate(
**prompt_encoding, generation_config=generation_config
)
prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)
predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list, skip_special_tokens=True
)
eval_src.extend(prompt_texts)
eval_pred.extend(predicted_texts)
eval_ref.extend(completion_texts)
return eval_src, eval_pred, eval_ref
if is_main_process():
eval_preds = predict_with_generate()
trainer.log(evaluate_preds(*eval_preds))
return control
return CausalLMBenchEvalCallback
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""
@@ -567,7 +388,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
# pylint: disable=duplicate-code
generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_max_new_tokens,
max_new_tokens=self.cfg.eval_table_max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
@@ -755,3 +576,31 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
return control
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
"""Callback to save axolotl config to mlflow"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
return control

View File

@@ -1,44 +0,0 @@
"""MLFlow module for trainer callbacks"""
import logging
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
import mlflow
from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
# pylint: disable=duplicate-code
"""Callback to save axolotl config to mlflow"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
return control

View File

@@ -22,7 +22,6 @@ def chat_templates(user_choice: str):
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
}
if user_choice in templates:

View File

@@ -3,16 +3,11 @@ import json
import logging
import os
from pathlib import Path
from typing import Optional
import torch
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlConfigWCapabilities,
AxolotlInputConfig,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config
@@ -61,13 +56,7 @@ def normalize_config(cfg):
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0
cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
"sacrebleu",
"comet",
"ter",
"chrf",
]
cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
@@ -124,7 +113,7 @@ def normalize_config(cfg):
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower()
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
or (cfg.model_type and "llama" in cfg.model_type.lower())
)
# figure out if the model is falcon
@@ -140,7 +129,7 @@ def normalize_config(cfg):
)
or cfg.is_falcon_derived_model
or "falcon" in cfg.base_model.lower()
or (cfg.type_of_model and "rwforcausallm" in cfg.type_of_model.lower())
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
)
cfg.is_mistral_derived_model = (
@@ -153,7 +142,7 @@ def normalize_config(cfg):
)
or cfg.is_mistral_derived_model
or "mistral" in cfg.base_model.lower().split("/")[-1]
or (cfg.type_of_model and "mistral" in cfg.type_of_model.lower())
or (cfg.model_type and "mistral" in cfg.model_type.lower())
)
cfg.is_qwen_derived_model = (
@@ -164,6 +153,9 @@ def normalize_config(cfg):
]
) or cfg.is_qwen_derived_model
if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate)
if isinstance(cfg.pretraining_dataset, dict):
cfg.pretraining_dataset = [cfg.pretraining_dataset]
@@ -193,21 +185,7 @@ def normalize_cfg_datasets(cfg):
cfg.datasets[idx].conversation = "chatml"
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
if capabilities:
return DictDefault(
dict(
AxolotlConfigWCapabilities(
**cfg.to_dict(), capabilities=capabilities
).model_dump(exclude_unset=True)
)
)
return DictDefault(
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_unset=True))
)
def legacy_validate_config(cfg):
def validate_config(cfg):
"""
This is a "pre-validation" step that handles the yaml configuration before we have any
information about the model architecture
@@ -379,11 +357,11 @@ def legacy_validate_config(cfg):
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
)
if cfg.gptq and cfg.revision_of_model:
if cfg.gptq and cfg.model_revision:
raise ValueError(
"revision_of_model is not supported for GPTQ models. "
"model_revision is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove revision_of_model from the config."
+ "point to its path, and remove model_revision from the config."
)
# if cfg.sample_packing and cfg.sdp_attention:
@@ -496,6 +474,9 @@ def legacy_validate_config(cfg):
if cfg.rope_scaling:
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
if cfg.warmup_steps and cfg.warmup_ratio:
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
if cfg.wandb_run_id and not cfg.wandb_name:
cfg.wandb_name = cfg.wandb_run_id
@@ -569,21 +550,6 @@ def legacy_validate_config(cfg):
if cfg.fsdp and "bnb" in cfg.optimizer:
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
raise ValueError(
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
)
if cfg.eval_causal_lm_metrics:
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
if not isinstance(cfg.eval_causal_lm_metrics, list):
raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(cfg.eval_causal_lm_metrics) - set(supported_metrics):
raise ValueError(
f"eval_causal_lm_metrics must be one of {supported_metrics}"
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

File diff suppressed because it is too large Load Diff

View File

@@ -1,14 +0,0 @@
"""module for gpu capabilities"""
from typing import Optional
from pydantic import BaseModel, Field
class GPUCapabilities(BaseModel):
"""model to manage the gpu capabilities statically"""
bf16: bool = Field(default=False)
fp8: bool = Field(default=False)
n_gpu: int = Field(default=1)
n_node: int = Field(default=1)
compute_capability: Optional[str] = Field(default=None)

View File

@@ -937,9 +937,7 @@ def load_prepare_dpo_datasets(cfg):
for i, data_set in enumerate(split_datasets):
_type = dataset_cfgs[i]["type"]
if _type:
if isinstance(_type, DictDefault):
_type = "user_defined.default"
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
ds_transform_fn = load_dpo(_type, _cfg)
split_datasets[i] = data_set.map(
ds_transform_fn,
desc="Mapping RL Dataset",

View File

@@ -12,4 +12,4 @@ class DictDefault(Dict):
return None
def __or__(self, other):
return DictDefault(super().__ror__(other))
return DictDefault(super().__or__(other))

View File

@@ -3,14 +3,13 @@ module to freeze/unfreeze parameters by name
"""
import logging
import re
from typing import Callable, List, Tuple
from axolotl.utils.distributed import is_main_process
LOG = logging.getLogger("axolotl.utils.freeze")
def freeze_layers_except(model, regex_patterns):
def freeze_parameters_except(model, regex_patterns):
"""
Freezes all layers of the given model except for the layers that match given regex patterns.
Periods in the patterns are treated as literal periods, not as wildcard characters.
@@ -18,209 +17,22 @@ def freeze_layers_except(model, regex_patterns):
Parameters:
- model (nn.Module): The PyTorch model to be modified.
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
Note that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names.
Also, to match the entire layer name, the pattern should start with "^" and end with "$", otherwise it will match any part of the layer name.
The range pattern part is optional and it is not compiled as a regex pattern which means you must put "$" before the range pattern if you want to match the entire layer name.
E.g., ["^model.embed_tokens.weight$[:32000]", "layers.2[0-9]+.block_sparse_moe.gate.[a-z]+$"]
Returns:
None; the model is modified in place.
"""
if isinstance(regex_patterns, str):
regex_patterns = [regex_patterns]
# Escape periods and compile the regex patterns
compiled_patterns = [
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
]
patterns = [LayerNamePattern(pattern) for pattern in regex_patterns]
# First, freeze all parameters in the model
for param in model.parameters():
param.requires_grad = False
# Unfreeze layers that match the regex patterns
for name, param in model.named_parameters():
param.requires_grad = False
unfrozen_ranges = []
for pattern in patterns:
if not pattern.match(name):
continue
if any(pattern.match(name) for pattern in compiled_patterns):
if is_main_process():
LOG.debug(f"unfreezing {name}")
param.requires_grad = True
if pattern.range is not None:
unfrozen_ranges.append(pattern.range)
merged_unfrozen_ranges = _merge_ranges(unfrozen_ranges, len(param))
if param.requires_grad and is_main_process():
unfrozen_ranges = (
f" with ranges {merged_unfrozen_ranges}"
if merged_unfrozen_ranges
else ""
)
LOG.debug(f"Unfrozen {name}{unfrozen_ranges}")
if not merged_unfrozen_ranges:
continue
# The range list we need is actually the inverted of the merged ranges
ranges_to_freeze = _invert_ranges(merged_unfrozen_ranges, len(param))
param.register_hook(_create_freeze_parameters_hook(ranges_to_freeze))
if is_main_process() and all(
not param.requires_grad for param in model.parameters()
):
LOG.warning("All parameters are frozen. Model will not be trained.")
def _invert_ranges(
given_ranges: List[Tuple[int, int]], layer_size: int
) -> List[Tuple[int, int]]:
"""
Inverts a list of ranges to obtain the ranges not covered by the given ranges.
Parameters:
- given_ranges (List[Tuple[int, int]]): List of ranges to invert. Each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
Returns:
- List[Tuple[int, int]]: List of inverted ranges, where each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
"""
if not given_ranges:
return [(0, layer_size)]
inverted_ranges = []
current_start = 0
for start, end in sorted(given_ranges):
if start > current_start:
inverted_ranges.append((current_start, start))
current_start = max(current_start, end)
# Handle the case where the last given range does not reach the end of the total_size
if current_start < layer_size:
inverted_ranges.append((current_start, layer_size))
return inverted_ranges
def _merge_ranges(
given_ranges: List[Tuple[int, int | None]], layer_size: int
) -> List[Tuple[int, int]]:
"""
Merges overlapping ranges and sorts the given ranges.
This function takes a list of ranges and merges any overlapping ranges. The ranges are represented
as tuples, where the first element is the start index (inclusive) and the second element is the end
index (exclusive). The end index can be None, indicating that the range extends to the end of the
sequence.
Parameters:
- given_ranges (List[Tuple[int, int | None]]): List of ranges to merge.
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
Returns:
- List[Tuple[int, int]]: List of merged ranges, as start (inclusive) and end (exclusive) indices.
"""
# End of each range can be determined now since we have the total size
processed_ranges = [
(start, end if end is not None else layer_size) for start, end in given_ranges
]
# No need to merge if there's only one or no ranges
if len(processed_ranges) <= 1:
return processed_ranges
sorted_ranges = sorted(processed_ranges)
merged_ranges = [sorted_ranges[0]]
for start, end in sorted_ranges[1:]:
prev_start, prev_end = merged_ranges[-1]
if start <= prev_end:
merged_ranges[-1] = (prev_start, max(prev_end, end))
else:
merged_ranges.append((start, end))
return merged_ranges
def _create_freeze_parameters_hook(ranges_to_freeze: List[Tuple[int, int]]) -> Callable:
"""
Create a hook to freeze parameters in specified ranges by setting their gradients to zero.
This function takes a list of tuples representing the ranges of indices to freeze. Each tuple should contain
two integers representing the start and end indices of the range.
Parameters:
- ranges_to_freeze (List[Tuple[int, int]]): Ranges of indices to freeze.
Returns:
- Callable: A hook function to be used with `register_hook` on parameters.
Example usage:
```
ranges_to_freeze = [(0, 10), (20, 30)]
hook = _create_freeze_parameters_hook(ranges_to_freeze)
model.register_hook(hook)
```
"""
def freeze_parameters_hook(gradients):
for start, end in ranges_to_freeze:
gradients[start:end].zero_()
return freeze_parameters_hook
class LayerNamePattern:
"""
Represents a regex pattern for layer names, potentially including a parameter index range.
"""
def __init__(self, pattern: str):
"""
Initializes a new instance of the LayerNamePattern class.
Parameters:
- pattern (str): The regex pattern for layer names, potentially including a parameter index range.
"""
self.raw_pattern = pattern
name_pattern, self.range = self._parse_pattern(pattern)
self.name_regex = re.compile(name_pattern.replace(".", "\\."))
def match(self, name: str) -> bool:
"""
Checks if the given layer name matches the regex pattern.
Parameters:
- name (str): The layer name to check.
Returns:
- bool: True if the layer name matches the pattern, False otherwise.
"""
return self.name_regex.match(name) is not None
def _parse_pattern(self, pattern: str) -> Tuple[str, Tuple[int, int | None] | None]:
"""
Extracts the range pattern from the given pattern.
Parameters:
- pattern (str): The pattern to extract the range from.
Returns:
- Tuple[str, Tuple[int, int | None] | None]: A tuple containing the regex pattern to match the layer name without the range pattern and the range of layer indices to match, if specified.
"""
match = re.match(r"^(.+)\[([0-9]*)(?::([0-9]*))?\]$", pattern)
if not match:
return pattern, None
base_pattern, start_part, end_part = match.groups()
if end_part is None and start_part.isdecimal():
index = int(start_part)
return base_pattern, (index, index + 1)
# [:end] or [start:] or [start:end]
start = int(start_part) if start_part else 0
end = int(end_part) if end_part else None
if end is not None and start >= end:
raise ValueError(
f"Invalid range in layer name pattern: {pattern}."
"End of range must be greater than start."
)
return base_pattern, (start, end)

View File

@@ -7,7 +7,7 @@ from axolotl.utils.dict import DictDefault
def setup_mlflow_env_vars(cfg: DictDefault):
for key in cfg.keys():
if key.startswith("mlflow_") or key.startswith("hf_mlflow_"):
if key.startswith("mlflow_"):
value = cfg.get(key, "")
if value and isinstance(value, str) and len(value) > 0:

View File

@@ -1,20 +1,13 @@
"""Module for models and model loading"""
# pylint: disable=too-many-lines
import logging
import math
import os
import types
from typing import Any, Dict, List, Optional, Tuple, Type, Union # noqa: F401
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
import addict
import bitsandbytes as bnb
import safetensors
import torch
import transformers
from accelerate import init_empty_weights
from bitsandbytes.nn import Linear4bit, Params4bit
from fastcore.parallel import parallel
from peft import (
LoftQConfig,
PeftConfig,
@@ -23,7 +16,6 @@ from peft import (
prepare_model_for_kbit_training,
)
from peft.tuners.lora import QuantLinear
from torch import Tensor, nn
from transformers import ( # noqa: F401
AddedToken,
AutoConfig,
@@ -35,9 +27,7 @@ from transformers import ( # noqa: F401
PreTrainedTokenizerBase,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
from axolotl.core.policies.auto_wrap import SUPPORTED_AUTO_WRAP_MODEL_TYPES
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -96,8 +86,8 @@ def load_model_config(cfg):
model_config_name = cfg.tokenizer_config
trust_remote_code = cfg.trust_remote_code is True
config_kwargs = {}
if cfg.revision_of_model:
config_kwargs["revision"] = cfg.revision_of_model
if cfg.model_revision:
config_kwargs["revision"] = cfg.model_revision
try:
model_config = AutoConfig.from_pretrained(
@@ -114,8 +104,8 @@ def load_model_config(cfg):
)
raise err
if cfg.overrides_of_model_config:
for key, val in cfg.overrides_of_model_config.items():
if cfg.model_config:
for key, val in cfg.model_config.items():
setattr(model_config, key, val)
check_model_config(cfg, model_config)
@@ -272,117 +262,6 @@ def load_tokenizer(cfg):
return tokenizer
def replace_linear(
model: nn.Module,
linear_replacement: Type[nn.Module],
quant_config: Union[dict, None] = None,
skip_modules=None,
**kwargs,
):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
"""
if skip_modules is None:
skip_modules = ["lm_head"]
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_linear(
module, linear_replacement, quant_config, skip_modules, **kwargs
)
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
if issubclass(linear_replacement, Linear4bit):
model._modules[ # pylint: disable=protected-access
name
] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
**kwargs,
)
else:
raise ValueError(
f"Unsupported linear replacement: {type(linear_replacement)}"
)
return model
def load_and_quantize(
module: nn.Module,
name: str,
value: Tensor,
device: torch.device = None,
dtype: torch.dtype = None,
skip_names: Optional[List[str]] = None,
is_meta_rank: bool = False,
low_memory: bool = True,
verbose: bool = False,
quant_method: str = "bnb",
):
"""
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
"""
if skip_names is None:
skip_names = []
def place_on_device(value):
if is_meta_rank:
device = "meta"
elif low_memory:
device = "cpu"
else:
device = "cuda"
return value.to(device=device, dtype=dtype)
if any(skip_name in name for skip_name in skip_names):
if verbose:
print(f"Skipping {name} because it is in skip_names")
return
module_key, _, value_key = name.rpartition(".")
try:
submodule = module.get_submodule(module_key)
except AttributeError as exc:
print(f"Module {module_key} not found:\n{exc}")
return
try:
if quant_method == "bnb":
param = submodule.get_parameter(value_key)
if isinstance(param, Params4bit):
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
# shape as the quantized Params4bit with an initialized quant_state. However,
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
value = type(param)(
value.to(device=device, dtype=dtype).data, **param.__dict__
).cuda(device)
if is_meta_rank:
value = type(param)(value.data.to("meta"), **value.__dict__)
elif low_memory:
value = type(param)(value.data.to("cpu"), **value.__dict__)
else:
value = type(param)(place_on_device(value).data)
except AttributeError:
# it's a buffer
value = place_on_device(value)
setattr(submodule, value_key, value)
def load_model(
cfg: DictDefault,
tokenizer: PreTrainedTokenizerBase,
@@ -393,7 +272,7 @@ def load_model(
Load a model for a given configuration and tokenizer.
"""
base_model = cfg.base_model
model_type = cfg.type_of_model
model_type = cfg.model_type
model_config = load_model_config(cfg)
# TODO refactor as a kwarg
@@ -429,7 +308,7 @@ def load_model(
and cfg.flash_attention
and cfg.sample_packing
):
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
patch_for_multipack(cfg.model_config_type)
elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block
@@ -515,7 +394,7 @@ def load_model(
if max_memory is not None:
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
from accelerate import infer_auto_device_map
from accelerate import infer_auto_device_map, init_empty_weights
with init_empty_weights():
model_canvas = AutoModelForCausalLM.from_config(model_config)
@@ -547,8 +426,8 @@ def load_model(
if is_deepspeed_zero3_enabled():
del model_kwargs["device_map"]
if cfg.revision_of_model:
model_kwargs["revision"] = cfg.revision_of_model
if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.gptq:
if not hasattr(model_config, "quantization_config"):
LOG.warning("model config does not contain quantization_config information")
@@ -617,78 +496,8 @@ def load_model(
model_kwargs["attn_implementation"] = "eager"
model_config._attn_implementation = "eager" # pylint: disable=protected-access
qlora_fsdp = (
cfg.fsdp
and cfg.adapter == "qlora"
and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES
)
try:
if qlora_fsdp:
if cfg.bf16 or cfg.bfloat16:
torch_dtype, compute_dtype = torch.float32, torch.bfloat16
elif cfg.fp16 or cfg.float16:
torch_dtype, compute_dtype = torch.float32, torch.float16
else:
torch_dtype, compute_dtype = torch.float32, torch.float16
with init_empty_weights():
LOG.info("Loading model with empty weights.")
model = AutoModelForCausalLM.from_config(model_config)
model.model = replace_linear(
model.model,
Linear4bit,
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=torch_dtype,
)
model.is_loaded_in_4bit = True
# Grab the safetensors files that hold the weights
try:
idx = hub.cached_file(base_model, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(base_model, idx)
except OSError:
try:
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
files = []
files.append(hub.cached_file(base_model, SAFE_WEIGHTS_NAME))
except OSError as exc:
# This means the model probably doesn't have a safetensors file
raise exc
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
def load_and_quantize_parallel(name_param, model, **kwargs):
name, param = name_param
load_and_quantize(model, name, param, **kwargs)
param_count = sum((p.numel() for n, p in model.named_parameters()))
for filename in files:
weights = safetensors.torch.load_file(filename)
quant_method = "bnb"
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
left = int(os.cpu_count() / torch.cuda.device_count())
right = int(
8 * (devprops.total_memory / 1e9 / 40) * (70 / (param_count / 1e9))
)
n_workers = min(left, right)
parallel(
load_and_quantize_parallel,
weights.items(),
n_workers=n_workers,
threadpool=True,
model=model,
dtype=torch_dtype,
device=cfg.local_rank,
skip_names=[],
is_meta_rank=(cfg.local_rank != 0),
verbose=False,
quant_method=quant_method,
)
elif (
if (
model_config.model_type == "llama"
and not cfg.trust_remote_code
and not cfg.gptq
@@ -703,12 +512,11 @@ def load_model(
if cfg.flash_attention and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
is_xformers_swiglu_available,
replace_llama_mlp_with_swiglu,
replace_llama_qkv_with_fused,
)
if cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
if cfg.flash_attn_fuse_mlp:
LOG.info("patching with SwiGLU")
replace_llama_mlp_with_swiglu(model)
@@ -804,7 +612,7 @@ def load_model(
LOG.exception(err)
raise err
if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp:
if isinstance(model, (PeftModel, PeftModelForCausalLM)):
model = model.merge_and_unload()
embeddings_len = (
@@ -883,9 +691,6 @@ def load_model(
if cfg.adapter == "lora" and loftq_bits:
skip_prepare_model_for_kbit_training = True
if qlora_fsdp:
skip_prepare_model_for_kbit_training = True
if cfg.adapter in ["lora", "qlora"]:
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable()
@@ -900,7 +705,7 @@ def load_model(
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if (needs_fa2_dtype or cfg.flash_attention) and not qlora_fsdp:
if needs_fa2_dtype or cfg.flash_attention:
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules():
if "norm" in name:
@@ -918,12 +723,7 @@ def load_model(
else:
model, lora_config = load_adapter(model, cfg, cfg.adapter)
if (
cfg.ddp
and not load_in_8bit
and not (cfg.rl and cfg.load_in_4bit)
and not qlora_fsdp
):
if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
# TODO revaldate this conditional
model.to(f"cuda:{cfg.local_rank}")
@@ -1012,30 +812,6 @@ def find_all_linear_names(model):
return list(lora_module_names)
def setup_quantized_meta_for_peft(model: nn.Module):
"""Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device"""
def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument
return self
for param in model.parameters():
if isinstance(param, Params4bit):
param.quant_state._orig_to = ( # pylint: disable=protected-access
param.quant_state.to
)
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
def setup_quantized_peft_meta_for_training(model: nn.Module):
"""Replaces dummy `quant_state.to` method with the original function to allow training to continue"""
for param in model.parameters():
if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"):
param.quant_state.to = (
param.quant_state._orig_to # pylint: disable=protected-access
)
param.quant_state._orig_to = None # pylint: disable=protected-access
def load_lora(model, cfg, inference=False, config_only=False):
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
@@ -1053,10 +829,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
if cfg.peft_use_rslora:
lora_config_kwargs["use_rslora"] = cfg.use_rslora
lora_config = LoraConfig(
r=cfg.lora_r,
@@ -1074,11 +846,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
if config_only:
return None, lora_config
rank = int(os.environ.get("LOCAL_RANK", 0))
if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
setup_quantized_meta_for_peft(model)
if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - LoRA")
model_kwargs: Any = {}
@@ -1094,9 +861,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
else:
model = get_peft_model(model, lora_config)
if rank == 0:
model.print_trainable_parameters()
elif cfg.fsdp and cfg.adapter == "qlora":
setup_quantized_peft_meta_for_training(model)
model.print_trainable_parameters()
return model, lora_config

View File

@@ -1,6 +1,7 @@
"""Module for custom LRScheduler class"""
import math
from functools import partial
from typing import Sequence
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
@@ -52,7 +53,7 @@ def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
*,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float,
num_cycles: float
):
if current_step < num_warmup_steps:
return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
@@ -107,7 +108,7 @@ def _get_cosine_schedule_with_min_lr_lambda(
*,
num_warmup_steps: int,
num_training_steps: int,
min_lr_ratio: float,
min_lr_ratio: float
):
# Warm up
if current_step < num_warmup_steps:
@@ -142,78 +143,46 @@ def get_cosine_schedule_with_min_lr(
return LambdaLR(optimizer, lr_lambda)
def _get_cosine_schedule_with_warmup_decay_constant_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
constant_lr_ratio: float,
min_lr_ratio: float,
num_cycles: float,
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
class JaggedLRRestartScheduler(LRScheduler):
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
num_constant_steps = int(num_training_steps * constant_lr_ratio)
current_step = min(current_step, num_constant_steps)
def __init__(
self,
optimizer: Optimizer,
inner_schedule: LRScheduler,
jagged_restarts_steps: int,
jagged_restarts_warmup_steps: int,
jagged_restarts_anneal_steps: int = 1,
min_lr_scale: float = 0.001,
) -> None:
self.inner_schedule = inner_schedule
self.restarts_steps = jagged_restarts_steps
self.warmup_steps = jagged_restarts_warmup_steps
self.anneal_steps = jagged_restarts_anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
progress = float(current_step - num_warmup_steps) / float(
max(1, num_constant_steps - num_warmup_steps)
)
def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch
return (
max(
0,
(1 - min_lr_ratio)
* 0.5
* (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
)
+ min_lr_ratio
)
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.restarts_steps:
scale = 1
else:
per_relora_progress = step % self.restarts_steps
if per_relora_progress < self.warmup_steps:
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
elif per_relora_progress > (self.restarts_steps - self.anneal_steps):
cycle_t = min(
1.0,
(self.restarts_steps - per_relora_progress) / self.anneal_steps,
)
else:
cycle_t = 1
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
def get_cosine_schedule_with_warmup_decay_constant(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
constant_lr_ratio: float,
min_lr_ratio: float,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Implementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (https://arxiv.org/pdf/2308.04014.pdf)
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to min_lr_ratio until num_training_steps * constant_lr_ratio, after constant_rate returns constant value of min_rate
, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
constant_lr_ratio: (`float`):
The ratio of num_training_steps to decrease by cosine function.
min_lr_ratio: (`float):
The ratio of maximum learning rate for cosine function to decay to minimum learning rate.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_cosine_schedule_with_warmup_decay_constant_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
constant_lr_ratio=constant_lr_ratio,
min_lr_ratio=min_lr_ratio,
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
if isinstance(original, Sequence):
return [lr * scale for lr in original]
return original * scale

View File

@@ -2,8 +2,6 @@
import logging
import re
from typing import Dict, List
from termcolor import colored
@@ -38,65 +36,3 @@ def check_example_labels(example, tokenizer, text_only=False):
LOG.info("\n\n\n")
return " ".join(colored_tokens)
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
GLAIVE_TO_SHAREGPT_ROLE = {
"SYSTEM": "system",
"USER": "human",
"ASSISTANT": "gpt",
"FUNCTION RESPONSE": "tool",
}
GLAIVE_MSG_REGEX = re.compile(rf"({'|'.join(GLAIVE_ROLES)}): ")
def chatml_to_conversation(row: Dict[str, str]) -> List[Dict[str, str]]:
"""
Converts a ChatML formatted row to a list of messages in ShareGPT format.
Initially based off https://github.com/lilacai/lilac/blob/main/notebooks/GlaiveToShareGPT.ipynb.
"""
system_prompt = row.get("system")
if system_prompt:
system_prompt = system_prompt.removeprefix("SYSTEM: ")
chat_str = row["chat"]
chat_msgs = [s.strip() for s in GLAIVE_MSG_REGEX.split(chat_str) if s]
chat_msg_dicts = [
{"from": GLAIVE_TO_SHAREGPT_ROLE[role], "value": value}
for role, value in zip(chat_msgs[::2], chat_msgs[1::2])
]
if system_prompt:
chat_msg_dicts = [
{"from": GLAIVE_TO_SHAREGPT_ROLE["SYSTEM"], "value": system_prompt}
] + chat_msg_dicts
return chat_msg_dicts
def merge_consecutive_messages(messages):
"""
Merge consecutive messages from the same sender into a single message.
This can be useful with datasets that contain multiple consecutive tool calls.
"""
merged_messages = []
current_from = None
current_message = ""
for msg in messages:
if current_from == msg["from"]:
current_message += msg["value"]
else:
if current_from is not None:
merged_messages.append({"from": current_from, "value": current_message})
current_from = msg["from"]
current_message = msg["value"]
if current_from is not None:
merged_messages.append({"from": current_from, "value": current_message})
return merged_messages

View File

@@ -255,7 +255,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
train_dataset.remove_columns(["length"]),
batch_sampler=sampler,
)
data_loader_len = len(data_loader) // cfg.batch_size
data_loader_len = len(data_loader) // batch_size
actual_eff = sampler.efficiency()
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
# FIXME: is there a bug here somewhere? the total num steps depends

Some files were not shown because too many files have changed in this diff Show More