Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
459f407e69 avoid crash/oom on train end 2025-05-15 15:53:35 -04:00
298 changed files with 7729 additions and 13497 deletions

View File

@@ -16,9 +16,8 @@ on:
jobs:
build-base:
if: github.repository_owner == 'axolotl-ai-cloud'
timeout-minutes: 480
# this job needs to be run on self-hosted GPU runners...
runs-on: ubuntu-latest-m
runs-on: axolotl-gpu-runner
strategy:
fail-fast: false
matrix:
@@ -29,50 +28,42 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: nightly
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base-nightly"
# # "next" is for release candidates of pytorch
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
# python_version: "3.11"
# pytorch: next
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-base-next"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: next
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -94,60 +85,7 @@ jobs:
uses: docker/build-push-action@v4
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
build-args: |
CUDA_VERSION=${{ matrix.cuda_version }}
CUDNN_VERSION=${{ matrix.cudnn_version }}
CUDA=${{ matrix.cuda }}
PYTHON_VERSION=${{ matrix.python_version }}
PYTORCH_VERSION=${{ matrix.pytorch }}
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
build-base-uv:
if: github.repository_owner == 'axolotl-ai-cloud'
timeout-minutes: 480
runs-on: ubuntu-latest-m
strategy:
fail-fast: false
matrix:
include:
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-base-uv
- name: Login to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v4
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -9,7 +9,6 @@ on:
- '.github/workflows/*.yml'
- "*.[q]md"
- "examples/**/*.y[a]?ml"
- ".pre-commit-config.yaml"
workflow_dispatch:
jobs:

View File

@@ -29,12 +29,7 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -97,12 +92,7 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -43,7 +43,7 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
@@ -59,7 +59,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -25,6 +25,7 @@ jobs:
pre-commit autoupdate
if [[ -n $(git status --porcelain) ]]; then
echo "changes=true" >> $GITHUB_OUTPUT
git diff .pre-commit-config.yaml > pre-commit-update.diff
fi
- name: Create Pull Request
@@ -38,3 +39,11 @@ jobs:
commit-message: "chore: update pre-commit hooks"
body: |
Automated PR to update pre-commit hooks to their latest versions.
<details>
<summary>Changes:</summary>
```diff
${{ steps.update.outputs.diff }}
```
</details>

View File

@@ -44,6 +44,98 @@ jobs:
env:
SKIP: no-commit-to-branch
# preload-cache:
# name: Preload HF cache
# runs-on: ubuntu-latest
# strategy:
# fail-fast: false
# matrix:
# python_version: ["3.11"]
# pytorch_version: ["2.6.0"]
# timeout-minutes: 20
#
# env:
# AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
#
# steps:
# - name: Check out repository code
# uses: actions/checkout@v4
#
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
#
# - name: Restore Cache from S3
# id: hf-cache-restore-s3
# run: |
# mkdir -p /home/runner/.cache/huggingface/hub
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
#
# - name: Setup Python
# uses: actions/setup-python@v5
# with:
# python-version: ${{ matrix.python_version }}
# cache: 'pip' # caching pip dependencies
#
# - name: upgrade pip
# run: |
# pip3 install --upgrade pip
# pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
#
# - name: Install PyTorch
# run: |
# pip3 install torch==${{ matrix.pytorch_version }}
#
# - name: Install dependencies
# run: |
# pip3 show torch
# pip3 install --no-build-isolation -U -e .
# python scripts/unsloth_install.py | sh
# python scripts/cutcrossentropy_install.py | sh
# pip3 install -r requirements-dev.txt -r requirements-tests.txt
#
# - name: Make sure PyTorch version wasn't clobbered
# run: |
# python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
#
# - name: Ensure axolotl CLI was installed
# run: |
# axolotl --help
#
# - name: Pre-Download dataset fixture
# run: |
# huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
#
# - name: Run tests
# run: |
# pytest -v tests/conftest.py
#
# - name: Upload coverage to Codecov
# uses: codecov/codecov-action@v5
# with:
# token: ${{ secrets.CODECOV_TOKEN }}
# files: ./coverage.xml
# flags: unittests,pytorch-${{ matrix.pytorch_version }}
# fail_ci_if_error: false
#
# - name: cleanup pip cache
# run: |
# find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
#
# - name: Save HF cache
# id: hf-cache
# uses: actions/cache/save@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest:
name: PyTest
runs-on: ubuntu-latest
@@ -52,13 +144,22 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
- name: Check out repository code
uses: actions/checkout@v4
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
@@ -121,17 +222,27 @@ jobs:
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
# needs: [preload-cache]
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
- name: Check out repository code
uses: actions/checkout@v4
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
@@ -184,11 +295,10 @@ jobs:
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests-1st:
# Run this job first as a gate for running the remainder of the test matrix
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
timeout-minutes: 90
needs: [pre-commit, pytest, pytest-sdist]
strategy:
@@ -201,13 +311,6 @@ jobs:
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: vllm
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -218,7 +321,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -229,7 +332,6 @@ jobs:
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
@@ -238,9 +340,7 @@ jobs:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
# Only run the remainder of the matrix if the first e2e check passed;
# this is to save on wasted compute costs for known failures that get caught in the first run
timeout-minutes: 90
needs: [pre-commit, pytest, docker-e2e-tests-1st]
strategy:
@@ -262,13 +362,7 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.7.0
num_gpus: 1
axolotl_extras:
steps:
@@ -281,7 +375,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -292,7 +386,6 @@ jobs:
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
@@ -322,7 +415,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -19,15 +19,15 @@ repos:
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 7.2.0
rev: 7.1.2
hooks:
- id: flake8
- repo: https://github.com/pylint-dev/pylint
rev: v3.3.7
rev: v3.3.6
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.16.0
rev: v1.15.0
hooks:
- id: mypy
additional_dependencies:

View File

@@ -242,12 +242,16 @@
# early_stopping_patience: 3
# # Specify a scheduler and kwargs to use with the optimizer
# lr_scheduler: # 'one_cycle' | empty for cosine
# lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
# lr_scheduler_kwargs:
# # For one_cycle optim
# lr_div_factor: # Learning rate div factor
# # For log_sweep optim
# log_sweep_min_lr:
# log_sweep_max_lr:
# # Specify optimizer
# # Valid values are driven by the Transformers OptimizerNames class, see:
# # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134

View File

@@ -22,32 +22,28 @@
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
</p>
## 🎉 Latest Updates
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).
## ✨ Overview
Axolotl is a tool designed to streamline post-training for various AI models.
Post-training refers to any modifications or additional training performed on
pre-trained models - including full model fine-tuning, parameter-efficient tuning (like
LoRA and QLoRA), supervised fine-tuning (SFT), instruction tuning, and alignment
techniques. With support for multiple model architectures and training configurations,
Axolotl makes it easy to get started with these techniques.
Axolotl is designed to work with YAML config files that contain everything you need to
preprocess a dataset, train or fine-tune a model, run model inference or evaluation,
and much more.
Features:
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), Sequence Parallelism (SP), LoRA optimizations, Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more!
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
- Train various Huggingface models such as llama, pythia, falcon, mpt
- Supports fullfinetune, lora, qlora, relora, and gptq
- Customize configurations using a simple yaml file or CLI overwrite
- Load different dataset formats, use custom formats, or bring your own tokenized datasets
- Integrated with [xformers](https://github.com/facebookresearch/xformers), flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb, mlflow or Comet
- And more!
## 🚀 Quick Start
@@ -55,7 +51,7 @@ Features:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- PyTorch ≥2.5.1
- PyTorch ≥2.4.1
### Installation
@@ -85,12 +81,19 @@ axolotl train examples/llama-3/lora-1b.yml
That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/getting-started.html) for a more detailed walkthrough.
## ✨ Key Features
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, and more
- **Easy Configuration**: Simple YAML files to control your training setup
- **Performance Optimizations**: Flash Attention, xformers, multi-GPU training
- **Flexible Dataset Handling**: Use various formats and custom datasets
- **Cloud Ready**: Run on cloud platforms or local hardware
## 📚 Documentation
- [Installation Options](https://docs.axolotl.ai/docs/installation.html) - Detailed setup instructions for different environments
- [Configuration Guide](https://docs.axolotl.ai/docs/config.html) - Full configuration options and examples
- [Dataset Loading](https://docs.axolotl.ai/docs/dataset_loading.html) - Loading datasets from various sources
- [Dataset Guide](https://docs.axolotl.ai/docs/dataset-formats/) - Supported formats and how to use them
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
@@ -109,6 +112,31 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
## Supported Models
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Mixtral8X22 | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
| Jamba | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
✅: supported
❌: not supported
❓: untested
## ❤️ Sponsors
Thank you to our sponsors who help make Axolotl possible:

View File

@@ -17,9 +17,7 @@ quartodoc:
- convert
- prompt_tokenizers
- logging_config
- core.builders.base
- core.builders.causal
- core.builders.rl
- core.trainer_builder
- core.training_args
- core.chat.messages
- core.chat.format.chatml
@@ -45,7 +43,6 @@ quartodoc:
- cli.vllm_serve
- cli.cloud.base
- cli.cloud.modal_
- cli.quantize
- title: Trainers
desc: Training implementations
contents:
@@ -57,21 +54,13 @@ quartodoc:
- core.trainers.grpo.trainer
- core.trainers.grpo.sampler
- core.trainers.utils
- title: Model Loading
desc: Functionality for loading and patching models, tokenizers, etc.
contents:
- loaders.model
- loaders.tokenizer
- loaders.processor
- loaders.adapter
- loaders.patch_manager
- loaders.constants
- title: Mixins
desc: Mixin classes for augmenting trainers
contents:
- core.trainers.mixins.optimizer
- core.trainers.mixins.rng_state_loader
- core.trainers.mixins.scheduler
- core.trainers.mixins.sequence_parallel
- title: Context Managers
desc: Context managers for altering trainer behaviors
contents:
@@ -129,16 +118,17 @@ quartodoc:
- monkeypatch.trainer_fsdp_optim
- monkeypatch.transformers_fa_utils
- monkeypatch.unsloth_
- monkeypatch.attention.mllama
- monkeypatch.data.batch_dataset_fetcher
- monkeypatch.mixtral
- monkeypatch.gradient_checkpointing.offload_cpu
- monkeypatch.gradient_checkpointing.offload_disk
- title: Utils
desc: Utility functions
contents:
- utils.models
- utils.tokenization
- utils.chat_templates
- utils.lora
- utils.lora_embeddings
- utils.model_shard_quant
- utils.bench
- utils.freeze
@@ -149,7 +139,8 @@ quartodoc:
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.sft
- utils.quantization
- utils.gradient_checkpointing.offload_cpu
- utils.gradient_checkpointing.offload_disk
- title: Schemas
desc: Pydantic data models for Axolotl config
contents:
@@ -199,14 +190,12 @@ quartodoc:
- utils.callbacks.lisa
- utils.callbacks.mlflow_
- utils.callbacks.comet_
- utils.callbacks.qat
website:
title: "Axolotl"
description: "We make fine-tuning accessible, scalable, and fun"
favicon: favicon.jpg
google-analytics: "G-9KYCVJBNMQ"
navbar:
logo: image/axolotl_logo_digital_white.svg
title: false
@@ -259,8 +248,6 @@ website:
- docs/lr_groups.qmd
- docs/lora_optims.qmd
- docs/dataset_loading.qmd
- docs/qat.qmd
- docs/quantize.qmd
- section: "Core Concepts"
contents:

View File

@@ -1,52 +0,0 @@
FROM axolotlai/axolotl-base-uv:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN uv pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py --uv | sh
RUN python scripts/cutcrossentropy_install.py --uv | sh
# So we can test the Docker image
RUN uv pip install -r requirements-dev.txt -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli
RUN git config --global credential.helper store

View File

@@ -24,9 +24,9 @@ df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
"CUDA": os.environ.get("CUDA", "124"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
@@ -55,7 +55,7 @@ VOLUME_CONFIG = {
}
N_GPUS = int(os.environ.get("N_GPUS", 2))
GPU_CONFIG = f"H100:{N_GPUS}"
GPU_CONFIG = modal.gpu.H100(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):
@@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str):
image=cicd_image,
gpu=GPU_CONFIG,
timeout=90 * 60,
cpu=16.0,
cpu=8.0,
memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG,
)

View File

@@ -8,9 +8,8 @@ import tempfile
import jinja2
import modal
import modal.experimental
from jinja2 import select_autoescape
from modal import App
from modal import App, Image
cicd_path = pathlib.Path(__file__).parent.resolve()
@@ -18,15 +17,14 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
df_template = template_env.get_template(dockerfile)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
"CUDA": os.environ.get("CUDA", "124"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
@@ -40,11 +38,11 @@ temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = modal.experimental.raw_dockerfile_image(
cicd_image = Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
# context_mount=None,
context_mount=None,
force_build=True,
# gpu="A10G",
gpu="A10G",
).env(df_args)
app = App("Axolotl CI/CD", secrets=[])
@@ -57,7 +55,7 @@ VOLUME_CONFIG = {
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = f"L40S:{N_GPUS}"
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):

View File

@@ -1,31 +0,0 @@
{
"compile": {
"disable": false,
"backend": "inductor"
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -38,6 +38,6 @@ RUN git lfs install --skip-repo && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
pip3 install flash-attn==2.7.4.post1; \
fi

View File

@@ -29,7 +29,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
python3 -m pip install --no-cache-dir -U torch==2.7.0 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"

View File

@@ -1,40 +0,0 @@
ARG CUDA_VERSION="12.6.3"
ARG CUDNN_VERSION=""
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.6.0"
ARG CUDA="126"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
ENV UV_TORCH_BACKEND="cu${CUDA}"
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config curl && rm -rf /var/lib/apt/lists/* \
&& git lfs install --skip-repo \
&& curl -LsSf https://astral.sh/uv/install.sh | sh
ENV PATH="/root/.local/bin:${PATH}"
RUN uv python install ${PYTHON_VERSION}
WORKDIR /workspace
RUN uv venv --no-project --relocatable axolotl-venv
ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install torch==${PYTORCH_VERSION} \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
uv pip install --no-build-isolation flash-attn==2.7.4.post1; \
fi

View File

@@ -209,16 +209,6 @@ axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.
### quantize
Quantizes a model using the quantization configuration specified in your YAML file.
```bash
axolotl quantize config.yml
```
See [Quantization](./quantize.qmd) for more details.
## Legacy CLI Usage

View File

@@ -27,8 +27,6 @@ trust_remote_code:
tokenizer_use_fast:
# Whether to use the legacy tokenizer setting, defaults to True
tokenizer_legacy:
# Whether to use mistral-common tokenizer. If set to True, it will use the mistral-common tokenizer.
tokenizer_use_mistral_common:
# Resize the model embeddings when new tokens are added to multiples of 32
# This is reported to improve training speed on some models
resize_token_embeddings_to_32x:
@@ -67,20 +65,6 @@ bnb_config_kwargs:
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# quantization aware training
qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
# post-training quantization
quantization:
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
# Whether you are training a 4-bit GPTQ quantized model
gptq: true
@@ -114,10 +98,8 @@ plugins:
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# A list of one or more datasets to finetune the model with
# See https://docs.axolotl.ai/docs/dataset_loading.html for guide on loading datasets
# See https://docs.axolotl.ai/docs/dataset-formats/ for guide on dataset formats
datasets:
# HuggingFace dataset repo | s3:// | gs:// | path to local file or directory
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
@@ -175,10 +157,6 @@ datasets:
# Key containing the messages (default: "messages")
field_messages: messages
# Key containing the tools (default: "tools")
# Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
field_tools: tools
# Key containing the system message (default: "system")
# If the system message is not present in the dataset sample, it will be loaded from the field_system property.
field_system: system
@@ -243,7 +221,7 @@ datasets:
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
# Deduplicates datasets and test_datasets with identical entries.
Deduplicates datasets and test_datasets with identical entries.
dataset_exact_deduplication: true
# A list of one or more datasets to eval the model with.
@@ -292,25 +270,10 @@ trl:
num_generations: # Optional[int]. Number of generations to sample.
log_completions: # Optional[bool]. Whether to log completions.
num_completions_to_print: # Optional[int]. Number of completions to print when log_completions is True.
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
scale_rewards: # Optional[bool]. Whether to scale rewards by their standard deviation.
temperature: # Optional[float]. Sampling temperature for the GRPO policy.
top_p: # Optional[float]. Top-p sampling probability for the generation policy.
top_k: # Optional[int]. Top-k sampling for the generation policy.
min_p: # Optional[float]. Minimum probability for the generation policy.
repetition_penalty: # Optional[float]. Penalty for tokens that appear in prompt and generated text.
num_iterations: # Optional[int]. Number of iterations per batch (μ) for GRPO.
epsilon: # Optional[float]. Epsilon value for clipping in the GRPO algorithm.
epsilon_high: # Optional[float]. Upper-bound epsilon value for clipping in the GRPO algorithm.
use_liger_loss: # Optional[bool]. Whether to use Liger loss for GRPO.
loss_type: # Optional[str]. Loss formulation to use. Supported values: grpo, bnpo, dr_grpo.
mask_truncated_completions: # Optional[bool]. Whether to exclude truncated completions from loss calculation.
# reward modelling: `True` or `False`
@@ -520,7 +483,6 @@ output_dir: ./completed-model
# setting to `auto` will enable torch compile when torch>=2.5.1
torch_compile: # Optional[Union[Literal["auto"], bool]]
torch_compile_backend: # Optional[str]
torch_compile_mode: # 'default' | 'reduce-overhead' | 'max-autotune'
# Training hyperparameters
@@ -567,7 +529,7 @@ profiler_steps: # enable the pytorch profiler to capture the first N steps of tr
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
# Save model as safetensors (require safetensors package). Default True
# Save model as safetensors (require safetensors package)
save_safetensors:
# Whether to mask out or include the human's prompt from the training labels
@@ -589,24 +551,7 @@ gradient_checkpointing: false
early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer
# Valid values are driven by the Transformers SchedulerType class, see:
# https://github.com/huggingface/transformers/blob/5f4ecf2d9f867a1255131d2461d75793c0cf1db2/src/transformers/trainer_utils.py#L420
# Valid values include
# - 'linear'
# - 'cosine' (default)
# - 'cosine_with_restarts'
# - 'polynomial'
# - 'constant'
# - 'constant_with_warmup'
# - 'inverse_sqrt'
# - 'reduce_lr_on_plateau'
# - 'cosine_with_min_lr'
# - 'warmup_stable_decay'
# Additional schedulers include:
# - 'one_cycle'
# - 'rex'
lr_scheduler:
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | 'linear' | 'cosine_with_restarts' | 'polynomial' | 'constant' | 'constant_with_warmup' | 'inverse_sqrt' | 'reduce_lr_on_plateau' | 'cosine_with_min_lr' | 'warmup_stable_decay' | empty for cosine
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
@@ -624,7 +569,7 @@ lr_div_factor: # Learning rate div factor
#
# Valid values for 'optimizer' include:
# - adamw_torch
# - adamw_torch_fused (default)
# - adamw_torch_fused
# - adamw_torch_xla
# - adamw_torch_npu_fused
# - adamw_apex_fused
@@ -688,9 +633,7 @@ weight_decay:
# adamw hyperparams
adam_beta1:
adam_beta2:
adam_beta3: # only used for CAME Optimizer
adam_epsilon:
adam_epsilon2: # only used for CAME Optimizer
# Gradient clipping max norm
max_grad_norm:

View File

@@ -52,9 +52,7 @@ We recommend checking the below examples for other usecases.
### Examples
#### Training on last message
(Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
1. (Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
```yaml
datasets:
@@ -68,9 +66,7 @@ datasets:
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
:::
#### Overriding default chat template
Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: gemma # this overwrites the tokenizer's chat_template
@@ -80,13 +76,7 @@ datasets:
roles_to_train: ["assistant"] # default value
```
::: {.callout-note}
If you want to use built-in chat_template, use `chat_template: tokenizer_default` (this is set by default).
:::
#### Using default chat template with fallback
Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
@@ -95,9 +85,7 @@ datasets:
type: chat_template
```
#### Custom Jinja template
Using a custom jinja template on OpenAI messages format, training on all assistant messages.
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
```yaml
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
@@ -112,9 +100,7 @@ datasets:
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
:::
#### Using template with different token for EOT and EOS
- If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
5. If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
```yaml
eot_tokens:
@@ -139,7 +125,7 @@ Using `eot_tokens` requires each token that exists in `chat_template` to be a si
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details.
:::
- Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
6. Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
```yaml
eot_tokens:
@@ -159,73 +145,7 @@ If EOS token only appears at the end of a prompt, `train_on_eos: last` is equiva
:::
#### Using tool use
Instead of passing `tools` via the system prompt, an alternative method would be to have the `tools` in a separate column and loaded via `chat_template` to let the template dynamically build it.
```json
{
"tools": [
{
"type": "...",
"function": {
"name": "...",
"description": "...",
"parameters": {
"type": "...",
"properties": {
// ...
},
"required": ["..."],
},
},
},
],
"messages": [
// ...
{
"role": "assistant", // call the function via assistant
"tool_calls": [
{
"type": "function",
"function": {
"name": "...",
"arguments": {
"...": "...",
}
}
}
]
},
{
"role": "tool",
"name": "...",
"content": "..."
},
],
}
```
::: {.callout-note}
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
:::
```yaml
chat_template: llama4
datasets:
- path: ...
type: chat_template
# field_tools: tools # default is `tools`
```
::: {.callout-tip}
Look into the `chat_template` you are using to see if it supports `tools` and what the expected role is for the tool answer. In the example above, the tool answer is expected to be in the `tool` or `ipython` role for `llama4` template.
:::
#### Using fine-grained control over token masking
(Advanced) Using fine-grained control over tokens and turns to train in a conversation
7. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
@@ -276,9 +196,7 @@ datasets:
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
:::
#### Reasoning split
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
8. (For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
```yaml
datasets:

View File

@@ -36,6 +36,10 @@ It is typically recommended to save your dataset as `.jsonl` due to its flexibil
Axolotl supports loading from a Hugging Face hub repo or from local files.
::: {.callout-important}
For pre-training only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts.
:::
### Pre-training from Hugging Face hub datasets
As an example, to train using a Hugging Face dataset `hf_org/name`, you can pass the following config:
@@ -73,21 +77,18 @@ datasets:
type: completion
```
From local files:
From local files (either example works):
```yaml
datasets:
- path: A.jsonl
type: completion
- path: B.jsonl
- path: json
data_files: ["A.jsonl", "B.jsonl", "C.jsonl"]
type: completion
```
::: {.callout-important}
For `completion` only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts. If you are interested in having this for `pretraining_dataset` too, please let us know or help make a PR!
:::
### Pre-training dataset configuration tips
#### Setting max_steps

View File

@@ -54,7 +54,7 @@ datasets:
#### Files
To load a JSON file, you would do something like this:
Usually, to load a JSON file, you would do something like this:
```python
from datasets import load_dataset
@@ -66,11 +66,19 @@ Which translates to the following config:
```yaml
datasets:
- path: data.json
ds_type: json
- path: json
data_files: /path/to/your/file.jsonl
```
In the example above, it can be seen that we can just point the `path` to the file or directory along with the `ds_type` to load the dataset.
However, to make things easier, we have added a few shortcuts for loading local dataset files.
You can just point the `path` to the file or directory along with the `ds_type` to load the dataset. The below example shows for a JSON file:
```yaml
datasets:
- path: /path/to/your/file.jsonl
ds_type: json
```
This works for CSV, JSON, Parquet, and Arrow files.

View File

@@ -8,10 +8,6 @@ format:
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
::: {.callout-important}
For Blackwell GPUs, please use the tags with Pytorch 2.7.1 and CUDA 12.8.
:::
## Base
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
@@ -32,10 +28,11 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
Tags examples:
- `main-base-py3.11-cu128-2.7.1`
- `main-base-py3.11-cu126-2.7.1`
- `main-base-py3.11-cu128-2.7.0`
- `main-base-py3.11-cu126-2.7.0`
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1`
## Main
@@ -76,10 +73,12 @@ Tags examples:
- `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1`
- `0.9.2`
- `main-20250303-py3.11-cu124-2.4.1`
- `0.7.1`
## Cloud

View File

@@ -110,17 +110,3 @@ description: Frequently asked questions
> A: If `eot_tokens: ` is not provided, the default behavior is the same as before. EOS tokens used to delimit turns are masked/unmasked depending on whether the turn is trainable.
> Internally, `eot_tokens: tokenizer.eos_token` and `train_on_eot: train_on_eos` (which defaults to `turn`). This transition helps clarify the naming and behavior of EOT/EOS tokens.
**Q: `Data processing error: CAS service error`**
> A: Try disabling XET with `export HF_HUB_DISABLE_XET=1`
**Q: `torch._inductor.exc.LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. `**
> A: Depending on the version of torch, you may need to include this in your YAML:
> ```yaml
> flex_attn_compile_kwargs:
> dynamic: false
> mode: max-autotune-no-cudagraphs
> ```

View File

@@ -104,7 +104,7 @@ the `alpaca` dataset format, which has the following format:
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
format them.
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca`
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca
format):
```json
@@ -120,12 +120,6 @@ axolotl train my_training.yml
## Common Tasks {#sec-common-tasks}
::: {.callout-tip}
The same yaml file is used for training, inference, and merging.
:::
### Testing Your Model {#sec-testing}
After training, test your model:
@@ -134,16 +128,6 @@ After training, test your model:
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
```
More details can be found in [Inference](inference.qmd).
### Using a UI {#sec-ui}
Launch a Gradio interface:
```bash
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
```
### Preprocessing Data {#sec-preprocessing}
For large datasets, preprocess first:
@@ -152,22 +136,14 @@ For large datasets, preprocess first:
axolotl preprocess my_training.yml
```
Please make sure to set `dataset_prepared_path: ` in your config to set the path to save the prepared dataset.
### Using a UI {#sec-ui}
More details can be found in [Dataset Preprocessing](dataset_preprocessing.qmd).
### Merging LoRA weights {#sec-merging-lora}
To merge the LoRA weights back into the base model, run:
Launch a Gradio interface:
```bash
axolotl merge-lora my_training.yml --lora-model-dir="./outputs/lora-out"
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
```
The merged model will be saved in the `{output_dir}/merged` directory.
More details can be found in [Merging LoRA weights](inference.qmd#sec-merging).
## Next Steps {#sec-next-steps}
Now that you have the basics, you might want to:
@@ -180,7 +156,6 @@ Now that you have the basics, you might want to:
Check our other guides for details on these topics:
- [Configuration Guide](config.qmd) - Full configuration options
- [Dataset Loading](dataset_loading.qmd) - Loading datasets from various sources
- [Dataset Formats](dataset-formats) - Working with different data formats
- [Multi-GPU Training](multi-gpu.qmd)
- [Multi-Node Training](multi-node.qmd)

View File

@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.10
- PyTorch ≥2.5.1
- PyTorch ≥2.4.1
## Installation Methods {#sec-installation-methods}
@@ -25,10 +25,6 @@ Please make sure to have Pytorch installed before installing Axolotl in your loc
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
:::
::: {.callout-important}
For Blackwell GPUs, please use Pytorch 2.7.0 and CUDA 12.8.
:::
### PyPI Installation (Recommended) {#sec-pypi}
```{.bash}
@@ -41,40 +37,6 @@ installed) in order not to clobber it, and so that we set the correct version of
dependencies that are specific to the PyTorch version or other installed
co-dependencies.
### uv Installation {#sec-uv}
uv is a fast, reliable Python package installer and resolver built in Rust. It offers significant performance improvements over pip and provides better dependency resolution, making it an excellent choice for complex environments.
Install uv if not already installed
```{.bash}
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env
```
Choose your CUDA version to use with PyTorch; e.g. `cu124`, `cu126`, `cu128`,
then create the venv and activate
```{.bash}
export UV_TORCH_BACKEND=cu126
uv venv --no-project --relocatable
source .venv/bin/activate
```
Install PyTorch
- PyTorch 2.6.0 recommended
```{.bash}
uv pip install packaging setuptools wheel
uv pip install torch==2.6.0
uv pip install awscli pydantic
```
Install axolotl from PyPi
```{.bash}
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn]
# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn,vllm]
```
### Edge/Development Build {#sec-edge-build}
For the latest features between releases:
@@ -110,10 +72,6 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
```
:::
::: {.callout-important}
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.7.0` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.7.0`.
:::
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
## Cloud Environments {#sec-cloud}

View File

@@ -84,10 +84,6 @@ lora_qkv_kernel: true
lora_o_kernel: true
```
::: {.callout-note}
Currently, LoRA kernels are not supported for RLHF training, only SFT.
:::
## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)

View File

@@ -87,7 +87,20 @@ We support sequence parallelism (SP) via the
allows one to split up sequences across GPUs, which is useful in the event that a
single sequence causes OOM errors during model training.
See our [dedicated guide](sequence_parallelism.qmd) for more information.
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
or from source with `pip install .[ring-flash-attn]`.
Your Axolotl YAML config should contain the following lines:
```{.yaml}
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
heads_k_stride: 1
```
See our [dedicated guide](sequence_parallelism.qmd) for more details.
### FSDP + QLoRA {#sec-fsdp-qlora}

View File

@@ -43,7 +43,7 @@ datasets:
# leave the vision model and vision tower frozen
# load_in_8bit: true
adapter: lora
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
# (optional) if you want to resize images to a set size
image_size: 512

View File

@@ -1,32 +0,0 @@
---
title: "Quantization Aware Training (QAT)"
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
## Overview
[Quantization Aware Training](https://pytorch.org/blog/introduction-to-quantization-on-pytorch/#quantization-aware-training) (QAT) is a technique for improving the accuracy of models which are quantized
by applying "fake" quantizations to the model's weights (and optionally, activations) during training. This fake
quantization allows for the model to adjust for noise introduced by the quantization, so when the model is eventually
quantized, the accuracy loss is minimized. We use the quantization techniques implemented in [torchao](https://github.com/pytorch/ao) to provide
support for QAT and post-training quantization (PTQ) in axolotl.
We recommend reviewing the excellent QAT tutorial in the [torchtune library](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#quantizing-the-qat-model),
and the QAT documentation in the [torchao library](https://github.com/pytorch/ao/tree/main/torchao/quantization/qat), for more details.
## Configuring QAT in Axolotl
To enable QAT in axolotl, add the following to your configuration file:
```yaml
qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
```
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize`](./quantize.qmd) command to do this.

View File

@@ -1,53 +0,0 @@
---
title: "Quantization with torchao"
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
Quantization is a technique to lower the memory footprint of your model, potentially at the cost of accuracy or model performance. We support quantizing your model using the [torchao](https://github.com/pytorch/ao) library. Quantization is supported for both post-training quantization (PTQ) and quantization-aware training (QAT).
::: {.callout-note}
We do not currently support quantization techniques such as GGUF/GPTQ,EXL2 at the moment.
:::
## Configuring Quantization in Axolotl
Quantization is configured using the `quantization` key in your configuration file.
```yaml
base_model: # The path to the model to quantize.
quantization:
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
output_dir: # The path to the output directory.
```
Once quantization is complete, your quantized model will be saved in the `{output_dir}/quantized` directory.
You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.md) - you can do this by using the existing QAT configuration file which
you used to train the model:
```yaml
# qat.yml
qat:
activation_dtype: int8
weight_dtype: int8
group_size: 256
quantize_embedding: true
output_dir: # The path to the output directory used during training where the final checkpoint has been saved.
```
```bash
axolotl quantize qat.yml
```
This ensures that an identical quantization configuration is used to quantize the model as was used to train it.

View File

@@ -16,8 +16,7 @@ feedback. Various methods include, but not limited to:
- [Identity Preference Optimization (IPO)](#ipo)
- [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
- [Group Relative Policy Optimization (GRPO)](#grpo)
- Proximal Policy Optimization (PPO) (not yet supported in axolotl, if you're interested in contributing, please reach out!)
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
## RLHF using Axolotl
@@ -500,7 +499,7 @@ The input format is a simple JSON input with customizable fields based on the ab
### GRPO
::: {.callout-tip}
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/grpo_code).
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
:::
In the latest GRPO implementation, `vLLM` is used to significantly speedup trajectory generation during training. In this example, we're using 4 GPUs - 2 for training, and 2 for vLLM:
@@ -583,20 +582,7 @@ datasets:
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
#### GRPO with DAPO/Dr. GRPO loss
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.
```yaml
trl:
loss_type: dr_grpo
# Normalizes loss based on max completion length (default: 256)
max_completion_length:
```
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
### SimPO

View File

@@ -41,7 +41,7 @@ When sequence parallelism is enabled:
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
3. Position IDs are adjusted to maintain proper relative positions
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
4. The trainer uses special ring communication patterns for attention operations
## Requirements
@@ -67,11 +67,9 @@ sequence_len: 8192
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
ring_attn_func:
...
```

View File

@@ -28,7 +28,7 @@ pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -30,7 +30,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -29,7 +29,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -1,79 +0,0 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
output_dir: ./outputs/qat_out/
sample_packing: true
pad_to_sequence_len: true
sequence_len: 512
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
qat:
activation_dtype: int8
weight_dtype: int4
group_size: 32
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true
bf16: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_steps: 10
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -5,10 +5,6 @@ tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
load_in_8bit: true
load_in_4bit: false

View File

@@ -5,7 +5,7 @@ base_model: NousResearch/Llama-3.2-1B
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
@@ -38,7 +38,6 @@ wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -25,7 +25,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -1,71 +0,0 @@
# Finetune Magistral Small with Axolotl
Magistral Small is a 24B parameter opensource model from MistralAI found on [HuggingFace](https://huggingface.co/mistralai/Magistral-Small-2506). This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.
MistralAI has also released a proprietary medium-sized version called Magistral Medium.
Thanks to the team at MistralAI for giving us early access to prepare for this release.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Magistral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 recommended)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,mistral]'
```
2. Download the example config:
```bash
axolotl fetch examples
```
3. Run the finetuning example:
```bash
axolotl train examples/magistral/magistral-small-qlora.yaml
```
This config uses about 24GB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Limitations
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet.
## Related Resources
- [MistralAI Magistral Blog](https://mistral.ai/news/magistral/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
## Future Work
- Add parity to Preference Tuning, RL, Multi-modal, etc.
- Add parity to other tokenizer configs like overriding tokens.

View File

@@ -1,72 +0,0 @@
base_model: mistralai/Magistral-Small-2506
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing:
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
fsdp_activation_checkpointing: true

View File

@@ -1,63 +0,0 @@
base_model: mistralai/Magistral-Small-2506
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -27,7 +27,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -25,7 +25,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -25,7 +25,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -2,6 +2,7 @@ base_model: Qwen/Qwen2.5-0.5B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
chat_template: qwen_25
rl: dpo
datasets:

View File

@@ -1,78 +0,0 @@
base_model: Qwen/Qwen3-8B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/qat_out/
sequence_len: 2048
sample_packing: true
flex_attention: true
pad_to_sequence_len: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
qat:
activation_dtype: int8
weight_dtype: int4
group_size: 256
fake_quant_after_n_steps: 1000
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
max_steps: 2000
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_steps: 10
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:

View File

@@ -6,20 +6,21 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.10
liger-kernel==0.5.9
# END section
packaging==23.2
huggingface_hub==0.32.2
huggingface_hub==0.31.0
peft==0.15.2
transformers==4.52.3
transformers==4.51.3
tokenizers>=0.21.1
accelerate==1.7.0
datasets==3.6.0
deepspeed>=0.17.0
trl==0.18.1
hf_xet==1.1.2
accelerate==1.6.0
datasets==3.5.1
deepspeed>=0.15.4
trl==0.17.0
hf_xet==1.1.0
hqq==0.2.5
optimum==1.16.2
hf_transfer
@@ -62,10 +63,8 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.10.0
torchao==0.9.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3
mistral-common==1.6.0

View File

@@ -9,8 +9,6 @@ except ImportError as exc:
raise ImportError("Install torch via `pip install torch`") from exc
from packaging.version import Version as V
USE_UV = "--uv" in sys.argv[1:]
v = V(torch.__version__)
# no cut-cross-entropy support for torch < 2.4.0
@@ -25,9 +23,7 @@ if cce_spec:
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a1174ca"'
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"'
)

View File

@@ -11,7 +11,7 @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory is empty, run the following commands:
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:
```
cd /workspace

View File

@@ -1,15 +1,11 @@
# noqa
# pylint: skip-file
import sys
try:
import torch
except ImportError:
raise ImportError("Install torch via `pip install torch`")
from packaging.version import Version as V
use_uv = "--uv" in sys.argv[1:]
v = V(torch.__version__)
cuda = str(torch.version.cuda)
try:
@@ -35,7 +31,6 @@ elif v < V("2.6.0"):
else:
raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
uv_prefix = "uv " if use_uv else ""
print(
f'{uv_prefix}pip install unsloth-zoo==2024.12.1 && {uv_prefix}pip install --no-deps "unsloth[{x}]==2024.12.4"'
f'pip install unsloth-zoo==2024.12.1 && pip install --no-deps "unsloth[{x}]==2024.12.4"'
)

View File

@@ -118,7 +118,7 @@ extras_require = {
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.17.0",
"deepspeed==0.15.4",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.10.0"
__version__ = "0.10.0.dev0"

View File

@@ -28,6 +28,7 @@ class TrainerCliArgs:
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
main_process_port: Optional[int] = field(default=None)
@@ -88,26 +89,6 @@ class VllmServeCliArgs:
},
)
enable_reasoning: Optional[bool] = field(
default=None,
)
reasoning_parser: Optional[str] = field(
default=None,
)
@dataclass
class QuantizeCliArgs:
"""Dataclass with CLI arguments for `axolotl quantize` command."""
base_model: Optional[str] = field(default=None)
weight_dtype: Optional[str] = field(default=None)
activation_dtype: Optional[str] = field(default=None)
quantize_embedding: Optional[bool] = field(default=None)
group_size: Optional[int] = field(default=None)
output_dir: Optional[str] = field(default=None)
@dataclass
class EvaluateCliArgs:

View File

@@ -1,5 +1,6 @@
"""Various checks for Axolotl CLI."""
import logging
import os
from pathlib import Path
@@ -7,9 +8,7 @@ from accelerate.commands.config import config_args
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
LOG = logging.getLogger(__name__)
def check_accelerate_default_config() -> None:

View File

@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
return res
def get_image(self):
docker_tag = "main-py3.11-cu124-2.6.0"
docker_tag = "main-py3.11-cu124-2.5.1"
if self.config.docker_tag:
docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}"

View File

@@ -1,6 +1,7 @@
"""Configuration loading and processing."""
import json
import logging
import os
import tempfile
from pathlib import Path
@@ -21,12 +22,11 @@ from axolotl.utils.config import (
validate_config,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__, use_environ=True)
LOG = logging.getLogger(__name__)
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
@@ -119,12 +119,12 @@ def choose_config(path: Path) -> str:
)
if len(yaml_files) == 1:
LOG.info(f"Using default YAML file '{yaml_files[0]}'")
print(f"Using default YAML file '{yaml_files[0]}'")
return str(yaml_files[0])
LOG.info("Choose a YAML file:")
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
LOG.info(f"{idx + 1}. {file}")
print(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
@@ -133,9 +133,9 @@ def choose_config(path: Path) -> str:
if 1 <= choice <= len(yaml_files):
chosen_file = str(yaml_files[choice - 1])
else:
LOG.info("Invalid choice. Please choose a number from the list.")
print("Invalid choice. Please choose a number from the list.")
except ValueError:
LOG.info("Invalid input. Please enter a number.")
print("Invalid input. Please enter a number.")
return chosen_file

View File

@@ -1,5 +1,6 @@
"""CLI to run evaluation on a model."""
import logging
import os
from pathlib import Path
from typing import Union
@@ -16,9 +17,8 @@ from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate
from axolotl.utils import patch_optimized_env
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
LOG = logging.getLogger(__name__)
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:

View File

@@ -1,6 +1,7 @@
"""CLI to run inference on a trained model."""
import importlib
import logging
import sys
from pathlib import Path
from threading import Thread
@@ -21,9 +22,8 @@ from axolotl.utils.chat_templates import (
get_chat_template_from_config,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
LOG = logging.getLogger(__name__)
def get_multi_line_input() -> str:

View File

@@ -2,6 +2,7 @@
# pylint: disable=redefined-outer-name
import logging
import os
import subprocess # nosec B404
import tempfile
@@ -16,7 +17,6 @@ import axolotl
from axolotl.cli.args import (
EvaluateCliArgs,
PreprocessCliArgs,
QuantizeCliArgs,
TrainerCliArgs,
VllmServeCliArgs,
)
@@ -30,11 +30,8 @@ from axolotl.cli.utils import (
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import patch_optimized_env
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import AxolotlInputConfig
LOG = get_logger(__name__)
@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
@@ -179,7 +176,7 @@ def train(
do_cli(config=cfg_file, **kwargs)
except subprocess.CalledProcessError as exc:
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:
raise exc
@@ -336,16 +333,6 @@ def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
do_vllm_serve(config, cli_args)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(QuantizeCliArgs)
@filter_none_kwargs
def quantize(config: str, **cli_args: QuantizeCliArgs):
from axolotl.cli.quantize import do_quantize
do_quantize(config, cli_args)
@cli.command()
@click.argument("model", type=click.Path(exists=True, path_type=str))
@click.argument("output", type=click.Path(exists=False, path_type=str))

View File

@@ -1,18 +1,20 @@
"""CLI to merge a trained LoRA into a base model."""
import logging
from pathlib import Path
from typing import Union
import fire
import transformers
from dotenv import load_dotenv
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
LOG = logging.getLogger(__name__)
def do_merge_lora(*, cfg: DictDefault) -> None:
@@ -66,6 +68,12 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
Raises:
ValueError: If target directory for LoRA merged model does not exist.
"""
# pylint: disable=duplicate-code
parser = transformers.HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(
config,

View File

@@ -1,6 +1,7 @@
"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint."""
import json
import logging
import os
import shutil
from pathlib import Path
@@ -10,6 +11,7 @@ import fire
import torch
import torch.distributed.checkpoint as dist_cp
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
import transformers
from accelerate.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
@@ -22,11 +24,11 @@ from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
LOG = logging.getLogger(__name__)
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
@@ -195,6 +197,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
"""
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(config, **kwargs)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"

View File

@@ -1,5 +1,6 @@
"""CLI to run preprocessing of a dataset."""
import logging
import warnings
from pathlib import Path
from typing import Union
@@ -19,10 +20,9 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.trainer import disable_datasets_caching
LOG = get_logger(__name__)
LOG = logging.getLogger(__name__)
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:

View File

@@ -1,90 +0,0 @@
"""
CLI to post-training quantize a model using torchao
"""
from pathlib import Path
from typing import Union
from transformers import AutoModelForCausalLM
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer
from axolotl.utils.logging import get_logger
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
LOG = get_logger(__name__)
def do_quantize(
config: Union[Path, str],
cli_args: dict,
):
"""
Quantizes a model's model's weights
Args:
config (Union[Path, str]): The path to the config file
cli_args (dict): Additional command-line arguments
"""
print_axolotl_text_art()
cfg = load_cfg(config)
if cfg.qat and cfg.quantization:
raise ValueError(
"QAT and quantization cannot be used together. Please specify only one of qat or quantization in your config file."
)
if cfg.qat:
quantize_cfg = cfg.qat
elif cfg.quantization:
quantize_cfg = cfg.quantization
else:
raise ValueError(
"No quantization configuration found. Please specify either qat or quantization in your config file."
)
model_path = cli_args.get("model_path") or cfg.output_dir
if weight_dtype := cli_args.get("weight_dtype"):
weight_dtype = TorchIntDType[weight_dtype]
else:
weight_dtype = quantize_cfg.weight_dtype
if activation_dtype := cli_args.get("activation_dtype"):
activation_dtype = TorchIntDType[activation_dtype]
else:
activation_dtype = quantize_cfg.activation_dtype
group_size = cli_args.get("group_size") or quantize_cfg.group_size
quantize_embedding = (
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
)
output_dir = cli_args.get("output_dir") or cfg.output_dir
LOG.info(f"Loading model from {model_path}...")
tokenizer = load_tokenizer(cfg)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
LOG.info(
f"Quantizing model with configuration: \n"
f"\tweight_dtype: {weight_dtype}\n"
f"\tactivation_dtype: {activation_dtype}\n"
f"\tgroup_size: {group_size}\n"
f"\tquantize_embedding: {quantize_embedding}"
)
quantize_model_for_ptq(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
)
tokenizer.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
)
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")

View File

@@ -1,6 +1,7 @@
"""CLI to run training on a model."""
import gc
import logging
import os
from pathlib import Path
from typing import Union
@@ -21,6 +22,8 @@ from axolotl.utils import patch_optimized_env
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
"""

View File

@@ -4,6 +4,7 @@ import concurrent.futures
import dataclasses
import hashlib
import json
import logging
from functools import wraps
from pathlib import Path
from types import NoneType
@@ -19,12 +20,10 @@ from transformers import (
ProcessorMixin,
)
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.loaders.model import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.models import load_model, load_processor, load_tokenizer
LOG = get_logger(__name__)
LOG = logging.getLogger(__name__)
def strip_optional_type(field_type: type | str | None):
@@ -319,8 +318,7 @@ def load_model_and_tokenizer(
tokenizer = load_tokenizer(cfg)
LOG.info("loading model...")
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
model, _ = model_loader.load()
model, _ = load_model(cfg, tokenizer, inference=inference)
processor = None
if cfg.is_multimodal:

View File

@@ -2,27 +2,14 @@
CLI to start the vllm server for online RL
"""
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
import trl
from trl.scripts.vllm_serve import ScriptArguments
from axolotl.cli.config import load_cfg
@dataclass
class AxolotlScriptArguments(ScriptArguments):
"""
Additional arguments for the VLLM server
"""
reasoning_parser: str = field(default="", kw_only=True)
enable_reasoning: bool | None = field(default=None, kw_only=True)
def do_vllm_serve(
config: Union[Path, str],
cli_args: dict,
@@ -37,7 +24,6 @@ def do_vllm_serve(
Returns:
process_id: the process id of the started VLLM server
"""
patch_vllm_worker()
cfg = load_cfg(config)
model = cfg.base_model
@@ -57,16 +43,9 @@ def do_vllm_serve(
enable_prefix_caching = (
cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching
)
reasoning_parser = (
cli_args.get("reasoning_parser") or cfg.vllm.reasoning_parser or ""
)
enable_reasoning = (
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
)
# pylint: disable=unexpected-keyword-arg
vllm_script_args = AxolotlScriptArguments(
model=model,
vllm_script_args = ScriptArguments(
model,
tensor_parallel_size=tensor_parallel_size,
host=host,
port=port,
@@ -74,67 +53,5 @@ def do_vllm_serve(
dtype=dtype,
max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching,
reasoning_parser=reasoning_parser,
enable_reasoning=enable_reasoning,
)
vllm_serve_main(vllm_script_args)
def patch_vllm_worker():
from multiprocessing.connection import Connection
from vllm import LLM
def llm_worker(
script_args: AxolotlScriptArguments,
data_parallel_rank: int,
master_port: int,
connection: Connection,
) -> None:
# Set required environment variables for DP to work with vLLM
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
llm = LLM(
model=script_args.model,
revision=script_args.revision,
tensor_parallel_size=script_args.tensor_parallel_size,
gpu_memory_utilization=script_args.gpu_memory_utilization,
enforce_eager=script_args.enforce_eager,
dtype=script_args.dtype,
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=script_args.enable_prefix_caching,
kv_cache_dtype=script_args.kv_cache_dtype,
max_model_len=script_args.max_model_len,
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
enable_reasoning=script_args.enable_reasoning,
reasoning_parser=script_args.reasoning_parser,
)
# Send ready signal to parent process
connection.send({"status": "ready"})
while True:
# Wait for commands from the parent process
try:
command = connection.recv()
except KeyboardInterrupt:
llm.collective_rpc(method="close_communicator")
break
# Handle commands
if command["type"] in ["call", "fire_and_forget"]:
method_name = command["method"]
args, kwargs = command.get("args", ()), command.get("kwargs", {})
method = getattr(llm, method_name)
result = method(*args, **kwargs)
if command["type"] == "call":
connection.send(result)
elif command["type"] == "shutdown":
break
trl.scripts.vllm_serve.llm_worker = llm_worker

View File

@@ -1,3 +1,5 @@
"""Various shared constants"""
"""
Various shared constants
"""
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"

View File

@@ -1,21 +1,23 @@
"""Dataset loading utilities."""
import logging
import math
import random
from dataclasses import dataclass
from typing import Optional, Union
from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.tokenization import check_dataset_labels
LOG = get_logger(__name__)
LOG = logging.getLogger(__name__)
@dataclass
@@ -28,7 +30,16 @@ class TrainDatasetMeta:
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
"""Randomly sample `num_samples` samples with replacement from `dataset`."""
"""
Randomly sample `num_samples` samples from `dataset`.
Args:
dataset: Dataset.
num_samples: Number of samples to return.
Returns:
Random sample (with replacement) of examples in `dataset`.
"""
return dataset.select(
[random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec
)
@@ -40,37 +51,44 @@ def load_datasets(
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
debug: bool = False,
) -> TrainDatasetMeta:
"""Loads one or more training or evaluation datasets, calling
`axolotl.utils.data.prepare_datasets`. Optionally, logs out debug information.
"""
Loads one or more training or evaluation datasets, calling
`axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments.
debug: Whether to print out tokenization of sample. This is duplicated in
`cfg` and `cli_args`, but is kept due to use in our Colab notebooks.
debug: Whether to print out tokenization of sample
Returns:
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
`total_num_steps`.
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = getattr(cli_args, "iterable", False)
preprocess_iterable = (
cli_args
and hasattr(cli_args, "iterable")
and cli_args.iterable is not None
and cli_args.iterable
)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if (
cfg.debug
or getattr(cli_args, "debug", False)
or getattr(cli_args, "debug_text_only", False)
or getattr(cli_args, "debug_num_examples", 0) > 0
or debug
):
if ( # pylint: disable=too-many-boolean-expressions
cli_args
and (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
)
) or debug:
LOG.info("check_dataset_labels...")
num_examples = cli_args.debug_num_examples if cli_args else 1
@@ -95,10 +113,13 @@ def load_datasets(
def load_preference_datasets(
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None
*,
cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta:
"""Loads one or more training or evaluation datasets for RL training using paired
preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`.
"""
Loads one or more training or evaluation datasets for RL training using paired
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
Optionally, logs out debug information.
Args:
@@ -109,28 +130,23 @@ def load_preference_datasets(
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
"""
tokenizer = load_tokenizer(cfg)
train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer)
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps: Optional[int] = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cfg.rl is RLType.GRPO:
total_num_steps = None
total_num_steps: int | None = None
if cfg.rl is not RLType.GRPO:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if (cli_args and cli_args.debug) or cfg.debug:
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
num_examples = cli_args.debug_num_examples if cli_args else 1
text_only = cli_args.debug_text_only if cli_args else False
tokenizer = load_tokenizer(cfg)
train_samples = sample_dataset(train_dataset, num_examples)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels(
dataset=train_samples,
tokenizer=tokenizer,
num_examples=num_examples,
text_only=text_only,
train_samples,
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
)

View File

@@ -1,6 +0,0 @@
"""Trainer builder classes"""
from .causal import HFCausalTrainerBuilder
from .rl import HFRLTrainerBuilder
__all__ = ["HFCausalTrainerBuilder", "HFRLTrainerBuilder"]

View File

@@ -1,508 +0,0 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for trainer builder"""
import abc
import importlib
import logging
import sys
from abc import abstractmethod
from contextlib import suppress
from pathlib import Path
from typing import Any
import torch
from transformers import (
TrainerCallback,
)
from transformers.training_args import OptimizerNames
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
GCCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__)
with suppress(ImportError):
import torch._dynamo # pylint: disable=ungrouped-imports
class TrainerBuilderBase(abc.ABC):
"""Base class for trainer builder."""
def __init__(self, cfg, model, tokenizer, processor=None):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
self.processor = processor
self._train_dataset = None
self._eval_dataset = None
self._model_ref = None
self._peft_config = None
# If the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
# model.push_to_hub instead of trainer.push_to_hub.
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
patch_trainer_get_lr()
@property
def model_ref(self):
return self._model_ref
@model_ref.setter
def model_ref(self, model):
self._model_ref = model
@property
def train_dataset(self):
return self._train_dataset
@train_dataset.setter
def train_dataset(self, dataset):
self._train_dataset = dataset
@property
def eval_dataset(self):
return self._eval_dataset
@eval_dataset.setter
def eval_dataset(self, dataset):
self._eval_dataset = dataset
@property
def peft_config(self):
return self._peft_config
@peft_config.setter
def peft_config(self, peft_config):
self._peft_config = peft_config
@abstractmethod
def build(self, total_num_steps):
pass
def get_callbacks(self) -> list[TrainerCallback]:
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(
steps_to_profile=self.cfg.profiler_steps,
)
)
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
]
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
callbacks.append(GPUStatsCallback(cfg=self.cfg))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
callbacks = []
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
callbacks.extend(
[
cb
for cb in plugin_manager.add_callbacks_post_trainer(
self.cfg, trainer
)
if cb
]
)
return callbacks
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
return training_arguments_kwargs
def hook_post_create_training_args(self, training_arguments):
# TODO
return training_arguments
def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
# TODO
return trainer_kwargs, trainer_cls
def hook_post_create_trainer(self, trainer):
# TODO
return trainer
def _configure_warmup_and_logging(
self, total_num_steps: int, training_args_kwargs: dict
):
warmup_steps = 0
warmup_ratio = 0.0
if self.cfg.warmup_steps:
warmup_steps = self.cfg.warmup_steps
elif self.cfg.warmup_ratio:
if total_num_steps:
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
else:
warmup_ratio = self.cfg.warmup_ratio
elif total_num_steps:
warmup_steps = min(int(0.03 * total_num_steps), 100)
else:
warmup_ratio = 0.03
if warmup_steps == 1:
warmup_steps = 2
if self.cfg.logging_steps is not None:
training_args_kwargs["logging_steps"] = self.cfg.logging_steps
else:
training_args_kwargs["logging_steps"] = (
500 # transformers defaults to 500
if not total_num_steps
else max(min(int(0.005 * total_num_steps), 10), 1)
)
training_args_kwargs["warmup_ratio"] = warmup_ratio
training_args_kwargs["warmup_steps"] = warmup_steps
def _configure_precision_settings(self, training_args_kwargs: dict):
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
training_args_kwargs["tf32"] = self.cfg.tf32
if self.cfg.bf16 == "full":
training_args_kwargs["bf16_full_eval"] = True
else:
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
def _configure_scheduler(self, training_args_kwargs: dict):
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
training_args_kwargs["lr_scheduler_type"] = "cosine"
training_args_kwargs["alternate_lr_scheduler_type"] = self.cfg.lr_scheduler
else:
training_args_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
)
training_args_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
def _configure_optimizer(self, training_args_kwargs: dict, trainer_kwargs: dict):
def _configure_custom_optimizer(
training_args_kwargs: dict, trainer_kwargs: dict
):
# Common optimizer kwargs
optimizer_kwargs = {
"lr": training_args_kwargs["learning_rate"],
"weight_decay": training_args_kwargs["weight_decay"],
}
# Adam-specific kwargs
adam_kwargs: dict = {}
if training_args_kwargs.get("adam_beta1") and training_args_kwargs.get(
"adam_beta2"
):
adam_kwargs["betas"] = (
training_args_kwargs.get("adam_beta1"),
training_args_kwargs.get("adam_beta2"),
)
if training_args_kwargs.get("adam_epsilon"):
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
MuonOptimizerFactory,
)
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "came_pytorch":
from came_pytorch import CAME
optimizer_cls = CAME
beta1 = training_args_kwargs.get("adam_beta1", 0.9)
beta2 = training_args_kwargs.get("adam_beta2", 0.999)
beta3 = training_args_kwargs.get("adam_beta3", 0.9999)
eps1 = training_args_kwargs.get("adam_epsilon", 1e-30)
eps2 = training_args_kwargs.get("adam_epsilon2", 1e-16)
adam_kwargs["betas"] = (beta1, beta2, beta3)
adam_kwargs["eps"] = (eps1, eps2)
optimizer_kwargs.update(adam_kwargs)
else:
raise ValueError(
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
)
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optimizer_kwargs.update(self.cfg.optim_args)
else:
# Parse string format "key1=value1,key2=value2"
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optimizer_kwargs[key] = value
# Note: This is not used in training_args_kwargs, but in trainer_kwargs
trainer_kwargs["optimizer_cls_and_kwargs"] = (
optimizer_cls,
optimizer_kwargs,
)
# Handle custom optimizer
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
if self.cfg.optimizer in custom_supported_optimizers:
_configure_custom_optimizer(training_args_kwargs, trainer_kwargs)
else:
# Use transformers' optimizer
training_args_kwargs["optim"] = self.cfg.optimizer
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_args_kwargs["optim_args"] = optim_args
if (
self.cfg.optimizer == "adamw_anyprecision"
and Path(self.cfg.torchdistx_path).exists()
):
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
def _configure_hub_parameters(self, training_args_kwargs: dict):
if self.cfg.hub_model_id:
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_args_kwargs["push_to_hub"] = True
training_args_kwargs["hub_private_repo"] = True
training_args_kwargs["hub_always_push"] = True
if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
def _configure_save_and_eval_strategy(self, training_args_kwargs: dict):
# save_strategy and save_steps
if self.cfg.save_steps:
training_args_kwargs["save_strategy"] = "steps"
training_args_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["save_total_limit"] = (
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
)
# eval_strategy and eval_steps
if not self.eval_dataset and self.cfg.val_set_size == 0:
# do not eval if no eval_dataset and val_set_size=0
training_args_kwargs["eval_strategy"] = "no"
elif self.cfg.eval_steps:
training_args_kwargs["eval_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
training_args_kwargs["eval_on_start"] = True
elif self.cfg.eval_strategy:
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
training_args_kwargs["eval_on_start"] = True
def _configure_reporting(self, training_args_kwargs: dict):
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.use_mlflow:
report_to.append("mlflow")
if self.cfg.use_tensorboard:
report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
training_args_kwargs["report_to"] = report_to
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
elif self.cfg.use_mlflow:
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_args_kwargs["run_name"] = None
def _configure_torch_compile(self, training_args_kwargs: dict):
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_args_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_args_kwargs["torch_compile_backend"] = (
self.cfg.torch_compile_backend
)
if self.cfg.torch_compile_mode:
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs["gradient_checkpointing_kwargs"] = (
self.cfg.gradient_checkpointing_kwargs
)
else:
training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
}
def _set_base_training_args(
self, total_num_steps
) -> tuple[dict[str, Any], dict[str, Any]]:
training_args_kwargs: dict[str, Any] = {}
trainer_kwargs: dict[str, Any] = {}
self._configure_warmup_and_logging(total_num_steps, training_args_kwargs)
self._configure_precision_settings(training_args_kwargs)
self._configure_save_and_eval_strategy(training_args_kwargs)
self._configure_gradient_checkpointing(training_args_kwargs)
# set arg into trainer_args_kwargs with same name if value not None
for arg in [
# optim/scheduler
"adam_beta1",
"adam_beta2",
"adam_beta3",
"adam_epsilon",
"adam_epsilon2",
"cosine_min_lr_ratio",
"cosine_constant_lr_ratio",
"optim_target_modules",
# trainer
"max_grad_norm",
"dataloader_num_workers",
"dataloader_pin_memory",
"dataloader_prefetch_factor",
"gradient_accumulation_steps",
"learning_rate",
"embedding_lr",
"embedding_lr_scale",
"lr_groups",
"loraplus_lr_ratio",
"loraplus_lr_embedding",
"output_dir",
"save_safetensors",
"save_only_model",
"include_tokens_per_second",
"weight_decay",
"seed",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
if self.cfg.eval_batch_size:
training_args_kwargs["per_device_eval_batch_size"] = (
self.cfg.eval_batch_size
)
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
# max_length is not used in CausalTrainer
if self.cfg.reward_model or self.cfg.rl:
training_args_kwargs["max_length"] = self.cfg.sequence_len
self._configure_reporting(training_args_kwargs)
self._configure_hub_parameters(training_args_kwargs)
self._configure_scheduler(training_args_kwargs)
self._configure_optimizer(training_args_kwargs, trainer_kwargs)
self._configure_torch_compile(training_args_kwargs)
return training_args_kwargs, trainer_kwargs

View File

@@ -1,488 +0,0 @@
"""Builder for causal trainers"""
import inspect
import math
import os
from pathlib import Path
from typing import Type, Union
import transformers
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
)
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.builders.base import TrainerBuilderBase
from axolotl.core.trainers import (
AxolotlMambaTrainer,
AxolotlPRMTrainer,
AxolotlRewardTrainer,
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for causal models and reward modeling
using TRL.
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
if self.cfg.relora_steps:
callbacks.append(ReLoRACallback(self.cfg))
if (
hasattr(self.model, "use_bettertransformer")
and self.model.use_bettertransformer is True
):
callbacks.append(SaveBetterTransformerModelCallback())
# TODO: check if can move to base class
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
if self.cfg.qat:
callbacks.append(QATCallback(self.cfg.qat))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
)
callbacks.append(LogPredictionCallback(self.cfg))
if (
self.cfg.use_mlflow
and is_mlflow_available()
and self.cfg.eval_table_size > 0
):
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "mlflow"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "comet_ml"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.do_causal_lm_eval:
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
trainer, self.tokenizer
)
callbacks.append(CausalLMBenchEvalCallback(self.cfg))
if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
self.cfg.early_stopping_patience,
)
callbacks.append(early_stop_cb)
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))
if any("COLAB_" in key for key in os.environ):
ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg))
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks
def _get_trainer_cls(self):
"""
Gets the trainer class for the given configuration.
"""
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
if trainer_cls:
return trainer_cls
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
if self.cfg.process_reward_model:
return AxolotlPRMTrainer
return AxolotlTrainer
def build(self, total_num_steps):
from axolotl.core.training_args import (
AxolotlPRMConfig,
AxolotlRewardConfig,
AxolotlTrainingArguments,
)
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps
)
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = {
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
}
if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True
# deepspeed
if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
if self.cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs["lr_quadratic_warmup"] = (
self.cfg.lr_quadratic_warmup
)
if self.cfg.dataloader_drop_last is not None:
training_arguments_kwargs["dataloader_drop_last"] = (
self.cfg.dataloader_drop_last
)
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True
if self.cfg.remove_unused_columns is not None:
training_arguments_kwargs["remove_unused_columns"] = (
self.cfg.remove_unused_columns
)
if self.cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model:
training_arguments_kwargs["metric_for_best_model"] = (
self.cfg.metric_for_best_model
)
if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
# DDP Config
if self.cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
if self.cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
if self.cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs["ddp_broadcast_buffers"] = (
self.cfg.ddp_broadcast_buffers
)
# these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
if self.cfg.auto_find_batch_size is not None:
training_arguments_kwargs["auto_find_batch_size"] = (
self.cfg.auto_find_batch_size
)
training_arguments_kwargs["eval_accumulation_steps"] = (
self.cfg.gradient_accumulation_steps
)
training_arguments_kwargs["load_best_model_at_end"] = (
(
self.cfg.load_best_model_at_end is not False
or self.cfg.early_stopping_patience
)
and (
(not self.cfg.test_datasets and self.cfg.val_set_size > 0)
or (self.cfg.test_datasets and self.cfg.val_set_size == 0)
)
and self.cfg.save_steps
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
) or False
# handle ddp
ddp_find_unused_parameters = None
if self.cfg.ddp:
ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters)
training_arguments_kwargs["ddp_find_unused_parameters"] = (
ddp_find_unused_parameters
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
training_arguments_kwargs["multipack_real_batches"] = (
self.cfg.multipack_real_batches
if self.cfg.multipack_real_batches is not None
else not self.cfg.flash_attention
)
training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing
)
if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs["sample_packing_bin_size"] = (
self.cfg.sample_packing_bin_size
)
if self.cfg.sample_packing_group_size is not None:
training_arguments_kwargs["sample_packing_group_size"] = (
self.cfg.sample_packing_group_size
)
if self.cfg.sample_packing_eff_est:
training_arguments_kwargs["sample_packing_efficiency"] = (
self.cfg.sample_packing_eff_est
)
if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = (
self.cfg.relora_warmup_steps
)
if self.cfg.relora_anneal_steps:
training_arguments_kwargs["relora_anneal_steps"] = (
self.cfg.relora_anneal_steps
)
if self.cfg.relora_prune_ratio:
training_arguments_kwargs["relora_prune_ratio"] = (
self.cfg.relora_prune_ratio
)
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
training_arguments_kwargs["lisa_step_interval"] = (
self.cfg.lisa_step_interval
)
training_arguments_kwargs["lisa_layers_attribute"] = (
self.cfg.lisa_layers_attribute
)
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template_from_config(
cfg=self.cfg,
tokenizer=self.tokenizer,
)
if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs["neftune_noise_alpha"] = (
self.cfg.neftune_noise_alpha
)
if self.cfg.accelerator_config:
training_arguments_kwargs["accelerator_config"] = (
self.cfg.accelerator_config
)
if self.cfg.image_size:
training_arguments_kwargs["image_size"] = self.cfg.image_size
if self.cfg.image_resize_algorithm:
training_arguments_kwargs["image_resize_algorithm"] = (
self.cfg.image_resize_algorithm
)
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
plugin_training_args = plugin_manager.get_training_args(self.cfg)
if plugin_training_args:
training_arguments_kwargs.update(plugin_training_args)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
elif self.cfg.process_reward_model:
training_args_cls = AxolotlPRMConfig
else:
training_args_cls = AxolotlTrainingArguments
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
training_args = self.hook_post_create_training_args(training_args)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
multiple = 64
if self.cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
self.cfg.sequence_len / multiple
)
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
if eval_data_collator := self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
):
if not (self.cfg.reward_model or self.cfg.process_reward_model):
trainer_kwargs["eval_data_collator"] = eval_data_collator
if not (self.cfg.reward_model or self.cfg.process_reward_model):
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
)
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters:
trainer_kwargs["processing_class"] = self.tokenizer
elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
if (
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
and self.cfg.datasets is not None
):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=self.build_collator(training_args, **data_collator_kwargs),
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
if self.cfg.deepspeed and self.cfg.sample_packing:
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
"train_micro_batch_size_per_gpu"
] = self.cfg.micro_batch_size
return trainer
def build_collator(
self,
training_args, # type: "AxolotlTrainingArguments" # type: ignore
is_eval=False,
**kwargs,
):
if training_args.pretraining:
if (
self.cfg.pretraining_sample_concatenation is False
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None
if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)
use_batch_sampler_collator = False
if is_eval is False and training_args.sample_packing:
use_batch_sampler_collator = True
if is_eval and training_args.eval_sample_packing:
use_batch_sampler_collator = True
collator: Type[
Union[
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
DataCollatorWithFlattening,
RewardDataCollatorWithPadding,
]
]
collator_args = [self.tokenizer]
collator_cls_and_kwargs = None
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs(
self.cfg, is_eval=is_eval
)
if collator_cls_and_kwargs:
collator = collator_cls_and_kwargs[0]
if kwargs and isinstance(kwargs, dict):
kwargs.update(collator_cls_and_kwargs[1])
elif self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator:
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
# supported multipack models, or non-flash-attention llama
if (
self.cfg.flex_attention
or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
or (
self.cfg.model_config_type in ["llama"]
and self.cfg.flash_attention is not True
)
):
collator = V2BatchSamplerDataCollatorForSeq2Seq
else:
collator = BatchSamplerDataCollatorForSeq2Seq
else:
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
kwargs["processing_strategy"] = get_processing_strategy(
self.processor,
training_args.chat_template,
self.cfg.chat_template,
image_size=training_args.image_size,
image_resize_algorithm=training_args.image_resize_algorithm,
)
elif self.cfg.batch_flattening:
collator = DataCollatorWithFlattening
collator_args.pop(0)
kwargs.pop("pad_to_multiple_of", None)
kwargs.pop("padding", None)
else:
collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt"
return collator(
*collator_args,
**kwargs,
)

View File

@@ -1,238 +0,0 @@
"""Builder for RLHF trainers"""
import inspect
from pathlib import Path
from axolotl.core.builders.base import TrainerBuilderBase
from axolotl.core.trainers import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlORPOTrainer,
)
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
LOG = get_logger(__name__)
class HFRLTrainerBuilder(TrainerBuilderBase):
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
def get_callbacks(self):
callbacks = super().get_callbacks()
if self.cfg.qat:
callbacks.append(QATCallback(self.cfg.qat))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def _get_trainer_cls(self, trainer_kwargs: dict):
"""
Returns trainer_cls and trainer_cls_args
"""
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
trainer_cls_args = [] # type: ignore
if trainer_cls is not None:
return trainer_cls, trainer_cls_args
trainer_cls = None
trainer_cls_args = [self.model]
if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args.append(self.model_ref)
elif self.cfg.rl is RLType.ORPO:
trainer_cls = AxolotlORPOTrainer
elif self.cfg.rl is RLType.KTO:
trainer_cls = AxolotlKTOTrainer
elif self.cfg.rl is RLType.SIMPO:
trainer_cls = AxolotlCPOTrainer
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
return trainer_cls, trainer_cls_args
def _build_training_arguments(self, total_num_steps):
"""
Returns training_args and trainer_kwargs
"""
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlKTOConfig,
AxolotlORPOConfig,
)
training_args_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps=total_num_steps
)
if self.cfg.remove_unused_columns is not None:
training_args_kwargs["remove_unused_columns"] = (
self.cfg.remove_unused_columns
)
else:
training_args_kwargs["remove_unused_columns"] = False
if self.cfg.trl and self.cfg.trl.beta is not None:
training_args_kwargs["beta"] = self.cfg.trl.beta
elif self.cfg.rl_beta is not None:
training_args_kwargs["beta"] = self.cfg.rl_beta
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl is RLType.SIMPO:
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
)
training_args_kwargs["undesirable_weight"] = (
self.cfg.kto_undesirable_weight or 1.0
)
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl is RLType.GRPO:
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
training_args_cls = AxolotlDPOConfig
training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg))
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
plugin_training_args = plugin_manager.get_training_args(self.cfg)
if plugin_training_args:
training_args_kwargs.update(plugin_training_args)
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
logging_first_step=True,
**training_args_kwargs,
)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
return training_args, trainer_kwargs
def build(self, total_num_steps):
training_args, trainer_kwargs = self._build_training_arguments(total_num_steps)
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs
)
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
else:
trainer_kwargs["processing_class"] = self.tokenizer
if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()
):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
if self.cfg.fsdp:
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
trainer = self.hook_post_create_trainer(trainer)
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
return trainer
class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
HF Factory class for PPO Trainer
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build(self, total_num_steps):
# TODO: build PPOConfig
raise NotImplementedError("PPO trainer builder is not implemented yet.")

View File

@@ -156,6 +156,7 @@ class Messages(BaseModel):
len(input_ids) : len(input_ids) + len(pending_input_ids)
]
if new_pending_inputs != pending_input_ids:
# logging.warning("tokenization mismatch from concatenation.")
pending_input_ids = new_pending_inputs
input_ids.extend(pending_input_ids)
if pending_weight:

File diff suppressed because it is too large Load Diff

View File

@@ -4,10 +4,11 @@
from __future__ import annotations
import logging
import os
from collections import defaultdict
from functools import partial, wraps
from typing import Callable, Literal, Optional
from functools import wraps
from typing import Literal
import datasets
import torch
@@ -25,24 +26,22 @@ from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.mixins import (
CheckpointSaveMixin,
OptimizerMixin,
RngLoaderMixin,
SchedulerMixin,
SequenceParallelMixin,
)
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
from axolotl.utils import get_not_null
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
LOG = logging.getLogger(__name__)
class AxolotlTrainer(
SchedulerMixin, OptimizerMixin, RngLoaderMixin, CheckpointSaveMixin, Trainer
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
):
"""Extend the base Trainer for axolotl helpers"""
@@ -69,6 +68,10 @@ class AxolotlTrainer(
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# Initialize sequence parallelism if enabled
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
@@ -105,7 +108,7 @@ class AxolotlTrainer(
)
batch_max_len = train_batch_size * self.args.max_seq_length
sampler = MultipackBatchSampler(
return MultipackBatchSampler(
base_sampler,
lengths=get_dataset_lengths(dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
@@ -115,18 +118,12 @@ class AxolotlTrainer(
bin_size=self.args.sample_packing_bin_size,
sequential=self.args.sample_packing_sequentially,
drop_last=True,
num_processes=self.args.dataset_num_proc,
)
len(sampler)
return sampler
def _get_train_sampler(
self, train_dataset: Optional[Dataset] = None
) -> Optional[Sampler]:
def _get_train_sampler(self) -> Sampler | None:
"""
Helper method to get the sampler for training. Handles cases for sample packing
and curriculum sampling (sequential).
Helper method to get the sampler for training. Handles cases for sequence
parallelism, sample packing, and curriculum sampling (sequential).
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
@@ -135,7 +132,9 @@ class AxolotlTrainer(
use_sample_packing = self.args.sample_packing and not self.args.pretraining
# Determine the base sampler first
if self.args.curriculum_sampling:
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_train_sampler(self.train_dataset)
elif self.args.curriculum_sampling:
base_sampler = SequentialSampler(self.train_dataset)
elif use_sample_packing:
base_sampler = RandomSampler(self.train_dataset)
@@ -147,26 +146,31 @@ class AxolotlTrainer(
if use_sample_packing:
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=train_dataset,
dataset=self.train_dataset,
)
return base_sampler
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
"""
Helper method to get the sampler for evaluation. Handles sample packing case.
Helper method to get the sampler for evaluation. Handles sequence parallelism
and sample packing cases.
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args.
"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Multipacking enabled if training is enabled and eval is not explicitly disabled
use_multipack = (
self.args.sample_packing and self.args.eval_sample_packing is not False
)
# Determine the base sampler
if use_multipack:
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_eval_sampler(eval_dataset)
elif use_multipack:
base_sampler = SequentialSampler(eval_dataset)
else:
return super()._get_eval_sampler(eval_dataset)
@@ -180,93 +184,149 @@ class AxolotlTrainer(
return base_sampler
def _get_dataloader(
self,
dataset: Dataset,
description: str,
batch_size: int,
sampler_fn: Optional[Callable[[Dataset], torch.utils.data.Sampler]] = None,
is_training: bool = False,
dataloader_key: Optional[str] = None,
) -> DataLoader:
"""Create a [`~torch.utils.data.DataLoader`] from the given dataset."""
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
"""Create common dataloader parameters for train or eval."""
batch_size = custom_batch_size or (
self.args.eval_batch_size if is_eval else self._train_batch_size
)
data_collator = self.data_collator if is_training else self.eval_data_collator
if dataset.column_names and "length" in dataset.column_names:
dataset = dataset.remove_columns(["length"])
if isinstance(dataset, datasets.Dataset):
if is_training:
if not self.args.sample_packing or self.args.pretraining:
dataset = self._remove_unused_columns(
dataset, description="training"
)
elif (
not is_training
and self.args.sample_packing
and self.args.eval_sample_packing is not False
):
batch_size = (
batch_size
if self.args.sample_packing
else self.args.per_device_eval_batch_size
)
else:
dataset = self._remove_unused_columns(dataset, description=description)
else:
data_collator = self._get_collator_with_removed_columns(
self.data_collator, description=description
)
dataloader_params = {
params = {
"batch_size": batch_size,
"collate_fn": data_collator,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(dataset, torch.utils.data.IterableDataset):
dataloader_params["drop_last"] = get_not_null(
self.args.dataloader_drop_last, True
)
if sampler_fn is not None:
sampler = sampler_fn(dataset)
if isinstance(sampler, BatchSampler):
# batch_size and batch_sampler are mutually exclusive
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
del dataloader_params["drop_last"]
else:
dataloader_params["sampler"] = sampler
# Add persistent workers only for training
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
params["persistent_workers"] = self.args.dataloader_persistent_workers
# Add prefetch factor if specified
if self.args.dataloader_prefetch_factor:
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return params
def _prepare_dataloader(
self, dataset, sampler, is_eval=False, custom_batch_size=None
):
"""Prepare a dataloader with the given dataset and sampler."""
# Get base parameters
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
# Add sampler configuration
if not isinstance(dataset, torch.utils.data.IterableDataset):
if isinstance(sampler, BatchSampler):
# batch_size and batch_sampler are mutually exclusive
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
if not is_eval:
dataloader_params["worker_init_fn"] = seed_worker
# Create the dataloader
dataloader = DataLoader(dataset, **dataloader_params)
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if is_training:
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
if self.args.sample_packing and (
(is_training and not self.args.pretraining)
or (not is_training and self.args.eval_sample_packing is not False)
(not is_eval and not self.args.pretraining)
or (is_eval and self.args.eval_sample_packing is not False)
):
self.accelerator.even_batches = False
dataloader = DataLoader(dataset, **dataloader_params)
# Return unprepared dataloader if using sequence parallelism
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
return dataloader
# Accelerator.free_memory() will destroy the references, so
# we need to store the non-prepared version for eval dataloaders.
# fmt: off
if dataloader_key is not None and self.args.dataloader_persistent_workers:
if hasattr(self, "_eval_dataloaders"):
self._eval_dataloaders[dataloader_key] = dataloader # type: ignore # pylint: disable=access-member-before-definition
else:
self._eval_dataloaders = {dataloader_key: dataloader} # pylint: disable=attribute-defined-outside-init
# fmt: on
# Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader)
return self.accelerator.prepare(dataloader)
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
data_collator = self.data_collator # type: ignore
# Handle dataset preprocessing
if isinstance(train_dataset, datasets.Dataset):
if self.args.sample_packing and not self.args.pretraining:
train_dataset = train_dataset.remove_columns(["length"])
if not self.args.sample_packing or self.args.pretraining:
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
data_collator,
description="training",
)
# Get sampler and create dataloader
sampler = self._get_train_sampler()
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
"""Get dataloader for evaluation"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Handle special case: sample packing is enabled but eval_sample_packing is False
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
if "length" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns(["length"])
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader
# Handle sample packing or sequence parallelism
if (
self.args.sample_packing
and self.args.eval_sample_packing is not False
or self.args.sequence_parallel_degree > 1
):
# Get appropriate data collator
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
if hasattr(self, "eval_data_collator") and self.eval_data_collator
else self.data_collator
)
if "length" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns(["length"])
# Handle dataset preprocessing for SP
if self.args.sequence_parallel_degree > 1:
if isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(
eval_dataset, description="evaluation"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
self.data_collator, description="evaluation"
)
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
batch_size = (
self.args.eval_batch_size
if self.args.sample_packing
else self.args.per_device_eval_batch_size
)
sampler = self._get_eval_sampler(eval_dataset)
dataloader = self._prepare_dataloader(
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
)
return dataloader
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset

View File

@@ -22,19 +22,10 @@ class DPOStrategy:
training_args_kwargs = {}
if cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = "ipo"
# Label smoothing is not compatible with IPO
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
if cfg.dpo_padding_free is not None:
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
if cfg.dpo_norm_loss is not None:
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
if cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
return training_args_kwargs

View File

@@ -14,5 +14,3 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""
dpo_norm_loss: bool | None = False

View File

@@ -1,41 +1,92 @@
"""DPO trainer for axolotl"""
"""
DPO trainer for axolotl
"""
import gc
import random
from functools import wraps
from typing import Any, Dict, Union
from typing import Any, Dict, Optional, Union
import pandas as pd
import torch
import wandb
from accelerate import PartialState
from datasets import Dataset, IterableDataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from trl import DPOTrainer
from torch.utils.data import DataLoader
from transformers import (
BaseImageProcessor,
FeatureExtractionMixin,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
)
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt
from trl.trainer.utils import log_table_to_comet_experiment
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, DPOTrainer
):
"""Extend the base DPOTrainer for axolotl helpers."""
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
def create_optimizer(self):
# pylint: disable=duplicate-code
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
if loraplus_lr_ratio:
print("Using lora+")
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
# pylint: disable=duplicate-code
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing
the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub`
for more details.
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
@@ -44,6 +95,64 @@ class AxolotlDPOTrainer(
return super().push_to_hub(*args, **kwargs)
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
def _prepare_dataset(
self,
dataset: Union[Dataset, IterableDataset],
processing_class: Union[
PreTrainedTokenizerBase,
BaseImageProcessor,
FeatureExtractionMixin,
ProcessorMixin,
],
args: DPOConfig,
dataset_name: str,
) -> Union[Dataset, IterableDataset]:
# Build the kwargs for the `map` function
map_kwargs: Dict[str, Any] = {"writer_batch_size": 10}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
map_kwargs["num_proc"] = args.dataset_num_proc
with PartialState().main_process_first():
# Extract prompt if needed
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
# Apply the chat template if needed
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
dataset = dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class, "tools": args.tools},
**map_kwargs,
)
# Tokenize the dataset
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
dataset = dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
remove_columns=["chosen", "rejected"],
fn_kwargs={
"processing_class": processing_class,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
"add_special_tokens": False,
},
**map_kwargs,
)
return dataset
@staticmethod
def tokenize_row(
features,
@@ -84,19 +193,68 @@ class AxolotlDPOTrainer(
torch.cuda.empty_cache()
return loss
def concatenated_forward(
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
def evaluation_loop(
self,
model: nn.Module,
batch: dict[str, Union[list, torch.LongTensor]],
is_ref_model: bool = False,
) -> dict[str, torch.Tensor]:
if self.args.dpo_norm_loss:
# fmt: off
loss_type: str = self.loss_type # type: ignore[has-type] # pylint: disable=access-member-before-definition
# fmt: on
# concatenated_forward handles avg token logprob for ipo case already
self.loss_type = "ipo" # pylint: disable=attribute-defined-outside-init
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
self.loss_type = loss_type # pylint: disable=attribute-defined-outside-init
return res
return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[list[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
Overriding built-in evaluation loop to store metrics for each batch.
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
# Sample and save to game log if requested (for one batch to save time)
if self.generate_during_eval:
# Generate random indices within the range of the total number of samples
num_samples = len(dataloader.dataset)
random_indices = random.sample(
range(num_samples), k=self.args.eval_batch_size
)
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
random_batch_dataset = dataloader.dataset.select(random_indices)
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)
policy_output_decoded, ref_output_decoded = (
self.generate_from_model_and_ref(self.model, random_batch)
)
table = pd.DataFrame(
columns=["Prompt", "Policy", "Ref Model"],
data=[
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
for prompt, pol, ref in zip(
random_batch_dataset["prompt"],
policy_output_decoded,
ref_output_decoded,
)
],
)
if "wandb" in self.args.report_to and self.accelerator.is_main_process:
wandb.log({"game_log": wandb.Table(data=table)})
if "comet_ml" in self.args.report_to:
log_table_to_comet_experiment(
name="game_log.csv",
table=table,
)
# Base evaluation
initial_output = super( # pylint: disable=bad-super-call
DPOTrainer, self
).evaluation_loop(
dataloader,
description,
prediction_loss_only,
ignore_keys,
metric_key_prefix,
)
return initial_output

View File

@@ -2,6 +2,7 @@
import importlib
import inspect
import logging
from typing import Any
from trl.trainer.grpo_trainer import RewardFunc
@@ -12,10 +13,9 @@ from axolotl.core.trainers.grpo.trainer import (
AxolotlGRPOTrainer,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.trl import TRLConfig
LOG = get_logger(__name__)
LOG = logging.getLogger(__name__)
class GRPOStrategy:
@@ -69,9 +69,6 @@ class GRPOStrategy:
grpo_args_kwargs["log_completions"] = trl.log_completions
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
if cfg.sequence_parallel_degree > 1:
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights
@@ -109,9 +106,7 @@ class GRPOStrategy:
return grpo_args_kwargs
@classmethod
def set_trainer_args(
cls, cfg: DictDefault
) -> list[Any]: # pylint: disable=unused-argument
def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:
trainer_args = []
if cfg.trl and cfg.trl.reward_funcs:
reward_funcs = []
@@ -128,7 +123,6 @@ class GRPOStrategy:
trainer_kwargs["reward_processing_classes"] = (
cfg.trl.reward_processing_classes
)
return trainer_kwargs
@classmethod
@@ -138,7 +132,7 @@ class GRPOStrategy:
@classmethod
def get_blocklist_args_kwargs(cls) -> list[str]:
return ["dataset_num_proc", "max_length"]
return ["dataset_num_proc"]
@classmethod
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
@@ -173,4 +167,4 @@ class GRPOStrategy:
LOG.info(
f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
)
return reward_func_fqn
return reward_func

View File

@@ -12,5 +12,3 @@ from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training"""
sequence_parallel_degree: int | None = None

View File

@@ -3,7 +3,7 @@
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings
from functools import partial
from contextlib import nullcontext
from typing import Any
import datasets
@@ -14,7 +14,7 @@ from accelerate.utils import (
broadcast_object_list,
gather,
gather_object,
is_peft_available,
is_peft_model,
)
from datasets import Dataset, IterableDataset
from torch import nn
@@ -30,13 +30,15 @@ from transformers import (
TrainerCallback,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_peft_available
from trl import GRPOTrainer
from trl.data_utils import (
apply_chat_template,
is_conversational,
maybe_apply_chat_template,
)
from trl.extras.profiling import profiling_context
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.import_utils import is_deepspeed_available
from trl.models import unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc, nanstd
@@ -44,56 +46,67 @@ from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group
if is_peft_available():
# pylint: disable=unused-import
from peft import PeftConfig
if is_deepspeed_available():
import deepspeed
class AxolotlGRPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, GRPOTrainer
):
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""Extend the base GRPOTrainer for axolotl helpers"""
_tag_names = ["trl", "grpo", "axolotl"]
def get_train_dataloader(self):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
@profiling_decorator
def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
gather_if_zero3 = (
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
)
train_dataset = self.train_dataset
data_collator = self.data_collator
if isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
if is_peft_model(self.model):
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
# adapters in a sharded manner is not supported.
with gather_if_zero3(list(self.model.parameters())):
self.model.merge_adapter()
# Update vLLM weights while parameters are gathered
for name, param in self.model.named_parameters():
# When using PEFT, we need to recover the original parameter name and discard some parameters
name = (
name.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", "")
)
if self.model.prefix in name:
continue
# When module to save, remove its prefix and discard the original module
if "original_module" in name:
continue
name = name.replace("modules_to_save.default.", "")
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
# Unmerge adapters while parameters are still gathered
self.model.unmerge_adapter()
# Parameters will automatically be repartitioned when exiting the context
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)
# For non-PEFT models, simply gather and update each parameter individually.
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
dataloader_params = {
"batch_size": self._train_batch_size
* self.args.steps_per_generation, # < this is the change
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
# Reset cache on main process
if self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
@@ -117,7 +130,6 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None
] = (None, None),
peft_config: "PeftConfig | None" = None,
optimizer_cls_and_kwargs: tuple[type, dict] | None = None,
):
# First call the superclass constructor with all arguments
super().__init__(
@@ -131,7 +143,6 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
callbacks=callbacks,
optimizers=optimizers,
peft_config=peft_config,
optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
)
# Get number of SP groups (number of processes divided by SP degree)
@@ -173,13 +184,6 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
f"the valid values for the number of generations are: {possible_values}."
)
self.sp_group = None
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.local_rank = 0
self.local_world_size = 1
def train(self, *args, **kwargs):
# Initialize the SP group
self.sp_group = get_ring_attn_group()
self.rank = dist.get_rank()
@@ -187,8 +191,6 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
self.local_rank = dist.get_rank(group=self.sp_group)
self.local_world_size = dist.get_world_size(group=self.sp_group)
return super().train(*args, **kwargs)
def _get_train_sampler(self) -> Sampler:
effective_batch_size = (
self.args.per_device_train_batch_size

View File

@@ -3,7 +3,7 @@
# pylint: disable=unused-import
# flake8: noqa
from .checkpoints import CheckpointSaveMixin
from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin

View File

@@ -1,21 +0,0 @@
"""Custom handling to not fail training if fsdp optimizer is not savable"""
from transformers import Trainer
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class CheckpointSaveMixin(Trainer):
"""Mixin to handle saving the optimizer and scheduler if they are not savable."""
def _save_optimizer_and_scheduler(self, output_dir):
try:
super()._save_optimizer_and_scheduler(output_dir)
except NotImplementedError as exc:
LOG.warning(
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
"Optimizer and scheduler states were not saved - resuming from checkpoints "
"for this training run will not be possible."
)

View File

@@ -1,17 +1,18 @@
"""Module for Axolotl trainer optimizer mixin"""
import logging
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from transformers.trainer import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.utils.logging import get_logger
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
LOG = get_logger(__name__)
LOG = logging.getLogger(__name__)
class OptimizerMixin(Trainer):
@@ -198,20 +199,3 @@ class OptimizerMixin(Trainer):
)
return self.optimizer
class OptimizerInitMixin:
"""
Mixin to handle common optimizer initialization logic for Trainers (mostly TRL) that do not
accept optimizer_cls_and_kwargs as kwarg in constructor.
"""
def __init__(self, *args, **kwargs):
optimizer_cls_and_kwargs = kwargs.pop("optimizer_cls_and_kwargs", None)
super().__init__(*args, **kwargs)
if (
optimizer_cls_and_kwargs
and self.optimizer_cls_and_kwargs is None
and self.optimizer is None
):
self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs

View File

@@ -6,6 +6,7 @@ See https://github.com/huggingface/transformers/pull/37162
TODO: Remove when upstream added PR to release
"""
import logging
import os
import random
@@ -16,9 +17,7 @@ from transformers.trainer import safe_globals
from transformers.trainer_pt_utils import set_rng_state_for_device
from transformers.training_args import ParallelMode
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
LOG = logging.getLogger(__name__)
class RngLoaderMixin(Trainer):

View File

@@ -1,11 +1,12 @@
"""Module for Axolotl trainer scheduler mixin"""
import logging
import torch
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
from transformers.trainer import Trainer
from axolotl.integrations.base import PluginManager
from axolotl.utils.logging import get_logger
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
@@ -13,7 +14,7 @@ from axolotl.utils.schedulers import (
get_cosine_schedule_with_warmup_decay_constant,
)
LOG = get_logger(__name__)
LOG = logging.getLogger(__name__)
class SchedulerMixin(Trainer):
@@ -79,15 +80,13 @@ class SchedulerMixin(Trainer):
self.lr_scheduler = RexLR(
optimizer=optimizer,
max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (
self.args.learning_rate * self.args.cosine_min_lr_ratio),
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning(
"Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
@@ -116,11 +115,9 @@ class SchedulerMixin(Trainer):
return super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning(
"axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning(
"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler # type: ignore

View File

@@ -0,0 +1,87 @@
"""Module for Axolotl trainer sequence parallelism mixin"""
import torch.distributed as dist
from datasets import Dataset
from torch.utils.data import DistributedSampler, Sampler
from axolotl.monkeypatch.attention.ring_attn import (
get_ring_attn_group,
)
class SequenceParallelMixin:
"""
Mixin class for sequence parallelism support in trainers.
This mixin provides functionality for handling sequence parallelism,
specifically for creating appropriate data samplers.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def _setup_sequence_parallel(self):
"""Set up sequence parallelism environment."""
self.ring_attn_group = get_ring_attn_group()
def _create_sequence_parallel_sampler(
self,
dataset: Dataset,
shuffle: bool = True,
is_eval: bool = False,
) -> DistributedSampler:
"""
Helper method to create sampler for sequence parallelism (SP).
We create a distributed sampler with rank equal to the SP group ID, which
means that all ranks in the SP group receive the same sample / set of samples
per training step. We also set the number of replicas equal to the number of
SP groups, which is a bit of a hack / unintended use, but works!
Args:
dataset: Dataset to sample from.
shuffle: Whether to shuffle the dataset.
is_eval: Whether we are creating a sampler for evaluation or training.
Returns:
Distributed sampler.
"""
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
return DistributedSampler(
dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if shuffle else None,
shuffle=shuffle,
drop_last=not is_eval,
)
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
"""
Get a training sampler configured for sequence parallelism.
Args:
dataset: The training dataset
Returns:
Configured sequence parallel sampler.
"""
return self._create_sequence_parallel_sampler(
dataset,
shuffle=not self.args.curriculum_sampling,
)
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
"""
Get an evaluation sampler configured for sequence parallelism.
Args:
eval_dataset: The evaluation dataset.
Returns:
Configured sequence parallel sampler.
"""
return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)

View File

@@ -1,5 +1,7 @@
"""Module for TRL PPO trainer"""
from typing import Literal, Union
import torch
from tqdm import tqdm
from trl import (
@@ -12,7 +14,6 @@ from trl import (
)
from axolotl.core.trainers.mixins import RngLoaderMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
@@ -74,19 +75,87 @@ class TRLPPOTrainer(PPOTrainer):
)
class AxolotlORPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
):
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
class AxolotlKTOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, KTOTrainer
):
# TODO remove once https://github.com/huggingface/trl/pull/3069 is included in a trl release
metrics = {}
forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = (
self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(
chosen_rewards
).mean()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(
rejected_rewards
).mean()
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(
reward_accuracies
).mean()
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
chosen_rewards - rejected_rewards
).mean()
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
)
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
policy_rejected_logits.detach().mean()
).mean()
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
policy_chosen_logits.detach().mean()
).mean()
metrics[f"{prefix}nll_loss"] = (
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
)
metrics[f"{prefix}log_odds_ratio"] = (
self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
)
metrics[f"{prefix}log_odds_chosen"] = (
self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
)
for k, v in metrics.items():
metrics[k] = v.item()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss
return loss, metrics
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
@@ -94,19 +163,89 @@ class AxolotlKTOTrainer(
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, CPOTrainer
):
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
class AxolotlRewardTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, RewardTrainer
):
forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
policy_chosen_logps,
policy_rejected_logps,
)
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = (
self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
)
metrics[f"{prefix}rewards/rejected"] = (
self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
)
metrics[f"{prefix}rewards/accuracies"] = (
self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
)
metrics[f"{prefix}rewards/margins"] = (
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards)
.mean()
.item()
)
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps)
.detach()
.mean()
.item()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps)
.detach()
.mean()
.item()
)
metrics[f"{prefix}logits/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean())
.mean()
.item()
)
metrics[f"{prefix}logits/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean())
.mean()
.item()
)
metrics[f"{prefix}nll_loss"] = (
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
)
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss
return loss, metrics
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
@@ -114,9 +253,7 @@ class AxolotlRewardTrainer(
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, PRMTrainer
):
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""

View File

@@ -2,17 +2,244 @@
extra axolotl specific training args
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional, Type
from typing import Optional
from PIL.Image import Resampling
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.integrations.config import merge_training_args
from axolotl.utils.schemas.enums import RingAttnFunc
AxolotlTrainingMixins: Type = merge_training_args()
@dataclass
class AxolotlTrainingMixins:
"""
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
sample_packing_sequentially: bool = field(
default=False,
metadata={
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
sample_packing_bin_size: int = field(
default=200,
metadata={
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
},
)
sample_packing_group_size: int = field(
default=100000,
metadata={
"help": "The number of samples to group together for packing. Increase for better packing."
},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
dataloader_prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)
cosine_min_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
embedding_lr_scale: Optional[float] = field(
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
orpo_alpha: Optional[float] = field(
default=None,
)
lisa_n_layers: Optional[int] = field(
default=None,
metadata={"help": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = field(
default=None,
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = field(
default=None,
metadata={"help": "path under the model to access the layers"},
)
curriculum_sampling: Optional[bool] = field(
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
alternate_optimizer: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate optimizer to the HF trainer"
},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)
kd_ce_alpha: Optional[float] = field(
default=None,
metadata={
"help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
},
)
kd_alpha: Optional[float] = field(
default=1.0,
metadata={"help": "The alpha scaling parameter for KD loss"},
)
kd_temperature: Optional[float] = field(
default=1.0,
metadata={
"help": "the temperature parameter for KL divergence loss when using KD"
},
)
kd_zscore_base_temp: Optional[float] = field(
default=None,
metadata={
"help": "the base temperature parameter for KL divergence with z-score when using KD"
},
)
kd_top_k_before_softmax: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
},
)
sequence_parallel_degree: Optional[int] = field(
default=1,
metadata={"help": "The number of workers to use in sequence parallelism"},
)
ring_attn_func: Optional[RingAttnFunc] = field(
default=None,
metadata={
"help": "The ring-flash-attn function to use in sequence parallelism"
},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(
default=None,
metadata={"help": "The size of the image to resize to"},
)
image_resize_algorithm: Resampling | None = field(
default=None,
metadata={"help": "The algorithm to use for image resizing"},
)
# end of multi-modal section
@dataclass

View File

@@ -1,224 +0,0 @@
"""
Base Axolotl Training Mixins shared across various trainer configs
"""
from dataclasses import dataclass, field
from typing import Optional
from PIL.Image import Resampling
@dataclass
class AxolotlTrainingMixins:
"""
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
sample_packing_sequentially: bool = field(
default=False,
metadata={
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
sample_packing_bin_size: int = field(
default=200,
metadata={
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
},
)
sample_packing_group_size: int = field(
default=100000,
metadata={
"help": "The number of samples to group together for packing. Increase for better packing."
},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "The number of processes to use for data processing"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
dataloader_prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)
cosine_min_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
embedding_lr_scale: Optional[float] = field(
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
orpo_alpha: Optional[float] = field(
default=None,
)
lisa_n_layers: Optional[int] = field(
default=None,
metadata={"help": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = field(
default=None,
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = field(
default=None,
metadata={"help": "path under the model to access the layers"},
)
curriculum_sampling: Optional[bool] = field(
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)
# kd_ce_alpha: Optional[float] = field(
# default=None,
# metadata={
# "help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
# },
# )
#
# kd_alpha: Optional[float] = field(
# default=1.0,
# metadata={"help": "The alpha scaling parameter for KD loss"},
# )
#
# kd_temperature: Optional[float] = field(
# default=1.0,
# metadata={
# "help": "the temperature parameter for KL divergence loss when using KD"
# },
# )
adam_beta3: Optional[float] = field(
default=None,
metadata={
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
},
)
adam_epsilon2: Optional[float] = field(
default=None,
metadata={
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(
default=None,
metadata={"help": "The size of the image to resize to"},
)
image_resize_algorithm: Resampling | None = field(
default=None,
metadata={"help": "The algorithm to use for image resizing"},
)
# end of multi-modal section

View File

@@ -1,12 +1,12 @@
"""Module containing Dataset functionality"""
import logging
import os
from typing import List, Optional, Union
import torch
from datasets import Dataset, IterableDataset
from axolotl.utils.logging import get_logger
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
@@ -15,25 +15,25 @@ from .prompt_tokenizers import PromptTokenizingStrategy
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = get_logger(__name__)
LOG = logging.getLogger("axolotl")
class TokenizedPromptDataset(Dataset):
"""Dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer: The prompt tokenizing method for processing the data.
dataset: Dataset with text files.
process_count: Number of processes to use for tokenizing.
keep_in_memory: Whether to keep the tokenized dataset in memory.
"""
Dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
dataset (dataset.Dataset): Dataset with text files.
process_count (int): Number of processes to use for tokenizing.
keep_in_memory (bool): Whether to keep the tokenized dataset in memory.
"""
def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Dataset,
process_count: int | None = None,
keep_in_memory: bool | None = False,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs,
):
self.prompt_tokenizer = prompt_tokenizer
@@ -48,13 +48,6 @@ class TokenizedPromptDataset(Dataset):
features = dataset.features.keys()
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
LOG.info(
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
)
num_proc = 1
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
@@ -82,14 +75,14 @@ class TokenizedPromptDataset(Dataset):
def wrap_dataset_for_tokenized_prompt(
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Dataset | IterableDataset,
dataset: Union[Dataset, IterableDataset],
**kwargs,
):
if isinstance(dataset, IterableDataset):
map_kwargs = {}
if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
features = list(dataset.features.keys())
features = dataset.features.keys()
return dataset.map(
prompt_tokenizer.tokenize_prompt,
remove_columns=features,
@@ -100,13 +93,12 @@ def wrap_dataset_for_tokenized_prompt(
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""Iterable dataset that returns constant length chunks of tokens from stream of
text files.
Args:
tokenizer: The processor used for processing the data.
dataset: Dataset with text files.
seq_length: Length of token sequences to return.
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for processing the data.
dataset (dataset.Dataset): Dataset with text files.
seq_length (int): Length of token sequences to return.
"""
def __init__( # pylint: disable=super-init-not-called
@@ -117,7 +109,7 @@ class ConstantLengthDataset(IterableDataset):
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.datasets: list[IterableDataset] = datasets
self.datasets: List[IterableDataset] = datasets
self.seq_length = seq_length
vocab_size = len(tokenizer.get_vocab())
@@ -181,10 +173,7 @@ class ConstantLengthDataset(IterableDataset):
}
else:
LOG.warning(
"Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],

View File

@@ -7,6 +7,7 @@ from pathlib import Path
from typing import Dict, Optional
import torch
from accelerate.logging import get_logger
from datasets import Dataset
from transformers.trainer import Trainer
@@ -16,7 +17,6 @@ from axolotl.train import (
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.logging import get_logger
from axolotl.utils.trainer import setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

View File

@@ -10,89 +10,71 @@
# License for the specific language governing permissions and limitations under
# the License.
"""Base class for all plugins.
"""
Base class for all plugins.
A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl.
Plugins can be used to integrate third-party models, modify the training process, or add new features.
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
"""
from __future__ import annotations
import collections
import importlib
import traceback
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
import logging
from typing import OrderedDict
from peft import PeftModel
from torch.optim import Optimizer
import torch
from torch.optim.lr_scheduler import LRScheduler
from transformers import PreTrainedModel, Trainer
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__, use_environ=True)
if TYPE_CHECKING:
from axolotl.common.datasets import TrainDatasetMeta
class BasePlugin:
"""Base class for all plugins. Defines the interface for plugin methods.
"""
Base class for all plugins. Defines the interface for plugin methods.
A plugin is a reusable, modular, and self-contained piece of code that extends
the functionality of Axolotl. Plugins can be used to integrate third-party models,
modify the training process, or add new features.
Attributes:
None
To create a new plugin, you need to inherit from the BasePlugin class and
implement the required methods.
Note:
Plugin methods include:
- register(cfg): Registers the plugin with the given configuration.
- load_datasets(cfg): Loads and preprocesses the dataset for training.
- pre_model_load(cfg): Performs actions before the model is loaded.
- post_model_build(cfg, model): Performs actions after the model is loaded, but
before LoRA adapters are applied.
- pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
- post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
- post_model_load(cfg, model): Performs actions after the model is loaded,
inclusive of any adapters.
- post_trainer_create(cfg, trainer): Performs actions after the trainer is
created.
- create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
- create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and
returns a learning rate scheduler.
- add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before
training.
- add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after
training.
Methods:
register(cfg): Registers the plugin with the given configuration.
load_datasets(cfg): Loads and preprocesses the dataset for training.
pre_model_load(cfg): Performs actions before the model is loaded.
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
post_trainer_create(cfg, trainer): Performs actions after the trainer is created.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training.
"""
def __init__(self):
"""Initializes the BasePlugin."""
"""
Initializes the BasePlugin.
"""
def register(self, cfg: DictDefault): # pylint: disable=unused-argument
"""Registers the plugin with the given configuration.
def register(self, cfg): # pylint: disable=unused-argument
"""
Registers the plugin with the given configuration.
Args:
cfg: The configuration for the plugin.
Parameters:
cfg (dict): The configuration for the plugin.
Returns:
None
"""
def get_input_args(self) -> str | None:
"""Returns a pydantic model for the plugin's input arguments."""
def get_training_args_mixin(self) -> str | None:
"""
Returns a dataclass model for the plugin's training arguments.
Returns a pydantic model for the plugin's input arguments.
"""
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
"""Loads and preprocesses the dataset for training.
def load_datasets(self, cfg: DictDefault, preprocess: bool = False):
"""
Loads and preprocesses the dataset for training.
Args:
cfg: The configuration for the plugin.
@@ -102,189 +84,181 @@ class BasePlugin:
dataset_meta: The metadata for the training dataset.
"""
def pre_model_load(self, cfg: DictDefault): # pylint: disable=unused-argument
"""Performs actions before the model is loaded.
Args:
cfg: The configuration for the plugin.
def pre_model_load(self, cfg): # pylint: disable=unused-argument
"""
# pylint: disable=unused-argument
def post_model_build(self, cfg: DictDefault, model: PreTrainedModel):
"""Performs actions after the model is built/loaded, but before any adapters are applied.
Performs actions before the model is loaded.
Args:
cfg: The configuration for the plugin.
"""
# pylint: disable=unused-argument
def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel):
"""Performs actions before LoRA weights are loaded.
Args:
cfg: The configuration for the plugin.
model: The loaded model.
"""
# pylint: disable=unused-argument
def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after LoRA weights are loaded.
Args:
cfg: The configuration for the plugin.
model: The loaded model.
"""
# pylint: disable=unused-argument
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after the model is loaded.
Args:
cfg: The configuration for the plugin.
model: The loaded model.
"""
# pylint: disable=unused-argument
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
"""Returns a custom class for the trainer.
Args:
cfg: The global axolotl configuration.
cfg (dict): The configuration for the plugin.
Returns:
The first non-`None` trainer class returned by a plugin.
None
"""
# pylint: disable=unused-argument
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Performs actions after the trainer is created.
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is built/loaded, but before any adapters are applied.
Args:
cfg: The configuration for the plugin.
trainer: The trainer object for training.
cfg (dict): The configuration for the plugin.
"""
def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument):
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
"""
Returns custom training arguments to set on TrainingArgs.
Performs actions after the model is loaded.
Args:
cfg: The global axolotl configuration.
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
object: dict containing the training arguments.
None
"""
def get_collator_cls_and_kwargs(
self, cfg: DictDefault, is_eval: bool = False
): # pylint: disable=unused-argument):
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Returns a custom class for the collator.
Performs actions before LoRA weights are loaded.
Args:
cfg: The global axolotl configuration.
is_eval: Whether this is an eval split.
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
class: The class for the collator.
None
"""
# pylint: disable=unused-argument
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
"""Creates and returns an optimizer for training.
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after LoRA weights are loaded.
Args:
cfg: The configuration for the plugin.
trainer: The trainer object for training.
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
The created optimizer.
None
"""
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
"""
Returns a custom class for the trainer.
Args:
cfg (dict): The global axolotl configuration.
Returns:
class: The class for the trainer.
"""
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
"""
Performs actions after the trainer is created.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
None
"""
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
object: The created optimizer.
"""
# pylint: disable=unused-argument
def create_lr_scheduler(
self,
cfg: DictDefault,
trainer: Trainer,
optimizer: Optimizer,
num_training_steps: int,
) -> LRScheduler | None:
"""Creates and returns a learning rate scheduler.
self, cfg, trainer, optimizer, num_training_steps
) -> LRScheduler | None: # pylint: disable=unused-argument
"""
Creates and returns a learning rate scheduler.
Args:
cfg: The configuration for the plugin.
trainer: The trainer object for training.
optimizer: The optimizer for training.
num_training_steps: Total number of training steps
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
num_training_steps (int): Total number of training steps
Returns:
The created learning rate scheduler.
object (LRScheduler): The created learning rate scheduler.
"""
# pylint: disable=unused-argument
def add_callbacks_pre_trainer(
self, cfg: DictDefault, model: PreTrainedModel
) -> list[Callable]:
"""Set up callbacks before creating the trainer.
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
"""
setup callbacks before creating the trainer.
Args:
cfg: The configuration for the plugin.
model: The loaded model.
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
A list of callback functions to be added to the `TrainingArgs`.
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
# pylint: disable=unused-argument
def add_callbacks_post_trainer(
self, cfg: DictDefault, trainer: Trainer
) -> list[Callable]:
"""Adds callbacks to the trainer after creating the trainer. This is useful for
callbacks that require access to the model or trainer.
self, cfg, trainer
): # pylint: disable=unused-argument
"""
Adds callbacks to the trainer after creating the trainer.
This is useful for callbacks that require access to the model or trainer.
Args:
cfg: The configuration for the plugin.
trainer: The trainer object for training.
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
A list of callback functions to be added
List[callable]: A list of callback functions to be added
"""
return []
# pylint: disable=unused-argument
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after training is complete.
def post_train(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after training is complete.
Args:
cfg: The axolotl configuration.
model: The loaded model.
cfg (dict): The axolotl configuration
model (object): The loaded model.
Returns:
None
"""
def post_train_unload(self, cfg: DictDefault): # pylint: disable=unused-argument
"""Performs actions after training is complete and the model is unloaded.
def post_train_unload(self, cfg): # pylint: disable=unused-argument
"""
Performs actions after training is complete and the model is unloaded.
Args:
cfg: The configuration for the plugin.
cfg (dict): The configuration for the plugin.
Returns:
None
"""
def load_plugin(plugin_name: str) -> BasePlugin:
"""Loads a plugin based on the given plugin name.
"""
Loads a plugin based on the given plugin name.
The plugin name should be in the format "module_name.class_name". This function
splits the plugin name into module and class, imports the module, retrieves the
class from the module, and creates an instance of the class.
The plugin name should be in the format "module_name.class_name".
This function splits the plugin name into module and class, imports the module,
retrieves the class from the module, and creates an instance of the class.
Args:
plugin_name: The name of the plugin to be loaded. The name should be in the
format "module_name.class_name".
Parameters:
plugin_name (str): The name of the plugin to be loaded. The name should be in the format "module_name.class_name".
Returns:
An instance of the loaded plugin.
BasePlugin: An instance of the loaded plugin.
Raises:
ImportError: If the plugin module cannot be imported.
ImportError: If the plugin module cannot be imported.
"""
# split the plugin name into module and class
module_name, class_name = plugin_name.rsplit(".", 1)
@@ -309,27 +283,29 @@ def load_plugin(plugin_name: str) -> BasePlugin:
return plugin
class PluginManager: # pylint: disable=too-many-public-methods
"""The `PluginManager` class is responsible for loading and managing plugins. It
should be a singleton so it can be accessed from anywhere in the codebase.
class PluginManager:
"""
The PluginManager class is responsible for loading and managing plugins.
It should be a singleton so it can be accessed from anywhere in the codebase.
Attributes:
plugins: A list of loaded plugins.
plugins (List[BasePlugin]): A list of loaded plugins.
Note:
Key methods include:
- get_instance(): Static method to get the singleton instance of `PluginManager`.
- register(plugin_name: str): Registers a new plugin by its name.
- pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
Methods:
get_instance(): Static method to get the singleton instance of PluginManager.
register(plugin_name: str): Registers a new plugin by its name.
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
"""
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
_instance: PluginManager | None = None
_cfg: DictDefault | None = None
_instance = None
_cfg = None
def __new__(cls):
"""Creates a new instance of PluginManager if it doesn't exist yet."""
"""
Creates a new instance of PluginManager if it doesn't exist yet.
"""
if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins: OrderedDict[str, BasePlugin] = (
@@ -339,8 +315,9 @@ class PluginManager: # pylint: disable=too-many-public-methods
@staticmethod
def get_instance() -> "PluginManager":
"""Returns the singleton instance of PluginManager. If the instance doesn't
exist, it creates a new one.
"""
Returns the singleton instance of PluginManager.
If the instance doesn't exist, it creates a new one.
"""
if PluginManager._instance is None:
PluginManager()
@@ -355,30 +332,32 @@ class PluginManager: # pylint: disable=too-many-public-methods
self._cfg = cfg
def register(self, plugin_name: str):
"""Registers a new plugin by its name.
Args:
plugin_name: The name of the plugin to be registered.
Raises:
ImportError: If the plugin module cannot be imported.
"""
try:
LOG.info(f"Attempting to load plugin: {plugin_name}")
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
LOG.info(f"Plugin loaded successfully: {plugin_name}")
except ImportError as exc:
LOG.error(f"Failed to load plugin: {plugin_name}")
# print stacktrace
traceback.print_exc()
print(f"Error: {exc}")
Registers a new plugin by its name.
def get_input_args(self) -> list[str]:
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
Parameters:
plugin_name (str): The name of the plugin to be registered.
Returns:
A list of Pydantic classes for all registered plugins' input arguments.'
None
Raises:
ImportError: If the plugin module cannot be imported.
"""
try:
logging.info(f"Attempting to load plugin: {plugin_name}")
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
logging.info(f"Plugin loaded successfully: {plugin_name}")
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
def get_input_args(self):
"""
Returns a list of Pydantic classes for all registered plugins' input arguments.'
Returns:
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
"""
input_args = []
for plugin in self.plugins.values():
@@ -387,31 +366,16 @@ class PluginManager: # pylint: disable=too-many-public-methods
input_args.append(input_args_from_plugin)
return input_args
def get_training_args_mixin(self):
def load_datasets(self, cfg, preprocess: bool = False):
"""
Returns a list of dataclasses for all registered plugins' training args mixins'
Returns:
list[str]: A list of dataclsses
"""
training_args = []
for plugin in self.plugins.values():
training_args_from_plugin = plugin.get_training_args_mixin()
if training_args_from_plugin is not None:
training_args.append(training_args_from_plugin)
return training_args
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
"""Calls the load_datasets method of each registered plugin.
Calls the load_datasets method of each registered plugin.
Args:
cfg: The configuration for the plugins.
preprocess: Whether this is preprocess step of the datasets.
preprocess : Whether this is preprocess step of the datasets.
Returns:
The dataset metadata loaded from all registered plugins.
dataset_meta: The dataset metadata loaded from all registered plugins.
"""
return_ds_meta = None
for plugin in self.plugins.values():
@@ -423,66 +387,83 @@ class PluginManager: # pylint: disable=too-many-public-methods
raise RuntimeError("Multiple plugins loaded datasets")
return return_ds_meta
def pre_model_load(self, cfg: DictDefault):
"""Calls the pre_model_load method of all registered plugins.
def pre_model_load(self, cfg):
"""
Calls the pre_model_load method of all registered plugins.
Args:
cfg: The configuration for the plugins.
Parameters:
cfg (dict): The configuration for the plugins.
Returns:
None
"""
for plugin in self.plugins.values():
plugin.pre_model_load(cfg)
def post_model_build(self, cfg: DictDefault, model: PreTrainedModel):
"""Calls the `post_model_build` method of all registered plugins after the
model has been built / loaded, but before any adapters have been applied.
def post_model_build(self, cfg, model):
"""
Calls the post_model_build method of all registered plugins after the model has been built/loaded,
but before any adapters have been applied.
Args:
cfg: The configuration for the plugins.
model: The loaded model.
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
"""
for plugin in self.plugins.values():
plugin.post_model_build(cfg, model)
def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel):
"""Calls the `pre_lora_load` method of all registered plugins.
Args:
cfg: The configuration for the plugins.
model: The loaded model.
def post_model_load(self, cfg, model):
"""
for plugin in self.plugins.values():
plugin.pre_lora_load(cfg, model)
Calls the post_model_load method of all registered plugins after the model has been loaded
inclusive of any adapters
def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Calls the `post_lora_load` method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Args:
cfg: The configuration for the plugins.
model: The loaded model.
"""
for plugin in self.plugins.values():
plugin.post_lora_load(cfg, model)
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Calls the `post_model_load` method of all registered plugins after the model
has been loaded inclusive of any adapters.
Args:
cfg: The configuration for the plugins.
model: The loaded model.
Returns:
None
"""
for plugin in self.plugins.values():
plugin.post_model_load(cfg, model)
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
"""Calls the `get_trainer_cls` method of all registered plugins and returns the
first non-`None` trainer class.
def pre_lora_load(self, cfg, model):
"""
Calls the pre_lora_load method of all registered plugins.
Args:
cfg: The configuration for the plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
The first non-`None` trainer class returned by a plugin.
None
"""
for plugin in self.plugins.values():
plugin.pre_lora_load(cfg, model)
def post_lora_load(self, cfg, model):
"""
Calls the post_lora_load method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins.values():
plugin.post_lora_load(cfg, model)
def get_trainer_cls(self, cfg):
"""
Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class.
Parameters:
cfg (dict): The configuration for the plugins.
Returns:
object: The trainer class, or None if none was found.
"""
for plugin in self.plugins.values():
trainer_cls = plugin.get_trainer_cls(cfg)
@@ -490,61 +471,29 @@ class PluginManager: # pylint: disable=too-many-public-methods
return trainer_cls
return None
def get_training_args(self, cfg):
def post_trainer_create(self, cfg, trainer):
"""
Calls the get_training_args method of all registered plugins and returns the combined training arguments.
Calls the post_trainer_create method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
Returns:
object: The training arguments
"""
training_args_kwargs = {}
for plugin in self.plugins.values():
training_args = plugin.get_training_args(cfg)
if training_args is not None:
training_args_kwargs.update(training_args)
return training_args_kwargs
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
"""
Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.
Parameters:
cfg (dict): The configuration for the plugins.
is_eval (bool): Whether this is an eval split.
Returns:
object: The collator class, or None if none was found.
"""
for plugin in self.plugins.values():
collator = plugin.get_collator_cls_and_kwargs(cfg, is_eval=is_eval)
if collator is not None:
collator_cls, collator_kwargs = collator
return collator_cls, collator_kwargs
return None
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Calls the `post_trainer_create` method of all registered plugins.
Args:
cfg: The configuration for the plugins.
trainer: The trainer object for training.
None
"""
for plugin in self.plugins.values():
plugin.post_trainer_create(cfg, trainer)
def create_optimizer(self, trainer: Trainer) -> Optimizer | None:
"""Calls the `create_optimizer` method of all registered plugins and returns
the first non-`None` optimizer.
def create_optimizer(self, trainer):
"""
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
Args:
trainer: The trainer object for training.
Parameters:
trainer (object): The trainer object for training.
Returns:
The created optimizer, or `None` if none was found.
object: The created optimizer, or None if none was found.
"""
for plugin in self.plugins.values():
optimizer = plugin.create_optimizer(self.cfg, trainer)
@@ -553,17 +502,17 @@ class PluginManager: # pylint: disable=too-many-public-methods
return None
def create_lr_scheduler(
self, trainer: Trainer, optimizer: Optimizer, num_training_steps: int
self, trainer, optimizer, num_training_steps
) -> LRScheduler | None:
"""Calls the `create_lr_scheduler` method of all registered plugins and returns
the first non-`None` scheduler.
"""
Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.
Args:
trainer: The trainer object for training.
optimizer: The optimizer for training.
Parameters:
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
Returns:
The created learning rate scheduler, or `None` if not found.
object: The created learning rate scheduler, or None if none was found.
"""
for plugin in self.plugins.values():
scheduler: LRScheduler | None = plugin.create_lr_scheduler(
@@ -576,17 +525,16 @@ class PluginManager: # pylint: disable=too-many-public-methods
return scheduler
return None
def add_callbacks_pre_trainer(
self, cfg: DictDefault, model: PreTrainedModel
) -> list[Callable]:
"""Calls the add_callbacks_pre_trainer method of all registered plugins.
def add_callbacks_pre_trainer(self, cfg, model):
"""
Calls the add_callbacks_pre_trainer method of all registered plugins.
Args:
cfg: The configuration for the plugins.
model: The loaded model.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
A list of callback functions to be added to the `TrainingArgs`.
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins.values():
@@ -595,17 +543,16 @@ class PluginManager: # pylint: disable=too-many-public-methods
callbacks.extend(plugin_callbacks)
return callbacks
def add_callbacks_post_trainer(
self, cfg: DictDefault, trainer: Trainer
) -> list[Callable]:
"""Calls the `add_callbacks_post_trainer` method of all registered plugins.
def add_callbacks_post_trainer(self, cfg, trainer):
"""
Calls the add_callbacks_post_trainer method of all registered plugins.
Args:
cfg: The configuration for the plugins.
trainer: The trainer object for training.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
Returns:
A list of callback functions to be added to the `TrainingArgs`.
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins.values():
@@ -614,30 +561,41 @@ class PluginManager: # pylint: disable=too-many-public-methods
callbacks.extend(plugin_callbacks)
return callbacks
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Calls the post_train method of all registered plugins.
def post_train(self, cfg, model):
"""
Calls the post_train method of all registered plugins.
Args:
cfg: The configuration for the plugins.
model: The loaded model.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins.values():
plugin.post_train(cfg, model)
def post_train_unload(self, cfg: DictDefault):
"""Calls the post_train_unload method of all registered plugins.
def post_train_unload(self, cfg):
"""
Calls the post_train_unload method of all registered plugins.
Args:
cfg: The configuration for the plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins.values():
plugin.post_train_unload(cfg)
class BaseOptimizerFactory:
"""Base class for factories to create custom optimizers"""
"""
Base class for factories to create custom optimizers
"""
def __call__(
self, opt_model, training_args, **optimizer_kwargs
) -> Optimizer | None:
) -> "torch.optim.Optimizer":
pass

View File

@@ -16,7 +16,7 @@ Module to handle merging the plugins' input arguments with the base configuratio
This was moved here to prevent circular imports.
"""
from typing import Any, Dict, List, Type
from typing import Any, Dict, List
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
@@ -61,43 +61,3 @@ def merge_input_args():
]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
def merge_training_args() -> Type:
"""
Merges training arguments from registered plugins with the base TrainingArguments.
This function retrieves the training arguments from registered plugins using the PluginManager.
It then dynamically creates new classes, AxolotlTrainingMixins,
that inherit from the base configurations and include the training arguments from the plugins.
Returns:
tuple: A tuple containing the newly created classes, AxolotlTrainingMixins.
"""
# pylint: disable=duplicate-code
from axolotl.core.training_args_base import (
AxolotlTrainingMixins as AxolotlTrainingMixinsBase,
)
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
training_args_mixins: List[str] = plugin_manager.get_training_args_mixin()
mixin_classes = []
dynamic_input = ""
for plugin_args in training_args_mixins:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
mixin_classes.append(plugin_cls)
if dynamic_input:
dynamic_input += f"class AxolotlTrainingMixins(AxolotlTrainingMixinsBase, {', '.join(mixin_classes)}):\n pass\n"
namespace: Dict[Any, Any] = {}
local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase}
exec( # pylint: disable=exec-used # nosec B102
dynamic_input, {**globals(), **local_vars}, namespace
)
AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name
"AxolotlTrainingMixins"
]
return AxolotlTrainingMixins
return AxolotlTrainingMixinsBase

View File

@@ -24,14 +24,6 @@ pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transform
## Usage
**NOTE**: If you are training a VLM model, please use older version of Axolotl as upstream has applied a major VLM refactor, and our patches have not been updated yet.
```bash
git checkout 787880215b3ab32ccaf81c1b2e9588c6f3e6e764
pip3 install --no-build-isolation -e .
```
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

View File

@@ -19,16 +19,17 @@ Cut Cross Entropy is an optimized implementation of cross entropy loss
from Apple's ML team.
"""
import importlib
import logging
import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version
from axolotl.utils.logging import get_logger
from axolotl.utils.distributed import is_main_process
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
LOG = get_logger(__name__, use_environ=True)
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
_CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using "
@@ -75,9 +76,10 @@ class CutCrossEntropyPlugin(BasePlugin):
cce_patch,
)
LOG.info(
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
)
if is_main_process(use_environ=True):
LOG.info(
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
)
# The patch checks model_type internally
cce_patch(cfg.model_config_type)

View File

@@ -15,13 +15,12 @@
"""
Module for handling Cut Cross Entropy input arguments.
"""
import logging
from typing import Optional
from pydantic import BaseModel, model_validator
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy.args")
class CutCrossEntropyArgs(BaseModel):

View File

@@ -20,15 +20,25 @@ from cut_cross_entropy.transformers.utils import (
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import (
_CONFIG_FOR_DOC,
COHERE_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,

Some files were not shown because too many files have changed in this diff Show More