Compare commits

..

24 Commits

Author SHA1 Message Date
Wing Lian
c9880977be split llmcompressor from vllm checks 2025-04-29 08:35:06 -04:00
Wing Lian
f196941315 additional fixes for docker and saving compressed 2025-04-28 13:16:29 -04:00
Rahul Tuli
5be047ac46 Fix: Test
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
2025-04-28 13:16:29 -04:00
Rahul Tuli
758115b8c6 Apply patch from @winglian
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
2025-04-28 13:16:29 -04:00
Rahul Tuli
0dc1da5876 Add: line about further optimizations using llmcompressor
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
2025-04-28 13:16:29 -04:00
Rahul Tuli
f3e876dbfc Address Review Comments:
* deleted redundant docs/llm_compressor.qmd
* incorporated feedback in integration README.md
* added llmcompressor integration to docs/custom_integrations.qmd

Signed-off-by: Rahul Tuli <rtuli@redhat.com>
2025-04-28 13:16:29 -04:00
Rahul Tuli
99c13ef60c Add: .qmd file 2025-04-28 13:16:29 -04:00
Rahul Tuli
2c24434ee0 Tests, Style, Updates 2025-04-28 13:16:29 -04:00
Rahul Tuli
586268a0d7 Rebase and updates! 2025-04-28 13:16:29 -04:00
Rahul Tuli
b600e119b6 Add: llm_compressor integration documentation 2025-04-28 13:16:29 -04:00
Rahul Tuli
a8e5ba000e Move: LLMCompressorPlugin into it's own submodule 2025-04-28 13:16:29 -04:00
Rahul Tuli
bc3dfa666d Update model config 2025-04-28 13:16:29 -04:00
Rahul Tuli
4371f3459e Use: absolute import 2025-04-28 13:16:29 -04:00
Rahul Tuli
cc58d5e072 Rename: sft.yaml to sparse-finetuning.yaml 2025-04-28 13:16:29 -04:00
Rahul Tuli
d197b054e3 Add: llcompressor installable 2025-04-28 13:16:29 -04:00
Rahul Tuli
7e1e153831 Address review comments from @markurtz 2025-04-28 13:16:29 -04:00
Rahul Tuli
42de3096cf Apply suggestions from @markurtz
Co-authored-by: Mark Kurtz <mark.j.kurtz@gmail.com>
2025-04-28 13:16:29 -04:00
Rahul Tuli
27758840a1 Update llmcompressor version to latest 2025-04-28 13:16:29 -04:00
Rahul Tuli
8dbf5c215a Revert: TODO's 2025-04-28 13:16:29 -04:00
Rahul Tuli
6411ca3fe1 Use: warning over warn 2025-04-28 13:16:29 -04:00
Rahul Tuli
813809c54d pre commit hooks 2025-04-28 13:16:29 -04:00
Rahul Tuli
af7cfdc30b Add:llmcompressor instalable 2025-04-28 13:16:29 -04:00
Rahul Tuli
b76d2d1130 Update: review comments! 2025-04-28 13:16:29 -04:00
Rahul Tuli
7946f89df4 Add: SFTPlugin with llmcompressor 2025-04-28 13:16:29 -04:00
103 changed files with 409 additions and 5773 deletions

View File

@@ -22,6 +22,12 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""

View File

@@ -15,6 +15,11 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -30,7 +35,7 @@ jobs:
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
axolotl_extras: vllm
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -62,7 +67,6 @@ jobs:
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: |
@@ -78,6 +82,11 @@ jobs:
strategy:
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -9,7 +9,6 @@ on:
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
- 'src/axolotl/utils/distributed.py'
workflow_dispatch:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
@@ -33,6 +32,13 @@ jobs:
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras: # no vllm support for 2.4.1
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -12,6 +12,11 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -65,6 +70,11 @@ jobs:
strategy:
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -1,61 +0,0 @@
name: Preview
on:
workflow_dispatch:
pull_request:
types: [opened, synchronize, reopened]
# Run the workflow only when one of these files changes
paths:
- '**/*.md' # any Markdown file
- '**/*.qmd' # any Quarto file
- '_quarto.yaml'
permissions:
checks: write
contents: write
deployments: write
issues: write
discussions: write
pages: write
pull-requests: write
statuses: write
jobs:
preview:
runs-on: ubuntu-latest
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Set up Quarto
uses: quarto-dev/quarto-actions/setup@v2
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install dependencies
run: |
python3 -m pip install jupyter quartodoc
python3 -m pip install -e . --no-deps
- name: Build autodoc
run: quartodoc build
- name: Quarto render
run: quarto render
- name: Netlify Publish
uses: nwtgck/actions-netlify@v3.0
with:
publish-dir: './_site'
enable-pull-request-comment: true
enable-github-deployment: true
github-token: ${{ secrets.GITHUB_TOKEN }}
deploy-message: "Deployed On Netlify"
github-deployment-environment: 'preview'
github-deployment-description: 'Preview Deployment'
env:
NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }}
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}

View File

@@ -26,7 +26,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
timeout-minutes: 20
steps:
@@ -106,6 +106,13 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -27,9 +27,6 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
env:
TRANSFORMERS_IS_CI: "yes"
jobs:
pre-commit:
name: pre-commit
@@ -44,101 +41,15 @@ jobs:
env:
SKIP: no-commit-to-branch
preload-cache:
name: Preload HF cache
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0"]
timeout-minutes: 20
env:
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v tests/conftest.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest:
name: PyTest
runs-on: ubuntu-latest
needs: [preload-cache]
strategy:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
@@ -207,15 +118,24 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
needs: [preload-cache]
strategy:
fail-fast: false
max-parallel: 1
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
timeout-minutes: 20
steps:
@@ -276,6 +196,15 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
docker-e2e-tests-1st:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...

161
.runpod/.gitignore vendored
View File

@@ -1,161 +0,0 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
pod/scripts/config.yaml

View File

@@ -1,18 +0,0 @@
FROM axolotlai/axolotl-cloud:main-py3.11-cu124-2.6.0
COPY .runpod/requirements.txt /requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install --upgrade pip && \
python3 -m pip install --upgrade -r /requirements.txt
# Environment settings
ARG BASE_VOLUME="/runpod-volume"
ENV BASE_VOLUME=$BASE_VOLUME
ENV HF_DATASETS_CACHE="${BASE_VOLUME}/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
COPY .runpod/src /src
WORKDIR /src
CMD ["python3", "/src/handler.py"]

View File

@@ -1,335 +0,0 @@
<h1>LLM Post Training- Full fine-tune, LoRA, QLoRa etc. Llama/Mistral/Gemma and more</h1>
# Configuration Options
This document outlines all available configuration options for training models. The configuration can be provided as a JSON request.
## Usage
You can use these configuration Options:
1. As a JSON request body:
```json
{
"input": {
"user_id": "user",
"model_id": "model-name",
"run_id": "run-id",
"credentials": {
"wandb_api_key": "", # add your Weights & biases key. TODO: you will be able to set this in Enviornment variables.
"hf_token": "", # add your HF_token. TODO: you will be able to set this in Enviornment variables.
},
"args": {
"base_model": "NousResearch/Llama-3.2-1B",
// ... other options
}
}
}
```
## Configuration Options
### Model Configuration
| Option | Description | Default |
| ------------------- | --------------------------------------------------------------------------------------------- | -------------------- |
| `base_model` | Path to the base model (local or HuggingFace) | Required |
| `base_model_config` | Configuration path for the base model | Same as base_model |
| `revision_of_model` | Specific model revision from HuggingFace hub | Latest |
| `tokenizer_config` | Custom tokenizer configuration path | Optional |
| `model_type` | Type of model to load | AutoModelForCausalLM |
| `tokenizer_type` | Type of tokenizer to use | AutoTokenizer |
| `hub_model_id` | Repository ID where the model will be pushed on Hugging Face Hub (format: username/repo-name) | Optional |
## Model Family Identification
| Option | Default | Description |
| -------------------------- | ------- | ------------------------------ |
| `is_falcon_derived_model` | `false` | Whether model is Falcon-based |
| `is_llama_derived_model` | `false` | Whether model is LLaMA-based |
| `is_qwen_derived_model` | `false` | Whether model is Qwen-based |
| `is_mistral_derived_model` | `false` | Whether model is Mistral-based |
## Model Configuration Overrides
| Option | Default | Description |
| ----------------------------------------------- | ---------- | ---------------------------------- |
| `overrides_of_model_config.rope_scaling.type` | `"linear"` | RoPE scaling type (linear/dynamic) |
| `overrides_of_model_config.rope_scaling.factor` | `1.0` | RoPE scaling factor |
### Model Loading Options
| Option | Description | Default |
| -------------- | ----------------------------- | ------- |
| `load_in_8bit` | Load model in 8-bit precision | false |
| `load_in_4bit` | Load model in 4-bit precision | false |
| `bf16` | Use bfloat16 precision | false |
| `fp16` | Use float16 precision | false |
| `tf32` | Use tensor float 32 precision | false |
## Memory and Device Settings
| Option | Default | Description |
| ------------------ | --------- | ----------------------- |
| `gpu_memory_limit` | `"20GiB"` | GPU memory limit |
| `lora_on_cpu` | `false` | Load LoRA on CPU |
| `device_map` | `"auto"` | Device mapping strategy |
| `max_memory` | `null` | Max memory per device |
## Training Hyperparameters
| Option | Default | Description |
| ----------------------------- | --------- | --------------------------- |
| `gradient_accumulation_steps` | `1` | Gradient accumulation steps |
| `micro_batch_size` | `2` | Batch size per GPU |
| `eval_batch_size` | `null` | Evaluation batch size |
| `num_epochs` | `4` | Number of training epochs |
| `warmup_steps` | `100` | Warmup steps |
| `warmup_ratio` | `0.05` | Warmup ratio |
| `learning_rate` | `0.00003` | Learning rate |
| `lr_quadratic_warmup` | `false` | Quadratic warmup |
| `logging_steps` | `null` | Logging frequency |
| `eval_steps` | `null` | Evaluation frequency |
| `evals_per_epoch` | `null` | Evaluations per epoch |
| `save_strategy` | `"epoch"` | Checkpoint saving strategy |
| `save_steps` | `null` | Saving frequency |
| `saves_per_epoch` | `null` | Saves per epoch |
| `save_total_limit` | `null` | Maximum checkpoints to keep |
| `max_steps` | `null` | Maximum training steps |
### Dataset Configuration
```yaml
datasets:
- path: vicgalle/alpaca-gpt4 # HuggingFace dataset or TODO: You will be able to add the local path.
type: alpaca # Format type (alpaca, gpteacher, oasst, etc.)
ds_type: json # Dataset type
data_files: path/to/data # Source data files
train_on_split: train # Dataset split to use
```
## Chat Template Settings
| Option | Default | Description |
| ------------------------ | -------------------------------- | ---------------------- |
| `chat_template` | `"tokenizer_default"` | Chat template type |
| `chat_template_jinja` | `null` | Custom Jinja template |
| `default_system_message` | `"You are a helpful assistant."` | Default system message |
## Dataset Processing
| Option | Default | Description |
| ----------------------------- | -------------------------- | --------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_processes` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `dataset_exact_deduplication` | `true` | Deduplicate datasets |
## LoRA Configuration
| Option | Default | Description |
| -------------------------- | ---------------------- | ------------------------------ |
| `adapter` | `"lora"` | Adapter type (lora/qlora) |
| `lora_model_dir` | `""` | Directory with pretrained LoRA |
| `lora_r` | `8` | LoRA attention dimension |
| `lora_alpha` | `16` | LoRA alpha parameter |
| `lora_dropout` | `0.05` | LoRA dropout |
| `lora_target_modules` | `["q_proj", "v_proj"]` | Modules to apply LoRA |
| `lora_target_linear` | `false` | Target all linear modules |
| `peft_layers_to_transform` | `[]` | Layers to transform |
| `lora_modules_to_save` | `[]` | Modules to save |
| `lora_fan_in_fan_out` | `false` | Fan in/out structure |
## Optimization Settings
| Option | Default | Description |
| ------------------------- | ------- | -------------------------- |
| `train_on_inputs` | `false` | Train on input prompts |
| `group_by_length` | `false` | Group by sequence length |
| `gradient_checkpointing` | `false` | Use gradient checkpointing |
| `early_stopping_patience` | `3` | Early stopping patience |
## Learning Rate Scheduling
| Option | Default | Description |
| -------------------------- | ---------- | -------------------- |
| `lr_scheduler` | `"cosine"` | Scheduler type |
| `lr_scheduler_kwargs` | `{}` | Scheduler parameters |
| `cosine_min_lr_ratio` | `null` | Minimum LR ratio |
| `cosine_constant_lr_ratio` | `null` | Constant LR ratio |
| `lr_div_factor` | `null` | LR division factor |
## Optimizer Settings
| Option | Default | Description |
| ---------------------- | ------------ | ------------------- |
| `optimizer` | `"adamw_hf"` | Optimizer choice |
| `optim_args` | `{}` | Optimizer arguments |
| `optim_target_modules` | `[]` | Target modules |
| `weight_decay` | `null` | Weight decay |
| `adam_beta1` | `null` | Adam beta1 |
| `adam_beta2` | `null` | Adam beta2 |
| `adam_epsilon` | `null` | Adam epsilon |
| `max_grad_norm` | `null` | Gradient clipping |
## Attention Implementations
| Option | Default | Description |
| -------------------------- | ------- | ----------------------------- |
| `flash_optimum` | `false` | Use better transformers |
| `xformers_attention` | `false` | Use xformers |
| `flash_attention` | `false` | Use flash attention |
| `flash_attn_cross_entropy` | `false` | Flash attention cross entropy |
| `flash_attn_rms_norm` | `false` | Flash attention RMS norm |
| `flash_attn_fuse_qkv` | `false` | Fuse QKV operations |
| `flash_attn_fuse_mlp` | `false` | Fuse MLP operations |
| `sdp_attention` | `false` | Use scaled dot product |
| `s2_attention` | `false` | Use shifted sparse attention |
## Tokenizer Modifications
| Option | Default | Description |
| ---------------- | ------- | ---------------------------- |
| `special_tokens` | - | Special tokens to add/modify |
| `tokens` | `[]` | Additional tokens |
## Distributed Training
| Option | Default | Description |
| ----------------------- | ------- | --------------------- |
| `fsdp` | `null` | FSDP configuration |
| `fsdp_config` | `null` | FSDP config options |
| `deepspeed` | `null` | Deepspeed config path |
| `ddp_timeout` | `null` | DDP timeout |
| `ddp_bucket_cap_mb` | `null` | DDP bucket capacity |
| `ddp_broadcast_buffers` | `null` | DDP broadcast buffers |
<details>
<summary><h3>Example Configuration Request:</h3></summary>
Here's a complete example for fine-tuning a LLaMA model using LoRA:
```json
{
"input": {
"user_id": "user",
"model_id": "llama-test",
"run_id": "test-run",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "NousResearch/Llama-3.2-1B",
"load_in_8bit": false,
"load_in_4bit": false,
"strict": false,
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca"
}
],
"dataset_prepared_path": "last_run_prepared",
"val_set_size": 0.1,
"output_dir": "./outputs/lora-out",
"adapter": "lora",
"sequence_len": 2048,
"sample_packing": true,
"eval_sample_packing": true,
"pad_to_sequence_len": true,
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_modules": [
"gate_proj",
"down_proj",
"up_proj",
"q_proj",
"v_proj",
"k_proj",
"o_proj"
],
"gradient_accumulation_steps": 2,
"micro_batch_size": 2,
"num_epochs": 1,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": false,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"loss_watchdog_threshold": 5,
"loss_watchdog_patience": 3,
"warmup_steps": 10,
"evals_per_epoch": 4,
"saves_per_epoch": 1,
"weight_decay": 0,
"hub_model_id": "runpod/llama-fr-lora",
"wandb_name": "test-run-1",
"wandb_project": "test-run-1",
"wandb_entity": "axo-test",
"special_tokens": {
"pad_token": "<|end_of_text|>"
}
}
}
}
```
</details>
### Advanced Features
#### Wandb Integration
- `wandb_project`: Project name for Weights & Biases
- `wandb_entity`: Team name in W&B
- `wandb_watch`: Monitor model with W&B
- `wandb_name`: Name of the W&B run
- `wandb_run_id`: ID for the W&B run
#### Performance Optimization
- `sample_packing`: Enable efficient sequence packing
- `eval_sample_packing`: Use sequence packing during evaluation
- `torch_compile`: Enable PyTorch 2.0 compilation
- `flash_attention`: Use Flash Attention implementation
- `xformers_attention`: Use xFormers attention implementation
### Available Optimizers
The following optimizers are supported:
- `adamw_hf`: HuggingFace's AdamW implementation
- `adamw_torch`: PyTorch's AdamW
- `adamw_torch_fused`: Fused AdamW implementation
- `adamw_torch_xla`: XLA-optimized AdamW
- `adamw_apex_fused`: NVIDIA Apex fused AdamW
- `adafactor`: Adafactor optimizer
- `adamw_anyprecision`: Anyprecision AdamW
- `adamw_bnb_8bit`: 8-bit AdamW from bitsandbytes
- `lion_8bit`: 8-bit Lion optimizer
- `lion_32bit`: 32-bit Lion optimizer
- `sgd`: Stochastic Gradient Descent
- `adagrad`: Adagrad optimizer
## Notes
- Set `load_in_8bit: true` or `load_in_4bit: true` for memory-efficient training
- Enable `flash_attention: true` for faster training on modern GPUs
- Use `gradient_checkpointing: true` to reduce memory usage
- Adjust `micro_batch_size` and `gradient_accumulation_steps` based on your GPU memory
For more detailed information, please refer to the [documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html).
### Errors:
- if you face any issues with the Flash Attention-2, Delete yoor worker and Re-start.

View File

@@ -1,93 +0,0 @@
{
"title": "Axolotl Fine-Tuning",
"description": "Serverless fine-tuning of open-source LLMs with Axolotl. Supports LoRA, QLoRA, DPO, and more using Hugging Face models and datasets.",
"type": "serverless",
"category": "language",
"iconUrl": "https://avatars.githubusercontent.com/u/167502477",
"config": {
"runsOn": "GPU",
"containerDiskInGb": 200,
"gpuCount": 1,
"allowedCudaVersions": [
"12.8",
"12.7",
"12.6",
"12.5",
"12.4"
],
"presets": [],
"env": [
{
"key": "TOKENIZER",
"input": {
"name": "Tokenizer",
"type": "string",
"description": "Name or path of the Hugging Face tokenizer to use.",
"default": "",
"advanced": true
}
},
{
"key": "MAX_NUM_SEQS",
"input": {
"name": "Max Num Seqs",
"type": "number",
"description": "Maximum number of sequences per iteration.",
"default": 256,
"advanced": true
}
},
{
"key": "DISABLE_LOG_STATS",
"input": {
"name": "Disable Log Stats",
"type": "boolean",
"description": "Disable logging statistics.",
"default": false,
"trueValue": "true",
"falseValue": "false"
}
},
{
"key": "LOAD_FORMAT",
"input": {
"name": "Load Format",
"type": "string",
"description": "The format of the model weights to load.",
"default": "auto",
"options": [
{
"label": "auto",
"value": "auto"
},
{
"label": "pt",
"value": "pt"
},
{
"label": "safetensors",
"value": "safetensors"
},
{
"label": "npcache",
"value": "npcache"
},
{
"label": "dummy",
"value": "dummy"
},
{
"label": "tensorizer",
"value": "tensorizer"
},
{
"label": "bitsandbytes",
"value": "bitsandbytes"
}
],
"advanced": true
}
}
]
}
}

View File

@@ -1,7 +0,0 @@
# Required Python packages get listed here, one per line.
# Reccomended to lock the version number to avoid unexpected changes.
# You can also install packages from a git repository, e.g.:
# git+https://github.com/runpod/runpod-python.git
# To learn more, see https://pip.pypa.io/en/stable/reference/requirements-file-format/
runpod~=1.7.0

View File

@@ -1,577 +0,0 @@
# # This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
# # This can also be a relative path to a model on disk
# base_model: ./llama-7b-hf
# # You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
# base_model_ignore_patterns:
# # If the base_model repo on hf hub doesn't include configuration .json files,
# # You can set that here, or leave this empty to default to base_model
# base_model_config: ./llama-7b-hf
# # You can specify to choose a specific model revision from huggingface hub
# model_revision:
# # Optional tokenizer configuration override in case you want to use a different tokenizer
# # than the one defined in the base model
# tokenizer_config:
# # If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
# model_type: AutoModelForCausalLM
# # Corresponding tokenizer for the model AutoTokenizer is a good choice
# tokenizer_type: AutoTokenizer
# # Trust remote code for untrusted source
# trust_remote_code:
# # use_fast option for tokenizer loading from_pretrained, default to True
# tokenizer_use_fast:
# # Whether to use the legacy tokenizer setting, defaults to True
# tokenizer_legacy:
# # Resize the model embeddings when new tokens are added to multiples of 32
# # This is reported to improve training speed on some models
# resize_token_embeddings_to_32x:
# # Used to identify which the model is based on
# is_falcon_derived_model:
# is_llama_derived_model:
# # Please note that if you set this to true, `padding_side` will be set to "left" by default
# is_mistral_derived_model:
# is_qwen_derived_model:
# # optional overrides to the base model configuration
# model_config:
# # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
# rope_scaling:
# type: # linear | dynamic
# factor: # float
# # Whether you are training a 4-bit GPTQ quantized model
# gptq: true
# gptq_groupsize: 128 # group size
# gptq_model_v1: false # v1 or v2
# # This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
# load_in_8bit: true
# # Use bitsandbytes 4 bit
# load_in_4bit:
# # Use CUDA bf16
# bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere
# # Use CUDA fp16
# fp16: true
# # Use CUDA tf32
# tf32: true # require >=ampere
# # No AMP (automatic mixed precision)
# bfloat16: true # require >=ampere
# float16: true
# # A list of one or more datasets to finetune the model with
# datasets:
# # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
# - path: vicgalle/alpaca-gpt4
# # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
# type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
# ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
# data_files: # Optional[str] path to source data files
# shards: # Optional[int] number of shards to split data into
# name: # Optional[str] name of dataset configuration to load
# train_on_split: train # Optional[str] name of dataset split to load from
# # Optional[str] fastchat conversation type, only used with type: sharegpt
# conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# field_human: # Optional[str]. Human key to use for conversation.
# field_model: # Optional[str]. Assistant key to use for conversation.
# # Custom user prompt
# - path: repo
# type:
# # The below are defaults. only set what's needed.
# system_prompt: ""
# system_format: "{system}"
# field_system: system
# field_instruction: instruction
# field_input: input
# field_output: output
# # Customizable to be single line or multi-line
# # 'format' can include {input}
# format: |-
# User: {instruction} {input}
# Assistant:
# # 'no_input_format' cannot include {input}
# no_input_format: "{instruction} "
# # For `completion` datsets only, uses the provided field instead of `text` column
# field:
# # Axolotl attempts to save the dataset as an arrow after packing the data together so
# # subsequent training attempts load faster, relative path
# dataset_prepared_path: data/last_run_prepared
# # Push prepared dataset to hub
# push_dataset_to_hub: # repo path
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# # if not set.
# dataset_processes: # defaults to os.cpu_count() if not set
# # push checkpoints to hub
# hub_model_id: # repo path to push finetuned model
# # how to push checkpoints to hub
# # https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
# hub_strategy:
# # Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# # Required to be true when used in combination with `push_dataset_to_hub`
# hf_use_auth_token: # boolean
# # How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
# val_set_size: 0.04
# # Num shards for whole dataset
# dataset_shard_num:
# # Index of shard to use for whole dataset
# dataset_shard_idx:
# # The maximum length of an input to train with, this should typically be less than 2048
# # as most models have a token/context limit of 2048
# sequence_len: 2048
# # Pad inputs so each step uses constant sized buffers
# # This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
# pad_to_sequence_len:
# # Max sequence length to concatenate training samples together up to
# # Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# # FutureWarning: This will soon be DEPRECATED
# max_packed_sequence_len: 1024
# # Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
# sample_packing:
# # Set to 'false' if getting errors during eval with sample_packing on.
# eval_sample_packing:
# # You can set these packing optimizations AFTER starting a training at least once.
# # The trainer will provide recommended values for these values.
# sample_packing_eff_est:
# total_num_tokens:
# # If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
# adapter: lora
# # If you already have a lora model trained that you want to load, put that here.
# # This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
# lora_model_dir:
# # LoRA hyperparameters
# # For more details about the following options, see:
# # https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
# lora_r: 8
# lora_alpha: 16
# lora_dropout: 0.05
# lora_target_modules:
# - q_proj
# - v_proj
# # - k_proj
# # - o_proj
# # - gate_proj
# # - down_proj
# # - up_proj
# lora_target_linear: # If true, will target all linear layers
# # If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# # For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
# # `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
# # https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
# lora_modules_to_save:
# # - embed_tokens
# # - lm_head
# # Once you complete training, the model will be saved to the following directory.
# # If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
# # Make sure `lora_model_dir` points to this directory if you want to use the trained model.
# lora_out_dir:
# lora_fan_in_fan_out: false
# # ReLoRA configuration
# # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
# relora_steps: # Number of steps per ReLoRA restart
# relora_warmup_steps: # Number of per-restart warmup steps
# relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# # wandb configuration if you're using it
# wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
# wandb_project: # Your wandb project name
# wandb_entity: # A wandb Team name if using a Team
# wandb_watch:
# wandb_run_id: # Set the name of your wandb run
# wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# # Where to save the full-finetuned model to
# output_dir: ./completed-model
# # Whether to use torch.compile and which backend to use
# torch_compile: # bool
# torch_compile_backend: # Optional[str]
# # Training hyperparameters
# # If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
# gradient_accumulation_steps: 1
# # The number of samples to include in each batch. This is the number of samples sent to each GPU.
# micro_batch_size: 2
# eval_batch_size:
# num_epochs: 4
# warmup_steps: 100 # cannot use with warmup_ratio
# warmup_ratio: 0.05 # cannot use with warmup_steps
# learning_rate: 0.00003
# lr_quadratic_warmup:
# logging_steps:
# save_strategy: # Set to `no` to skip checkpoint saves
# save_steps: # Leave empty to save at each epoch
# eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
# save_total_limit: # Checkpoints saved at a time
# # Maximum number of iterations to train for. It precedes num_epochs which means that
# # if both are set, num_epochs will not be guaranteed.
# # e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
# max_steps:
# eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
# eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
# # Save model as safetensors (require safetensors package)
# save_safetensors:
# # Whether to mask out or include the human's prompt from the training labels
# train_on_inputs: false
# # Group similarly sized data to minimize padding.
# # May be slower to start, as it must download and sort the entire dataset.
# # Note that training loss may have an oscillating pattern with this enabled.
# group_by_length: false
# # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
# gradient_checkpointing: false
# # Stop training after this many evaluation losses have increased in a row
# # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
# early_stopping_patience: 3
# # Specify a scheduler and kwargs to use with the optimizer
# lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
# lr_scheduler_kwargs:
# # For one_cycle optim
# lr_div_factor: # Learning rate div factor
# # For log_sweep optim
# log_sweep_min_lr:
# log_sweep_max_lr:
# # Specify optimizer
# # Valid values are driven by the Transformers OptimizerNames class, see:
# # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
# #
# # Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
# # torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
# # in the examples/ for your model and fine-tuning use case.
# #
# # Valid values for 'optimizer' include:
# # - adamw_hf
# # - adamw_torch
# # - adamw_torch_fused
# # - adamw_torch_xla
# # - adamw_apex_fused
# # - adafactor
# # - adamw_anyprecision
# # - sgd
# # - adagrad
# # - adamw_bnb_8bit
# # - lion_8bit
# # - lion_32bit
# # - paged_adamw_32bit
# # - paged_adamw_8bit
# # - paged_lion_32bit
# # - paged_lion_8bit
# optimizer:
# # Specify weight decay
# weight_decay:
# # adamw hyperparams
# adam_beta1:
# adam_beta2:
# adam_epsilon:
# # Gradient clipping max norm
# max_grad_norm:
# # Augmentation techniques
# # NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
# # currently only supported on Llama and Mistral
# noisy_embedding_alpha:
# # Whether to bettertransformers
# flash_optimum:
# # Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
# xformers_attention:
# # Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
# flash_attention:
# flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
# flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
# flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
# flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# # Whether to use scaled-dot-product attention
# # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
# sdp_attention:
# # Landmark attention (only llama)
# landmark_attention:
# # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# # LLaMA only
# xpos_rope:
# # Resume from a specific checkpoint dir
# resume_from_checkpoint:
# # If resume_from_checkpoint isn't set and you simply want it to start where it left off.
# # Be careful with this being turned on between different models.
# auto_resume_from_checkpoints: false
# # Don't mess with this, it's here for accelerate and torchrun
# local_rank:
# # Add or change special tokens.
# # If you add tokens here, you don't need to add them to the `tokens` list.
# special_tokens:
# # bos_token: "<s>"
# # eos_token: "</s>"
# # unk_token: "<unk>"
# # Add extra tokens.
# tokens:
# # FSDP
# fsdp:
# fsdp_config:
# # Deepspeed config path. e.g., deepspeed/zero3.json
# deepspeed:
# # Advanced DDP Arguments
# ddp_timeout:
# ddp_bucket_cap_mb:
# ddp_broadcast_buffers:
# # Path to torch distx for optim 'adamw_anyprecision'
# torchdistx_path:
# # Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
# pretraining_dataset:
# # Debug mode
# debug:
# # Seed
# seed:
# # Allow overwrite yml config using from cli
# strict:
base_model: ${BASE_MODEL}
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
base_model_config: ${BASE_MODEL_CONFIG}
revision_of_model: ${REVISION_OF_MODEL}
tokenizer_config: ${TOKENIZER_CONFIG}
model_type: ${MODEL_TYPE}
tokenizer_type: ${TOKENIZER_TYPE}
trust_remote_code: ${TRUST_REMOTE_CODE}
tokenizer_use_fast: ${TOKENIZER_USE_FAST}
tokenizer_legacy: ${TOKENIZER_LEGACY}
resize_token_embeddings_to_32x: ${RESIZE_TOKEN_EMBEDDINGS_TO_32X}
is_falcon_derived_model: ${IS_FALCON_DERIVED_MODEL}
is_llama_derived_model: ${IS_LLAMA_DERIVED_MODEL}
is_qwen_derived_model: ${IS_QWEN_DERIVED_MODEL}
is_mistral_derived_model: ${IS_MISTRAL_DERIVED_MODEL}
overrides_of_model_config:
rope_scaling:
type: ${ROPE_SCALING_TYPE}
factor: ${ROPE_SCALING_FACTOR}
bnb_config_kwargs:
llm_int8_has_fp16_weight: ${BNB_LLM_INT8_HAS_FP16_WEIGHT}
bnb_4bit_quant_type: ${BNB_4BIT_QUANT_TYPE}
bnb_4bit_use_double_quant: ${BNB_4BIT_USE_DOUBLE_QUANT}
gptq: ${GPTQ}
load_in_8bit: ${LOAD_IN_8BIT}
load_in_4bit: ${LOAD_IN_4BIT}
bf16: ${BF16}
fp16: ${FP16}
tf32: ${TF32}
bfloat16: ${BFLOAT16}
float16: ${FLOAT16}
gpu_memory_limit: ${GPU_MEMORY_LIMIT}
lora_on_cpu: ${LORA_ON_CPU}
datasets:
- path: ${DATASET_PATH}
type: ${DATASET_TYPE}
ds_type: ${DATASET_DS_TYPE}
data_files: ${DATASET_DATA_FILES}
shards: ${DATASET_SHARDS}
name: ${DATASET_NAME}
train_on_split: ${DATASET_TRAIN_ON_SPLIT}
revision: ${DATASET_REVISION}
trust_remote_code: ${DATASET_TRUST_REMOTE_CODE}
rl: ${RL}
dpo_use_weighting: ${DPO_USE_WEIGHTING}
chat_template: ${CHAT_TEMPLATE}
chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
dataset_prepared_path: ${DATASET_PREPARED_PATH}
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
dataset_processes: ${DATASET_PROCESSES}
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
hub_model_id: ${HUB_MODEL_ID}
hub_strategy: ${HUB_STRATEGY}
hf_use_auth_token: ${HF_USE_AUTH_TOKEN}
val_set_size: ${VAL_SET_SIZE}
dataset_shard_num: ${DATASET_SHARD_NUM}
dataset_shard_idx: ${DATASET_SHARD_IDX}
sequence_len: ${SEQUENCE_LEN}
pad_to_sequence_len: ${PAD_TO_SEQUENCE_LEN}
sample_packing: ${SAMPLE_PACKING}
eval_sample_packing: ${EVAL_SAMPLE_PACKING}
sample_packing_eff_est: ${SAMPLE_PACKING_EFF_EST}
total_num_tokens: ${TOTAL_NUM_TOKENS}
sample_packing_group_size: ${SAMPLE_PACKING_GROUP_SIZE}
sample_packing_bin_size: ${SAMPLE_PACKING_BIN_SIZE}
batch_flattening: ${BATCH_FLATTENING}
device_map: ${DEVICE_MAP}
max_memory: ${MAX_MEMORY}
adapter: ${ADAPTER}
lora_model_dir: ${LORA_MODEL_DIR}
lora_r: ${LORA_R}
lora_alpha: ${LORA_ALPHA}
lora_dropout: ${LORA_DROPOUT}
lora_target_modules:
- ${LORA_TARGET_MODULES}
lora_target_linear: ${LORA_TARGET_LINEAR}
peft_layers_to_transform: ${PEFT_LAYERS_TO_TRANSFORM}
lora_modules_to_save: ${LORA_MODULES_TO_SAVE}
lora_fan_in_fan_out: ${LORA_FAN_IN_FAN_OUT}
loraplus_lr_ratio: ${LORAPLUS_LR_RATIO}
loraplus_lr_embedding: ${LORAPLUS_LR_EMBEDDING}
peft:
loftq_config:
loftq_bits: ${LOFTQ_BITS}
relora_steps: ${RELORA_STEPS}
relora_warmup_steps: ${RELORA_WARMUP_STEPS}
relora_anneal_steps: ${RELORA_ANNEAL_STEPS}
relora_prune_ratio: ${RELORA_PRUNE_RATIO}
relora_cpu_offload: ${RELORA_CPU_OFFLOAD}
wandb_mode: ${WANDB_MODE}
wandb_project: ${WANDB_PROJECT}
wandb_entity: ${WANDB_ENTITY}
wandb_watch: ${WANDB_WATCH}
wandb_name: ${WANDB_NAME}
wandb_run_id: ${WANDB_RUN_ID}
wandb_log_model: ${WANDB_LOG_MODEL}
mlflow_tracking_uri: ${MLFLOW_TRACKING_URI}
mlflow_experiment_name: ${MLFLOW_EXPERIMENT_NAME}
mlflow_run_name: ${MLFLOW_RUN_NAME}
hf_mlflow_log_artifacts: ${HF_MLFLOW_LOG_ARTIFACTS}
use_comet: ${USE_COMET}
comet_api_key: ${COMET_API_KEY}
comet_workspace: ${COMET_WORKSPACE}
comet_project_name: ${COMET_PROJECT_NAME}
comet_experiment_key: ${COMET_EXPERIMENT_KEY}
comet_mode: ${COMET_MODE}
comet_online: ${COMET_ONLINE}
comet_experiment_config: ${COMET_EXPERIMENT_CONFIG}
output_dir: ${OUTPUT_DIR}
torch_compile: ${TORCH_COMPILE}
torch_compile_backend: ${TORCH_COMPILE_BACKEND}
gradient_accumulation_steps: ${GRADIENT_ACCUMULATION_STEPS}
micro_batch_size: ${MICRO_BATCH_SIZE}
eval_batch_size: ${EVAL_BATCH_SIZE}
num_epochs: ${NUM_EPOCHS}
warmup_steps: ${WARMUP_STEPS}
warmup_ratio: ${WARMUP_RATIO}
learning_rate: ${LEARNING_RATE}
lr_quadratic_warmup: ${LR_QUADRATIC_WARMUP}
logging_steps: ${LOGGING_STEPS}
eval_steps: ${EVAL_STEPS}
evals_per_epoch: ${EVALS_PER_EPOCH}
save_strategy: ${SAVE_STRATEGY}
save_steps: ${SAVE_STEPS}
saves_per_epoch: ${SAVES_PER_EPOCH}
save_total_limit: ${SAVE_TOTAL_LIMIT}
max_steps: ${MAX_STEPS}
eval_table_size: ${EVAL_TABLE_SIZE}
eval_max_new_tokens: ${EVAL_MAX_NEW_TOKENS}
eval_causal_lm_metrics: ${EVAL_CAUSAL_LM_METRICS}
profiler_steps: ${PROFILER_STEPS}
loss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD}
loss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE}
save_safetensors: ${SAVE_SAFETENSORS}
train_on_inputs: ${TRAIN_ON_INPUTS}
group_by_length: ${GROUP_BY_LENGTH}
gradient_checkpointing: ${GRADIENT_CHECKPOINTING}
early_stopping_patience: ${EARLY_STOPPING_PATIENCE}
lr_scheduler: ${LR_SCHEDULER}
lr_scheduler_kwargs: ${LR_SCHEDULER_KWARGS}
cosine_min_lr_ratio: ${COSINE_MIN_LR_RATIO}
cosine_constant_lr_ratio: ${COSINE_CONSTANT_LR_RATIO}
lr_div_factor: ${LR_DIV_FACTOR}
optimizer: ${OPTIMIZER}
optim_args: ${OPTIM_ARGS}
optim_target_modules: ${OPTIM_TARGET_MODULES}
weight_decay: ${WEIGHT_DECAY}
adam_beta1: ${ADAM_BETA1}
adam_beta2: ${ADAM_BETA2}
adam_epsilon: ${ADAM_EPSILON}
max_grad_norm: ${MAX_GRAD_NORM}
neftune_noise_alpha: ${NEFTUNE_NOISE_ALPHA}
flash_optimum: ${FLASH_OPTIMUM}
xformers_attention: ${XFORMERS_ATTENTION}
flash_attention: ${FLASH_ATTENTION}
flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY}
flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM}
flash_attn_fuse_qkv: ${FLASH_ATTN_FUSE_QKV}
flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP}
sdp_attention: ${SDP_ATTENTION}
s2_attention: ${S2_ATTENTION}
resume_from_checkpoint: ${RESUME_FROM_CHECKPOINT}
auto_resume_from_checkpoints: ${AUTO_RESUME_FROM_CHECKPOINTS}
local_rank: ${LOCAL_RANK}
special_tokens:
bos_token: ${SPECIAL_TOKEN_BOS}
eos_token: ${SPECIAL_TOKEN_EOS}
unk_token: ${SPECIAL_TOKEN_UNK}
pad_token: ${SPECIAL_TOKEN_PAD}
tokens: ${TOKENS}
fsdp: ${FSDP}
fsdp_config: ${FSDP_CONFIG}
deepspeed: ${DEEPSPEED}
ddp_timeout: ${DDP_TIMEOUT}
ddp_bucket_cap_mb: ${DDP_BUCKET_CAP_MB}
ddp_broadcast_buffers: ${DDP_BROADCAST_BUFFERS}
torchdistx_path: ${TORCHDISTX_PATH}
pretraining_dataset: ${PRETRAINING_DATASET}
debug: ${DEBUG}
seed: ${SEED}
strict: ${STRICT}

View File

@@ -1,64 +0,0 @@
"""
Runpod serverless entrypoint handler
"""
import os
import runpod
import yaml
from huggingface_hub._login import login
from train import train
from utils import get_output_dir
BASE_VOLUME = os.environ.get("BASE_VOLUME", "/runpod-volume")
if not os.path.exists(BASE_VOLUME):
os.makedirs(BASE_VOLUME)
logger = runpod.RunPodLogger()
async def handler(job):
runpod_job_id = job["id"]
inputs = job["input"]
run_id = inputs.get("run_id", "default_run_id")
args = inputs.get("args", {})
# Set output directory
output_dir = os.path.join(BASE_VOLUME, get_output_dir(run_id))
args["output_dir"] = output_dir
# First save args to a temporary config file
config_path = "/workspace/test_config.yaml"
# Add run_name and job_id to args before saving
args["run_name"] = run_id
args["runpod_job_id"] = runpod_job_id
yaml_data = yaml.dump(args, default_flow_style=False)
with open(config_path, "w", encoding="utf-8") as file:
file.write(yaml_data)
# Handle credentials
credentials = inputs.get("credentials", {})
if "wandb_api_key" in credentials:
os.environ["WANDB_API_KEY"] = credentials["wandb_api_key"]
if "hf_token" in credentials:
os.environ["HF_TOKEN"] = credentials["hf_token"]
if os.environ.get("HF_TOKEN"):
login(token=os.environ["HF_TOKEN"])
else:
logger.info("No HF_TOKEN provided. Skipping login.")
logger.info("Starting Training.")
async for result in train(config_path): # Pass the config path instead of args
logger.info(result)
logger.info("Training Complete.")
# Cleanup
del os.environ["WANDB_API_KEY"]
del os.environ["HF_TOKEN"]
runpod.serverless.start({"handler": handler, "return_aggregate_stream": True})

View File

@@ -1,61 +0,0 @@
{
"input": {
"user_id": "user",
"model_id": "llama-test",
"run_id": "llama-test",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "NousResearch/Meta-Llama-3-8B",
"model_type": "LlamaForCausalLM",
"tokenizer_type": "AutoTokenizer",
"load_in_8bit": true,
"load_in_4bit": false,
"strict": false,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca"
}
],
"val_set_size": 0.05,
"output_dir": "./outputs/lora-out",
"sequence_len": 4096,
"sample_packing": true,
"eval_sample_packing": false,
"pad_to_sequence_len": true,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": true,
"lora_modules_to_save": [
"embed_tokens",
"lm_head"
],
"gradient_accumulation_steps": 4,
"micro_batch_size": 2,
"num_epochs": 1,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": false,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"warmup_steps": 1,
"evals_per_epoch": 1,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"weight_decay": 0.0,
"special_tokens": {
"pad_token": "<|end_of_text|>"
}
}
}
}

View File

@@ -1,45 +0,0 @@
"""
Runpod train entrypoint
"""
import asyncio
async def train(config_path: str, gpu_id: str = "0", preprocess: bool = True):
"""
Run preprocessing (if enabled) and training with the given config file
:param config_path: Path to the YAML config file
:param gpu_id: GPU ID to use (default: "0")
:param preprocess: Whether to run preprocessing (default: True)
"""
# First check if preprocessing is needed
if preprocess:
# Preprocess command
preprocess_cmd = (
f"CUDA_VISIBLE_DEVICES={gpu_id} axolotl preprocess {config_path}"
)
process = await asyncio.create_subprocess_shell(
preprocess_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
if process.stdout is not None:
async for line in process.stdout:
yield f"Preprocessing: {line.decode().strip()}"
await process.wait()
yield "Preprocessing completed."
else:
yield "Skipping preprocessing step."
# Training command
train_cmd = f"axolotl train {config_path}"
process = await asyncio.create_subprocess_shell(
train_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
)
if process.stdout is not None:
async for line in process.stdout:
yield f"Training: {line.decode().strip()}"
await process.wait()

View File

@@ -1,89 +0,0 @@
"""
Runpod launcher utils
"""
import os
import yaml
def get_output_dir(run_id):
path = f"fine-tuning/{run_id}"
return path
def make_valid_config(input_args):
"""
Creates and saves updated config file, returns the path to the new config
:param input_args: dict of input args
:return: str, path to the updated config file
"""
# Load default config
with open("config/config.yaml", "r", encoding="utf-8") as fin:
all_args = yaml.safe_load(fin)
if not input_args:
print("No args provided, using defaults")
else:
all_args.update(input_args)
# Create updated config path
updated_config_path = "config/updated_config.yaml"
# Save updated config to new file
with open(updated_config_path, "w", encoding="utf-8") as f:
yaml.dump(all_args, f)
return updated_config_path
def set_config_env_vars(args: dict):
"""
Convert API arguments into environment variables.
Handles nested dictionaries, lists, and special values.
Args:
args (dict): The arguments dictionary from the API request
"""
def process_value(value):
"""Convert Python values to string format for environment variables"""
if value is None:
return ""
if isinstance(value, bool):
return str(value).lower()
if isinstance(value, (list, dict)):
return str(value)
return str(value)
def set_env_vars(data, prefix=""):
"""Recursively set environment variables from nested dictionary"""
for key, value in data.items():
env_key = prefix + key.upper()
# Handle special cases
if isinstance(value, dict):
# For nested dictionaries (like special_tokens)
set_env_vars(value, f"{env_key}_")
elif isinstance(value, list):
# Handle list of dictionaries (like datasets)
if value and isinstance(value[0], dict):
for i, item in enumerate(value):
set_env_vars(item, f"{env_key}_{i}_")
else:
# For simple lists (like lora_target_modules)
os.environ[env_key] = process_value(value)
else:
# Handle all other cases
os.environ[env_key] = process_value(value)
# Clear any existing related environment variables
# This prevents old values from persisting
for key in list(os.environ.keys()):
if key.startswith(
("BASE_MODEL", "MODEL_TYPE", "TOKENIZER_TYPE", "DATASET", "LORA_", "WANDB_")
):
del os.environ[key]
# Set new environment variables
set_env_vars(args)

View File

@@ -1,86 +0,0 @@
{
"input": {
"name": "quick_smoke_test_sft",
"user_id": "user",
"model_id": "llama-test",
"run_id": "llama-test",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"load_in_4bit": true,
"strict": false,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"split": "train[:10%]"
}
],
"val_set_size": 0.02,
"output_dir": "./outputs/lora-out",
"sequence_len": 4096,
"sample_packing": true,
"eval_sample_packing": false,
"pad_to_sequence_len": true,
"adapter": "qlora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": true,
"lora_modules_to_save": [
"embed_tokens",
"lm_head"
],
"gradient_accumulation_steps": 2,
"micro_batch_size": 1,
"num_epochs": 1,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": true,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"warmup_steps": 1,
"evals_per_epoch": 1,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"weight_decay": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>"
},
"max_steps": 20
},
"timeout": 100000
},
"config": {
"gpuTypeId": "NVIDIA GeForce RTX 4090",
"gpuCount": 1,
"containerDiskInGb": 200,
"env": [
{
"key": "TOKENIZER",
"value": ""
},
{
"key": "DISABLE_LOG_STATS",
"value": "true"
}
],
"allowedCudaVersions": [
"12.8",
"12.7",
"12.6",
"12.5",
"12.4"
]
}
}

View File

@@ -1,90 +0,0 @@
{
"tests": [
{
"name": "quick_smoke_test_sft",
"input": {
"user_id": "user",
"model_id": "llama-test",
"run_id": "llama-test",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"load_in_4bit": true,
"strict": false,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"split": "train[:10%]"
}
],
"val_set_size": 0.02,
"output_dir": "./outputs/lora-out",
"sequence_len": 4096,
"sample_packing": true,
"eval_sample_packing": false,
"pad_to_sequence_len": true,
"adapter": "qlora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": true,
"lora_modules_to_save": [
"embed_tokens",
"lm_head"
],
"gradient_accumulation_steps": 2,
"micro_batch_size": 1,
"num_epochs": 1,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": true,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"warmup_steps": 1,
"evals_per_epoch": 1,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"weight_decay": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>"
},
"max_steps": 20
}
},
"timeout": 100000
}
],
"config": {
"gpuTypeId": "NVIDIA GeForce RTX 4090",
"gpuCount": 1,
"containerDiskInGb": 200,
"env": [
{
"key": "TOKENIZER",
"value": ""
},
{
"key": "DISABLE_LOG_STATS",
"value": "true"
}
],
"allowedCudaVersions": [
"12.8",
"12.7",
"12.6",
"12.5",
"12.4"
]
}
}

View File

@@ -154,10 +154,6 @@ datasets:
# Key containing the messages (default: "messages")
field_messages: messages
# Key containing the system message (default: "system")
# If the system message is not present in the dataset sample, it will be loaded from the field_system property.
field_system: system
# Mapping of properties from the input dataset to the chat template.
# (default: message_property_mappings={'role':'role', 'content':'content'})
# If a property exists in the template but not in this mapping, the system will attempt
@@ -184,14 +180,10 @@ datasets:
# adding a system turn with empty content.
drop_system_message:
# Optional[bool]. Whether to split the assistant turn based on a reasoning trace inside delimited tags
# defaults to False
split_thinking:
# IMPORTANT: The following fields determine which parts of the conversation to train on.
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
# See examples at `docs/dataset-formats/conversation.qmd`
# Note: If the below 5 fields are empty, defaults to training only on the last message.
# Note: If the below 4 fields are set to empty, defaults to training only on the last message.
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
roles_to_train: ["assistant"] # default
@@ -200,13 +192,7 @@ datasets:
# - turn (default): train on the EOS token at the end of each trainable turn
# - last: train on the last EOS token in the conversation
# TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
train_on_eos: turn
# Optional[str]. Which EOT (End-of-Turn) tokens to train on in the conversation. Possible values are:
# - all: train on all EOT tokens
# - turn: train on the EOT token at the end of each trainable turn
# - last: train on the last EOT token in the conversation
# If not specified, defaults to the value of train_on_eos for backward compatibility.
train_on_eot:
train_on_eos: last
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
message_field_training: training
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
@@ -289,17 +275,8 @@ process_reward_model:
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Optional[List[str]]. Custom EOT (End-of-Turn) tokens to mask/unmask during training.
# These tokens mark the boundaries between conversation turns.
# For example: ["/INST", "</s>", "[/SYSTEM_PROMPT]"]
# If not specified, defaults to just the model's eos_token.
# This is useful for templates that use multiple delimiter tokens.
eot_tokens:
# - "</s>"
# - "[/INST]"
# - "[/SYSTEM_PROMPT]"
# Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Changes the default system message. Currently only supports chatml.
default_system_message: You are a helpful assistant. Please give a long and detailed answer.
# Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
@@ -547,7 +524,7 @@ gradient_checkpointing: false
early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | 'linear' | 'cosine_with_restarts' | 'polynomial' | 'constant' | 'constant_with_warmup' | 'inverse_sqrt' | 'reduce_lr_on_plateau' | 'cosine_with_min_lr' | 'warmup_stable_decay' | empty for cosine
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
@@ -684,10 +661,8 @@ special_tokens:
# unk_token: "<unk>"
# pad_token: "[PAD]"
# Optional[list[str]]. Add extra tokens to the tokenizer.
# Add extra tokens.
tokens:
# - "<|startoftext|>"
# - "<|endoftext|>"
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
# Only works for tokens that are not part of the base vocab (aka are added_tokens).

View File

@@ -4,6 +4,18 @@ description: Conversation format for supervised fine-tuning.
order: 3
---
## sharegpt
::: {.callout-important}
ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below.
:::
## pygmalion
```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "value": "..."}]}
```
## chat_template
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
@@ -52,7 +64,7 @@ We recommend checking the below examples for other usecases.
### Examples
1. (Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
```yaml
datasets:
@@ -97,55 +109,10 @@ datasets:
```
::: {.callout-important}
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
:::
5. If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
```yaml
eot_tokens:
- "[/INST]"
# - "[/SYSTEM_PROMPT]"
datasets:
- path: ...
type: chat_template
# optional
train_on_eot: turn # defaults read from train_on_eos (which defaults to turn)
```
::: {.callout-tip}
See [config documentation](../config.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens.
:::
::: {.callout-note}
Using `eot_tokens` requires each token that exists in `chat_template` to be a single token in the tokenizer. Otherwise, the tokenizer will split the token and cause unexpected behavior.
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details.
:::
6. Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
```yaml
eot_tokens:
- "[/INST]"
# ...
datasets:
- path: ...
type: chat_template
train_on_eos: last
train_on_eot: turn
```
::: {.callout-tip}
If EOS token only appears at the end of a prompt, `train_on_eos: last` is equivalent to `train_on_eos: turn`. Therefore, generally, you can leave them to their defaults and omit them.
:::
7. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
@@ -195,15 +162,3 @@ datasets:
::: {.callout-tip}
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
:::
## sharegpt
::: {.callout-important}
ShareGPT is deprecated!. Please see [chat_template](#chat_template) section.
:::
## pygmalion
```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "value": "..."}]}
```

View File

@@ -73,40 +73,10 @@ description: Frequently asked questions
> A: This is likely an empty turn.
**Q: The EOS token is incorrectly being masked or not being masked / `EOS token __ not found in chat template`.**
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
> A: There can be two reasons:
> 1. This is because of the mismatch between `tokenizer.eos_token` and EOS token in template. Please make sure to set `eos_token: ` under `special_tokens: ` to the same EOS token as in template.
> 2. The EOS token is not in the template. Please check if your template is correct. As an example, `phi_35` template does not use its dedicated EOS token `<|endoftext|>` at the end.
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.
**Q: The EOT token(s) are incorrectly being masked or not being masked / `EOT token __ not found in chat template`.**
> A: There can be two reasons:
> 1. The EOT token is different from the EOS token and was not specified under `eot_tokens: `. Please set `eot_tokens: ` to the same EOT token(s) as in template.
> 2. There is more than one EOT token per turn in the template. Please raise an issue with examples as we recognize this as an edge case.
**Q: `EOT token encoding failed. Please check if the token is valid and can be encoded.`**
> A: There could be some issue with the tokenizer or unicode encoding. Please raise an issue with examples with the EOT token & tokenizer causing the issue.
**Q: `EOT token __ is encoded as multiple tokens.`**
> A: This is because the EOT token is encoded as multiple tokens which can cause unexpected behavior. Please add it under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `.
**Q: `Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot`**
> A: This is because the EOS token is in the `eot_tokens: ` while mismatch between `train_on_eos: ` and `train_on_eot: `. This will cause one to override the other. Please ensure that `train_on_eos: ` and `train_on_eot: ` are the same or remove the EOS token from `eot_tokens: `.
**Q: If `eot_tokens: ` is not provided, what happens?**
> A: If `eot_tokens: ` is not provided, the default behavior is the same as before. EOS tokens used to delimit turns are masked/unmasked depending on whether the turn is trainable.
> Internally, `eot_tokens: tokenizer.eos_token` and `train_on_eot: train_on_eos` (which defaults to `turn`). This transition helps clarify the naming and behavior of EOT/EOS tokens.

View File

@@ -164,7 +164,7 @@ Here is an example of a multi-modal dataset:
{
"role": "user",
"content": [
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
{"type": "text", "text": "Describe this image in detail."}
]
},

View File

@@ -502,7 +502,9 @@ The input format is a simple JSON input with customizable fields based on the ab
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
:::
In the latest GRPO implementation, `vLLM` is used to significantly speedup trajectory generation during training. In this example, we're using 4 GPUs - 2 for training, and 2 for vLLM:
If you have multiple GPUs available, we reccomend using `vLLM` with the `GRPOTrainer` to significantly speedup trajectory generation during training.
First, launch a `vLLM` server using `trl vllm-serve` - you may use a config file or CLI overrides to configure your vLLM server. In this example, we're
using 4 GPUs - 2 for training, and 2 for vLLM:
::: {.callout-important}
Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`.
@@ -537,10 +539,6 @@ Your `vLLM` instance will now attempt to spin up, and it's time to kick off trai
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
```
::: {.callout-note}
Due to TRL's implementation with vLLM, the vLLM instance must use the last N GPUs instead of the first N GPUs. This is why in the example above, we use `CUDA_VISIBLE_DEVICES=2,3` for the vLLM instance.
:::
#### Reward functions
GRPO uses custom reward functions and transformations. Please have them ready locally.

View File

@@ -1,341 +0,0 @@
# Finetuning LLMs to output audio
In this example, we finetune Orpcanopylabs/orpheus-tts-0.1-pretrained (a LLaMA 3.2 3b model) to output audio.
The `finetune.yml` withe current settings will run on any Nvidia GPU with 45GB VRAM or more. If you adjust the batch size it can easily run on any GPU under 24GB.
## Dataset pre-processing for pre-training
If you are adding another voice in English, please jump ahead to finetuning pre-processing.
For this to work, we need to preprocess our dataset. Since we are expecting to output audio, we will need to add tokens to the tokenizer.
Using this code, it will download the SNAC model and add the correct tokens and upload the final dataset.
```python
import torch
from snac import SNAC
from datasets import load_dataset
from huggingface_hub import snapshot_download
from datasets import load_dataset
import random
import torchaudio.transforms as T
from transformers import AutoTokenizer
import os
my_original_dataset_name = "<huggingface-id-of-dataset-that-we-want-to-preprocess>"
name_to_push_dataset_to = "<huggingface-id-of-where-to-save-dataset>"
dsn = my_original_dataset_name
snapshot_download(
repo_id=dsn,
repo_type="dataset",
revision="main",
max_workers=64,
)
ds = load_dataset(dsn, split="train")
ds_sample_rate = ds[0]["audio"]["sampling_rate"]
model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
model = model.to("mps")
def tokenise_audio(waveform):
waveform = torch.from_numpy(waveform).unsqueeze(0)
waveform = waveform.to(dtype=torch.float32)
resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)
waveform = resample_transform(waveform)
waveform = waveform.unsqueeze(0).to("cuda")
#generate the codes from snac
with torch.inference_mode():
codes = model.encode(waveform)
all_codes = []
for i in range(codes[0].shape[1]):
all_codes.append(codes[0][0][i].item()+128266)
all_codes.append(codes[1][0][2*i].item()+128266+4096)
all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))
return all_codes
def add_codes(example):
# Always initialize codes_list to None
codes_list = None
try:
answer_audio = example.get("audio")
# If there's a valid audio array, tokenise it
if answer_audio and "array" in answer_audio:
audio_array = answer_audio["array"]
codes_list = tokenise_audio(audio_array)
except Exception as e:
print(f"Skipping row due to error: {e}")
# Keep codes_list as None if we fail
example["codes_list"] = codes_list
return example
ds = ds.map(add_codes, remove_columns=["audio"])
#@title Load Tokenizer
tokeniser_length = 128256
start_of_text = 128000
end_of_text = 128009
start_of_speech = tokeniser_length + 1
end_of_speech = tokeniser_length + 2
start_of_human = tokeniser_length + 3
end_of_human = tokeniser_length + 4
start_of_ai = tokeniser_length + 5
end_of_ai = tokeniser_length + 6
pad_token = tokeniser_length + 7
audio_tokens_start = tokeniser_length + 10
tokenizer_name = "canopylabs/orpheus-3b-0.1-pretrained"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
num_proc = os.cpu_count() - 2
ds = ds.filter(lambda x: x["codes_list"] is not None)
ds = ds.filter(lambda x: len(x["codes_list"]) > 0)
#@title Create Input Ids
def remove_duplicate_frames(example):
vals = example["codes_list"]
if len(vals) % 7 != 0:
raise ValueError("Input list length must be divisible by 7")
result = vals[:7]
removed_frames = 0
for i in range(7, len(vals), 7):
current_first = vals[i]
previous_first = result[-7]
if current_first != previous_first:
result.extend(vals[i:i+7])
else:
removed_frames += 1
example["codes_list"] = result
return example
ds = ds.map(remove_duplicate_frames, num_proc=num_proc)
def create_input_ids(example):
text_ids = tokenizer.encode({example['text']}, add_special_tokens=True)
text_ids.append(end_of_text)
example["text_tokens"] = text_ids
input_ids = (
[start_of_human]
+ example["text_tokens"]
+ [end_of_human]
+ [start_of_ai]
+ [start_of_speech]
+ example["codes_list"]
+ [end_of_speech]
+ [end_of_ai]
)
example["input_ids"] = input_ids
example["labels"] = input_ids
example["attention_mask"] = [1] * len(input_ids)
return example
ds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=["text", "codes_list"])
#@title Remove unnecessary columns
columns_to_keep = ["input_ids", "labels", "attention_mask"]
columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]
ds = ds.remove_columns(columns_to_remove)
ds.push_to_hub(name_to_push_dataset_to)
```
## Finetune pre-processing
Use this code to add a new voice.
```python
import torch
from snac import SNAC
from datasets import load_dataset
from huggingface_hub import snapshot_download
from datasets import load_dataset
import random
import torchaudio.transforms as T
from transformers import AutoTokenizer
import os
my_original_dataset_name = "<huggingface-id-of-dataset-that-we-want-to-preprocess>"
name_to_push_dataset_to = "<huggingface-id-of-where-to-save-dataset>"
dsn = my_original_dataset_name
snapshot_download(
repo_id=dsn,
repo_type="dataset",
revision="main",
max_workers=64,
)
ds = load_dataset(dsn, split="train")
ds_sample_rate = ds[0]["audio"]["sampling_rate"]
model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
model = model.to("mps")
def tokenise_audio(waveform):
waveform = torch.from_numpy(waveform).unsqueeze(0)
waveform = waveform.to(dtype=torch.float32)
resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)
waveform = resample_transform(waveform)
waveform = waveform.unsqueeze(0).to("cuda")
#generate the codes from snac
with torch.inference_mode():
codes = model.encode(waveform)
all_codes = []
for i in range(codes[0].shape[1]):
all_codes.append(codes[0][0][i].item()+128266)
all_codes.append(codes[1][0][2*i].item()+128266+4096)
all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))
return all_codes
def add_codes(example):
# Always initialize codes_list to None
codes_list = None
try:
answer_audio = example.get("audio")
# If there's a valid audio array, tokenise it
if answer_audio and "array" in answer_audio:
audio_array = answer_audio["array"]
codes_list = tokenise_audio(audio_array)
except Exception as e:
print(f"Skipping row due to error: {e}")
# Keep codes_list as None if we fail
example["codes_list"] = codes_list
return example
ds = ds.map(add_codes, remove_columns=["audio"])
#@title Load Tokenizer
tokeniser_length = 128256
start_of_text = 128000
end_of_text = 128009
start_of_speech = tokeniser_length + 1
end_of_speech = tokeniser_length + 2
start_of_human = tokeniser_length + 3
end_of_human = tokeniser_length + 4
start_of_ai = tokeniser_length + 5
end_of_ai = tokeniser_length + 6
pad_token = tokeniser_length + 7
audio_tokens_start = tokeniser_length + 10
tokenizer_name = "canopylabs/orpheus-3b-0.1-pretrained"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
num_proc = os.cpu_count() - 2
ds = ds.filter(lambda x: x["codes_list"] is not None)
ds = ds.filter(lambda x: len(x["codes_list"]) > 0)
#@title Create Input Ids
def remove_duplicate_frames(example):
vals = example["codes_list"]
if len(vals) % 7 != 0:
raise ValueError("Input list length must be divisible by 7")
result = vals[:7]
removed_frames = 0
for i in range(7, len(vals), 7):
current_first = vals[i]
previous_first = result[-7]
if current_first != previous_first:
result.extend(vals[i:i+7])
else:
removed_frames += 1
example["codes_list"] = result
return example
ds = ds.map(remove_duplicate_frames, num_proc=num_proc)
tok_info = '''*** HERE you can modify the text prompt
i.e. if you wanted a multispeaker model like canopylabs/orpheus-3b-0.1-ft, you can pass:
f"{example["source"]}: {example["text"]}", as is passed.
'''
print(tok_info)
def create_input_ids(example):
text_ids = tokenizer.encode(f"{example['speaker_id']}: {example['text']}", add_special_tokens=True)
text_ids.append(end_of_text)
example["text_tokens"] = text_ids
input_ids = (
[start_of_human]
+ example["text_tokens"]
+ [end_of_human]
+ [start_of_ai]
+ [start_of_speech]
+ example["codes_list"]
+ [end_of_speech]
+ [end_of_ai]
)
example["input_ids"] = input_ids
example["labels"] = input_ids
example["attention_mask"] = [1] * len(input_ids)
return example
ds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=["text", "codes_list"])
#@title Remove unnecessary columns
columns_to_keep = ["input_ids", "labels", "attention_mask"]
columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]
ds = ds.remove_columns(columns_to_remove)
ds.push_to_hub(name_to_push_dataset_to)
```
## Training
After preprocessing is done, fill out the blanks in finetune.yml and simply run `axolotl train finetune.yml`
## Inference
For inference, please refer to the original [orpheus github](https://github.com/canopyai/Orpheus-TTS/tree/main).

View File

@@ -1,52 +0,0 @@
base_model: canopylabs/orpheus-3b-0.1-pretrained
hub_model_id: <your-hub-model-id>
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true
datasets:
- path: <your-hf-dataset-id>
type: # leave empty to load pre-tokenized
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 4
num_epochs: 3
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: auto
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 20
evals_per_epoch: 5
saves_per_epoch: 5
weight_decay: 0.05
special_tokens:
pad_token: <custom_token_7>

View File

@@ -1,69 +0,0 @@
base_model: Qwen/Qwen3-32B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: offload
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,68 +0,0 @@
base_model: Qwen/Qwen3-8B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
adapter: qlora
lora_model_dir:
lora_r: 32
lora_alpha: 64
lora_dropout: 0.05
lora_target_linear: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:

View File

@@ -15,10 +15,10 @@ peft==0.15.2
transformers==4.51.3
tokenizers>=0.21.1
accelerate==1.6.0
datasets==3.5.1
datasets==3.5.0
deepspeed>=0.15.4
trl==0.17.0
hf_xet==1.1.0
hf_xet==1.0.0
hqq==0.2.5
optimum==1.16.2

View File

@@ -67,13 +67,13 @@ def parse_requirements(extras_require_map):
if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.5"]
extras_require_map["vllm"] = ["vllm==0.8.4"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append(
"xformers==0.0.29.post2"
) # vllm needs post2 w torch 2.6
extras_require_map["vllm"] = ["vllm==0.8.5"]
extras_require_map["vllm"] = ["vllm==0.8.4"]
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.10.0.dev0"
__version__ = "0.8.0"

View File

@@ -2,7 +2,4 @@
import os
from axolotl.logging_config import configure_logging
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
configure_logging()

View File

@@ -16,15 +16,8 @@ AXOLOTL_LOGO = """
@@@@ @@@@@@@@@@@@@@@@
"""
HAS_PRINTED_LOGO = False
def print_axolotl_text_art():
"""Prints axolotl ASCII art."""
global HAS_PRINTED_LOGO # pylint: disable=global-statement
if HAS_PRINTED_LOGO:
return
if is_main_process():
HAS_PRINTED_LOGO = True
print(AXOLOTL_LOGO)

View File

@@ -8,6 +8,9 @@ from accelerate.commands.config import config_args
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from axolotl.logging_config import configure_logging
configure_logging()
LOG = logging.getLogger(__name__)

View File

@@ -5,7 +5,6 @@ import logging
import os
import tempfile
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Union
from urllib.parse import urlparse
@@ -153,15 +152,7 @@ def prepare_plugins(cfg: DictDefault):
plugin_manager.register(plugin_name)
def plugin_set_cfg(cfg: DictDefault):
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
plugin_manager.cfg = cfg
def load_cfg(
config: str | Path | DictDefault = Path("examples/"), **kwargs
) -> DictDefault:
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
"""
Loads the `axolotl` configuration stored at `config`, validates it, and performs
various setup.
@@ -173,24 +164,13 @@ def load_cfg(
Returns:
`DictDefault` mapping configuration keys to values.
"""
if isinstance(config, (str, Path)):
config = check_remote_config(config)
if Path(config).is_dir():
config = choose_config(Path(config))
config = check_remote_config(config)
if Path(config).is_dir():
config = choose_config(Path(config))
# Load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
else:
cfg = config
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
temp_file.write(yaml.dump(config.to_dict()))
temp_file.close()
cfg.axolotl_config_path = temp_file.name
# Load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
@@ -204,6 +184,8 @@ def load_cfg(
else:
cfg[k] = kwargs[k]
cfg.axolotl_config_path = config
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
@@ -231,6 +213,5 @@ def load_cfg(
setup_wandb_env_vars(cfg)
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
plugin_set_cfg(cfg)
return cfg

View File

@@ -1,7 +1,6 @@
"""CLI to run evaluation on a model."""
import logging
import os
from pathlib import Path
from typing import Union
@@ -15,7 +14,6 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate
from axolotl.utils import patch_optimized_env
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
@@ -31,14 +29,10 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: CLI arguments.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
# pylint: disable=duplicate-code
print_axolotl_text_art()
check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
check_user_token()
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -28,8 +28,9 @@ from axolotl.cli.utils import (
fetch_from_github,
filter_none_kwargs,
)
from axolotl.cli.vllm_serve import do_vllm_serve
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import patch_optimized_env
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.schemas.config import AxolotlInputConfig
@@ -55,8 +56,6 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
patch_optimized_env()
if cloud:
from axolotl.cli.cloud import do_cli_preprocess
@@ -102,7 +101,7 @@ def train(
config options.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
set_pytorch_cuda_alloc_conf()
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
@@ -328,8 +327,6 @@ def fetch(directory: str, dest: Optional[str]) -> None:
@add_options_from_dataclass(VllmServeCliArgs)
@filter_none_kwargs
def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
from axolotl.cli.vllm_serve import do_vllm_serve
do_vllm_serve(config, cli_args)

View File

@@ -1,6 +1,5 @@
"""CLI to run training on a model."""
import gc
import logging
import os
from pathlib import Path
@@ -18,7 +17,7 @@ from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils import patch_optimized_env
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
@@ -36,7 +35,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
cli_args: Training-specific CLI arguments.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
set_pytorch_cuda_alloc_conf()
print_axolotl_text_art()
check_accelerate_default_config()
@@ -49,11 +48,8 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
del model, tokenizer, trainer
gc.collect()
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train_unload(cfg)

View File

@@ -20,9 +20,11 @@ from transformers import (
ProcessorMixin,
)
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer
configure_logging()
LOG = logging.getLogger(__name__)

View File

@@ -11,6 +11,5 @@ MOE_ARCH_BLOCK = {
],
"mixtral": "MixtralSparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
}

View File

@@ -47,8 +47,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
def load_datasets(
*,
cfg: DictDefault,
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
debug: bool = False,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets, calling
@@ -57,7 +56,6 @@ def load_datasets(
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments.
debug: Whether to print out tokenization of sample
Returns:
Dataclass with fields for training and evaluation datasets and the computed
@@ -66,8 +64,7 @@ def load_datasets(
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = (
cli_args
and hasattr(cli_args, "iterable")
hasattr(cli_args, "iterable")
and cli_args.iterable is not None
and cli_args.iterable
)
@@ -79,25 +76,20 @@ def load_datasets(
preprocess_iterable=preprocess_iterable,
)
if ( # pylint: disable=too-many-boolean-expressions
cli_args
and (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
)
) or debug:
if (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
LOG.info("check_dataset_labels...")
num_examples = cli_args.debug_num_examples if cli_args else 1
text_only = cli_args.debug_text_only if cli_args else False
train_samples = sample_dataset(train_dataset, num_examples)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels(
train_samples,
tokenizer,
num_examples=num_examples,
text_only=text_only,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
)
LOG.info("printing prompters...")

View File

@@ -60,7 +60,6 @@ from axolotl.core.training_args import (
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
@@ -115,8 +114,6 @@ class TrainerBuilderBase(abc.ABC):
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
patch_trainer_get_lr()
@property
def model_ref(self):
return self._model_ref
@@ -168,9 +165,6 @@ class TrainerBuilderBase(abc.ABC):
)
)
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
@@ -252,6 +246,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -488,7 +485,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = (
self.cfg.max_steps if self.cfg.max_steps else -1
total_num_steps if self.cfg.max_steps else -1
)
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs["per_device_train_batch_size"] = (

View File

@@ -3,29 +3,15 @@ DPO trainer for axolotl
"""
import gc
import random
from functools import wraps
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Union
import pandas as pd
import torch
import wandb
from accelerate import PartialState
from datasets import Dataset, IterableDataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.utils.data import DataLoader
from transformers import (
BaseImageProcessor,
FeatureExtractionMixin,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
)
from transformers.trainer_utils import EvalLoopOutput
from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt
from trl.trainer.utils import log_table_to_comet_experiment
from trl import DPOTrainer
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.utils import (
@@ -95,64 +81,6 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
return super().push_to_hub(*args, **kwargs)
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
def _prepare_dataset(
self,
dataset: Union[Dataset, IterableDataset],
processing_class: Union[
PreTrainedTokenizerBase,
BaseImageProcessor,
FeatureExtractionMixin,
ProcessorMixin,
],
args: DPOConfig,
dataset_name: str,
) -> Union[Dataset, IterableDataset]:
# Build the kwargs for the `map` function
map_kwargs: Dict[str, Any] = {"writer_batch_size": 10}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
map_kwargs["num_proc"] = args.dataset_num_proc
with PartialState().main_process_first():
# Extract prompt if needed
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
# Apply the chat template if needed
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
dataset = dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class, "tools": args.tools},
**map_kwargs,
)
# Tokenize the dataset
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
dataset = dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
remove_columns=["chosen", "rejected"],
fn_kwargs={
"processing_class": processing_class,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
"add_special_tokens": False,
},
**map_kwargs,
)
return dataset
@staticmethod
def tokenize_row(
features,
@@ -177,8 +105,12 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
res["chosen_labels"] = res["chosen_labels"][1:]
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
res["rejected_labels"] = res["rejected_labels"][1:]
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
return res
@@ -192,69 +124,3 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
gc.collect()
torch.cuda.empty_cache()
return loss
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[list[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
Overriding built-in evaluation loop to store metrics for each batch.
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
# Sample and save to game log if requested (for one batch to save time)
if self.generate_during_eval:
# Generate random indices within the range of the total number of samples
num_samples = len(dataloader.dataset)
random_indices = random.sample(
range(num_samples), k=self.args.eval_batch_size
)
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
random_batch_dataset = dataloader.dataset.select(random_indices)
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)
policy_output_decoded, ref_output_decoded = (
self.generate_from_model_and_ref(self.model, random_batch)
)
table = pd.DataFrame(
columns=["Prompt", "Policy", "Ref Model"],
data=[
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
for prompt, pol, ref in zip(
random_batch_dataset["prompt"],
policy_output_decoded,
ref_output_decoded,
)
],
)
if "wandb" in self.args.report_to and self.accelerator.is_main_process:
wandb.log({"game_log": wandb.Table(data=table)})
if "comet_ml" in self.args.report_to:
log_table_to_comet_experiment(
name="game_log.csv",
table=table,
)
# Base evaluation
initial_output = super( # pylint: disable=bad-super-call
DPOTrainer, self
).evaluation_loop(
dataloader,
description,
prediction_loss_only,
ignore_keys,
metric_key_prefix,
)
return initial_output

View File

@@ -63,7 +63,6 @@ class GRPOStrategy:
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
grpo_args_kwargs["log_completions"] = trl.log_completions
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights
@@ -71,13 +70,6 @@ class GRPOStrategy:
if trl.scale_rewards is not None:
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
if trl.loss_type is not None:
grpo_args_kwargs["loss_type"] = trl.loss_type
if trl.mask_truncated_completions is not None:
grpo_args_kwargs["mask_truncated_completions"] = (
trl.mask_truncated_completions
)
if trl.temperature is not None:
grpo_args_kwargs["temperature"] = trl.temperature
if trl.top_p is not None:
@@ -93,11 +85,6 @@ class GRPOStrategy:
grpo_args_kwargs["num_iterations"] = trl.num_iterations
if trl.epsilon is not None:
grpo_args_kwargs["epsilon"] = trl.epsilon
if trl.epsilon_high is not None:
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
return grpo_args_kwargs

View File

@@ -3,10 +3,9 @@
import logging
import torch
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
from torch.optim.lr_scheduler import OneCycleLR
from transformers.trainer import Trainer
from axolotl.integrations.base import PluginManager
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
@@ -26,9 +25,9 @@ class SchedulerMixin(Trainer):
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
) -> LRScheduler:
):
"""
Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
@@ -48,16 +47,7 @@ class SchedulerMixin(Trainer):
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
plugin_manager = PluginManager.get_instance()
lr_scheduler: LRScheduler | None = plugin_manager.create_lr_scheduler(
trainer=self,
optimizer=optimizer,
num_training_steps=num_training_steps
)
if lr_scheduler is not None:
LOG.info(f"Using plugin-created lr_scheduler: {lr_scheduler}")
self.lr_scheduler = lr_scheduler
elif self.args.alternate_lr_scheduler_type == "one_cycle":
if self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
@@ -120,4 +110,4 @@ class SchedulerMixin(Trainer):
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler # type: ignore
return self.lr_scheduler

View File

@@ -1,7 +1,6 @@
"""Module for ReLoRA trainer"""
import torch
from torch.optim.lr_scheduler import LRScheduler
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.monkeypatch.relora import ReLoRAScheduler
@@ -20,11 +19,9 @@ class ReLoRATrainer(AxolotlTrainer):
self,
num_training_steps: int,
optimizer: torch.optim.Optimizer | None = None,
) -> LRScheduler:
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler: LRScheduler = super().create_scheduler(
num_training_steps, optimizer
)
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
@@ -33,7 +30,7 @@ class ReLoRATrainer(AxolotlTrainer):
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler( # type: ignore
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
@@ -41,6 +38,6 @@ class ReLoRATrainer(AxolotlTrainer):
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler # type: ignore
self.lr_scheduler = lr_scheduler
return self.lr_scheduler # type: ignore
return self.lr_scheduler

View File

@@ -11,19 +11,20 @@ from accelerate.logging import get_logger
from datasets import Dataset
from transformers.trainer import Trainer
from axolotl.train import (
TrainDatasetMeta,
setup_model_and_tokenizer,
)
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
LOG = get_logger(__name__)
configure_logging()
LOG = get_logger("axolotl.evaluate")
def evaluate_dataset(
@@ -74,22 +75,37 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
Returns:
Dictionary mapping metric names to their values.
"""
# Load tokenizer, processor and model
LOG.debug("loading model for evaluation...")
model, tokenizer, _, processor = setup_model_and_tokenizer(cfg)
# pylint: disable=duplicate-code
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
# Load processor for multimodal models if needed
processor = None
if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer)
# Get datasets
# pylint: disable=duplicate-code
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Load model
LOG.debug("loading model for evaluation...")
model, _ = load_model(cfg, tokenizer, processor=processor)
# Set up trainer
trainer = setup_trainer(
cfg=cfg,
cfg,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=model,
model=(model, None, None), # No need for model_ref or peft_config
tokenizer=tokenizer,
processor=processor,
total_num_steps=total_num_steps,

View File

@@ -24,7 +24,6 @@ import logging
from typing import OrderedDict
import torch
from torch.optim.lr_scheduler import LRScheduler
class BasePlugin:
@@ -37,12 +36,11 @@ class BasePlugin:
Methods:
register(cfg): Registers the plugin with the given configuration.
pre_model_load(cfg): Performs actions before the model is loaded.
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
post_model_load(cfg, model): Performs actions after the model is loaded.
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler.
create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training.
"""
@@ -79,14 +77,6 @@ class BasePlugin:
None
"""
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is built/loaded, but before any adapters are applied.
Args:
cfg (dict): The configuration for the plugin.
"""
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is loaded.
@@ -147,8 +137,8 @@ class BasePlugin:
"""
def create_lr_scheduler(
self, cfg, trainer, optimizer, num_training_steps
) -> LRScheduler | None: # pylint: disable=unused-argument
self, cfg, trainer, optimizer
): # pylint: disable=unused-argument
"""
Creates and returns a learning rate scheduler.
@@ -156,10 +146,9 @@ class BasePlugin:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
num_training_steps (int): Total number of training steps
Returns:
object (LRScheduler): The created learning rate scheduler.
object: The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
@@ -272,7 +261,6 @@ class PluginManager:
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
_instance = None
_cfg = None
def __new__(cls):
"""
@@ -280,9 +268,7 @@ class PluginManager:
"""
if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins: OrderedDict[str, BasePlugin] = (
collections.OrderedDict()
)
cls._instance.plugins = collections.OrderedDict()
return cls._instance
@staticmethod
@@ -295,14 +281,6 @@ class PluginManager:
PluginManager()
return PluginManager._instance # type: ignore
@property
def cfg(self):
return self._cfg
@cfg.setter
def cfg(self, cfg):
self._cfg = cfg
def register(self, plugin_name: str):
"""
Registers a new plugin by its name.
@@ -351,22 +329,9 @@ class PluginManager:
for plugin in self.plugins.values():
plugin.pre_model_load(cfg)
def post_model_build(self, cfg, model):
"""
Calls the post_model_build method of all registered plugins after the model has been built/loaded,
but before any adapters have been applied.
Args:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
"""
for plugin in self.plugins.values():
plugin.post_model_build(cfg, model)
def post_model_load(self, cfg, model):
"""
Calls the post_model_load method of all registered plugins after the model has been loaded
inclusive of any adapters
Calls the post_model_load method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
@@ -422,29 +387,29 @@ class PluginManager:
return trainer_cls
return None
def create_optimizer(self, trainer):
def create_optimizer(self, cfg, trainer):
"""
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
Returns:
object: The created optimizer, or None if none was found.
"""
for plugin in self.plugins.values():
optimizer = plugin.create_optimizer(self.cfg, trainer)
optimizer = plugin.create_optimizer(cfg, trainer)
if optimizer is not None:
return optimizer
return None
def create_lr_scheduler(
self, trainer, optimizer, num_training_steps
) -> LRScheduler | None:
def create_lr_scheduler(self, cfg, trainer, optimizer):
"""
Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
@@ -452,12 +417,7 @@ class PluginManager:
object: The created learning rate scheduler, or None if none was found.
"""
for plugin in self.plugins.values():
scheduler: LRScheduler | None = plugin.create_lr_scheduler(
self.cfg,
trainer=trainer,
optimizer=optimizer,
num_training_steps=num_training_steps,
)
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
if scheduler is not None:
return scheduler
return None
@@ -498,20 +458,6 @@ class PluginManager:
callbacks.extend(plugin_callbacks)
return callbacks
def post_train(self, cfg, model):
"""
Calls the post_train method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins.values():
plugin.post_train(cfg, model)
def post_train_unload(self, cfg):
"""
Calls the post_train_unload method of all registered plugins.

View File

@@ -32,8 +32,8 @@ plugins:
## Supported Models
- llama
- llama4
- llama4_text
- llama4
- mllama
- phi3
- gemma
@@ -43,11 +43,6 @@ plugins:
- mistral
- mistral3
- qwen2
- qwen2_moe
- qwen2_vl
- qwen2_5_vl
- qwen3
- qwen3_moe
- cohere
- cohere2
- glm

View File

@@ -25,7 +25,7 @@ import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version
from axolotl.utils.distributed import is_main_process
from axolotl.utils.distributed import zero_only
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
@@ -76,7 +76,7 @@ class CutCrossEntropyPlugin(BasePlugin):
cce_patch,
)
if is_main_process(use_environ=True):
with zero_only():
LOG.info(
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
)

View File

@@ -1,174 +0,0 @@
"""Llama CCE patch. Adapted from transformers v4.51.2"""
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from transformers.cache_utils import Cache
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.models.llama.modeling_llama import (
_CONFIG_FOR_DOC,
LLAMA_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple
_PATCH_OPTS: PatchOptions | None = None
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
if hidden_states is None:
raise ValueError("hidden_states is None")
loss = None
logits = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
**kwargs,
)
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def patch_llama(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
"""Patch Llama for CCE."""
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.llama import modeling_llama
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_llama.LlamaForCausalLM
), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_llama.LlamaForCausalLM.forward = cce_forward
return None

View File

@@ -5,7 +5,9 @@
import transformers
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
from cut_cross_entropy.transformers.llama import patch_llama
from cut_cross_entropy.transformers.phi3 import patch_phi3
from cut_cross_entropy.transformers.qwen2 import patch_qwen2
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
@@ -22,9 +24,6 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import (
patch_glm,
patch_glm4,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
patch_llama,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
patch_llama4,
patch_llama4_text,
@@ -34,22 +33,6 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
patch_mistral3,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2 import (
patch_qwen2,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_5_vl import (
patch_qwen2_5_vl,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_moe import (
patch_qwen2_moe,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_vl import (
patch_qwen2_vl,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3 import patch_qwen3
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3_moe import (
patch_qwen3_moe,
)
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
"llama": patch_llama,
@@ -64,11 +47,6 @@ CUT_CROSS_ENTROPY_MODEL_MAPPING = {
"mistral": patch_mistral,
"mistral3": patch_mistral3,
"qwen2": patch_qwen2,
"qwen2_moe": patch_qwen2_moe,
"qwen2_vl": patch_qwen2_vl,
"qwen2_5_vl": patch_qwen2_5_vl,
"qwen3": patch_qwen3,
"qwen3_moe": patch_qwen3_moe,
"cohere": patch_cohere,
"cohere2": patch_cohere2,
"glm": patch_glm,

View File

@@ -1,37 +0,0 @@
"""Qwen2 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
# pylint: disable=duplicate-code
from types import MethodType
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
)
def patch_qwen2(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
from transformers.models.qwen2 import modeling_qwen2
# Set the _PATCH_OPTS in the llama patch file
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
cce_forward,
)
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_qwen2.Qwen2ForCausalLM
), f"Expected a Qwen2ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_qwen2.Qwen2ForCausalLM.forward = cce_forward
return None

View File

@@ -1,246 +0,0 @@
"""Qwen2.5 VL CCE patch. Adapted from transformers v4.51.2"""
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Tuple, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLCausalLMOutputWithPast,
)
_PATCH_OPTS: PatchOptions | None = None
def cce_forward_multimodal(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
mask = input_ids == self.config.image_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
mask = input_ids == self.config.video_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
):
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts,
attention_mask,
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = (
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
if cache_position is not None
else 0
)
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
position_ids = position_ids.add(delta) # type: ignore
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
logits = None
loss = None
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states,
self.lm_head.weight,
labels,
_PATCH_OPTS,
)
else:
logits = self.lm_head(hidden_states)
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Qwen2_5_VLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)
def patch_qwen2_5_vl(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration
), f"Expected a Qwen2_5_VLForConditionalGeneration model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
return maybe_model
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = (
cce_forward_multimodal
)
return None

View File

@@ -1,188 +0,0 @@
"""Qwen2 MoE CCE patch. Adapted from transformers v4.51.2"""
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
_CONFIG_FOR_DOC,
QWEN2MOE_INPUTS_DOCSTRING,
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
load_balancing_loss_func,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple
_PATCH_OPTS: PatchOptions | None = None
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
) -> MoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM
>>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: MoeModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
)
hidden_states = outputs.last_hidden_state
loss = None
logits = None
if hidden_states is None:
raise ValueError("hidden_states is None")
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
**loss_kwargs,
)
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
loss.device # type: ignore
) # make sure to reside in the same device
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss, # type: ignore
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
def patch_qwen2_moe(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.qwen2_moe import modeling_qwen2_moe
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_qwen2_moe.Qwen2MoeForCausalLM
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(forward, maybe_model)
return maybe_model
modeling_qwen2_moe.Qwen2MoeForCausalLM.forward = forward
return None

View File

@@ -1,249 +0,0 @@
"""Qwen2 VL CCE patch. Adapted from transformers v4.51.2"""
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Tuple, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
_CONFIG_FOR_DOC,
QWEN2_VL_INPUTS_DOCSTRING,
Qwen2VLCausalLMOutputWithPast,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
_PATCH_OPTS: PatchOptions | None = None
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
):
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = (
cache_position[0] + self.rope_deltas
if cache_position is not None
else 0
)
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
delta = delta.to(position_ids.device) # type: ignore
position_ids = position_ids.add(delta) # type: ignore
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
logits = None
loss = None
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states,
self.lm_head.weight,
labels,
_PATCH_OPTS,
)
else:
logits = self.lm_head(hidden_states)
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Qwen2VLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)
def patch_qwen2_vl(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.qwen2_vl import modeling_qwen2_vl
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_qwen2_vl.Qwen2VLForConditionalGeneration
), f"Expected a Qwen2VLForConditionalGeneration model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
return maybe_model
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = cce_forward_multimodal
return None

View File

@@ -1,35 +0,0 @@
"""Qwen3 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
# pylint: disable=duplicate-code
from types import MethodType
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
)
def patch_qwen3(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
from transformers.models.qwen3 import modeling_qwen3
# Set the _PATCH_OPTS in the llama patch file
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import cce_forward
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_qwen3.Qwen3ForCausalLM
), f"Expected a Qwen3ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_qwen3.Qwen3ForCausalLM.forward = cce_forward
return None

View File

@@ -1,194 +0,0 @@
"""Qwen3 MoE CCE patch. Adapted from transformers v4.51.2"""
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
_CONFIG_FOR_DOC,
QWEN3_MOE_INPUTS_DOCSTRING,
KwargsForCausalLM,
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
load_balancing_loss_func,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple
_PATCH_OPTS: PatchOptions | None = None
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> MoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
>>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: MoeModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
if hidden_states is None:
raise ValueError("hidden_states is None")
loss = None
logits = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
**kwargs,
)
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
loss.device # type: ignore
) # make sure to reside in the same device
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss, # type: ignore
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
def patch_qwen3_moe(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.qwen3_moe import modeling_qwen3_moe
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_qwen3_moe.Qwen3MoeForCausalLM
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(forward, maybe_model)
return maybe_model
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = forward
return None

View File

@@ -35,9 +35,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
sequence_len,
roles_to_train=None,
train_on_eos=None,
train_on_eot=None,
eot_tokens=None,
split_thinking: bool | None = False,
logprobs_field="logprobs",
gen_temperature=1.0,
kd_temperature=1.0,
@@ -53,9 +50,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
sequence_len,
roles_to_train=roles_to_train,
train_on_eos=train_on_eos,
train_on_eot=train_on_eot,
eot_tokens=eot_tokens,
split_thinking=split_thinking,
)
@property

View File

@@ -23,8 +23,8 @@ import logging
import sys
from axolotl.integrations.base import BasePlugin
from axolotl.utils.distributed import is_main_process
from ...utils.distributed import zero_only
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
from .utils import patch_with_compile_disable
@@ -85,7 +85,7 @@ class LigerPlugin(BasePlugin):
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
if is_main_process(use_environ=True):
with zero_only():
LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
)
@@ -151,30 +151,6 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3":
from axolotl.integrations.liger.models.qwen3 import (
apply_liger_kernel_to_qwen3,
)
apply_liger_kernel_to_qwen3(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_moe":
from axolotl.integrations.liger.models.qwen3_moe import (
apply_liger_kernel_to_qwen3_moe,
)
apply_liger_kernel_to_qwen3_moe(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
else:
logging.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."

View File

@@ -1,160 +0,0 @@
"""
Liger FLCE for Qwen3. Based on transformers v4.51.3.
"""
import sys
from typing import Optional, Tuple, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_liger_kernel_to_qwen3(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> None:
# pylint: disable=duplicate-code
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3.modeling_qwen3 # noqa: F401 # pylint: disable=unused-import
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
modeling_qwen3 = sys.modules["transformers.models.qwen3.modeling_qwen3"]
if rms_norm:
modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
if glu_activation:
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
if layer_norm:
modeling_qwen3.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_qwen3.Qwen3ForCausalLM.forward = lce_forward

View File

@@ -1,191 +0,0 @@
"""
Liger FLCE for Qwen3 MoE. Based on transformers v4.51.3.
"""
import sys
from copy import deepcopy
from typing import List, Optional, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
from transformers.models.qwen3_moe.modeling_qwen3_moe import load_balancing_loss_func
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> MoeCausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(
loss.device
) # make sure to reside in the same device
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_liger_kernel_to_qwen3_moe(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> None:
# pylint: disable=duplicate-code
"""
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3_moe.modeling_qwen3_moe # noqa: F401 # pylint: disable=unused-import
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
modeling_qwen3_moe = sys.modules["transformers.models.qwen3_moe.modeling_qwen3_moe"]
if rms_norm:
modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
if glu_activation:
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"Accepts intermediate_size to pass to LigerSwiGLUMLP"
# clone config to avoid modifying the original
config = deepcopy(config)
if intermediate_size:
setattr(config, "intermediate_size", intermediate_size)
return LigerSwiGLUMLP(config, **kwargs)
modeling_qwen3_moe.Qwen3MoeMLP = _liger_swiglu_mlp_wrapper
if layer_norm:
modeling_qwen3_moe.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = lce_forward

View File

@@ -12,8 +12,10 @@ import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
configure_logging()
LOG = get_logger(__name__)

View File

@@ -23,42 +23,22 @@ from axolotl.utils.dict import DictDefault
LOG = get_logger(__name__)
QKV_PATCHES = [
(
"""
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
),
"""
"\n"
)
PATCHED_QKV_CODE = """
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
),
),
(
"""
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
),
),
]
"\n"
)
ORIGINAL_O_CODE = """
attn_output = self.o_proj(attn_output)
@@ -148,11 +128,10 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
module = __import__(
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
return attention_cls
except (ImportError, AttributeError) as e:
@@ -189,18 +168,10 @@ def patch_self_attn_lora(cfg: DictDefault):
attention_cls._original_forward = self_attn_forward
self_attn_forward, _ = detab_code(self_attn_forward)
assert any(
qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES
), "Original QKV code not found"
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
for qkv_orig, qkv_patched in QKV_PATCHES:
if qkv_orig in self_attn_forward:
self_attn_forward = self_attn_forward.replace(
qkv_orig,
qkv_patched,
)
break
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",

View File

@@ -18,8 +18,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mixtral",
"qwen2",
"qwen2_moe",
"qwen3",
"qwen3_moe",
"falcon",
"phi",
"phi3",

View File

@@ -1,42 +0,0 @@
"""
monkeypatch for Trainer _get_learning_rate method
"""
import logging
import torch
LOG = logging.getLogger(__name__)
# TODO remove this patch once https://github.com/huggingface/transformers/pull/37881 is included in a release
def _get_learning_rate(self):
if self.is_deepspeed_enabled:
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
# not run for the first few dozen steps while loss scale is too large, and thus during
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
try:
last_lr = self.lr_scheduler.get_last_lr()[0]
except AssertionError as e:
if "need to call step" in str(e):
LOG.warning(
"tried to get lr value before scheduler/optimizer started stepping, returning lr=0"
)
last_lr = 0
else:
raise
else:
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
last_lr = self.optimizer.param_groups[0]["lr"]
else:
last_lr = self.lr_scheduler.get_last_lr()[0]
if torch.is_tensor(last_lr):
last_lr = last_lr.item()
return last_lr
def patch_trainer_get_lr():
from transformers.trainer import Trainer
Trainer._get_learning_rate = _get_learning_rate # pylint: disable=protected-access

View File

@@ -4,7 +4,7 @@ HF Chat Templates prompt strategy
import logging
from collections import defaultdict
from typing import Any, Dict, List, Set, Union
from typing import Any, Dict, List, Optional, Set, Union
from pydantic import BaseModel
from transformers import ProcessorMixin
@@ -29,12 +29,11 @@ class ChatTemplatePrompter(Prompter):
chat_template: str,
processor=None,
max_length=2048,
message_property_mappings: Dict[str, str] | None = None,
message_field_training: str | None = None,
message_field_training_detail: str | None = None,
message_property_mappings: Optional[Dict[str, str]] = None,
message_field_training: Optional[str] = None,
message_field_training_detail: Optional[str] = None,
field_messages: str = "messages",
field_system: str = "system",
roles: Dict[str, List[str]] | None = None,
roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False,
):
# check if message_property_mappings is None or empty dict
@@ -42,7 +41,6 @@ class ChatTemplatePrompter(Prompter):
message_property_mappings = {
"role": "role",
"content": "content",
"reasoning_content": "reasoning_content",
}
if roles:
@@ -64,9 +62,8 @@ class ChatTemplatePrompter(Prompter):
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.field_messages = field_messages
self.field_system = field_system
self.tokenizer = tokenizer
self.processor: ProcessorMixin | None = processor
self.processor: Optional[ProcessorMixin] = processor
self.chat_template = chat_template
self.max_length = max_length
self.drop_system_message = drop_system_message
@@ -223,13 +220,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
self,
prompter: "ChatTemplatePrompter",
tokenizer,
train_on_inputs: bool,
sequence_len: int,
roles_to_train: list[str] | None = None,
train_on_eos: str | None = None,
train_on_eot: str | None = None,
eot_tokens: list[str] | None = None,
split_thinking: bool | None = False,
train_on_inputs,
sequence_len,
roles_to_train=None,
train_on_eos=None,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.prompter: ChatTemplatePrompter = prompter
@@ -242,88 +236,12 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
]
self.train_on_eos = train_on_eos
# Backward compatibility, load from train_on_eos
self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos
# Default to eos_token if eot_tokens not provided
self.eot_tokens = (
eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token]
)
self.split_thinking = split_thinking
self.images = "images"
LOG.debug(
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
)
self._validate_eot_and_eos_tokens()
def _validate_eot_and_eos_tokens(self):
"""
- Validates that EOT tokens (or eos_token) are in the chat_template
- Checks if EOT tokens are encoded as multiple tokens in the tokenizer.
- Checks for potential conflicts between train_on_eos and train_on_eot.
"""
if self.prompter.chat_template is None:
# Usually this should not happen
LOG.warning(
"No chat template provided, skipping EOT and EOS token validation"
)
return
# If the EOT token is the same as the EOS token, we need to check differently
if len(self.eot_tokens) == 1 and self.eot_tokens[0] == self.tokenizer.eos_token:
# Check if the eos_token is in the chat_template or as a variable `eos_token`
# Note: we check for `eos_token` in the string, but it could possibly not be a variable
if (
self.tokenizer.eos_token not in self.prompter.chat_template
and "eos_token" not in self.prompter.chat_template
):
LOG.warning(
f"EOS token '{self.tokenizer.eos_token}' not found in chat_template. Please check if your template/EOS token is correct."
)
return
# Create a new list to store tokens that should be kept
valid_eot_tokens = []
for token in self.eot_tokens:
# Check if EOT token is in the chat_template
if token not in self.prompter.chat_template:
LOG.warning(f"EOT token '{token}' not found in chat_template.")
# Don't add to the valid tokens list
continue
valid_eot_tokens.append(token)
# Replace the original list with the filtered one
self.eot_tokens = valid_eot_tokens
for token in self.eot_tokens:
# If token in template, check if EOT token is in tokenizer and not encoded as multiple tokens
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
if not token_ids:
raise ValueError(
"EOT token encoding failed. Please check if the token is valid and can be encoded."
)
if token_ids and len(token_ids) > 1:
raise ValueError(
f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config "
"or (recommended) override unused added_tokens via `added_tokens_overrides: `."
)
# If eos_token is in eot_tokens and conflict between train_on_eos and train_on_eot, raise an error
if (
self.tokenizer.eos_token in self.eot_tokens
and self.train_on_eos != self.train_on_eot
):
raise ValueError(
"Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot"
f"train_on_eos: {self.train_on_eos}, train_on_eot: {self.train_on_eot}"
f"eot_tokens: {self.eot_tokens}"
f"eos_token: {self.tokenizer.eos_token}"
)
@property
def supports_batched(self) -> bool:
# Let calling code know we can handle lists of examples
@@ -367,7 +285,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if (
not self.roles_to_train
and not self.train_on_eos
and not self.train_on_eot
and not self.prompter.message_field_training # type: ignore
and not self.prompter.message_field_training_detail # type: ignore
):
@@ -403,7 +320,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
labels = [IGNORE_TOKEN_ID] * len(input_ids)
last_eos_idx = -1
last_eot_idx = -1
for index, turn in enumerate(turns):
role = turn.get("role")
content = turn.get("content")
@@ -452,46 +368,25 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
LOG.debug(f"Labels after processing turn {index}: {labels}")
# Handle special tokens (EOT and EOS)
for token_type, find_func, train_option in [
("EOT", self.find_first_eot_token, self.train_on_eot),
("EOS", self.find_first_eos_token, self.train_on_eos),
]:
token_idx = find_func(input_ids, start_idx=turn_end_idx)
if (
token_idx != -1 and abs(token_idx - turn_end_idx) <= 3
): # Allow for some template padding
# Update the last token index
if token_type == "EOT": # nosec B105
last_eot_idx = token_idx
else:
last_eos_idx = token_idx
# Set labels if needed for this turn
if train_option == "all" or (
train_option == "turn" and should_train
):
labels[token_idx] = input_ids[token_idx]
LOG.debug(
f"{token_type} token set for training at index {token_idx}"
)
else:
LOG.debug(
f"{token_type} token missing after turn {turn}. {token_type.lower()}_idx: {token_idx}, turn_end_idx: {turn_end_idx}"
)
# Handle 'last' option for special tokens
for token_type, last_idx, train_option in [
("EOT", last_eot_idx, self.train_on_eot),
("EOS", last_eos_idx, self.train_on_eos),
]:
if train_option == "last" and last_idx != -1:
labels[last_idx] = input_ids[last_idx]
# Handle EOS token
eos_idx = self.find_first_eos_token(input_ids, start_idx=turn_end_idx)
if abs(eos_idx - turn_end_idx) <= 3: # Allow for some template padding
last_eos_idx = eos_idx
if self.train_on_eos == "all" or (
self.train_on_eos == "turn" and should_train
):
labels[eos_idx] = input_ids[eos_idx]
LOG.debug(f"EOS token set for training at index {eos_idx}")
else:
LOG.debug(
f"Last {token_type} token set for training at index {last_idx}"
f"EOS token missing after turn {turn}. eos_idx: {eos_idx}, turn_end_idx: {turn_end_idx}"
)
# Handle 'last' option for train_on_eos
if self.train_on_eos == "last" and last_eos_idx != -1:
labels[last_eos_idx] = input_ids[last_eos_idx]
LOG.debug(f"Last EOS token set for training at index {last_eos_idx}")
LOG.debug(f"Final labels: {labels}")
return {
@@ -507,25 +402,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return i
return -1
def find_first_eot_token(self, input_ids, start_idx):
"""Find the first EOT token in the input_ids starting from start_idx."""
# Get token IDs for all EOT tokens
eot_token_ids = []
for token in self.eot_tokens:
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
if len(token_ids) != 1:
raise ValueError(
f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config."
)
eot_token_ids.append(token_ids[0]) # Use the last token ID if multiple
# Search for any of the EOT token IDs
for i in range(start_idx, len(input_ids)):
if input_ids[i] in eot_token_ids:
return i
return -1
def find_turn(self, turns: list[dict], turn_idx: int):
"""
Locate the starting and ending indices of the specified turn in a conversation.
@@ -612,17 +488,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def get_conversation_thread(self, prompt):
turns = []
possible_sys_turn = self.transform_message(
prompt[self.prompter.field_messages][0]
)
if (
possible_sys_turn["role"] != "system"
and self.prompter.field_system in prompt
):
turn = {"role": "system", "content": prompt[self.prompter.field_system]}
turns.append(turn)
for message in prompt[self.prompter.field_messages]:
transformed_message = self.transform_message(message)
@@ -658,52 +523,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
transformed_message["role"], transformed_message["role"]
)
# TODO handle reasoning_content with split_thinking
# if the role is assistant that we want to use reasoning_content
if self.split_thinking and transformed_message["role"] == "assistant":
content = transformed_message["content"]
thinking_pairs = [
("<think>", "</think>"),
("<reasoning>", "</reasoning>"),
("<|begin_of_thought|>", "<|end_of_thought|>"),
]
content_pairs = [("<|begin_of_solution|>", "<|end_of_solution|>")]
for tpair in thinking_pairs:
# check if the thinking pair is in the content
if tpair[0] in content and tpair[1] in content:
# find the start and end index of the thinking pair
t_start_idx = content.find(tpair[0])
t_end_idx = content.find(tpair[1])
# get the thinking content
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
transformed_message["reasoning_content"] = thinking_content.strip()
# take remainder of the content
# strip whitespace from beginning of the remainder (thinking tokens)
remainder = content[t_end_idx + len(tpair[1]) :].lstrip()
# check if the content pair is in the remainder
cpair_found = False
for cpair in content_pairs:
if cpair[0] in remainder and cpair[1] in remainder:
# find the start and end index of the content pair
c_start_idx = remainder.find(cpair[0])
c_end_idx = remainder.find(cpair[1])
# get the content content
content_content = remainder[
c_start_idx + len(cpair[0]) : c_end_idx
]
transformed_message["content"] = content_content.strip()
cpair_found = True
break
# else, the content is the remainder
if not cpair_found:
transformed_message["content"] = remainder
break
# Determine which keys in the original message were not mapped
mapped_values = set(self.prompter.message_property_mappings.values())
remaining_keys = set(message) - mapped_values
@@ -736,16 +555,13 @@ class StrategyLoader:
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
"train_on_eot": ds_cfg.get("train_on_eot", None),
"eot_tokens": cfg.get("eot_tokens", None), # loads from cfg, not ds_cfg
"split_thinking": ds_cfg.get("split_thinking", False),
}
def __call__(
self,
tokenizer,
cfg,
ds_cfg: Union[Dict[str, Any], DatasetConfig] | None = None,
ds_cfg: Optional[Union[Dict[str, Any], DatasetConfig]] = None,
processor=None,
):
if ds_cfg is None:

View File

@@ -21,7 +21,6 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
from axolotl.cli.art import print_axolotl_text_art
from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
@@ -30,7 +29,7 @@ from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuil
from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager,
)
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
@@ -42,6 +41,7 @@ try:
except ImportError:
BetterTransformer = None
configure_logging()
LOG = get_logger(__name__)
@@ -517,8 +517,6 @@ def train(
Returns:
Tuple of (model, tokenizer) after training
"""
print_axolotl_text_art()
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
(
trainer,
@@ -550,7 +548,4 @@ def train(
if not cfg.use_ray:
cleanup_distributed()
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train(cfg, model)
return model, tokenizer, trainer

View File

@@ -43,12 +43,3 @@ def set_pytorch_cuda_alloc_conf():
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
"expandable_segments:True,roundup_power2_divisions:16"
)
def patch_optimized_env():
"""
Patch environment variables to improve VRAM usage and increase download speed
"""
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
set_pytorch_cuda_alloc_conf()

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import gc
import json
import logging
import os
import traceback
@@ -809,44 +808,11 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
artifact.add_file(temp_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_file.name)
LOG.info(
"The Axolotl config has been saved to the WandB run under files."
)
LOG.info(
"The Axolotl config has been saved to the WandB run under files."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
if args.deepspeed:
try:
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
with NamedTemporaryFile(
mode="w",
delete=False,
suffix=".json",
prefix="deepspeed_config_",
) as temp_file:
skip_upload = False
if isinstance(args.deepspeed, dict):
json.dump(args.deepspeed, temp_file, indent=4)
elif isinstance(args.deepspeed, str) and os.path.exists(
args.deepspeed
):
copyfile(args.deepspeed, temp_file.name)
else:
skip_upload = True
if not skip_upload:
artifact = wandb.Artifact(
f"deepspeed-config-{wandb.run.id}",
type="deepspeed-config",
)
artifact.add_file(temp_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_file.name)
LOG.info(
"The DeepSpeed config has been saved to the WandB run under files."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving DeepSpeed config to WandB: {err}")
return control

File diff suppressed because one or more lines are too long

View File

@@ -59,7 +59,7 @@ def choose_device(cfg):
def resolve_dtype(cfg):
if (
not cfg.fp16 and cfg.bf16 == "auto" and not cfg.use_ray
cfg.bf16 == "auto" and not cfg.use_ray
): # if we use ray we want to defer this check to the worker node
if is_torch_bf16_gpu_available():
LOG.debug("bf16 support detected, enabling for this configuration.")
@@ -67,7 +67,7 @@ def resolve_dtype(cfg):
else:
LOG.debug("bf16 support not detected, disabling for this configuration.")
cfg.bf16 = False
if cfg.fp16 is None and not cfg.float16:
if cfg.fp16 is None:
cfg.fp16 = True
if cfg.device == "mps":

View File

@@ -204,37 +204,7 @@ def load_prepare_preference_datasets(cfg):
else:
eval_dataset = load_split(cfg.test_datasets, cfg)
if not eval_dataset:
if cfg.val_set_size:
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
train_dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "train"
+ "|"
+ str(cfg.seed or 42)
)
to_hash_test = (
train_dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "test"
+ "|"
+ str(cfg.seed or 42)
)
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)
ds_w_test_split = train_dataset.train_test_split(
test_size=cfg.val_set_size,
seed=cfg.seed,
shuffle=False,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
eval_dataset = ds_w_test_split["test"]
train_dataset = ds_w_test_split["train"]
eval_dataset = None
if not train_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)

View File

@@ -69,27 +69,17 @@ def barrier():
dist.barrier()
def is_main_process(use_environ=False):
def is_main_process():
"""
Check if the current process is the main process. If not in distributed mode,
always return `True`.
Args:
- use_environ (bool, optional): Use environment variable to determine main process.
Returns:
- bool: `True` if the current process is the main process, `False` otherwise.
"""
if use_environ:
return os.environ.get("LOCAL_RANK", "0") == "0"
if not is_distributed():
return True
return dist.get_rank() == 0
def is_local_main_process(use_environ=False):
if use_environ:
return os.environ.get("LOCAL_RANK", "0") == "0"
def is_local_main_process():
return PartialState().is_local_main_process
@@ -109,6 +99,17 @@ def cleanup_distributed():
torch.distributed.destroy_process_group()
@contextmanager
def zero_only():
"""
Context manager that only runs the enclosed block on the main rank.
"""
if is_main_process():
yield
else:
yield None
@contextmanager
def zero_first(is_main):
"""

View File

@@ -53,7 +53,6 @@ from transformers.integrations.deepspeed import (
)
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.integrations.base import PluginManager
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -68,14 +67,13 @@ from axolotl.utils.distributed import (
get_device_count,
get_device_type,
is_local_main_process,
is_main_process,
zero_only,
)
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
LOG = logging.getLogger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()
MULTIMODAL_AUTO_MODEL_MAPPING = {
"mllama": MllamaForConditionalGeneration,
@@ -453,7 +451,7 @@ def load_tokenizer(cfg):
{"additional_special_tokens": additional_special_tokens}
)
if is_main_process(use_environ=True):
with zero_only():
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
@@ -589,8 +587,10 @@ class ModelLoader:
patch_gemma3conditionalgeneration_forward()
# load any patches from plugins
from axolotl.integrations.base import PluginManager
PLUGIN_MANAGER.pre_model_load(self.cfg)
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)
# monkey patch to allow additional Accelerator init kwargs
if self.cfg.fp8:
@@ -1268,7 +1268,6 @@ class ModelLoader:
try:
skip_move_to_device = self.build_model(qlora_fsdp)
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
except Exception as err: # pylint: disable=broad-exception-caught
LOG.exception(err)
raise err
@@ -1348,8 +1347,6 @@ class ModelLoader:
before_kbit_train_or_finetune=False,
)
PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model)
# ---------------------------------------------------------
# load lora or adapter
# ---------------------------------------------------------
@@ -1411,7 +1408,7 @@ class ModelLoader:
gc.collect()
torch.cuda.empty_cache()
PLUGIN_MANAGER.post_model_load(self.cfg, self.model)
# TODO resume_from_checkpoint handling
return self.model, lora_config
@@ -1446,13 +1443,9 @@ def load_adapter(model, cfg, adapter, inference=False):
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
model, lora_config = load_lora(model, cfg, inference=inference)
PLUGIN_MANAGER.post_lora_load(cfg, model)
return model, lora_config
return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter":
model, lora_config = load_llama_adapter(model, cfg)
PLUGIN_MANAGER.post_lora_load(cfg, model)
return model, lora_config
return load_llama_adapter(model, cfg)
raise NotImplementedError(f"{adapter} peft adapter not available")

View File

@@ -190,7 +190,7 @@ class MultipackBatchSampler(BatchSampler):
self.len_across_ranks = None
if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warning(
LOG.warn(
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
)

View File

@@ -309,7 +309,6 @@ class AxolotlInputConfig(
| Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")]
) | None = None
chat_template_jinja: str | None = None
eot_tokens: list[str] | None = None
default_system_message: str | None = None
fix_untrained_tokens: int | list[int] | None = None
@@ -512,17 +511,10 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def hint_sample_packing_padding(cls, data):
if data.get("sample_packing"):
pad_to_sequence_len = data.get("pad_to_sequence_len")
if pad_to_sequence_len is False:
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
elif pad_to_sequence_len is None:
LOG.info(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
)
data["pad_to_sequence_len"] = True
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
return data
@model_validator(mode="before")
@@ -1157,18 +1149,6 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_grpo_peft_liger(cls, data):
if (
data.get("rl") == "grpo"
and data.get("trl", {})
and data.get("trl").get("use_liger_loss")
and data.get("adapter")
):
raise ValueError("PEFT + GRPO + Liger is not yet supported")
return data
@model_validator(mode="after")
def check_sequence_parallel_degree(self):
if not self.sequence_parallel_degree:
@@ -1334,57 +1314,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_auto_enable_lora_kernels(cls, data):
# Only proceed if using LoRA or QLoRA adapter
if data.get("rl"):
# RL trainers not tested so don't enable kernels by default
return data
if data.get("adapter") in ["lora", "qlora"]:
# Skip if already set, using unsloth optimizations, or using 8-bit
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
if (
any(data.get(k) is not None for k in kernel_fields)
or any(data.get(k) for k in unsloth_fields)
or data.get("adapter") == "lora"
and data.get("load_in_8bit")
):
return data
# Check multi-GPU compatibility
capabilities = data.get("capabilities")
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
is_fsdp = data.get("fsdp") is not None
is_fsdp2 = (
data.get("fsdp_config") is not None
and str(data.get("fsdp_config").get("fsdp_version")) == "2"
)
if (
not is_multi_gpu
or (is_multi_gpu and not is_fsdp)
or (is_multi_gpu and is_fsdp2)
):
# Auto-enable kernels if not explicitly set by user
if data.get("lora_mlp_kernel") is None:
data["lora_mlp_kernel"] = True
if data.get("lora_qkv_kernel") is None:
data["lora_qkv_kernel"] = True
if data.get("lora_o_kernel") is None:
data["lora_o_kernel"] = True
LOG.warning(
"Auto-enabling LoRA kernel optimizations for faster training. "
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "
+ "See https://docs.axolotl.ai/docs/lora_optims.html for more info."
)
return data
@model_validator(mode="before")
@classmethod
def check_adopt_torch_version(cls, data):

View File

@@ -50,7 +50,6 @@ class SFTDataset(BaseModel):
message_property_mappings: dict[str, str] | None = None
message_field_training: str | None = None
message_field_training_detail: str | None = None
split_thinking: bool | None = None
logprobs_field: str | None = None
temperature: float | None = None
roles_to_train: list[str] | None = None

View File

@@ -35,7 +35,6 @@ class ChatTemplate(str, Enum):
jamba = "jamba" # pylint: disable=invalid-name
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
qwen3 = "qwen3" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
exaone = "exaone" # pylint: disable=invalid-name
metharme = "metharme" # pylint: disable=invalid-name

View File

@@ -67,12 +67,6 @@ class TRLConfig(BaseModel):
default=False,
json_schema_extra={"description": "Whether to log completions"},
)
num_completions_to_print: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of completions to print. If `log_completions` is `True`, this will be the number of completions logged."
},
)
sync_ref_model: bool | None = Field(
default=False,
json_schema_extra={
@@ -139,25 +133,3 @@ class TRLConfig(BaseModel):
"description": "Epsilon value for clipping in the GRPO algorithm."
},
)
epsilon_high: float | None = Field(
default=None,
json_schema_extra={
"description": "Upper-bound epsilon value for clipping in the GRPO algorithm."
},
)
use_liger_loss: bool | None = Field(
default=None,
json_schema_extra={"description": "Whether to use Liger loss for GRPO."},
)
loss_type: str | None = Field(
default=None,
json_schema_extra={
"description": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`."
},
)
mask_truncated_completions: bool = Field(
default=False,
json_schema_extra={
"description": "When enabled, truncated completions are excluded from the loss calculation."
},
)

View File

@@ -597,8 +597,6 @@ def prepare_optim_env(cfg):
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
elif cfg.fp16:
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
else:
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
def prepare_opinionated_env(cfg):

View File

@@ -4,7 +4,6 @@ shared pytest fixtures
import functools
import importlib
import os
import shutil
import sys
import tempfile
@@ -80,9 +79,9 @@ def download_smollm2_135m_model():
@pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_gptq_model():
def download_llama_68m_random_model():
# download the model
snapshot_download_w_retry("lilmeaty/SmolLM2-135M-Instruct-GPTQ", repo_type="model")
snapshot_download_w_retry("JackFram/llama-68m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
@@ -91,12 +90,6 @@ def download_qwen_2_5_half_billion_model():
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_qwen3_half_billion_model():
# download the model
snapshot_download_w_retry("Qwen/Qwen3-0.6B", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tatsu_lab_alpaca_dataset():
# download the dataset
@@ -530,32 +523,31 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
# # pylint: disable=redefined-outer-name,unused-argument
@pytest.mark.skipif(
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",
reason="Not running in CI cache preload",
)
def test_load_fixtures(
download_smollm2_135m_model,
download_qwen_2_5_half_billion_model,
download_tatsu_lab_alpaca_dataset,
download_mhenrichsen_alpaca_2k_dataset,
download_mhenrichsen_alpaca_2k_w_revision_dataset,
download_mlabonne_finetome_100k_dataset,
download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
download_argilla_dpo_pairs_dataset,
download_tiny_shakespeare_dataset,
download_deepseek_model_fixture,
download_huggyllama_model_fixture,
download_llama_1b_model_fixture,
download_llama3_8b_model_fixture,
download_llama3_8b_instruct_model_fixture,
download_phi_35_mini_model_fixture,
download_phi_3_medium_model_fixture,
download_mistral_7b_model_fixture,
download_gemma_2b_model_fixture,
download_gemma2_9b_model_fixture,
download_mlx_mistral_7b_model_fixture,
download_llama2_model_fixture,
):
pass
# def test_load_fixtures(
# download_smollm2_135m_model,
# download_llama_68m_random_model,
# download_qwen_2_5_half_billion_model,
# download_tatsu_lab_alpaca_dataset,
# download_mhenrichsen_alpaca_2k_dataset,
# download_mhenrichsen_alpaca_2k_w_revision_dataset,
# download_mlabonne_finetome_100k_dataset,
# download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
# download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset,
# download_fozzie_alpaca_dpo_dataset,
# download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
# download_argilla_dpo_pairs_dataset,
# download_tiny_shakespeare_dataset,
# download_deepseek_model_fixture,
# download_huggyllama_model_fixture,
# download_llama_1b_model_fixture,
# download_llama3_8b_model_fixture,
# download_llama3_8b_instruct_model_fixture,
# download_phi_35_mini_model_fixture,
# download_phi_3_medium_model_fixture,
# download_mistral_7b_model_fixture,
# download_gemma_2b_model_fixture,
# download_gemma2_9b_model_fixture,
# download_mlx_mistral_7b_model_fixture,
# download_llama2_model_fixture,
# ):
# pass

View File

@@ -1,184 +0,0 @@
"""
e2e tests to make sure all the hooks are fired on the plugin
"""
import os
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.integrations.base import BasePlugin
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
class LogHooksPlugin(BasePlugin):
"""
fixture to capture in a log file each hook that was fired
"""
base_dir = Path("/tmp/axolotl-log-hooks")
def __init__(self):
self.base_dir.mkdir(parents=True, exist_ok=True)
try:
os.remove(self.base_dir.joinpath("plugin_hooks.log"))
except FileNotFoundError:
pass
def pre_model_load(self, cfg): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("pre_model_load\n")
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_model_build\n")
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("pre_lora_load\n")
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_lora_load\n")
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_model_load\n")
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("create_optimizer\n")
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("get_trainer_cls\n")
def create_lr_scheduler(
self, cfg, trainer, optimizer, num_training_steps
): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("create_lr_scheduler\n")
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("add_callbacks_pre_trainer\n")
return []
def add_callbacks_post_trainer(
self, cfg, trainer
): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("add_callbacks_post_trainer\n")
return []
def post_train(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_train\n")
def post_train_unload(self, cfg): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_train_unload\n")
class TestPluginHooks:
"""
e2e tests to make sure all the hooks are fired during the training
"""
def test_plugin_hooks(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"tests.e2e.integrations.test_hooks.LogHooksPlugin",
],
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"flash_attention": True,
"bf16": "auto",
}
)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
with open(
"/tmp/axolotl-log-hooks" + "/plugin_hooks.log", "r", encoding="utf-8"
) as f:
file_contents = f.readlines()
file_contents = "\n".join(file_contents)
assert "pre_model_load" in file_contents
assert "post_model_build" in file_contents
assert "pre_lora_load" in file_contents
assert "post_lora_load" in file_contents
assert "post_model_load" in file_contents
# assert "create_optimizer" in file_contents # not implemented yet
assert "get_trainer_cls" in file_contents
assert "create_lr_scheduler" in file_contents
assert "add_callbacks_pre_trainer" in file_contents
assert "add_callbacks_post_trainer" in file_contents
assert "post_train" in file_contents
# assert "post_train_unload" in file_contents # not called from test train call
try:
os.remove("/tmp/axolotl-log-hooks" + "/plugin_hooks.log")
except FileNotFoundError:
pass

View File

@@ -30,18 +30,16 @@ MODELS = [
@pytest.mark.parametrize(
"save_compressed", [True, False], ids=["save_compressed", "save_uncompressed"]
)
@require_llmcompressor
class TestLLMCompressorIntegration:
"""
e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin
"""
@require_llmcompressor
@require_torch_2_4_1
def test_llmcompressor_plugin(
self, temp_dir, base_model: str, save_compressed: bool
):
from llmcompressor import active_session
# core cfg
cfg = DictDefault(
{
@@ -91,12 +89,9 @@ class TestLLMCompressorIntegration:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
try:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
_check_llmcompressor_model_outputs(temp_dir, save_compressed)
finally:
active_session().reset()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
_check_llmcompressor_model_outputs(temp_dir, save_compressed)
def _check_llmcompressor_model_outputs(temp_dir, save_compressed):

View File

@@ -2,19 +2,14 @@
# pylint: disable=redefined-outer-name
from pathlib import Path
import pytest
import torch
import yaml
from accelerate.state import PartialState
from peft import PeftModelForCausalLM, get_peft_config
from transformers import AutoModelForCausalLM, LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeAttention
from axolotl.cli.config import load_cfg
from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
@@ -71,36 +66,29 @@ def small_llama_model():
return LlamaForCausalLM(LlamaConfig(**config))
@pytest.mark.parametrize(
"model_name,attention_cls",
[
("HuggingFaceTB/SmolLM2-135M", LlamaAttention),
("Qwen/Qwen3-30B-A3B", Qwen3MoeAttention),
],
)
def test_attention_patching_integration(model_name, attention_cls):
def test_attention_patching_integration():
"""Test attention patching in integration context."""
cfg = {"base_model": model_name}
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
# Store the original implementation
original_forward = getattr(attention_cls, "forward")
original_forward = getattr(LlamaAttention, "forward")
# Apply patch
patch_self_attn_lora(cfg)
# Get the new forward method
patched_forward = attention_cls.forward
patched_forward = LlamaAttention.forward
# Check the forward method was replaced
assert original_forward is not patched_forward
assert patched_forward.__name__ == "axolotl_attn_forward"
# Check original implementation was stored
assert hasattr(attention_cls, "_original_forward")
assert hasattr(LlamaAttention, "_original_forward")
# Clean up
setattr(attention_cls, "forward", original_forward)
delattr(attention_cls, "_original_forward")
setattr(LlamaAttention, "forward", original_forward)
delattr(LlamaAttention, "_original_forward")
def test_swiglu_mlp_integration(small_llama_model):
@@ -425,42 +413,3 @@ def test_kernel_training_integration():
# Verify correct activation function
layer = model.model.model.layers[0]
assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu
def test_kernel_training_integration_auto_enable(temp_dir):
"""Test model loading with auto-enabled kernel patches."""
# Create minimal config without explicitly setting kernel options
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
}
)
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Verify kernel options were auto-enabled in the config
assert cfg.lora_mlp_kernel is True
assert cfg.lora_qkv_kernel is True
assert cfg.lora_o_kernel is True

View File

@@ -28,7 +28,7 @@ class Test4dMultipackLlama(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"base_model": "JackFram/llama-68m",
"flash_attention": False,
"sdp_attention": True,
"sample_packing": True,
@@ -41,9 +41,6 @@ class Test4dMultipackLlama(unittest.TestCase):
"lora_target_linear": True,
"sequence_len": 1024,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
@@ -76,7 +73,7 @@ class Test4dMultipackLlama(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"base_model": "JackFram/llama-68m",
"flash_attention": False,
"sdp_attention": False,
"sample_packing": True,
@@ -89,9 +86,6 @@ class Test4dMultipackLlama(unittest.TestCase):
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",

View File

@@ -32,7 +32,7 @@ class TestFusedLlama(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"base_model": "JackFram/llama-68m",
"flash_attention": True,
"pad_to_sequence_len": True,
"flash_attn_fuse_qkv": True,
@@ -41,7 +41,9 @@ class TestFusedLlama(unittest.TestCase):
"sequence_len": 1024,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{

View File

@@ -31,8 +31,8 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 16384,
"sample_packing": False,
"flash_attention": True,
@@ -44,9 +44,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"special_tokens": {},
"datasets": [
{
"path": "Yukang/LongAlpaca-12k",
@@ -80,16 +78,14 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 16384,
"sample_packing": False,
"flash_attention": True,
"s2_attention": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"special_tokens": {},
"datasets": [
{
"path": "Yukang/LongAlpaca-12k",

View File

@@ -31,8 +31,8 @@ class TestLoraLlama(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
@@ -44,7 +44,9 @@ class TestLoraLlama(unittest.TestCase):
"lora_target_linear": True,
"val_set_size": 0.2,
"special_tokens": {
"pad_token": "<|endoftext|>",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
@@ -82,9 +84,9 @@ class TestLoraLlama(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "lilmeaty/SmolLM2-135M-Instruct-GPTQ",
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
@@ -98,7 +100,9 @@ class TestLoraLlama(unittest.TestCase):
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{

View File

@@ -31,8 +31,8 @@ class TestDPOLlamaLora(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
@@ -40,9 +40,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"special_tokens": {},
"rl": "dpo",
"datasets": [
{
@@ -79,8 +77,8 @@ class TestDPOLlamaLora(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
@@ -88,9 +86,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"special_tokens": {},
"rl": "dpo",
"rpo_alpha": 0.5,
"datasets": [
@@ -128,8 +124,8 @@ class TestDPOLlamaLora(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
@@ -137,9 +133,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"special_tokens": {},
"rl": "dpo",
"dpo_use_weighting": True,
"datasets": [
@@ -178,8 +172,8 @@ class TestDPOLlamaLora(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
@@ -187,9 +181,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"special_tokens": {},
"rl": "kto_pair",
"datasets": [
{
@@ -226,8 +218,8 @@ class TestDPOLlamaLora(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
@@ -235,9 +227,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"special_tokens": {},
"rl": "ipo",
"datasets": [
{
@@ -274,8 +264,8 @@ class TestDPOLlamaLora(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
@@ -283,9 +273,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"special_tokens": {},
"rl": "orpo",
"orpo_alpha": 0.1,
"remove_unused_columns": False,
@@ -326,7 +314,7 @@ class TestDPOLlamaLora(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
@@ -335,9 +323,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"special_tokens": {},
"rl": "kto",
"rl_beta": 0.5,
"kto_desirable_weight": 1.0,

View File

@@ -1,65 +0,0 @@
"""E2E smoke test for evaluate CLI command"""
import os
from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
os.environ["WANDB_DISABLED"] = "true"
class TestE2eEvaluate:
"""Test cases for evaluate CLI"""
def test_evaluate(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"val_set_size": 0.02,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.evaluate",
str(Path(temp_dir) / "config.yaml"),
]
)

View File

@@ -26,13 +26,15 @@ class TestLlama:
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"trust_remote_code": True,
"sequence_len": 512,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{

View File

@@ -26,9 +26,9 @@ class TestLoadModelUtils:
# load config
self.cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"tokenizer_config": "JackFram/llama-68m",
"sequence_len": 1024,
"load_in_8bit": False,
"adapter": "lora",
@@ -38,7 +38,9 @@ class TestLoadModelUtils:
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{

View File

@@ -28,8 +28,8 @@ class TestLoraLlama(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
@@ -39,7 +39,9 @@ class TestLoraLlama(unittest.TestCase):
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
@@ -48,13 +50,13 @@ class TestLoraLlama(unittest.TestCase):
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"max_steps": 20,
}
)

View File

@@ -28,9 +28,8 @@ class TestCustomOptimizers(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
@@ -40,7 +39,9 @@ class TestCustomOptimizers(unittest.TestCase):
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
@@ -74,9 +75,8 @@ class TestCustomOptimizers(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
@@ -86,7 +86,9 @@ class TestCustomOptimizers(unittest.TestCase):
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
@@ -120,9 +122,8 @@ class TestCustomOptimizers(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
@@ -132,7 +133,9 @@ class TestCustomOptimizers(unittest.TestCase):
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
@@ -167,7 +170,6 @@ class TestCustomOptimizers(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"sequence_len": 1024,
"val_set_size": 0.01,
"special_tokens": {

View File

@@ -28,8 +28,8 @@ class TestCustomSchedulers(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
@@ -39,7 +39,9 @@ class TestCustomSchedulers(unittest.TestCase):
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{

View File

@@ -105,7 +105,7 @@ def require_vllm(test_case):
return False
return unittest.skipUnless(
is_vllm_installed(), "test requires vllm to be installed"
is_vllm_installed(), "test requires a vllm to be installed"
)(test_case)
@@ -123,7 +123,7 @@ def require_llmcompressor(test_case):
return False
return unittest.skipUnless(
is_llmcompressor_installed(), "test requires llmcompressor to be installed"
is_llmcompressor_installed(), "test requires a llmcompressor to be installed"
)(test_case)

View File

@@ -648,7 +648,7 @@ class TestValidation(BaseValidation):
DictDefault(
{
"sample_packing": True,
"pad_to_sequence_len": False,
"pad_to_sequence_len": None,
"flash_attention": True,
}
)
@@ -662,26 +662,6 @@ class TestValidation(BaseValidation):
for record in self._caplog.records
)
def test_packing_autoset(self, minimal_cfg):
cfg = (
DictDefault(
{
"sample_packing": True,
"pad_to_sequence_len": None,
"flash_attention": True,
}
)
| minimal_cfg
)
with self._caplog.at_level(logging.INFO):
cfg = validate_config(cfg)
assert any(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
in record.message
for record in self._caplog.records
)
assert cfg.pad_to_sequence_len is True
def test_merge_lora_no_bf16_fail(self, minimal_cfg):
"""
This is assumed to be run on a CPU machine, so bf16 is not supported.

View File

@@ -2,8 +2,6 @@
tests for chat_template prompt strategy
"""
# pylint: disable=too-many-lines
import logging
from copy import deepcopy
@@ -55,6 +53,14 @@ class TestChatTemplateConfigurations:
Test class for various configurations of ChatTemplateStrategy.
"""
@staticmethod
def find_sublist(full_list, sub_list):
token_count = len(sub_list)
for index in range(len(full_list) - token_count + 1):
if full_list[index : index + token_count] == sub_list:
return index
return -1
@staticmethod
def setup_tokenizer(
tokenizer_name,
@@ -62,7 +68,6 @@ class TestChatTemplateConfigurations:
chat_template_jinja=None,
eos_token=None,
request=None,
eot_token=None,
) -> tuple[PreTrainedTokenizer, str]:
"""
Helper function to set up the tokenizer and chat template for the test.
@@ -83,10 +88,6 @@ class TestChatTemplateConfigurations:
"CodeLlamaTokenizerFast",
):
tokenizer.update_post_processor()
if eot_token:
tokenizer.add_special_tokens({"additional_special_tokens": [eot_token]})
return tokenizer, chat_template_jinja
def _should_skip_turn(self, tokenizer, turn, turn_idx, start_idx, end_idx):
@@ -973,311 +974,3 @@ class TestChatTemplateConfigurations:
raise ValueError(
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
)
def test_eot_tokens_conflict_with_eos_token(
self,
tokenizer,
chat_template,
chat_template_jinja,
eos_token,
basic_dataset, # pylint: disable=unused-argument
request,
):
"""Test that an error is raised when eot_tokens contains eos_token and train_on_eot/train_on_eos conflict"""
LOG.info(
"Testing conflict between eot_tokens containing eos_token and train_on_eot/train_on_eos mismatch"
)
tokenizer, chat_template_jinja = self.setup_tokenizer(
tokenizer, chat_template, chat_template_jinja, eos_token, request
)
# Create a situation where eot_tokens contains eos_token
eot_tokens = [
tokenizer.eos_token,
"[/INST]",
] # Deliberately including eos_token
# Create conflicting train_on_eos and train_on_eot settings
with pytest.raises(
ValueError,
match=".*eos_token is in eot_tokens and train_on_eos != train_on_eot.*",
):
ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="none", # Setting to none
train_on_eot="turn", # Different from train_on_eos
eot_tokens=eot_tokens,
)
def test_eot_token_backward_compatibility(
self,
tokenizer,
chat_template,
chat_template_jinja,
eos_token,
basic_dataset, # pylint: disable=unused-argument
request,
):
"""Test that eot_tokens inherits from eos_token when not specified"""
LOG.info("Testing backward compatibility that eot_token inherits eos_token")
tokenizer, chat_template_jinja = self.setup_tokenizer(
tokenizer, chat_template, chat_template_jinja, eos_token, request
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="turn", # Setting train_on_eos to "turn"
)
# In backward compatibility mode, eot_tokens should be derived from eos_token
assert strategy.eot_tokens == [
tokenizer.eos_token
], f"Expected eot_tokens to inherit from eos_token, got {strategy.eot_tokens}"
assert (
strategy.train_on_eot == "turn"
), f"Expected train_on_eot to inherit from train_on_eos, got {strategy.train_on_eot}"
def test_token_not_in_template(
self,
tokenizer,
chat_template,
chat_template_jinja,
eos_token,
basic_dataset,
request,
):
"""Test runs even when tokens are not found in the template"""
LOG.info("Testing runs even when tokens are not found in template")
tokenizer, chat_template_jinja = self.setup_tokenizer(
tokenizer, chat_template, chat_template_jinja, eos_token, request
)
# Create a non-existent token that definitely won't be in the template
non_existent_token = "[DEFINITELY_NOT_IN_TEMPLATE]"
tokenizer.add_special_tokens(
{"additional_special_tokens": [non_existent_token]}
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
eot_tokens=[non_existent_token],
)
# Force template check by calling tokenize_prompt
strategy.tokenize_prompt(basic_dataset[0])
# We can also check that a warning was logged, but there's
# caplog conflicts when running with other tests
# assert any(
# "not found in chat_template" in record.message for record in self._caplog.records
# ), "Expected warning about token not found in template was not logged"
def test_custom_eot_tokens(
self,
tokenizer,
chat_template,
chat_template_jinja,
eos_token, # pylint: disable=unused-argument
basic_dataset,
request,
):
"""Test with custom EOT tokens to ensure proper masking and training"""
LOG.info("Testing with custom EOT tokens")
tokenizer, chat_template_jinja = self.setup_tokenizer(
tokenizer, chat_template, chat_template_jinja, None, request
)
# Add custom EOT tokens to the tokenizer
custom_eot = "[EOT]"
tokenizer.add_special_tokens({"additional_special_tokens": [custom_eot]})
# Create a custom chat template that uses our EOT token
custom_template = """{% for message in messages %}{% if message['role'] == 'system' %}{{ message['content'] }}{% elif message['role'] == 'user' %}User: {{ message['content'] }}{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}[EOT]{% endif %}{% endfor %}"""
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=custom_template,
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eot="turn", # Train on EOT token after each turn
eot_tokens=[custom_eot],
)
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Find indices of the EOT token
eot_token_id = tokenizer.convert_tokens_to_ids(custom_eot)
eot_indices = [
i for i, token_id in enumerate(input_ids) if token_id == eot_token_id
]
assert len(eot_indices) > 0, "Expected at least one EOT token in the input"
# Verify labeling for EOT tokens based on role
turns = strategy.get_conversation_thread(basic_dataset[0])
assistant_turn_indices = []
non_assistant_turn_indices = []
for i, turn in enumerate(basic_dataset[0]["conversations"]):
start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)
if start_idx != -1 and end_idx != -1: # If turn is found
if turn["from"] == "assistant":
assistant_turn_indices.append((start_idx, end_idx))
else:
non_assistant_turn_indices.append((start_idx, end_idx))
# Check EOT tokens after assistant turns are labeled
for eot_idx in eot_indices:
is_after_assistant = any(
start_idx <= eot_idx <= end_idx + 1 # +1 to include the EOT token
for start_idx, end_idx in assistant_turn_indices
)
if is_after_assistant:
assert (
labels[eot_idx] != IGNORE_TOKEN_ID
), f"Expected EOT token after assistant turn at index {eot_idx} to be labeled"
else:
assert (
labels[eot_idx] == IGNORE_TOKEN_ID
), f"Expected EOT token not after assistant turn at index {eot_idx} to not be labeled"
def test_multiple_train_on_eot_settings(
self,
tokenizer,
chat_template,
chat_template_jinja,
eos_token,
basic_dataset,
request,
):
"""Test different train_on_eot settings"""
LOG.info("Testing different train_on_eot settings")
tokenizer, chat_template_jinja = self.setup_tokenizer(
tokenizer, chat_template, chat_template_jinja, eos_token, request
)
# Create a list to test different train_on_eot settings
test_settings = [
("none", lambda idx, is_assistant: False), # Never train on EOT
("all", lambda idx, is_assistant: True), # Always train on EOT
(
"turn",
lambda idx, is_assistant: is_assistant,
), # Train on EOT after assistant turns
("last", lambda idx, is_last: is_last), # Only train on last EOT
]
for setting, expected_train_func in test_settings:
LOG.info(f"Testing train_on_eot='{setting}'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eot=setting,
eot_tokens=[
tokenizer.eos_token
], # Use eos_token as the EOT token for simplicity
)
res = strategy.tokenize_prompt(basic_dataset[0])
turns = strategy.get_conversation_thread(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
eos_token_id = tokenizer.eos_token_id
eos_indices = [
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
assert (
len(eos_indices) > 0
), "Expected at least one EOS/EOT token in the input"
# Check labeling for each EOS/EOT token
for idx, eos_idx in enumerate(eos_indices):
# Find which turn this EOS token belongs to
preceding_turn = None
for i, turn in enumerate(basic_dataset[0]["conversations"]):
start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)
if (
start_idx != -1
and end_idx != -1
and start_idx <= eos_idx <= end_idx + 1
):
preceding_turn = turn
break
is_assistant = (
preceding_turn is not None and preceding_turn["from"] == "assistant"
)
is_last = idx == len(eos_indices) - 1
expected_label = not expected_train_func(
idx, is_assistant if setting != "last" else is_last
)
if expected_label:
assert (
labels[eos_idx] == IGNORE_TOKEN_ID
), f"Expected EOT token at index {eos_idx} to not be labeled with train_on_eot='{setting}'"
else:
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"

Some files were not shown because too many files have changed in this diff Show More