Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
5b15816cf4 drop valueerror as this was from when 4bit required gptq 2024-08-22 19:16:32 -04:00
136 changed files with 1716 additions and 9010 deletions

View File

@@ -28,19 +28,7 @@ jobs:
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.4.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout

View File

@@ -6,7 +6,7 @@ on:
- '**.py'
- 'requirements.txt'
- '.github/workflows/*.yml'
- "*.[q]md"
- "*.md"
- "examples/**/*.y[a]?ml"
workflow_dispatch:

View File

@@ -27,12 +27,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -89,12 +84,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -1,9 +1,6 @@
name: docker-multigpu-tests-biweekly
on:
pull_request:
paths:
- 'tests/e2e/multigpu/*.py'
workflow_dispatch:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
@@ -21,17 +18,10 @@ jobs:
pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
- cuda: 124
cuda_version: 12.4.1
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"

View File

@@ -26,12 +26,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -88,12 +83,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -27,7 +27,7 @@ jobs:
run: |
pip3 install wheel packaging
pip3 install -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pip3 install -r requirements-tests.txt
- name: Extract tag name
id: tag

View File

@@ -25,7 +25,6 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
timeout-minutes: 20
steps:
@@ -38,23 +37,19 @@ jobs:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
- name: Update requirements.txt
run: |
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
sed -i 's#^bitsandbytes.*#bitsandbytes @ git+https://github.com/bitsandbytes-foundation/bitsandbytes.git@main#' requirements.txt
- name: Install dependencies
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pip3 install -r requirements-tests.txt
- name: Run tests
run: |
@@ -92,14 +87,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.4.0
num_gpus: 1
axolotl_extras:
nightly_build: "true"

View File

@@ -36,7 +36,6 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
timeout-minutes: 20
steps:
@@ -49,20 +48,12 @@ jobs:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 show torch
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pip3 install -r requirements-tests.txt
- name: Run tests
run: |
@@ -76,7 +67,7 @@ jobs:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
timeout-minutes: 60
needs: [pre-commit, pytest]
strategy:
@@ -98,13 +89,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.4.0
num_gpus: 1
axolotl_extras:
steps:

View File

@@ -1,3 +1,3 @@
[settings]
profile=black
known_third_party=wandb,comet_ml
known_third_party=wandb

View File

@@ -11,9 +11,6 @@ ignore_errors = True
[mypy-axolotl.models.mixtral.*]
ignore_errors = True
[mypy-axolotl.integrations.liger.models.*]
ignore_errors = True
[mypy-axolotl.models.phi.*]
ignore_errors = True

295
1991.yml
View File

@@ -1,295 +0,0 @@
base_model: Qwen/Qwen2.5-14B-Instruct
model_type: AutoModelForCausalLM #nohup accelerate launch -m axolotl.cli.train /home/ubuntu/qwen2.5_14B.yml > training_output.log 2>&1 &
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
chat_template: chatml
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
unfrozen_parameters:
- ^lm_head.weight$
- ^model.embed_tokens.weight$
# input_layernorm layers
- model.layers.0.input_layernorm
- model.layers.1.input_layernorm
- model.layers.2.input_layernorm
- model.layers.3.input_layernorm
- model.layers.4.input_layernorm
- model.layers.5.input_layernorm
- model.layers.6.input_layernorm
- model.layers.7.input_layernorm
- model.layers.8.input_layernorm
- model.layers.9.input_layernorm
- model.layers.10.input_layernorm
- model.layers.11.input_layernorm
- model.layers.12.input_layernorm
- model.layers.13.input_layernorm
- model.layers.14.input_layernorm
- model.layers.15.input_layernorm
- model.layers.16.input_layernorm
- model.layers.17.input_layernorm
- model.layers.18.input_layernorm
- model.layers.19.input_layernorm
- model.layers.20.input_layernorm
- model.layers.21.input_layernorm
- model.layers.22.input_layernorm
- model.layers.23.input_layernorm
# lm_head layers
# mlp.down_proj layers
- model.layers.1.mlp.down_proj
- model.layers.35.mlp.down_proj
- model.layers.38.mlp.down_proj
- model.layers.37.mlp.down_proj
- model.layers.36.mlp.down_proj
- model.layers.15.mlp.down_proj
- model.layers.11.mlp.down_proj
- model.layers.12.mlp.down_proj
- model.layers.34.mlp.down_proj
- model.layers.44.mlp.down_proj
- model.layers.45.mlp.down_proj
- model.layers.9.mlp.down_proj
- model.layers.41.mlp.down_proj
- model.layers.33.mlp.down_proj
- model.layers.43.mlp.down_proj
- model.layers.40.mlp.down_proj
- model.layers.13.mlp.down_proj
- model.layers.8.mlp.down_proj
- model.layers.39.mlp.down_proj
- model.layers.10.mlp.down_proj
- model.layers.14.mlp.down_proj
- model.layers.16.mlp.down_proj
- model.layers.31.mlp.down_proj
- model.layers.32.mlp.down_proj
# mlp.gate_proj layers
- model.layers.1.mlp.gate_proj
- model.layers.44.mlp.gate_proj
- model.layers.46.mlp.gate_proj
- model.layers.45.mlp.gate_proj
- model.layers.43.mlp.gate_proj
- model.layers.47.mlp.gate_proj
- model.layers.42.mlp.gate_proj
- model.layers.32.mlp.gate_proj
- model.layers.27.mlp.gate_proj
- model.layers.33.mlp.gate_proj
- model.layers.28.mlp.gate_proj
- model.layers.39.mlp.gate_proj
- model.layers.41.mlp.gate_proj
- model.layers.40.mlp.gate_proj
- model.layers.30.mlp.gate_proj
- model.layers.29.mlp.gate_proj
- model.layers.31.mlp.gate_proj
- model.layers.26.mlp.gate_proj
- model.layers.37.mlp.gate_proj
- model.layers.10.mlp.gate_proj
- model.layers.38.mlp.gate_proj
- model.layers.12.mlp.gate_proj
- model.layers.36.mlp.gate_proj
- model.layers.13.mlp.gate_proj
# mlp.up_proj layers
- model.layers.1.mlp.up_proj
- model.layers.13.mlp.up_proj
- model.layers.11.mlp.up_proj
- model.layers.14.mlp.up_proj
- model.layers.15.mlp.up_proj
- model.layers.12.mlp.up_proj
- model.layers.8.mlp.up_proj
- model.layers.16.mlp.up_proj
- model.layers.9.mlp.up_proj
- model.layers.19.mlp.up_proj
- model.layers.10.mlp.up_proj
- model.layers.7.mlp.up_proj
- model.layers.17.mlp.up_proj
- model.layers.20.mlp.up_proj
- model.layers.21.mlp.up_proj
- model.layers.18.mlp.up_proj
- model.layers.38.mlp.up_proj
- model.layers.37.mlp.up_proj
- model.layers.39.mlp.up_proj
- model.layers.42.mlp.up_proj
- model.layers.41.mlp.up_proj
- model.layers.27.mlp.up_proj
- model.layers.28.mlp.up_proj
- model.layers.34.mlp.up_proj
# model.norm layers
# post_attention_layernorm layers
- model.layers.0.post_attention_layernorm
- model.layers.1.post_attention_layernorm
- model.layers.2.post_attention_layernorm
- model.layers.3.post_attention_layernorm
- model.layers.4.post_attention_layernorm
- model.layers.5.post_attention_layernorm
- model.layers.6.post_attention_layernorm
- model.layers.7.post_attention_layernorm
- model.layers.8.post_attention_layernorm
- model.layers.9.post_attention_layernorm
- model.layers.10.post_attention_layernorm
- model.layers.11.post_attention_layernorm
- model.layers.12.post_attention_layernorm
- model.layers.13.post_attention_layernorm
- model.layers.14.post_attention_layernorm
- model.layers.15.post_attention_layernorm
- model.layers.16.post_attention_layernorm
- model.layers.17.post_attention_layernorm
- model.layers.18.post_attention_layernorm
- model.layers.19.post_attention_layernorm
- model.layers.20.post_attention_layernorm
- model.layers.21.post_attention_layernorm
- model.layers.22.post_attention_layernorm
- model.layers.23.post_attention_layernorm
# self_attn.k_proj layers
- model.layers.47.self_attn.k_proj
- model.layers.39.self_attn.k_proj
- model.layers.41.self_attn.k_proj
- model.layers.37.self_attn.k_proj
- model.layers.35.self_attn.k_proj
- model.layers.44.self_attn.k_proj
- model.layers.38.self_attn.k_proj
- model.layers.14.self_attn.k_proj
- model.layers.7.self_attn.k_proj
- model.layers.12.self_attn.k_proj
- model.layers.11.self_attn.k_proj
- model.layers.32.self_attn.k_proj
- model.layers.10.self_attn.k_proj
- model.layers.8.self_attn.k_proj
- model.layers.9.self_attn.k_proj
- model.layers.6.self_attn.k_proj
- model.layers.45.self_attn.k_proj
- model.layers.42.self_attn.k_proj
- model.layers.5.self_attn.k_proj
- model.layers.40.self_attn.k_proj
- model.layers.33.self_attn.k_proj
- model.layers.0.self_attn.k_proj
- model.layers.34.self_attn.k_proj
- model.layers.13.self_attn.k_proj
# self_attn.o_proj layers
- model.layers.12.self_attn.o_proj
- model.layers.5.self_attn.o_proj
- model.layers.14.self_attn.o_proj
- model.layers.16.self_attn.o_proj
- model.layers.20.self_attn.o_proj
- model.layers.13.self_attn.o_proj
- model.layers.11.self_attn.o_proj
- model.layers.4.self_attn.o_proj
- model.layers.6.self_attn.o_proj
- model.layers.19.self_attn.o_proj
- model.layers.7.self_attn.o_proj
- model.layers.18.self_attn.o_proj
- model.layers.8.self_attn.o_proj
- model.layers.38.self_attn.o_proj
- model.layers.15.self_attn.o_proj
- model.layers.17.self_attn.o_proj
- model.layers.9.self_attn.o_proj
- model.layers.10.self_attn.o_proj
- model.layers.21.self_attn.o_proj
- model.layers.28.self_attn.o_proj
- model.layers.32.self_attn.o_proj
- model.layers.35.self_attn.o_proj
- model.layers.39.self_attn.o_proj
- model.layers.3.self_attn.o_proj
# self_attn.q_proj layers
- model.layers.1.self_attn.q_proj
- model.layers.2.self_attn.q_proj
- model.layers.3.self_attn.q_proj
- model.layers.44.self_attn.q_proj
- model.layers.29.self_attn.q_proj
- model.layers.45.self_attn.q_proj
- model.layers.43.self_attn.q_proj
- model.layers.32.self_attn.q_proj
- model.layers.38.self_attn.q_proj
- model.layers.19.self_attn.q_proj
- model.layers.42.self_attn.q_proj
- model.layers.34.self_attn.q_proj
- model.layers.36.self_attn.q_proj
- model.layers.40.self_attn.q_proj
- model.layers.26.self_attn.q_proj
- model.layers.20.self_attn.q_proj
- model.layers.39.self_attn.q_proj
- model.layers.28.self_attn.q_proj
- model.layers.35.self_attn.q_proj
- model.layers.41.self_attn.q_proj
- model.layers.33.self_attn.q_proj
- model.layers.25.self_attn.q_proj
- model.layers.30.self_attn.q_proj
- model.layers.27.self_attn.q_proj
# self_attn.v_proj layers
- model.layers.0.self_attn.v_proj
- model.layers.7.self_attn.v_proj
- model.layers.39.self_attn.v_proj
- model.layers.31.self_attn.v_proj
- model.layers.15.self_attn.v_proj
- model.layers.10.self_attn.v_proj
- model.layers.32.self_attn.v_proj
- model.layers.41.self_attn.v_proj
- model.layers.6.self_attn.v_proj
- model.layers.33.self_attn.v_proj
- model.layers.42.self_attn.v_proj
- model.layers.29.self_attn.v_proj
- model.layers.14.self_attn.v_proj
- model.layers.9.self_attn.v_proj
- model.layers.35.self_attn.v_proj
- model.layers.38.self_attn.v_proj
- model.layers.13.self_attn.v_proj
- model.layers.30.self_attn.v_proj
- model.layers.5.self_attn.v_proj
- model.layers.34.self_attn.v_proj
- model.layers.28.self_attn.v_proj
- model.layers.37.self_attn.v_proj
- model.layers.27.self_attn.v_proj
- model.layers.11.self_attn.v_proj
# model.embed_tokens layers
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_torch_fused
lr_scheduler: linear
learning_rate: 5e-6
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
gradient_checkpointing: unsloth
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 2
saves_per_epoch: 1
save_total_limit: 4
debug:
deepspeed: deepspeed_configs/zero3_bf16.json
weight_decay: 0.05
special_tokens:
eos_token: <|im_end|>

View File

@@ -11,10 +11,10 @@ Features:
- Supports fullfinetune, lora, qlora, relora, and gptq
- Customize configurations using a simple yaml file or CLI overwrite
- Load different dataset formats, use custom formats, or bring your own tokenized datasets
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Integrated with xformer, flash attention, rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb, mlflow or Comet
- Log results and optionally checkpoints to wandb or mlflow
- And more!
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
@@ -55,7 +55,6 @@ Features:
- [FSDP + QLoRA](#fsdp--qlora)
- [Weights \& Biases Logging](#weights--biases-logging)
- [Special Tokens](#special-tokens)
- [Liger Kernel](#liger-kernel)
- [Inference Playground](#inference-playground)
- [Merge LORA to base](#merge-lora-to-base)
- [Common Errors 🧰](#common-errors-)
@@ -121,7 +120,7 @@ Features:
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1.
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl
@@ -383,7 +382,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript
type: ... # unimplemented custom format
# fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template)
# fastchat conversation
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
- path: ...
type: sharegpt
@@ -515,22 +514,6 @@ wandb_name:
wandb_log_model:
```
##### Comet Logging
Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.
- wandb options
```yaml
use_comet:
comet_api_key:
comet_workspace:
comet_project_name:
comet_experiment_key:
comet_mode:
comet_online:
comet_experiment_config:
```
##### Special Tokens
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
@@ -547,25 +530,6 @@ tokens: # these are delimiters
When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.
##### Liger Kernel
Liger Kernel: Efficient Triton Kernels for LLM Training
https://github.com/linkedin/Liger-Kernel
Liger (LinkedIn GPU Efficient Runtime) Kernel is a collection of Triton kernels designed specifically for LLM training.
It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The Liger Kernel
composes well and is compatible with both FSDP and Deepspeed.
```yaml
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
```
### Inference Playground
Axolotl allows you to load your model in an interactive terminal playground for quick experimentation.

View File

@@ -37,7 +37,6 @@ website:
- docs/mac.qmd
- docs/multi-node.qmd
- docs/unsloth.qmd
- docs/amd_hpc.qmd
- section: "Dataset Formats"
contents: docs/dataset-formats/*
- section: "Reference"

View File

@@ -23,11 +23,12 @@ RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
sed -i 's#^bitsandbytes.*#bitsandbytes @ git+https://github.com/bitsandbytes-foundation/bitsandbytes.git@main#' requirements.txt; \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
@@ -37,7 +38,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
fi
# So we can test the Docker image
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
RUN pip install -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \

View File

@@ -1,6 +1,6 @@
#!/bin/bash
set -e
pytest -n4 --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ /workspace/axolotl/tests/e2e/

View File

@@ -64,7 +64,7 @@ def run_cmd(cmd: str, run_folder: str):
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=60 * 60,
timeout=45 * 60,
cpu=8.0,
memory=131072 * N_GPUS,
)

View File

@@ -65,7 +65,7 @@ def run_cmd(cmd: str, run_folder: str):
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=60 * 60,
timeout=45 * 60,
cpu=8.0,
memory=131072,
)

View File

@@ -14,6 +14,15 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",

View File

@@ -24,6 +24,15 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",

View File

@@ -20,6 +20,15 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",

View File

@@ -20,6 +20,7 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -1,108 +0,0 @@
---
title: Training with AMD GPUs on HPC Systems
description: A comprehensive guide for using Axolotl on distributed systems with AMD GPUs
---
This guide provides step-by-step instructions for installing and configuring Axolotl on a High-Performance Computing (HPC) environment equipped with AMD GPUs.
## Setup
### 1. Install Python
We recommend using Miniforge, a minimal conda-based Python distribution:
```bash
curl -L -O "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-$(uname)-$(uname -m).sh
```
### 2. Configure Python Environment
Add Python to your PATH and ensure it's available at login:
```bash
echo 'export PATH=~/miniforge3/bin:$PATH' >> ~/.bashrc
echo 'if [ -f ~/.bashrc ]; then . ~/.bashrc; fi' >> ~/.bash_profile
```
### 3. Load AMD GPU Software
Load the ROCm module:
```bash
module load rocm/5.7.1
```
Note: The specific module name and version may vary depending on your HPC system. Consult your system documentation for the correct module name.
### 4. Install PyTorch
Install PyTorch with ROCm support:
```bash
pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7 --force-reinstall
```
### 5. Install Flash Attention
Clone and install the Flash Attention repository:
```bash
git clone --recursive https://github.com/ROCmSoftwarePlatform/flash-attention.git
export GPU_ARCHS="gfx90a"
cd flash-attention
export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')
patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch
pip install .
```
### 6. Install Axolotl
Clone and install Axolotl:
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl
cd axolotl
pip install packaging ninja
pip install -e .
```
### 7. Apply xformers Workaround
xformers appears to be incompatible with ROCm. Apply the following workarounds:
- Edit $HOME/packages/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py modifying the code to always return `False` for SwiGLU availability from xformers.
- Edit $HOME/miniforge3/lib/python3.10/site-packages/xformers/ops/swiglu_op.py replacing the "SwiGLU" function with a pass statement.
### 8. Prepare Job Submission Script
Create a script for job submission using your HPC's particular software (e.g. Slurm, PBS). Include necessary environment setup and the command to run Axolotl training. If the compute node(s) do(es) not have internet access, it is recommended to include
```bash
export TRANSFORMERS_OFFLINE=1
export HF_DATASETS_OFFLINE=1
```
### 9. Download Base Model
Download a base model using the Hugging Face CLI:
```bash
huggingface-cli download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
```
### 10. Create Axolotl Configuration
Create an Axolotl configuration file (YAML format) tailored to your specific training requirements and dataset. Use FSDP for multi-node training.
Note: Deepspeed did not work at the time of testing. However, if anyone managed to get it working, please let us know.
### 11. Preprocess Data
Run preprocessing on the login node:
```bash
CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess /path/to/your/config.yaml
```
### 12. Train
You are now ready to submit your previously prepared job script. 🚂

View File

@@ -83,14 +83,13 @@ lora_on_cpu: true
datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
data_files: # Optional[str] path to source data files
shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
@@ -124,48 +123,6 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column
field:
# Using chat template
- path: ...
# Set type to `chat_template` to use this strategy
type: chat_template
# Specify the name of the chat template to use
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
chat_template: tokenizer_default
# Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`).
chat_template_jinja:
# The key in the data example that contains the messages. Default is "messages".
field_messages: messages
# The key in the message turn that contains the role. Default is "role".
message_field_role: role
# The key in the message turn that contains the content. Default is "content".
message_field_content: content
# Optional[Dict[str, List]]. Roles mapping for the messages.
roles:
user: ["human", "user"]
assistant: ["gpt", "assistant", "ai"]
system: ["system"]
## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on.
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
roles_to_train: ["gpt", "assistant"]
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
# - all: train on all EOS tokens
# - turn: train on the EOS token at the end of each trainable turn
# - last: train on the last EOS token in the conversation
train_on_eos: last
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
message_field_training: training
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
# See example at `docs/dataset-formats/conversation.qmd`
message_field_training_detail: train_detail
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
@@ -184,16 +141,9 @@ test_datasets:
# use RL training: 'dpo', 'ipo', 'kto'
rl:
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
# Currently supports chatml and inst (mistral/mixtral)
chat_template: chatml
# Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so
@@ -315,21 +265,8 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# mlflow configuration if you're using it
mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name
mlflow_run_name: # Your run name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
# Comet configuration if you're using it
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
use_comet: # Enable or disable Comet integration.
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
# Where to save the full-finetuned model to
output_dir: ./completed-model
@@ -364,7 +301,7 @@ max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)

View File

@@ -6,8 +6,6 @@ order: 3
## sharegpt
UPDATE: ShareGPT is being deprecated in the next release. Please see `chat_template` section below.
conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
```{.json filename="data.jsonl"}
@@ -71,138 +69,3 @@ creates a chat where bot is asked to tell a joke, then explain why the joke is f
```{.json filename="data.jsonl"}
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
```
## chat_template
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "content": "..."}]}
```
See `config.qmd` for full configs and supported templates.
### Migrating from sharegpt
Most configs can be adapted as follows:
```yaml
# old
chat_template: chatml
datasets:
- path: ...
type: sharegpt
conversation: chatml
# new (if using tokenizer's chat_template)
datasets:
- path: ...
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
# new (if setting a new chat_template like chatml, gemma, etc)
chat_template: chatml
datasets:
- path: ...
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
```
We recommend checking the below examples for other usecases.
### Examples
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
```yaml
datasets:
- path: ...
type: chat_template
```
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: gemma # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
```yaml
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
```{.json filename="data.jsonl"}
{
"conversations": [
{"from": "system", "value": "You are an AI assistant.", "train": false},
{"from": "human", "value": "Hello", "train": false},
{"from": "assistant", "value": "Hello", "train": true},
{"from": "human", "value": "How are you?", "train": true},
{
"from": "assistant",
"value": "I'm doing very well, thank you!",
"train_detail": [
{"begin_offset": 0, "end_offset": 8, "train": false},
{"begin_offset": 9, "end_offset": 18, "train": true},
{"begin_offset": 19, "end_offset": 30, "train": false},
],
},
{
"from": "human",
"value": "I'm doing very well, thank you!",
"train": true,
},
{"from": "assistant", "value": "Hi there!", "train": true}
]
}
```
The configuration would look like:
```yaml
datasets:
- path: ...
type: chat_template
chat_template: tokenizer_default
field_messages: conversations
message_field_role: from
message_field_content: value
roles_to_train: []
train_on_eos: turn
message_field_training: train
message_field_training_detail: train_detail
```
Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time.

View File

@@ -7,7 +7,7 @@ order: 5
- Pass an empty `type:` in your axolotl config.
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
- To indicate that a token should be ignored during training, set its corresponding label to `-100`.
- You must add BOS and EOS, and make sure that you are training on EOS by not setting its label to -100.
- Do not add BOS/EOS. Axolotl will add them for you based on the default tokenizer for the model you're using.
- For pretraining, do not truncate/pad documents to the context window length.
- For instruction training, documents must be truncated/padded as desired.

View File

@@ -205,7 +205,7 @@ ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
hi there!. goodbye farewell</s>
```
We can check that the right tokens are ignored by comparing the labels
We can check that the right tokens are ingored by comparing the labels
to each token:
```python

View File

@@ -1,28 +0,0 @@
# MultiModal / Vision Language Models (BETA)
### Supported Models
- Mllama, i.e. llama with vision models
### Usage
Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA,
you'll need to use the following in YAML in combination with the rest of the required hyperparams.
```yaml
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
processor_type: AutoProcessor
skip_prepare_dataset: true
chat_template: llama3_2_vision
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
remove_unused_columns: false
sample_packing: false
# only finetune the Language model, leave the vision model and vision tower frozen
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
```

View File

@@ -1,67 +0,0 @@
base_model: deepseek-ai/DeepSeek-V2-Lite
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
special_tokens:
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD

View File

@@ -1,83 +0,0 @@
base_model: axolotl-quants/DeepSeek-V2.5-bnb-nf4-bf16
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rms_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
chat_template: deepseek_v2
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
adapter: qlora
lora_r: 256
lora_alpha: 256
lora_target_linear: true
peft_use_rslora: true
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
special_tokens:
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD

View File

@@ -1,63 +0,0 @@
base_model: google/gemma-2-2b
model_type: AutoModelForSequenceClassification
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
reward_model: true
chat_template: gemma
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
remove_unused_columns: false
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,63 +0,0 @@
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: llama3_2_vision
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -1,76 +0,0 @@
base_model: NousResearch/Meta-Llama-3.1-8B
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
strict: false
chat_template: llama3
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_backward_prefetch: BACKWARD_PRE
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>

View File

@@ -1,4 +1,6 @@
base_model: NousResearch/Meta-Llama-3.1-8B
base_model: NousResearch/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false

View File

@@ -11,6 +11,7 @@ rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
chat_template: llama3
field_messages: conversation
field_chosen: chosen
field_rejected: rejected

View File

@@ -10,6 +10,7 @@ chat_template: llama3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
chat_template: llama3
field_messages: messages
message_field_role: role
message_field_content: content

View File

@@ -1,77 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -1,76 +0,0 @@
base_model: microsoft/Phi-3.5-mini-instruct
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
chat_template: phi_3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
chat_template: phi_3
field_messages: messages
message_field_role: role
message_field_content: content
roles:
user:
- user
assistant:
- assistant
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 2
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bfloat16: true
bf16: true
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 4
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -2,4 +2,3 @@ pre-commit
black
mypy
types-requests
tbparse

View File

@@ -1,12 +1,12 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.2
transformers==4.46.0
tokenizers>=0.20.1
bitsandbytes==0.44.1
accelerate==1.0.1
datasets==3.0.1
deepspeed==0.15.3
peft==0.12.0
transformers==4.44.0
tokenizers>=0.19.1
bitsandbytes==0.43.3
accelerate==0.33.0
datasets==2.20.0
deepspeed==0.14.4
pydantic==2.6.3
addict
fire
@@ -16,7 +16,7 @@ flash-attn==2.6.3
sentencepiece
wandb
einops
xformers>=0.0.23.post1
xformers==0.0.27
optimum==1.16.2
hf_transfer
colorama
@@ -33,8 +33,6 @@ gradio==3.50.2
tensorboard
python-dotenv==1.0.1
autoawq>=0.2.5
triton>=2.3.0
liger-kernel==0.3.0
mamba-ssm==1.2.0.post1
@@ -43,14 +41,6 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0
# adlfs
trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924
trl==0.9.6
zstandard==0.22.0
fastcore
# lm eval harness
lm_eval==0.4.4
langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.5.0

View File

@@ -1,315 +0,0 @@
accelerate==0.34.1
addict==2.4.0
aiofiles==23.2.1
aiohttp==3.9.0
aiosignal==1.3.1
aiostream==0.5.2
alembic==1.13.1
annotated-types==0.6.0
annoy==1.17.3
ansible==6.7.0
ansible-core==2.13.13
ansible-vault==2.1.0
anyio==3.7.1
appdirs==1.4.4
art==6.0
asgiref==3.7.2
async-timeout==4.0.2
attrdict==2.0.1
attrs==22.2.0
awscli==1.32.75
-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
backoff==2.2.1
base58==2.1.1
beartype==0.17.2
bitnet==0.2.1
bitsandbytes==0.42.0
bittensor==6.7.0
black==23.7.0
blinker==1.7.0
boto3==1.34.75
botocore==1.34.75
cachetools==5.3.3
cachy==0.1.1
certifi==2023.7.22
cffi==1.16.0
cfgv==3.3.1
chai-guanaco==1.2.4
charset-normalizer==3.2.0
cleo==0.6.8
click==8.1.7
cloudpickle==2.0.0
cohere==4.11.2
colorama==0.4.4
coloredlogs==15.0.1
CoLT5-attention==0.10.20
contextlib2==21.6.0
contourpy==1.2.0
cryptography==41.0.3
cycler==0.12.1
cytoolz==0.12.3
databricks-cli==0.18.0
dataclasses-json==0.5.7
datasets==2.11.0
ddt==1.6.0
decorator==5.1.1
deepspeed==0.15.0
# Editable Git install with no remote (dialogpt==0.1)
-e /Users/wing/Projects/ml/dialogpt/src
dill==0.3.6
distlib==0.3.6
docker==7.0.0
docker-pycreds==0.4.0
docstring-parser==0.15
docutils==0.16
ecdsa==0.18.0
einops==0.7.0
einops-exts==0.0.4
einx==0.1.3
entrypoints==0.4
eth-hash==0.6.0
eth-keys==0.5.0
eth-typing==4.0.0
eth-utils==2.3.1
evaluate==0.4.0
exceptiongroup==1.1.1
fastapi==0.109.2
fastcore==1.5.29
ffmpy==0.4.0
filelock==3.12.2
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
fire==0.5.0
first==2.0.2
flake8==7.0.0
Flask==3.0.1
fonttools==4.47.2
frozendict==2.4.1
frozenlist==1.3.3
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
fsspec==2023.6.0
fuzzywuzzy==0.18.0
gitdb==4.0.10
GitPython==3.1.31
google-pasta==0.2.0
gradio==4.42.0
gradio_client==1.3.0
greenlet==2.0.2
grpclib==0.4.7
gunicorn==21.2.0
h11==0.14.0
h2==4.1.0
hpack==4.0.0
httpcore==0.17.3
httpx==0.24.1
huggingface-hub==0.23.4
humanfriendly==10.0
hyperframe==6.0.1
identify==2.5.24
idna==3.4
immutables==0.20
importlib-metadata==6.7.0
importlib-resources==6.1.1
inflection==0.5.1
iniconfig==2.0.0
itsdangerous==2.1.2
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.3.2
jsonlines==3.1.0
jsonschema==2.6.0
kiwisolver==1.4.5
langchain==0.0.144
Levenshtein==0.24.0
libcst==1.1.0
liger-kernel==0.0.0
lion-pytorch==0.1.2
llama-cpp-python==0.1.36
llvmlite==0.40.1
local-attention==1.9.0
loguru==0.7.0
Mako==1.3.2
Markdown==3.5.2
markdown-it-py==3.0.0
markdown2==2.4.10
MarkupSafe==2.1.2
marshmallow==3.19.0
marshmallow-enum==1.5.1
matplotlib==3.8.2
mccabe==0.7.0
mdurl==0.1.2
MEGABYTE-pytorch==0.0.7
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
mlflow==2.10.0
modal==0.62.77
more-itertools==10.2.0
mpmath==1.2.1
msgpack==1.0.7
msgpack-numpy-opentensor==0.5.0
multidict==6.0.4
multiprocess==0.70.14
munch==2.5.0
mypy==1.3.0
mypy-extensions==1.0.0
nest-asyncio==1.6.0
netaddr==0.10.1
networkx==3.0rc1
nh3==0.2.14
nodeenv==1.8.0
nomic==2.0.2
numba==0.57.1
numexpr==2.8.4
numpy==1.24.4
oauthlib==3.2.2
openai==0.27.4
openapi==1.1.0
openapi-schema-pydantic==1.2.4
optimum==1.8.6
orjson==3.10.7
packaging==23.1
pandas==2.0.0
parameterized==0.9.0
password-strength==0.0.3.post2
pastel==0.1.1
pathos==0.3.0
pathspec==0.11.1
pathtools==0.1.2
peft==0.11.1
pendulum==3.0.0
Pillow==9.5.0
pip-tools==1.11.0
platformdirs==3.2.0
pluggy==1.4.0
poetry==0.7.1
pox==0.3.2
ppft==1.7.6.6
pre-commit==3.3.2
prettytable==3.10.0
prompt-toolkit==3.0.39
protobuf==3.20.2
protobuf3-to-dict==0.1.5
psutil==5.9.5
psycopg==3.1.18
PuLP==2.8.0
py==1.11.0
py-bip39-bindings==0.1.11
py-cpuinfo==9.0.0
py-ed25519-zebra-bindings==1.0.1
py-sr25519-bindings==0.2.0
pyarrow==11.0.0
pyasn1==0.6.0
pycodestyle==2.11.1
pycparser==2.21
pycryptodome==3.20.0
pydantic==2.5.3
pydantic_core==2.14.6
pydub==0.25.1
pyfiglet==0.8.post1
pyflakes==3.2.0
Pygments==2.15.1
PyJWT==2.8.0
pylev==1.4.0
PyNaCl==1.5.0
pynvml==11.5.0
pyparsing==2.4.7
pyrsistent==0.14.11
pytest==8.0.2
pytest-asyncio==0.23.4
python-dateutil==2.8.2
python-dotenv==1.0.1
python-Levenshtein==0.24.0
python-multipart==0.0.9
pytz==2023.3
PyYAML==6.0.1
querystring-parser==1.2.4
rapidfuzz==3.6.1
regex==2023.6.3
requests==2.31.0
requests-toolbelt==0.8.0
resolvelib==0.8.1
responses==0.18.0
retry==0.9.2
rich==13.7.0
rsa==4.7.2
ruff==0.6.3
s3transfer==0.10.1
safetensors==0.4.5
sagemaker==2.148.0
scalecodec==1.2.7
schedulefree==1.2.1
schema==0.7.5
scikit-learn==1.4.0
scipy==1.9.3
seaborn==0.13.2
semantic-version==2.10.0
sentencepiece==0.2.0
sentry-sdk==1.19.1
setproctitle==1.3.2
shellingham==1.5.4
shortuuid==1.0.11
shtab==1.6.5
sigtools==4.0.1
six==1.16.0
skypilot==0.4.1
smdebug-rulesconfig==1.0.1
smmap==5.0.0
sniffio==1.3.0
SQLAlchemy==1.4.47
sqlparse==0.4.4
starlette==0.36.3
substrate-interface==1.5.2
svgwrite==1.4.3
sympy==1.11.1
synchronicity==0.6.7
tabulate==0.9.0
tblib==1.7.0
tenacity==8.2.2
tensor-parallel==2.0.0
termcolor==2.2.0
text2art==0.2.0
threadpoolctl==3.2.0
tiktoken==0.6.0
time-machine==2.14.1
timm==0.9.16
tokenizers==0.19.1
tokenmonster==1.1.12
toml==0.9.6
tomli==2.0.1
tomlkit==0.12.0
toolz==0.12.1
torch==2.2.0
torchdata==0.6.1
torchdiffeq==0.2.3
TorchFix==0.4.0
torchtext==0.15.2
torchvision==0.17.0
tqdm==4.66.2
transformers==4.44.2
trl==0.9.6
typer==0.12.5
types-certifi==2021.10.8.3
types-requests==2.31.0.20240125
types-setuptools==69.0.0.20240125
types-toml==0.10.8.7
typing==3.7.4.3
typing-inspect==0.8.0
typing_extensions==4.9.0
tyro==0.5.18
tzdata==2023.3
unique-names-generator==1.0.2
urllib3==2.2.2
uvicorn==0.22.0
vector_quantize_pytorch==1.14.1
virtualenv==20.23.0
voyager==2.0.2
wandb==0.16.2
watchfiles==0.21.0
wavedrom==2.0.3.post3
wcwidth==0.2.6
websocket-client==1.7.0
websockets==12.0
Werkzeug==3.0.1
wonderwords==2.2.0
xxhash==3.2.0
yarl==1.8.2
zetascale==2.2.7
zipp==3.15.0

View File

@@ -1,60 +0,0 @@
"""
helper script to parse chat datasets into a usable yaml
"""
import click
import yaml
from datasets import load_dataset
@click.command()
@click.argument("dataset", type=str)
@click.option("--split", type=str, default="train")
def parse_dataset(dataset=None, split="train"):
ds_cfg = {}
ds_cfg["path"] = dataset
ds_cfg["split"] = split
ds_cfg["type"] = "chat_template"
ds_cfg["chat_template"] = "<<<Replace based on your model>>>"
dataset = load_dataset(dataset, split=split)
features = dataset.features
feature_keys = features.keys()
field_messages = None
for key in ["conversation", "conversations", "messages"]:
if key in feature_keys:
field_messages = key
break
if not field_messages:
raise ValueError(
f'No conversation field found in dataset: {", ".join(feature_keys)}'
)
ds_cfg["field_messages"] = field_messages
message_fields = features["conversations"][0].keys()
message_field_role = None
for key in ["from", "role"]:
if key in message_fields:
message_field_role = key
break
if not message_field_role:
raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_role"] = message_field_role
message_field_content = None
for key in ["content", "text", "value"]:
if key in message_fields:
message_field_content = key
break
if not message_field_content:
raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_content"] = message_field_content
print(yaml.dump({"datasets": [ds_cfg]}))
if __name__ == "__main__":
parse_dataset()

View File

@@ -30,9 +30,6 @@ def parse_requirements():
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
@@ -52,35 +49,20 @@ def parse_requirements():
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
if (major, minor) >= (2, 3):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1")
else:
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError:
pass
return _install_requires, _dependency_links
@@ -109,7 +91,6 @@ setup(
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
],
"auto-gptq": [
"auto-gptq==0.5.1",

View File

@@ -27,11 +27,8 @@ from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import _is_package_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
@@ -41,7 +38,7 @@ from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
@@ -55,22 +52,8 @@ LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
AXOLOTL_LOGO = """
#@@ #@@ @@# @@#
@@ @@ @@ @@ =@@# @@ #@ =@@#.
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
"""
def print_legacy_axolotl_text_art(suffix=None):
def print_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
@@ -83,13 +66,6 @@ def print_legacy_axolotl_text_art(suffix=None):
print_dep_versions()
def print_axolotl_text_art(
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
print(AXOLOTL_LOGO)
def print_dep_versions():
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages)
@@ -257,8 +233,7 @@ def do_inference_gradio(
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
default_tokens: Dict[str, str] = {}
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
@@ -266,13 +241,10 @@ def do_inference_gradio(
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -286,24 +258,7 @@ def do_inference_gradio(
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
@@ -326,7 +281,6 @@ def do_inference_gradio(
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"attention_mask": batch["attention_mask"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
@@ -411,11 +365,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
cfg.axolotl_config_path = config
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
@@ -443,8 +392,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
return cfg
@@ -454,20 +401,12 @@ def load_datasets(
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg,
tokenizer,
processor=processor,
cfg, tokenizer
)
if (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(

View File

@@ -27,7 +27,6 @@ from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
)
from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger("axolotl.cli.preprocess")
@@ -71,11 +70,10 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
with disable_datasets_caching():
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.download:
model_name = parsed_cfg.base_model

View File

@@ -3,11 +3,13 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Union
from typing import Tuple, Union
import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from axolotl.cli import (
check_accelerate_default_config,
@@ -18,7 +20,6 @@ from axolotl.cli import (
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.integrations.base import PluginManager
from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
@@ -38,7 +39,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
return do_train(parsed_cfg, parsed_cli_args)
def do_train(cfg, cli_args) -> None:
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
@@ -63,13 +64,7 @@ def do_train(cfg, cli_args) -> None:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
plugin_manager.post_train_unload(cfg)
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
if __name__ == "__main__":

View File

@@ -23,7 +23,7 @@ class TrainerCliArgs:
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
debug_num_examples: int = field(default=5)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)

View File

@@ -1,34 +0,0 @@
"""
ChatML transformation functions for MessageContents
"""
from typing import Optional
from ..messages import MessageContents, Messages
from .shared import wrap_tools
def format_message(
message: Messages,
message_index: Optional[int] = None, # pylint: disable=unused-argument
) -> Messages:
if message.is_chat_formatted:
return message
# prepend the role prefix within a MessageContents to message.content
message.content.insert(
0,
MessageContents(
type="text",
value=f"<|im_start|>{message.role}\n",
weight=0,
),
)
message.content.append(
MessageContents(type="text", value="<|im_end|>", weight=message.weight)
)
message.content.append(MessageContents(type="text", value="\n", weight=0))
message = wrap_tools(message)
message.is_chat_formatted = True
return message

View File

@@ -1,45 +0,0 @@
"""
Llama 3.x chat formatting functions for MessageContents
"""
from typing import Optional
from ..messages import MessageContents, Messages
from .shared import wrap_tools
def format_message(message: Messages, message_index: Optional[int] = None) -> Messages:
if message.is_chat_formatted:
return message
message_role = message.role
if message.role == "tool":
message_role = "ipython"
# prepend the role prefix within a MessageContents to message.content
message.content.insert(
0,
MessageContents(
type="text",
value=f"<|start_header_id|>{message_role}<|end_header_id|>\n\n",
weight=0,
),
)
message.content.append(
MessageContents(type="text", value="<|eot_id|>", weight=message.weight)
)
message = wrap_tools(message)
if message_index == 0:
message.content.insert(
0,
MessageContents(
type="text",
value="<|begin_of_text|>",
weight=0,
),
)
message.is_chat_formatted = True
return message

View File

@@ -1,47 +0,0 @@
"""
shared functions for format transforms
"""
from axolotl.core.chat.messages import MessageContents, Messages
def wrap_tools(message: Messages):
# loop over message.content by index to find tool calls, we need to wrap each with tags,
# so be wary of indexing issues when changing the list while iterating.
# iterate over the range in reverse order to avoid index shifting
for i in range(len(message.content) - 1, -1, -1):
if message.content[i].type == "tool_call":
# append a </tool_call> MessageContents text tag after
message.content.insert(
i + 1,
MessageContents(
type="text", value="</tool_call>\n", weight=message.weight
),
)
# make sure the actual tool call content ends with a newline
message.content[i].has_newline = True
# prepend a <tool_call> MessageContents text tag before
message.content.insert(
i,
MessageContents(
type="text", value="<tool_call>\n", weight=message.weight
),
)
elif message.content[i].type == "tool_response":
# append a </tool_call> MessageContents text tag after
message.content.insert(
i + 1,
MessageContents(
type="text", value="</tool_response>\n", weight=message.weight
),
)
# make sure the actual tool response content ends with a newline
message.content[i].has_newline = True
# prepend a <tool_call> MessageContents text tag before
message.content.insert(
i,
MessageContents(
type="text", value="<tool_response>\n", weight=message.weight
),
)
return message

View File

@@ -1,230 +0,0 @@
"""
internal message representations of chat messages
"""
import json
from enum import Enum
from typing import Any, Callable, List, Optional, Union
from pydantic import BaseModel
from transformers import PreTrainedTokenizer
class MessageRoles(str, Enum):
"""
Message roles for the system, user, assistant, and tools
"""
system = "system" # pylint: disable=invalid-name
user = "user" # pylint: disable=invalid-name
assistant = "assistant" # pylint: disable=invalid-name
tool = "tool" # pylint: disable=invalid-name
ipython = ( # pylint: disable=invalid-name
# for responses from builtin tools
"ipython"
)
class MessageContentTypes(str, Enum):
"""
Message content types for text, image, audio, tool calls, and tool responses
"""
special_token = "special_token" # pylint: disable=invalid-name # nosec B105
text = "text" # pylint: disable=invalid-name
image = "image" # pylint: disable=invalid-name
audio = "audio" # pylint: disable=invalid-name
tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant
tool_response = "tool_response" # pylint: disable=invalid-name
class SpecialToken(str, Enum):
"""
Special tokens for beginning of string and end of string
"""
bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105
eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105
class ToolCallFunction(BaseModel):
"""
Tool call function with name and arguments
"""
name: str
arguments: dict[str, str]
class Tool(BaseModel):
"""
Tool with description, function, and parameters
"""
description: str
function: ToolCallFunction
parameters: dict[str, str] # .properties
class ToolCallContents(BaseModel):
"""
Tool call contents with name, arguments, and optional id
"""
name: str
arguments: dict[str, Union[str, int]]
id: Optional[str] = None # pylint: disable=invalid-name
def __str__(self) -> str:
data = {"name": self.name, "arguments": self.arguments}
if self.id is not None:
data["id"] = self.id
return json.dumps(data)
class ToolResponseContents(BaseModel):
"""
Tool response contents with name, content, and optional id
"""
name: str
content: Union[str, dict[str, Union[str, int, float]]]
id: Optional[str] = None # pylint: disable=invalid-name
def __str__(self) -> str:
data = {"name": self.name, "content": self.content}
if self.id is not None:
data["id"] = self.id
return json.dumps(data)
class MessageContents(BaseModel):
"""
Message contents with type, value, metadata, weight, newline, and end of contents
"""
type: Union[str, MessageContentTypes]
value: Union[str, ToolCallContents, ToolResponseContents, SpecialToken]
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
weight: Optional[Union[int, float]] = None
has_newline: bool = False
eoc: bool = False # end of contents
def __str__(self) -> str:
str_val = str(self.value)
if self.has_newline and not str_val.endswith("\n"):
str_val += "\n"
return str_val
class Messages(BaseModel):
"""
Messages with role, content, metadata, weight, and chat formatting
"""
role: Union[MessageRoles, str] # allows for arbitrary roles
content: List["MessageContents"]
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
weight: Optional[Union[int, float]] = None
is_chat_formatted: bool = False
def __str__(self) -> str:
return "".join(str(c) for c in self.content)
def tokenized(
self, tokenizer: PreTrainedTokenizer, ignore_index=-100
) -> dict[str, List[int]]:
# iterate over the contents, tokenizing the concatenated string values up to the current MessageContents
# returns a dictionary mapping w input_ids, attention_mask, and labels
input_ids: List[int] = []
labels: List[int] = []
pending_input_ids: List[int] = []
pending_weight = self.weight
running_content = ""
for _, msg_content in enumerate(self.content):
# TODO also handle non-text content types
if msg_content.type in [
MessageContentTypes.text.value,
MessageContentTypes.tool_call.value,
MessageContentTypes.tool_response.value,
]:
running_content += str(msg_content)
tok_results = tokenizer(running_content, add_special_tokens=False)
tok_input_ids = tok_results["input_ids"]
if pending_input_ids:
new_pending_inputs = tok_input_ids[
len(input_ids) : len(input_ids) + len(pending_input_ids)
]
if new_pending_inputs != pending_input_ids:
# logging.warning("tokenization mismatch from concatenation.")
pending_input_ids = new_pending_inputs
input_ids.extend(pending_input_ids)
if pending_weight:
labels.extend(pending_input_ids)
else:
labels.extend([ignore_index] * len(pending_input_ids))
pending_input_ids = tok_results["input_ids"][len(input_ids) :]
pending_weight = self.weight and msg_content.weight not in [0, 0.0]
input_ids.extend(pending_input_ids)
if pending_weight:
labels.extend(pending_input_ids)
else:
labels.extend([ignore_index] * len(pending_input_ids))
attention_mask = [1] * len(input_ids)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
class Chats(BaseModel):
"""
top level data structure for chat conversations
"""
conversation: List[Messages]
def __str__(self) -> str:
return "".join(str(c) for c in self.conversation)
def tokenized(
self, tokenizer: Callable[[str], dict[str, List[int]]], ignore_index=-100
) -> dict[str, List[int]]:
input_ids = []
attention_mask = []
labels = []
for msg in self.conversation:
msg_results = msg.tokenized(tokenizer, ignore_index)
input_ids.extend(msg_results["input_ids"])
attention_mask.extend(msg_results["attention_mask"])
labels.extend(msg_results["labels"])
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
class ChatFormattedChats(Chats):
"""
Chat formatted chats with formatter and optional train on inputs
"""
formatter: Callable # [[Union[dict, Chats]], Chats]
train_on_inputs: bool = False
def model_post_init(self, __context):
for i, msg in enumerate(self.conversation):
self.conversation[i] = self.formatter(msg, message_index=i)
if self.train_on_inputs:
self.conversation[i].weight = 1
class PreferenceChats(BaseModel):
"""
representation for preference data for chat
"""
prompt: List[Messages]
chosen: Messages
rejected: Messages

View File

@@ -1,55 +0,0 @@
"""
chat dataset module
"""
import os
from typing import Callable, Optional, Union
from datasets import Dataset
from transformers import PreTrainedTokenizer
from axolotl.core.chat.messages import ChatFormattedChats
class TokenizedChatDataset(Dataset):
"""
Tokenized chat dataset
"""
def __init__(
self,
data: Dataset,
model_transform: Union[PreTrainedTokenizer, Callable],
*args,
message_transform: Optional[Callable] = None,
formatter=None,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs,
):
def map_fn(ex):
if message_transform is not None:
ex = message_transform(ex)
if formatter is not None:
ex = ChatFormattedChats(
formatter=formatter,
**ex,
)
else:
ex = ChatFormattedChats(
**ex,
)
return ex.tokenized(model_transform)
process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment]
)
num_proc = min(64, process_or_cpu_count)
features = data.features.keys()
tokenized_data = data.map(
map_fn,
num_proc=num_proc,
keep_in_memory=keep_in_memory,
remove_columns=features,
desc="Tokenizing Chats",
)
super().__init__(tokenized_data.data, *args, **kwargs)

View File

@@ -1,150 +0,0 @@
"""
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
"""
from typing import Any, Mapping, Union
def chat_message_transform_builder( # pylint: disable=dangerous-default-value
train_on_inputs=False,
conversations_field: str = "conversations",
message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role"
message_field_content: Union[str, list[str]] = [
"value",
"text",
"content",
], # commonly "content"
message_field_training: Union[str, list[str]] = [
"train",
"weight",
], # commonly "weight"
):
"""Builds a transform that takes a row from the dataset and converts it to a Chat
Args:
train_on_inputs (bool, optional):
If True, the transform will train on the inputs. If False, the transform will train on the targets.
Defaults to False.
conversations_field (str, optional):
The field name of the conversations. Defaults to "conversations".
message_field_role (str | list[str], optional):
The field name of the role. Defaults to "role".
message_field_content (str | list[str], optional):
The field name of the message content. Defaults to "content".
message_field_training (str | list[str], optional):
The field name of the train/weight. Defaults to "weight".
Returns:
Callable:
A function that takes a list of conversations and returns a list of messages.
"""
message_field_role = (
[message_field_role]
if isinstance(message_field_role, str)
else message_field_role
)
message_field_content = (
[message_field_content]
if isinstance(message_field_content, str)
else message_field_content
)
message_weight_fields = (
[message_field_training]
if isinstance(message_field_training, str)
else message_field_training
)
role_value_mappings = {
"system": "system",
"user": "user",
"human": "user",
"assistant": "assistant",
"gpt": "assistant",
"tool": "tool",
"ipython": "ipython",
}
if train_on_inputs:
role_default_weights_mappings = {
"system": 1,
"user": 1,
"assistant": 1,
"tool": 1,
"ipython": 1,
}
else:
role_default_weights_mappings = {
"system": 0,
"user": 0,
"assistant": 1,
"tool": 0,
"ipython": 0,
}
def transform_builder(sample: Mapping[str, Any]):
if conversations_field not in sample:
raise ValueError(f"Field '{conversations_field}' not found in sample.")
# if none of the role fields are in the message, raise an error
if not any(
role in sample[conversations_field][0] for role in message_field_role
):
raise ValueError("No role field found in message.")
role_field = next(
role
for role in message_field_role
if role in sample[conversations_field][0]
)
if not any(
field in sample[conversations_field][0] for field in message_field_content
):
raise ValueError("No message_content field found in message.")
message_content_field = next(
field
for field in message_field_content
if field in sample[conversations_field][0]
)
if not any(
field in sample[conversations_field][0] for field in message_field_training
):
message_weight_field = None
else:
message_weight_field = next(
field
for field in message_weight_fields
if field in sample[conversations_field][0]
)
messages = []
for message in sample[conversations_field]:
role = role_value_mappings[message[role_field]]
weight = (
int(message[message_weight_field])
if message_weight_field
else role_default_weights_mappings[role]
)
# TODO if "tool_calls" in message[message_content_field]: then convert tool call to ToolCallContents
if isinstance(message[message_content_field], str):
messages.append(
{
"role": role,
"content": [
{
"type": "text",
"value": message[message_content_field],
}
],
"weight": weight,
}
)
else:
messages.append(
{
"role": role,
"content": message[message_content_field],
"weight": weight,
}
)
return {"conversation": messages}
return transform_builder

View File

@@ -4,10 +4,8 @@ Builder for the training args and trainer
"""
import abc
import gc
import importlib
import importlib.util
import inspect
import logging
import math
import os
@@ -17,17 +15,16 @@ from collections import defaultdict
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Type, Union
from typing import Dict, List, Literal, Optional, Type, Union
import torch
import transformers
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
EarlyStoppingCallback,
PreTrainedModel,
Trainer,
TrainerCallback,
TrainingArguments,
@@ -43,14 +40,13 @@ from trl import (
KTOTrainer,
ORPOConfig,
ORPOTrainer,
RewardConfig,
RewardTrainer,
)
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils import is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
@@ -63,14 +59,12 @@ from axolotl.utils.callbacks import (
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.models import ensure_dtype
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
@@ -254,10 +248,6 @@ class AxolotlTrainingMixins:
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)
@dataclass
@@ -303,13 +293,6 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
)
@dataclass
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
"""
Reward config for Reward training
"""
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
@@ -407,10 +390,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
def __init__(
self,
*_args,
num_epochs=1,
bench_data_collator=None,
eval_data_collator=None,
**kwargs,
):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
super().__init__(*_args, **kwargs)
@@ -469,14 +454,14 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6
self.args, "loraplus_lr_embedding", None
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
@@ -519,10 +504,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
train_batch_size = (
self.state.train_batch_size or self.args.per_device_train_batch_size
batch_max_len = (
self.args.per_device_train_batch_size * self.args.max_seq_length
)
batch_max_len = train_batch_size * self.args.max_seq_length
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
lengths=get_dataset_lengths(self.train_dataset),
@@ -666,9 +650,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
@@ -676,18 +658,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
return super().compute_loss(model, inputs, return_outputs=return_outputs)
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
@@ -783,13 +755,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
).squeeze(2)
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
def orpo_compute_loss(
self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None, # pylint: disable=unused-argument
):
def orpo_compute_loss(self, model, inputs, return_outputs=False):
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
inputs,
label_pad_token=-100,
@@ -895,13 +861,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial):
def _save_checkpoint(self, model, trial, metrics=None):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial)
return super()._save_checkpoint(model, trial, metrics=metrics)
class AxolotlMambaTrainer(AxolotlTrainer):
@@ -916,7 +882,6 @@ class AxolotlMambaTrainer(AxolotlTrainer):
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
@@ -1001,9 +966,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
if is_sagemaker_mp_enabled():
@@ -1024,36 +989,14 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
return super().push_to_hub(*args, **kwargs)
def tokenize_row(
self,
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
) -> Dict:
res = super().tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
res = super().tokenize_row(feature, model=model)
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
return res
def training_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
gc.collect()
torch.cuda.empty_cache()
return loss
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
@@ -1079,14 +1022,6 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"]
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"]
class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
@@ -1097,11 +1032,10 @@ class TrainerBuilderBase(abc.ABC):
_model_ref = None
_peft_config = None
def __init__(self, cfg, model, tokenizer, processor=None):
def __init__(self, cfg, model, tokenizer):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
self.processor = processor
# in case the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
@@ -1152,23 +1086,12 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from transformers.integrations.integration_utils import MLflowCallback
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
MLflowCallback,
]
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
return callbacks
@@ -1238,11 +1161,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer, self.tokenizer, "mlflow"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "comet_ml"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
@@ -1267,8 +1185,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
return AxolotlTrainer
def build(self, total_num_steps):
@@ -1453,10 +1369,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
if self.cfg.auto_find_batch_size is not None:
training_arguments_kwargs[
"auto_find_batch_size"
] = self.cfg.auto_find_batch_size
training_arguments_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
@@ -1490,22 +1402,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.wandb_name:
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.use_mlflow:
report_to.append("mlflow")
if self.cfg.use_tensorboard:
report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
training_arguments_kwargs["report_to"] = report_to
if self.cfg.use_wandb:
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
elif self.cfg.use_mlflow:
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_arguments_kwargs["run_name"] = None
training_arguments_kwargs["run_name"] = (
self.cfg.wandb_name if self.cfg.use_wandb else None
)
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
@@ -1546,9 +1451,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
training_arguments_kwargs["multipack_real_batches"] = (
not self.cfg.flash_attention or self.cfg.multipack_real_batches
)
training_arguments_kwargs[
"multipack_real_batches"
] = not self.cfg.flash_attention
training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing
)
@@ -1593,10 +1498,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template
)
if self.cfg.rl == "orpo":
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
@@ -1608,9 +1509,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
if self.cfg.reward_model:
trainer_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
@@ -1654,22 +1552,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
"accelerator_config"
] = self.cfg.accelerator_config
training_args_cls = (
AxolotlTrainingArguments
if not self.cfg.reward_model
else AxolotlRewardConfig
)
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
)
training_args = self.hook_post_create_training_args(training_args)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
@@ -1682,37 +1571,27 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
if self.cfg.reward_model:
data_collator_kwargs["max_length"] = self.cfg.sequence_len
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
if eval_data_collator := self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
):
if not self.cfg.reward_model:
trainer_kwargs["eval_data_collator"] = eval_data_collator
if not self.cfg.reward_model:
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
)
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
trainer_kwargs["processing_class"] = self.tokenizer
else:
trainer_kwargs["tokenizer"] = self.tokenizer
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
tokenizer=self.tokenizer,
data_collator=self.build_collator(training_args, **data_collator_kwargs),
eval_data_collator=self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=self.get_callbacks(),
num_epochs=self.cfg.num_epochs,
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
@@ -1746,14 +1625,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
RewardDataCollatorWithPadding,
]
]
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
@@ -1764,12 +1638,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
else:
collator = BatchSamplerDataCollatorForSeq2Seq
else:
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
kwargs["processor"] = self.processor
kwargs["chat_template"] = training_args.chat_template
else:
collator = DataCollatorForSeq2Seq
collator = DataCollatorForSeq2Seq
return collator(
self.tokenizer,
@@ -1955,7 +1824,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
dpo_trainer_kwargs["generate_during_eval"] = True
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
@@ -1967,17 +1836,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
dpo_trainer_kwargs["processing_class"] = self.tokenizer
else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
tokenizer=self.tokenizer,
callbacks=self.get_callbacks(),
**dpo_trainer_kwargs,
)

View File

@@ -1,58 +0,0 @@
### AXOLOTL COMMUNITY LICENSE AGREEMENT
This Axolotl Community License Agreement (“Agreement”) is entered into by and between Axolotl AI Corp. (“Axolotl”) and
any individual or entity (“Licensee”) who wishes to use the Software (as defined below) in accordance with the terms
and conditions set forth in this Agreement.
1. Definitions
1.1 “Licensee” refers to any individual or entity who has obtained a copy of the Software under this Agreement.
1.2 “Plugin Integration” means independent integration software modules which may or may not be offered by Axolotl,
which may be licensed separately by their respective authors and/or licensors.
1.3 “Software” refers to the specific sub-directory of the Axolotl, Inc. software located at
https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations and its subdirectories which
permits Plugin Integrations to integrate with the Axolotl service.
2. Grant of License
2.1 Axolotl hereby grants Licensee a worldwide, non-exclusive, royalty-free, license to use, copy, modify, merge,
publish, distribute, sublicense, and/or otherwise exploit the Software, subject to the following conditions:
- Licensee must comply with all the terms and conditions of this Agreement.
- Licensee must include the original copyright notice and disclaimer of warranty in all copies or substantial
portions of the Software.
2.2 Licensee may use the Software for any lawful purpose, except as restricted in Section 3.
3. Restrictions
3.1 Licensee shall not use the Software for any activity that constitutes a commercial activity of offering for
free or for sale any services, platform, or equivalent to third parties for the purposes of allowing such
third parties to fine-tune artificial intelligence models.
3.2 Licensee shall not:
- Use the Software for any illegal or unauthorized purpose.
- Reverse engineer, decompile, or disassemble the Software.
- Remove or modify any copyright, trademark, or other proprietary notices contained in the Software.
- Use the Software in a way that could damage, disable, overburden, or impair the functionality of the
Software or interfere with any third-party use of the Software.
3.3 Axolotl reserves the right to restrict certain Plugin Integrations for use with the Software. To the extent Licensee integrates a permitted, applicable Plugin Integration with the Software, Licensee shall comply with any additional terms and conditions imposed by the licensors of such Plugin Integration for use of such Plugin Integrations. Licensee shall contact Axolotl if it has questions about whether its use of the Software falls beyond the scope of this Agreement.
4. Intellectual Property Rights
4.1 Axolotl and its contributors retain all intellectual property rights in and to the Software. Licensee
acknowledges that this Agreement does not transfer any ownership rights or intellectual property rights to
Licensee.
5. Disclaimer of Warranty
5.1 THE SOFTWARE IS PROVIDED “AS IS,” WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN ACTION OF
CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
6. Termination
6.1 Axolotl may terminate this Agreement at any time if Licensee fails to comply with any of the terms and
conditions set forth herein. Upon termination, Licensee shall cease all use of the Software and destroy any
copies in its possession.
7. Governing Law
7.1 This Agreement shall be governed by and construed in accordance with the laws of the State of California,
without regards to conflicts of laws provisions thereof.
8. Entire Agreement
8.1 This Agreement constitutes the entire agreement between Axolotl and Licensee with respect to the subject matter
hereof and supersedes all prior or contemporaneous understandings or agreements between the parties concerning
the Software, whether written or oral. Axolotl may update the terms of this Agreement from time to time, and
Licensees continued use of the Software after any such updates shall constitute acceptance of updated terms
on a go-forward basis. Axolotl will use commercially reasonable efforts to provide Licensee notice of any
material updates. By using the Software, Licensee acknowledges that it has read, understood, and agrees to be
bound by the terms and conditions of this Agreement.
This Agreement was last updated on August 23, 2024.

View File

@@ -1,420 +0,0 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Axolotl Community License Agreement (the "License");
# you may not use this file except in compliance with the License.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""
Base class for all plugins.
A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl.
Plugins can be used to integrate third-party models, modify the training process, or add new features.
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
"""
import importlib
import logging
from typing import List
class BasePlugin:
"""
Base class for all plugins. Defines the interface for plugin methods.
Attributes:
None
Methods:
register(cfg): Registers the plugin with the given configuration.
pre_model_load(cfg): Performs actions before the model is loaded.
post_model_load(cfg, model): Performs actions after the model is loaded.
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training.
"""
def __init__(self):
"""
Initializes the BasePlugin.
"""
def register(self, cfg):
"""
Registers the plugin with the given configuration.
Parameters:
cfg (dict): The configuration for the plugin.
Returns:
None
"""
def get_input_args(self):
"""
Returns a pydantic model for the plugin's input arguments.
"""
def pre_model_load(self, cfg):
"""
Performs actions before the model is loaded.
Parameters:
cfg (dict): The configuration for the plugin.
Returns:
None
"""
def post_model_load(self, cfg, model):
"""
Performs actions after the model is loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
"""
def pre_lora_load(self, cfg, model):
"""
Performs actions before LoRA weights are loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
"""
def post_lora_load(self, cfg, model):
"""
Performs actions after LoRA weights are loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
"""
def create_optimizer(self, cfg, trainer):
"""
Creates and returns an optimizer for training.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
object: The created optimizer.
"""
def create_lr_scheduler(self, cfg, trainer, optimizer):
"""
Creates and returns a learning rate scheduler.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
Returns:
object: The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model):
"""
Adds callbacks to the trainer before training.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
def add_callbacks_post_trainer(self, cfg, trainer):
"""
Adds callbacks to the trainer after training.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
def post_train(self, cfg, model):
"""
Performs actions after training is complete.
Parameters:
cfg (dict): The axolotl configuration
model (object): The loaded model.
Returns:
None
"""
def post_train_unload(self, cfg):
"""
Performs actions after training is complete and the model is unloaded.
Parameters:
cfg (dict): The configuration for the plugin.
Returns:
None
"""
def load_plugin(plugin_name: str) -> BasePlugin:
"""
Loads a plugin based on the given plugin name.
The plugin name should be in the format "module_name.class_name".
This function splits the plugin name into module and class, imports the module,
retrieves the class from the module, and creates an instance of the class.
Parameters:
plugin_name (str): The name of the plugin to be loaded. The name should be in the format "module_name.class_name".
Returns:
BasePlugin: An instance of the loaded plugin.
Raises:
ImportError: If the plugin module cannot be imported.
"""
# split the plugin name into module and class
module_name, class_name = plugin_name.rsplit(".", 1)
# import the module
module = importlib.import_module(module_name)
# instantiate the class
plugin_class = getattr(module, class_name)
# create an instance of the class
plugin = plugin_class()
return plugin
class PluginManager:
"""
The PluginManager class is responsible for loading and managing plugins.
It should be a singleton so it can be accessed from anywhere in the codebase.
Attributes:
plugins (List[BasePlugin]): A list of loaded plugins.
Methods:
get_instance(): Static method to get the singleton instance of PluginManager.
register(plugin_name: str): Registers a new plugin by its name.
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
"""
plugins: List[BasePlugin] = []
_instance = None
def __new__(cls):
"""
Creates a new instance of PluginManager if it doesn't exist yet.
"""
if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins: List[BasePlugin] = []
return cls._instance
@staticmethod
def get_instance() -> "PluginManager":
"""
Returns the singleton instance of PluginManager.
If the instance doesn't exist, it creates a new one.
"""
if PluginManager._instance is None:
PluginManager()
return PluginManager._instance # type: ignore
def register(self, plugin_name: str):
"""
Registers a new plugin by its name.
Parameters:
plugin_name (str): The name of the plugin to be registered.
Returns:
None
Raises:
ImportError: If the plugin module cannot be imported.
"""
try:
plugin = load_plugin(plugin_name)
self.plugins.append(plugin)
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
def get_input_args(self):
"""
Returns a list of Pydantic classes for all registered plugins' input arguments.'
Returns:
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
"""
input_args = []
for plugin in self.plugins:
input_args_from_plugin = plugin.get_input_args()
if input_args_from_plugin is not None:
input_args.append(input_args_from_plugin)
return input_args
def pre_model_load(self, cfg):
"""
Calls the pre_model_load method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
Returns:
None
"""
for plugin in self.plugins:
plugin.pre_model_load(cfg)
def post_model_load(self, cfg, model):
"""
Calls the post_model_load method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins:
plugin.post_model_load(cfg, model)
def pre_lora_load(self, cfg, model):
"""
Calls the pre_lora_load method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins:
plugin.pre_lora_load(cfg, model)
def post_lora_load(self, cfg, model):
"""
Calls the post_lora_load method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins:
plugin.post_lora_load(cfg, model)
def create_optimizer(self, cfg, trainer):
"""
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
Returns:
object: The created optimizer, or None if none was found.
"""
for plugin in self.plugins:
optimizer = plugin.create_optimizer(cfg, trainer)
if optimizer is not None:
return optimizer
return None
def create_lr_scheduler(self, cfg, trainer, optimizer):
"""
Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
Returns:
object: The created learning rate scheduler, or None if none was found.
"""
for plugin in self.plugins:
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
if scheduler is not None:
return scheduler
return None
def add_callbacks_pre_trainer(self, cfg, model):
"""
Calls the add_callbacks_pre_trainer method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
return callbacks
def add_callbacks_post_trainer(self, cfg, trainer):
"""
Calls the add_callbacks_post_trainer method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
return callbacks
def post_train_unload(self, cfg):
"""
Calls the post_train_unload method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins:
plugin.post_train_unload(cfg)

View File

@@ -1,65 +0,0 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Axolotl Community License Agreement (the "License");
# you may not use this file except in compliance with the License.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""
module to handle merging the plugins' input arguments with the base configurations.
this was moved here to prevent circular imports
"""
from typing import Any, Dict, List
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
)
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlInputConfig as AxolotlInputConfigBase,
)
def merge_input_args():
"""
Merges input arguments from registered plugins with the base configurations.
This function retrieves the input arguments from registered plugins using the PluginManager.
It then dynamically creates new classes, AxolotlConfigWCapabilities and AxolotlInputConfig,
that inherit from the base configurations and include the input arguments from the plugins.
Returns:
tuple: A tuple containing the newly created classes, AxolotlConfigWCapabilities and AxolotlInputConfig.
"""
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
input_args: List[str] = plugin_manager.get_input_args()
plugin_classes = []
dynamic_input = ""
for plugin_args in input_args:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
plugin_classes.append(plugin_cls)
if dynamic_input:
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
namespace: Dict[Any, Any] = {}
exec( # pylint: disable=exec-used # nosec B102
dynamic_input, globals(), namespace
)
AxolotlInputConfig = namespace[ # pylint: disable=invalid-name
"AxolotlInputConfig"
]
AxolotlConfigWCapabilities = namespace[ # pylint: disable=invalid-name
"AxolotlConfigWCapabilities"
]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase

View File

@@ -1,202 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -1,189 +0,0 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module for the Plugin for LIGER integraton with Axolotl.
Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight.
"""
import logging
import sys
from functools import partial
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
class LigerPlugin(BasePlugin):
"""
Plugin for LIGER integraton with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
if cfg.model_config_type == "llama":
from liger_kernel.transformers.model.llama import (
lce_forward as llama_lce_forward,
)
from transformers.models.llama import modeling_llama
if cfg.liger_rope:
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_llama.LlamaRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_llama.LlamaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
elif cfg.liger_fused_linear_cross_entropy:
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
elif cfg.model_config_type == "mistral":
from liger_kernel.transformers.model.mistral import (
lce_forward as mistral_lce_forward,
)
from transformers.models.mistral import modeling_mistral
if cfg.liger_rope:
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_mistral.MistralRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_mistral.MistralMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
elif cfg.model_config_type == "gemma":
from liger_kernel.transformers.model.gemma import (
lce_forward as gemma_lce_forward,
)
from transformers.models.gemma import modeling_gemma
if cfg.liger_rope:
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma.GemmaRMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
if cfg.liger_swiglu:
modeling_gemma.GemmaMLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba
from .models.jamba import lce_forward as jamba_lce_forward
if cfg.liger_rope:
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
elif cfg.model_config_type == "qwen2":
from liger_kernel.transformers.model.qwen2 import (
lce_forward as qwen2_lce_forward,
)
from transformers.models.qwen2 import modeling_qwen2
if cfg.liger_rope:
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
elif cfg.model_config_type == "deepseek_v2":
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(
cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
)
modeling_mod = sys.modules[model.__class__.__module__]
from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward
if cfg.liger_rope:
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_cross_entropy:
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type == "gemma2":
from transformers.models.gemma2 import modeling_gemma2
if cfg.liger_rope:
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma2.Gemma2RMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
if cfg.liger_swiglu:
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
logging.warning(
"Fused linear cross entropy is not supported for Gemma 2."
)
elif cfg.model_config_type == "phi3":
from liger_kernel.transformers.model.phi3 import (
lce_forward as phi3_lce_forward,
)
from transformers.models.phi3 import modeling_phi3
if cfg.liger_rope:
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward

View File

@@ -1,32 +0,0 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module for handling LIGER input arguments.
"""
from typing import Optional
from pydantic import BaseModel
class LigerArgs(BaseModel):
"""
Input args for LIGER.
"""
liger_rope: Optional[bool] = None
liger_rms_norm: Optional[bool] = None
liger_swiglu: Optional[bool] = None
liger_cross_entropy: Optional[bool] = None
liger_fused_linear_cross_entropy: Optional[bool] = None

View File

@@ -1,127 +0,0 @@
"""
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
"""
# pylint: disable=duplicate-code
from typing import List, Optional, Tuple, Union
import torch
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithPast
# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
# @replace_return_docstrings(
# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
# )
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM
>>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
loss = None
logits = None
if self.training:
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# flatten tokens
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
shift_labels = shift_labels.view(-1)
lce = LigerFusedLinearCrossEntropyLoss()
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@@ -1,173 +0,0 @@
"""
Jamba model with LigerFusedLinearCrossEntropyLoss
"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple, Union
import torch
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
from transformers.models.jamba.modeling_jamba import (
_CONFIG_FOR_DOC,
JAMBA_INPUTS_DOCSTRING,
HybridMambaAttentionDynamicCache,
load_balancing_loss_func,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: Optional[Union[int, None]] = None,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int` or `None`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all
`input_ids`. Only last token logits are needed for generation, and calculating them only for that token
can save memory, which becomes pretty significant for long sequences.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, JambaForCausalLM
>>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
>>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
return_dict=return_dict,
)
hidden_states = outputs[0]
loss = None
logits = None
if self.training:
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# flatten tokens
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
shift_labels = shift_labels.view(-1)
lce = LigerFusedLinearCrossEntropyLoss()
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
if num_logits_to_keep is None:
logits = self.lm_head(hidden_states)
else:
logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :])
logits = logits.float()
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits if return_dict else outputs[-1],
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(
loss.device
) # make sure to reside in the same device
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)

View File

@@ -1,13 +0,0 @@
# LM Eval Harness
### Usage
```yaml
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
```

View File

@@ -1,42 +0,0 @@
"""
Module for the Plugin for LM Eval Harness
"""
import subprocess # nosec
from datetime import datetime
from axolotl.integrations.base import BasePlugin
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
class LMEvalPlugin(BasePlugin):
"""
Plugin for LM Evaluation Harness integraton with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.lm_eval.LMEvalArgs"
def post_train_unload(self, cfg):
tasks = ",".join(cfg.lm_eval_tasks)
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
output_path = cfg.output_dir
output_path += "" if cfg.output_dir.endswith("/") else "/"
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
subprocess.run( # nosec
[
"lm_eval",
"--model",
"hf",
"--model_args",
f"pretrained={cfg.output_dir}{fa2}{dtype}",
"--tasks",
tasks,
"--batch_size",
str(cfg.lm_eval_batch_size),
"--output_path",
output_path,
],
check=True,
)

View File

@@ -1,15 +0,0 @@
"""
Module for handling lm eval harness input arguments.
"""
from typing import List, Optional
from pydantic import BaseModel
class LMEvalArgs(BaseModel):
"""
Input args for lm eval harness
"""
lm_eval_tasks: List[str] = []
lm_eval_batch_size: Optional[int] = 8

View File

@@ -1,202 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -1,21 +0,0 @@
## Spectrum: Targeted Training on Signal to Noise Ratio
by Eric Hartford, Lucas Atkins, Fernando Fernandes, David Golchinfar
This plugin contains code to freeze the bottom fraction of modules in a model, based on the Signal-to-Noise Ratio (SNR).
### Overview
Spectrum is a tool for scanning and evaluating the Signal-to-Noise Ratio (SNR) of layers in large language models.
By identifying the top n% of layers with the highest SNR, you can optimize training efficiency.
### Usage
```yaml
plugins:
- axolotl.integrations.spectrum.SpectrumPlugin
spectrum_top_fraction: 0.5
# Optional if using a pre-scanned model as your base_model. Useful if using a model mirror
spectrum_model_name: meta-llama/Meta-Llama-3.1-8B
```

View File

@@ -1,102 +0,0 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Spectrum Plugin to automatically generate unfrozen parameters based on SNR data.
"""
import json
import logging
import requests
from axolotl.integrations.base import BasePlugin
from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401
def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5):
unfrozen_parameters = {}
for layer_name, info in snr_data.items():
layer_type = info["type"]
if layer_type not in unfrozen_parameters:
unfrozen_parameters[layer_type] = []
unfrozen_parameters[layer_type].append((layer_name, info["snr"]))
top_layers_by_type = {}
for layer_type, layers in unfrozen_parameters.items():
layers_sorted = sorted(layers, key=lambda x: x[1], reverse=True)
num_top_layers = int(len(layers) * top_fraction)
top_layers_by_type[layer_type] = [
layer[0] for layer in layers_sorted[:num_top_layers]
]
unfrozen_parameters = [
"^lm_head.weight$",
"^model.embed_tokens.weight$",
]
for layer_type, layer_names in top_layers_by_type.items():
for layer_name in layer_names:
unfrozen_parameters.append(layer_name)
return unfrozen_parameters
class SpectrumPlugin(BasePlugin):
"""
Spectrum Plugin to automatically generate unfrozen parameters based on SNR data.
"""
base_url = "https://raw.githubusercontent.com/cognitivecomputations/spectrum/main/model_snr_results/"
base_path = "./model_snr_results/"
snr_file_template = "snr_results_{model_name_slug}.json"
def get_input_args(self):
return "axolotl.integrations.spectrum.SpectrumArgs"
def pre_model_load(self, cfg):
if cfg.get("spectrum_model_name"):
model_name = cfg["spectrum_model_name"]
else:
model_name = cfg["base_model"]
top_fraction = cfg.get("spectrum_top_fraction", 50)
model_slug = model_name.replace("/", "-").replace("_", "-")
snr_url = self.base_url + self.snr_file_template.format(
model_name_slug=model_slug
)
snr_path = self.base_path + self.snr_file_template.format(
model_name_slug=model_slug
)
# first check if the files exist locally and read the json
snr_data = None
try:
with open(snr_path, "r", encoding="utf-8") as fin:
snr_data = json.load(fin)
except FileNotFoundError:
pass
except Exception as exc: # pylint: disable=broad-exception-caught
logging.warning(f"Failed to read SNR data from {snr_path}: {exc}")
if not snr_data:
try:
snr_data = requests.get(snr_url, timeout=60).json()
except requests.exceptions.RequestException as exc:
logging.warning(f"Failed to fetch SNR data from {snr_url}: {exc}")
return
# also catch json parsing errors
except json.JSONDecodeError as exc:
logging.warning(f"Failed to parse SNR data from {snr_url}: {exc}")
return
unfrozen_parameters = _generate_unfrozen_params_yaml(
snr_data, top_fraction=top_fraction
)
cfg["unfrozen_parameters"] = unfrozen_parameters

View File

@@ -1,29 +0,0 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module for handling Spectrum input arguments.
"""
from typing import Optional
from pydantic import BaseModel
class SpectrumArgs(BaseModel):
"""
Input args for Spectrum.
"""
spectrum_top_fraction: Optional[float] = 0.5
spectrum_model_name: Optional[str] = None

133
src/axolotl/loraplus.py Normal file
View File

@@ -0,0 +1,133 @@
"""Module for LoRA+"""
# MIT License
#
# Copyright (c) 2024 nikhil-ghosh-berkeley
# https://github.com/nikhil-ghosh-berkeley/loraplus
import logging
from functools import reduce
from peft.tuners import lora
from torch import nn
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
LOG = logging.getLogger("axolotl.loraplus")
def get_module(name, opt_model):
"""
Retrieve a module from a model using its parameter name.
Args:
name (str): Full name of the parameter, typically including module path.
opt_model (torch.nn.Module): The model from which to retrieve the module.
Returns:
Module corresponding to the given name.
"""
parent_idx = 2 if "lora" in name else 1
module_names = name.split(sep=".")[:-parent_idx]
module = reduce(getattr, module_names, opt_model)
return module
def create_loraplus_optimizer(
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding=None,
):
"""
Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.
Args:
opt_model (torch.nn.Module): The model for which the optimizer is being created.
optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.
Returns:
An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
"""
assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."
if loraplus_lr_embedding is None:
loraplus_lr_embedding = 1e-6
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
param_groups = {
"groupA": {},
"groupB": {},
"groupB_no_decay": {},
"embedding": {},
}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
module = get_module(name, opt_model)
if isinstance(module, lora.Embedding):
param_groups["embedding"][name] = param
elif "lora_B" in name or param.ndim == 1:
if name in decay_parameters:
param_groups["groupB"][name] = param
else:
param_groups["groupB_no_decay"][name] = param
else:
param_groups["groupA"][name] = param
assigned_param_groups = ""
for group, group_params in param_groups.items():
assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n"
LOG.info(assigned_param_groups)
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
optimizer_grouped_parameters = [
{
"params": list(param_groups["groupA"].values()),
"weight_decay": weight_decay,
"lr": lr,
},
{
"params": list(param_groups["embedding"].values()),
"weight_decay": weight_decay,
"lr": loraplus_lr_embedding,
},
{
"params": list(param_groups["groupB"].values()),
"weight_decay": weight_decay,
"lr": lr * loraplus_lr_ratio,
},
{
"params": list(param_groups["groupB_no_decay"].values()),
"weight_decay": 0.0,
"lr": lr * loraplus_lr_ratio,
},
]
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{p.data_ptr(): p.numel() for p in module.parameters()}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
return optimizer

View File

@@ -1,229 +0,0 @@
"""
Monkeypatch for Vision Llama for FA2 support
"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple
import torch
from flash_attn.flash_attn_interface import flash_attn_func
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
from transformers.models.mllama.modeling_mllama import (
MllamaTextCrossAttention,
MllamaTextSelfAttention,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import is_flash_attn_greater_or_equal_2_10
class MllamaTextCrossFlashAttention2(MllamaTextCrossAttention):
"""
Mllama flash cross-attention module. This module inherits from `MllamaTextCrossAttention` and
implements the forward pass using Flash Attention for improved performance.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Check if flash attention version is greater or equal to 2.1
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[ # pylint: disable=unused-argument
torch.Tensor
] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
query_states = self.q_norm(query_states)
if cross_attention_states is not None:
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(
bsz, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
key_states = self.k_norm(key_states)
if past_key_value is not None:
key_states, value_states = past_key_value.update(
key_states,
value_states,
self.layer_idx,
{"cache_position": cache_position},
)
elif cache_position[0] != 0:
key_states, value_states = (
past_key_value.key_cache[self.layer_idx],
past_key_value.value_cache[self.layer_idx],
)
else:
raise ValueError(
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
)
# Transpose to get the expected layout for flash attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# Apply Flash Attention
dropout_rate = self.dropout if self.training else 0.0
output = flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=dropout_rate,
softmax_scale=None,
causal=False,
return_attn_probs=output_attentions,
)
attn_output = output.contiguous().view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class MllamaTextSelfFlashAttention2(MllamaTextSelfAttention):
"""
Mllama flash self-attention module. This module inherits from `MllamaTextSelfAttention` and
implements the forward pass using Flash Attention for improved performance.
"""
def __init__(self, config: MllamaTextConfig, layer_idx: int, *args, **kwargs):
super().__init__(config, layer_idx, *args, **kwargs)
# Check if flash attention version is greater or equal to 2.1
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
past_key_value=None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x num_heads x head_dim
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Transpose to get the expected layout for flash attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.dropout if self.training else 0.0
# Handle potential silent casting to float32
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = (
self.config._pre_quantization_dtype # pylint: disable=protected-access
)
else:
target_dtype = self.q_proj.weight.dtype
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=True,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def patch_mllama():
from transformers.models.mllama.modeling_mllama import (
MLLAMA_TEXT_ATTENTION_CLASSES,
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES,
MLLAMA_VISION_ATTENTION_CLASSES,
MllamaPreTrainedModel,
)
MllamaPreTrainedModel._supports_flash_attn_2 = ( # pylint: disable=protected-access
True
)
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[
"flash_attention_2"
] = MllamaTextCrossFlashAttention2
# fallback to SDPA
MLLAMA_VISION_ATTENTION_CLASSES[
"flash_attention_2"
] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]

View File

@@ -22,6 +22,7 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb,
repeat_kv,
)
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
@@ -43,19 +44,7 @@ except ImportError:
LOG = logging.getLogger("axolotl")
def is_xformers_available() -> bool:
try:
import xformers # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
def is_xformers_swiglu_available() -> bool:
if not is_xformers_available():
return False
from xformers.ops.common import get_xformers_operator
try:
@@ -68,11 +57,6 @@ def is_xformers_swiglu_available() -> bool:
def replace_llama_mlp_with_swiglu(model):
if is_xformers_swiglu_available():
from axolotl.monkeypatch.xformers_ import FusedMLP
else:
raise RuntimeError("xformers SwiGLU not available for this environment")
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):
mlp = FusedMLP(
@@ -197,6 +181,49 @@ class FusedAttention(LlamaAttention):
set_module_name(model, name, new_attn)
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(

View File

@@ -9,18 +9,18 @@ from axolotl.monkeypatch.utils import (
def hijack_llama_prepare_4d_mask():
from transformers import modeling_attn_mask_utils
from transformers.models.llama import modeling_llama
import transformers.modeling_attn_mask_utils
import transformers.models.llama.modeling_llama
modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)
modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)

View File

@@ -10,7 +10,6 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mllama_text_model",
"llama",
"mistral",
"mixtral",
@@ -27,18 +26,15 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
]
# def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
if model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
# elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
elif hasattr(transformers, "modeling_flash_attention_utils"):
if not has_remote_code:
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
return

View File

@@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio):
def reset_optimizer(
optimizer: torch.optim.Optimizer,
*,
reset_params: List[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: List[str],
reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: list[str],
prune_ratio: float = 0.9,
):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)

View File

@@ -16,7 +16,6 @@
# This code is based off the following work:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# pylint: disable=duplicate-code
""" PyTorch StableLM Epoch model. """
import importlib
import math

View File

@@ -16,6 +16,28 @@ from transformers.models.llama.modeling_llama import (
LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_CEL_CODE = """ if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
"""
PATCHED_CEL_CODE = """ if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
)
"""
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
@@ -60,6 +82,12 @@ def get_forward_code() -> str:
return forward
def check_cel_is_patchable() -> bool:
forward = get_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_CEL_CODE in forward
def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward
@@ -72,31 +100,48 @@ def check_self_attn_is_patchable() -> bool:
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
def UnslothForCausalLMLoss( # pylint: disable=invalid-name
logits,
labels,
vocab_size: int, # pylint: disable=unused-argument
num_items_in_batch: int = None,
ignore_index: int = -100, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
)
return loss
if model_type == "llama":
from transformers.loss import loss_utils
forward = get_forward_code()
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment]
forward = forward.replace(
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
)
forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"",
)
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace(
"def forward(",
"def fast_cross_entropy_loss_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
else:
raise ValueError("Unsupported model type")

View File

@@ -17,9 +17,11 @@ def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
max_num = int(torch.max(attention_mask).item())
batch_size, _ = attention_mask.shape
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
for i in range(1, max_num + 1):
mask = attention_mask == i
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
result = counts.flatten()
nonzero_indices = torch.nonzero(result).squeeze(-1)
return result[nonzero_indices]

View File

@@ -1,51 +0,0 @@
"""
Fused MLP layer for incrementally improved training efficiency
"""
import torch
from transformers.models.llama.modeling_llama import LlamaMLP
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import set_module_name
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)

View File

@@ -9,12 +9,8 @@ from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
LOG = logging.getLogger("axolotl.prompt_strategies")
def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
def load(strategy, tokenizer, cfg, ds_cfg):
try:
if strategy == "messages":
from .messages import load as messages_load
return messages_load(tokenizer, cfg, ds_cfg, processor=processor)
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
@@ -28,12 +24,9 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
if "processor" in sig.parameters:
load_kwargs["processor"] = processor
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc
return None
return None

View File

@@ -1,10 +0,0 @@
### example yaml
```yaml
chat_template: gemma
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
```

View File

@@ -1,35 +0,0 @@
"""Module to load prompt strategies."""
import importlib
import inspect
import logging
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")
def load(strategy, tokenizer, cfg, ds_cfg):
# pylint: disable=duplicate-code
try:
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(
f".{strategy}", "axolotl.prompt_strategies.bradley_terry"
)
func = getattr(mod, load_fn)
load_kwargs = {}
if strategy == "user_defined":
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
else:
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
return None

View File

@@ -1,102 +0,0 @@
"""
Bradley-Terry model with chat template prompt strategy.
"""
import logging
from typing import Any, Dict, Optional
from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
ChatTemplateStrategy,
)
from axolotl.utils.chat_templates import get_chat_template_from_config
# Configure the logger
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
LOG.setLevel(logging.INFO)
class BTChatTemplateStrategy(ChatTemplateStrategy):
"""
Bradley-Terry reward model pairwise chat template prompt strategy.
"""
def tokenize_prompt(self, prompt):
"""
:param prompt: the actual row of data from the underlying dataset
:return:
"""
self.messages = "chosen_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt)
self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
rejected_tokenized = super().tokenize_prompt(prompt)
return {
"input_ids_chosen": chosen_tokenized["input_ids"],
"attention_mask_chosen": chosen_tokenized["attention_mask"],
"labels_chosen": 1.0,
"input_ids_rejected": rejected_tokenized["input_ids"],
"attention_mask_rejected": rejected_tokenized["attention_mask"],
"labels_rejected": 0.0,
}
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", None
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1
if not cfg.reward_model
else cfg.sequence_len,
}
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", []),
"train_on_eos": ds_cfg.get("train_on_eos", None),
}
strategy = BTChatTemplateStrategy(
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy

View File

@@ -1,27 +0,0 @@
"""
chatml transforms for datasets with system, input, chosen, rejected to match llama3 chat template
"""
def icr(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
prompt = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = prompt + f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = prompt + f"{sample['rejected']}<|eot_id|>"
return sample
return transform_fn

View File

@@ -5,11 +5,9 @@ HF Chat Templates prompt strategy
import logging
from typing import Any, Dict, List, Optional
from transformers import ProcessorMixin
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.chat_templates import chat_templates
# Configure the logger
LOG = logging.getLogger("axolotl")
@@ -22,13 +20,12 @@ class ChatTemplatePrompter(Prompter):
def __init__(
self,
tokenizer,
processor=None,
chat_template=None,
max_length=2048,
message_field_role: str = "from",
message_field_content: str = "value",
message_field_training: Optional[str] = None,
message_field_training_detail: Optional[str] = None,
message_field_training: str = "train",
message_field_training_detail: str = "train_detail",
roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False,
):
@@ -47,12 +44,11 @@ class ChatTemplatePrompter(Prompter):
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.tokenizer = tokenizer
self.processor: ProcessorMixin = processor
self.chat_template = chat_template
self.max_length = max_length
self.drop_system_message = drop_system_message
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
def build_prompt(self, conversation, add_generation_prompt=False):
turns = [
{
"role": self.roles[t[self.message_field_role]],
@@ -65,28 +61,6 @@ class ChatTemplatePrompter(Prompter):
if self.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
if self.processor:
text = self.processor.apply_chat_template(
turns,
chat_template=self.chat_template,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
batch = self.processor(
text=text,
images=images,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
)
# workaround since processor works in batches instead of single examples
for k, val in batch.items():
if k in ["pixel_values"]:
batch[k] = val.tolist()
else:
batch[k] = val.squeeze().tolist()
return batch
return self.tokenizer.apply_chat_template(
turns,
truncation=True,
@@ -212,12 +186,11 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
train_on_inputs,
sequence_len,
roles_to_train=None,
train_on_eos=None,
train_on_eos="last",
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.roles_to_train = roles_to_train if roles_to_train is not None else []
self.train_on_eos = train_on_eos
self.images = "images"
@property
def messages(self):
@@ -228,40 +201,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
self._messages = messages
def tokenize_prompt(self, prompt):
# Old simple legacy behavior that works reliably.
if (
not self.roles_to_train
and not self.train_on_eos
and not self.prompter.message_field_training
and not self.prompter.message_field_training_detail
):
turns = self.get_conversation_thread(prompt)
images = self.get_images(prompt)
prompt_ids = self.prompter.build_prompt(
turns[:-1],
add_generation_prompt=True,
images=images,
)
tokenized_res = self.prompter.build_prompt(turns, images=images)
tokenized_prompt = {}
if isinstance(tokenized_res, list):
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
tokenized_prompt["input_ids"] = input_ids
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
else:
input_ids = tokenized_res["input_ids"]
tokenized_prompt = tokenized_res
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else:
labels = input_ids
tokenized_prompt["labels"] = labels
return tokenized_prompt
turns = prompt[self.messages]
input_ids = self.prompter.build_prompt(turns)
labels = [IGNORE_TOKEN_ID] * len(input_ids)
@@ -280,11 +219,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
should_train = (
train_turn
if train_turn is not None
else (
bool(train_detail is not None)
if train_detail is not None
else self.train_on_inputs or role in self.roles_to_train
)
else bool(train_detail is not None)
if train_detail is not None
else self.train_on_inputs or role in self.roles_to_train
)
LOG.debug(f"Should train: {should_train}")
@@ -398,40 +335,29 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def get_conversation_thread(self, prompt):
return prompt[self.messages]
def get_images(self, prompt):
return prompt.get(self.images, None)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "from"),
"message_field_content": ds_cfg.get("message_field_content", "value"),
"message_field_training": ds_cfg.get("message_field_training", "training"),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail",
None,
"message_field_training_detail", "train_detail"
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1,
"processor": processor,
"max_length": cfg.sequence_len,
}
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", []),
"train_on_eos": ds_cfg.get("train_on_eos", None),
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
}
strategy = ChatTemplateStrategy(

View File

@@ -2,16 +2,15 @@
DPO prompt strategies for using tokenizer chat templates.
"""
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
from axolotl.utils.chat_templates import chat_templates
def default(
cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
ds_cfg = cfg["datasets"][dataset_idx]
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg=cfg, ds_cfg=ds_cfg
)
chat_template_str = chat_templates(cfg.chat_template)
field_messages = ds_cfg.get("field_messages", "messages")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
@@ -31,12 +30,6 @@ def default(
role_map[source] = target
def transform_fn(sample, tokenizer=None):
chat_template_string = get_chat_template(
user_choice=chat_template_choice,
jinja_template=chat_template_jinja,
tokenizer=tokenizer,
)
messages = sample[field_messages]
messages = [
{
@@ -53,29 +46,28 @@ def default(
"role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content],
}
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
result = {}
result["prompt"] = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)
result["chosen"] = tokenizer.apply_chat_template(
[dummy_user_message, chosen],
[chosen],
add_generation_prompt=False,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)
chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template(
[dummy_user_message, rejected],
[rejected],
add_generation_prompt=False,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)
rejected_strip_index = result["rejected"].find(rejected["content"])

View File

@@ -1,34 +0,0 @@
"""Module to load message prompt strategies."""
import importlib
import inspect
import logging
LOG = logging.getLogger("axolotl.prompt_strategies.messages")
def load(tokenizer, cfg, ds_cfg, processor=None):
try:
strategy = ds_cfg.get("input_transform", "chat")
# pylint: disable=duplicate-code
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(
f".{strategy}", "axolotl.prompt_strategies.messages"
)
func = getattr(mod, load_fn)
load_kwargs = {}
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
if "processor" in sig.parameters:
load_kwargs["processor"] = processor
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc
return None

View File

@@ -1,84 +0,0 @@
"""
Chat dataset wrapping strategy for new internal messages representations
"""
from typing import Any, Callable, Dict, Optional
from axolotl.core.datasets.chat import TokenizedChatDataset
from axolotl.core.datasets.transforms.chat_builder import chat_message_transform_builder
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
class ChatMessageDatasetWrappingStrategy(DatasetWrappingStrategy):
"""
Chat dataset wrapping strategy for new internal messages representations
"""
def __init__(
self,
processor,
message_transform=None,
formatter=None,
**kwargs, # pylint: disable=unused-argument
):
"""
:param processor: tokenizer or image processor
:param kwargs:
"""
self.processor = processor
self.dataset = None
self.message_transform = message_transform
self.formatter = formatter
def wrap_dataset(
self,
dataset,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs, # pylint: disable=unused-argument
):
self.dataset = TokenizedChatDataset(
dataset,
message_transform=self.message_transform,
model_transform=self.processor,
formatter=self.formatter,
process_count=process_count,
keep_in_memory=keep_in_memory,
)
return self.dataset
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
field_messages = ds_cfg.get("field_messages")
message_field_role = ds_cfg.get("message_field_role")
message_field_content = ds_cfg.get("message_field_content")
message_field_training = ds_cfg.get("message_field_training")
builder_kwargs = {}
if field_messages:
builder_kwargs["conversations_field"] = field_messages
if message_field_role:
builder_kwargs["message_field_role"] = message_field_role
if message_field_content:
builder_kwargs["message_field_content"] = message_field_content
if message_field_training:
builder_kwargs["message_field_training"] = message_field_training
chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml"))
format_message = (
lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment
)
if chat_template == "chatml":
from axolotl.core.chat.format.chatml import format_message # noqa F811
if chat_template.startswith("llama3"):
from axolotl.core.chat.format.llama3x import format_message # noqa F811
message_transform: Callable = chat_message_transform_builder(
train_on_inputs=ds_cfg.get("train_on_inputs", False),
**builder_kwargs,
)
strategy = ChatMessageDatasetWrappingStrategy(
tokenizer, message_transform=message_transform, formatter=format_message
)
return strategy

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.chat_templates import chat_templates
class Message(BaseModel):
@@ -28,13 +28,18 @@ def load(
"""
chatml transforms for datasets with system, input, chosen, rejected
"""
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
tokenizer.chat_template = chat_template_string
chat_template = chat_templates("chatml")
if ds_cfg and "chat_template" in ds_cfg:
chat_template = ds_cfg["chat_template"]
try:
chat_template = chat_templates(chat_template)
except ValueError:
pass
tokenizer.chat_template = chat_template
return ORPOTokenizingStrategy(
ORPOPrompter(chat_template_string, tokenizer),
ORPOPrompter(chat_template, tokenizer),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
@@ -243,30 +248,28 @@ class ORPOPrompter(Prompter):
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
dataset_parser = ORPODatasetParsingStrategy()
chat_template_str = chat_templates(cfg.chat_template)
def transform_fn(sample, tokenizer=None):
res = {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, tokenizer=tokenizer
)
res["prompt"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
add_generation_prompt=True,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)
prompt_str_len = len(res["prompt"])
res["chosen"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)[prompt_str_len:]
res["rejected"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)[prompt_str_len:]

View File

@@ -61,9 +61,6 @@ def build_loader(
default_conversation: Optional[str] = None,
):
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
LOG.warning(
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead. https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template",
)
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg

View File

@@ -30,12 +30,6 @@ class InvalidDataException(Exception):
"""
class DatasetWrappingStrategy(abc.ABC):
"""
Abstract class for wrapping datasets for Chat Messages
"""
class PromptTokenizingStrategy(abc.ABC):
"""
Abstract class for tokenizing strategies

View File

@@ -352,12 +352,9 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
"Please help us by creating an Issue to add support for this conversation type."
)
if self._conversation.name in ["llama3"]:
role = from_role
else:
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
ROLE=from_role
)
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
ROLE=from_role
)
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
if (

View File

@@ -10,6 +10,7 @@ from typing import Optional, Tuple, Union
import torch
import transformers.modelcard
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from datasets import Dataset
@@ -23,7 +24,7 @@ from axolotl.core.tokenizer_utils import fix_untrained_tokens
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
try:
@@ -68,9 +69,6 @@ def train(
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
processor = None
if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer)
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
@@ -96,11 +94,10 @@ def train(
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
if model.generation_config is not None:
model.generation_config.do_sample = True
# we wait unitl the last possible moment to setup Accelerator
Accelerator()
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model.generation_config.do_sample = True
model_ref = None
if cfg.rl and cfg.rl != "orpo":
@@ -125,7 +122,6 @@ def train(
eval_dataset,
(model, model_ref, peft_config),
tokenizer,
processor,
total_num_steps,
)
@@ -260,10 +256,8 @@ def train(
if not cfg.hub_model_id:
try:
trainer.create_model_card(
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
)
except (AttributeError, UnicodeDecodeError):
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
except AttributeError:
pass
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated

View File

@@ -1,12 +1,8 @@
"""
Basic utils for Axolotl
"""
import importlib.util
import importlib
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
def is_comet_available():
return importlib.util.find_spec("comet_ml") is not None

View File

@@ -29,7 +29,7 @@ from transformers import (
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils import is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.callbacks.perplexity import Perplexity
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@@ -462,7 +462,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
references=[[r] for r in references],
predictions=predictions,
)
scores["eval_" + metric_name] = score
scores[metric_name] = score
return scores
def predict_with_generate():
@@ -747,15 +747,6 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
artifact_file="PredictionsVsGroundTruth.json",
tracking_uri=tracking_uri,
)
elif logger == "comet_ml" and is_comet_available():
import comet_ml
experiment = comet_ml.get_running_experiment()
if experiment:
experiment.log_table(
f"{name} - Predictions vs Ground Truth.csv",
pd.DataFrame(table_data),
)
if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)

View File

@@ -1,43 +0,0 @@
"""Comet module for trainer callbacks"""
import logging
from typing import TYPE_CHECKING
import comet_ml
from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
class SaveAxolotlConfigtoCometCallback(TrainerCallback):
"""Callback to save axolotl config to comet"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
comet_experiment = comet_ml.start(source="axolotl")
comet_experiment.log_other("Created from", "axolotl")
comet_experiment.log_asset(
self.axolotl_config_path,
file_name="axolotl-config",
)
LOG.info(
"The Axolotl config has been saved to the Comet Experiment under assets."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to Comet: {err}")
return control

File diff suppressed because one or more lines are too long

Some files were not shown because too many files have changed in this diff Show More