Compare commits

..

3 Commits

Author SHA1 Message Date
NanoCode012
1f2f285173 fix: missing key in enum 2025-07-03 13:46:16 +08:00
NanoCode012
98e912e416 feat: add custom processing strategy for phi35 vl 2025-07-03 13:46:16 +08:00
NanoCode012
e1528fb381 feat: add phi_35_vl support 2025-07-03 13:46:16 +08:00
189 changed files with 1075 additions and 4741 deletions

View File

@@ -1,3 +1,3 @@
[bandit]
exclude = tests
skips = B101,B615
skips = B101

View File

@@ -5,13 +5,11 @@ on:
branches:
- "main"
paths:
- 'docker/Dockerfile-base'
- 'docker/Dockerfile-uv-base'
- 'Dockerfile-base'
- '.github/workflows/base.yml'
pull_request:
paths:
- 'docker/Dockerfile-base'
- 'docker/Dockerfile-uv-base'
- 'Dockerfile-base'
- '.github/workflows/base.yml'
workflow_dispatch:
@@ -29,11 +27,11 @@ jobs:
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.5.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
@@ -43,7 +41,7 @@ jobs:
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"

View File

@@ -15,15 +15,15 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.5.1
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
pytorch: 2.6.0
axolotl_extras: vllm
- cuda: 126
cuda_version: 12.6.3
@@ -82,17 +82,17 @@ jobs:
strategy:
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"

View File

@@ -33,6 +33,13 @@ jobs:
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"

View File

@@ -12,6 +12,11 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -63,10 +68,10 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.5.1
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:

View File

@@ -28,8 +28,6 @@ jobs:
steps:
- name: Check out repository
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Set up Quarto
uses: quarto-dev/quarto-actions/setup@v2
@@ -52,11 +50,10 @@ jobs:
- name: Netlify Publish
uses: nwtgck/actions-netlify@v3.0
id: netlify
with:
publish-dir: './_site'
enable-pull-request-comment: false
enable-github-deployment: false
enable-pull-request-comment: true
enable-github-deployment: true
github-token: ${{ secrets.GITHUB_TOKEN }}
deploy-message: "Deployed On Netlify"
github-deployment-environment: 'preview'
@@ -64,13 +61,3 @@ jobs:
env:
NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }}
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}
- name: Update PR with preview link
if: ${{ steps.netlify.outcome == 'success' }}
uses: marocchino/sticky-pull-request-comment@v2
with:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
message: |
📖 **Documentation Preview**: ${{ steps.netlify.outputs.deploy-url }}
Deployed on Netlify from commit ${{ github.event.pull_request.head.sha }}

View File

@@ -18,26 +18,116 @@ jobs:
env:
SKIP: no-commit-to-branch
preload-cache:
name: Preload HF cache
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0"]
timeout-minutes: 20
env:
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v tests/conftest.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest:
name: PyTest
runs-on: ubuntu-latest
needs: [preload-cache]
strategy:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
@@ -78,11 +168,15 @@ jobs:
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v --durations=10 tests/patched/
pytest -v --durations=10 tests/cli/
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v tests/patched/
pytest -v tests/cli/
- name: cleanup pip cache
run: |
@@ -99,8 +193,15 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1

View File

@@ -52,7 +52,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
timeout-minutes: 20
steps:
@@ -102,9 +102,9 @@ jobs:
- name: Run tests
run: |
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
pytest -v tests/patched/ --cov=axolotl --cov-append --cov-report=xml
pytest -v tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
@@ -125,7 +125,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
timeout-minutes: 20
steps:
@@ -175,9 +175,9 @@ jobs:
- name: Run tests
run: |
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v --durations=10 tests/patched/
pytest -v --durations=10 tests/cli/
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v tests/patched/
pytest -v tests/cli/
- name: cleanup pip cache
run: |
@@ -198,7 +198,7 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
- cuda: 126
@@ -252,6 +252,18 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: llmcompressor
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
num_gpus: 1
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1

View File

@@ -36,7 +36,7 @@ repos:
'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.8.6
rev: 1.8.5
hooks:
- id: bandit
args: [

View File

@@ -97,7 +97,7 @@
# # 'no_input_format' cannot include {input}
# no_input_format: "{instruction} "
# # For `completion` datasets only, uses the provided field instead of `text` column
# # For `completion` datsets only, uses the provided field instead of `text` column
# field:
# # Axolotl attempts to save the dataset as an arrow after packing the data together so

View File

@@ -2,5 +2,4 @@ include requirements.txt
include README.md
include LICENSE
include src/setuptools_axolotl_dynamic_dependencies.py
include src/axolotl/utils/chat_templates/templates/*.jinja
recursive-include axolotl *.py

View File

@@ -55,7 +55,7 @@ Features:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- PyTorch ≥2.6.0
- PyTorch ≥2.5.1
### Installation

View File

@@ -9,7 +9,6 @@ ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_PROCESSES="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev

View File

@@ -24,9 +24,9 @@ df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
"CUDA": os.environ.get("CUDA", "126"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
"CUDA": os.environ.get("CUDA", "124"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),

View File

@@ -24,9 +24,9 @@ df_template = template_env.get_template(dockerfile)
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
"CUDA": os.environ.get("CUDA", "126"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
"CUDA": os.environ.get("CUDA", "124"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),

View File

@@ -37,7 +37,3 @@ RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
fi

View File

@@ -7,7 +7,6 @@ toc-depth: 3
```{python}
#| echo: false
import os
import re
def process_readme(integration_name):
@@ -54,24 +53,6 @@ sections = [
("LLMCompressor", "llm_compressor")
]
for folder_name in os.listdir("../src/axolotl/integrations/"):
if folder_name in [path for name, path in sections]:
# skip if already in sections
continue
if os.path.exists(f"../src/axolotl/integrations/{folder_name}/README.md"):
# grab the first heading in README.md as the section name
with open(f"../src/axolotl/integrations/{folder_name}/README.md", "r") as f:
txt = f.read()
matches = re.search(r'^# (.*)\n?', txt, flags=re.MULTILINE)
if matches:
name = matches.group(1)
else:
continue
sections.append((name, folder_name))
# sort sections by name
sections = sorted(sections, key=lambda x: x[0])
for section_name, folder_name in sections:
print(print_section(section_name, folder_name))
```

View File

@@ -187,7 +187,6 @@ Instead of passing `tools` via the system prompt, an alternative method would be
"role": "assistant", // call the function via assistant
"tool_calls": [
{
"id": "...", // required only for mistral
"type": "function",
"function": {
"name": "...",
@@ -200,7 +199,6 @@ Instead of passing `tools` via the system prompt, an alternative method would be
},
{
"role": "tool",
"tool_call_id": "...", // required only for mistral
"name": "...",
"content": "..."
},

View File

@@ -34,9 +34,9 @@ Tags examples:
- `main-base-py3.11-cu128-2.7.1`
- `main-base-py3.11-cu126-2.7.1`
- `main-base-py3.11-cu126-2.7.0`
- `main-base-py3.11-cu126-2.6.0`
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
## Main
@@ -76,12 +76,12 @@ Tags examples:
- `main-py3.11-cu128-2.7.1`
- `main-py3.11-cu126-2.7.1`
- `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu126-2.6.0`
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu126-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1`
- `0.10.1`
## Cloud

View File

@@ -51,18 +51,6 @@ description: Frequently asked questions
> pad_token: "..."
> ```
**Q: `IterableDataset error` or `KeyError: 'input_ids'` when using `preprocess` CLI**
> A: This is because you may be using `preprocess` CLI with `pretraining_dataset:` or `skip_prepare_dataset: true` respectively. Please use `axolotl train` CLI directly instead as these datasets are prepared on demand.
**Q: vLLM is not working with Axolotl**
> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag.
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**

View File

@@ -20,7 +20,7 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
> See the [example config](#example-config) file in addition to reading these instructions.
1. Set `adapter: qlora` in your axolotl config file.
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
2. Enable FSDP in your axolotl config, as [described here](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#fsdp).
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
## Example Config

View File

@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.11
- PyTorch ≥2.6.0
- PyTorch ≥2.5.1
## Installation Methods {#sec-installation-methods}

View File

@@ -23,6 +23,8 @@ Axolotl supports several methods for multi-GPU training:
## DeepSpeed {#sec-deepspeed}
DeepSpeed is the recommended approach for multi-GPU training due to its stability and performance. It provides various optimization levels through ZeRO stages.
### Configuration {#sec-deepspeed-config}
Add to your YAML config:
@@ -30,6 +32,7 @@ Add to your YAML config:
```{.yaml}
deepspeed: deepspeed_configs/zero1.json
```
### Usage {#sec-deepspeed-usage}
```{.bash}
@@ -63,75 +66,9 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
:::
::: {.callout-tip}
## FSDP {#sec-fsdp}
Using ZeRO Stage 3 with Single-GPU training
ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables:
`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500`
:::
## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
::: {.callout-note}
FSDP2 is recommended for new users. FSDP1 is deprecated and will be removed in an upcoming release of Axolotl.
:::
### Migrating from FSDP1 to FSDP2 {#sec-migrate-fsdp1-fsdp2}
To migrate your config from FSDP1 to FSDP2, you must use the `fsdp_version` top-level config field to specify the FSDP version, and
also follow the config field mapping below to update field names.
#### Config mapping
FSDP1 | FSDP2
-------- | --------
fsdp_sharding_strategy | reshard_after_forward
fsdp_backward_prefetch_policy | **REMOVED**
fsdp_backward_prefetch | **REMOVED**
fsdp_forward_prefetch | **REMOVED**
fsdp_sync_module_states | **REMOVED**
fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading
fsdp_state_dict_type | state_dict_type
fsdp_use_orig_params | **REMOVED**
For example, if you were using the following FSDP1 config:
```{.yaml}
fsdp_version: 1
fsdp_config:
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
```
You can migrate to the following FSDP2 config:
```{.yaml}
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3DecoderLayer
state_dict_type: FULL_STATE_DICT
reshard_after_forward: true
```
### FSDP1 (deprecated) {#sec-fsdp-config}
::: {.callout-note}
Using `fsdp` to configure FSDP is deprecated and will be removed in an upcoming release of Axolotl. Please use `fsdp_config` as above instead.
:::
### Basic FSDP Configuration {#sec-fsdp-config}
```{.yaml}
fsdp:
@@ -143,7 +80,6 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
## Sequence parallelism {#sec-sequence-parallelism}
We support sequence parallelism (SP) via the

View File

@@ -40,13 +40,13 @@ use_cpu: false
Configure your model to use FSDP in the Axolotl yaml. For example:
```yaml
fsdp_version: 2
fsdp:
- full_shard
- auto_wrap
fsdp_config:
offload_params: true
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
reshard_after_forward: true
fsdp_offload_params: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.

View File

@@ -16,6 +16,7 @@ format:
- [Gemma-3](#sec-gemma-3)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [Phi3-V](#sec-phi3-v)
## Usage
@@ -126,6 +127,15 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### Phi3-V {#sec-phi3-v}
```yaml
base_model: microsoft/Phi-3.5-vision-instruct
trust_remote_code: true
chat_template: phi_35_vl
```
## Dataset Format
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.

View File

@@ -17,6 +17,7 @@ feedback. Various methods include, but not limited to:
- [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
- [Group Relative Policy Optimization (GRPO)](#grpo)
- Proximal Policy Optimization (PPO) (not yet supported in axolotl, if you're interested in contributing, please reach out!)
## RLHF using Axolotl
@@ -274,14 +275,15 @@ rl: dpo
datasets:
- path: ...
split: train
type:
field_prompt: "prompt"
field_system: "system"
field_chosen: "chosen"
field_rejected: "rejected"
prompt_format: "{prompt}"
chosen_format: "{chosen}"
rejected_format: "{rejected}"
type: user_defined.default
field_prompt: "prompt"
field_system: "system"
field_chosen: "chosen"
field_rejected: "rejected"
prompt_format: "{prompt}"
chosen_format: "{chosen}"
rejected_format: "{rejected}"
```
The input format is a simple JSON input with customizable fields based on the above config.
@@ -474,13 +476,14 @@ rl: kto
datasets:
- path: ...
split: train
type:
field_prompt: "prompt"
field_system: "system"
field_completion: "completion"
field_label: "label"
prompt_format: "{prompt}"
completion_format: "{completion}"
type: user_defined.default
field_prompt: "prompt"
field_system: "system"
field_completion: "completion"
field_label: "label"
prompt_format: "{prompt}"
completion_format: "{completion}"
```
The input format is a simple JSON input with customizable fields based on the above config.

View File

@@ -1,5 +0,0 @@
# Archived Examples
This directory contains examples that are no longer maintained and may no longer be functional.
We keep them around for archival purposes in case they are useful to others.

View File

@@ -1,70 +0,0 @@
# Finetune Devstral with Axolotl
Devstral Small is a 24B parameter opensource model from MistralAI found on HuggingFace [Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505) and [Devstral-Small-2507](https://huggingface.co/mistralai/Devstral-Small-2507). `Devstral-Small-2507` is the latest version of the model and has [function calling](https://mistralai.github.io/mistral-common/usage/tools/) support.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.
The model was fine-tuned ontop of [Mistral-Small-3.1](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503) without the vision layer and has a context of up to 128k tokens.
Thanks to the team at MistralAI for giving us early access to prepare for this release.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. Run the finetuning example:
```bash
axolotl train examples/devstral/devstral-small-qlora.yml
```
This config uses about 21GB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- Learn how to use function calling with Axolotl at [docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)
- [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels)
## Limitations
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
In addition, we do not support overriding tokens yet.
## Related Resources
- [MistralAI Devstral Blog](https://mistral.ai/news/devstral)
- [MistralAI Devstral 1.1 Blog](https://mistral.ai/news/devstral-2507)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
## Future Work
- Add parity to Preference Tuning, RL, Multi-modal, etc.
- Add parity to other tokenizer configs like overriding tokens.

View File

@@ -1,64 +0,0 @@
base_model: mistralai/Devstral-Small-2507
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
load_in_8bit: false
load_in_4bit: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_linear: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_ratio: 0.05
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,7 +0,0 @@
# Liquid Foundation Models 2
LFM2 support in transformers exists in the main branch, but is not yet included in the transformers release.
```bash
pip install --upgrade --no-deps --force-reinstall git+https://github.com/huggingface/transformers.git
```

View File

@@ -1,48 +0,0 @@
base_model: LiquidAI/LFM2-350M
chunked_cross_entropy: true
chat_template: tokenizer_default
eot_tokens:
- "<|im_end|>"
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
bf16: true
tf32: true
gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -18,10 +18,16 @@ git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
pip3 install --no-build-isolation -e '.[flash-attn,mistral]'
```
2. Run the finetuning example:
2. Download the example config:
```bash
axolotl fetch examples
```
3. Run the finetuning example:
```bash
axolotl train examples/magistral/magistral-small-qlora.yaml
@@ -36,7 +42,7 @@ Let us know how it goes. Happy finetuning! 🚀
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
@@ -48,7 +54,7 @@ Let us know how it goes. Happy finetuning! 🚀
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
In addition, we do not support overriding tokens yet.
The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet.
## Related Resources

View File

@@ -13,7 +13,7 @@ packaging==23.2
huggingface_hub==0.32.2
peft==0.15.2
transformers==4.53.1
transformers==4.52.4
tokenizers>=0.21.1
accelerate==1.8.1
datasets==3.6.0
@@ -68,4 +68,4 @@ schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3
mistral-common==1.7.0
mistral-common==1.6.3

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@78b2a45713a54c9bedf8b33f5e31cf07a1a57154"'
)

View File

@@ -66,11 +66,8 @@ def parse_requirements(extras_require_map):
if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.30")
else:
_install_requires.append("xformers==0.0.31.post1")
extras_require_map["vllm"] = ["vllm>=0.9.0"]
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append(
@@ -117,7 +114,7 @@ extras_require = {
"flash-attn": ["flash-attn==2.8.0.post2"],
"ring-flash-attn": [
"flash-attn==2.8.0.post2",
"ring-flash-attn>=0.1.5",
"ring-flash-attn>=0.1.4",
"yunchang==0.6.0",
],
"deepspeed": [

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.12.0.dev"
__version__ = "0.11.0.dev"

View File

@@ -16,7 +16,6 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
migrate_fsdp_config,
normalize_cfg_datasets,
normalize_config,
validate_config,
@@ -227,7 +226,6 @@ def load_cfg(
},
)
migrate_fsdp_config(cfg)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg)

View File

@@ -35,12 +35,6 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
check_accelerate_default_config()
check_user_token()
for key in ["skip_prepare_dataset", "pretraining_dataset"]:
if cfg.get("key"):
raise ValueError(
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead."
)
if not cfg.dataset_prepared_path:
msg = (
Fore.RED

View File

@@ -109,13 +109,6 @@ def ray_train_func(kwargs: dict):
# initialize accelerator before model instantiation
Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
# Register plugins in Ray workers
if cfg.get("plugins"):
from axolotl.cli.config import plugin_set_cfg, prepare_plugins
prepare_plugins(cfg)
plugin_set_cfg(cfg)
kwargs["cfg"] = cfg
do_train(**kwargs)

View File

@@ -1,162 +0,0 @@
"""
monkeypatch for flex + packing
"""
import sys
from typing import Callable, Optional, Union
import torch
from torch.nn.attention.flex_attention import BlockMask
from transformers import Cache, PretrainedConfig
from transformers.masking_utils import (
ALL_MASK_ATTENTION_FUNCTIONS,
_preprocess_mask_arguments,
and_masks,
causal_mask_function,
or_masks,
)
from transformers.utils import is_torch_greater_or_equal
_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
def create_causal_mask(
config: PretrainedConfig,
input_embeds: torch.Tensor,
attention_mask: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, BlockMask]]:
"""
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
to what is needed in the `modeling_xxx.py` files).
Args:
config (`PretrainedConfig`):
The model config.
input_embeds (`torch.Tensor`):
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
batch size, query length and dtype.
attention_mask (`torch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
It can also be an already prepared 4D mask, in which case it is returned as-is.
cache_position (`torch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
or_mask_function (`Callable`, optional):
An optional mask function to combine with the causal mask function (by doing the union of both). This is
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
and_mask_function (`Callable`, optional):
An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
"""
# If we have an HybridCache structure, here we want to create the mask for the full layers
if (
past_key_values
and hasattr(past_key_values, "is_sliding")
and False in past_key_values.is_sliding
):
layer_idx = past_key_values.is_sliding.index(False)
else:
layer_idx = 0
original_attention_mask = (
None
if attention_mask is None
else attention_mask.clone().to(cache_position.device)
)
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
)
if early_exit:
return attention_mask
batch_size, total_seq_len = cache_position.shape
key_length = total_seq_len
document_ids = torch.nn.functional.pad(
original_attention_mask, value=0, pad=(0, key_length)
)
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
if attention_mask is not None:
def causal_doc_mask_mod(
batch_idx, head_idx, q_idx, kv_idx
): # pylint: disable=unused-argument
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask_ = q_idx >= kv_idx # not valid when decoding
document_mask = (
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
)
final_mask = causal_mask_ & document_mask
return final_mask
mask_factory_function = causal_doc_mask_mod
else:
mask_factory_function = causal_mask_function
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[
config._attn_implementation # pylint: disable=protected-access
]
# Do not allow skip if we are compiling (this is to match BC)
allow_is_causal_skip = (
not past_key_values.is_compileable if past_key_values is not None else True
)
# Allow slight deviations from causal mask
if or_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError(
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
)
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
allow_is_causal_skip = False
if and_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError(
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
)
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
allow_is_causal_skip = False
# We now create the mask
causal_mask = mask_interface(
batch_size=batch_size,
cache_position=cache_position,
kv_length=kv_length,
kv_offset=kv_offset,
mask_function=mask_factory_function,
attention_mask=attention_mask,
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
dtype=dtype, # Additional kwarg for eager
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
)
return causal_mask
def patch_create_causal_mask(model_type):
import transformers.masking_utils
transformers.masking_utils.create_causal_mask = create_causal_mask
if model_type:
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = __import__(module_path)
module.create_causal_mask = create_causal_mask
del sys.modules[module_path]
except (ImportError, AttributeError) as e:
raise ValueError(
f"Could not import attention class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e

View File

@@ -219,9 +219,7 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.bf16 == "full":
training_args_kwargs["bf16_full_eval"] = True
else:
bf16 = self.cfg.bf16 or self.cfg.bfloat16
bf16 = bf16 if bf16 is not None else False
training_args_kwargs["bf16"] = bf16
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
def _configure_scheduler(self, training_args_kwargs: dict):
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
@@ -501,10 +499,6 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.reward_model or self.cfg.rl:
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.fsdp_config or self.cfg.fsdp:
training_args_kwargs["fsdp_config"] = self.cfg.fsdp_config
training_args_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp else True
self._configure_reporting(training_args_kwargs)
self._configure_hub_parameters(training_args_kwargs)
self._configure_scheduler(training_args_kwargs)

View File

@@ -151,6 +151,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps
)
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = {
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
}
if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True
@@ -237,19 +245,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
training_arguments_kwargs["sample_packing_drop_attention_mask"] = bool(
self.cfg.flash_attention
or self.cfg.xformers_attention
or self.cfg.flex_attention
)
training_arguments_kwargs["multipack_real_batches"] = (
self.cfg.multipack_real_batches
if self.cfg.multipack_real_batches is not None
else not (
self.cfg.flash_attention
or self.cfg.flex_attention
or self.cfg.xformers_attention
)
else not self.cfg.flash_attention
)
training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing

View File

@@ -208,7 +208,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
if self.cfg.fsdp_config or self.cfg.fsdp:
if self.cfg.fsdp:
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
@@ -218,3 +218,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer.add_callback(callback)
return trainer
class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
HF Factory class for PPO Trainer
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build(self, total_num_steps):
# TODO: build PPOConfig
raise NotImplementedError("PPO trainer builder is not implemented yet.")

View File

@@ -14,4 +14,5 @@ from .trl import (
AxolotlORPOTrainer,
AxolotlPRMTrainer,
AxolotlRewardTrainer,
TRLPPOTrainer,
)

View File

@@ -27,7 +27,6 @@ from typing_extensions import override
from axolotl.core.trainers.mixins import (
CheckpointSaveMixin,
OptimizerMixin,
PackingMixin,
RngLoaderMixin,
SchedulerMixin,
)
@@ -43,12 +42,7 @@ LOG = get_logger(__name__)
class AxolotlTrainer(
PackingMixin,
SchedulerMixin,
OptimizerMixin,
RngLoaderMixin,
CheckpointSaveMixin,
Trainer,
SchedulerMixin, OptimizerMixin, RngLoaderMixin, CheckpointSaveMixin, Trainer
):
"""Extend the base Trainer for axolotl helpers"""
@@ -212,14 +206,6 @@ class AxolotlTrainer(
if dataset.column_names and "length" in dataset.column_names:
dataset = dataset.remove_columns(["length"])
if (
dataset.column_names
and "position_ids" in dataset.column_names
and "attention_mask" in dataset.column_names
and self.args.sample_packing
and self.args.sample_packing_drop_attention_mask
):
dataset = dataset.remove_columns(["attention_mask"])
if isinstance(dataset, datasets.Dataset):
if is_training:

View File

@@ -5,6 +5,5 @@
from .checkpoints import CheckpointSaveMixin
from .optimizer import OptimizerMixin
from .packing import PackingMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin

View File

@@ -1,20 +0,0 @@
"""Trainer mixin to support packing"""
from transformers import Trainer
class PackingMixin(Trainer):
"""
Trainer mixin to support packing
"""
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()
if (
self._signature_columns
and self.args.sample_packing
and self.args.sample_packing_drop_attention_mask
):
set_sig_columns = set(self._signature_columns)
set_sig_columns.remove("attention_mask")
self._signature_columns = list(set_sig_columns)

View File

@@ -1,9 +1,12 @@
"""Module for TRL RL trainers"""
"""Module for TRL PPO trainer"""
import torch
from tqdm import tqdm
from trl import (
CPOTrainer,
KTOTrainer,
ORPOTrainer,
PPOTrainer,
PRMTrainer,
RewardTrainer,
)
@@ -13,6 +16,64 @@ from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, Optimizer
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
class TRLPPOTrainer(PPOTrainer):
"""Wrapper for TRL PPO trainer to handle customizations"""
tag_names = ["axolotl", "ppo"]
def train(
self,
reward_pipe,
resume_from_checkpoint=None, # pylint: disable=unused-argument
):
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": 32,
}
sent_kwargs = {
"return_all_scores": True,
"function_to_apply": "none",
"batch_size": 16,
}
for _, batch in tqdm(enumerate(self.dataloader)):
query_tensors = batch["input_ids"]
# generate model response
response_tensors, ref_response_tensors = self.generate(
query_tensors,
return_prompt=False,
generate_ref_response=True,
**generation_kwargs,
)
batch["response"] = self.tokenizer.batch_decode(response_tensors)
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
# Compute sentiment score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = reward_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs)
ref_rewards = [
torch.tensor(output[1]["score"]) for output in ref_pipe_outputs
]
batch["ref_rewards"] = ref_rewards
# Run PPO step
stats = self.step(query_tensors, response_tensors, rewards)
self.log_stats(
stats,
batch,
rewards,
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
)
class AxolotlORPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
):

View File

@@ -42,10 +42,6 @@ class AxolotlTrainingMixins:
default=None,
metadata={"help": "The multiprocessing start method to use."},
)
sample_packing_drop_attention_mask: bool = field(
default=False,
metadata={"help": "Drop attention mask from inputs when using packing."},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},

View File

@@ -48,6 +48,13 @@ class TokenizedPromptDataset(Dataset):
features = dataset.features.keys()
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
LOG.info(
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
)
num_proc = 1
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True

Some files were not shown because too many files have changed in this diff Show More