Compare commits
45 Commits
lora-quant
...
coderabbit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e23a5c9fda | ||
|
|
5d7a61576d | ||
|
|
5ecf22b54e | ||
|
|
9c5b8da22f | ||
|
|
c0a0c7534c | ||
|
|
7fa1089cea | ||
|
|
fea6649518 | ||
|
|
124ad2b968 | ||
|
|
80304c26a7 | ||
|
|
767c2340f1 | ||
|
|
f6623c34cc | ||
|
|
5dd8f0b2b8 | ||
|
|
67c4ea9c7c | ||
|
|
526ddb886d | ||
|
|
f34eef546a | ||
|
|
c7b6790614 | ||
|
|
be3c6bbd85 | ||
|
|
f07db4f853 | ||
|
|
17a5838d38 | ||
|
|
9f68918f13 | ||
|
|
47e0e71bc8 | ||
|
|
0f3587174d | ||
|
|
25e6c5f9bd | ||
|
|
32f51bca35 | ||
|
|
9daa04da90 | ||
|
|
0d71b0aa5f | ||
|
|
63aaccf85b | ||
|
|
ff0fe767c8 | ||
|
|
8e4158cc0b | ||
|
|
cd84325253 | ||
|
|
0b140fef83 | ||
|
|
e4cfebe995 | ||
|
|
a6cac5dd32 | ||
|
|
b71c0e3447 | ||
|
|
ddaebf8309 | ||
|
|
679743087a | ||
|
|
f720b6e72d | ||
|
|
a980618fd0 | ||
|
|
54960d4de0 | ||
|
|
ed922796b7 | ||
|
|
3dd9c3bf3f | ||
|
|
0ba7d362fa | ||
|
|
e4f73bc98e | ||
|
|
bcb59c70e2 | ||
|
|
6a3e6f8c53 |
2
.github/workflows/multi-gpu-e2e.yml
vendored
2
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -3,7 +3,7 @@ name: docker-multigpu-tests-biweekly
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'tests/e2e/multigpu/*.py'
|
||||
- 'tests/e2e/multigpu/**.py'
|
||||
- 'requirements.txt'
|
||||
- 'setup.py'
|
||||
- 'pyproject.toml'
|
||||
|
||||
6
.github/workflows/preview-docs.yml
vendored
6
.github/workflows/preview-docs.yml
vendored
@@ -4,6 +4,12 @@ on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
|
||||
# Run the workflow only when one of these files changes
|
||||
paths:
|
||||
- '**/*.md' # any Markdown file
|
||||
- '**/*.qmd' # any Quarto file
|
||||
- '_quarto.yaml'
|
||||
|
||||
permissions:
|
||||
checks: write
|
||||
contents: write
|
||||
|
||||
87
.github/workflows/tests-nightly.yml
vendored
87
.github/workflows/tests-nightly.yml
vendored
@@ -18,9 +18,96 @@ jobs:
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
preload-cache:
|
||||
name: Preload HF cache
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.6.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
env:
|
||||
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Pre-Download dataset fixture
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v tests/conftest.py
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
files: ./coverage.xml
|
||||
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
||||
fail_ci_if_error: false
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
- name: Save HF cache
|
||||
id: hf-cache
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
pytest:
|
||||
name: PyTest
|
||||
runs-on: ubuntu-latest
|
||||
needs: [preload-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
|
||||
204
.github/workflows/tests.yml
vendored
204
.github/workflows/tests.yml
vendored
@@ -44,12 +44,104 @@ jobs:
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
# preload-cache:
|
||||
# name: Preload HF cache
|
||||
# runs-on: ubuntu-latest
|
||||
# strategy:
|
||||
# fail-fast: false
|
||||
# matrix:
|
||||
# python_version: ["3.11"]
|
||||
# pytorch_version: ["2.6.0"]
|
||||
# timeout-minutes: 20
|
||||
#
|
||||
# env:
|
||||
# AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
|
||||
#
|
||||
# steps:
|
||||
# - name: Check out repository code
|
||||
# uses: actions/checkout@v4
|
||||
#
|
||||
# - name: Restore HF cache
|
||||
# id: hf-cache-restore
|
||||
# uses: actions/cache/restore@v4
|
||||
# with:
|
||||
# path: |
|
||||
# /home/runner/.cache/huggingface/hub/datasets--*
|
||||
# /home/runner/.cache/huggingface/hub/models--*
|
||||
# key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
#
|
||||
# - name: Restore Cache from S3
|
||||
# id: hf-cache-restore-s3
|
||||
# run: |
|
||||
# mkdir -p /home/runner/.cache/huggingface/hub
|
||||
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
#
|
||||
# - name: Setup Python
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: ${{ matrix.python_version }}
|
||||
# cache: 'pip' # caching pip dependencies
|
||||
#
|
||||
# - name: upgrade pip
|
||||
# run: |
|
||||
# pip3 install --upgrade pip
|
||||
# pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||
#
|
||||
# - name: Install PyTorch
|
||||
# run: |
|
||||
# pip3 install torch==${{ matrix.pytorch_version }}
|
||||
#
|
||||
# - name: Install dependencies
|
||||
# run: |
|
||||
# pip3 show torch
|
||||
# pip3 install --no-build-isolation -U -e .
|
||||
# python scripts/unsloth_install.py | sh
|
||||
# python scripts/cutcrossentropy_install.py | sh
|
||||
# pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
#
|
||||
# - name: Make sure PyTorch version wasn't clobbered
|
||||
# run: |
|
||||
# python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
#
|
||||
# - name: Ensure axolotl CLI was installed
|
||||
# run: |
|
||||
# axolotl --help
|
||||
#
|
||||
# - name: Pre-Download dataset fixture
|
||||
# run: |
|
||||
# huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
#
|
||||
# - name: Run tests
|
||||
# run: |
|
||||
# pytest -v tests/conftest.py
|
||||
#
|
||||
# - name: Upload coverage to Codecov
|
||||
# uses: codecov/codecov-action@v5
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# files: ./coverage.xml
|
||||
# flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
||||
# fail_ci_if_error: false
|
||||
#
|
||||
# - name: cleanup pip cache
|
||||
# run: |
|
||||
# find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
#
|
||||
# - name: Save HF cache
|
||||
# id: hf-cache
|
||||
# uses: actions/cache/save@v4
|
||||
# with:
|
||||
# path: |
|
||||
# /home/runner/.cache/huggingface/hub/datasets--*
|
||||
# /home/runner/.cache/huggingface/hub/models--*
|
||||
# key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
pytest:
|
||||
name: PyTest
|
||||
runs-on: ubuntu-latest
|
||||
# needs: [preload-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
|
||||
@@ -59,14 +151,20 @@ jobs:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
# - name: Restore HF cache
|
||||
# id: hf-cache-restore
|
||||
# uses: actions/cache/restore@v4
|
||||
# with:
|
||||
# path: |
|
||||
# /home/runner/.cache/huggingface/hub/datasets--*
|
||||
# /home/runner/.cache/huggingface/hub/models--*
|
||||
# key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p /home/runner/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -121,21 +219,12 @@ jobs:
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
- name: Save HF cache
|
||||
id: hf-cache
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
pytest-sdist:
|
||||
name: PyTest from Source Dist
|
||||
runs-on: ubuntu-latest
|
||||
# needs: [preload-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 1
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
|
||||
@@ -145,14 +234,20 @@ jobs:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
# - name: Restore HF cache
|
||||
# id: hf-cache-restore
|
||||
# uses: actions/cache/restore@v4
|
||||
# with:
|
||||
# path: |
|
||||
# /home/runner/.cache/huggingface/hub/datasets--*
|
||||
# /home/runner/.cache/huggingface/hub/models--*
|
||||
# key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p /home/runner/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -199,15 +294,6 @@ jobs:
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
- name: Save HF cache
|
||||
id: hf-cache
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
docker-e2e-tests-1st:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
@@ -267,12 +353,6 @@ jobs:
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras: llmcompressor
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
@@ -309,3 +389,43 @@ jobs:
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
docker-e2e-cleanup:
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 90
|
||||
needs: [docker-e2e-tests]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras: vllm
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==0.71.8 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.cleanup
|
||||
|
||||
@@ -57,8 +57,10 @@ async def handler(job):
|
||||
logger.info("Training Complete.")
|
||||
|
||||
# Cleanup
|
||||
del os.environ["WANDB_API_KEY"]
|
||||
del os.environ["HF_TOKEN"]
|
||||
if "WANDB_API_KEY" in os.environ:
|
||||
del os.environ["WANDB_API_KEY"]
|
||||
if "HF_TOKEN" in os.environ:
|
||||
del os.environ["HF_TOKEN"]
|
||||
|
||||
|
||||
runpod.serverless.start({"handler": handler, "return_aggregate_stream": True})
|
||||
|
||||
20
_quarto.yml
20
_quarto.yml
@@ -48,8 +48,23 @@ quartodoc:
|
||||
contents:
|
||||
- core.trainers.base
|
||||
- core.trainers.trl
|
||||
- core.trainers.mamba
|
||||
- core.trainers.relora
|
||||
- core.trainers.dpo.trainer
|
||||
- core.trainers.grpo.trainer
|
||||
- core.trainers.grpo.sampler
|
||||
- core.trainers.utils
|
||||
- title: Mixins
|
||||
desc: Mixin classes for augmenting trainers
|
||||
contents:
|
||||
- core.trainers.mixins.optimizer
|
||||
- core.trainers.mixins.rng_state_loader
|
||||
- core.trainers.mixins.scheduler
|
||||
- core.trainers.mixins.sequence_parallel
|
||||
- title: Context Managers
|
||||
desc: Context managers for altering trainer behaviors
|
||||
contents:
|
||||
- utils.ctx_managers.sequence_parallel
|
||||
- title: Prompt Strategies
|
||||
desc: Prompt formatting strategies
|
||||
contents:
|
||||
@@ -86,7 +101,7 @@ quartodoc:
|
||||
- kernels.swiglu
|
||||
- kernels.quantize
|
||||
- kernels.utils
|
||||
- title: MonkeyPatches
|
||||
- title: Monkey Patches
|
||||
desc: Runtime patches for model optimizations
|
||||
contents:
|
||||
- monkeypatch.llama_attn_hijack_flash
|
||||
@@ -124,7 +139,8 @@ quartodoc:
|
||||
- utils.optimizers.adopt
|
||||
- utils.data.pretraining
|
||||
- utils.data.sft
|
||||
- utils.gradient_checkpointing.unsloth
|
||||
- utils.gradient_checkpointing.offload_cpu
|
||||
- utils.gradient_checkpointing.offload_disk
|
||||
- title: Schemas
|
||||
desc: Pydantic data models for Axolotl config
|
||||
contents:
|
||||
|
||||
0
cicd/__init__.py
Normal file
0
cicd/__init__.py
Normal file
@@ -18,7 +18,7 @@ pytest -v --durations=10 \
|
||||
--cov-append
|
||||
|
||||
# Run patched tests excluding lora kernels with coverage append
|
||||
pytest -v --durations=10 \
|
||||
pytest --full-trace -vvv --durations=10 \
|
||||
--ignore=tests/e2e/patched/lora_kernels \
|
||||
/workspace/axolotl/tests/e2e/patched \
|
||||
--cov=axolotl \
|
||||
|
||||
19
cicd/cleanup.py
Normal file
19
cicd/cleanup.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Modal app to run axolotl GPU cleanup"""
|
||||
|
||||
from .single_gpu import VOLUME_CONFIG, app, cicd_image, run_cmd
|
||||
|
||||
|
||||
@app.function(
|
||||
image=cicd_image,
|
||||
timeout=60 * 60,
|
||||
cpu=8.0,
|
||||
memory=131072,
|
||||
volumes=VOLUME_CONFIG,
|
||||
)
|
||||
def cleanup():
|
||||
run_cmd("./cicd/cleanup.sh", "/workspace/axolotl")
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
cleanup.remote()
|
||||
6
cicd/cleanup.sh
Executable file
6
cicd/cleanup.sh
Executable file
@@ -0,0 +1,6 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# cleanup old cache files for datasets processing and intermediate mappings
|
||||
find /workspace/data/huggingface-cache/hub/datasets -name "cache-*" -type f -mtime +1 -exec rm {} \;
|
||||
find /workspace/data/huggingface-cache/hub/datasets -name "*.lock" -type f -mtime +1 -exec rm {} \;
|
||||
@@ -1,75 +1,12 @@
|
||||
"""Modal app to run axolotl GPU tests"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
|
||||
import jinja2
|
||||
import modal
|
||||
from jinja2 import select_autoescape
|
||||
from modal import App, Image
|
||||
|
||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||
template_env = jinja2.Environment(
|
||||
loader=template_loader, autoescape=select_autoescape()
|
||||
)
|
||||
df_template = template_env.get_template("Dockerfile.jinja")
|
||||
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
|
||||
"CUDA": os.environ.get("CUDA", "121"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||
}
|
||||
|
||||
dockerfile_contents = df_template.render(**df_args)
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||
f.write(dockerfile_contents)
|
||||
|
||||
cicd_image = Image.from_dockerfile(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
context_mount=None,
|
||||
force_build=True,
|
||||
gpu="A10G",
|
||||
).env(df_args)
|
||||
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
|
||||
hf_cache_volume = modal.Volume.from_name(
|
||||
"axolotl-ci-hf-hub-cache", create_if_missing=True
|
||||
)
|
||||
VOLUME_CONFIG = {
|
||||
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
|
||||
}
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
|
||||
|
||||
|
||||
def run_cmd(cmd: str, run_folder: str):
|
||||
import subprocess # nosec
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
|
||||
|
||||
|
||||
@app.function(
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=60 * 60,
|
||||
timeout=90 * 60, # 90 min
|
||||
cpu=8.0,
|
||||
memory=131072,
|
||||
volumes=VOLUME_CONFIG,
|
||||
|
||||
66
cicd/single_gpu.py
Normal file
66
cicd/single_gpu.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Modal app to run axolotl GPU tests"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
|
||||
import jinja2
|
||||
import modal
|
||||
from jinja2 import select_autoescape
|
||||
from modal import App, Image
|
||||
|
||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||
template_env = jinja2.Environment(
|
||||
loader=template_loader, autoescape=select_autoescape()
|
||||
)
|
||||
df_template = template_env.get_template("Dockerfile.jinja")
|
||||
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
|
||||
"CUDA": os.environ.get("CUDA", "121"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||
}
|
||||
|
||||
dockerfile_contents = df_template.render(**df_args)
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||
f.write(dockerfile_contents)
|
||||
|
||||
cicd_image = Image.from_dockerfile(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
context_mount=None,
|
||||
force_build=True,
|
||||
gpu="A10G",
|
||||
).env(df_args)
|
||||
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
|
||||
hf_cache_volume = modal.Volume.from_name(
|
||||
"axolotl-ci-hf-hub-cache", create_if_missing=True
|
||||
)
|
||||
VOLUME_CONFIG = {
|
||||
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
|
||||
}
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
|
||||
|
||||
|
||||
def run_cmd(cmd: str, run_folder: str):
|
||||
import subprocess # nosec
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
@@ -19,7 +19,7 @@ coverage:
|
||||
if_no_uploads: error
|
||||
if_not_found: success
|
||||
if_ci_failed: error
|
||||
only_pulls: false
|
||||
only_pulls: true
|
||||
flags: null
|
||||
paths: null
|
||||
patch:
|
||||
|
||||
@@ -32,6 +32,8 @@ tokenizer_legacy:
|
||||
resize_token_embeddings_to_32x:
|
||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||
shrink_embeddings:
|
||||
# Optional[bool] Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs
|
||||
embeddings_skip_upcast:
|
||||
# Whether to load the model with randomly initialized weights. Useful for
|
||||
# pre-training a model from scratch or debugging purposes.
|
||||
random_init_weights:
|
||||
@@ -73,11 +75,12 @@ load_in_8bit: true
|
||||
load_in_4bit:
|
||||
|
||||
# Use CUDA bf16
|
||||
bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere
|
||||
bf16: true # bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere
|
||||
# Use CUDA fp16
|
||||
fp16: true
|
||||
# Use CUDA tf32
|
||||
tf32: true # require >=ampere
|
||||
# Note: if bf16 is set to 'auto', and fp16 is set to true, we will prefer the explict fp16 setting
|
||||
|
||||
# No AMP (automatic mixed precision)
|
||||
bfloat16: true # require >=ampere
|
||||
@@ -184,8 +187,8 @@ datasets:
|
||||
# adding a system turn with empty content.
|
||||
drop_system_message:
|
||||
|
||||
# Optional[bool]. Whether to split the assistant turn based on a reasoning trace inside delimited tags
|
||||
# defaults to False
|
||||
# Optional[bool]. (for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags
|
||||
# See example at `docs/dataset-formats/conversation.qmd`
|
||||
split_thinking:
|
||||
|
||||
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
||||
@@ -329,6 +332,8 @@ dataset_shard_idx:
|
||||
# The maximum length of an input to train with, this should typically be less than 2048
|
||||
# as most models have a token/context limit of 2048
|
||||
sequence_len: 2048
|
||||
# How to handle sequences that overflow the sequence_len: 'drop' (default, removes sample) or 'truncate' (cuts off excess tokens).
|
||||
sequence_len_overflow_handling: drop
|
||||
# Pad inputs so each step uses constant sized buffers
|
||||
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||
pad_to_sequence_len:
|
||||
@@ -502,6 +507,7 @@ save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of eac
|
||||
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
|
||||
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
||||
save_total_limit: # Checkpoints saved at a time
|
||||
save_only_model: # Save only the model weights, skipping the optimizer. Using this means you can't resume from checkpoints.
|
||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||
# if both are set, num_epochs will not be guaranteed.
|
||||
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
|
||||
@@ -535,7 +541,7 @@ train_on_inputs: false
|
||||
# Note that training loss may have an oscillating pattern with this enabled.
|
||||
group_by_length: false
|
||||
|
||||
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
|
||||
# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk".
|
||||
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
gradient_checkpointing: false
|
||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
||||
@@ -547,7 +553,7 @@ gradient_checkpointing: false
|
||||
early_stopping_patience: 3
|
||||
|
||||
# Specify a scheduler and kwargs to use with the optimizer
|
||||
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
|
||||
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | 'linear' | 'cosine_with_restarts' | 'polynomial' | 'constant' | 'constant_with_warmup' | 'inverse_sqrt' | 'reduce_lr_on_plateau' | 'cosine_with_min_lr' | 'warmup_stable_decay' | empty for cosine
|
||||
lr_scheduler_kwargs:
|
||||
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
|
||||
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
|
||||
@@ -609,6 +615,7 @@ lr_div_factor: # Learning rate div factor
|
||||
# - optimi_adamw
|
||||
# - ao_adamw_8bit
|
||||
# - ao_adamw_fp8
|
||||
# - came_pytorch
|
||||
optimizer:
|
||||
# Dictionary of arguments to pass to the optimizer
|
||||
optim_args:
|
||||
|
||||
@@ -196,6 +196,34 @@ datasets:
|
||||
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
||||
:::
|
||||
|
||||
8. (For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
chat_template: qwen3
|
||||
split_thinking: true
|
||||
```
|
||||
|
||||
For example, a content can look like:
|
||||
|
||||
```json
|
||||
{
|
||||
"content": "<think>Some thinking outputs</think>Output after thinking."
|
||||
}
|
||||
```
|
||||
|
||||
After split, it will look like:
|
||||
|
||||
```json
|
||||
{
|
||||
"reasoning_content": "Some thinking outputs",
|
||||
"content": "Output after thinking..."
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## sharegpt
|
||||
|
||||
::: {.callout-important}
|
||||
|
||||
@@ -3,8 +3,6 @@ title: Sequence Parallelism
|
||||
description: Train with long sequences split across multiple GPUs.
|
||||
---
|
||||
|
||||
# Sequence Parallelism
|
||||
|
||||
Sequence parallelism is a technique that splits sequences across multiple GPUs,
|
||||
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
|
||||
GPU processes a different portion of the sequence, and the results are aggregated
|
||||
@@ -27,7 +25,7 @@ To enable sequence parallelism, add the following to your configuration file:
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to
|
||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
|
||||
ring_attn_func:
|
||||
```
|
||||
|
||||
@@ -34,3 +34,5 @@ We provide a script to delinearize Llama 4 linearized models into regular Huggin
|
||||
```bash
|
||||
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
|
||||
```
|
||||
|
||||
Note: This only works with the non-quantized linearized model. If you have an adapter, merge it with the *non-quantized linearized* model before delinearizing.
|
||||
|
||||
341
examples/orpheus/README.md
Normal file
341
examples/orpheus/README.md
Normal file
@@ -0,0 +1,341 @@
|
||||
# Finetuning LLMs to output audio
|
||||
|
||||
In this example, we finetune Orpcanopylabs/orpheus-tts-0.1-pretrained (a LLaMA 3.2 3b model) to output audio.
|
||||
|
||||
The `finetune.yml` withe current settings will run on any Nvidia GPU with 45GB VRAM or more. If you adjust the batch size it can easily run on any GPU under 24GB.
|
||||
|
||||
## Dataset pre-processing for pre-training
|
||||
If you are adding another voice in English, please jump ahead to finetuning pre-processing.
|
||||
|
||||
For this to work, we need to preprocess our dataset. Since we are expecting to output audio, we will need to add tokens to the tokenizer.
|
||||
|
||||
Using this code, it will download the SNAC model and add the correct tokens and upload the final dataset.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from snac import SNAC
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from datasets import load_dataset
|
||||
import random
|
||||
import torchaudio.transforms as T
|
||||
from transformers import AutoTokenizer
|
||||
import os
|
||||
|
||||
my_original_dataset_name = "<huggingface-id-of-dataset-that-we-want-to-preprocess>"
|
||||
name_to_push_dataset_to = "<huggingface-id-of-where-to-save-dataset>"
|
||||
|
||||
dsn = my_original_dataset_name
|
||||
|
||||
snapshot_download(
|
||||
repo_id=dsn,
|
||||
repo_type="dataset",
|
||||
revision="main",
|
||||
max_workers=64,
|
||||
)
|
||||
|
||||
|
||||
ds = load_dataset(dsn, split="train")
|
||||
ds_sample_rate = ds[0]["audio"]["sampling_rate"]
|
||||
|
||||
model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
|
||||
model = model.to("mps")
|
||||
|
||||
def tokenise_audio(waveform):
|
||||
waveform = torch.from_numpy(waveform).unsqueeze(0)
|
||||
waveform = waveform.to(dtype=torch.float32)
|
||||
resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)
|
||||
waveform = resample_transform(waveform)
|
||||
|
||||
waveform = waveform.unsqueeze(0).to("cuda")
|
||||
|
||||
#generate the codes from snac
|
||||
with torch.inference_mode():
|
||||
codes = model.encode(waveform)
|
||||
|
||||
all_codes = []
|
||||
for i in range(codes[0].shape[1]):
|
||||
all_codes.append(codes[0][0][i].item()+128266)
|
||||
all_codes.append(codes[1][0][2*i].item()+128266+4096)
|
||||
all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
|
||||
all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
|
||||
all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
|
||||
all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
|
||||
all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))
|
||||
|
||||
|
||||
return all_codes
|
||||
|
||||
def add_codes(example):
|
||||
# Always initialize codes_list to None
|
||||
codes_list = None
|
||||
|
||||
try:
|
||||
answer_audio = example.get("audio")
|
||||
# If there's a valid audio array, tokenise it
|
||||
if answer_audio and "array" in answer_audio:
|
||||
audio_array = answer_audio["array"]
|
||||
codes_list = tokenise_audio(audio_array)
|
||||
except Exception as e:
|
||||
print(f"Skipping row due to error: {e}")
|
||||
# Keep codes_list as None if we fail
|
||||
example["codes_list"] = codes_list
|
||||
|
||||
return example
|
||||
|
||||
ds = ds.map(add_codes, remove_columns=["audio"])
|
||||
|
||||
#@title Load Tokenizer
|
||||
tokeniser_length = 128256
|
||||
start_of_text = 128000
|
||||
end_of_text = 128009
|
||||
|
||||
start_of_speech = tokeniser_length + 1
|
||||
end_of_speech = tokeniser_length + 2
|
||||
|
||||
start_of_human = tokeniser_length + 3
|
||||
end_of_human = tokeniser_length + 4
|
||||
|
||||
start_of_ai = tokeniser_length + 5
|
||||
end_of_ai = tokeniser_length + 6
|
||||
pad_token = tokeniser_length + 7
|
||||
|
||||
audio_tokens_start = tokeniser_length + 10
|
||||
|
||||
tokenizer_name = "canopylabs/orpheus-3b-0.1-pretrained"
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
num_proc = os.cpu_count() - 2
|
||||
|
||||
ds = ds.filter(lambda x: x["codes_list"] is not None)
|
||||
ds = ds.filter(lambda x: len(x["codes_list"]) > 0)
|
||||
|
||||
#@title Create Input Ids
|
||||
def remove_duplicate_frames(example):
|
||||
vals = example["codes_list"]
|
||||
if len(vals) % 7 != 0:
|
||||
raise ValueError("Input list length must be divisible by 7")
|
||||
|
||||
result = vals[:7]
|
||||
|
||||
removed_frames = 0
|
||||
|
||||
for i in range(7, len(vals), 7):
|
||||
current_first = vals[i]
|
||||
previous_first = result[-7]
|
||||
|
||||
if current_first != previous_first:
|
||||
result.extend(vals[i:i+7])
|
||||
else:
|
||||
removed_frames += 1
|
||||
|
||||
example["codes_list"] = result
|
||||
|
||||
return example
|
||||
|
||||
ds = ds.map(remove_duplicate_frames, num_proc=num_proc)
|
||||
|
||||
|
||||
def create_input_ids(example):
|
||||
text_ids = tokenizer.encode({example['text']}, add_special_tokens=True)
|
||||
text_ids.append(end_of_text)
|
||||
example["text_tokens"] = text_ids
|
||||
input_ids = (
|
||||
[start_of_human]
|
||||
+ example["text_tokens"]
|
||||
+ [end_of_human]
|
||||
+ [start_of_ai]
|
||||
+ [start_of_speech]
|
||||
+ example["codes_list"]
|
||||
+ [end_of_speech]
|
||||
+ [end_of_ai]
|
||||
)
|
||||
example["input_ids"] = input_ids
|
||||
example["labels"] = input_ids
|
||||
example["attention_mask"] = [1] * len(input_ids)
|
||||
|
||||
return example
|
||||
|
||||
ds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=["text", "codes_list"])
|
||||
|
||||
#@title Remove unnecessary columns
|
||||
columns_to_keep = ["input_ids", "labels", "attention_mask"]
|
||||
columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]
|
||||
|
||||
ds = ds.remove_columns(columns_to_remove)
|
||||
|
||||
ds.push_to_hub(name_to_push_dataset_to)
|
||||
```
|
||||
|
||||
|
||||
## Finetune pre-processing
|
||||
Use this code to add a new voice.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from snac import SNAC
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from datasets import load_dataset
|
||||
import random
|
||||
import torchaudio.transforms as T
|
||||
from transformers import AutoTokenizer
|
||||
import os
|
||||
|
||||
my_original_dataset_name = "<huggingface-id-of-dataset-that-we-want-to-preprocess>"
|
||||
name_to_push_dataset_to = "<huggingface-id-of-where-to-save-dataset>"
|
||||
|
||||
dsn = my_original_dataset_name
|
||||
|
||||
snapshot_download(
|
||||
repo_id=dsn,
|
||||
repo_type="dataset",
|
||||
revision="main",
|
||||
max_workers=64,
|
||||
)
|
||||
|
||||
|
||||
ds = load_dataset(dsn, split="train")
|
||||
ds_sample_rate = ds[0]["audio"]["sampling_rate"]
|
||||
|
||||
model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
|
||||
model = model.to("mps")
|
||||
|
||||
def tokenise_audio(waveform):
|
||||
waveform = torch.from_numpy(waveform).unsqueeze(0)
|
||||
waveform = waveform.to(dtype=torch.float32)
|
||||
resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)
|
||||
waveform = resample_transform(waveform)
|
||||
|
||||
waveform = waveform.unsqueeze(0).to("cuda")
|
||||
|
||||
#generate the codes from snac
|
||||
with torch.inference_mode():
|
||||
codes = model.encode(waveform)
|
||||
|
||||
all_codes = []
|
||||
for i in range(codes[0].shape[1]):
|
||||
all_codes.append(codes[0][0][i].item()+128266)
|
||||
all_codes.append(codes[1][0][2*i].item()+128266+4096)
|
||||
all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
|
||||
all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
|
||||
all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
|
||||
all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
|
||||
all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))
|
||||
|
||||
|
||||
return all_codes
|
||||
|
||||
def add_codes(example):
|
||||
# Always initialize codes_list to None
|
||||
codes_list = None
|
||||
|
||||
try:
|
||||
answer_audio = example.get("audio")
|
||||
# If there's a valid audio array, tokenise it
|
||||
if answer_audio and "array" in answer_audio:
|
||||
audio_array = answer_audio["array"]
|
||||
codes_list = tokenise_audio(audio_array)
|
||||
except Exception as e:
|
||||
print(f"Skipping row due to error: {e}")
|
||||
# Keep codes_list as None if we fail
|
||||
example["codes_list"] = codes_list
|
||||
|
||||
return example
|
||||
|
||||
ds = ds.map(add_codes, remove_columns=["audio"])
|
||||
|
||||
#@title Load Tokenizer
|
||||
tokeniser_length = 128256
|
||||
start_of_text = 128000
|
||||
end_of_text = 128009
|
||||
|
||||
start_of_speech = tokeniser_length + 1
|
||||
end_of_speech = tokeniser_length + 2
|
||||
|
||||
start_of_human = tokeniser_length + 3
|
||||
end_of_human = tokeniser_length + 4
|
||||
|
||||
start_of_ai = tokeniser_length + 5
|
||||
end_of_ai = tokeniser_length + 6
|
||||
pad_token = tokeniser_length + 7
|
||||
|
||||
audio_tokens_start = tokeniser_length + 10
|
||||
|
||||
tokenizer_name = "canopylabs/orpheus-3b-0.1-pretrained"
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
num_proc = os.cpu_count() - 2
|
||||
|
||||
ds = ds.filter(lambda x: x["codes_list"] is not None)
|
||||
ds = ds.filter(lambda x: len(x["codes_list"]) > 0)
|
||||
|
||||
#@title Create Input Ids
|
||||
def remove_duplicate_frames(example):
|
||||
vals = example["codes_list"]
|
||||
if len(vals) % 7 != 0:
|
||||
raise ValueError("Input list length must be divisible by 7")
|
||||
|
||||
result = vals[:7]
|
||||
|
||||
removed_frames = 0
|
||||
|
||||
for i in range(7, len(vals), 7):
|
||||
current_first = vals[i]
|
||||
previous_first = result[-7]
|
||||
|
||||
if current_first != previous_first:
|
||||
result.extend(vals[i:i+7])
|
||||
else:
|
||||
removed_frames += 1
|
||||
|
||||
example["codes_list"] = result
|
||||
|
||||
return example
|
||||
|
||||
ds = ds.map(remove_duplicate_frames, num_proc=num_proc)
|
||||
|
||||
tok_info = '''*** HERE you can modify the text prompt
|
||||
i.e. if you wanted a multispeaker model like canopylabs/orpheus-3b-0.1-ft, you can pass:
|
||||
f"{example["source"]}: {example["text"]}", as is passed.
|
||||
'''
|
||||
print(tok_info)
|
||||
|
||||
def create_input_ids(example):
|
||||
text_ids = tokenizer.encode(f"{example['speaker_id']}: {example['text']}", add_special_tokens=True)
|
||||
text_ids.append(end_of_text)
|
||||
example["text_tokens"] = text_ids
|
||||
input_ids = (
|
||||
[start_of_human]
|
||||
+ example["text_tokens"]
|
||||
+ [end_of_human]
|
||||
+ [start_of_ai]
|
||||
+ [start_of_speech]
|
||||
+ example["codes_list"]
|
||||
+ [end_of_speech]
|
||||
+ [end_of_ai]
|
||||
)
|
||||
example["input_ids"] = input_ids
|
||||
example["labels"] = input_ids
|
||||
example["attention_mask"] = [1] * len(input_ids)
|
||||
|
||||
return example
|
||||
|
||||
ds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=["text", "codes_list"])
|
||||
|
||||
#@title Remove unnecessary columns
|
||||
columns_to_keep = ["input_ids", "labels", "attention_mask"]
|
||||
columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]
|
||||
|
||||
ds = ds.remove_columns(columns_to_remove)
|
||||
|
||||
ds.push_to_hub(name_to_push_dataset_to)
|
||||
```
|
||||
|
||||
## Training
|
||||
After preprocessing is done, fill out the blanks in finetune.yml and simply run `axolotl train finetune.yml`
|
||||
|
||||
## Inference
|
||||
For inference, please refer to the original [orpheus github](https://github.com/canopyai/Orpheus-TTS/tree/main).
|
||||
52
examples/orpheus/finetune.yml
Normal file
52
examples/orpheus/finetune.yml
Normal file
@@ -0,0 +1,52 @@
|
||||
base_model: canopylabs/orpheus-3b-0.1-pretrained
|
||||
|
||||
hub_model_id: <your-hub-model-id>
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
datasets:
|
||||
- path: <your-hf-dataset-id>
|
||||
type: # leave empty to load pre-tokenized
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 20
|
||||
evals_per_epoch: 5
|
||||
saves_per_epoch: 5
|
||||
weight_decay: 0.05
|
||||
|
||||
special_tokens:
|
||||
pad_token: <custom_token_7>
|
||||
@@ -6,16 +6,17 @@ triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.5.8
|
||||
liger-kernel==0.5.9
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub==0.31.0
|
||||
peft==0.15.2
|
||||
transformers==4.51.3
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.6.0
|
||||
datasets==3.5.0
|
||||
datasets==3.5.1
|
||||
deepspeed>=0.15.4
|
||||
trl==0.17.0
|
||||
hf_xet==1.1.0
|
||||
|
||||
5
setup.py
5
setup.py
@@ -67,13 +67,13 @@ def parse_requirements(extras_require_map):
|
||||
if (major, minor) >= (2, 7):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5"]
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||
elif (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append(
|
||||
"xformers==0.0.29.post2"
|
||||
) # vllm needs post2 w torch 2.6
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5"]
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
@@ -142,6 +142,7 @@ extras_require = {
|
||||
"apollo-torch",
|
||||
"lomo-optim==0.1.1",
|
||||
"torch-optimi==0.2.1",
|
||||
"came_pytorch==0.1.3",
|
||||
],
|
||||
"ray": [
|
||||
"ray[train]",
|
||||
|
||||
@@ -82,6 +82,12 @@ class VllmServeCliArgs:
|
||||
"hardware support this feature."
|
||||
},
|
||||
)
|
||||
serve_module: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Module to serve. If not set, the default module will be used."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -16,8 +16,15 @@ AXOLOTL_LOGO = """
|
||||
@@@@ @@@@@@@@@@@@@@@@
|
||||
"""
|
||||
|
||||
HAS_PRINTED_LOGO = False
|
||||
|
||||
|
||||
def print_axolotl_text_art():
|
||||
"""Prints axolotl ASCII art."""
|
||||
|
||||
global HAS_PRINTED_LOGO # pylint: disable=global-statement
|
||||
if HAS_PRINTED_LOGO:
|
||||
return
|
||||
if is_main_process():
|
||||
HAS_PRINTED_LOGO = True
|
||||
print(AXOLOTL_LOGO)
|
||||
|
||||
@@ -15,7 +15,7 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.evaluate import evaluate
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@@ -32,7 +32,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
cli_args: CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
patch_optimized_env()
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
|
||||
@@ -29,7 +29,7 @@ from axolotl.cli.utils import (
|
||||
filter_none_kwargs,
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
|
||||
@@ -55,6 +55,8 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
patch_optimized_env()
|
||||
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_preprocess
|
||||
|
||||
@@ -100,7 +102,7 @@ def train(
|
||||
config options.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
patch_optimized_env()
|
||||
|
||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||
accelerate = False
|
||||
|
||||
@@ -18,6 +18,7 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.trainer import disable_datasets_caching
|
||||
|
||||
@@ -47,7 +48,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||
cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
with disable_datasets_caching():
|
||||
if cfg.rl:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
if plugin_manager.load_datasets(cfg, preprocess=True):
|
||||
pass
|
||||
elif cfg.rl:
|
||||
load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -18,7 +18,7 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -36,17 +36,20 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
cli_args: Training-specific CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
patch_optimized_env()
|
||||
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
check_user_token()
|
||||
|
||||
if cfg.rl:
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
|
||||
if not dataset_meta:
|
||||
if cfg.rl:
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from trl.scripts.vllm_serve import ScriptArguments
|
||||
from trl.scripts.vllm_serve import main as vllm_serve_main
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
|
||||
@@ -28,6 +27,9 @@ def do_vllm_serve(
|
||||
cfg = load_cfg(config)
|
||||
model = cfg.base_model
|
||||
|
||||
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
|
||||
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
|
||||
|
||||
tensor_parallel_size = (
|
||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_processor, load_tokenizer
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@@ -48,6 +49,7 @@ def load_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
|
||||
debug: bool = False,
|
||||
) -> TrainDatasetMeta:
|
||||
"""
|
||||
Loads one or more training or evaluation datasets, calling
|
||||
@@ -56,6 +58,7 @@ def load_datasets(
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Command-specific CLI arguments.
|
||||
debug: Whether to print out tokenization of sample
|
||||
|
||||
Returns:
|
||||
Dataclass with fields for training and evaluation datasets and the computed
|
||||
@@ -77,20 +80,25 @@ def load_datasets(
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
|
||||
if cli_args and (
|
||||
cli_args.debug
|
||||
or cfg.debug
|
||||
or cli_args.debug_text_only
|
||||
or int(cli_args.debug_num_examples) > 0
|
||||
):
|
||||
if ( # pylint: disable=too-many-boolean-expressions
|
||||
cli_args
|
||||
and (
|
||||
cli_args.debug
|
||||
or cfg.debug
|
||||
or cli_args.debug_text_only
|
||||
or int(cli_args.debug_num_examples) > 0
|
||||
)
|
||||
) or debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
|
||||
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
|
||||
num_examples = cli_args.debug_num_examples if cli_args else 1
|
||||
text_only = cli_args.debug_text_only if cli_args else False
|
||||
train_samples = sample_dataset(train_dataset, num_examples)
|
||||
check_dataset_labels(
|
||||
train_samples,
|
||||
tokenizer,
|
||||
num_examples=cli_args.debug_num_examples,
|
||||
text_only=cli_args.debug_text_only,
|
||||
num_examples=num_examples,
|
||||
text_only=text_only,
|
||||
)
|
||||
|
||||
LOG.info("printing prompters...")
|
||||
@@ -126,7 +134,7 @@ def load_preference_datasets(
|
||||
total_num_steps: Optional[int] = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
if cfg.rl == "grpo":
|
||||
if cfg.rl is RLType.GRPO:
|
||||
total_num_steps = None
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
|
||||
@@ -21,6 +21,7 @@ import importlib.util
|
||||
import inspect
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
@@ -72,6 +73,7 @@ from axolotl.utils.callbacks import (
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
causal_lm_bench_eval_callback_factory,
|
||||
colab_inference_post_train_callback,
|
||||
log_prediction_callback_factory,
|
||||
)
|
||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||
@@ -85,7 +87,7 @@ from axolotl.utils.collators import (
|
||||
)
|
||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||
from axolotl.utils.models import ensure_dtype
|
||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType
|
||||
|
||||
try:
|
||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||
@@ -168,6 +170,9 @@ class TrainerBuilderBase(abc.ABC):
|
||||
)
|
||||
)
|
||||
|
||||
if self.cfg.gc_steps:
|
||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||
|
||||
if self.cfg.use_wandb:
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
@@ -249,9 +254,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
|
||||
if self.cfg.gc_steps:
|
||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
@@ -293,6 +295,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||
callbacks.append(lisa_callback_factory(trainer))
|
||||
|
||||
if any("COLAB_" in key for key in os.environ):
|
||||
ColabCallback = colab_inference_post_train_callback(trainer)
|
||||
callbacks.append(ColabCallback(self.cfg))
|
||||
|
||||
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
|
||||
return callbacks
|
||||
|
||||
@@ -347,7 +353,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||
|
||||
if self.cfg.seed:
|
||||
if self.cfg.seed is not None:
|
||||
training_arguments_kwargs["seed"] = self.cfg.seed
|
||||
|
||||
if self.cfg.gradient_checkpointing:
|
||||
@@ -541,8 +547,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
report_to = []
|
||||
if self.cfg.use_wandb:
|
||||
report_to.append("wandb")
|
||||
if self.cfg.wandb_name:
|
||||
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
|
||||
if self.cfg.use_mlflow:
|
||||
report_to.append("mlflow")
|
||||
if self.cfg.use_tensorboard:
|
||||
@@ -702,6 +706,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
optimizer_cls = ADOPT
|
||||
adam_kwargs["decouple"] = True
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "came_pytorch":
|
||||
from came_pytorch import CAME
|
||||
|
||||
optimizer_cls = CAME
|
||||
|
||||
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
|
||||
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
|
||||
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999)
|
||||
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
|
||||
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
|
||||
adam_kwargs["betas"] = (beta1, beta2, beta3)
|
||||
adam_kwargs["eps"] = (eps1, eps2)
|
||||
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
|
||||
# Parse any additional optimizer args from config
|
||||
if self.cfg.optim_args:
|
||||
@@ -801,14 +819,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
data_collator_kwargs = {
|
||||
"padding": True, # True/"longest" is the default
|
||||
}
|
||||
multiple = 64
|
||||
if self.cfg.pad_to_sequence_len:
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
|
||||
self.cfg.sequence_len / 64
|
||||
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
||||
self.cfg.sequence_len / multiple
|
||||
)
|
||||
else:
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
||||
|
||||
if self.cfg.reward_model:
|
||||
data_collator_kwargs["max_length"] = self.cfg.sequence_len
|
||||
@@ -1014,6 +1033,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_kwargs["dataloader_prefetch_factor"] = (
|
||||
self.cfg.dataloader_prefetch_factor
|
||||
)
|
||||
|
||||
if self.cfg.seed is not None:
|
||||
training_args_kwargs["seed"] = self.cfg.seed
|
||||
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_args_kwargs["gradient_checkpointing"] = (
|
||||
self.cfg.gradient_checkpointing
|
||||
@@ -1037,6 +1060,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
training_args_kwargs["save_only_model"] = self.cfg.save_only_model
|
||||
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
@@ -1054,9 +1079,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.use_wandb:
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
|
||||
training_args_kwargs["sequence_parallel_degree"] = (
|
||||
self.cfg.sequence_parallel_degree
|
||||
)
|
||||
|
||||
training_args_cls = None
|
||||
blocklist_args_kwargs = []
|
||||
if self.cfg.rl == "simpo":
|
||||
if self.cfg.rl is RLType.SIMPO:
|
||||
training_args_cls = AxolotlCPOConfig
|
||||
training_args_kwargs["loss_type"] = "simpo"
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
@@ -1064,13 +1093,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.cpo_alpha is not None:
|
||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||
|
||||
elif self.cfg.rl == "orpo":
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
training_args_cls = AxolotlORPOConfig
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl == "kto":
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
training_args_cls = AxolotlKTOConfig
|
||||
|
||||
training_args_kwargs["desirable_weight"] = (
|
||||
@@ -1084,14 +1113,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl == "grpo":
|
||||
elif self.cfg.rl is RLType.GRPO:
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||
|
||||
else:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
if self.cfg.rl == "ipo":
|
||||
if self.cfg.rl is RLType.IPO:
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
@@ -1134,67 +1163,73 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
def build(self, total_num_steps):
|
||||
training_args = self.build_training_arguments(total_num_steps)
|
||||
dpo_trainer_kwargs = {}
|
||||
if self.cfg.rl == "ipo":
|
||||
trainer_kwargs = {}
|
||||
if self.cfg.rl is RLType.IPO:
|
||||
if self.cfg.dpo_label_smoothing:
|
||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
if self.eval_dataset:
|
||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||
if self.cfg.adapter and self.peft_config:
|
||||
dpo_trainer_kwargs["peft_config"] = self.peft_config
|
||||
trainer_kwargs["peft_config"] = self.peft_config
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
dpo_trainer_kwargs["precompute_ref_log_probs"] = (
|
||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||
self.cfg.precompute_ref_log_probs
|
||||
)
|
||||
if self.cfg.rl == "grpo":
|
||||
trainer_cls = GRPOStrategy.get_trainer_class()
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||
)
|
||||
trainer_cls_args = [self.model]
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
elif self.cfg.rl in ["dpo", "ipo"]:
|
||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
trainer_cls = DPOStrategy.get_trainer_class()
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
elif self.cfg.rl == "orpo":
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
elif self.cfg.rl in ["kto"]:
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
trainer_cls = AxolotlKTOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
elif self.cfg.rl in ["simpo"]:
|
||||
elif self.cfg.rl is RLType.SIMPO:
|
||||
trainer_cls = AxolotlCPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "tokenizer" in sig.parameters.keys():
|
||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
else:
|
||||
dpo_trainer_kwargs["processing_class"] = self.tokenizer
|
||||
trainer_kwargs["processing_class"] = self.tokenizer
|
||||
|
||||
if self.cfg.datasets is not None and (
|
||||
trainer_cls is DPOStrategy.get_trainer_class()
|
||||
):
|
||||
dpo_trainer_kwargs["dataset_tags"] = [
|
||||
trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
dpo_trainer = trainer_cls(
|
||||
trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
args=training_args,
|
||||
train_dataset=self.train_dataset,
|
||||
callbacks=self.get_callbacks(),
|
||||
**dpo_trainer_kwargs,
|
||||
**trainer_kwargs,
|
||||
)
|
||||
if self.cfg.fsdp:
|
||||
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
|
||||
if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model:
|
||||
ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype)
|
||||
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
|
||||
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
|
||||
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
|
||||
|
||||
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
|
||||
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
||||
dpo_trainer.add_callback(callback)
|
||||
trainer = self.hook_post_create_trainer(trainer)
|
||||
for callback in self.get_post_trainer_create_callbacks(trainer):
|
||||
trainer.add_callback(callback)
|
||||
|
||||
return dpo_trainer
|
||||
return trainer
|
||||
|
||||
|
||||
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .relora import ReLoRATrainer
|
||||
from .trl import (
|
||||
|
||||
@@ -114,6 +114,8 @@ class AxolotlTrainer(
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
sequential=self.args.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
)
|
||||
@@ -371,15 +373,13 @@ class AxolotlTrainer(
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
loss = super().compute_loss(
|
||||
return super().compute_loss(
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=return_outputs,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||
concatenated_batch = {}
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
"""
|
||||
DPO Specific Strategy for training
|
||||
"""
|
||||
"""DPO Specific Strategy for training"""
|
||||
|
||||
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
|
||||
class DPOStrategy:
|
||||
"""
|
||||
Strategy for DPO training
|
||||
"""
|
||||
"""Strategy for DPO training"""
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(cls):
|
||||
@@ -23,7 +20,7 @@ class DPOStrategy:
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg):
|
||||
training_args_kwargs = {}
|
||||
if cfg.rl == "ipo":
|
||||
if cfg.rl is RLType.IPO:
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
|
||||
@@ -177,12 +177,8 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||
|
||||
return res
|
||||
|
||||
@@ -251,7 +247,9 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
||||
)
|
||||
|
||||
# Base evaluation
|
||||
initial_output = super().evaluation_loop(
|
||||
initial_output = super( # pylint: disable=bad-super-call
|
||||
DPOTrainer, self
|
||||
).evaluation_loop(
|
||||
dataloader,
|
||||
description,
|
||||
prediction_loss_only,
|
||||
|
||||
@@ -1,37 +1,41 @@
|
||||
"""
|
||||
GRPO Specific Strategy for training
|
||||
"""
|
||||
"""GRPO Specific Strategy for training"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||
from axolotl.core.trainers.grpo.trainer import (
|
||||
AxolotlGRPOSequenceParallelTrainer,
|
||||
AxolotlGRPOTrainer,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GRPOStrategy:
|
||||
"""
|
||||
Strategy for GRPO training
|
||||
"""
|
||||
"""Strategy for GRPO training"""
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(cls):
|
||||
def get_trainer_class(
|
||||
cls, sequence_parallel: bool
|
||||
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]:
|
||||
if sequence_parallel:
|
||||
return AxolotlGRPOSequenceParallelTrainer
|
||||
return AxolotlGRPOTrainer
|
||||
|
||||
@classmethod
|
||||
def get_training_args_class(cls):
|
||||
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||
|
||||
def get_training_args_class(cls) -> type[AxolotlGRPOConfig]:
|
||||
return AxolotlGRPOConfig
|
||||
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg):
|
||||
grpo_args_kwargs = {}
|
||||
def set_training_args_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:
|
||||
grpo_args_kwargs: dict[str, Any] = {}
|
||||
|
||||
if not hasattr(cfg, "trl") or not cfg.trl:
|
||||
return grpo_args_kwargs
|
||||
@@ -40,8 +44,8 @@ class GRPOStrategy:
|
||||
|
||||
if trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host
|
||||
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port
|
||||
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host # type: ignore[attr-defined]
|
||||
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port # type: ignore[attr-defined]
|
||||
if trl.vllm_server_timeout:
|
||||
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
|
||||
if trl.vllm_guided_decoding_regex:
|
||||
@@ -102,17 +106,18 @@ class GRPOStrategy:
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
def set_trainer_args(cls, cfg):
|
||||
def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:
|
||||
trainer_args = []
|
||||
if cfg.trl and cfg.trl.reward_funcs:
|
||||
reward_funcs = []
|
||||
for reward_func_fqn in cfg.trl.reward_funcs:
|
||||
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
|
||||
trainer_args.append(reward_funcs)
|
||||
|
||||
return trainer_args
|
||||
|
||||
@classmethod
|
||||
def set_trainer_kwargs(cls, cfg):
|
||||
def set_trainer_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:
|
||||
trainer_kwargs = {}
|
||||
if cfg.trl and cfg.trl.reward_processing_classes:
|
||||
trainer_kwargs["reward_processing_classes"] = (
|
||||
@@ -126,7 +131,7 @@ class GRPOStrategy:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_blocklist_args_kwargs(cls):
|
||||
def get_blocklist_args_kwargs(cls) -> list[str]:
|
||||
return ["dataset_num_proc"]
|
||||
|
||||
@classmethod
|
||||
@@ -137,13 +142,13 @@ class GRPOStrategy:
|
||||
Args:
|
||||
reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform),
|
||||
or a HF hub path to the reward model.
|
||||
Raises:
|
||||
ValueError: If the reward function does not accept at least two arguments.
|
||||
|
||||
Returns:
|
||||
RewardFunc: A callable that accepts prompts and completions and returns rewards,
|
||||
or a path to a reward model.
|
||||
|
||||
Raises:
|
||||
ValueError: If the reward function does not accept at least two arguments.
|
||||
"""
|
||||
try:
|
||||
# use importlib to dynamically load the reward function from the module
|
||||
|
||||
@@ -11,6 +11,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
@dataclass
|
||||
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""
|
||||
Axolotl GRPO Config for GRPO training
|
||||
"""
|
||||
"""Axolotl GRPO Config for GRPO training"""
|
||||
|
||||
172
src/axolotl/core/trainers/grpo/sampler.py
Normal file
172
src/axolotl/core/trainers/grpo/sampler.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""Repeat random sampler (similar to the one implemented in
|
||||
https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds
|
||||
sequence parallelism functionality; i.e., duplicating data across ranks in the same
|
||||
sequence parallel group.
|
||||
"""
|
||||
|
||||
from typing import Iterator, Sized
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
|
||||
class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
"""Sampler for GRPO training with sequence parallelism.
|
||||
|
||||
This sampler ensures:
|
||||
- Ranks in the same sequence parallel (SP) group receive identical data.
|
||||
- Each index is repeated multiple times for sampling different completions.
|
||||
- Entire batches are repeated for reuse in multiple updates.
|
||||
- Data is properly distributed across SP groups.
|
||||
|
||||
In the table below, the values represent dataset indices. Each SP group has
|
||||
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
|
||||
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
||||
|
||||
Sequence Parallel Groups
|
||||
| SP0 | SP1 |
|
||||
| GPU 0 | GPU 1 | GPU 2 | GPU 3 |
|
||||
global_step step <---> mini_repeat_count=3
|
||||
<----------> batch_size=2 per SP group
|
||||
grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- SP groups get different data
|
||||
▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each SP group GPU
|
||||
|
|
||||
| 1 2 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Repeat same indices for iterations
|
||||
num_iterations=2 ▼ 1 3 [0 0 0 1 1 1] [2 2 2 3 3 3] <- When using gradient accumulation
|
||||
|
||||
2 4 [4 4 4 5 5 5] [6 6 6 7 7 7] <- New batch of data indices
|
||||
2 5 [4 4 4 5 5 5] [6 6 6 7 7 7]
|
||||
...
|
||||
|
||||
Args:
|
||||
dataset: Dataset to sample from.
|
||||
mini_repeat_count: How many times to repeat each sample immediately.
|
||||
world_size: Total number of processes.
|
||||
rank: Rank of current process.
|
||||
batch_size: Number of samples per batch.
|
||||
repeat_count: How many times to repeat the full sampling process.
|
||||
sequence_parallel_degree: Number of ranks in a sequence parallel group.
|
||||
shuffle: Whether to shuffle the dataset.
|
||||
seed: Random seed for shuffling.
|
||||
drop_last: Whether to drop the last incomplete batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Sized,
|
||||
mini_repeat_count: int,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
batch_size: int = 1,
|
||||
repeat_count: int = 1,
|
||||
sequence_parallel_degree: int = 1,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
):
|
||||
self.dataset = dataset
|
||||
self.mini_repeat_count = mini_repeat_count
|
||||
self.batch_size = batch_size
|
||||
self.repeat_count = repeat_count
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self.drop_last = drop_last
|
||||
self.epoch = 0
|
||||
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
|
||||
# Sequence parallelism parameters
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.num_sp_groups = world_size // sequence_parallel_degree
|
||||
self.sp_group_id = rank // sequence_parallel_degree
|
||||
|
||||
# Adjust dataset size for distributed sampling
|
||||
self.num_samples = len(self.dataset)
|
||||
self.total_size = self.num_samples
|
||||
|
||||
# Calculate effective number of samples per SP group
|
||||
if (
|
||||
self.drop_last
|
||||
and self.total_size % (self.num_sp_groups * self.batch_size) != 0
|
||||
):
|
||||
# Drop last incomplete batch if drop_last is True
|
||||
self.num_samples_per_sp_group = (
|
||||
self.total_size // self.batch_size // self.num_sp_groups
|
||||
) * self.batch_size
|
||||
else:
|
||||
# Round up to include last batch if drop_last is False
|
||||
self.num_samples_per_sp_group = (
|
||||
(self.total_size + self.batch_size * self.num_sp_groups - 1)
|
||||
// (self.batch_size * self.num_sp_groups)
|
||||
* self.batch_size
|
||||
)
|
||||
|
||||
if shuffle:
|
||||
self.generator = torch.Generator()
|
||||
self.generator.manual_seed(seed)
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
"""Creates iterator over dataset indices.
|
||||
|
||||
Returns:
|
||||
Iterator that yields indices into the dataset.
|
||||
"""
|
||||
# Deterministically shuffle based on epoch and seed
|
||||
if self.shuffle:
|
||||
indices = torch.randperm(
|
||||
self.num_samples, generator=self.generator
|
||||
).tolist()
|
||||
else:
|
||||
indices = list(range(self.num_samples))
|
||||
|
||||
# Add extra samples to make it evenly divisible by batch_size
|
||||
if len(indices) % self.batch_size != 0:
|
||||
padding = indices[: self.batch_size - len(indices) % self.batch_size]
|
||||
indices += padding
|
||||
|
||||
# Subsample based on SP group ID
|
||||
# Each SP group gets distinct batches of data
|
||||
batch_indices = []
|
||||
for i in range(0, len(indices), self.batch_size * self.num_sp_groups):
|
||||
start_idx = i + self.sp_group_id * self.batch_size
|
||||
end_idx = min(start_idx + self.batch_size, len(indices))
|
||||
if start_idx < len(indices):
|
||||
for j in range(self.batch_size):
|
||||
if start_idx + j < end_idx:
|
||||
batch_indices.append(indices[start_idx + j])
|
||||
|
||||
# Make sure batch_indices is exactly batch_size * num_batches_per_sp_group
|
||||
if self.drop_last:
|
||||
num_batches_per_sp_group = self.num_samples_per_sp_group // self.batch_size
|
||||
target_len = self.batch_size * num_batches_per_sp_group
|
||||
if len(batch_indices) > target_len:
|
||||
batch_indices = batch_indices[:target_len]
|
||||
|
||||
# Apply the GRPO repeat pattern
|
||||
final_indices = []
|
||||
for _ in range(self.repeat_count):
|
||||
for idx in batch_indices:
|
||||
for _ in range(self.mini_repeat_count):
|
||||
final_indices.append(idx)
|
||||
|
||||
return iter(final_indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the total length of the iterable including repetitions.
|
||||
|
||||
Returns:
|
||||
Total number of samples.
|
||||
"""
|
||||
# Total length including all repetitions
|
||||
return (
|
||||
self.num_samples_per_sp_group * self.mini_repeat_count * self.repeat_count
|
||||
)
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
"""Sets the epoch for this sampler.
|
||||
|
||||
Args:
|
||||
epoch: Epoch number to use for shuffling.
|
||||
"""
|
||||
self.epoch = epoch
|
||||
@@ -1,23 +1,63 @@
|
||||
"""
|
||||
Axolotl GRPO trainer
|
||||
"""
|
||||
"""Axolotl GRPO trainers (with and without sequence parallelism handling)"""
|
||||
|
||||
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
||||
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
|
||||
from accelerate.utils import is_deepspeed_available, is_peft_model
|
||||
import datasets
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data
|
||||
from accelerate.utils import (
|
||||
broadcast_object_list,
|
||||
gather,
|
||||
gather_object,
|
||||
is_peft_model,
|
||||
)
|
||||
from datasets import Dataset, IterableDataset
|
||||
from torch import nn
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
DataLoader,
|
||||
Sampler,
|
||||
)
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
)
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from transformers.utils import is_peft_available
|
||||
from trl import GRPOTrainer
|
||||
from trl.extras.profiling import profiling_decorator
|
||||
from trl.data_utils import (
|
||||
apply_chat_template,
|
||||
is_conversational,
|
||||
maybe_apply_chat_template,
|
||||
)
|
||||
from trl.extras.profiling import profiling_context, profiling_decorator
|
||||
from trl.import_utils import is_deepspeed_available
|
||||
from trl.models import unwrap_model_for_generation
|
||||
from trl.trainer.grpo_config import GRPOConfig
|
||||
from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
||||
from trl.trainer.utils import pad
|
||||
|
||||
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group
|
||||
|
||||
if is_peft_available():
|
||||
# pylint: disable=unused-import
|
||||
from peft import PeftConfig
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
|
||||
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
"""
|
||||
Extend the base GRPOTrainer for axolotl helpers
|
||||
"""
|
||||
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
@@ -67,3 +107,600 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
# Reset cache on main process
|
||||
if self.accelerator.is_main_process:
|
||||
self.vllm_client.reset_prefix_cache()
|
||||
|
||||
|
||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str | PreTrainedModel,
|
||||
reward_funcs: RewardFunc | list[RewardFunc],
|
||||
args: GRPOConfig | None = None,
|
||||
train_dataset: Dataset | IterableDataset | None = None,
|
||||
eval_dataset: (
|
||||
Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None
|
||||
) = None,
|
||||
processing_class: PreTrainedTokenizerBase | None = None,
|
||||
reward_processing_classes: (
|
||||
PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None
|
||||
) = None,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
optimizers: tuple[
|
||||
torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None
|
||||
] = (None, None),
|
||||
peft_config: "PeftConfig | None" = None,
|
||||
):
|
||||
# First call the superclass constructor with all arguments
|
||||
super().__init__(
|
||||
model=model,
|
||||
reward_funcs=reward_funcs,
|
||||
args=args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
reward_processing_classes=reward_processing_classes,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
# Get number of SP groups (number of processes divided by SP degree)
|
||||
num_processes = self.accelerator.num_processes
|
||||
num_sp_groups = num_processes // self.args.sequence_parallel_degree
|
||||
|
||||
# Calculate batch size per SP group (not per process)
|
||||
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
|
||||
possible_values = [
|
||||
n_gen
|
||||
for n_gen in range(2, sp_group_batch_size + 1)
|
||||
if (sp_group_batch_size) % n_gen == 0
|
||||
]
|
||||
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"The batch size per SP group ({num_sp_groups} x "
|
||||
f"{self.args.per_device_train_batch_size}) must be evenly divisible by "
|
||||
f"the number of generations per prompt ({self.num_generations}). Given "
|
||||
"the current configuration, the valid values for the number of "
|
||||
f"generations are: {possible_values}."
|
||||
)
|
||||
|
||||
if self.args.eval_strategy != "no":
|
||||
# If sequence parallelism is enabled, calculate batch size per SP group
|
||||
sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups # type: ignore[union-attr]
|
||||
possible_values = [
|
||||
n_gen
|
||||
for n_gen in range(2, sp_group_eval_batch_size + 1)
|
||||
if (sp_group_eval_batch_size) % n_gen == 0
|
||||
]
|
||||
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
|
||||
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
||||
f"must be evenly divisible by the number of generations per prompt "
|
||||
f"({self.num_generations}). Given the current eval batch size, "
|
||||
f"the valid values for the number of generations are: {possible_values}."
|
||||
)
|
||||
|
||||
# Initialize the SP group
|
||||
self.sp_group = get_ring_attn_group()
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
self.local_rank = dist.get_rank(group=self.sp_group)
|
||||
self.local_world_size = dist.get_world_size(group=self.sp_group)
|
||||
|
||||
def _get_train_sampler(self) -> Sampler:
|
||||
effective_batch_size = (
|
||||
self.args.per_device_train_batch_size
|
||||
* self.world_size
|
||||
* self.args.gradient_accumulation_steps
|
||||
)
|
||||
|
||||
return SequenceParallelRepeatRandomSampler(
|
||||
dataset=self.train_dataset,
|
||||
mini_repeat_count=self.num_generations,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
batch_size=effective_batch_size
|
||||
// self.num_generations
|
||||
// self.args.sequence_parallel_degree,
|
||||
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
|
||||
sequence_parallel_degree=self.args.sequence_parallel_degree,
|
||||
shuffle=True,
|
||||
seed=self.args.seed,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
|
||||
"""Create common dataloader parameters for train or eval."""
|
||||
batch_size = custom_batch_size or (
|
||||
self.args.eval_batch_size if is_eval else self._train_batch_size
|
||||
)
|
||||
|
||||
params = {
|
||||
"batch_size": batch_size,
|
||||
"collate_fn": self.data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
|
||||
# Add persistent workers only for training
|
||||
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
|
||||
params["persistent_workers"] = self.args.dataloader_persistent_workers
|
||||
|
||||
# Add prefetch factor if specified
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
return params
|
||||
|
||||
def _prepare_dataloader(
|
||||
self, dataset, sampler, is_eval=False, custom_batch_size=None
|
||||
):
|
||||
"""Prepare a dataloader with the given dataset and sampler."""
|
||||
# Get base parameters
|
||||
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
|
||||
|
||||
# Add sampler configuration
|
||||
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
if isinstance(sampler, BatchSampler):
|
||||
# batch_size and batch_sampler are mutually exclusive
|
||||
dataloader_params["batch_sampler"] = sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
if not is_eval:
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
# Create the dataloader
|
||||
dataloader = DataLoader(dataset, **dataloader_params)
|
||||
|
||||
if self.args.sample_packing and (
|
||||
(not is_eval and not self.args.pretraining)
|
||||
or (is_eval and self.args.eval_sample_packing is not False)
|
||||
):
|
||||
self.accelerator.even_batches = False
|
||||
|
||||
# Return unprepared dataloader if using sequence parallelism
|
||||
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||
# slice each batch along the sequence dimension).
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
return dataloader
|
||||
|
||||
# Otherwise prepare with accelerator
|
||||
return self.accelerator.prepare_data_loader(dataloader)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
"""Get dataloader for training"""
|
||||
train_dataset = self.train_dataset
|
||||
# pylint: disable=access-member-before-definition
|
||||
data_collator = self.data_collator # type: ignore
|
||||
|
||||
# Handle dataset preprocessing
|
||||
if isinstance(train_dataset, datasets.Dataset):
|
||||
# Add debug print before any modifications
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
if not self.args.sample_packing or self.args.pretraining:
|
||||
train_dataset = self._remove_unused_columns(
|
||||
train_dataset, description="training"
|
||||
)
|
||||
else:
|
||||
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||
data_collator,
|
||||
description="training",
|
||||
)
|
||||
|
||||
# Get sampler and create dataloader
|
||||
sampler = self._get_train_sampler()
|
||||
dataloader = self._prepare_dataloader(train_dataset, sampler, is_eval=False)
|
||||
|
||||
return dataloader
|
||||
|
||||
def _generate_and_score_completions(
|
||||
self, inputs: list[dict[str, torch.Tensor | Any]]
|
||||
) -> dict[str, torch.Tensor | Any]:
|
||||
device = self.accelerator.device
|
||||
mode = "eval" if self.control.should_evaluate else "train"
|
||||
|
||||
prompts = [x["prompt"] for x in inputs]
|
||||
prompts_text = [
|
||||
maybe_apply_chat_template(example, self.processing_class)["prompt"]
|
||||
for example in inputs
|
||||
]
|
||||
prompt_inputs = self.processing_class(
|
||||
text=prompts_text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
add_special_tokens=False,
|
||||
)
|
||||
prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs)
|
||||
prompt_ids, prompt_mask = (
|
||||
prompt_inputs["input_ids"],
|
||||
prompt_inputs["attention_mask"],
|
||||
)
|
||||
|
||||
if self.max_prompt_length is not None:
|
||||
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
||||
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
||||
|
||||
# Generate completions using either vLLM or regular generation
|
||||
if self.args.use_vllm:
|
||||
# First, have main process load weights if needed
|
||||
# pylint: disable=access-member-before-definition
|
||||
if self.state.global_step != self._last_loaded_step: # type: ignore[has-type]
|
||||
self._move_model_to_vllm()
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self._last_loaded_step = self.state.global_step
|
||||
|
||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
if self.accelerator.is_main_process:
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Calculate sequence parallel group information
|
||||
world_size = self.accelerator.num_processes
|
||||
sequence_parallel_degree = self.args.sequence_parallel_degree
|
||||
num_sp_groups = world_size // sequence_parallel_degree
|
||||
|
||||
# Since processes in the same SP group have the same prompts, we need to ensure
|
||||
# we only take one copy of each prompt from each SP group
|
||||
ordered_set_of_prompts = []
|
||||
for sp_group_id in range(num_sp_groups):
|
||||
# Get the first process from each SP group (typically the group leader)
|
||||
group_leader_rank = sp_group_id * sequence_parallel_degree
|
||||
|
||||
# Extract prompts from this SP group, accounting for num_generations duplicates
|
||||
# We only need prompts from one rank in each SP group
|
||||
group_prompts = all_prompts_text[
|
||||
group_leader_rank
|
||||
* len(prompts_text) : (group_leader_rank + 1)
|
||||
* len(prompts_text) : self.num_generations
|
||||
]
|
||||
|
||||
ordered_set_of_prompts.extend(group_prompts)
|
||||
else:
|
||||
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
|
||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||
# prompt individually.
|
||||
ordered_set_of_prompts = all_prompts_text[
|
||||
:: self.num_generations * self.args.sequence_parallel_degree
|
||||
]
|
||||
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
completion_ids = self.vllm_client.generate(
|
||||
prompts=ordered_set_of_prompts,
|
||||
n=self.num_generations,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=-1 if self.top_k is None else self.top_k,
|
||||
min_p=0.0 if self.min_p is None else self.min_p,
|
||||
max_tokens=self.max_completion_length,
|
||||
guided_decoding_regex=self.guided_decoding_regex,
|
||||
)
|
||||
else:
|
||||
completion_ids = [None] * (
|
||||
len(all_prompts_text) // self.args.sequence_parallel_degree
|
||||
)
|
||||
|
||||
# Broadcast the completions from the main process to all processes
|
||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||
|
||||
# Determine the appropriate slice based on sequence parallelism
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
# Calculate the start index for this SP group
|
||||
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
|
||||
|
||||
# All ranks in the same SP group get the same data slice
|
||||
process_slice = slice(
|
||||
sp_group_start,
|
||||
sp_group_start + len(prompts),
|
||||
)
|
||||
completion_ids = completion_ids[process_slice]
|
||||
else:
|
||||
# Original behavior for non-sequence parallel case
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts),
|
||||
(self.accelerator.process_index + 1) * len(prompts),
|
||||
)
|
||||
completion_ids = completion_ids[process_slice]
|
||||
|
||||
# Pad the completions, and concatenate them with the prompts
|
||||
completion_ids = [
|
||||
torch.tensor(ids, device=device) for ids in completion_ids
|
||||
]
|
||||
completion_ids = pad(
|
||||
completion_ids, padding_value=self.processing_class.pad_token_id
|
||||
)
|
||||
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
||||
else:
|
||||
# Regular generation path
|
||||
with unwrap_model_for_generation(
|
||||
self.model_wrapped,
|
||||
self.accelerator,
|
||||
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
|
||||
) as unwrapped_model:
|
||||
prompt_completion_ids = unwrapped_model.generate(
|
||||
prompt_ids,
|
||||
attention_mask=prompt_mask,
|
||||
generation_config=self.generation_config,
|
||||
)
|
||||
|
||||
# Compute prompt length and extract completion ids
|
||||
prompt_length = prompt_ids.size(1)
|
||||
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
||||
completion_ids = prompt_completion_ids[:, prompt_length:]
|
||||
|
||||
# Mask everything after the first EOS token
|
||||
is_eos = completion_ids == self.processing_class.eos_token_id
|
||||
eos_idx = torch.full(
|
||||
(is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device
|
||||
)
|
||||
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
||||
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(
|
||||
is_eos.size(0), -1
|
||||
)
|
||||
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
||||
|
||||
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
|
||||
if self.args.mask_truncated_completions:
|
||||
truncated_completions = ~is_eos.any(dim=1)
|
||||
completion_mask = (
|
||||
completion_mask * (~truncated_completions).unsqueeze(1).int()
|
||||
)
|
||||
|
||||
# Concatenate prompt_mask with completion_mask for logit computation
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
||||
|
||||
logits_to_keep = completion_ids.size(
|
||||
1
|
||||
) # we only need to compute the logits for the completion tokens
|
||||
batch_size = (
|
||||
self.args.per_device_train_batch_size
|
||||
if mode == "train"
|
||||
else self.args.per_device_eval_batch_size
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
|
||||
# computation here, and use per_token_logps.detach() instead.
|
||||
if self.num_iterations > 1:
|
||||
old_per_token_logps = self._get_per_token_logps(
|
||||
self.model,
|
||||
prompt_completion_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
batch_size,
|
||||
)
|
||||
else:
|
||||
old_per_token_logps = None
|
||||
|
||||
if self.beta == 0.0:
|
||||
ref_per_token_logps = None
|
||||
elif self.ref_model is not None:
|
||||
ref_per_token_logps = self._get_per_token_logps(
|
||||
self.ref_model,
|
||||
prompt_completion_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
batch_size,
|
||||
)
|
||||
else:
|
||||
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
||||
ref_per_token_logps = self._get_per_token_logps(
|
||||
self.model,
|
||||
prompt_completion_ids,
|
||||
attention_mask,
|
||||
logits_to_keep,
|
||||
batch_size,
|
||||
)
|
||||
|
||||
# Decode the generated completions
|
||||
completions_text = self.processing_class.batch_decode(
|
||||
completion_ids, skip_special_tokens=True
|
||||
)
|
||||
if is_conversational(inputs[0]):
|
||||
completions = []
|
||||
for prompt, completion in zip(prompts, completions_text):
|
||||
bootstrap = (
|
||||
prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
|
||||
)
|
||||
completions.append(
|
||||
[{"role": "assistant", "content": bootstrap + completion}]
|
||||
)
|
||||
else:
|
||||
completions = completions_text
|
||||
|
||||
rewards_per_func = torch.zeros(
|
||||
len(prompts), len(self.reward_funcs), device=device
|
||||
)
|
||||
for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
|
||||
zip(
|
||||
self.reward_funcs,
|
||||
self.reward_processing_classes,
|
||||
self.reward_func_names,
|
||||
)
|
||||
):
|
||||
with profiling_context(self, reward_func_name):
|
||||
if isinstance(
|
||||
reward_func, nn.Module
|
||||
): # Module instead of PretrainedModel for compat with compiled models
|
||||
if is_conversational(inputs[0]):
|
||||
messages = [
|
||||
{"messages": p + c} for p, c in zip(prompts, completions)
|
||||
]
|
||||
texts = [
|
||||
apply_chat_template(x, reward_processing_class)["text"]
|
||||
for x in messages
|
||||
]
|
||||
else:
|
||||
texts = [p + c for p, c in zip(prompts, completions)]
|
||||
reward_inputs = reward_processing_class(
|
||||
text=texts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="right",
|
||||
add_special_tokens=False,
|
||||
)
|
||||
reward_inputs = Trainer._prepare_inputs(self, reward_inputs)
|
||||
with torch.inference_mode():
|
||||
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[
|
||||
:, 0
|
||||
] # Shape (B*G,)
|
||||
else:
|
||||
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
||||
keys = [
|
||||
key for key in inputs[0] if key not in ["prompt", "completion"]
|
||||
]
|
||||
reward_kwargs = {
|
||||
key: [example[key] for example in inputs] for key in keys
|
||||
}
|
||||
output_reward_func = reward_func(
|
||||
prompts=prompts, completions=completions, **reward_kwargs
|
||||
)
|
||||
# Convert None values to NaN
|
||||
output_reward_func = [
|
||||
reward if reward is not None else torch.nan
|
||||
for reward in output_reward_func
|
||||
]
|
||||
|
||||
rewards_per_func[:, i] = torch.tensor(
|
||||
output_reward_func, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
# If all reward functions return None for a given row, issue a detailed warning
|
||||
if torch.isnan(rewards_per_func).all(dim=1).any():
|
||||
nan_row_idx = (
|
||||
torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
|
||||
)
|
||||
row_reward_kwargs = {
|
||||
key: value[nan_row_idx] for key, value in reward_kwargs.items()
|
||||
}
|
||||
row_reward_kwargs["prompt"] = prompts[nan_row_idx]
|
||||
row_reward_kwargs["completion"] = completions[nan_row_idx]
|
||||
warnings.warn(
|
||||
f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
|
||||
"Please ensure that at least one reward function returns a valid reward."
|
||||
)
|
||||
|
||||
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
|
||||
# completions may be distributed across processes
|
||||
rewards_per_func = gather(rewards_per_func)
|
||||
|
||||
# Apply weights to each reward function's output and sum
|
||||
rewards = (
|
||||
rewards_per_func * self.reward_weights.to(device).unsqueeze(0)
|
||||
).nansum(dim=1)
|
||||
|
||||
# Compute grouped-wise rewards
|
||||
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
||||
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
||||
|
||||
# Normalize the rewards to compute the advantages
|
||||
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
|
||||
self.num_generations, dim=0
|
||||
)
|
||||
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
|
||||
self.num_generations, dim=0
|
||||
)
|
||||
advantages = rewards - mean_grouped_rewards
|
||||
if self.args.scale_rewards:
|
||||
advantages = advantages / (std_grouped_rewards + 1e-4)
|
||||
|
||||
# Slice to keep only the local part of the data
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
# Calculate the start index for this SP group
|
||||
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
|
||||
|
||||
# All ranks in the same SP group get the same data slice
|
||||
process_slice = slice(
|
||||
sp_group_start,
|
||||
sp_group_start + len(prompts),
|
||||
)
|
||||
else:
|
||||
# Original behavior for non-sequence parallel case
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts),
|
||||
(self.accelerator.process_index + 1) * len(prompts),
|
||||
)
|
||||
advantages = advantages[process_slice]
|
||||
|
||||
# Log the metrics
|
||||
if mode == "train":
|
||||
self._total_train_tokens += (
|
||||
self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
|
||||
)
|
||||
self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
|
||||
|
||||
# log completion lengths, mean, min, max
|
||||
agg_completion_mask = self.accelerator.gather_for_metrics(
|
||||
completion_mask.sum(1)
|
||||
)
|
||||
self._metrics[mode]["completions/mean_length"].append(
|
||||
agg_completion_mask.float().mean().item()
|
||||
)
|
||||
self._metrics[mode]["completions/min_length"].append(
|
||||
agg_completion_mask.float().min().item()
|
||||
)
|
||||
self._metrics[mode]["completions/max_length"].append(
|
||||
agg_completion_mask.float().max().item()
|
||||
)
|
||||
|
||||
# identify sequences that terminated with EOS and log their lengths
|
||||
agg_terminated_with_eos = self.accelerator.gather_for_metrics(is_eos.any(dim=1))
|
||||
term_completion_mask = agg_completion_mask[agg_terminated_with_eos]
|
||||
clipped_completions_ratio = 1 - len(term_completion_mask) / len(
|
||||
agg_completion_mask
|
||||
)
|
||||
self._metrics[mode]["completions/clipped_ratio"].append(
|
||||
clipped_completions_ratio
|
||||
)
|
||||
if len(term_completion_mask) == 0:
|
||||
# edge case where no completed sequences are found
|
||||
term_completion_mask = torch.zeros(1, device=device)
|
||||
self._metrics[mode]["completions/mean_terminated_length"].append(
|
||||
term_completion_mask.float().mean().item()
|
||||
)
|
||||
self._metrics[mode]["completions/min_terminated_length"].append(
|
||||
term_completion_mask.float().min().item()
|
||||
)
|
||||
self._metrics[mode]["completions/max_terminated_length"].append(
|
||||
term_completion_mask.float().max().item()
|
||||
)
|
||||
|
||||
# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
|
||||
for i, reward_func_name in enumerate(self.reward_func_names):
|
||||
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
|
||||
self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
|
||||
std_rewards = nanstd(rewards_per_func[:, i]).item()
|
||||
self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_rewards)
|
||||
self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
|
||||
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
|
||||
|
||||
# Log prompt and completion texts
|
||||
self._textual_logs["prompt"].extend(gather_object(prompts_text))
|
||||
self._textual_logs["completion"].extend(gather_object(completions_text))
|
||||
for i, name in enumerate(self.reward_func_names):
|
||||
self._textual_logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
|
||||
|
||||
return {
|
||||
"prompt_ids": prompt_ids,
|
||||
"prompt_mask": prompt_mask,
|
||||
"completion_ids": completion_ids,
|
||||
"completion_mask": completion_mask,
|
||||
"advantages": advantages,
|
||||
"old_per_token_logps": old_per_token_logps,
|
||||
"ref_per_token_logps": ref_per_token_logps,
|
||||
}
|
||||
|
||||
@@ -6,4 +6,4 @@
|
||||
from .optimizer import OptimizerMixin
|
||||
from .rng_state_loader import RngLoaderMixin
|
||||
from .scheduler import SchedulerMixin
|
||||
from .sequence_parallel import SequenceParallelContextManager, SequenceParallelMixin
|
||||
from .sequence_parallel import SequenceParallelMixin
|
||||
|
||||
@@ -1,85 +1,13 @@
|
||||
"""
|
||||
Module for Axolotl trainer sequence parallelism mixin and training context manager
|
||||
"""
|
||||
"""Module for Axolotl trainer sequence parallelism mixin"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from datasets import Dataset
|
||||
from torch import nn
|
||||
from torch.utils.data import DistributedSampler, Sampler
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import (
|
||||
RingAttnFunc,
|
||||
get_ring_attn_group,
|
||||
update_ring_attn_params,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def apply_sequence_parallelism(
|
||||
batch: dict[str, torch.Tensor],
|
||||
local_rank: int,
|
||||
local_world_size: int,
|
||||
ring_attn_func: RingAttnFunc,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Apply sequence parallelism slicing to a batch.
|
||||
|
||||
Args:
|
||||
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
|
||||
local_rank: Local rank in the sequence parallel group
|
||||
local_world_size: World size of the sequence parallel group
|
||||
ring_attn_func: The ring attention function to use
|
||||
|
||||
Returns:
|
||||
Sliced batch dictionary.
|
||||
"""
|
||||
# Update ring attention params if needed
|
||||
if batch.get("position_ids") is not None:
|
||||
update_ring_attn_params(position_ids=batch["position_ids"])
|
||||
|
||||
# Slice batch for sequence parallel processing
|
||||
total_seq_len = batch["input_ids"].size(1)
|
||||
for key in batch:
|
||||
if (
|
||||
key in batch
|
||||
and isinstance(batch[key], torch.Tensor)
|
||||
and batch[key].dim() > 1
|
||||
and batch[key].size(1) == total_seq_len
|
||||
):
|
||||
|
||||
if ring_attn_func in [
|
||||
RingAttnFunc.VARLEN_LLAMA3,
|
||||
RingAttnFunc.BATCH_RING,
|
||||
]:
|
||||
# Split in sequential fashion and grab this rank's chunk
|
||||
batch[key] = (
|
||||
batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous()
|
||||
)
|
||||
elif ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
|
||||
chunks = batch[key].chunk(2 * local_world_size, dim=1)
|
||||
|
||||
# Take rank's chunk and opposing chunk for zigzag pattern
|
||||
selected_chunks = [
|
||||
chunks[local_rank],
|
||||
chunks[2 * local_world_size - local_rank - 1],
|
||||
]
|
||||
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
|
||||
elif ring_attn_func is RingAttnFunc.BATCH_STRIPE:
|
||||
# Split into striped data and stack
|
||||
tensor = torch.stack(
|
||||
batch[key].split(local_world_size, dim=1),
|
||||
dim=1,
|
||||
).transpose(1, 2)
|
||||
batch[key] = tensor[:, local_rank].contiguous()
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class SequenceParallelMixin:
|
||||
"""
|
||||
@@ -157,157 +85,3 @@ class SequenceParallelMixin:
|
||||
return self._create_sequence_parallel_sampler(
|
||||
eval_dataset, shuffle=False, is_eval=True
|
||||
)
|
||||
|
||||
|
||||
class SequenceParallelContextManager:
|
||||
"""
|
||||
Context manager for sequence parallelism operations.
|
||||
|
||||
This class provides a context that will automatically apply sequence parallelism
|
||||
during model forward passes using a pre-forward hook, and gather outputs from
|
||||
across the sequence parallelism group using a post-forward hook.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
sequence_parallel_degree: int,
|
||||
ring_attn_func: RingAttnFunc,
|
||||
):
|
||||
self.model = model
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.ring_attn_func = ring_attn_func
|
||||
self.process_group = get_ring_attn_group()
|
||||
|
||||
# Initialize sequence parallel group details
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.local_world_size = dist.get_world_size(self.process_group)
|
||||
|
||||
# Will store hook handles for removal
|
||||
self.hook_handles: list[RemovableHandle] = []
|
||||
|
||||
# Create a partially applied version of the apply_sequence_parallelism function
|
||||
# with pre-configured params
|
||||
self.apply_sequence_parallelism = functools.partial(
|
||||
apply_sequence_parallelism,
|
||||
local_rank=self.local_rank,
|
||||
local_world_size=self.local_world_size,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
# Forward pre-hook to apply sequence parallelism
|
||||
def sequence_parallel_pre_hook(_, args, kwargs):
|
||||
# Apply sequence parallelism to kwargs
|
||||
kwargs = self.apply_sequence_parallelism(batch=kwargs)
|
||||
return args, kwargs
|
||||
|
||||
# Forward post-hook to gather outputs
|
||||
def sequence_parallel_post_hook(_, __, output):
|
||||
# Gather the sharded outputs
|
||||
return self.gather_outputs(output)
|
||||
|
||||
# Register both hooks
|
||||
self.hook_handles.append(
|
||||
self.model.register_forward_pre_hook(
|
||||
sequence_parallel_pre_hook, with_kwargs=True
|
||||
)
|
||||
)
|
||||
self.hook_handles.append(
|
||||
self.model.register_forward_hook(sequence_parallel_post_hook)
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Remove all hooks
|
||||
for handle in self.hook_handles:
|
||||
handle.remove()
|
||||
self.hook_handles = []
|
||||
|
||||
def gather_outputs(self, output):
|
||||
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
|
||||
# Handle different output formats (dict, tensor, etc.)
|
||||
if isinstance(output, dict):
|
||||
gathered_output = {}
|
||||
for key, value in output.items():
|
||||
if isinstance(value, torch.Tensor) and value.dim() > 1:
|
||||
# Gather logits or other sequence-sharded tensors
|
||||
gathered_value = self.gather_tensor(value)
|
||||
gathered_output[key] = gathered_value
|
||||
else:
|
||||
gathered_value = value.clone()
|
||||
dist.all_reduce(
|
||||
gathered_value, op=dist.ReduceOp.SUM, group=self.process_group
|
||||
)
|
||||
gathered_output[key] = gathered_value
|
||||
return gathered_output
|
||||
if isinstance(output, torch.Tensor):
|
||||
return self.gather_tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
def gather_tensor(self, tensor):
|
||||
"""Gather a sharded tensor from all ranks."""
|
||||
# Prepare tensors for all_gather
|
||||
world_size = self.local_world_size
|
||||
|
||||
# Create list to store tensors from all ranks
|
||||
gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
|
||||
|
||||
# All-gather operation
|
||||
dist.all_gather(gathered_tensors, tensor, group=self.process_group)
|
||||
|
||||
# Concatenate along sequence dimension (typically dim=1)
|
||||
if self.ring_attn_func in [RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.BATCH_RING]:
|
||||
# Simple concatenation for standard sharding
|
||||
return torch.cat(gathered_tensors, dim=1)
|
||||
|
||||
if self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
|
||||
# Each rank has a pattern of (rank, world_size*2-rank-1)
|
||||
reconstituted_tensors = [None] * (world_size * 2)
|
||||
|
||||
# First, split each gathered tensor into its two chunks
|
||||
for rank, gathered_tensor in enumerate(gathered_tensors):
|
||||
# Each tensor contains two chunks in the sequence dimension
|
||||
chunk_size = gathered_tensor.size(1) // 2
|
||||
chunk1, chunk2 = gathered_tensor.split(chunk_size, dim=1)
|
||||
|
||||
# Place chunks in their original positions
|
||||
reconstituted_tensors[rank] = chunk1
|
||||
reconstituted_tensors[world_size * 2 - rank - 1] = chunk2
|
||||
|
||||
# Concatenate the reconstituted tensors in the correct order
|
||||
return torch.cat(reconstituted_tensors, dim=1)
|
||||
|
||||
# Otherwise, RingAttnFunc.BATCH_STRIPE
|
||||
# In striping, each rank has every world_size-th slice
|
||||
batch_size = tensor.size(0)
|
||||
hidden_dim = tensor.size(-1)
|
||||
|
||||
# First, determine the full sequence length
|
||||
total_seq_len = 0
|
||||
for t in gathered_tensors:
|
||||
total_seq_len += t.size(1)
|
||||
|
||||
# Create a tensor to hold the unstriped result
|
||||
result = torch.zeros(
|
||||
batch_size,
|
||||
total_seq_len,
|
||||
hidden_dim,
|
||||
dtype=tensor.dtype,
|
||||
device=tensor.device,
|
||||
)
|
||||
|
||||
# For each rank's tensor, distribute its slices to the correct positions
|
||||
for rank, gathered_tensor in enumerate(gathered_tensors):
|
||||
# The rank's tensor contains every world_size-th slice
|
||||
# starting from its rank position
|
||||
seq_len = gathered_tensor.size(1)
|
||||
for i in range(seq_len):
|
||||
# Calculate the position in the full tensor
|
||||
pos = i * world_size + rank
|
||||
if pos < total_seq_len:
|
||||
result[:, pos] = gathered_tensor[:, i]
|
||||
|
||||
return result
|
||||
|
||||
@@ -9,7 +9,7 @@ from PIL.Image import Resampling
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -26,6 +26,8 @@ from typing import OrderedDict
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
"""
|
||||
@@ -36,11 +38,13 @@ class BasePlugin:
|
||||
|
||||
Methods:
|
||||
register(cfg): Registers the plugin with the given configuration.
|
||||
load_datasets(cfg): Loads and preprocesses the dataset for training.
|
||||
pre_model_load(cfg): Performs actions before the model is loaded.
|
||||
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
|
||||
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
|
||||
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
|
||||
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
|
||||
post_trainer_create(cfg, trainer): Performs actions after the trainer is created.
|
||||
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
|
||||
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler.
|
||||
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
|
||||
@@ -63,20 +67,32 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def get_input_args(self):
|
||||
def get_input_args(self) -> str | None:
|
||||
"""
|
||||
Returns a pydantic model for the plugin's input arguments.
|
||||
"""
|
||||
|
||||
def load_datasets(self, cfg: DictDefault, preprocess: bool = False):
|
||||
"""
|
||||
Loads and preprocesses the dataset for training.
|
||||
|
||||
Args:
|
||||
cfg: The configuration for the plugin.
|
||||
preprocess: Whether this is the preprocess step of the datasets.
|
||||
|
||||
Returns:
|
||||
dataset_meta: The metadata for the training dataset.
|
||||
"""
|
||||
|
||||
def pre_model_load(self, cfg): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions before the model is loaded.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
|
||||
Returns:
|
||||
None
|
||||
None
|
||||
"""
|
||||
|
||||
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
|
||||
@@ -91,59 +107,71 @@ class BasePlugin:
|
||||
"""
|
||||
Performs actions after the model is loaded.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
None
|
||||
"""
|
||||
|
||||
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions before LoRA weights are loaded.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
None
|
||||
"""
|
||||
|
||||
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after LoRA weights are loaded.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
None
|
||||
"""
|
||||
|
||||
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
|
||||
"""
|
||||
Returns a custom class for the trainer.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The global axolotl configuration.
|
||||
Args:
|
||||
cfg (dict): The global axolotl configuration.
|
||||
|
||||
Returns:
|
||||
class: The class for the trainer.
|
||||
class: The class for the trainer.
|
||||
"""
|
||||
|
||||
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after the trainer is created.
|
||||
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
"""
|
||||
Creates and returns an optimizer for training.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
|
||||
Returns:
|
||||
object: The created optimizer.
|
||||
object: The created optimizer.
|
||||
"""
|
||||
|
||||
def create_lr_scheduler(
|
||||
@@ -152,26 +180,26 @@ class BasePlugin:
|
||||
"""
|
||||
Creates and returns a learning rate scheduler.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
optimizer (object): The optimizer for training.
|
||||
num_training_steps (int): Total number of training steps
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
optimizer (object): The optimizer for training.
|
||||
num_training_steps (int): Total number of training steps
|
||||
|
||||
Returns:
|
||||
object (LRScheduler): The created learning rate scheduler.
|
||||
object (LRScheduler): The created learning rate scheduler.
|
||||
"""
|
||||
|
||||
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
setup callbacks before creating the trainer.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||
"""
|
||||
return []
|
||||
|
||||
@@ -182,12 +210,12 @@ class BasePlugin:
|
||||
Adds callbacks to the trainer after creating the trainer.
|
||||
This is useful for callbacks that require access to the model or trainer.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
|
||||
Returns:
|
||||
List[callable]: A list of callback functions to be added
|
||||
List[callable]: A list of callback functions to be added
|
||||
"""
|
||||
return []
|
||||
|
||||
@@ -195,23 +223,23 @@ class BasePlugin:
|
||||
"""
|
||||
Performs actions after training is complete.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The axolotl configuration
|
||||
model (object): The loaded model.
|
||||
Args:
|
||||
cfg (dict): The axolotl configuration
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
None
|
||||
"""
|
||||
|
||||
def post_train_unload(self, cfg): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after training is complete and the model is unloaded.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
|
||||
Returns:
|
||||
None
|
||||
None
|
||||
"""
|
||||
|
||||
|
||||
@@ -338,6 +366,27 @@ class PluginManager:
|
||||
input_args.append(input_args_from_plugin)
|
||||
return input_args
|
||||
|
||||
def load_datasets(self, cfg, preprocess: bool = False):
|
||||
"""
|
||||
Calls the load_datasets method of each registered plugin.
|
||||
|
||||
Args:
|
||||
cfg: The configuration for the plugins.
|
||||
preprocess : Whether this is preprocess step of the datasets.
|
||||
|
||||
Returns:
|
||||
dataset_meta: The dataset metadata loaded from all registered plugins.
|
||||
"""
|
||||
return_ds_meta = None
|
||||
for plugin in self.plugins.values():
|
||||
dataset_meta = plugin.load_datasets(cfg, preprocess)
|
||||
if dataset_meta is not None:
|
||||
if return_ds_meta is None:
|
||||
return_ds_meta = dataset_meta
|
||||
else:
|
||||
raise RuntimeError("Multiple plugins loaded datasets")
|
||||
return return_ds_meta
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""
|
||||
Calls the pre_model_load method of all registered plugins.
|
||||
@@ -422,6 +471,20 @@ class PluginManager:
|
||||
return trainer_cls
|
||||
return None
|
||||
|
||||
def post_trainer_create(self, cfg, trainer):
|
||||
"""
|
||||
Calls the post_trainer_create method of all registered plugins.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
trainer (object): The trainer object for training.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
|
||||
def create_optimizer(self, trainer):
|
||||
"""
|
||||
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
|
||||
|
||||
@@ -151,6 +151,30 @@ class LigerPlugin(BasePlugin):
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "qwen3":
|
||||
from axolotl.integrations.liger.models.qwen3 import (
|
||||
apply_liger_kernel_to_qwen3,
|
||||
)
|
||||
|
||||
apply_liger_kernel_to_qwen3(
|
||||
cross_entropy=cfg.liger_cross_entropy,
|
||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||
glu_activation=cfg.liger_glu_activation,
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "qwen3_moe":
|
||||
from axolotl.integrations.liger.models.qwen3_moe import (
|
||||
apply_liger_kernel_to_qwen3_moe,
|
||||
)
|
||||
|
||||
apply_liger_kernel_to_qwen3_moe(
|
||||
cross_entropy=cfg.liger_cross_entropy,
|
||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||
glu_activation=cfg.liger_glu_activation,
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
||||
|
||||
160
src/axolotl/integrations/liger/models/qwen3.py
Normal file
160
src/axolotl/integrations/liger/models/qwen3.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
Liger FLCE for Qwen3. Based on transformers v4.51.3.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
def lce_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = None
|
||||
loss = None
|
||||
# if in training mode, don't materialize logits
|
||||
if self.training and (labels is not None):
|
||||
loss = LigerForCausalLMLoss(
|
||||
hidden_states=hidden_states,
|
||||
lm_head_weight=self.lm_head.weight,
|
||||
labels=labels,
|
||||
hidden_size=self.config.hidden_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else: # if in inference mode materialize logits
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def apply_liger_kernel_to_qwen3(
|
||||
cross_entropy: bool = False,
|
||||
fused_linear_cross_entropy: bool = False,
|
||||
rms_norm: bool = False,
|
||||
glu_activation: bool = False,
|
||||
layer_norm: bool = False,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
"""
|
||||
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
|
||||
|
||||
Args:
|
||||
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
||||
fused_linear_cross_entropy (bool):
|
||||
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
||||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
|
||||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
||||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
||||
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
||||
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
||||
"""
|
||||
|
||||
import transformers.models.qwen3.modeling_qwen3 # noqa: F401 # pylint: disable=unused-import
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
assert not (
|
||||
cross_entropy and fused_linear_cross_entropy
|
||||
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
||||
|
||||
modeling_qwen3 = sys.modules["transformers.models.qwen3.modeling_qwen3"]
|
||||
|
||||
if rms_norm:
|
||||
modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
|
||||
|
||||
if glu_activation:
|
||||
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
|
||||
|
||||
if layer_norm:
|
||||
modeling_qwen3.nn.LayerNorm = LigerLayerNorm
|
||||
|
||||
if cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
|
||||
if fused_linear_cross_entropy:
|
||||
modeling_qwen3.Qwen3ForCausalLM.forward = lce_forward
|
||||
191
src/axolotl/integrations/liger/models/qwen3_moe.py
Normal file
191
src/axolotl/integrations/liger/models/qwen3_moe.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
Liger FLCE for Qwen3 MoE. Based on transformers v4.51.3.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
||||
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import load_balancing_loss_func
|
||||
|
||||
|
||||
def lce_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
) -> MoeCausalLMOutputWithPast:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_router_logits = (
|
||||
output_router_logits
|
||||
if output_router_logits is not None
|
||||
else self.config.output_router_logits
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_router_logits=output_router_logits,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = None
|
||||
loss = None
|
||||
# if in training mode, don't materialize logits
|
||||
if self.training and (labels is not None):
|
||||
loss = LigerForCausalLMLoss(
|
||||
hidden_states=hidden_states,
|
||||
lm_head_weight=self.lm_head.weight,
|
||||
labels=labels,
|
||||
hidden_size=self.config.hidden_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else: # if in inference mode materialize logits
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(
|
||||
outputs.router_logits,
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
attention_mask,
|
||||
)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss.to(
|
||||
loss.device
|
||||
) # make sure to reside in the same device
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def apply_liger_kernel_to_qwen3_moe(
|
||||
cross_entropy: bool = False,
|
||||
fused_linear_cross_entropy: bool = False,
|
||||
rms_norm: bool = False,
|
||||
glu_activation: bool = False,
|
||||
layer_norm: bool = False,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
"""
|
||||
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
|
||||
|
||||
Args:
|
||||
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
||||
fused_linear_cross_entropy (bool):
|
||||
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
||||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
|
||||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
||||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
||||
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
||||
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
||||
"""
|
||||
|
||||
import transformers.models.qwen3_moe.modeling_qwen3_moe # noqa: F401 # pylint: disable=unused-import
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
assert not (
|
||||
cross_entropy and fused_linear_cross_entropy
|
||||
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
|
||||
|
||||
modeling_qwen3_moe = sys.modules["transformers.models.qwen3_moe.modeling_qwen3_moe"]
|
||||
|
||||
if rms_norm:
|
||||
modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
|
||||
|
||||
if glu_activation:
|
||||
|
||||
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
||||
"Accepts intermediate_size to pass to LigerSwiGLUMLP"
|
||||
# clone config to avoid modifying the original
|
||||
config = deepcopy(config)
|
||||
if intermediate_size:
|
||||
setattr(config, "intermediate_size", intermediate_size)
|
||||
return LigerSwiGLUMLP(config, **kwargs)
|
||||
|
||||
modeling_qwen3_moe.Qwen3MoeMLP = _liger_swiglu_mlp_wrapper
|
||||
|
||||
if layer_norm:
|
||||
modeling_qwen3_moe.nn.LayerNorm = LigerLayerNorm
|
||||
|
||||
if cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
|
||||
if fused_linear_cross_entropy:
|
||||
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = lce_forward
|
||||
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
attention module for attention monkeypatches
|
||||
"""
|
||||
|
||||
from transformers.integrations.flash_attention import flash_attention_forward
|
||||
|
||||
|
||||
def patch_xformers_attn_over_fa2():
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from .xformers import xformers_attention_forward
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = xformers_attention_forward
|
||||
|
||||
|
||||
def unpatch_xformers_attn_over_fa2():
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward()
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
# flake8: noqa
|
||||
|
||||
from .patch import (
|
||||
RingAttnFunc,
|
||||
get_ring_attn_group,
|
||||
register_ring_attn,
|
||||
set_ring_attn_group,
|
||||
|
||||
@@ -16,11 +16,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
import transformers.modeling_flash_attention_utils
|
||||
from ring_flash_attn import (
|
||||
ring_flash_attn_func,
|
||||
stripe_flash_attn_func,
|
||||
zigzag_ring_flash_attn_func,
|
||||
)
|
||||
from ring_flash_attn import ring_flash_attn_func
|
||||
from ring_flash_attn.adapters.hf_adapter import check_params
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size,
|
||||
@@ -28,12 +24,12 @@ from transformers.modeling_flash_attention_utils import (
|
||||
)
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
RING_ATTN_FUNC_MAPPING = {
|
||||
RingAttnFunc.BATCH_RING: ring_flash_attn_func,
|
||||
RingAttnFunc.BATCH_ZIGZAG: zigzag_ring_flash_attn_func,
|
||||
RingAttnFunc.BATCH_STRIPE: stripe_flash_attn_func,
|
||||
RingAttnFunc.BATCH_RING: torch.compile(ring_flash_attn_func),
|
||||
# RingAttnFunc.BATCH_ZIGZAG: torch.compile(zigzag_ring_flash_attn_func),
|
||||
# RingAttnFunc.BATCH_STRIPE: torch.compile(stripe_flash_attn_func),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -6,13 +6,12 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
|
||||
their sequence parallel version of Flash Attention 2.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -41,17 +40,6 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||
RING_ATTN_GROUP = ring_attn_group
|
||||
|
||||
|
||||
class RingAttnFunc(str, Enum):
|
||||
"""Enum class for supported `ring-flash-attn` implementations"""
|
||||
|
||||
# VARLEN_RING = "varlen_ring"
|
||||
# VARLEN_ZIGZAG = "varlen_zigzag"
|
||||
VARLEN_LLAMA3 = "varlen_llama3"
|
||||
BATCH_RING = "batch_ring"
|
||||
BATCH_ZIGZAG = "batch_zigzag"
|
||||
BATCH_STRIPE = "batch_stripe"
|
||||
|
||||
|
||||
def register_ring_attn(
|
||||
sequence_parallel_degree: int,
|
||||
heads_k_stride: int | None,
|
||||
@@ -117,11 +105,7 @@ def register_ring_attn(
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
|
||||
)
|
||||
elif ring_attn_func in [
|
||||
RingAttnFunc.BATCH_RING,
|
||||
RingAttnFunc.BATCH_ZIGZAG,
|
||||
RingAttnFunc.BATCH_STRIPE,
|
||||
]:
|
||||
elif ring_attn_func is RingAttnFunc.BATCH_RING:
|
||||
from axolotl.monkeypatch.attention.ring_attn.adapters.batch import (
|
||||
substitute_hf_flash_attn,
|
||||
)
|
||||
|
||||
160
src/axolotl/monkeypatch/attention/xformers.py
Normal file
160
src/axolotl/monkeypatch/attention/xformers.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
xformers attention implementation for packing
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import xformers
|
||||
import xformers.ops.fmha
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_upad_input,
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
xformers_attention = xformers.ops.fmha.memory_efficient_attention
|
||||
|
||||
|
||||
def xformers_attention_forward(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
dropout: float = 0.0, # pylint: disable=unused-argument
|
||||
scaling: Optional[float] = None, # pylint: disable=unused-argument
|
||||
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
|
||||
softcap: Optional[float] = None, # pylint: disable=unused-argument
|
||||
cu_seq_lens_q: Optional[torch.LongTensor] = None,
|
||||
cu_seq_lens_k: Optional[torch.LongTensor] = None,
|
||||
max_length_q: Optional[int] = None,
|
||||
max_length_k: Optional[int] = None, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
# Get dimensions
|
||||
# query: [batch, heads, seq_len, hidden_dim]
|
||||
batch_size = query.size(0)
|
||||
query_length = query.shape[2]
|
||||
key_length = key.shape[2]
|
||||
|
||||
# Default causal mask
|
||||
attn_bias = xformers.ops.LowerTriangularMask()
|
||||
|
||||
# Check if we have sliding window attention
|
||||
has_sliding_window = sliding_window is not None and sliding_window < query_length
|
||||
|
||||
# Transpose dimensions for xformers (Q: [b, h, s, d] -> [b, s, h, d])
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# Get GQA parameters
|
||||
num_attention_heads = module.config.num_attention_heads
|
||||
num_key_value_heads = module.config.num_key_value_heads
|
||||
head_dim = query.size(-1)
|
||||
is_gqa = num_attention_heads != num_key_value_heads
|
||||
n_groups = num_attention_heads // num_key_value_heads if is_gqa else 1
|
||||
|
||||
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
||||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
||||
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||
if position_ids is not None and (
|
||||
max_length_q is not None
|
||||
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
|
||||
):
|
||||
if cu_seq_lens_q is None or cu_seq_lens_k is None:
|
||||
cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0]
|
||||
cu_seq_lens_q = cu_seq_lens_q.squeeze()
|
||||
seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
|
||||
attn_bias = (
|
||||
xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
||||
q_seqlen=seq_lengths.tolist(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = query.reshape(-1, query.size(-2), query.size(-1))
|
||||
key = key.reshape(-1, key.size(-2), key.size(-1))
|
||||
value = value.reshape(-1, value.size(-2), value.size(-1))
|
||||
|
||||
# Handle GQA
|
||||
if is_gqa:
|
||||
key = key.repeat_interleave(n_groups, dim=2)
|
||||
value = value.repeat_interleave(n_groups, dim=2)
|
||||
|
||||
elif attention_mask is not None:
|
||||
query, key, value, _, cu_seq_lens, _ = _upad_input(
|
||||
query, key, value, attention_mask, query_length
|
||||
)
|
||||
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
|
||||
seq_lengths = []
|
||||
for i in range(len(cu_seq_lens_q) - 1):
|
||||
seq_lengths.append(cu_seq_lens_q[i + 1] - cu_seq_lens_q[i])
|
||||
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
||||
q_seqlen=seq_lengths,
|
||||
kv_seqlen=seq_lengths,
|
||||
)
|
||||
|
||||
# Handle GQA
|
||||
if is_gqa:
|
||||
key = key.repeat_interleave(n_groups, dim=2)
|
||||
value = value.repeat_interleave(n_groups, dim=2)
|
||||
else:
|
||||
# Handle Group Query Attention (GQA) using view/expand approach from reference
|
||||
key = key.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
|
||||
value = value.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
|
||||
key = key.expand(
|
||||
batch_size, key_length, num_key_value_heads, n_groups, head_dim
|
||||
)
|
||||
value = value.expand(
|
||||
batch_size, key_length, num_key_value_heads, n_groups, head_dim
|
||||
)
|
||||
|
||||
if module.training:
|
||||
key = key.reshape(batch_size, key_length, num_attention_heads, head_dim)
|
||||
value = value.reshape(batch_size, key_length, num_attention_heads, head_dim)
|
||||
|
||||
if has_sliding_window:
|
||||
query = query.view(
|
||||
1, batch_size * query_length, num_attention_heads, head_dim
|
||||
)
|
||||
key = key.view(
|
||||
1, batch_size * key_length, num_attention_heads, head_dim
|
||||
)
|
||||
value = value.view(
|
||||
1, batch_size * key_length, num_attention_heads, head_dim
|
||||
)
|
||||
else:
|
||||
query = query.view(
|
||||
batch_size, query_length, num_key_value_heads, n_groups, head_dim
|
||||
)
|
||||
|
||||
# If we need a sliding window attention
|
||||
if has_sliding_window:
|
||||
query = query.view(
|
||||
1,
|
||||
batch_size * query_length,
|
||||
num_key_value_heads,
|
||||
n_groups,
|
||||
head_dim,
|
||||
)
|
||||
key = key.view(
|
||||
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
|
||||
)
|
||||
value = value.view(
|
||||
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
|
||||
)
|
||||
|
||||
# Run the xformers attention
|
||||
attn_output = xformers_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(
|
||||
batch_size, -1, attn_output.size(-2), attn_output.size(-1)
|
||||
)
|
||||
return attn_output, None
|
||||
@@ -18,6 +18,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"mixtral",
|
||||
"qwen2",
|
||||
"qwen2_moe",
|
||||
"qwen3",
|
||||
"qwen3_moe",
|
||||
"falcon",
|
||||
"phi",
|
||||
"phi3",
|
||||
|
||||
0
src/axolotl/monkeypatch/peft/__init__.py
Normal file
0
src/axolotl/monkeypatch/peft/__init__.py
Normal file
78
src/axolotl/monkeypatch/peft/utils.py
Normal file
78
src/axolotl/monkeypatch/peft/utils.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Patch prepare_model_for_kbit_training to not upcast everything
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
import peft
|
||||
|
||||
import axolotl
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
ORIGINAL_PREPARE_CODE = """
|
||||
for param in model.parameters():
|
||||
if (
|
||||
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
|
||||
) and param.__class__.__name__ != "Params4bit":
|
||||
param.data = param.data.to(torch.float32)
|
||||
"""
|
||||
|
||||
PATCHED_PREPARE_CODE = """
|
||||
for name, param in model.named_parameters():
|
||||
if (
|
||||
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
|
||||
) and param.__class__.__name__ != "Params4bit" and all(embed_name not in name for embed_name in ["embed_tokens", "lm_head"]):
|
||||
param.data = param.data.to(torch.float32)
|
||||
"""
|
||||
|
||||
|
||||
def get_peft_prep_code() -> str:
|
||||
prepare = inspect.getsource(peft.utils.other.prepare_model_for_kbit_training)
|
||||
return prepare
|
||||
|
||||
|
||||
def check_peft_prep_code_is_patchable() -> bool:
|
||||
prep_code = get_peft_prep_code()
|
||||
prep_code, _ = detab_code(prep_code)
|
||||
return ORIGINAL_PREPARE_CODE in prep_code
|
||||
|
||||
|
||||
def patch_peft_prep_code():
|
||||
"""
|
||||
monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs
|
||||
"""
|
||||
|
||||
try:
|
||||
prep_code = get_peft_prep_code()
|
||||
except OSError:
|
||||
return
|
||||
peft.utils.other._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access
|
||||
prep_code
|
||||
)
|
||||
prep_code, _ = detab_code(prep_code)
|
||||
if ORIGINAL_PREPARE_CODE not in prep_code:
|
||||
return
|
||||
|
||||
prep_code = prep_code.replace(ORIGINAL_PREPARE_CODE, PATCHED_PREPARE_CODE)
|
||||
prep_code = prep_code.replace(
|
||||
"def prepare_model_for_kbit_training(",
|
||||
"def fixed_prepare_model_for_kbit_training(",
|
||||
1,
|
||||
)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(peft.utils.other):
|
||||
if item in prep_code:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from peft.utils.other import (" + ", ".join(x for x in items_to_import) + ")",
|
||||
globals(),
|
||||
)
|
||||
exec(prep_code, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching prepare_model_for_kbit_training to allow for overrides")
|
||||
peft.utils.other.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821
|
||||
axolotl.utils.models.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821
|
||||
@@ -2,17 +2,17 @@
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import weakref
|
||||
from contextlib import nullcontext
|
||||
from contextlib import ExitStack
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import transformers.modelcard
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import save_fsdp_model
|
||||
from datasets import Dataset
|
||||
from huggingface_hub.errors import OfflineModeIsEnabled
|
||||
@@ -21,19 +21,19 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.common.datasets import TrainDatasetMeta
|
||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||
fix_untrained_tokens,
|
||||
)
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.core.trainers.mixins.sequence_parallel import (
|
||||
SequenceParallelContextManager,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
try:
|
||||
@@ -41,7 +41,7 @@ try:
|
||||
except ImportError:
|
||||
BetterTransformer = None
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_model_and_tokenizer(
|
||||
@@ -62,7 +62,6 @@ def setup_model_and_tokenizer(
|
||||
# Load tokenizer
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
main_process_only=True,
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
@@ -107,7 +106,7 @@ def setup_reference_model(
|
||||
Reference model if needed for RL training, `None` otherwise.
|
||||
"""
|
||||
model_ref = None
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
if cfg.rl and cfg.rl != RLType.ORPO:
|
||||
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
# use built-in trl autounwrap
|
||||
LOG.debug("Passing model_ref: None to RL trainer")
|
||||
@@ -188,28 +187,32 @@ def execute_training(
|
||||
trainer: The configured trainer object.
|
||||
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
||||
"""
|
||||
# Define the context managers to use
|
||||
flash_context = (
|
||||
torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=True,
|
||||
)
|
||||
if cfg.flash_optimum
|
||||
else nullcontext()
|
||||
)
|
||||
sequence_parallel_context = (
|
||||
SequenceParallelContextManager(
|
||||
model=trainer.model,
|
||||
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
||||
ring_attn_func=cfg.ring_attn_func,
|
||||
)
|
||||
if cfg.sequence_parallel_degree > 1
|
||||
else nullcontext()
|
||||
)
|
||||
with ExitStack() as stack:
|
||||
# Define the context managers to use
|
||||
if cfg.flash_optimum:
|
||||
stack.enter_context(
|
||||
torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=True,
|
||||
)
|
||||
)
|
||||
|
||||
LOG.info("Starting trainer...")
|
||||
with flash_context, sequence_parallel_context:
|
||||
if cfg.sequence_parallel_degree > 1:
|
||||
models = [trainer.model]
|
||||
if hasattr(trainer, "ref_model"):
|
||||
models.append(trainer.ref_model)
|
||||
|
||||
stack.enter_context(
|
||||
SequenceParallelContextManager(
|
||||
models=models,
|
||||
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
ring_attn_func=cfg.ring_attn_func,
|
||||
)
|
||||
)
|
||||
|
||||
LOG.info("Starting trainer...")
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
|
||||
@@ -516,6 +519,8 @@ def train(
|
||||
Returns:
|
||||
Tuple of (model, tokenizer) after training
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
|
||||
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
||||
(
|
||||
trainer,
|
||||
@@ -525,6 +530,9 @@ def train(
|
||||
processor,
|
||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_trainer_create(cfg, trainer)
|
||||
|
||||
# Handle untrained tokens if configured
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
@@ -547,7 +555,6 @@ def train(
|
||||
if not cfg.use_ray:
|
||||
cleanup_distributed()
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_train(cfg, model)
|
||||
|
||||
return model, tokenizer, trainer
|
||||
|
||||
@@ -43,3 +43,12 @@ def set_pytorch_cuda_alloc_conf():
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
|
||||
"expandable_segments:True,roundup_power2_divisions:16"
|
||||
)
|
||||
|
||||
|
||||
def patch_optimized_env():
|
||||
"""
|
||||
Patch environment variables to improve VRAM usage and increase download speed
|
||||
"""
|
||||
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
@@ -868,3 +868,28 @@ class GCCallback(TrainerCallback):
|
||||
):
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
def colab_inference_post_train_callback(trainer: Trainer):
|
||||
class ColabCallback(TrainerCallback):
|
||||
"""Callback to prep model for inference on Google Colab"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.gpu_name = torch.cuda.get_device_name(0)
|
||||
self.cfg = cfg
|
||||
|
||||
def on_train_end(
|
||||
self, args, state, control, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
handle T4 gpu, we need to convert attention to eager for inference
|
||||
"""
|
||||
if "Tesla T4" in self.gpu_name and self.cfg.xformers_attention:
|
||||
trainer.model.config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"eager"
|
||||
)
|
||||
trainer.model.gradient_checkpointing_disable()
|
||||
trainer.model.config.use_cache = True
|
||||
trainer.model.eval()
|
||||
|
||||
return ColabCallback
|
||||
|
||||
@@ -59,7 +59,7 @@ def choose_device(cfg):
|
||||
|
||||
def resolve_dtype(cfg):
|
||||
if (
|
||||
cfg.bf16 == "auto" and not cfg.use_ray
|
||||
not cfg.fp16 and cfg.bf16 == "auto" and not cfg.use_ray
|
||||
): # if we use ray we want to defer this check to the worker node
|
||||
if is_torch_bf16_gpu_available():
|
||||
LOG.debug("bf16 support detected, enabling for this configuration.")
|
||||
@@ -70,6 +70,9 @@ def resolve_dtype(cfg):
|
||||
if cfg.fp16 is None and not cfg.float16:
|
||||
cfg.fp16 = True
|
||||
|
||||
if cfg.fp16 and cfg.bf16 == "auto":
|
||||
cfg.bf16 = False
|
||||
|
||||
if cfg.device == "mps":
|
||||
cfg.load_in_8bit = False
|
||||
cfg.tf32 = False
|
||||
|
||||
6
src/axolotl/utils/ctx_managers/__init__.py
Normal file
6
src/axolotl/utils/ctx_managers/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Init for context manager submodule"""
|
||||
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .sequence_parallel import SequenceParallelContextManager
|
||||
335
src/axolotl/utils/ctx_managers/sequence_parallel.py
Normal file
335
src/axolotl/utils/ctx_managers/sequence_parallel.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import (
|
||||
get_ring_attn_group,
|
||||
update_ring_attn_params,
|
||||
)
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
|
||||
# TODO(djsaunde): implement zigzag, stripe patterns here (and elsewhere) in this
|
||||
# module. Currently, we just focus on batch ring and varlen llama3 for simplicity.
|
||||
def apply_sequence_parallelism(
|
||||
batch: dict[str, torch.Tensor],
|
||||
local_rank: int,
|
||||
local_world_size: int,
|
||||
gradient_accumulation_steps: int,
|
||||
ring_attn_func: RingAttnFunc, # pylint: disable=unused-argument
|
||||
) -> tuple[dict[str, torch.Tensor], int, int]:
|
||||
"""
|
||||
Apply sequence parallelism slicing to a batch.
|
||||
|
||||
Special handling is implemented for integer logits_to_keep, which indicates
|
||||
to only keep the last N tokens in the sequence during generation.
|
||||
|
||||
Args:
|
||||
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
|
||||
local_rank: Local rank in the sequence parallel group.
|
||||
local_world_size: World size of the sequence parallel group.
|
||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||
ring_attn_func: Which ring attention function to use. Currently unused, but
|
||||
related to above TODO.
|
||||
|
||||
Returns:
|
||||
tuple of:
|
||||
- Batch dictionary with sliced tensors.
|
||||
- The original sequence length before padding.
|
||||
- The number of padding tokens added.
|
||||
"""
|
||||
original_seq_len = batch["input_ids"].size(1)
|
||||
|
||||
# Update ring attention params if needed
|
||||
if batch.get("position_ids") is not None:
|
||||
update_ring_attn_params(position_ids=batch["position_ids"])
|
||||
else:
|
||||
# If position_ids aren't already in the batch, create them
|
||||
batch["position_ids"] = torch.arange(
|
||||
0,
|
||||
original_seq_len,
|
||||
dtype=torch.long,
|
||||
device=batch["input_ids"].device,
|
||||
).expand(batch["input_ids"].size(0), -1)
|
||||
|
||||
if "logits_to_keep" in batch and isinstance(batch["logits_to_keep"], int):
|
||||
logits_to_keep = batch["logits_to_keep"]
|
||||
|
||||
# Calculate which positions in the full sequence contain the last N tokens
|
||||
start_position = max(0, original_seq_len - logits_to_keep)
|
||||
chunk_size = original_seq_len // local_world_size
|
||||
rank_start = local_rank * chunk_size
|
||||
rank_end = rank_start + chunk_size
|
||||
|
||||
# Create a boolean mask tensor for this rank's chunk
|
||||
mask = torch.zeros(
|
||||
chunk_size,
|
||||
dtype=torch.bool,
|
||||
device=batch["input_ids"].device,
|
||||
)
|
||||
|
||||
if rank_end > start_position:
|
||||
# Calculate how many of the last N tokens fall within this rank's range
|
||||
tokens_in_rank = min(rank_end, original_seq_len) - max(
|
||||
rank_start, start_position
|
||||
)
|
||||
|
||||
# Calculate where these tokens start in the local chunk
|
||||
local_start_idx = max(0, start_position - rank_start)
|
||||
|
||||
# Set the appropriate positions in the mask to True
|
||||
mask[local_start_idx : local_start_idx + tokens_in_rank] = True
|
||||
|
||||
# Replace the integer with the boolean mask
|
||||
batch["logits_to_keep"] = mask
|
||||
|
||||
# Add padding to make sequence length divisible by local_world_size
|
||||
total_seq_len = original_seq_len
|
||||
pad_len = 0
|
||||
divisor = min(local_world_size, 64)
|
||||
if total_seq_len % divisor != 0:
|
||||
pad_len = divisor - (total_seq_len % divisor)
|
||||
|
||||
# Apply padding to all relevant tensors
|
||||
for key in batch:
|
||||
if (
|
||||
isinstance(batch[key], torch.Tensor)
|
||||
and batch[key].dim() > 1
|
||||
and batch[key].size(1) == total_seq_len
|
||||
):
|
||||
# Create padding tensor
|
||||
pad_value = -100 if key == "labels" else 0
|
||||
padding = torch.full(
|
||||
(batch[key].size(0), pad_len, *batch[key].shape[2:]),
|
||||
pad_value,
|
||||
dtype=batch[key].dtype,
|
||||
device=batch[key].device,
|
||||
)
|
||||
|
||||
# Concatenate padding to the right side of the tensor
|
||||
batch[key] = torch.cat([batch[key], padding], dim=1)
|
||||
if key == "logits_to_keep":
|
||||
# Create padding tensor
|
||||
padding = torch.ones(
|
||||
1,
|
||||
dtype=batch[key].dtype,
|
||||
device=batch[key].device,
|
||||
)
|
||||
|
||||
# Concatenate padding to the right side of the tensor
|
||||
batch[key] = torch.cat([batch[key], padding], dim=0)
|
||||
|
||||
# Update the total sequence length after padding
|
||||
total_seq_len = batch["input_ids"].size(1)
|
||||
|
||||
# Slice batch for sequence parallel
|
||||
for key in batch:
|
||||
if not isinstance(batch[key], torch.Tensor) or batch[key].dim() <= 1:
|
||||
continue
|
||||
|
||||
# Split in sequential fashion and grab this rank's chunk
|
||||
if batch[key].size(1) == total_seq_len:
|
||||
batch[key] = (
|
||||
batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous()
|
||||
)
|
||||
elif key == "logits_to_keep":
|
||||
batch[key] = (
|
||||
batch[key].chunk(local_world_size, dim=0)[local_rank].contiguous()
|
||||
)
|
||||
|
||||
# Handle num_items_in_batch
|
||||
if "num_items_in_batch" in batch:
|
||||
# Approximation; this needed since num_items_in_batch may be counted across
|
||||
# all samples in a gradient accumulated batch, not on a per-step basis.
|
||||
batch["num_items_in_batch"] = (
|
||||
batch["labels"] != -100
|
||||
).sum() * gradient_accumulation_steps
|
||||
|
||||
return batch, original_seq_len, pad_len
|
||||
|
||||
|
||||
class SequenceParallelContextManager:
|
||||
"""Context manager for sequence parallelism operations.
|
||||
|
||||
This class provides a context that will automatically apply sequence parallelism
|
||||
during model forward passes using a pre-forward hook, and gather outputs from
|
||||
across the sequence parallelism group using a post-forward hook.
|
||||
|
||||
Args:
|
||||
models: List of models to apply sequence parallelism to pre- and post- forward
|
||||
hooks.
|
||||
sequence_parallel_degree: Number of processes to split sequences over.
|
||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||
ring_attn_func: Which ring attention function to use. Currently unused.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
models: list[nn.Module],
|
||||
sequence_parallel_degree: int,
|
||||
gradient_accumulation_steps: int,
|
||||
ring_attn_func: RingAttnFunc,
|
||||
):
|
||||
self.models = models
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
self.ring_attn_func = ring_attn_func
|
||||
self.process_group = get_ring_attn_group()
|
||||
|
||||
# Initialize sequence parallel group details
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.local_world_size = dist.get_world_size(self.process_group)
|
||||
|
||||
# Will store hook handles for removal
|
||||
self.hook_handles: list[RemovableHandle] = []
|
||||
|
||||
# Store original sequence length and padding information
|
||||
self.original_seq_len = 0
|
||||
self.pad_len = 0
|
||||
|
||||
# Create a partially applied version of the apply_sequence_parallelism function
|
||||
self.apply_sequence_parallelism = functools.partial(
|
||||
apply_sequence_parallelism,
|
||||
local_rank=self.local_rank,
|
||||
local_world_size=self.local_world_size,
|
||||
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
# Forward pre-hook to apply sequence parallelism
|
||||
def sequence_parallel_pre_hook(_, args, kwargs):
|
||||
# Apply sequence parallelism to kwargs and get original sequence length and padding info
|
||||
kwargs, self.original_seq_len, self.pad_len = (
|
||||
self.apply_sequence_parallelism(batch=kwargs)
|
||||
)
|
||||
|
||||
return args, kwargs
|
||||
|
||||
# Forward post-hook to gather outputs
|
||||
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
|
||||
# Gather the sharded outputs
|
||||
output = self.gather_outputs(output)
|
||||
|
||||
# Remove padding if it was added
|
||||
if self.pad_len > 0:
|
||||
for key, value in output.items():
|
||||
if isinstance(value, torch.Tensor) and value.dim() > 1:
|
||||
if value.size(1) == self.original_seq_len + self.pad_len:
|
||||
# Slice to remove padding
|
||||
output[key] = value[:, : self.original_seq_len].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
# Register both hooks
|
||||
for model in self.models:
|
||||
self.hook_handles.append(
|
||||
model.register_forward_pre_hook(
|
||||
sequence_parallel_pre_hook, with_kwargs=True
|
||||
)
|
||||
)
|
||||
self.hook_handles.append(
|
||||
model.register_forward_hook(sequence_parallel_post_hook)
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Remove all hooks
|
||||
for handle in self.hook_handles:
|
||||
handle.remove()
|
||||
self.hook_handles = []
|
||||
|
||||
def gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
|
||||
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
|
||||
for key, value in output.items():
|
||||
if isinstance(value, torch.Tensor) and value.dim() > 1:
|
||||
output[key] = AllGatherWithGrad.apply(value, self.process_group)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class AllGatherWithGrad(torch.autograd.Function):
|
||||
"""Custom autograd function for all-gather to preserve gradients."""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
input_tensor: torch.Tensor,
|
||||
group: dist.ProcessGroup,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of all-gather of data with sequence dimension.
|
||||
|
||||
Args:
|
||||
ctx: `torch.autograd` function context.
|
||||
input_tensor: Tensor from model output with sequence dimension.
|
||||
group: `torch.distributed` process group.
|
||||
|
||||
Returns:
|
||||
Tensor from gathering the `input_tensor` from across the process group and
|
||||
concatenating along the sequence dimension.
|
||||
"""
|
||||
ctx.group = group
|
||||
ctx.rank = dist.get_rank(group)
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
# Gather shape metadata
|
||||
local_shape = torch.tensor(list(input_tensor.shape), device=input_tensor.device)
|
||||
all_shapes = [torch.zeros_like(local_shape) for _ in range(world_size)]
|
||||
dist.all_gather(all_shapes, local_shape, group=group)
|
||||
|
||||
# Store sequence lengths for backward pass
|
||||
seq_lens = [int(shape[1].item()) for shape in all_shapes]
|
||||
ctx.seq_lens = seq_lens
|
||||
|
||||
# Perform all_gather operation
|
||||
gathered = [
|
||||
torch.zeros(
|
||||
tuple(shape.tolist()),
|
||||
dtype=input_tensor.dtype,
|
||||
device=input_tensor.device,
|
||||
)
|
||||
for shape in all_shapes
|
||||
]
|
||||
dist.all_gather(gathered, input_tensor, group=group)
|
||||
|
||||
# Concatenate tensors along sequence dimension
|
||||
result = torch.cat(gathered, dim=1)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
"""
|
||||
Backward pass for all-gather operation.
|
||||
|
||||
Extracts the gradient slice corresponding to this rank's original input
|
||||
from the full gradient tensor.
|
||||
|
||||
Args:
|
||||
ctx: `torch.autograd` function context.
|
||||
grad_output: Gradient from subsequent layers with respect to the
|
||||
concatenated output tensor.
|
||||
|
||||
Returns:
|
||||
Tuple containing the gradient slice for this rank's input tensor and `None`
|
||||
for the process group parameter which doesn't require gradients.
|
||||
"""
|
||||
rank = ctx.rank
|
||||
seq_lens = ctx.seq_lens
|
||||
|
||||
# Extract gradient for this rank's chunk
|
||||
offset = sum(seq_lens[:rank])
|
||||
grad_slice = grad_output[:, offset : offset + seq_lens[rank]].contiguous()
|
||||
|
||||
return grad_slice, None
|
||||
@@ -11,6 +11,7 @@ from torch.utils.data import RandomSampler
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.data.utils import DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
|
||||
|
||||
@@ -250,6 +251,22 @@ def encode_packed_pretraining(
|
||||
# pylint: disable=duplicate-code
|
||||
# tokenize all the examples
|
||||
# rows get split with stride (overlap)
|
||||
"""
|
||||
Encodes and packs input examples into fixed-length batches for pretraining with optional multipack attention.
|
||||
|
||||
Wraps and processes input examples into a dataset, applies sequence packing with configurable overflow handling, and batches the data using a multipack sampler. Each batch is collated and features are aggregated into lists keyed by feature name.
|
||||
|
||||
Args:
|
||||
collate_fn: Function to collate individual feature dictionaries into batch tensors.
|
||||
ds_wrapper: Callable that wraps a Hugging Face Dataset for further processing.
|
||||
examples: Dictionary of input examples to encode and pack.
|
||||
max_seq_length: Maximum sequence length for each packed sequence.
|
||||
batch_size: Number of sequences to pack per batch.
|
||||
multipack_attn: If True, enables multipack attention and drops attention masks.
|
||||
|
||||
Returns:
|
||||
Dictionary where each key is a feature name and each value is a list of packed feature tensors.
|
||||
"""
|
||||
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
|
||||
|
||||
train_dataset = process_pretraining_datasets_for_packing(
|
||||
@@ -259,6 +276,10 @@ def encode_packed_pretraining(
|
||||
# FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm
|
||||
# workaround by using the position id logic for now in trainer
|
||||
drop_attention_mask=multipack_attn,
|
||||
# pass through handling mode from config via ds_wrapper function
|
||||
handling=getattr(ds_wrapper, "cfg", {}).get(
|
||||
"sequence_len_overflow_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
|
||||
),
|
||||
)
|
||||
|
||||
sampler = MultipackBatchSampler(
|
||||
|
||||
@@ -18,8 +18,9 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_path(ds_hash, cfg):
|
||||
@@ -78,9 +79,34 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
||||
|
||||
|
||||
def drop_long_rl_seq(
|
||||
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
|
||||
sample,
|
||||
rl,
|
||||
tokenizer,
|
||||
sequence_len,
|
||||
handling="drop", # Use the default handling mode
|
||||
):
|
||||
if rl in ("dpo", "ipo", "orpo", "simpo"):
|
||||
"""
|
||||
Handles samples exceeding a maximum sequence length for various RL dataset types by either truncating or dropping them.
|
||||
|
||||
Depending on the RL type and the `handling` mode, this function either truncates response fields to fit within the specified sequence length or determines whether the sample should be dropped. For DPO, IPO, ORPO, and SIMPO types, both "chosen" and "rejected" responses are considered; for KTO, the "completion" is considered. For GRPO, samples are always retained. If truncation is not possible (e.g., the prompt alone exceeds the limit), the sample is returned unchanged for mapping, or dropped during filtering.
|
||||
|
||||
Args:
|
||||
sample: A dictionary representing a single dataset sample.
|
||||
rl: The RLType indicating the dataset type.
|
||||
tokenizer: The tokenizer used to compute token lengths and perform truncation.
|
||||
sequence_len: The maximum allowed sequence length.
|
||||
handling: Specifies how to handle overlong sequences ("drop" or "truncate").
|
||||
|
||||
Returns:
|
||||
For "truncate": The modified sample with responses truncated as needed, or the original sample if truncation is not possible.
|
||||
For "drop": True if the sample fits within the sequence length, otherwise False.
|
||||
|
||||
Raises:
|
||||
ValueError: If required keys are missing for the specified RL type, or if the RL type is unknown.
|
||||
"""
|
||||
result = None
|
||||
|
||||
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
|
||||
if not (
|
||||
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
||||
):
|
||||
@@ -96,11 +122,65 @@ def drop_long_rl_seq(
|
||||
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
|
||||
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
||||
|
||||
return (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len
|
||||
# Truncate first, then drop if still invalid (although truncate should handle it)
|
||||
if handling == "truncate":
|
||||
# If both sequences fit, return sample unchanged
|
||||
if (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len:
|
||||
result = sample
|
||||
else:
|
||||
# Calculate maximum response length that can fit with the prompt
|
||||
max_response_len = sequence_len - len_prompt
|
||||
|
||||
if rl == "kto":
|
||||
if max_response_len <= 0:
|
||||
# Prompt is already too long, behavior depends on handling
|
||||
# If truncate is chosen, we technically can't truncate, but drop seems harsh.
|
||||
# Returning the sample might be unexpected. Let's stick to the filter logic
|
||||
# which would drop this in the `filter` step later if needed.
|
||||
# For now, return sample to map, or False to filter.
|
||||
# Let's simplify: truncate *should* result in a valid sample if possible.
|
||||
# If prompt >= seq_len, truncate won't work. Filter will catch this later.
|
||||
# So, if max_response_len <= 0, we pass it through for map, drop for filter.
|
||||
# However, the filter/map logic is applied *after* this function.
|
||||
# This function needs to return the *modified* sample for map, or bool for filter.
|
||||
|
||||
# Re-think: If handling==truncate, return the modified sample if possible.
|
||||
# If prompt >= seq_len, modification is impossible. What should map return?
|
||||
# Maybe return the original sample? But map expects *modified* sample.
|
||||
# Let's stick to the original logic: if prompt is too long, return False for filter
|
||||
# and original sample for map.
|
||||
|
||||
result = (
|
||||
sample # For map, let downstream handle it if still invalid?
|
||||
)
|
||||
# Or maybe return None/empty dict? Let's return sample for now.
|
||||
# If handling was drop, filter would remove this.
|
||||
|
||||
else:
|
||||
# Truncate the chosen and rejected responses if needed
|
||||
if len_chosen > max_response_len:
|
||||
chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
|
||||
"input_ids"
|
||||
][:max_response_len]
|
||||
sample["chosen"] = tokenizer.decode(
|
||||
chosen_tokens, skip_special_tokens=True
|
||||
)
|
||||
|
||||
if len_rejected > max_response_len:
|
||||
rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
|
||||
"input_ids"
|
||||
][:max_response_len]
|
||||
sample["rejected"] = tokenizer.decode(
|
||||
rejected_tokens, skip_special_tokens=True
|
||||
)
|
||||
result = sample
|
||||
else: # handling == "drop"
|
||||
result = (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len
|
||||
|
||||
elif rl == RLType.KTO:
|
||||
if not (sample.get("prompt") and sample.get("completion")):
|
||||
raise ValueError("Prompt and completion keys are required for KTO datasets")
|
||||
|
||||
@@ -112,15 +192,54 @@ def drop_long_rl_seq(
|
||||
tokenizer(completion, add_special_tokens=False)["input_ids"]
|
||||
)
|
||||
|
||||
return (len_prompt + len_completion) <= sequence_len
|
||||
# Truncate first
|
||||
if handling == "truncate":
|
||||
# If sequence fits, return sample unchanged
|
||||
if (len_prompt + len_completion) <= sequence_len:
|
||||
result = sample
|
||||
else:
|
||||
# Calculate maximum completion length
|
||||
max_completion_len = sequence_len - len_prompt
|
||||
|
||||
if rl == "grpo":
|
||||
return True
|
||||
if max_completion_len <= 0:
|
||||
# Prompt too long, return sample for map
|
||||
result = sample
|
||||
else:
|
||||
# Truncate the completion if needed
|
||||
if len_completion > max_completion_len:
|
||||
completion_tokens = tokenizer(
|
||||
completion, add_special_tokens=False
|
||||
)["input_ids"][:max_completion_len]
|
||||
sample["completion"] = tokenizer.decode(
|
||||
completion_tokens, skip_special_tokens=True
|
||||
)
|
||||
result = sample
|
||||
else: # handling == "drop"
|
||||
result = (len_prompt + len_completion) <= sequence_len
|
||||
|
||||
raise ValueError("Unknown RL type")
|
||||
elif rl == RLType.GRPO:
|
||||
# GRPO doesn't involve sequence length checks in the same way?
|
||||
# The original code returned True for drop. What should it return for truncate?
|
||||
# Let's assume for now it always passes.
|
||||
result = sample if handling == "truncate" else True
|
||||
else:
|
||||
raise ValueError("Unknown RL type")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def load_prepare_preference_datasets(cfg):
|
||||
"""
|
||||
Loads, preprocesses, and prepares preference datasets for RL training and evaluation.
|
||||
|
||||
This function orchestrates the loading, transformation, sequence length handling, optional deduplication, and caching of datasets for Direct Preference Optimization (DPO) and related RL types. It supports configurable handling of overlong sequences (dropping or truncating), applies dataset-specific transformations, and manages train/validation/test splits as needed.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object specifying dataset sources, RL type, tokenizer, sequence length, and processing options.
|
||||
|
||||
Returns:
|
||||
A tuple containing the prepared training and evaluation datasets.
|
||||
"""
|
||||
def load_split(dataset_cfgs, _cfg):
|
||||
split_datasets: List[Any] = []
|
||||
use_auth_token = _cfg.hf_use_auth_token
|
||||
@@ -137,9 +256,9 @@ def load_prepare_preference_datasets(cfg):
|
||||
if _type:
|
||||
if isinstance(_type, DictDefault):
|
||||
_type = "user_defined.default"
|
||||
if _cfg.rl == "orpo":
|
||||
if _cfg.rl is RLType.ORPO:
|
||||
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
|
||||
elif _cfg.rl == "kto":
|
||||
elif _cfg.rl is RLType.KTO:
|
||||
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
||||
else:
|
||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||
@@ -150,7 +269,7 @@ def load_prepare_preference_datasets(cfg):
|
||||
split_datasets[i] = map_dataset(
|
||||
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||
)
|
||||
elif _cfg.rl == "kto":
|
||||
elif _cfg.rl is RLType.KTO:
|
||||
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
||||
map_kwargs = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
@@ -164,28 +283,46 @@ def load_prepare_preference_datasets(cfg):
|
||||
split_datasets[i] = data_set
|
||||
|
||||
if not cfg.skip_prepare_dataset:
|
||||
# Determine handling mode
|
||||
handling = cfg.get("sequence_len_overflow_handling", "drop")
|
||||
|
||||
drop_long = partial(
|
||||
drop_long_rl_seq,
|
||||
rl=_cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
handling=handling, # Pass the handling mode
|
||||
)
|
||||
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} long samples from dataset index {i}"
|
||||
|
||||
# Use map for truncate mode and filter for drop mode
|
||||
if handling == "truncate":
|
||||
split_datasets[i] = split_datasets[i].map(
|
||||
drop_long, # Function now returns modified sample or original
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Truncating Long Sequences",
|
||||
)
|
||||
# Note: Length might not change if truncation always occurs
|
||||
LOG.info(
|
||||
f"Processed dataset index {i} with truncation handling for sequence length {cfg.sequence_len}"
|
||||
)
|
||||
else: # handling == "drop"
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long, # Function now returns boolean
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} long samples from dataset index {i}"
|
||||
)
|
||||
|
||||
combined_datasets = concatenate_datasets(split_datasets)
|
||||
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
||||
combined_datasets = combined_datasets.shuffle(seed=cfg.seed or 42)
|
||||
|
||||
return combined_datasets
|
||||
|
||||
@@ -205,6 +342,8 @@ def load_prepare_preference_datasets(cfg):
|
||||
eval_dataset = load_split(cfg.test_datasets, cfg)
|
||||
if not eval_dataset:
|
||||
if cfg.val_set_size:
|
||||
seed = cfg.seed if cfg.seed is not None else 42
|
||||
|
||||
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
||||
to_hash_train = (
|
||||
train_dataset._fingerprint # pylint: disable=protected-access
|
||||
@@ -213,7 +352,7 @@ def load_prepare_preference_datasets(cfg):
|
||||
+ "|"
|
||||
+ "train"
|
||||
+ "|"
|
||||
+ str(cfg.seed or 42)
|
||||
+ str(seed)
|
||||
)
|
||||
to_hash_test = (
|
||||
train_dataset._fingerprint # pylint: disable=protected-access
|
||||
@@ -222,13 +361,13 @@ def load_prepare_preference_datasets(cfg):
|
||||
+ "|"
|
||||
+ "test"
|
||||
+ "|"
|
||||
+ str(cfg.seed or 42)
|
||||
+ str(seed)
|
||||
)
|
||||
train_fingerprint = md5(to_hash_train)
|
||||
test_fingerprint = md5(to_hash_test)
|
||||
ds_w_test_split = train_dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
seed=cfg.seed,
|
||||
seed=seed,
|
||||
shuffle=False,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
|
||||
@@ -148,7 +148,7 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
||||
ds_wrapper_partial,
|
||||
max_tokens=cfg.sequence_len,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
seed=cfg.seed or 42,
|
||||
seed=cfg.seed if cfg.seed is not None else 42,
|
||||
buffer_size=cfg.pretrain_multipack_buffer_size or 10_000,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
@@ -416,6 +416,8 @@ def load_prepare_datasets(
|
||||
)
|
||||
|
||||
if split == "train" and val_set_size:
|
||||
seed = cfg.seed if cfg.seed is not None else 42
|
||||
|
||||
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
||||
to_hash_train = (
|
||||
dataset._fingerprint # pylint: disable=protected-access
|
||||
@@ -424,7 +426,7 @@ def load_prepare_datasets(
|
||||
+ "|"
|
||||
+ "train"
|
||||
+ "|"
|
||||
+ str(cfg.seed or 42)
|
||||
+ str(seed)
|
||||
)
|
||||
to_hash_test = (
|
||||
dataset._fingerprint # pylint: disable=protected-access
|
||||
@@ -433,7 +435,7 @@ def load_prepare_datasets(
|
||||
+ "|"
|
||||
+ "test"
|
||||
+ "|"
|
||||
+ str(cfg.seed or 42)
|
||||
+ str(seed)
|
||||
)
|
||||
train_fingerprint = md5(to_hash_train)
|
||||
test_fingerprint = md5(to_hash_test)
|
||||
@@ -442,7 +444,7 @@ def load_prepare_datasets(
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed or 42,
|
||||
seed=seed,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
|
||||
@@ -281,6 +281,10 @@ def load_dataset_w_config(
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
if not ds:
|
||||
raise ValueError("unhandled dataset load")
|
||||
raise ValueError(
|
||||
"The dataset could not be loaded. This could be due to a misconfigured dataset path "
|
||||
f"({config_dataset.path}). Try double-check your path / name / data_files. "
|
||||
"This is not caused by the dataset type."
|
||||
)
|
||||
|
||||
return ds
|
||||
|
||||
@@ -13,10 +13,12 @@ from datasets import Dataset, IterableDataset
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.samplers.utils import get_dataset_lengths
|
||||
from axolotl.utils.trainer import drop_long_seq
|
||||
from axolotl.utils.trainer import truncate_or_drop_long_seq
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING = "drop"
|
||||
|
||||
|
||||
class RetryStrategy(Enum):
|
||||
"""
|
||||
@@ -159,16 +161,33 @@ def deduplicate_and_log_datasets(
|
||||
|
||||
|
||||
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
||||
"""
|
||||
Processes a dataset to handle sequences exceeding a configured maximum length by either truncating or dropping them.
|
||||
|
||||
If the dataset lacks an "input_ids" column, the function returns the dataset unchanged. The handling mode is determined by the configuration parameter "sequence_len_overflow_handling", defaulting to "drop". In "truncate" mode, sequences longer than the maximum length are truncated; in "drop" mode, such sequences are removed from the dataset. The function logs information about sequence lengths and the number of samples affected when applicable.
|
||||
|
||||
Args:
|
||||
dataset: The Huggingface Dataset to process.
|
||||
cfg: Configuration object specifying sequence length parameters and handling mode.
|
||||
|
||||
Returns:
|
||||
The processed dataset with long sequences either truncated or dropped according to the configuration.
|
||||
"""
|
||||
if "input_ids" not in dataset.column_names:
|
||||
LOG.warning(
|
||||
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling."
|
||||
)
|
||||
return dataset
|
||||
|
||||
drop_long = functools.partial(
|
||||
drop_long_seq,
|
||||
# Get the handling method from config, default to "drop" for backward compatibility
|
||||
handling = cfg.get("sequence_len_overflow_handling", "drop")
|
||||
|
||||
# Use the new function with the specified handling mode
|
||||
seq_handler = functools.partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=cfg.sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len,
|
||||
handling=handling,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -193,17 +212,31 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
||||
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
||||
if handling == "truncate":
|
||||
drop_long_kwargs["desc"] = "Truncating Long Sequences"
|
||||
else: # handling == "drop"
|
||||
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
||||
|
||||
dataset = dataset.filter(
|
||||
drop_long,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
if prior_len:
|
||||
dropped = prior_len - len(dataset)
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
||||
if handling == "truncate":
|
||||
# Use map for truncate mode
|
||||
dataset = dataset.map(
|
||||
seq_handler,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
LOG.info(f"Truncated long samples in dataset to {cfg.sequence_len} tokens")
|
||||
else: # handling == "drop"
|
||||
# Use filter for drop mode
|
||||
dataset = dataset.filter(
|
||||
seq_handler,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
if prior_len:
|
||||
dropped = prior_len - len(dataset)
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -1,16 +1,59 @@
|
||||
"""custom checkpointing utils"""
|
||||
|
||||
import importlib
|
||||
from functools import partial
|
||||
|
||||
from axolotl.utils.gradient_checkpointing.unsloth import (
|
||||
Unsloth_Offloaded_Gradient_Checkpointer,
|
||||
from packaging import version
|
||||
|
||||
from axolotl.utils.gradient_checkpointing.offload_cpu import (
|
||||
CPU_Offloaded_Gradient_Checkpointer,
|
||||
)
|
||||
from axolotl.utils.gradient_checkpointing.offload_disk import (
|
||||
Disco,
|
||||
)
|
||||
|
||||
transformers_version = version.parse(importlib.metadata.version("transformers"))
|
||||
if transformers_version > version.parse("4.51.3"):
|
||||
from transformers.modeling_layers import GradientCheckpointingLayer
|
||||
|
||||
def uses_gc_layers(decoder_layer):
|
||||
return isinstance(decoder_layer.func.__self__, GradientCheckpointingLayer)
|
||||
|
||||
else:
|
||||
|
||||
def uses_gc_layers(_):
|
||||
return False
|
||||
|
||||
|
||||
def hf_grad_checkpoint_offload_wrapper(
|
||||
decoder_layer, *args, use_reentrant=None
|
||||
): # pylint: disable=unused-argument
|
||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||
if uses_gc_layers(decoder_layer):
|
||||
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
||||
decoder_layer,
|
||||
*args,
|
||||
)
|
||||
|
||||
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
||||
(
|
||||
decoder_layer.func.__self__
|
||||
if isinstance(decoder_layer, partial)
|
||||
else decoder_layer.__self__
|
||||
),
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def hf_grad_checkpoint_disk_offload_wrapper(
|
||||
decoder_layer, *args, use_reentrant=None
|
||||
): # pylint: disable=unused-argument
|
||||
if uses_gc_layers(decoder_layer):
|
||||
return Disco.apply(
|
||||
decoder_layer,
|
||||
*args,
|
||||
)
|
||||
|
||||
return Disco.apply(
|
||||
(
|
||||
decoder_layer.func.__self__
|
||||
if isinstance(decoder_layer, partial)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Unsloth checkpointing"""
|
||||
"""CPU offloaded checkpointing"""
|
||||
|
||||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||
#
|
||||
@@ -26,7 +26,7 @@ else:
|
||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
|
||||
|
||||
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
torch.autograd.Function
|
||||
):
|
||||
"""
|
||||
531
src/axolotl/utils/gradient_checkpointing/offload_disk.py
Normal file
531
src/axolotl/utils/gradient_checkpointing/offload_disk.py
Normal file
@@ -0,0 +1,531 @@
|
||||
"""
|
||||
DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching
|
||||
"""
|
||||
|
||||
# Copyright 2025 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from concurrent.futures import Future
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
|
||||
# Setup logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiskOffloadManager:
|
||||
"""
|
||||
Manages offloaded tensors and handles prefetching in a separate thread.
|
||||
Includes synchronization to prevent race conditions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefetch_size: int = 3,
|
||||
prefetch_to_gpu: bool = True,
|
||||
save_workers: int = 4,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
prefetch_size: Maximum number of tensors to prefetch in the background.
|
||||
prefetch_to_gpu: Whether to prefetch tensors directly to GPU memory.
|
||||
save_workers: Maximum number of concurrent save operations.
|
||||
"""
|
||||
self.temp_dir = tempfile.mkdtemp(prefix="disco_")
|
||||
|
||||
# Track tensor paths and their status
|
||||
self.tensor_paths: deque = deque() # Ordered history of tensor paths (LIFO)
|
||||
self.file_locks: Dict[str, threading.Lock] = (
|
||||
{}
|
||||
) # Maps file_path -> threading.Lock()
|
||||
# Maps file_path -> status ("saving", "ready", "prefetching", "loaded", "deleted")
|
||||
self.file_status: Dict[str, str] = {}
|
||||
|
||||
self.max_prefetch = prefetch_size
|
||||
self.prefetch_to_gpu = prefetch_to_gpu
|
||||
|
||||
# Thread synchronization
|
||||
self.manager_lock = threading.RLock() # Used for thread-safe operations
|
||||
|
||||
# Prefetch queue and cache
|
||||
self.prefetch_queue: queue.Queue = queue.Queue()
|
||||
self.prefetch_cache: Dict[str, torch.Tensor] = {} # Maps file_path -> tensor
|
||||
|
||||
# Save queue and thread pool
|
||||
self.save_queue: queue.Queue = queue.Queue()
|
||||
self.save_pool = concurrent.futures.ThreadPoolExecutor(max_workers=save_workers)
|
||||
self.save_futures: Dict[str, Future] = {}
|
||||
self.save_semaphore = threading.Semaphore(
|
||||
save_workers * 2
|
||||
) # Limit concurrent save operations
|
||||
|
||||
# Start prefetch worker thread
|
||||
self.stop_event = threading.Event()
|
||||
# start multiple threads for prefetching
|
||||
self.prefetch_worker_count = 2
|
||||
self.prefetch_workers = []
|
||||
for _ in range(self.prefetch_worker_count):
|
||||
worker = threading.Thread(target=self._prefetch_worker, daemon=True)
|
||||
worker.start()
|
||||
self.prefetch_workers.append(worker)
|
||||
|
||||
# Start save worker thread
|
||||
self.save_worker = threading.Thread(target=self._save_worker, daemon=True)
|
||||
self.save_worker.start()
|
||||
self.idx = 0
|
||||
|
||||
atexit.register(self.cleanup)
|
||||
|
||||
def _save_worker(self):
|
||||
"""Background thread that processes the save queue"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
save_item = self.save_queue.get(timeout=0.5)
|
||||
if save_item is None:
|
||||
continue
|
||||
|
||||
tensor, file_path = save_item
|
||||
|
||||
# Submit the save task to the thread pool
|
||||
future = self.save_pool.submit(
|
||||
self._save_tensor_to_disk, tensor, file_path
|
||||
)
|
||||
with self.manager_lock:
|
||||
self.save_futures[file_path] = future
|
||||
|
||||
self.save_queue.task_done()
|
||||
|
||||
except queue.Empty:
|
||||
time.sleep(0.01) # Small sleep to prevent CPU spinning
|
||||
continue
|
||||
|
||||
def _save_tensor_to_disk(self, tensor: torch.Tensor, file_path: str):
|
||||
"""Actually save the tensor to disk"""
|
||||
try:
|
||||
# Save tensor to disk
|
||||
cpu_tensor = tensor.detach().cpu()
|
||||
torch.save(cpu_tensor, file_path)
|
||||
del cpu_tensor
|
||||
|
||||
with self.manager_lock:
|
||||
# Mark file as ready
|
||||
self.file_status[file_path] = "ready"
|
||||
|
||||
# Release semaphore
|
||||
self.save_semaphore.release()
|
||||
|
||||
return True
|
||||
except FileNotFoundError as e:
|
||||
logger.error(f"Error saving tensor to {file_path}: {e}")
|
||||
with self.manager_lock:
|
||||
self.file_status[file_path] = "error"
|
||||
|
||||
# Release semaphore
|
||||
self.save_semaphore.release()
|
||||
|
||||
return False
|
||||
|
||||
def _prefetch_worker(self):
|
||||
"""Background thread that loads tensors from disk ahead of time"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
file_path = self.prefetch_queue.get(timeout=0.5)
|
||||
if file_path is None:
|
||||
continue
|
||||
|
||||
# Check if file is available and not already in cache
|
||||
with self.manager_lock:
|
||||
if (
|
||||
file_path not in self.file_status
|
||||
or self.file_status[file_path] == "deleted"
|
||||
):
|
||||
self.prefetch_queue.task_done()
|
||||
if file_path in self.prefetch_cache:
|
||||
self.prefetch_queue.task_done()
|
||||
continue
|
||||
|
||||
# If file is still being saved, wait for it
|
||||
if (
|
||||
self.file_status[file_path] == "saving"
|
||||
and file_path in self.save_futures
|
||||
):
|
||||
# Re-queue this prefetch request with a little delay
|
||||
self.prefetch_queue.task_done()
|
||||
time.sleep(0.1)
|
||||
self.prefetch_queue.put(file_path)
|
||||
continue
|
||||
|
||||
# Mark file as being prefetched
|
||||
self.file_status[file_path] = "prefetching"
|
||||
|
||||
# Load tensor from disk and store in cache
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
if self.prefetch_to_gpu:
|
||||
tensor = torch.load(
|
||||
file_path,
|
||||
map_location=torch.device("cuda"),
|
||||
weights_only=True,
|
||||
)
|
||||
else:
|
||||
tensor = torch.load(file_path, weights_only=True)
|
||||
|
||||
with self.manager_lock:
|
||||
self.prefetch_cache[file_path] = tensor
|
||||
self.file_status[file_path] = "ready"
|
||||
else:
|
||||
with self.manager_lock:
|
||||
if self.file_status.get(file_path) != "deleted":
|
||||
logger.warning(
|
||||
f"Prefetch error: File not found {file_path}"
|
||||
)
|
||||
self.file_status[file_path] = "missing"
|
||||
|
||||
except FileNotFoundError as e:
|
||||
with self.manager_lock:
|
||||
if self.file_status.get(file_path) != "deleted":
|
||||
logger.warning(f"Prefetch error for {file_path}: {e}")
|
||||
self.file_status[file_path] = "error"
|
||||
|
||||
self.prefetch_queue.task_done()
|
||||
|
||||
except queue.Empty:
|
||||
time.sleep(0.01) # Small sleep to prevent CPU spinning
|
||||
continue
|
||||
|
||||
def save_tensor(self, tensor: torch.Tensor):
|
||||
"""Save tensor to disk asynchronously and return file path with thread-safe operations"""
|
||||
# Generate unique file path
|
||||
self.idx += 1
|
||||
file_path: str = os.path.join(
|
||||
self.temp_dir, f"{self.idx:06d}-{uuid.uuid4()}.pt"
|
||||
)
|
||||
|
||||
with self.manager_lock:
|
||||
# Mark file as being saved
|
||||
self.file_locks[file_path] = threading.Lock()
|
||||
self.file_status[file_path] = "saving"
|
||||
# Add to history
|
||||
self.tensor_paths.append(file_path)
|
||||
|
||||
# Acquire semaphore to limit concurrent save operations
|
||||
self.save_semaphore.acquire() # pylint: disable=consider-using-with
|
||||
# Queue tensor for saving in background
|
||||
self.save_queue.put((tensor.detach(), file_path))
|
||||
|
||||
return file_path
|
||||
|
||||
def wait_for_save(self, file_path, timeout=None) -> None:
|
||||
"""Wait for a tensor to be saved to disk"""
|
||||
start_time = time.time()
|
||||
while timeout is None or time.time() - start_time < timeout:
|
||||
with self.manager_lock:
|
||||
if self.file_status.get(file_path) == "ready":
|
||||
return
|
||||
if self.file_status.get(file_path) in ["error", "missing", "deleted"]:
|
||||
return
|
||||
|
||||
if file_path in self.save_futures:
|
||||
future = self.save_futures[file_path]
|
||||
if future.done():
|
||||
return
|
||||
|
||||
# Small sleep to prevent CPU spinning
|
||||
time.sleep(0.01)
|
||||
|
||||
# Timeout
|
||||
logger.warning(f"Timeout waiting for tensor to be saved: {file_path}")
|
||||
return
|
||||
|
||||
def load_tensor(self, file_path, target_device="cuda"):
|
||||
"""Load tensor from disk or prefetch cache with proper synchronization"""
|
||||
# Wait for tensor to be saved if it's still in progress
|
||||
self.wait_for_save(file_path)
|
||||
|
||||
tensor = None
|
||||
|
||||
# Try to get from cache first
|
||||
with self.manager_lock:
|
||||
# Check if tensor is already in cache
|
||||
if file_path in self.prefetch_cache:
|
||||
tensor = self.prefetch_cache[file_path]
|
||||
del self.prefetch_cache[file_path]
|
||||
self.file_status[file_path] = "loaded"
|
||||
|
||||
if tensor is not None:
|
||||
# Ensure tensor is on correct device
|
||||
if target_device != "cpu" and tensor.device.type == "cpu":
|
||||
tensor = tensor.to(target_device, non_blocking=True)
|
||||
return tensor
|
||||
|
||||
# If not in cache, load directly from disk
|
||||
try:
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"File not found for loading: {file_path}")
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
tensor = torch.load(file_path, weights_only=True)
|
||||
|
||||
with self.manager_lock:
|
||||
self.file_status[file_path] = "loaded"
|
||||
|
||||
if target_device != "cpu":
|
||||
tensor = tensor.to(target_device, non_blocking=True)
|
||||
|
||||
return tensor
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading tensor from {file_path}: {e}")
|
||||
raise
|
||||
|
||||
def _safe_delete_file(self, file_path):
|
||||
"""Safely delete a file with proper synchronization"""
|
||||
with self.manager_lock:
|
||||
# Make sure any save operation is completed
|
||||
if file_path in self.save_futures:
|
||||
future = self.save_futures[file_path]
|
||||
try:
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
del self.save_futures[file_path]
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(
|
||||
f"Error canceling save operation for {file_path}: {e}"
|
||||
)
|
||||
|
||||
# Only delete if file exists and is not being prefetched
|
||||
status = self.file_status.get(file_path)
|
||||
if status in ["ready", "loaded", "error", "missing"]:
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
self.file_status[file_path] = "deleted"
|
||||
return True
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(f"Error deleting file {file_path}: {e}")
|
||||
return False
|
||||
|
||||
def trigger_prefetch(self, n=None):
|
||||
"""Trigger prefetching of the next N tensors with proper synchronization"""
|
||||
if n is None:
|
||||
n = self.max_prefetch
|
||||
|
||||
prefetch_paths = []
|
||||
with self.manager_lock:
|
||||
# Find files that are ready to be prefetched (not already in cache or being prefetched)
|
||||
for path in reversed(self.tensor_paths):
|
||||
if (
|
||||
path not in self.prefetch_cache
|
||||
and self.file_status.get(path) == "ready"
|
||||
):
|
||||
prefetch_paths.append(path)
|
||||
if len(prefetch_paths) >= n:
|
||||
break
|
||||
|
||||
# Queue files for prefetching
|
||||
for path in prefetch_paths:
|
||||
self.prefetch_queue.put(path)
|
||||
|
||||
def cleanup_tensor(self, file_path: str):
|
||||
"""Clean up a specific tensor file after it's been used"""
|
||||
with self.manager_lock:
|
||||
if file_path in self.tensor_paths:
|
||||
self.tensor_paths.remove(file_path)
|
||||
|
||||
# Remove from prefetch cache if present
|
||||
if file_path in self.prefetch_cache:
|
||||
del self.prefetch_cache[file_path]
|
||||
|
||||
# Remove from save futures if present
|
||||
if file_path in self.save_futures:
|
||||
future = self.save_futures[file_path]
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
del self.save_futures[file_path]
|
||||
|
||||
# Try to delete the file
|
||||
self._safe_delete_file(file_path)
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up all temp files and stop prefetch thread with proper synchronization"""
|
||||
self.stop_event.set()
|
||||
|
||||
# Cancel all pending save operations
|
||||
with self.manager_lock:
|
||||
for _, future in self.save_futures.items():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self.save_futures.clear()
|
||||
|
||||
# Drain the save queue
|
||||
while not self.save_queue.empty():
|
||||
try:
|
||||
self.save_queue.get_nowait()
|
||||
self.save_queue.task_done()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Shutdown the save pool
|
||||
self.save_pool.shutdown(wait=False)
|
||||
|
||||
# Join the save worker thread
|
||||
if self.save_worker.is_alive():
|
||||
self.save_worker.join(timeout=2.0)
|
||||
|
||||
# Join the prefetch worker threads
|
||||
for thread in self.prefetch_workers:
|
||||
if thread.is_alive():
|
||||
thread.join(timeout=2.0)
|
||||
|
||||
# Clear cache and remove all temporary files
|
||||
with self.manager_lock:
|
||||
self.prefetch_cache.clear()
|
||||
paths_to_delete = list(self.tensor_paths)
|
||||
self.tensor_paths.clear()
|
||||
|
||||
# Delete all temporary files
|
||||
for path in paths_to_delete:
|
||||
self._safe_delete_file(path)
|
||||
|
||||
# Remove temp directory
|
||||
try:
|
||||
if os.path.exists(self.temp_dir):
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(f"Error removing temporary directory {self.temp_dir}: {e}")
|
||||
|
||||
|
||||
class Disco(torch.autograd.Function):
|
||||
"""
|
||||
Disco: DIsk-based Storage and Checkpointing with Optimized prefetching
|
||||
Advanced disk-based gradient checkpointer with prefetching.
|
||||
"""
|
||||
|
||||
# Shared manager instance across all checkpointing operations
|
||||
_manager = None
|
||||
|
||||
@staticmethod
|
||||
def get_instance(prefetch_size=1, prefetch_to_gpu=True, save_workers=4):
|
||||
"""Get or create the offload manager"""
|
||||
if Disco._manager is None:
|
||||
Disco._manager = DiskOffloadManager(
|
||||
prefetch_size=prefetch_size,
|
||||
prefetch_to_gpu=prefetch_to_gpu,
|
||||
save_workers=save_workers,
|
||||
)
|
||||
return Disco._manager
|
||||
|
||||
@staticmethod
|
||||
@torch_cuda_amp_custom_fwd
|
||||
def forward(
|
||||
ctx,
|
||||
forward_function,
|
||||
hidden_states,
|
||||
*args,
|
||||
prefetch_size=1,
|
||||
prefetch_to_gpu=True,
|
||||
save_workers=4,
|
||||
):
|
||||
"""Forward pass that offloads activations to disk asynchronously"""
|
||||
# Get or create the manager
|
||||
manager = Disco.get_instance(
|
||||
prefetch_size=prefetch_size,
|
||||
prefetch_to_gpu=prefetch_to_gpu,
|
||||
save_workers=save_workers,
|
||||
)
|
||||
|
||||
# Save tensor to disk asynchronously
|
||||
file_path = manager.save_tensor(hidden_states)
|
||||
|
||||
# Run forward pass immediately without waiting for save to complete
|
||||
with torch.no_grad():
|
||||
output = forward_function(hidden_states, *args)
|
||||
|
||||
# Store what we need for backward
|
||||
ctx.save_for_backward(torch.tensor([0])) # Dummy tensor
|
||||
ctx.file_path = file_path
|
||||
ctx.forward_function = forward_function
|
||||
ctx.args = args
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch_cuda_amp_custom_bwd
|
||||
def backward(ctx, *grad_outputs):
|
||||
"""Backward pass that loads activations from disk with prefetching"""
|
||||
# Get the manager
|
||||
manager = Disco._manager
|
||||
|
||||
# Trigger prefetching for future tensors
|
||||
# This happens at the start of backward, so should have time to complete
|
||||
manager.trigger_prefetch()
|
||||
|
||||
# Load hidden states from disk or prefetch cache
|
||||
file_path = ctx.file_path
|
||||
try:
|
||||
# Ensure the file is saved before we try to load it
|
||||
manager.wait_for_save(file_path)
|
||||
|
||||
hidden_states = manager.load_tensor(file_path)
|
||||
hidden_states.requires_grad = True
|
||||
|
||||
# Compute gradients
|
||||
with torch.enable_grad():
|
||||
output = ctx.forward_function(hidden_states, *ctx.args)
|
||||
|
||||
# Handle tuple outputs properly
|
||||
if isinstance(output, tuple):
|
||||
if len(grad_outputs) == len(output):
|
||||
torch.autograd.backward(output, grad_outputs)
|
||||
else:
|
||||
torch.autograd.backward(output, grad_outputs[0])
|
||||
else:
|
||||
torch.autograd.backward(output, grad_outputs[0])
|
||||
|
||||
# Clean up the file after we're done with it
|
||||
manager.cleanup_tensor(file_path)
|
||||
|
||||
return (
|
||||
(
|
||||
None, # forward_function
|
||||
hidden_states.grad, # hidden_states grad
|
||||
)
|
||||
+ (None,) * len(ctx.args) # for each arg
|
||||
+ (
|
||||
None, # prefetch_size
|
||||
None, # prefetch_to_gpu
|
||||
None, # save_workers
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in backward pass: {e}")
|
||||
# Clean up the file even on error
|
||||
manager.cleanup_tensor(file_path)
|
||||
raise
|
||||
@@ -70,9 +70,13 @@ from axolotl.utils.distributed import (
|
||||
is_local_main_process,
|
||||
is_main_process,
|
||||
)
|
||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
||||
from axolotl.utils.gradient_checkpointing import (
|
||||
hf_grad_checkpoint_disk_offload_wrapper,
|
||||
hf_grad_checkpoint_offload_wrapper,
|
||||
)
|
||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
@@ -556,11 +560,21 @@ class ModelLoader:
|
||||
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||
|
||||
def apply_patches(self) -> None:
|
||||
if self.cfg.xformers_attention and self.cfg.sample_packing:
|
||||
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
|
||||
|
||||
patch_xformers_attn_over_fa2()
|
||||
self.cfg.flash_attention = True
|
||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
|
||||
|
||||
patch_accelerate_fsdp_utils()
|
||||
|
||||
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
|
||||
from axolotl.monkeypatch.peft.utils import patch_peft_prep_code
|
||||
|
||||
patch_peft_prep_code()
|
||||
|
||||
if self.cfg.flex_attention:
|
||||
from axolotl.monkeypatch.attention.flex_attn import (
|
||||
patch_flex_make_mask,
|
||||
@@ -609,6 +623,10 @@ class ModelLoader:
|
||||
|
||||
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
||||
if self.cfg.gradient_checkpointing == "offload_disk":
|
||||
transformers.modeling_utils.checkpoint = (
|
||||
hf_grad_checkpoint_disk_offload_wrapper
|
||||
)
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
self.patch_attention()
|
||||
@@ -1180,7 +1198,7 @@ class ModelLoader:
|
||||
],
|
||||
)
|
||||
|
||||
def prepare_model(self, qlora_fsdp) -> None:
|
||||
def prepare_model(self, qlora_fsdp: bool) -> None:
|
||||
skip_prepare_model_for_kbit_training = False
|
||||
if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora":
|
||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||
@@ -1310,7 +1328,10 @@ class ModelLoader:
|
||||
# make sure these are fp32 per Ramesh et al. (2021)
|
||||
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
|
||||
if not self.cfg.fsdp:
|
||||
# FSDP doesn't like mixed Float and BFloat16
|
||||
# we don't run this during FSDP because this will leave mixed
|
||||
# float and bfloat16 dtypes in the model which FSDP doesn't like
|
||||
if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast:
|
||||
embedding_modules = []
|
||||
self.convert_embedding_modules_dtype(
|
||||
embedding_modules,
|
||||
dist_dtype=torch.float32,
|
||||
@@ -1359,7 +1380,7 @@ class ModelLoader:
|
||||
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
||||
if (
|
||||
self.cfg.adapter
|
||||
and self.cfg.rl in ["dpo", "ipo", "kto"]
|
||||
and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO]
|
||||
and not self.cfg.merge_lora
|
||||
):
|
||||
_, lora_config = load_lora(
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
# pylint: skip-file
|
||||
"""
|
||||
Multipack Batch Sampler
|
||||
Multipack Batch Sampler - An efficient batch sampler for packing variable-length sequences
|
||||
into fixed-capacity batches to optimize memory usage and training throughput.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Iterable, List, Union
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from multiprocessing import cpu_count, get_context
|
||||
from typing import Iterable, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
@@ -13,26 +16,39 @@ from torch.utils.data import BatchSampler, Sampler, SequentialSampler
|
||||
from axolotl.utils.distributed import reduce_and_broadcast
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
LOG.setLevel(logging.INFO)
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_check(a: np.ndarray, c: int, n: int):
|
||||
# First-fit-decreasing bin packing
|
||||
# Check if a[] could fit in n bins with capacity c
|
||||
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
||||
def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int):
|
||||
"""
|
||||
First-fit-decreasing bin packing algorithm check
|
||||
|
||||
a = np.sort(a)[::-1]
|
||||
bins = np.full((n,), c, dtype=a.dtype)
|
||||
for size in a:
|
||||
Checks if sequences with the given lengths could fit in the specified number of bins
|
||||
|
||||
Args:
|
||||
sequence_lengths: Array of sequence lengths
|
||||
bin_capacity: Maximum capacity of each bin
|
||||
num_bins: Number of bins available
|
||||
|
||||
Returns:
|
||||
True if all sequences can be packed, False otherwise
|
||||
"""
|
||||
# Sort sequence lengths in descending order for optimal packing
|
||||
sequence_lengths = np.sort(sequence_lengths)[::-1]
|
||||
# Initialize all bins with full capacity
|
||||
bins = np.full((num_bins,), bin_capacity, dtype=sequence_lengths.dtype)
|
||||
|
||||
# Try to place each sequence in the first bin it fits
|
||||
for size in sequence_lengths:
|
||||
not_found = True
|
||||
for idx in range(n):
|
||||
for idx in range(num_bins):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
not_found = False
|
||||
break
|
||||
|
||||
# If no bin could fit this sequence, packing failed
|
||||
if not_found:
|
||||
return False
|
||||
|
||||
@@ -40,86 +56,155 @@ def ffd_check(a: np.ndarray, c: int, n: int):
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
||||
# First-fit-decreasing bin packing (with result return)
|
||||
def pack_group(
|
||||
sequence_lengths: np.ndarray,
|
||||
group_offset: int,
|
||||
bin_capacity: int,
|
||||
max_bins: int,
|
||||
bin_size: int,
|
||||
safe_mode: bool = True,
|
||||
):
|
||||
"""
|
||||
Pack a group of sequences into bins using First-Fit Decreasing algorithm
|
||||
|
||||
indices = np.argsort(a)[::-1]
|
||||
a = a[indices]
|
||||
Args:
|
||||
sequence_lengths: Array of sequence lengths
|
||||
group_offset: Offset to apply to indices when returning results
|
||||
bin_capacity: Maximum capacity of each bin
|
||||
max_bins: Maximum number of bins to use
|
||||
bin_size: Maximum number of sequences per bin
|
||||
safe_mode: If True, use a more conservative packing approach
|
||||
|
||||
bins: List[Any] = []
|
||||
bins_result: List[Any] = []
|
||||
for a_id, size in enumerate(a):
|
||||
add_new = True
|
||||
for idx in range(len(bins)):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
bins_result[idx].append(indices[a_id] + start_index)
|
||||
add_new = False
|
||||
Returns:
|
||||
List of bins, where each bin contains indices of sequences assigned to it
|
||||
"""
|
||||
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
|
||||
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
|
||||
|
||||
for seq_id, size in enumerate(sequence_lengths):
|
||||
global_idx = seq_id + group_offset
|
||||
|
||||
# Try to place sequence in existing bins
|
||||
add_new_bin = True
|
||||
for bin_idx, _ in enumerate(bins_remaining_space):
|
||||
if (
|
||||
bins_remaining_space[bin_idx] >= size
|
||||
and len(bins_assigned_sequences[bin_idx]) < bin_size
|
||||
):
|
||||
bins_remaining_space[bin_idx] -= size
|
||||
bins_assigned_sequences[bin_idx].append(global_idx)
|
||||
add_new_bin = False
|
||||
break
|
||||
|
||||
if add_new:
|
||||
bins.append(c - size)
|
||||
bins_result.append([indices[a_id] + start_index])
|
||||
# Create a new bin if needed and if we haven't reached the limit
|
||||
if add_new_bin:
|
||||
if len(bins_remaining_space) >= max_bins and safe_mode:
|
||||
# In safe mode, skip items that would exceed max_bins
|
||||
continue
|
||||
bins_remaining_space.append(bin_capacity - size)
|
||||
bins_assigned_sequences.append([global_idx])
|
||||
|
||||
return bins_result
|
||||
# Safety check to avoid infinite bins
|
||||
if len(bins_remaining_space) > len(sequence_lengths):
|
||||
break
|
||||
|
||||
return bins_assigned_sequences
|
||||
|
||||
|
||||
@numba.njit
|
||||
def allocate(
|
||||
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
||||
# Define a standalone function for multiprocessing
|
||||
def _process_group(args):
|
||||
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode = args
|
||||
return pack_group(
|
||||
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode
|
||||
)
|
||||
|
||||
|
||||
def pack_parallel(
|
||||
sequence_lengths: np.ndarray,
|
||||
bin_capacity: int,
|
||||
group_size: int,
|
||||
bin_size: int,
|
||||
num_processes: int | None = None,
|
||||
safe_mode: bool = True,
|
||||
mp_start_method: str | None = "spawn",
|
||||
):
|
||||
# Dynamic batch allocator, similar to Multifit
|
||||
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
||||
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
||||
"""
|
||||
Pack sequences into bins using parallel processing
|
||||
|
||||
s = 0
|
||||
start_index = 0
|
||||
result = []
|
||||
Args:
|
||||
sequence_lengths: Array of sequence lengths
|
||||
bin_capacity: Maximum capacity of each bin as total number of tokens
|
||||
group_size: Number of sequences to process in each group
|
||||
bin_size: Maximum number of bins to use
|
||||
num_processes: Number of parallel processes to use
|
||||
safe_mode: If True, use a more conservative packing approach
|
||||
mp_start_method: Multiprocessing start method ('fork', 'spawn', 'forkserver').
|
||||
'spawn' is often safer with Numba/PyTorch.
|
||||
Set to None to use system default.
|
||||
Returns:
|
||||
List of bins, where each bin contains indices of sequences assigned to it
|
||||
"""
|
||||
num_items = len(sequence_lengths)
|
||||
if num_processes is None:
|
||||
num_processes = max(1, min(num_items // group_size, cpu_count()))
|
||||
|
||||
while True:
|
||||
# binary search [l, r)
|
||||
left = 1
|
||||
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
||||
# Create tasks for parallel processing
|
||||
tasks = []
|
||||
for i in range(0, num_items, group_size):
|
||||
group_lengths = sequence_lengths[i : i + group_size]
|
||||
max_bins = len(group_lengths) # Allow as many bins as items in the group
|
||||
tasks.append((group_lengths, i, bin_capacity, max_bins, bin_size, safe_mode))
|
||||
|
||||
while right - left > 1:
|
||||
mid = (left + right) // 2
|
||||
if ffd_check(lengths[start_index : start_index + mid], c, n):
|
||||
left = mid
|
||||
else:
|
||||
right = mid
|
||||
# Process groups in parallel
|
||||
all_bins = []
|
||||
|
||||
# use length l
|
||||
batch = ffd_with_result(
|
||||
lengths[start_index : start_index + left], c, start_index
|
||||
)
|
||||
assert len(batch) <= n
|
||||
if len(batch) < n:
|
||||
break
|
||||
mp_ctx = None
|
||||
if mp_start_method:
|
||||
try:
|
||||
mp_ctx = get_context(mp_start_method)
|
||||
except ValueError:
|
||||
LOG.warning(
|
||||
f"Failed to get multiprocessing context '{mp_start_method}'. "
|
||||
f"Falling back to default. Available: {get_context().get_all_start_methods()}"
|
||||
)
|
||||
mp_ctx = (
|
||||
None # Fallback to default context if specified one is not available
|
||||
)
|
||||
|
||||
start_index += left
|
||||
s = lengths_cumsum[start_index - 1]
|
||||
if num_processes == 1:
|
||||
LOG.debug("Using single process for pack_parallel, running sequentially.")
|
||||
for task_args in tasks:
|
||||
group_bins = _process_group(task_args)
|
||||
all_bins.extend(group_bins)
|
||||
else:
|
||||
# Use ProcessPoolExecutor only if num_processes > 1
|
||||
# Pass mp_context if available
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=num_processes, mp_context=mp_ctx
|
||||
) as executor:
|
||||
for group_bins in executor.map(_process_group, tasks):
|
||||
all_bins.extend(group_bins)
|
||||
|
||||
# add local rank
|
||||
result.append(batch[rank])
|
||||
|
||||
return result, s, len(result) * c * n
|
||||
return all_bins
|
||||
|
||||
|
||||
@numba.njit
|
||||
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
|
||||
def allocate_sequentially(
|
||||
sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int
|
||||
):
|
||||
"""
|
||||
Sequential allocator that preserves example order
|
||||
|
||||
Parameters:
|
||||
- lengths: The lengths of all examples
|
||||
- rank: The current rank (for distributed training)
|
||||
- c: The capacity of each bin (maximum sequence length)
|
||||
- n: Number of ranks
|
||||
Args:
|
||||
sequence_lengths: The lengths of all examples
|
||||
rank: The current rank (for distributed training)
|
||||
bin_capacity: The capacity of each bin (maximum sequence length)
|
||||
num_ranks: Number of ranks (processes/GPUs)
|
||||
|
||||
Returns:
|
||||
- result: List of batches for the current rank
|
||||
- total_used: Number of actual example tokens
|
||||
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
|
||||
rank_batches: List of batches for the current rank
|
||||
total_tokens_used: Number of actual example tokens
|
||||
total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
|
||||
"""
|
||||
result = []
|
||||
total_used = 0
|
||||
@@ -127,9 +212,9 @@ def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
|
||||
# First, do sequential packing into bins
|
||||
all_bins = []
|
||||
current_bin = [0 for i in range(0)] # numba hint
|
||||
remaining_capacity = c
|
||||
remaining_capacity = bin_capacity
|
||||
|
||||
for idx, size in enumerate(lengths):
|
||||
for idx, size in enumerate(sequence_lengths):
|
||||
if size <= remaining_capacity:
|
||||
# Example fits in current bin
|
||||
current_bin.append(idx)
|
||||
@@ -140,7 +225,7 @@ def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
|
||||
if current_bin: # Add non-empty bin to all_bins
|
||||
all_bins.append(current_bin)
|
||||
current_bin = [idx]
|
||||
remaining_capacity = c - size
|
||||
remaining_capacity = bin_capacity - size
|
||||
total_used += size
|
||||
|
||||
# Add the last bin if not empty
|
||||
@@ -148,132 +233,227 @@ def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
|
||||
all_bins.append(current_bin)
|
||||
|
||||
# Assign bins to ranks - each rank gets every n-th bin
|
||||
for bin_idx in range(rank, len(all_bins), n):
|
||||
for bin_idx in range(rank, len(all_bins), num_ranks):
|
||||
result.append(all_bins[bin_idx])
|
||||
|
||||
return result, total_used, len(all_bins) * c
|
||||
return result, total_used, len(all_bins) * bin_capacity
|
||||
|
||||
|
||||
class MultipackBatchSampler(BatchSampler):
|
||||
"""Batch sampler class for multipack"""
|
||||
"""
|
||||
Batch sampler class for efficient packing of variable-length sequences
|
||||
|
||||
This sampler packs sequences into fixed-capacity bins (batches) to maximize
|
||||
GPU memory utilization and training throughput by reducing padding.
|
||||
|
||||
It supports both parallel packing (using FFD algorithm) and
|
||||
sequential packing (preserving original sequence order).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampler: Union[Sampler[int], Iterable[int]],
|
||||
batch_size: int,
|
||||
batch_max_len: int,
|
||||
lengths: np.ndarray,
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
drop_last: bool = False,
|
||||
num_count_samples: int = 16,
|
||||
sequential: bool = False,
|
||||
**kwargs,
|
||||
batch_size: int, # Number of bins per batch
|
||||
batch_max_len: int, # Maximum sequence length (bin capacity)
|
||||
lengths: np.ndarray, # Sequence lengths
|
||||
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
||||
drop_last: bool = False, # Whether to drop final batches (might be incomplete)
|
||||
num_count_samples: int = 16, # Number of times to estimate batch count
|
||||
sequential: bool = False, # Whether to use sequential packing
|
||||
group_size: int = 100_000, # Size of groups for parallel packing
|
||||
bin_size: int = 200, # The max number of samples that can be packed in a single bin
|
||||
num_processes: int | None = None, # Number of processes for parallel packing
|
||||
safe_mode: bool = True, # Conservative packing to prevent training instability
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
self.batch_size = batch_size
|
||||
self.batch_max_len = batch_max_len
|
||||
self.lengths: np.ndarray = lengths
|
||||
self.lengths = np.array(lengths, dtype=np.int32)
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
self.sequential = sequential
|
||||
self.group_size = group_size
|
||||
self.bin_size = bin_size
|
||||
self.num_processes = num_processes
|
||||
self.safe_mode = safe_mode
|
||||
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
|
||||
self.epoch = 0
|
||||
|
||||
# statistics
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
# Efficiency statistics tracking
|
||||
self.total_tokens_used = 0
|
||||
self.total_token_slots = 0
|
||||
|
||||
# The number of times to calculate the batches to determine the minimum packed dataset length for the local rank
|
||||
# The number of times to calculate batches to determine minimum packed dataset length
|
||||
self.num_count_samples = num_count_samples
|
||||
# the minimum packed dataset length across all ranks determined by a gather/broadcast
|
||||
# Minimum packed dataset length across all ranks (determined by gather/broadcast)
|
||||
self.len_across_ranks = None
|
||||
|
||||
# Cache for batches
|
||||
self._batches = None
|
||||
|
||||
if self.sequential and not isinstance(sampler, SequentialSampler):
|
||||
LOG.warn(
|
||||
LOG.warning(
|
||||
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
|
||||
)
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
"""Set the epoch number, used for reproducible shuffling across epochs"""
|
||||
self.epoch = epoch
|
||||
self._batches = None # Invalidate batch cache
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
indices = [idx for idx in self.sampler]
|
||||
"""
|
||||
Generate packed batches for training
|
||||
|
||||
lengths = self.lengths[indices]
|
||||
lengths_cumsum = np.cumsum(lengths)
|
||||
Args:
|
||||
set_stats: Whether to update efficiency statistics
|
||||
|
||||
if self.sequential:
|
||||
batches, total_used, total_slots = allocate_sequentially(
|
||||
lengths=lengths,
|
||||
rank=0,
|
||||
c=self.batch_max_len,
|
||||
n=1,
|
||||
)
|
||||
else:
|
||||
batches, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=0,
|
||||
c=self.batch_max_len,
|
||||
n=1,
|
||||
)
|
||||
Returns:
|
||||
List of batches, where each batch contains multiple bins,
|
||||
and each bin contains multiple sequence indices
|
||||
"""
|
||||
if self._batches is not None:
|
||||
return self._batches
|
||||
|
||||
batches = [
|
||||
[
|
||||
[indices[b_idx] for b_idx in batch]
|
||||
for batch in batches[i : i + self.batch_size]
|
||||
]
|
||||
for i in range(0, len(batches), self.batch_size)
|
||||
# Get indices from the sampler
|
||||
indices = [ # pylint: disable=unnecessary-comprehension
|
||||
idx for idx in self.sampler
|
||||
]
|
||||
|
||||
# statistics
|
||||
if set_stats:
|
||||
self.eff_total_used += total_used
|
||||
self.eff_total_slots += total_slots
|
||||
# Get lengths of the selected sequences
|
||||
lengths = self.lengths[indices]
|
||||
|
||||
# Pack sequences into bins using either sequential or parallel packing
|
||||
if self.sequential:
|
||||
bins, total_used, total_slots = allocate_sequentially(
|
||||
lengths,
|
||||
rank=0,
|
||||
bin_capacity=self.batch_max_len,
|
||||
num_ranks=1,
|
||||
)
|
||||
# Map bin indices back to original indices
|
||||
bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins]
|
||||
else:
|
||||
# Use parallel packing
|
||||
all_bins = pack_parallel(
|
||||
lengths,
|
||||
bin_capacity=self.batch_max_len,
|
||||
group_size=self.group_size,
|
||||
bin_size=self.bin_size,
|
||||
num_processes=self.num_processes,
|
||||
safe_mode=self.safe_mode,
|
||||
)
|
||||
|
||||
# Map bin indices back to original indices
|
||||
bins = [
|
||||
[indices[b_idx] for b_idx in bin_indices] for bin_indices in all_bins
|
||||
]
|
||||
|
||||
# Calculate efficiency statistics
|
||||
total_used = lengths.sum()
|
||||
total_slots = len(all_bins) * self.batch_max_len
|
||||
|
||||
# Group bins into batches (each batch contains batch_size bins)
|
||||
batches = [
|
||||
bins[i : i + self.batch_size] for i in range(0, len(bins), self.batch_size)
|
||||
]
|
||||
|
||||
# Drop last batch if requested and it's incomplete
|
||||
if self.drop_last and len(batches[-1]) < self.batch_size:
|
||||
batches = batches[:-1]
|
||||
# Adjust total_slots if we dropped a batch
|
||||
if not self.sequential:
|
||||
total_slots -= (self.batch_size - len(batches[-1])) * self.batch_max_len
|
||||
|
||||
# Update statistics if requested
|
||||
if set_stats:
|
||||
self.total_tokens_used += total_used
|
||||
self.total_token_slots += total_slots
|
||||
|
||||
self._batches = batches
|
||||
return batches
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Return an iterator over batches
|
||||
|
||||
The batches are truncated to match the minimum number of batches across all ranks
|
||||
to ensure distributed training balance
|
||||
"""
|
||||
batches = self.generate_batches(set_stats=True)
|
||||
if self.len_across_ranks:
|
||||
# make sure the batches we iterate over is truncated to the same min length across all ranks
|
||||
# Truncate batches to ensure all ranks have the same number of batches
|
||||
batches = batches[: self.len_across_ranks]
|
||||
return iter(batches)
|
||||
|
||||
def num_batches(self):
|
||||
batches = self.generate_batches(set_stats=True)
|
||||
return len(batches)
|
||||
|
||||
def efficiency(self):
|
||||
return self.eff_total_used / self.eff_total_slots
|
||||
"""
|
||||
Calculate the packing efficiency (ratio of tokens used to total token slots)
|
||||
Higher is better - 1.0 would mean perfect packing with no wasted space
|
||||
"""
|
||||
if self.total_token_slots == 0:
|
||||
self.generate_batches(set_stats=True)
|
||||
if self.total_token_slots == 0:
|
||||
return 0.0
|
||||
# Return a Python float instead of potentially a numpy float
|
||||
return float(self.total_tokens_used / self.total_token_slots)
|
||||
|
||||
def gather_efficiency(self):
|
||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||
return math.floor(0.997 * max(estimates))
|
||||
"""
|
||||
Gather and synchronize packing efficiency estimates across all distributed ranks
|
||||
Returns a conservative efficiency estimate based on the measurements
|
||||
"""
|
||||
|
||||
def calc_sample_packing_eff_est(estimates: list[float]):
|
||||
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||
# Use 99.7% of max observed efficiency as a safe estimate
|
||||
max_eff = max(float(eff) for eff in estimates)
|
||||
return math.floor(0.997 * max_eff)
|
||||
|
||||
# Gather efficiency from all ranks and apply the calculation function
|
||||
sample_packing_actual_eff_all = reduce_and_broadcast(
|
||||
lambda: self.efficiency(), # pylint: disable=unnecessary-lambda
|
||||
lambda: float(self.efficiency()), # pylint: disable=unnecessary-lambda
|
||||
calc_sample_packing_eff_est,
|
||||
)
|
||||
|
||||
# Quantize to 0.5% intervals for stability
|
||||
sample_packing_eff_est = (
|
||||
math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0
|
||||
)
|
||||
return sample_packing_eff_est
|
||||
|
||||
def gather_len_batches(self, num):
|
||||
"""
|
||||
Gather and synchronize batch counts across all distributed ranks
|
||||
Returns the minimum number of batches available on any rank
|
||||
"""
|
||||
|
||||
def calc_min_len(estimates: list[(int, float)]):
|
||||
LOG.info(f"gather_len_batches: {repr(estimates)}")
|
||||
return math.floor(min(estimates))
|
||||
|
||||
# Find minimum batch count across ranks to ensure balance
|
||||
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
|
||||
return min_len_batches
|
||||
|
||||
def __len__(self):
|
||||
if not self.len_across_ranks:
|
||||
len_batches = min(
|
||||
[self.num_batches() for _ in range(self.num_count_samples)]
|
||||
"""
|
||||
Return the total number of batches that will be yielded by this sampler
|
||||
|
||||
This is calculated as the minimum number of batches available on any rank
|
||||
to ensure balanced distributed training
|
||||
"""
|
||||
if self._batches is None:
|
||||
self._batches = self.generate_batches(set_stats=True)
|
||||
|
||||
if self.len_across_ranks is None:
|
||||
# Sample multiple times to get stable estimate
|
||||
len_batches = min( # pylint: disable=consider-using-generator
|
||||
[len(self._batches) for _ in range(self.num_count_samples)]
|
||||
)
|
||||
# Gather minimum across all ranks
|
||||
self.len_across_ranks = self.gather_len_batches(len_batches)
|
||||
|
||||
return self.len_across_ranks
|
||||
|
||||
@@ -27,7 +27,7 @@ from axolotl.utils.schemas.datasets import (
|
||||
StepwiseSupervisedDataset,
|
||||
)
|
||||
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
||||
from axolotl.utils.schemas.enums import ChatTemplate, RLType
|
||||
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
|
||||
from axolotl.utils.schemas.integrations import (
|
||||
CometConfig,
|
||||
GradioConfig,
|
||||
@@ -82,6 +82,7 @@ class AxolotlInputConfig(
|
||||
mean_resizing_embeddings: bool | None = False
|
||||
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
||||
shrink_embeddings: bool | None = None
|
||||
embeddings_skip_upcast: bool | None = None
|
||||
|
||||
rl: RLType | None = None
|
||||
trl: TRLConfig | None = Field(
|
||||
@@ -177,7 +178,7 @@ class AxolotlInputConfig(
|
||||
|
||||
# torch_dtype: torch.dtype | None
|
||||
|
||||
gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field(
|
||||
gradient_checkpointing: Literal["offload", "offload_disk"] | bool | None = Field(
|
||||
default=False
|
||||
)
|
||||
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
||||
@@ -185,6 +186,12 @@ class AxolotlInputConfig(
|
||||
unfrozen_parameters: list[str] | None = None
|
||||
|
||||
sequence_len: int = Field(default=512)
|
||||
sequence_len_overflow_handling: Literal["drop", "truncate"] = Field(
|
||||
default="drop",
|
||||
json_schema_extra={
|
||||
"description": "How to handle sequences that overflow the sequence_len: 'drop' (remove the sample) or 'truncate' (cut off excess tokens)."
|
||||
},
|
||||
)
|
||||
min_sample_len: int | None = None
|
||||
max_prompt_len: int = Field(
|
||||
default=512,
|
||||
@@ -259,7 +266,7 @@ class AxolotlInputConfig(
|
||||
|
||||
sequence_parallel_degree: int | None = None
|
||||
heads_k_stride: int | None = None
|
||||
ring_attn_func: str | None = None
|
||||
ring_attn_func: RingAttnFunc | None = None
|
||||
|
||||
special_tokens: SpecialTokensConfig | None = None
|
||||
tokens: list[str] | None = None
|
||||
@@ -435,16 +442,6 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sample_packing_w_xformers(cls, data):
|
||||
if data.get("sample_packing") and data.get("xformers_attention"):
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -471,9 +468,10 @@ class AxolotlInputConfig(
|
||||
and not data.get("flash_attention")
|
||||
and not data.get("sdp_attention")
|
||||
and not data.get("flex_attention")
|
||||
and not data.get("xformers_attention")
|
||||
):
|
||||
LOG.warning(
|
||||
"sample_packing without flash, sdp or flex attention does not handle cross sample decontamination."
|
||||
"sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination."
|
||||
)
|
||||
|
||||
return data
|
||||
@@ -512,10 +510,17 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def hint_sample_packing_padding(cls, data):
|
||||
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
|
||||
LOG.warning(
|
||||
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||
)
|
||||
if data.get("sample_packing"):
|
||||
pad_to_sequence_len = data.get("pad_to_sequence_len")
|
||||
if pad_to_sequence_len is False:
|
||||
LOG.warning(
|
||||
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||
)
|
||||
elif pad_to_sequence_len is None:
|
||||
LOG.info(
|
||||
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
|
||||
)
|
||||
data["pad_to_sequence_len"] = True
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -783,7 +788,7 @@ class AxolotlInputConfig(
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_simpo_warmup(self):
|
||||
if self.rl == "simpo" and self.warmup_ratio:
|
||||
if self.rl is RLType.SIMPO and self.warmup_ratio:
|
||||
raise ValueError(
|
||||
"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
|
||||
)
|
||||
@@ -1150,16 +1155,28 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
# @model_validator(mode="before")
|
||||
# @classmethod
|
||||
# def check_grpo_peft_liger(cls, data):
|
||||
# if (
|
||||
# data.get("rl") == "grpo"
|
||||
# and data.get("trl", {})
|
||||
# and data.get("trl").get("use_liger_loss")
|
||||
# and data.get("adapter")
|
||||
# ):
|
||||
# raise ValueError("PEFT + GRPO + Liger is not yet supported")
|
||||
# return data
|
||||
#
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_grpo_peft_liger(cls, data):
|
||||
def check_grpo_liger_sequence_parallel(cls, data):
|
||||
if (
|
||||
data.get("rl") == "grpo"
|
||||
and data.get("trl", {})
|
||||
and data.get("trl").get("use_liger_loss")
|
||||
and data.get("adapter")
|
||||
and data.get("sequence_parallel_degree", 1) > 1
|
||||
):
|
||||
raise ValueError("PEFT + GRPO + Liger is not yet supported")
|
||||
raise ValueError("GRPO + SP + Liger not currently supported")
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -1174,7 +1191,7 @@ class AxolotlInputConfig(
|
||||
|
||||
if self.sample_packing and self.micro_batch_size > 1:
|
||||
raise ValueError(
|
||||
"micro_batch_size must be set to 1 when sample_packing is enabled"
|
||||
"micro_batch_size must be set to 1 when sample_packing is enabled "
|
||||
"due to a `ring-flash-attn` requirement"
|
||||
)
|
||||
|
||||
@@ -1206,16 +1223,8 @@ class AxolotlInputConfig(
|
||||
if getattr(self, "sequence_parallel_degree", 1) == 1:
|
||||
return self
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||
|
||||
if self.ring_attn_func is not None:
|
||||
valid_funcs = list(RingAttnFunc)
|
||||
if self.ring_attn_func in valid_funcs:
|
||||
self.ring_attn_func = RingAttnFunc(self.ring_attn_func)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"ring_attn_func: {self.ring_attn_func} must be in {valid_funcs}"
|
||||
)
|
||||
self.ring_attn_func = RingAttnFunc(self.ring_attn_func)
|
||||
else:
|
||||
# Default ring attention function selection
|
||||
sample_packing = getattr(self, "sample_packing", False)
|
||||
@@ -1346,6 +1355,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
):
|
||||
return data
|
||||
|
||||
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
|
||||
if data.get("lora_dropout") != 0:
|
||||
return data
|
||||
|
||||
# Check multi-GPU compatibility
|
||||
capabilities = data.get("capabilities")
|
||||
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
|
||||
|
||||
@@ -6,12 +6,12 @@ from enum import Enum
|
||||
class RLType(str, Enum):
|
||||
"""RL trainer type configuration subset"""
|
||||
|
||||
dpo = "dpo" # pylint: disable=invalid-name
|
||||
grpo = "grpo" # pylint: disable=invalid-name
|
||||
ipo = "ipo" # pylint: disable=invalid-name
|
||||
orpo = "orpo" # pylint: disable=invalid-name
|
||||
kto = "kto" # pylint: disable=invalid-name
|
||||
simpo = "simpo" # pylint: disable=invalid-name
|
||||
DPO = "dpo" # pylint: disable=invalid-name
|
||||
GRPO = "grpo" # pylint: disable=invalid-name
|
||||
IPO = "ipo" # pylint: disable=invalid-name
|
||||
ORPO = "orpo" # pylint: disable=invalid-name
|
||||
KTO = "kto" # pylint: disable=invalid-name
|
||||
SIMPO = "simpo" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChatTemplate(str, Enum):
|
||||
@@ -53,4 +53,16 @@ class CustomSupportedOptimizers(str, Enum):
|
||||
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
||||
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||
came_pytorch = "came_pytorch" # pylint: disable=invalid-name
|
||||
muon = "muon" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class RingAttnFunc(str, Enum):
|
||||
"""Enum class for supported `ring-flash-attn` implementations"""
|
||||
|
||||
# VARLEN_RING = "varlen_ring"
|
||||
# VARLEN_ZIGZAG = "varlen_zigzag"
|
||||
VARLEN_LLAMA3 = "varlen_llama3"
|
||||
BATCH_RING = "batch_ring"
|
||||
# BATCH_ZIGZAG = "batch_zigzag"
|
||||
# BATCH_STRIPE = "batch_stripe"
|
||||
|
||||
@@ -75,8 +75,10 @@ class HyperparametersConfig(BaseModel):
|
||||
lr_groups: list[LrGroup] | None = None
|
||||
|
||||
adam_epsilon: float | None = None
|
||||
adam_epsilon2: float | None = None
|
||||
adam_beta1: float | None = None
|
||||
adam_beta2: float | None = None
|
||||
adam_beta3: float | None = None
|
||||
max_grad_norm: float | None = None
|
||||
num_epochs: float = Field(default=1.0)
|
||||
|
||||
|
||||
@@ -207,10 +207,18 @@ def add_length(sample):
|
||||
|
||||
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
"""
|
||||
Drop samples whose sequence length is either too long (> sequence_len)
|
||||
or too short (< min_sequence_len).
|
||||
|
||||
Works for both single-example (list[int]) or batched (list[list[int]]).
|
||||
Determines whether samples should be kept based on sequence length constraints.
|
||||
|
||||
For a single example or a batch, returns True (or a list of booleans) if each sequence's length is within the specified range; otherwise, returns False (or a list with False for out-of-range sequences).
|
||||
|
||||
Args:
|
||||
sample: A dictionary containing "input_ids" as a list of ints or a list of lists of ints.
|
||||
sequence_len: Maximum allowed sequence length (inclusive).
|
||||
min_sequence_len: Minimum allowed sequence length (inclusive).
|
||||
|
||||
Returns:
|
||||
True if the single example is within the length range, False otherwise.
|
||||
For batched input, returns a list of booleans indicating which sequences are within the range.
|
||||
"""
|
||||
min_sequence_len = min_sequence_len or 2
|
||||
|
||||
@@ -235,7 +243,121 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
return results
|
||||
|
||||
|
||||
def truncate_or_drop_long_seq(
|
||||
sample, sequence_len=2048, min_sequence_len=2, handling="drop"
|
||||
):
|
||||
"""
|
||||
Drops or truncates samples based on sequence length constraints.
|
||||
|
||||
If handling is "drop", returns a boolean or list of booleans indicating whether each sample's sequence length is within the specified range. If handling is "truncate", returns the sample with sequences longer than sequence_len truncated and sequences shorter than min_sequence_len omitted. Supports both single-example and batched inputs.
|
||||
|
||||
Args:
|
||||
sample: A dictionary containing at least an "input_ids" field, representing either a single sequence or a batch of sequences.
|
||||
sequence_len: Maximum allowed sequence length.
|
||||
min_sequence_len: Minimum allowed sequence length.
|
||||
handling: "drop" to filter out samples outside the range, "truncate" to truncate long sequences.
|
||||
|
||||
Returns:
|
||||
In "drop" mode, a boolean or list of booleans indicating which samples to keep. In "truncate" mode, the modified sample with sequences truncated as needed.
|
||||
"""
|
||||
min_sequence_len = min_sequence_len or 2
|
||||
result = None
|
||||
|
||||
if handling == "drop":
|
||||
return drop_long_seq(sample, sequence_len, min_sequence_len)
|
||||
|
||||
input_ids = sample["input_ids"]
|
||||
|
||||
# Edge case: if input_ids is empty
|
||||
if not input_ids:
|
||||
result = False if handling == "drop" else sample
|
||||
# Single example (input_ids is a list of int)
|
||||
elif isinstance(input_ids[0], int):
|
||||
length = len(input_ids)
|
||||
|
||||
# Handle samples that are too short - always drop them
|
||||
if length < min_sequence_len:
|
||||
result = False if handling == "drop" else sample
|
||||
# If truncation is enabled and the sample is too long, truncate it
|
||||
elif length > sequence_len and handling == "truncate":
|
||||
sample["input_ids"] = input_ids[:sequence_len]
|
||||
|
||||
# Also truncate attention_mask if present
|
||||
if "attention_mask" in sample:
|
||||
sample["attention_mask"] = sample["attention_mask"][:sequence_len]
|
||||
|
||||
# Also truncate labels if present
|
||||
if "labels" in sample:
|
||||
sample["labels"] = sample["labels"][:sequence_len]
|
||||
|
||||
# Also truncate position_ids if present
|
||||
if "position_ids" in sample:
|
||||
sample["position_ids"] = sample["position_ids"][:sequence_len]
|
||||
|
||||
# Update length if present
|
||||
if "length" in sample:
|
||||
sample["length"] = sequence_len
|
||||
|
||||
result = sample
|
||||
# For drop mode or if the sample doesn't exceed max length
|
||||
else:
|
||||
result = (
|
||||
min_sequence_len <= length <= sequence_len
|
||||
if handling == "drop"
|
||||
else sample
|
||||
)
|
||||
# Batched (input_ids is a list of lists)
|
||||
else:
|
||||
if handling == "drop":
|
||||
results = []
|
||||
for seq in input_ids:
|
||||
length = len(seq)
|
||||
results.append(min_sequence_len <= length <= sequence_len)
|
||||
result = results
|
||||
else: # truncate
|
||||
# Check each sequence in the batch
|
||||
for i, seq in enumerate(input_ids):
|
||||
length = len(seq)
|
||||
|
||||
# Skip sequences that are too short
|
||||
if length < min_sequence_len:
|
||||
continue
|
||||
|
||||
# Truncate sequences that are too long
|
||||
if length > sequence_len:
|
||||
input_ids[i] = seq[:sequence_len]
|
||||
|
||||
# Also truncate attention_mask if present
|
||||
if "attention_mask" in sample:
|
||||
sample["attention_mask"][i] = sample["attention_mask"][i][
|
||||
:sequence_len
|
||||
]
|
||||
|
||||
# Also truncate labels if present
|
||||
if "labels" in sample:
|
||||
sample["labels"][i] = sample["labels"][i][:sequence_len]
|
||||
|
||||
# Also truncate position_ids if present
|
||||
if "position_ids" in sample:
|
||||
sample["position_ids"][i] = sample["position_ids"][i][
|
||||
:sequence_len
|
||||
]
|
||||
|
||||
# Update length if present
|
||||
if "length" in sample:
|
||||
sample["length"][i] = sequence_len
|
||||
|
||||
result = sample
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
"""
|
||||
Prepares training and evaluation datasets for sample packing and model-specific requirements.
|
||||
|
||||
Removes unnecessary columns based on model type, filters out samples with no trainable tokens, and optionally adds length or position ID columns for sample packing or PoSE techniques. Returns the processed training and evaluation datasets.
|
||||
"""
|
||||
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
|
||||
if drop_attn_mask:
|
||||
LOG.info("dropping attention_mask column")
|
||||
@@ -370,15 +492,48 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
|
||||
|
||||
def process_pretraining_datasets_for_packing(
|
||||
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
|
||||
train_dataset,
|
||||
sequence_len,
|
||||
skip_position_ids=True,
|
||||
drop_attention_mask=False,
|
||||
handling="drop",
|
||||
):
|
||||
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
|
||||
|
||||
train_dataset = train_dataset.filter(
|
||||
drop_long,
|
||||
desc="Dropping Long Sequences",
|
||||
load_from_cache_file=False,
|
||||
# Define the function to use for handling sequences based on the mode
|
||||
"""
|
||||
Processes a pretraining dataset by truncating or dropping sequences based on length.
|
||||
|
||||
Depending on the handling mode, sequences longer than `sequence_len` are either truncated or dropped, and sequences shorter than `min_sequence_len` are dropped. Optionally adds position IDs and removes the attention mask column.
|
||||
|
||||
Args:
|
||||
train_dataset: The dataset to process.
|
||||
sequence_len: Maximum allowed sequence length.
|
||||
skip_position_ids: If False, adds position IDs to each sample.
|
||||
drop_attention_mask: If True, removes the attention mask column.
|
||||
handling: "drop" to remove long sequences, "truncate" to truncate them.
|
||||
|
||||
Returns:
|
||||
The processed dataset with sequences handled according to the specified mode.
|
||||
"""
|
||||
seq_handler_fn = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=sequence_len,
|
||||
handling=handling, # Pass handling mode
|
||||
)
|
||||
|
||||
# Use map for truncate mode and filter for drop mode
|
||||
if handling == "truncate":
|
||||
train_dataset = train_dataset.map(
|
||||
seq_handler_fn,
|
||||
desc="Truncating Long Sequences",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
else: # handling == "drop"
|
||||
train_dataset = train_dataset.filter(
|
||||
seq_handler_fn, # Use the same function, it returns boolean for drop mode
|
||||
desc="Dropping Long Sequences",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
|
||||
if not skip_position_ids:
|
||||
train_dataset = train_dataset.map(
|
||||
add_position_ids,
|
||||
|
||||
@@ -4,6 +4,7 @@ shared pytest fixtures
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
@@ -529,31 +530,32 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
|
||||
|
||||
|
||||
# # pylint: disable=redefined-outer-name,unused-argument
|
||||
# def test_load_fixtures(
|
||||
# download_smollm2_135m_model,
|
||||
# download_llama_68m_random_model,
|
||||
# download_qwen_2_5_half_billion_model,
|
||||
# download_tatsu_lab_alpaca_dataset,
|
||||
# download_mhenrichsen_alpaca_2k_dataset,
|
||||
# download_mhenrichsen_alpaca_2k_w_revision_dataset,
|
||||
# download_mlabonne_finetome_100k_dataset,
|
||||
# download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
|
||||
# download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset,
|
||||
# download_fozzie_alpaca_dpo_dataset,
|
||||
# download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
|
||||
# download_argilla_dpo_pairs_dataset,
|
||||
# download_tiny_shakespeare_dataset,
|
||||
# download_deepseek_model_fixture,
|
||||
# download_huggyllama_model_fixture,
|
||||
# download_llama_1b_model_fixture,
|
||||
# download_llama3_8b_model_fixture,
|
||||
# download_llama3_8b_instruct_model_fixture,
|
||||
# download_phi_35_mini_model_fixture,
|
||||
# download_phi_3_medium_model_fixture,
|
||||
# download_mistral_7b_model_fixture,
|
||||
# download_gemma_2b_model_fixture,
|
||||
# download_gemma2_9b_model_fixture,
|
||||
# download_mlx_mistral_7b_model_fixture,
|
||||
# download_llama2_model_fixture,
|
||||
# ):
|
||||
# pass
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",
|
||||
reason="Not running in CI cache preload",
|
||||
)
|
||||
def test_load_fixtures(
|
||||
download_smollm2_135m_model,
|
||||
download_qwen_2_5_half_billion_model,
|
||||
download_tatsu_lab_alpaca_dataset,
|
||||
download_mhenrichsen_alpaca_2k_dataset,
|
||||
download_mhenrichsen_alpaca_2k_w_revision_dataset,
|
||||
download_mlabonne_finetome_100k_dataset,
|
||||
download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
|
||||
download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
|
||||
download_argilla_dpo_pairs_dataset,
|
||||
download_tiny_shakespeare_dataset,
|
||||
download_deepseek_model_fixture,
|
||||
download_huggyllama_model_fixture,
|
||||
download_llama_1b_model_fixture,
|
||||
download_llama3_8b_model_fixture,
|
||||
download_llama3_8b_instruct_model_fixture,
|
||||
download_phi_35_mini_model_fixture,
|
||||
download_phi_3_medium_model_fixture,
|
||||
download_mistral_7b_model_fixture,
|
||||
download_gemma_2b_model_fixture,
|
||||
download_gemma2_9b_model_fixture,
|
||||
download_mlx_mistral_7b_model_fixture,
|
||||
download_llama2_model_fixture,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -29,6 +29,12 @@ class LogHooksPlugin(BasePlugin):
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("post_trainer_create\n")
|
||||
|
||||
def pre_model_load(self, cfg): # pylint: disable=unused-argument
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
@@ -165,6 +171,7 @@ class TestPluginHooks:
|
||||
) as f:
|
||||
file_contents = f.readlines()
|
||||
file_contents = "\n".join(file_contents)
|
||||
assert "post_trainer_create" in file_contents
|
||||
assert "pre_model_load" in file_contents
|
||||
assert "post_model_build" in file_contents
|
||||
assert "pre_lora_load" in file_contents
|
||||
|
||||
@@ -90,7 +90,7 @@ class TestKnowledgeDistillation:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -121,5 +121,5 @@ class TestKnowledgeDistillation:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@@ -25,6 +25,7 @@ class TestSequenceParallelism:
|
||||
micro_batch_size=1,
|
||||
pad_to_sequence_len=True,
|
||||
ring_attn_func=None,
|
||||
threshold=2.0,
|
||||
):
|
||||
"""Helper method to run sequence parallel tests with different configurations"""
|
||||
cfg = DictDefault(
|
||||
@@ -93,22 +94,22 @@ class TestSequenceParallelism:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", threshold, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func",
|
||||
"sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func, threshold",
|
||||
[
|
||||
(True, 1, True, None), # defaults to varlen_llama3 ring_attn_func
|
||||
(False, 2, True, None), # defaults to batch_ring ring_attn_func
|
||||
(False, 2, True, "batch_zigzag"),
|
||||
# (False, 2, False), # not yet working
|
||||
(True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func
|
||||
(False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func
|
||||
# (False, 2, True, "batch_zigzag", 2.5),
|
||||
(False, 2, False, None, 2.5), # defaults to batch_ring ring_attn_func
|
||||
],
|
||||
ids=[
|
||||
"sample_packing, varlen_llama3 ring_attn_func",
|
||||
"no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
# "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
|
||||
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
"no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
|
||||
# "no sample_packing, pad_to_sequence_len", # not yet working
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_training(
|
||||
@@ -118,6 +119,7 @@ class TestSequenceParallelism:
|
||||
micro_batch_size,
|
||||
pad_to_sequence_len,
|
||||
ring_attn_func,
|
||||
threshold,
|
||||
):
|
||||
"""Test sequence parallel training with different configurations"""
|
||||
self._run_sequence_parallel_test(
|
||||
@@ -126,4 +128,5 @@ class TestSequenceParallelism:
|
||||
micro_batch_size=micro_batch_size,
|
||||
pad_to_sequence_len=pad_to_sequence_len,
|
||||
ring_attn_func=ring_attn_func,
|
||||
threshold=threshold,
|
||||
)
|
||||
|
||||
@@ -166,6 +166,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"""
|
||||
)
|
||||
|
||||
@pytest.mark.skip(reason="flaky test")
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
@@ -227,7 +228,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
|
||||
current_env = os.environ.copy()
|
||||
env = {
|
||||
"NCCL_P2P_LEVEL": "NVL",
|
||||
"NCCL_P2P_LEVEL": "LOC",
|
||||
**current_env,
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
||||
@@ -257,7 +258,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
],
|
||||
env={
|
||||
"NCCL_P2P_LEVEL": "NVL",
|
||||
"NCCL_P2P_LEVEL": "LOC",
|
||||
"NCCL_DEBUG": "INFO",
|
||||
**current_env,
|
||||
},
|
||||
@@ -265,6 +266,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
finally:
|
||||
recursive_kill(vllm_process)
|
||||
|
||||
@pytest.mark.skip(reason="flaky test")
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
@@ -320,7 +322,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
|
||||
current_env = os.environ.copy()
|
||||
env = {
|
||||
"NCCL_P2P_LEVEL": "NVL", # nccl can be brittle, assume P2P isn't reliable
|
||||
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
||||
**current_env,
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
||||
@@ -350,7 +352,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
],
|
||||
env={
|
||||
"NCCL_P2P_LEVEL": "NVL",
|
||||
"NCCL_P2P_LEVEL": "LOC",
|
||||
"NCCL_DEBUG": "INFO",
|
||||
**current_env,
|
||||
},
|
||||
|
||||
@@ -479,7 +479,7 @@ class TestMultiGPULlama:
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
|
||||
@@ -29,12 +29,12 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
MODEL_CONFIGS = [
|
||||
{
|
||||
"name": "openaccess-ai-collective/tiny-mistral",
|
||||
"name": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"expected_activation": apply_lora_mlp_swiglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
{
|
||||
"name": "Qwen/Qwen2-7B",
|
||||
"name": "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
"expected_activation": apply_lora_mlp_swiglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
@@ -44,7 +44,7 @@ MODEL_CONFIGS = [
|
||||
"dtype": torch.float32,
|
||||
},
|
||||
{
|
||||
"name": "mhenrichsen/gemma-2b",
|
||||
"name": "trl-internal-testing/tiny-Gemma2ForCausalLM",
|
||||
"expected_activation": apply_lora_mlp_geglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
@@ -156,7 +156,9 @@ def test_swiglu_mlp_integration(small_llama_model):
|
||||
def test_geglu_model_integration():
|
||||
"""Test GeGLU activation with Gemma model."""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda:0"
|
||||
"trl-internal-testing/tiny-Gemma2ForCausalLM",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda:0",
|
||||
)
|
||||
peft_config = get_peft_config(
|
||||
{
|
||||
|
||||
@@ -57,9 +57,9 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 5,
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"fp16": True,
|
||||
}
|
||||
)
|
||||
@@ -105,9 +105,9 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 5,
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"fp16": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -26,10 +26,15 @@ class TestActivationCheckpointing:
|
||||
E2E tests for activation checkpointing
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_checkpointing",
|
||||
["offload", "offload_disk"],
|
||||
)
|
||||
def test_activation_checkpointing_offload(
|
||||
self,
|
||||
temp_dir,
|
||||
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
|
||||
gradient_checkpointing,
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
@@ -64,7 +69,7 @@ class TestActivationCheckpointing:
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"gradient_checkpointing": "offload",
|
||||
"gradient_checkpointing": gradient_checkpointing,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
@@ -23,6 +25,7 @@ class TestFalconPatched(unittest.TestCase):
|
||||
Test case for Falcon models
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_qlora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -71,6 +74,7 @@ class TestFalconPatched(unittest.TestCase):
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestMistral(unittest.TestCase):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 1024,
|
||||
@@ -57,9 +57,9 @@ class TestMistral(unittest.TestCase):
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 5,
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
@@ -76,7 +76,7 @@ class TestMistral(unittest.TestCase):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 1024,
|
||||
@@ -99,9 +99,9 @@ class TestMistral(unittest.TestCase):
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 5,
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -54,9 +54,9 @@ class TestMixtral(unittest.TestCase):
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 5,
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
@@ -93,9 +93,9 @@ class TestMixtral(unittest.TestCase):
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"max_steps": 5,
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -56,7 +56,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
def test_mistral_multipack(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 2048,
|
||||
|
||||
63
tests/e2e/patched/test_peft_embeddings.py
Normal file
63
tests/e2e/patched/test_peft_embeddings.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
Test case for handling embeddings when using peft
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.train import setup_model_and_tokenizer
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestLlamaPeftEmbeddings:
|
||||
"""
|
||||
test class for handling embeddings when using peft
|
||||
"""
|
||||
|
||||
def test_peft_embeddings_upcast(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_target_linear": True,
|
||||
"trust_remote_code": True,
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.01,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"sample_packing": False,
|
||||
"bf16": "auto",
|
||||
"save_safetensors": True,
|
||||
"embeddings_skip_upcast": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
model, _, _, _ = setup_model_and_tokenizer(cfg)
|
||||
|
||||
# Check if the embeddings are upcast correctly
|
||||
# only embed_tokens is a parameter that may be upcast
|
||||
assert model.base_model.model.model.embed_tokens.weight.dtype == torch.bfloat16
|
||||
assert model.base_model.model.lm_head.weight.dtype == torch.bfloat16
|
||||
@@ -56,9 +56,9 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"eval_steps": 10,
|
||||
"save_steps": 10,
|
||||
"max_steps": 5,
|
||||
"eval_steps": 3,
|
||||
"save_steps": 4,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
@@ -108,9 +108,9 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"eval_steps": 10,
|
||||
"save_steps": 10,
|
||||
"max_steps": 5,
|
||||
"eval_steps": 3,
|
||||
"save_steps": 4,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, most_recent_subdir
|
||||
from ..utils import check_model_output_exists, most_recent_subdir, require_torch_2_6_0
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -26,6 +26,7 @@ class TestResumeLlama:
|
||||
Test case for resuming training of llama models
|
||||
"""
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_resume_lora_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
@@ -62,6 +63,7 @@ class TestResumeLlama:
|
||||
"save_total_limit": 5,
|
||||
"max_steps": 15,
|
||||
"use_tensorboard": True,
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
|
||||
@@ -10,14 +10,15 @@ import pytest
|
||||
import torch
|
||||
from accelerate.state import PartialState
|
||||
|
||||
from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism
|
||||
from axolotl.monkeypatch.attention.ring_attn import (
|
||||
RingAttnFunc,
|
||||
get_ring_attn_group,
|
||||
register_ring_attn,
|
||||
set_ring_attn_group,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -62,12 +63,14 @@ def sequence_parallel_batch():
|
||||
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
|
||||
attention_mask = torch.ones(batch_size, seq_len)
|
||||
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
|
||||
labels = input_ids.clone()
|
||||
|
||||
# Create test batch
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
return batch
|
||||
@@ -179,12 +182,44 @@ class TestConfigValidation:
|
||||
False,
|
||||
"micro_batch_size must be set to 1",
|
||||
),
|
||||
# Valid: Basic GRPO config
|
||||
(
|
||||
{
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
},
|
||||
{
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": TRLConfig(use_liger_loss=True),
|
||||
},
|
||||
True,
|
||||
"GRPO + SP + Liger not currently supported",
|
||||
),
|
||||
# Invalid: GRPO config with Liger loss
|
||||
(
|
||||
{
|
||||
"rl": "grpo",
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
},
|
||||
None,
|
||||
False,
|
||||
"GRPO + SP + Liger not currently supported",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"valid_config",
|
||||
"default_sp_degree",
|
||||
"without_flash_attention",
|
||||
"sample_packing_with_large_batch",
|
||||
"valid_grpo",
|
||||
"grpo_with_liger_loss",
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_config_validation(
|
||||
@@ -256,7 +291,7 @@ class TestConfigValidation:
|
||||
AxolotlInputConfig(**cfg)
|
||||
|
||||
# Verify error message
|
||||
assert "ring_attn_func: INVALID_FUNC must be in" in str(excinfo.value)
|
||||
assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value)
|
||||
|
||||
|
||||
class TestApplySequenceParallelism:
|
||||
@@ -290,10 +325,11 @@ class TestApplySequenceParallelism:
|
||||
|
||||
def test_world_size_one(self, sequence_parallel_batch):
|
||||
"""Test that function returns original batch when world size is 1."""
|
||||
result = apply_sequence_parallelism(
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=sequence_parallel_batch,
|
||||
local_rank=0,
|
||||
local_world_size=1,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
@@ -305,10 +341,11 @@ class TestApplySequenceParallelism:
|
||||
batch = sequence_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
|
||||
result = apply_sequence_parallelism(
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
@@ -328,57 +365,59 @@ class TestApplySequenceParallelism:
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
result = apply_sequence_parallelism(
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=1,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Verify content: rank 1 should get the second half of the sequence
|
||||
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
|
||||
|
||||
def test_batch_zigzag(self, sequence_parallel_batch):
|
||||
"""Test BATCH_ZIGZAG sharding pattern."""
|
||||
batch = sequence_parallel_batch
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
# TODO(djsaunde): add back once implemented.
|
||||
# def test_batch_zigzag(self, sequence_parallel_batch):
|
||||
# """Test BATCH_ZIGZAG sharding pattern."""
|
||||
# batch = sequence_parallel_batch
|
||||
# original_input_ids = batch["input_ids"].clone()
|
||||
# seq_len = batch["input_ids"].size(1)
|
||||
|
||||
# Test rank 0
|
||||
result_rank0 = apply_sequence_parallelism(
|
||||
batch={k: v.clone() for k, v in batch.items()},
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
|
||||
)
|
||||
# # Test rank 0
|
||||
# result_rank0 = apply_sequence_parallelism(
|
||||
# batch={k: v.clone() for k, v in batch.items()},
|
||||
# local_rank=0,
|
||||
# local_world_size=2,
|
||||
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
|
||||
# )
|
||||
|
||||
# Test rank 1
|
||||
result_rank1 = apply_sequence_parallelism(
|
||||
batch={k: v.clone() for k, v in batch.items()},
|
||||
local_rank=1,
|
||||
local_world_size=2,
|
||||
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
|
||||
)
|
||||
# # Test rank 1
|
||||
# result_rank1 = apply_sequence_parallelism(
|
||||
# batch={k: v.clone() for k, v in batch.items()},
|
||||
# local_rank=1,
|
||||
# local_world_size=2,
|
||||
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
|
||||
# )
|
||||
|
||||
# Checks for both ranks
|
||||
assert result_rank0["input_ids"].shape[1] == seq_len // 2
|
||||
assert result_rank1["input_ids"].shape[1] == seq_len // 2
|
||||
# # Checks for both ranks
|
||||
# assert result_rank0["input_ids"].shape[1] == seq_len // 2
|
||||
# assert result_rank1["input_ids"].shape[1] == seq_len // 2
|
||||
|
||||
# For a 2-rank system with 8 tokens, check specific zigzag pattern
|
||||
# Rank 0 should get chunks [0, 1] and [6, 7]
|
||||
# Rank 1 should get chunks [2, 3] and [4, 5]
|
||||
if seq_len == 8:
|
||||
# Create expected tensors for comparison
|
||||
rank0_expected = torch.cat(
|
||||
[original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
|
||||
)
|
||||
# # For a 2-rank system with 8 tokens, check specific zigzag pattern
|
||||
# # Rank 0 should get chunks [0, 1] and [6, 7]
|
||||
# # Rank 1 should get chunks [2, 3] and [4, 5]
|
||||
# if seq_len == 8:
|
||||
# # Create expected tensors for comparison
|
||||
# rank0_expected = torch.cat(
|
||||
# [original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
|
||||
# )
|
||||
|
||||
rank1_expected = torch.cat(
|
||||
[original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
|
||||
)
|
||||
# rank1_expected = torch.cat(
|
||||
# [original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
|
||||
# )
|
||||
|
||||
assert torch.equal(result_rank0["input_ids"], rank0_expected)
|
||||
assert torch.equal(result_rank1["input_ids"], rank1_expected)
|
||||
# assert torch.equal(result_rank0["input_ids"], rank0_expected)
|
||||
# assert torch.equal(result_rank1["input_ids"], rank1_expected)
|
||||
|
||||
def test_partial_application(self, sequence_parallel_batch):
|
||||
"""Test that we can create a partially applied version of the function."""
|
||||
@@ -390,11 +429,12 @@ class TestApplySequenceParallelism:
|
||||
apply_sequence_parallelism,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Use the partially applied function
|
||||
result = rank0_ring_parallel(batch=batch)
|
||||
result, _, _ = rank0_ring_parallel(batch=batch)
|
||||
|
||||
# Verify it works as expected
|
||||
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
|
||||
@@ -412,13 +452,15 @@ class TestApplySequenceParallelism:
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
# This should run without error even though position_ids is missing
|
||||
result = apply_sequence_parallelism(
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Verification should pass
|
||||
assert "position_ids" not in result
|
||||
assert "position_ids" in result
|
||||
assert result["input_ids"].shape[1] == result["position_ids"].shape[1]
|
||||
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
|
||||
|
||||
@@ -19,14 +19,11 @@ class TestE2eEvaluate:
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
|
||||
@@ -6,6 +6,8 @@ import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
@@ -23,6 +25,7 @@ class TestFalcon(unittest.TestCase):
|
||||
Test case for falcon
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -74,6 +77,7 @@ class TestFalcon(unittest.TestCase):
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_lora_added_vocab(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -129,6 +133,7 @@ class TestFalcon(unittest.TestCase):
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
@@ -30,7 +30,7 @@ class TestMistral(unittest.TestCase):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
@@ -77,7 +77,7 @@ class TestMistral(unittest.TestCase):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.02,
|
||||
|
||||
@@ -199,3 +199,50 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_came_pytorch(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "came_pytorch",
|
||||
"adam_beta3": 0.9999,
|
||||
"adam_epsilon2": 1e-16,
|
||||
"max_steps": 5,
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -648,7 +648,7 @@ class TestValidation(BaseValidation):
|
||||
DictDefault(
|
||||
{
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": None,
|
||||
"pad_to_sequence_len": False,
|
||||
"flash_attention": True,
|
||||
}
|
||||
)
|
||||
@@ -662,6 +662,26 @@ class TestValidation(BaseValidation):
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
def test_packing_autoset(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": None,
|
||||
"flash_attention": True,
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
with self._caplog.at_level(logging.INFO):
|
||||
cfg = validate_config(cfg)
|
||||
assert any(
|
||||
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
|
||||
in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
assert cfg.pad_to_sequence_len is True
|
||||
|
||||
def test_merge_lora_no_bf16_fail(self, minimal_cfg):
|
||||
"""
|
||||
This is assumed to be run on a CPU machine, so bf16 is not supported.
|
||||
|
||||
@@ -3,10 +3,12 @@ test module for the axolotl.utils.data module
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
from axolotl.utils.data import encode_pretraining, md5
|
||||
from axolotl.utils.data.rl import drop_long_rl_seq
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
@@ -58,11 +60,328 @@ class TestEncodePretraining(unittest.TestCase):
|
||||
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
|
||||
|
||||
def test_md5(self):
|
||||
"""
|
||||
Tests that the md5 function returns the correct hash for a given string and encoding.
|
||||
"""
|
||||
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
|
||||
self.assertEqual(
|
||||
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
|
||||
)
|
||||
|
||||
|
||||
class TestDropLongRLSeq(unittest.TestCase):
|
||||
"""
|
||||
Tests for the drop_long_rl_seq function.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
# Mock tokenizer that returns length based on input string length
|
||||
"""
|
||||
Sets up a mock tokenizer and sequence length for RL sequence length tests.
|
||||
|
||||
The mock tokenizer simulates tokenization by returning input IDs equal to the input string's length and decodes tokens as repeated "x" characters. The sequence length limit is set to 20.
|
||||
"""
|
||||
self.tokenizer = MagicMock()
|
||||
|
||||
def side_effect_func(
|
||||
text, add_special_tokens=False
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
Simulates tokenization by returning input IDs as a sequence of integers equal to the input text length.
|
||||
|
||||
Args:
|
||||
text: The input string to tokenize.
|
||||
add_special_tokens: Ignored parameter included for interface compatibility.
|
||||
|
||||
Returns:
|
||||
A dictionary with 'input_ids' as a list of integers from 0 to len(text) - 1.
|
||||
"""
|
||||
return {"input_ids": list(range(len(text)))}
|
||||
|
||||
self.tokenizer.side_effect = side_effect_func
|
||||
self.tokenizer.decode = lambda tokens, skip_special_tokens: "".join(
|
||||
["x"] * len(tokens)
|
||||
) # pylint: disable=unused-argument
|
||||
|
||||
self.sequence_len = 20
|
||||
|
||||
def test_dpo_drop_mode_valid(self):
|
||||
"""
|
||||
Tests that drop_long_rl_seq returns True in drop mode for a DPO sample within the sequence length limit.
|
||||
"""
|
||||
sample = {
|
||||
"prompt": "p" * 5,
|
||||
"chosen": "c" * 7,
|
||||
"rejected": "r" * 6,
|
||||
} # 5+7=12 <= 20, 5+6=11 <= 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_dpo_drop_mode_invalid_chosen(self):
|
||||
"""
|
||||
Tests that in DPO drop mode, a sample is rejected when the prompt and chosen lengths exceed the sequence limit.
|
||||
"""
|
||||
sample = {
|
||||
"prompt": "p" * 5,
|
||||
"chosen": "c" * 16,
|
||||
"rejected": "r" * 6,
|
||||
} # 5+16=21 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_dpo_drop_mode_invalid_rejected(self):
|
||||
"""
|
||||
Tests that in DPO drop mode, a sample is rejected when the prompt plus rejected response exceeds the sequence length limit.
|
||||
"""
|
||||
sample = {
|
||||
"prompt": "p" * 5,
|
||||
"chosen": "c" * 7,
|
||||
"rejected": "r" * 16,
|
||||
} # 5+16=21 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_dpo_truncate_mode_no_truncation_needed(self):
|
||||
"""
|
||||
Verifies that in DPO truncate mode, samples within the sequence length limit are returned unchanged.
|
||||
"""
|
||||
sample = {
|
||||
"prompt": "p" * 5,
|
||||
"chosen": "c" * 7,
|
||||
"rejected": "r" * 6,
|
||||
} # 5+7=12 <= 20, 5+6=11 <= 20
|
||||
original_sample = sample.copy()
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(
|
||||
result, original_sample
|
||||
) # Should return the original sample unchanged
|
||||
|
||||
def test_dpo_truncate_mode_prompt_too_long(self):
|
||||
"""
|
||||
Tests that in DPO truncate mode, if the prompt exceeds the sequence length limit,
|
||||
the original sample is returned unchanged.
|
||||
"""
|
||||
sample = {"prompt": "p" * 25, "chosen": "c" * 7, "rejected": "r" * 6}
|
||||
original_sample = sample.copy()
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
# Even though truncation isn't possible, the function should return the original sample
|
||||
# for the map operation, assuming downstream filtering will catch it.
|
||||
self.assertEqual(result, original_sample)
|
||||
|
||||
def test_dpo_truncate_mode_chosen_truncated(self):
|
||||
"""
|
||||
Tests that in DPO truncate mode, only the 'chosen' field is truncated when it exceeds the allowed sequence length, while 'prompt' and 'rejected' remain unchanged.
|
||||
"""
|
||||
prompt_len = 5
|
||||
max_resp_len = self.sequence_len - prompt_len # 20 - 5 = 15
|
||||
sample = {
|
||||
"prompt": "p" * prompt_len,
|
||||
"chosen": "c" * 18,
|
||||
"rejected": "r" * 10,
|
||||
} # 5+18=23 > 20, 5+10=15 <= 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 15
|
||||
self.assertEqual(
|
||||
result["chosen"], "x" * max_resp_len
|
||||
) # Check decoded truncated value
|
||||
self.assertEqual(len(result["rejected"]), 10) # Unchanged
|
||||
|
||||
def test_dpo_truncate_mode_rejected_truncated(self):
|
||||
"""
|
||||
Tests that in DPO truncate mode, only the 'rejected' field is truncated when it exceeds the sequence length limit, while 'prompt' and 'chosen' remain unchanged.
|
||||
"""
|
||||
prompt_len = 5
|
||||
max_resp_len = self.sequence_len - prompt_len # 15
|
||||
sample = {
|
||||
"prompt": "p" * prompt_len,
|
||||
"chosen": "c" * 10,
|
||||
"rejected": "r" * 18,
|
||||
} # 5+10=15 <= 20, 5+18=23 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||
self.assertEqual(len(result["chosen"]), 10) # Unchanged
|
||||
self.assertEqual(len(result["rejected"]), max_resp_len) # Truncated to 15
|
||||
self.assertEqual(
|
||||
result["rejected"], "x" * max_resp_len
|
||||
) # Check decoded truncated value
|
||||
|
||||
def test_dpo_truncate_mode_both_truncated(self):
|
||||
"""
|
||||
Tests that in DPO truncate mode, both 'chosen' and 'rejected' fields are truncated when their combined lengths with the prompt exceed the sequence limit.
|
||||
|
||||
Verifies that both fields are truncated to fit within the allowed response length and replaced with decoded placeholder content.
|
||||
"""
|
||||
prompt_len = 8
|
||||
max_resp_len = self.sequence_len - prompt_len # 20 - 8 = 12
|
||||
sample = {
|
||||
"prompt": "p" * prompt_len,
|
||||
"chosen": "c" * 15,
|
||||
"rejected": "r" * 14,
|
||||
} # 8+15=23 > 20, 8+14=22 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 12
|
||||
self.assertEqual(result["chosen"], "x" * max_resp_len)
|
||||
self.assertEqual(len(result["rejected"]), max_resp_len) # Truncated to 12
|
||||
self.assertEqual(result["rejected"], "x" * max_resp_len)
|
||||
|
||||
def test_dpo_truncate_mode_no_truncation_needed_but_long(self):
|
||||
"""
|
||||
Tests DPO truncate mode where only the overlong response is truncated.
|
||||
|
||||
Verifies that when the prompt plus one response exceeds the sequence length, only the response exceeding the maximum allowed length is truncated, while the other remains unchanged.
|
||||
"""
|
||||
# This tests the case where len(chosen) <= max_resp_len and len(rejected) <= max_resp_len
|
||||
# but the initial check failed because e.g. prompt + chosen > sequence_len
|
||||
# The current logic *will* truncate if len(chosen) > max_resp_len.
|
||||
# Let's test a case where one is slightly too long causing the initial fail,
|
||||
# but the other fits *within* the max_response_len, so only one gets truncated.
|
||||
prompt_len = 10
|
||||
max_resp_len = self.sequence_len - prompt_len # 10
|
||||
sample = {
|
||||
"prompt": "p" * prompt_len,
|
||||
"chosen": "c" * 11,
|
||||
"rejected": "r" * 9,
|
||||
} # 10+11=21 > 20, 10+9=19 <= 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 10
|
||||
self.assertEqual(result["chosen"], "x" * max_resp_len)
|
||||
self.assertEqual(len(result["rejected"]), 9) # Unchanged, as 9 <= 10
|
||||
|
||||
# Add similar tests for KTO if needed, checking prompt + completion length
|
||||
|
||||
def test_kto_drop_mode_valid(self):
|
||||
"""
|
||||
Tests that drop_long_rl_seq returns True for a KTO sample within the sequence length limit.
|
||||
"""
|
||||
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_kto_drop_mode_invalid(self):
|
||||
"""
|
||||
Tests that drop_long_rl_seq returns False when a KTO sample exceeds the sequence length limit in drop mode.
|
||||
"""
|
||||
sample = {"prompt": "p" * 5, "completion": "c" * 16} # 5+16=21 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_kto_truncate_mode_no_truncation_needed(self):
|
||||
"""
|
||||
Tests that KTO truncate mode returns the original sample unchanged when the combined prompt and completion length does not exceed the sequence limit.
|
||||
"""
|
||||
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
|
||||
original_sample = sample.copy()
|
||||
result = drop_long_rl_seq(
|
||||
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(result, original_sample)
|
||||
|
||||
def test_kto_truncate_mode_prompt_too_long(self):
|
||||
"""
|
||||
Tests that in KTO truncate mode, if the prompt exceeds the sequence length limit, the original sample is returned unchanged.
|
||||
"""
|
||||
sample = {"prompt": "p" * 25, "completion": "c" * 7}
|
||||
original_sample = sample.copy()
|
||||
result = drop_long_rl_seq(
|
||||
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(result, original_sample) # Returns original sample
|
||||
|
||||
def test_kto_truncate_mode_completion_truncated(self):
|
||||
"""
|
||||
Tests that in KTO truncate mode, the completion is truncated when the combined prompt and completion exceed the sequence length limit.
|
||||
|
||||
Verifies that the prompt remains unchanged and the completion is truncated to fit within the allowed length, with the truncated completion replaced by decoded "x" characters.
|
||||
"""
|
||||
prompt_len = 8
|
||||
max_comp_len = self.sequence_len - prompt_len # 20 - 8 = 12
|
||||
sample = {"prompt": "p" * prompt_len, "completion": "c" * 15} # 8+15=23 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||
self.assertEqual(len(result["completion"]), max_comp_len) # Truncated to 12
|
||||
self.assertEqual(result["completion"], "x" * max_comp_len)
|
||||
|
||||
def test_missing_keys_dpo(self):
|
||||
"""
|
||||
Tests that a ValueError is raised when required keys are missing for DPO samples.
|
||||
|
||||
Verifies that the function raises an error if the sample does not contain 'chosen' and 'rejected' keys.
|
||||
"""
|
||||
sample = {"prompt": "p"}
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Prompt, chosen and rejected keys are required"
|
||||
):
|
||||
drop_long_rl_seq(sample, "dpo", self.tokenizer, self.sequence_len)
|
||||
|
||||
def test_missing_keys_kto(self):
|
||||
"""
|
||||
Tests that a ValueError is raised when required keys are missing for RL type "kto".
|
||||
|
||||
Verifies that calling drop_long_rl_seq with a sample missing the "completion" key raises
|
||||
a ValueError with the expected error message.
|
||||
"""
|
||||
sample = {"prompt": "p"}
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Prompt and completion keys are required"
|
||||
):
|
||||
drop_long_rl_seq(sample, "kto", self.tokenizer, self.sequence_len)
|
||||
|
||||
def test_unknown_rl_type(self):
|
||||
"""
|
||||
Tests that a ValueError is raised when an unknown RL type is provided to drop_long_rl_seq.
|
||||
"""
|
||||
sample = {}
|
||||
with self.assertRaisesRegex(ValueError, "Unknown RL type"):
|
||||
drop_long_rl_seq(sample, "xyz", self.tokenizer, self.sequence_len)
|
||||
|
||||
# GRPO test - current implementation always passes
|
||||
def test_grpo_drop(self):
|
||||
"""
|
||||
Tests that drop_long_rl_seq in GRPO drop mode always returns True, regardless of input.
|
||||
"""
|
||||
sample = {}
|
||||
result = drop_long_rl_seq(
|
||||
sample, "grpo", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_grpo_truncate(self):
|
||||
"""
|
||||
Tests that in truncate mode for RL type "grpo", the original sample is returned unchanged.
|
||||
"""
|
||||
sample = {"a": 1}
|
||||
result = drop_long_rl_seq(
|
||||
sample, "grpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(result, sample)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -414,7 +414,6 @@ class TestDatasetPreparation:
|
||||
snapshot_path = snapshot_download(
|
||||
repo_id="mhenrichsen/alpaca_2k_test",
|
||||
repo_type="dataset",
|
||||
local_dir=tmp_ds_path,
|
||||
)
|
||||
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
|
||||
|
||||
|
||||
@@ -106,3 +106,4 @@ class TestBatchedSamplerPacking:
|
||||
|
||||
original_idxs = set(range(len(train_dataset)))
|
||||
assert original_idxs == set(batch_idxs)
|
||||
assert len(batch_idxs) == len(set(batch_idxs))
|
||||
|
||||
175
tests/test_trainer_utils.py
Normal file
175
tests/test_trainer_utils.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Module containing tests for trainer utility functions."""
|
||||
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
from axolotl.utils.trainer import truncate_or_drop_long_seq
|
||||
|
||||
|
||||
# Test cases for truncate_or_drop_long_seq
|
||||
class TestTruncateOrDropLongSeq(unittest.TestCase):
|
||||
"""
|
||||
Test suite for truncate_or_drop_long_seq function.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
# Example sequence length settings
|
||||
"""
|
||||
Sets up default sequence length parameters for the test cases.
|
||||
"""
|
||||
self.sequence_len = 10
|
||||
self.min_sequence_len = 3
|
||||
|
||||
def test_drop_mode_single(self):
|
||||
"""
|
||||
Verifies that 'drop' mode correctly filters single sequence examples based on length.
|
||||
|
||||
Tests that sequences shorter than the minimum, longer than the maximum, or empty are dropped,
|
||||
while sequences within the valid length range are kept.
|
||||
"""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="drop",
|
||||
)
|
||||
|
||||
# Too short
|
||||
sample_short = {"input_ids": [1, 2]}
|
||||
self.assertFalse(handler(sample_short))
|
||||
|
||||
# Too long
|
||||
sample_long = {"input_ids": list(range(self.sequence_len + 1))}
|
||||
self.assertFalse(handler(sample_long))
|
||||
|
||||
# Just right
|
||||
sample_ok = {"input_ids": list(range(self.min_sequence_len))}
|
||||
self.assertTrue(handler(sample_ok))
|
||||
|
||||
# Empty
|
||||
sample_empty = {"input_ids": []}
|
||||
self.assertFalse(handler(sample_empty))
|
||||
|
||||
def test_truncate_mode_single(self):
|
||||
"""
|
||||
Tests that 'truncate_or_drop_long_seq' correctly truncates or preserves single examples in "truncate" mode.
|
||||
|
||||
Verifies that sequences longer than the maximum length are truncated, while sequences that are too short, empty, or within the valid range remain unchanged.
|
||||
"""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="truncate",
|
||||
)
|
||||
|
||||
# Too short (should still be dropped implicitly by filter/map logic upstream,
|
||||
# but the function itself might return the sample or False based on impl.)
|
||||
# Current impl returns the original sample for map if too short, assuming upstream filters.
|
||||
# Let's refine this test - the function *itself* returns the sample if too short when truncating.
|
||||
sample_short = {"input_ids": [1, 2], "labels": [1, 2]}
|
||||
result_short = handler(sample_short)
|
||||
self.assertEqual(result_short["input_ids"], [1, 2]) # Unchanged
|
||||
|
||||
# Too long
|
||||
original_long = list(range(self.sequence_len + 5))
|
||||
sample_long = {"input_ids": list(original_long), "labels": list(original_long)}
|
||||
result_long = handler(sample_long)
|
||||
self.assertEqual(len(result_long["input_ids"]), self.sequence_len)
|
||||
self.assertEqual(result_long["input_ids"], list(range(self.sequence_len)))
|
||||
self.assertEqual(len(result_long["labels"]), self.sequence_len)
|
||||
self.assertEqual(result_long["labels"], list(range(self.sequence_len)))
|
||||
|
||||
# Just right
|
||||
sample_ok = {
|
||||
"input_ids": list(range(self.min_sequence_len)),
|
||||
"labels": list(range(self.min_sequence_len)),
|
||||
}
|
||||
result_ok = handler(sample_ok)
|
||||
self.assertEqual(len(result_ok["input_ids"]), self.min_sequence_len)
|
||||
self.assertEqual(result_ok, sample_ok) # Should be unchanged
|
||||
|
||||
# Empty
|
||||
sample_empty = {"input_ids": [], "labels": []}
|
||||
result_empty = handler(sample_empty)
|
||||
self.assertEqual(result_empty, sample_empty) # Unchanged
|
||||
|
||||
def test_drop_mode_batched(self):
|
||||
"""
|
||||
Tests that the "drop" handling mode correctly filters batched input sequences based on length constraints.
|
||||
|
||||
Verifies that sequences shorter than the minimum length, longer than the maximum length, or empty are dropped (returns False), while sequences within the valid range are kept (returns True).
|
||||
"""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="drop",
|
||||
)
|
||||
sample = {
|
||||
"input_ids": [
|
||||
[1, 2], # Too short
|
||||
list(range(self.sequence_len + 1)), # Too long
|
||||
list(range(self.sequence_len)), # OK (len = 10)
|
||||
list(range(self.min_sequence_len)), # OK (len = 3)
|
||||
[], # Empty
|
||||
]
|
||||
}
|
||||
expected = [False, False, True, True, False]
|
||||
self.assertEqual(handler(sample), expected)
|
||||
|
||||
def test_truncate_mode_batched(self):
|
||||
"""
|
||||
Tests that batched examples are correctly truncated in "truncate" mode.
|
||||
|
||||
Verifies that sequences in both "input_ids" and "labels" longer than the maximum
|
||||
allowed length are truncated, while sequences that are too short or empty remain
|
||||
unchanged.
|
||||
"""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="truncate",
|
||||
)
|
||||
sample = {
|
||||
"input_ids": [
|
||||
[1, 2], # Too short
|
||||
list(range(self.sequence_len + 5)), # Too long
|
||||
list(range(self.sequence_len)), # OK
|
||||
list(range(self.min_sequence_len)), # OK
|
||||
[], # Empty
|
||||
],
|
||||
"labels": [ # Add labels to test truncation
|
||||
[1, 2],
|
||||
list(range(self.sequence_len + 5)),
|
||||
list(range(self.sequence_len)),
|
||||
list(range(self.min_sequence_len)),
|
||||
[],
|
||||
],
|
||||
}
|
||||
|
||||
result = handler(sample)
|
||||
|
||||
# Expected results after truncation (too short and empty remain unchanged by this function)
|
||||
expected_input_ids = [
|
||||
[1, 2], # Unchanged (too short)
|
||||
list(range(self.sequence_len)), # Truncated
|
||||
list(range(self.sequence_len)), # Unchanged (OK)
|
||||
list(range(self.min_sequence_len)), # Unchanged (OK)
|
||||
[], # Unchanged (Empty)
|
||||
]
|
||||
expected_labels = [
|
||||
[1, 2], # Unchanged (too short)
|
||||
list(range(self.sequence_len)), # Truncated
|
||||
list(range(self.sequence_len)), # Unchanged (OK)
|
||||
list(range(self.min_sequence_len)), # Unchanged (OK)
|
||||
[], # Unchanged (Empty)
|
||||
]
|
||||
|
||||
self.assertEqual(result["input_ids"], expected_input_ids)
|
||||
self.assertEqual(result["labels"], expected_labels)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user