Compare commits

..

22 Commits

Author SHA1 Message Date
Dan Saunders
f144319697 distributed fix 2025-02-26 03:04:58 +00:00
Dan Saunders
07bb41812b fix issue with tests in ci 2025-02-26 03:04:58 +00:00
Dan Saunders
cae8c7636b fixes 2025-02-26 03:04:58 +00:00
Dan Saunders
09611cea10 remove duplicate info 2025-02-26 03:04:58 +00:00
Dan Saunders
61266ab843 adding runtime metrics / system info additional accelerator support, etc. 2025-02-26 03:04:58 +00:00
Dan Saunders
aea0e760e4 adding runtime metrics / system info additional accelerator support, etc. 2025-02-26 03:04:58 +00:00
Dan Saunders
5afad670da improved redaction, send system info during model config load telemetry, etc. 2025-02-26 03:04:58 +00:00
Dan Saunders
49ac79ed1e doc update 2025-02-26 03:04:58 +00:00
Dan Saunders
c9af72cd7a fix 2025-02-26 03:04:58 +00:00
Dan Saunders
d3d63c1432 adding back in base_model redaction w/ whitelist 2025-02-26 03:04:58 +00:00
Dan Saunders
675b65d711 sleep on all ranks in distributed setting 2025-02-26 03:04:58 +00:00
Dan Saunders
b23187daea simplifying path redaction 2025-02-26 03:04:58 +00:00
Dan Saunders
e373a6b8d0 small update / fix 2025-02-26 03:04:58 +00:00
Dan Saunders
fd5d5aecdc tests for runtime metrics telemetry and assoc. callback 2025-02-26 03:04:58 +00:00
Dan Saunders
3760175440 adding runtime metrics (cpu + gpu memory, steps/s, etc.) 2025-02-26 03:04:58 +00:00
Dan Saunders
7927abff90 updated sanitization logic, tests 2025-02-26 03:04:58 +00:00
Dan Saunders
ec36839316 update error file path sanitization function; adding more error tracking 2025-02-26 03:04:58 +00:00
Dan Saunders
3076b8df00 progress on telemetry: config load, process, model load, train start / end, error tracking 2025-02-26 03:04:58 +00:00
Dan Saunders
c50610375f updates 2025-02-26 03:04:58 +00:00
Dan Saunders
07ffd47f2b updates 2025-02-26 03:04:58 +00:00
Dan Saunders
76d951afd2 adding todo 2025-02-26 03:04:58 +00:00
Dan Saunders
5220e8ccf4 initial telemetry manager impl 2025-02-26 03:04:58 +00:00
235 changed files with 4544 additions and 4930 deletions

View File

@@ -40,12 +40,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: nightly
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -67,7 +61,7 @@ jobs:
uses: docker/build-push-action@v4 uses: docker/build-push-action@v4
with: with:
context: . context: .
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }} file: ./docker/Dockerfile-base
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -20,11 +20,9 @@ jobs:
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: '3.11' python-version: '3.11'
- name: Install dependencies - name: install dependencies
run: | run: |
python3 -m pip install jupyter quartodoc python3 -m pip install jupyter
- name: Build autodoc
run: quartodoc build
- name: Publish to GitHub Pages (and render) - name: Publish to GitHub Pages (and render)
uses: quarto-dev/quarto-actions/publish@v2 uses: quarto-dev/quarto-actions/publish@v2
with: with:

View File

@@ -88,11 +88,6 @@ jobs:
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -80,11 +80,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -1,49 +0,0 @@
name: Pre-commit auto-update
on:
schedule:
- cron: '0 0 * * 0' # Run weekly
workflow_dispatch: # Manual kickoff
jobs:
auto-update:
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Update pre-commit hooks
id: update
run: |
pip install pre-commit
pre-commit autoupdate
if [[ -n $(git status --porcelain) ]]; then
echo "changes=true" >> $GITHUB_OUTPUT
git diff .pre-commit-config.yaml > pre-commit-update.diff
fi
- name: Create Pull Request
if: steps.update.outputs.changes == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
branch: update/pre-commit-hooks
delete-branch: true
title: "chore: update pre-commit hooks"
commit-message: "chore: update pre-commit hooks"
body: |
Automated PR to update pre-commit hooks to their latest versions.
<details>
<summary>Changes:</summary>
```diff
${{ steps.update.outputs.diff }}
```
</details>

View File

@@ -40,7 +40,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 install wheel packaging==23.2 pip3 install wheel packaging
pip3 install --no-build-isolation -e . pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt

View File

@@ -42,7 +42,7 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |
@@ -59,7 +59,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 pip3 install --upgrade packaging
pip3 install --no-build-isolation -U -e . pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh

View File

@@ -74,7 +74,7 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |
@@ -98,9 +98,8 @@ jobs:
- name: Run tests - name: Run tests
run: | run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest -v tests/patched/ pytest -v tests/patched/
pytest -v tests/cli/
- name: cleanup pip cache - name: cleanup pip cache
run: | run: |
@@ -148,7 +147,7 @@ jobs:
- name: upgrade pip - name: upgrade pip
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel pip3 install --upgrade packaging setuptools setuptools_scm build wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |
@@ -173,9 +172,8 @@ jobs:
- name: Run tests - name: Run tests
run: | run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest -v tests/patched/ pytest -v tests/patched/
pytest -v tests/cli/
- name: cleanup pip cache - name: cleanup pip cache
run: | run: |

4
.gitignore vendored
View File

@@ -181,10 +181,6 @@ prepared-datasets/
submit.sh submit.sh
*.out* *.out*
# Quartodoc generated files
objects.json
site_libs/
typings/ typings/
out/ out/

View File

@@ -3,7 +3,7 @@ default_language_version:
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0 rev: v4.4.0
hooks: hooks:
- id: check-yaml - id: check-yaml
- id: end-of-file-fixer - id: end-of-file-fixer
@@ -11,23 +11,23 @@ repos:
- id: no-commit-to-branch - id: no-commit-to-branch
args: ['--branch', 'main'] args: ['--branch', 'main']
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 25.1.0 rev: 23.3.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 6.0.1 rev: 5.12.0
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 7.1.2 rev: 6.1.0
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://github.com/pylint-dev/pylint - repo: https://github.com/PyCQA/pylint
rev: v3.3.6 rev: v3.3.0
hooks: hooks:
- id: pylint - id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0 rev: v1.3.0
hooks: hooks:
- id: mypy - id: mypy
additional_dependencies: additional_dependencies:
@@ -36,7 +36,7 @@ repos:
'pydantic>=2.5.3', 'pydantic>=2.5.3',
] ]
- repo: https://github.com/PyCQA/bandit - repo: https://github.com/PyCQA/bandit
rev: 1.8.3 rev: 1.7.5
hooks: hooks:
- id: bandit - id: bandit
args: [ args: [

View File

@@ -19,6 +19,9 @@
<br/> <br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly"> <img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests"> <img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
</a>
</p> </p>
Axolotl is a tool designed to streamline post-training for various AI models. Axolotl is a tool designed to streamline post-training for various AI models.
@@ -55,7 +58,6 @@ Features:
### Installation ### Installation
```bash ```bash
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed] pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs # Download example axolotl configs, deepspeed configs
@@ -97,7 +99,6 @@ That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html) - [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html) - [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html) - [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
- [API Reference](https://axolotl-ai-cloud.github.io/axolotl/docs/api/) - Auto-generated code documentation
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions - [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
## 🤝 Getting Help ## 🤝 Getting Help

View File

@@ -1,178 +1,6 @@
project: project:
type: website type: website
quartodoc:
dir: docs/api
package: axolotl
title: API Reference
parser: google
sections:
- title: Core
desc: Core functionality for training
contents:
- train
- evaluate
- datasets
- convert
- prompt_tokenizers
- logging_config
- core.trainer_builder
- core.training_args
- core.chat.messages
- core.chat.format.chatml
- core.chat.format.llama3x
- core.chat.format.shared
- core.datasets.chat
- core.datasets.transforms.chat_builder
- title: CLI
desc: Command-line interface
contents:
- cli.main
- cli.train
- cli.evaluate
- cli.args
- cli.checks
- cli.config
- cli.inference
- cli.merge_lora
- cli.merge_sharded_fsdp_weights
- cli.preprocess
- cli.sweeps
- cli.utils
- cli.cloud.base
- cli.cloud.modal_
- title: Trainers
desc: Training implementations
contents:
- core.trainers.base
- core.trainers.trl
- core.trainers.dpo.trainer
- core.trainers.grpo.trainer
- title: Prompt Strategies
desc: Prompt formatting strategies
contents:
- prompt_strategies.base
- prompt_strategies.chat_template
- prompt_strategies.alpaca_chat
- prompt_strategies.alpaca_instruct
- prompt_strategies.alpaca_w_system
- prompt_strategies.user_defined
- prompt_strategies.llama2_chat
- prompt_strategies.completion
- prompt_strategies.input_output
- prompt_strategies.stepwise_supervised
- prompt_strategies.metharme
- prompt_strategies.orcamini
- prompt_strategies.pygmalion
- prompt_strategies.messages.chat
- prompt_strategies.dpo.chat_template
- prompt_strategies.dpo.llama3
- prompt_strategies.dpo.chatml
- prompt_strategies.dpo.zephyr
- prompt_strategies.dpo.user_defined
- prompt_strategies.dpo.passthrough
- prompt_strategies.kto.llama3
- prompt_strategies.kto.chatml
- prompt_strategies.kto.user_defined
- prompt_strategies.orpo.chat_template
- prompt_strategies.bradley_terry.llama3
- title: Kernels
desc: Low-level performance optimizations
contents:
- kernels.lora
- kernels.geglu
- kernels.swiglu
- kernels.quantize
- kernels.utils
- title: MonkeyPatches
desc: Runtime patches for model optimizations
contents:
- monkeypatch.llama_attn_hijack_flash
- monkeypatch.llama_attn_hijack_xformers
- monkeypatch.mistral_attn_hijack_flash
- monkeypatch.multipack
- monkeypatch.relora
- monkeypatch.llama_expand_mask
- monkeypatch.lora_kernels
- monkeypatch.utils
- monkeypatch.btlm_attn_hijack_flash
- monkeypatch.llama_patch_multipack
- monkeypatch.stablelm_attn_hijack_flash
- monkeypatch.trainer_fsdp_optim
- monkeypatch.transformers_fa_utils
- monkeypatch.unsloth_
- monkeypatch.attention.mllama
- monkeypatch.data.batch_dataset_fetcher
- monkeypatch.mixtral
- title: Utils
desc: Utility functions
contents:
- utils.models
- utils.tokenization
- utils.chat_templates
- utils.lora
- utils.lora_embeddings
- utils.model_shard_quant
- utils.bench
- utils.freeze
- utils.trainer
- utils.schedulers
- utils.distributed
- utils.dict
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.sft
- utils.gradient_checkpointing.unsloth
- title: Schemas
desc: Pydantic data models for Axolotl config
contents:
- utils.schemas.config
- utils.schemas.model
- utils.schemas.training
- utils.schemas.datasets
- utils.schemas.peft
- utils.schemas.trl
- utils.schemas.integrations
- utils.schemas.enums
- utils.schemas.utils
- title: Integrations
desc: Third-party integrations and extensions
contents:
- integrations.base
- integrations.cut_cross_entropy.args
- integrations.grokfast.optimizer
- integrations.kd.trainer
- integrations.liger.args
- integrations.lm_eval.args
- integrations.spectrum.args
- title: Common
desc: Common utilities and shared functionality
contents:
- common.architectures
- common.const
- common.datasets
- title: Models
desc: Custom model implementations
contents:
- models.mamba.modeling_mamba
- title: Data Processing
desc: Data processing utilities
contents:
- utils.collators.core
- utils.collators.batching
- utils.collators.mamba
- utils.collators.mm_chat
- utils.samplers.multipack
- title: Callbacks
desc: Training callbacks
contents:
- utils.callbacks.perplexity
- utils.callbacks.profiler
- utils.callbacks.lisa
- utils.callbacks.mlflow_
- utils.callbacks.comet_
website: website:
title: "Axolotl" title: "Axolotl"
description: "We make fine-tuning accessible, scalable, and fun" description: "We make fine-tuning accessible, scalable, and fun"
@@ -204,18 +32,14 @@ website:
contents: contents:
- docs/getting-started.qmd - docs/getting-started.qmd
- docs/installation.qmd - docs/installation.qmd
- docs/inference.qmd
- docs/cli.qmd - docs/cli.qmd
- docs/config.qmd - docs/inference.qmd
- text: "API Reference"
href: docs/api
- section: "Dataset Formats" - section: "Dataset Formats"
contents: docs/dataset-formats/* contents: docs/dataset-formats/*
- section: "Deployments" - section: "Deployments"
contents: contents:
- docs/docker.qmd
- docs/multi-gpu.qmd - docs/multi-gpu.qmd
- docs/multi-node.qmd - docs/multi-node.qmd
- docs/ray-integration.qmd - docs/ray-integration.qmd
@@ -249,27 +73,12 @@ website:
- docs/debugging.qmd - docs/debugging.qmd
- docs/nccl.qmd - docs/nccl.qmd
- section: "Reference"
contents:
- docs/config.qmd
format: format:
html: html:
theme: darkly theme: darkly
css: styles.css css: styles.css
toc: true toc: true
# Enable better handling of line breaks in markdown
preserve-tabs: true
html-math-method: mathjax
# Improved markdown processing options
md-extensions:
- markdown_it
- def_list
- attr_list
- fenced_divs
- tables
- html_admonition
- lineblocks
- fancy_lists
# Control whitespace handling
whitespace: preserve
# Process newlines in paragraphs
wrap: preserve
# Better line break handling
preserve-linebreaks: true

View File

@@ -31,11 +31,10 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \ sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi fi
RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi fi
RUN python scripts/unsloth_install.py | sh RUN python scripts/unsloth_install.py | sh

View File

@@ -3,10 +3,9 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__" python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 /workspace/axolotl/tests/cli pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/

View File

@@ -1,7 +1,6 @@
""" """
modal application to run axolotl gpu tests in Modal modal application to run axolotl gpu tests in Modal
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import os import os

View File

@@ -1,5 +1,4 @@
"""Modal app to run axolotl GPU tests""" """Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import os import os

View File

@@ -28,7 +28,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \ RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \ python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \ python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"

View File

@@ -1,39 +0,0 @@
ARG CUDA_VERSION="12.8.1"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="nightly"
ARG CUDA="128"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10

View File

@@ -14,7 +14,7 @@ COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \ RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \ RUN apt install --yes --no-install-recommends openssh-server tmux && \
mkdir -p ~/.ssh && \ mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \ chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \ printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \

2
docs/.gitignore vendored
View File

@@ -1,4 +1,2 @@
/.quarto/ /.quarto/
_site/ _site/
/api/*.qmd
/api/*.html

View File

@@ -1,5 +1,5 @@
--- ---
title: "Command Line Interface (CLI)" title: "CLI Reference"
format: format:
html: html:
toc: true toc: true

View File

@@ -1,5 +1,5 @@
--- ---
title: Config Reference title: Config options
description: A complete list of all configuration options. description: A complete list of all configuration options.
--- ---
@@ -30,11 +30,6 @@ tokenizer_legacy:
# Resize the model embeddings when new tokens are added to multiples of 32 # Resize the model embeddings when new tokens are added to multiples of 32
# This is reported to improve training speed on some models # This is reported to improve training speed on some models
resize_token_embeddings_to_32x: resize_token_embeddings_to_32x:
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
shrink_embeddings:
# Whether to load the model with randomly initialized weights. Useful for
# pre-training a model from scratch or debugging purposes.
random_init_weights:
# (Internal use only) # (Internal use only)
# Used to identify which the model is based on # Used to identify which the model is based on
@@ -88,12 +83,6 @@ gpu_memory_limit: 20GiB
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge # Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
lora_on_cpu: true lora_on_cpu: true
# List[str]. Add plugins to extend the pipeline.
# See `src/axolotl/integrations` for the available plugins or doc below for more details.
# https://axolotl-ai-cloud.github.io/axolotl/docs/custom_integrations.html
plugins:
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# A list of one or more datasets to finetune the model with # A list of one or more datasets to finetune the model with
datasets: datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
@@ -165,6 +154,8 @@ datasets:
content: value content: value
# ... # ...
message_property_mappings:
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is: # Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles: roles:
user: ["human", "user"] user: ["human", "user"]
@@ -172,12 +163,6 @@ datasets:
system: ["system"] system: ["system"]
tool: ["tool"] tool: ["tool"]
# Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template.
# This does not drop the default system message from chat_template if it exists. If you wish to,
# we recommend using a custom jinja template with the default system message removed or
# adding a system turn with empty content.
drop_system_message:
# IMPORTANT: The following fields determine which parts of the conversation to train on. # IMPORTANT: The following fields determine which parts of the conversation to train on.
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train # Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
# See examples at `docs/dataset-formats/conversation.qmd` # See examples at `docs/dataset-formats/conversation.qmd`
@@ -216,46 +201,10 @@ test_datasets:
data_files: data_files:
- /workspace/data/eval.jsonl - /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo' # use RL training: 'dpo', 'ipo', 'kto'
rl: rl:
rl_beta: # Optional[float]. The beta parameter for the RL training. # whether to perform weighting if doing DPO training. Boolean.
dpo_use_weighting:
# dpo
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
# orpo
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
# kto
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
# simpo
cpo_alpha: 1.0 # Weight of the BC regularizer
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
# grpo
trl:
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
vllm_device: # Optional[str]. Device to use for VLLM.
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
vllm_dtype: # Optional[str]. Data type for VLLM.
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
num_generations: # Optional[int]. Number of generations to sample.
log_completions: # Optional[bool]. Whether to log completions.
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
# reward modelling: `True` or `False` # reward modelling: `True` or `False`
reward_model: reward_model:
@@ -273,13 +222,13 @@ process_reward_model:
chat_template: tokenizer_default chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null. # custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null chat_template_jinja: null
# Changes the default system message. Currently only supports chatml. # Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so # Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path # subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared dataset_prepared_path: data/last_run_prepared
# Push prepared dataset to hub # Push prepared dataset to hub
push_dataset_to_hub: # Optional[str] repo_org/repo_name push_dataset_to_hub: # repo path
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# if not set. # if not set.
dataset_processes: # defaults to os.cpu_count() if not set dataset_processes: # defaults to os.cpu_count() if not set
@@ -496,7 +445,7 @@ gradient_checkpointing: false
early_stopping_patience: 3 early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer # Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs: lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf) cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
@@ -579,8 +528,6 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
sdp_attention: sdp_attention:
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf # Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
s2_attention: s2_attention:
# Optional[bool]. Whether to use low_cpu_mem_usage
low_cpu_mem_usage:
# Resume from a specific checkpoint dir # Resume from a specific checkpoint dir
resume_from_checkpoint: resume_from_checkpoint:
# If resume_from_checkpoint isn't set and you simply want it to start where it left off. # If resume_from_checkpoint isn't set and you simply want it to start where it left off.
@@ -601,13 +548,6 @@ special_tokens:
# Add extra tokens. # Add extra tokens.
tokens: tokens:
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
# Can be checked if they exist in tokenizer.json added_tokens.
added_tokens_overrides: # Dict[int, str]
# 128041: "<|im_start|>"
# 128042: "<|im_end|>"
# FSDP # FSDP
fsdp: fsdp:
fsdp_config: fsdp_config:
@@ -620,14 +560,6 @@ ddp_timeout:
ddp_bucket_cap_mb: ddp_bucket_cap_mb:
ddp_broadcast_buffers: ddp_broadcast_buffers:
# Sequence parallelism
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
# subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
sequence_parallel_degree:
# Path to torch distx for optim 'adamw_anyprecision' # Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path: torchdistx_path:

View File

@@ -55,47 +55,3 @@ sections = [
for section_name, folder_name in sections: for section_name, folder_name in sections:
print(print_section(section_name, folder_name)) print(print_section(section_name, folder_name))
``` ```
## Adding a new integration
Plugins can be used to customize the behavior of the training pipeline through [hooks](https://en.wikipedia.org/wiki/Hooking). See [`axolotl.integrations.BasePlugin`](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/base.py) for the possible hooks.
To add a new integration, please follow these steps:
1. Create a new folder in the `src/axolotl/integrations` directory.
2. Add any relevant files (`LICENSE`, `README.md`, `ACKNOWLEDGEMENTS.md`, etc.) to the new folder.
3. Add `__init__.py` and `args.py` files to the new folder.
- `__init__.py` should import the integration and hook into the appropriate functions.
- `args.py` should define the arguments for the integration.
4. (If applicable) Add CPU tests under `tests/integrations` or GPU tests under `tests/e2e/integrations`.
::: {.callout-tip}
See [src/axolotl/integrations/cut_cross_entropy](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/cut_cross_entropy) for a minimal integration example.
:::
::: {.callout-warning}
If you could not load your integration, please ensure you are pip installing in editable mode.
```bash
pip install -e .
```
and correctly spelled the integration name in the config file.
```yaml
plugins:
- axolotl.integrations.your_integration_name.YourIntegrationPlugin
```
:::
::: {.callout-note}
It is not necessary to place your integration in the `integrations` folder. It can be in any location, so long as it's installed in a package in your python env.
See this repo for an example: [https://github.com/axolotl-ai-cloud/diff-transformer](https://github.com/axolotl-ai-cloud/diff-transformer)
:::

View File

@@ -74,10 +74,6 @@ datasets:
train_on_eos: train_on_eos:
``` ```
::: {.callout-tip}
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
:::
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages. 2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml ```yaml

View File

@@ -129,7 +129,6 @@ You can mix and match within each approach or across approaches to train a model
We suggest this approach when you want to bring your own tokenized dataset. We suggest this approach when you want to bring your own tokenized dataset.
Axolotl expects the dataset to have three keys: Axolotl expects the dataset to have three keys:
- `input_ids`: from tokenizing formatted prompt - `input_ids`: from tokenizing formatted prompt
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]` - `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`. - `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.

View File

@@ -6,7 +6,7 @@ description: How datasets are processed
## Overview ## Overview
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
the [dataset format](dataset-formats) and prompt strategies to: the [dataset format](docs/dataset-formats) and prompt strategies to:
- parse the dataset based on the *dataset format* - parse the dataset based on the *dataset format*
- transform the dataset to how you would interact with the model based on the *prompt strategy* - transform the dataset to how you would interact with the model based on the *prompt strategy*

View File

@@ -1,140 +0,0 @@
---
title: "Docker"
format:
html:
toc: true
toc-depth: 4
---
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
## Base
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
#### Image
```
axolotlai/axolotl-base
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
#### Tags format
```bash
main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
```
Tags examples:
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1`
## Main
The main image is the image that is used to run Axolotl. It is based on the `axolotlai/axolotl-base` image and includes the Axolotl codebase, dependencies, and more.
#### Image
```
axolotlai/axolotl
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
#### Tags format {#sec-main-tags}
```bash
# on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
main-latest
# nightly build
{branch}-{date_in_YYYYMMDD}-py{python_version}-cu{cuda_version}-{pytorch_version}
# tagged release
{version}
```
:::{.callout-tip}
There may be some extra tags appended to the image, like `-vllm` which installs those packages.
:::
Tags examples:
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1`
- `main-20250303-py3.11-cu124-2.4.1`
- `0.7.1`
## Cloud
The cloud image is the image that is used to run Axolotl in the cloud. It is based on the `axolotlai/axolotl` image and sets ENV variables like HuggingFace cache directories for volume mounts, tmux, and more for different cloud providers.
:::{.callout-tip}
Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variables to disable it.
:::
#### Image
```
axolotlai/axolotl-cloud
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
#### Tags format
This uses the same tags as the [`main` image](#sec-main-tags).
#### Environment variables
- `JUPYTER_DISABLE`: Disable Jupyter lab.
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
- `PUBLIC_KEY`: Add a public key for the SSH service.
- `SSH_KEY`: Add a private key for the SSH service.
#### Volume mounts
:::{.callout-tip}
We recommend mounting volumes to `/workspace/data` for data persistence. `/workspace/axolotl` contains the source code and is ephemeral.
:::
- `/workspace/data/axolotl-artifacts`: Directory to store Axolotl artifacts.
- `/workspace/data/huggingface-cache`: Directory to store HuggingFace cache.
## Cloud-no-tmux
This is the same as the [`cloud` image](#sec-cloud) but without tmux.
#### Image
```
axolotlai/axolotl-cloud-term
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud-term)
:::{.callout-note}
The naming may be a bit confusing as it has `-term` appended to the end.
:::
#### Tags format
This uses the same tags as the [`cloud` image](#sec-cloud-tags).

View File

@@ -19,28 +19,12 @@ description: Frequently asked questions
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'** **Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
**Q: ModuleNotFoundError: No module named 'mpi4py' using single GPU with deepspeed** > A: You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
> A: You may be using deepspeed with single gpu. Please remove the `deepspeed:` section in the yaml file or `--deepspeed` CLI flag.
**Q: The codes is stuck on saving preprocessed datasets.** **Q: The codes is stuck on saving preprocessed datasets.**
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it. > A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
**Q: Received mismatch error on merge adapters / loading adapters between torch.Size of checkpoint and model.**
> A: This is likely due to vocab size mismatch. By default, Axolotl expands the model's embeddings if the tokenizer has more tokens than the model. Please use the `axolotl merge-lora` command to merge the adapters instead of using your own scripts.
> On the other hand, if the model has more tokens than the tokenizer, Axolotl does not shrink the model's embeddings unless `shrink_embeddings: true` is set in the config.
**Q: How to call Axolotl via custom python scripts?**
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
**Q: How to know the value to use for `fsdp_transformer_layer_cls_to_wrap`?**
> A: This is the class name of the transformer layer to wrap with FSDP. For example, for `LlamaForCausalLM`, the value is `LlamaDecoderLayer`. To find this for a specific model, check the model's `PreTrainedModel` definition and look for `_no_split_modules` variable in the `modeling_<model_name>.py` file within `transformers` library.
### Chat templates ### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`** **Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
@@ -66,7 +50,3 @@ description: Frequently asked questions
**Q: The EOS/EOT token is incorrectly being masked or not being masked.** **Q: The EOS/EOT token is incorrectly being masked or not being masked.**
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template. > A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.

View File

@@ -36,9 +36,7 @@ The YAML configuration file controls everything about your training. Here's what
```yaml ```yaml
base_model: NousResearch/Llama-3.2-1B base_model: NousResearch/Llama-3.2-1B
# hub_model_id: username/custom_model_name
load_in_8bit: true
adapter: lora
datasets: datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
@@ -46,15 +44,11 @@ datasets:
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.1 val_set_size: 0.1
output_dir: ./outputs/lora-out output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
``` ```
::: {.callout-tip}
`load_in_8bit: true` and `adapter: lora` enables LoRA adapter finetuning.
- To perform Full finetuning, remove these two lines.
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
:::
See our [Config options](config.qmd) for more details. See our [Config options](config.qmd) for more details.
### Training {#sec-training} ### Training {#sec-training}
@@ -62,7 +56,7 @@ See our [Config options](config.qmd) for more details.
When you run `axolotl train`, Axolotl: When you run `axolotl train`, Axolotl:
1. Downloads the base model 1. Downloads the base model
2. (If specified) applies QLoRA/LoRA adapter layers 2. (If specified) applies LoRA adapter layers
3. Loads and processes the dataset 3. Loads and processes the dataset
4. Runs the training loop 4. Runs the training loop
5. Saves the trained model and / or LoRA weights 5. Saves the trained model and / or LoRA weights
@@ -75,8 +69,6 @@ Let's modify the example for your own data:
```yaml ```yaml
base_model: NousResearch/Nous-Hermes-llama-1b-v1 base_model: NousResearch/Nous-Hermes-llama-1b-v1
load_in_8bit: true
adapter: lora adapter: lora
# Training settings # Training settings
@@ -112,6 +104,8 @@ format):
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"} {"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
``` ```
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
3. Run the training: 3. Run the training:
```bash ```bash

View File

@@ -1,5 +1,5 @@
--- ---
title: "Inference and Merging" title: "Inference"
format: format:
html: html:
toc: true toc: true
@@ -9,14 +9,10 @@ execute:
enabled: false enabled: false
--- ---
This guide covers how to use your trained models for inference, including model loading, interactive testing, merging adapters, and common troubleshooting steps. This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps.
## Quick Start {#sec-quickstart} ## Quick Start {#sec-quickstart}
::: {.callout-tip}
Use the same config used for training on inference/merging.
:::
### Basic Inference {#sec-basic} ### Basic Inference {#sec-basic}
::: {.panel-tabset} ::: {.panel-tabset}

View File

@@ -22,7 +22,6 @@ This guide covers all the ways you can install and set up Axolotl for your envir
### PyPI Installation (Recommended) {#sec-pypi} ### PyPI Installation (Recommended) {#sec-pypi}
```{.bash} ```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed] pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
``` ```
@@ -38,7 +37,7 @@ For the latest features between releases:
```{.bash} ```{.bash}
git clone https://github.com/axolotl-ai-cloud/axolotl.git git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl cd axolotl
pip3 install -U packaging setuptools wheel ninja pip3 install packaging ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]' pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
``` ```
@@ -66,8 +65,6 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
``` ```
::: :::
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
## Cloud Environments {#sec-cloud} ## Cloud Environments {#sec-cloud}
### Cloud GPU Providers {#sec-cloud-gpu} ### Cloud GPU Providers {#sec-cloud-gpu}
@@ -79,7 +76,6 @@ For providers supporting Docker:
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c) - [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl) - [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz) - [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
- [Novita](https://novita.ai/gpus-console?templateId=311)
### Google Colab {#sec-colab} ### Google Colab {#sec-colab}
@@ -109,7 +105,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
2. Install PyTorch: https://pytorch.org/get-started/locally/ 2. Install PyTorch: https://pytorch.org/get-started/locally/
3. Install Axolotl: 3. Install Axolotl:
```{.bash} ```{.bash}
pip3 install -U packaging setuptools wheel ninja pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]' pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
``` ```
4. (Optional) Login to Hugging Face: 4. (Optional) Login to Hugging Face:

View File

@@ -66,10 +66,6 @@ logic to be compatible with more of them.
</details> </details>
::: {.callout-tip}
Check out our [LoRA optimizations blog](https://axolotlai.substack.com/p/accelerating-lora-fine-tuning-with).
:::
## Usage ## Usage
These optimizations can be enabled in your Axolotl config YAML file. The These optimizations can be enabled in your Axolotl config YAML file. The

View File

@@ -28,23 +28,8 @@ val_set_size: 0.1
eval_steps: 100 eval_steps: 100
``` ```
Bradley-Terry chat templates expect single-turn conversations in the following format:
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
### Process Reward Models (PRM) ### Process Reward Models (PRM)
::: {.callout-tip}
Check out our [PRM blog](https://axolotlai.substack.com/p/process-reward-models).
:::
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning. Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
```yaml ```yaml
base_model: Qwen/Qwen2.5-3B base_model: Qwen/Qwen2.5-3B
@@ -60,5 +45,3 @@ datasets:
val_set_size: 0.1 val_set_size: 0.1
eval_steps: 100 eval_steps: 100
``` ```
Please see [stepwise_supervised](dataset-formats/stepwise_supervised.qmd) for more details on the dataset format.

View File

@@ -3,7 +3,6 @@ title: "RLHF (Beta)"
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback." description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
back-to-top-navigation: true back-to-top-navigation: true
toc: true toc: true
toc-expand: 2
toc-depth: 4 toc-depth: 4
--- ---
@@ -298,7 +297,7 @@ The input format is a simple JSON input with customizable fields based on the ab
### IPO ### IPO
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO. As IPO is just DPO with a different loss function, all supported options for DPO works here.
```yaml ```yaml
rl: ipo rl: ipo
@@ -344,9 +343,8 @@ ORPO supports the following types with the following dataset format:
```yaml ```yaml
rl: kto rl: kto
rl_beta: 0.1 # default rl_beta: 0.5
kto_desirable_weight: 1.0 # default kto_desirable_weight: 0.2
kto_undesirable_weight: 1.0 # default
remove_unused_columns: false remove_unused_columns: false
@@ -498,10 +496,6 @@ The input format is a simple JSON input with customizable fields based on the ab
### GRPO ### GRPO
::: {.callout-tip}
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
:::
GRPO uses custom reward functions and transformations. Please have them ready locally. GRPO uses custom reward functions and transformations. Please have them ready locally.
For ex, to load OpenAI's GSM8K and use a random reward for completions: For ex, to load OpenAI's GSM8K and use a random reward for completions:
@@ -534,7 +528,6 @@ trl:
vllm_gpu_memory_utilization: 0.15 vllm_gpu_memory_utilization: 0.15
num_generations: 4 num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}' reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0]
datasets: datasets:
- path: openai/gsm8k - path: openai/gsm8k
name: main name: main
@@ -543,21 +536,6 @@ datasets:
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function). To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
### SimPO
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
```yaml
rl: simpo
rl_beta: 0.1 # default in CPOTrainer
cpo_alpha: 1.0 # default in CPOTrainer
simpo_gamma: 0.5 # default in CPOTrainer
```
This method uses the same dataset format as [DPO](#dpo).
### Using local dataset files ### Using local dataset files
```yaml ```yaml

View File

@@ -1,90 +0,0 @@
---
title: Sequence Parallelism
description: Train with long sequences split across multiple GPUs.
---
# Sequence Parallelism
Sequence parallelism is a technique that splits sequences across multiple GPUs,
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
GPU processes a different portion of the sequence, and the results are aggregated
through a ring communication pattern.
## When to Use Sequence Parallelism
Use sequence parallelism when:
- You need to train with sequence lengths that don't fit into a single GPU's memory
- You have multiple GPUs available
- You're experiencing OOM (Out Of Memory) errors with long sequences
## Configuration
To enable sequence parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4
## Implementation Details
When sequence parallelism is enabled:
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
4. The trainer uses special ring communication patterns for attention operations
## Requirements
To use sequence parallelism, you need:
- Multiple GPUs (at least 2)
- The `ring-flash-attn` package. Install with:
- `pip install axolotl[ring-flash-attn]` (preferred)
- `pip install ring-flash-attn>=0.1.4`
## Limitations
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
- May have a small performance overhead due to communication between GPUs
## Example
```yaml
# Example config with sequence parallelism
base_model: meta-llama/Llama-3-8B-Instruct
sequence_len: 8192
sequence_parallel_degree: 2 # Split each sequence into 4 parts
flash_attention: true # Required with sequence parallelism
...
```
This will train the Llama 3 8B model with 8K context length, with each sequence split
into 2 subsequences of length 4096 across 2 GPUs.
## Sample Packing with Sequence Parallelism
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
1. Samples are first packed together
2. The packed sequences are then divided across GPUs in the sequence parallel group
3. Position IDs are automatically adjusted to maintain proper relative positions
## Effect on Batch Size
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases
For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4

59
docs/telemetry.qmd Normal file
View File

@@ -0,0 +1,59 @@
---
title: Telemetry
description: A description of the opt-out telemetry implementation in Axolotl.
---
# Telemetry in Axolotl
Axolotl implements anonymous telemetry to help maintainers understand how the library
is used and where users encounter issues. This data helps prioritize features, optimize
performance, and fix bugs.
## Data Collection
We collect:
- System info: OS, Python version, Axolotl version, PyTorch version, Transformers
version, etc.
- Hardware info: CPU count, memory, GPU count and models
- Runtime metrics: Training progress, memory usage, timing information
- Usage patterns: Models (from a whitelist) and configurations used
- Error tracking: Stack traces and error messages (sanitized to remove personal
information)
No personally identifiable information (PII) is collected.
## Implementation
Telemetry is implemented using PostHog and consists of:
- `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the
telemetry system and provides methods for tracking events.
- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and
sends sanitized stack traces.
- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks
runtime metrics during training.
- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends
runtime metrics telemetry.
The telemetry system will block training startup for 15 seconds to ensure users are
aware of data collection, unless telemetry is explicitly enabled or disabled.
## Opt-Out Mechanism
Telemetry is **enabled by default** on an opt-out basis. To disable it, set either:
- `AXOLOTL_DO_NOT_TRACK=1` (Axolotl-specific)
- `DO_NOT_TRACK=1` (Global standard; see https://consoledonottrack.com/)
To acknowledge and explicitly enable telemetry (and remove the warning message), set:
`AXOLOTL_DO_NOT_TRACK=0`.
## Privacy
- All path-like config information is automatically redacted from telemetry data
- Model information is only collected for whitelisted organizations
- See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations
- Each run generates a unique anonymous ID
- This allows us to link different telemetry events in a single same training run
- Telemetry is only sent from the main process to avoid duplicate events

View File

@@ -55,7 +55,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
gradient_checkpointing_kwargs: gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: true
early_stopping_patience: early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:

View File

@@ -1,5 +1,5 @@
[build-system] [build-system]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"] requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project] [project]
@@ -8,7 +8,6 @@ dynamic = ["version", "dependencies", "optional-dependencies"]
description = "LLM Trainer" description = "LLM Trainer"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
# license = "Apache-2.0"
[project.scripts] [project.scripts]
axolotl = "axolotl.cli.main:main" axolotl = "axolotl.cli.main:main"

View File

@@ -2,5 +2,3 @@ pre-commit
black black
mypy mypy
types-requests types-requests
quartodoc
jupyter

View File

@@ -1,9 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS # START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.3 bitsandbytes==0.45.2
triton>=3.0.0 triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
flash-attn==2.7.4.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
autoawq==0.2.7.post3 autoawq==0.2.7.post3
liger-kernel==0.5.3 liger-kernel==0.5.3
@@ -11,12 +12,12 @@ liger-kernel==0.5.3
packaging==23.2 packaging==23.2
peft==0.15.0 peft==0.14.0
transformers==4.49.0 transformers==4.49.0
tokenizers>=0.21.1 tokenizers>=0.21.0
accelerate==1.5.2 accelerate==1.3.0
datasets==3.4.1 datasets==3.2.0
deepspeed==0.16.4 deepspeed==0.16.1
trl==0.15.1 trl==0.15.1
optimum==1.16.2 optimum==1.16.2
@@ -35,7 +36,6 @@ einops
colorama colorama
numba numba
numpy>=1.24.4,<=2.0.1 numpy>=1.24.4,<=2.0.1
# qlora things # qlora things
evaluate==0.4.1 evaluate==0.4.1
scipy scipy
@@ -62,5 +62,7 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0 torchao==0.7.0
schedulefree==1.3.0 schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.6 axolotl-contribs-lgpl==0.0.3
axolotl-contribs-mit==0.0.3
# telemetry
posthog>=3.15.1

View File

@@ -1,7 +1,6 @@
""" """
helper script to parse chat datasets into a usable yaml helper script to parse chat datasets into a usable yaml
""" """
import click import click
import yaml import yaml
from datasets import load_dataset from datasets import load_dataset

View File

@@ -1,5 +1,4 @@
"""Script to output the correct installation command for cut-cross-entropy.""" """Script to output the correct installation command for cut-cross-entropy."""
import importlib.util import importlib.util
import sys import sys
@@ -25,5 +24,5 @@ if cce_spec:
print( print(
UNINSTALL_PREFIX UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"' + 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
) )

View File

@@ -17,7 +17,11 @@ def parse_requirements():
lines = [r.strip() for r in requirements_file.readlines()] lines = [r.strip() for r in requirements_file.readlines()]
for line in lines: for line in lines:
is_extras = ( is_extras = (
"deepspeed" in line or "mamba-ssm" in line or "lion-pytorch" in line "flash-attn" in line
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
or "lion-pytorch" in line
) )
if line.startswith("--extra-index-url"): if line.startswith("--extra-index-url"):
# Handle custom index URLs # Handle custom index URLs
@@ -35,6 +39,7 @@ def parse_requirements():
"bitsandbytes", "bitsandbytes",
"triton", "triton",
"mamba-ssm", "mamba-ssm",
"flash-attn",
"xformers", "xformers",
"autoawq", "autoawq",
"liger-kernel", "liger-kernel",
@@ -119,10 +124,11 @@ setup(
], ],
}, },
extras_require={ extras_require={
"flash-attn": ["flash-attn==2.7.4.post1"], "flash-attn": [
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"], "flash-attn==2.7.4.post1",
],
"deepspeed": [ "deepspeed": [
"deepspeed==0.16.4", "deepspeed==0.16.1",
"deepspeed-kernels", "deepspeed-kernels",
], ],
"mamba-ssm": [ "mamba-ssm": [

View File

@@ -1,7 +1,6 @@
""" """
launch axolotl in supported cloud platforms launch axolotl in supported cloud platforms
""" """
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union

View File

@@ -1,7 +1,6 @@
""" """
base class for cloud platforms from cli base class for cloud platforms from cli
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod

View File

@@ -1,7 +1,6 @@
""" """
Modal Cloud support from CLI Modal Cloud support from CLI
""" """
import copy import copy
import json import json
import os import os
@@ -114,7 +113,7 @@ class ModalCloud(Cloud):
[ [
# Random id for cache busting of branch commits # Random id for cache busting of branch commits
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311 f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch} && git pull", f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
] ]
) )
@@ -271,7 +270,6 @@ def _preprocess(config_yaml: str, volumes=None):
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs): def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out: with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml) f_out.write(config_yaml)
run_folder = "/workspace/mounts" run_folder = "/workspace/mounts"
@@ -290,7 +288,6 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
def _lm_eval(config_yaml: str, volumes=None): def _lm_eval(config_yaml: str, volumes=None):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out: with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml) f_out.write(config_yaml)
run_folder = "/workspace/mounts" run_folder = "/workspace/mounts"

View File

@@ -14,6 +14,8 @@ import yaml
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils.comet_ import setup_comet_env_vars from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import ( from axolotl.utils.config import (
normalize_cfg_datasets, normalize_cfg_datasets,
@@ -27,6 +29,8 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
""" """
@@ -152,6 +156,7 @@ def prepare_plugins(cfg: DictDefault):
plugin_manager.register(plugin_name) plugin_manager.register(plugin_name)
@send_errors
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault: def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
""" """
Loads the `axolotl` configuration stored at `config`, validates it, and performs Loads the `axolotl` configuration stored at `config`, validates it, and performs
@@ -171,6 +176,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
# Load the config from the yaml file # Load the config from the yaml file
with open(config, encoding="utf-8") as file: with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file)) cfg: DictDefault = DictDefault(yaml.safe_load(file))
TELEMETRY_MANAGER.send_event(event_type="config-loaded", properties=cfg)
# If there are any options passed in the cli, if it is something that seems valid # If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value # from the yaml, then overwrite the value
@@ -214,4 +220,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
setup_mlflow_env_vars(cfg) setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg) setup_comet_env_vars(cfg)
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
return cfg return cfg

View File

@@ -17,6 +17,7 @@ from axolotl.cli.args import InferenceCliArgs
from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.utils.chat_templates import ( from axolotl.utils.chat_templates import (
get_chat_template, get_chat_template,
get_chat_template_from_config, get_chat_template_from_config,
@@ -42,6 +43,7 @@ def get_multi_line_input() -> str:
return instruction return instruction
@send_errors
def do_inference( def do_inference(
*, *,
cfg: DictDefault, cfg: DictDefault,
@@ -135,6 +137,7 @@ def do_inference(
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
@send_errors
def do_inference_gradio( def do_inference_gradio(
*, *,
cfg: DictDefault, cfg: DictDefault,

View File

@@ -1,5 +1,4 @@
"""Click CLI definitions for various axolotl commands.""" """Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
import logging import logging
@@ -25,7 +24,7 @@ from axolotl.cli.utils import (
) )
from axolotl.integrations.lm_eval.cli import lm_eval from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.schemas.config import AxolotlInputConfig from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@click.group() @click.group()

View File

@@ -12,11 +12,13 @@ from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@send_errors
def do_merge_lora(*, cfg: DictDefault) -> None: def do_merge_lora(*, cfg: DictDefault) -> None:
""" """
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config

View File

@@ -27,6 +27,7 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.telemetry.errors import send_errors
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -120,6 +121,7 @@ def _distributed_checkpoint_to_merged_weights(
return save_path_ return save_path_
@send_errors
def merge_fsdp_weights( def merge_fsdp_weights(
checkpoint_dir: str, checkpoint_dir: str,
output_path: str, output_path: str,

View File

@@ -18,12 +18,14 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@send_errors
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
""" """
Preprocesses dataset specified in axolotl config. Preprocesses dataset specified in axolotl config.

View File

@@ -1,7 +1,6 @@
"""CLI to run training on a model.""" """CLI to run training on a model."""
import logging import logging
import os
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -23,7 +22,7 @@ from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
""" """
Trains a `transformers` model by first loading the dataset(s) specified in the Trains a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin `axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
@@ -35,22 +34,23 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
""" """
print_axolotl_text_art() print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0: check_user_token()
check_user_token()
if cfg.rl: if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta) model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
del model, tokenizer, trainer
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
del model
del tokenizer
plugin_manager.post_train_unload(cfg) plugin_manager.post_train_unload(cfg)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
""" """
Parses `axolotl` config, CLI args, and calls `do_train`. Parses `axolotl` config, CLI args, and calls `do_train`.

View File

@@ -5,6 +5,7 @@ import dataclasses
import hashlib import hashlib
import json import json
import logging import logging
import typing
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from types import NoneType from types import NoneType
@@ -23,7 +24,7 @@ configure_logging()
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def strip_optional_type(field_type: type | str | None): def strip_optional_type(field_type: type | typing._SpecialForm | None):
""" """
Extracts the non-`None` type from an `Optional` / `Union` type. Extracts the non-`None` type from an `Optional` / `Union` type.

View File

@@ -10,6 +10,7 @@ from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.telemetry.errors import send_errors
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -24,8 +25,8 @@ class TrainDatasetMeta:
"""Dataclass with fields for training and validation datasets and metadata.""" """Dataclass with fields for training and validation datasets and metadata."""
train_dataset: Dataset train_dataset: Dataset
eval_dataset: Dataset | None = None eval_dataset: Optional[Dataset] = None
total_num_steps: int | None = None total_num_steps: Optional[int] = None
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset: def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
@@ -44,6 +45,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
) )
@send_errors
def load_datasets( def load_datasets(
*, *,
cfg: DictDefault, cfg: DictDefault,
@@ -103,6 +105,7 @@ def load_datasets(
) )
@send_errors
def load_preference_datasets( def load_preference_datasets(
*, *,
cfg: DictDefault, cfg: DictDefault,

View File

@@ -1,5 +1,6 @@
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes""" """Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
import json import json
import sys import sys

View File

@@ -1,7 +1,6 @@
""" """
ChatML transformation functions for MessageContents ChatML transformation functions for MessageContents
""" """
from typing import Optional from typing import Optional
from ..messages import MessageContents, Messages from ..messages import MessageContents, Messages

View File

@@ -1,7 +1,6 @@
""" """
Llama 3.x chat formatting functions for MessageContents Llama 3.x chat formatting functions for MessageContents
""" """
from typing import Optional from typing import Optional
from ..messages import MessageContents, Messages from ..messages import MessageContents, Messages

View File

@@ -1,7 +1,6 @@
""" """
shared functions for format transforms shared functions for format transforms
""" """
from axolotl.core.chat.messages import MessageContents, Messages from axolotl.core.chat.messages import MessageContents, Messages

View File

@@ -1,7 +1,6 @@
""" """
internal message representations of chat messages internal message representations of chat messages
""" """
import json import json
from enum import Enum from enum import Enum
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union

View File

@@ -1,7 +1,6 @@
""" """
chat dataset module chat dataset module
""" """
import os import os
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
@@ -44,7 +43,7 @@ class TokenizedChatDataset(Dataset):
process_or_cpu_count: int = ( process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment] process_count or os.cpu_count() # type: ignore[assignment]
) )
num_proc = min(32, process_or_cpu_count) num_proc = min(64, process_or_cpu_count)
features = data.features.keys() features = data.features.keys()
tokenized_data = data.map( tokenized_data = data.map(
map_fn, map_fn,

View File

@@ -1,7 +1,6 @@
""" """
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat. This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
""" """
from typing import Any, Mapping, Union from typing import Any, Mapping, Union

View File

@@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
"""Builder for the training args and trainer""" """
Builder for the training args and trainer
"""
import abc import abc
import importlib import importlib
@@ -33,10 +35,9 @@ from transformers import (
EarlyStoppingCallback, EarlyStoppingCallback,
TrainerCallback, TrainerCallback,
) )
from transformers.training_args import OptimizerNames
from trl.trainer.utils import RewardDataCollatorWithPadding from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers import ( from axolotl.core.trainers.base import (
AxolotlCPOTrainer, AxolotlCPOTrainer,
AxolotlKTOTrainer, AxolotlKTOTrainer,
AxolotlMambaTrainer, AxolotlMambaTrainer,
@@ -60,6 +61,8 @@ from axolotl.core.training_args import (
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.telemetry.callbacks import TelemetryCallback
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
EvalFirstStepCallback, EvalFirstStepCallback,
@@ -84,18 +87,19 @@ from axolotl.utils.collators import (
) )
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.models import ensure_dtype from axolotl.utils.models import ensure_dtype
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
try: try:
import torch._dynamo # pylint: disable=ungrouped-imports import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError: except ImportError:
pass pass
LOG = logging.getLogger(__name__) LOG = logging.getLogger("axolotl.core.trainer_builder")
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
"""Base class for trainer builder.""" """
Base class for trainer builder
"""
_train_dataset = None _train_dataset = None
_eval_dataset = None _eval_dataset = None
@@ -108,9 +112,9 @@ class TrainerBuilderBase(abc.ABC):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.processor = processor self.processor = processor
# If the model supports tagging, add the axolotl tag. # in case the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls # This makes sure the tag is correctly pushed even if a user calls
# model.push_to_hub instead of trainer.push_to_hub. # model.push_to_hub instad of trainer.push_to_hub.
if hasattr(model, "add_model_tags"): if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"]) model.add_model_tags(["axolotl"])
@@ -174,10 +178,8 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoMlflowCallback, SaveAxolotlConfigtoMlflowCallback,
) )
callbacks.extend( callbacks.append(
[ SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
]
) )
if self.cfg.use_comet and is_comet_available(): if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
@@ -186,6 +188,10 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
) )
telemetry_manager = TelemetryManager.get_instance()
if telemetry_manager.enabled:
callbacks.append(TelemetryCallback())
return callbacks return callbacks
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):
@@ -225,8 +231,8 @@ class TrainerBuilderBase(abc.ABC):
class HFCausalTrainerBuilder(TrainerBuilderBase): class HFCausalTrainerBuilder(TrainerBuilderBase):
""" """
Build the HuggingFace training args/trainer for causal models and reward modeling Build the HuggingFace training args/trainer for causal models
using TRL. and reward modelling using TRL.
""" """
def get_callbacks(self): def get_callbacks(self):
@@ -330,9 +336,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs = {} training_arguments_kwargs = {}
if self.cfg.include_tokens_per_second is not None: if self.cfg.include_tokens_per_second is not None:
training_arguments_kwargs["include_tokens_per_second"] = ( training_arguments_kwargs[
self.cfg.include_tokens_per_second "include_tokens_per_second"
) ] = self.cfg.include_tokens_per_second
if self.cfg.bf16 == "full": if self.cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True training_arguments_kwargs["bf16_full_eval"] = True
@@ -349,13 +355,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["seed"] = self.cfg.seed training_arguments_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing: if self.cfg.gradient_checkpointing:
training_arguments_kwargs["gradient_checkpointing"] = ( training_arguments_kwargs[
self.cfg.gradient_checkpointing "gradient_checkpointing"
) ] = self.cfg.gradient_checkpointing
if self.cfg.gradient_checkpointing_kwargs is not None: if self.cfg.gradient_checkpointing_kwargs is not None:
training_arguments_kwargs["gradient_checkpointing_kwargs"] = ( training_arguments_kwargs[
self.cfg.gradient_checkpointing_kwargs "gradient_checkpointing_kwargs"
) ] = self.cfg.gradient_checkpointing_kwargs
if self.cfg.fsdp: if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config: if self.cfg.fsdp_config:
@@ -371,9 +377,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
if self.cfg.lr_quadratic_warmup is not None: if self.cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs["lr_quadratic_warmup"] = ( training_arguments_kwargs[
self.cfg.lr_quadratic_warmup "lr_quadratic_warmup"
) ] = self.cfg.lr_quadratic_warmup
if self.cfg.adam_beta1: if self.cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1 training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
@@ -397,28 +403,28 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.dataloader_pin_memory is not None: if self.cfg.dataloader_pin_memory is not None:
training_arguments_kwargs["dataloader_pin_memory"] = ( training_arguments_kwargs[
self.cfg.dataloader_pin_memory "dataloader_pin_memory"
) ] = self.cfg.dataloader_pin_memory
if self.cfg.dataloader_num_workers is not None: if self.cfg.dataloader_num_workers is not None:
training_arguments_kwargs["dataloader_num_workers"] = ( training_arguments_kwargs[
self.cfg.dataloader_num_workers "dataloader_num_workers"
) ] = self.cfg.dataloader_num_workers
if self.cfg.dataloader_prefetch_factor is not None: if self.cfg.dataloader_prefetch_factor is not None:
training_arguments_kwargs["dataloader_prefetch_factor"] = ( training_arguments_kwargs[
self.cfg.dataloader_prefetch_factor "dataloader_prefetch_factor"
) ] = self.cfg.dataloader_prefetch_factor
if self.cfg.dataloader_drop_last is not None: if self.cfg.dataloader_drop_last is not None:
training_arguments_kwargs["dataloader_drop_last"] = ( training_arguments_kwargs[
self.cfg.dataloader_drop_last "dataloader_drop_last"
) ] = self.cfg.dataloader_drop_last
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False: elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True training_arguments_kwargs["dataloader_drop_last"] = True
if self.cfg.remove_unused_columns is not None: if self.cfg.remove_unused_columns is not None:
training_arguments_kwargs["remove_unused_columns"] = ( training_arguments_kwargs[
self.cfg.remove_unused_columns "remove_unused_columns"
) ] = self.cfg.remove_unused_columns
if not self.cfg.test_datasets and self.cfg.val_set_size == 0: if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval # no eval set, so don't eval
@@ -450,9 +456,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.do_causal_lm_eval: if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model: if self.cfg.metric_for_best_model:
training_arguments_kwargs["metric_for_best_model"] = ( training_arguments_kwargs[
self.cfg.metric_for_best_model "metric_for_best_model"
) ] = self.cfg.metric_for_best_model
if self.cfg.greater_is_better: if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
@@ -465,13 +471,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend: if self.cfg.torch_compile_backend:
training_arguments_kwargs["torch_compile_backend"] = ( training_arguments_kwargs[
self.cfg.torch_compile_backend "torch_compile_backend"
) ] = self.cfg.torch_compile_backend
if self.cfg.torch_compile_mode: if self.cfg.torch_compile_mode:
training_arguments_kwargs["torch_compile_mode"] = ( training_arguments_kwargs[
self.cfg.torch_compile_mode "torch_compile_mode"
) ] = self.cfg.torch_compile_mode
# DDP Config # DDP Config
if self.cfg.ddp_timeout: if self.cfg.ddp_timeout:
@@ -480,32 +486,32 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.ddp_bucket_cap_mb: if self.cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
if self.cfg.ddp_broadcast_buffers is not None: if self.cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs["ddp_broadcast_buffers"] = ( training_arguments_kwargs[
self.cfg.ddp_broadcast_buffers "ddp_broadcast_buffers"
) ] = self.cfg.ddp_broadcast_buffers
# these are all the "standard" kwargs that are def used # these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = ( training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1 total_num_steps if self.cfg.max_steps else -1
) )
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs["per_device_train_batch_size"] = ( training_arguments_kwargs[
self.cfg.micro_batch_size "per_device_train_batch_size"
) ] = self.cfg.micro_batch_size
if self.cfg.eval_batch_size: if self.cfg.eval_batch_size:
training_arguments_kwargs["per_device_eval_batch_size"] = ( training_arguments_kwargs[
self.cfg.eval_batch_size "per_device_eval_batch_size"
) ] = self.cfg.eval_batch_size
if self.cfg.auto_find_batch_size is not None: if self.cfg.auto_find_batch_size is not None:
training_arguments_kwargs["auto_find_batch_size"] = ( training_arguments_kwargs[
self.cfg.auto_find_batch_size "auto_find_batch_size"
) ] = self.cfg.auto_find_batch_size
training_arguments_kwargs["gradient_accumulation_steps"] = ( training_arguments_kwargs[
self.cfg.gradient_accumulation_steps "gradient_accumulation_steps"
) ] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs["eval_accumulation_steps"] = ( training_arguments_kwargs[
self.cfg.gradient_accumulation_steps "eval_accumulation_steps"
) ] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
training_arguments_kwargs["output_dir"] = self.cfg.output_dir training_arguments_kwargs["output_dir"] = self.cfg.output_dir
@@ -549,12 +555,34 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
else: else:
training_arguments_kwargs["run_name"] = None training_arguments_kwargs["run_name"] = None
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]: if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine" training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs["alternate_lr_scheduler_type"] = ( training_arguments_kwargs[
self.cfg.lr_scheduler "alternate_lr_scheduler_type"
) ] = self.cfg.lr_scheduler
else: else:
training_arguments_kwargs["lr_scheduler_type"] = ( training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
@@ -563,9 +591,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
) )
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_arguments_kwargs["cosine_constant_lr_ratio"] = ( training_arguments_kwargs[
self.cfg.cosine_constant_lr_ratio "cosine_constant_lr_ratio"
) ] = self.cfg.cosine_constant_lr_ratio
training_arguments_kwargs["weight_decay"] = ( training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
) )
@@ -578,40 +606,40 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.eval_sample_packing self.cfg.eval_sample_packing
) )
if self.cfg.sample_packing_bin_size is not None: if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs["sample_packing_bin_size"] = ( training_arguments_kwargs[
self.cfg.sample_packing_bin_size "sample_packing_bin_size"
) ] = self.cfg.sample_packing_bin_size
if self.cfg.sample_packing_group_size is not None: if self.cfg.sample_packing_group_size is not None:
training_arguments_kwargs["sample_packing_group_size"] = ( training_arguments_kwargs[
self.cfg.sample_packing_group_size "sample_packing_group_size"
) ] = self.cfg.sample_packing_group_size
if self.cfg.sample_packing_eff_est: if self.cfg.sample_packing_eff_est:
training_arguments_kwargs["sample_packing_efficiency"] = ( training_arguments_kwargs[
self.cfg.sample_packing_eff_est "sample_packing_efficiency"
) ] = self.cfg.sample_packing_eff_est
if self.cfg.relora_steps: if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = ( training_arguments_kwargs[
self.cfg.relora_warmup_steps "relora_warmup_steps"
) ] = self.cfg.relora_warmup_steps
if self.cfg.relora_anneal_steps: if self.cfg.relora_anneal_steps:
training_arguments_kwargs["relora_anneal_steps"] = ( training_arguments_kwargs[
self.cfg.relora_anneal_steps "relora_anneal_steps"
) ] = self.cfg.relora_anneal_steps
if self.cfg.relora_prune_ratio: if self.cfg.relora_prune_ratio:
training_arguments_kwargs["relora_prune_ratio"] = ( training_arguments_kwargs[
self.cfg.relora_prune_ratio "relora_prune_ratio"
) ] = self.cfg.relora_prune_ratio
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
training_arguments_kwargs["lisa_step_interval"] = ( training_arguments_kwargs[
self.cfg.lisa_step_interval "lisa_step_interval"
) ] = self.cfg.lisa_step_interval
training_arguments_kwargs["lisa_layers_attribute"] = ( training_arguments_kwargs[
self.cfg.lisa_layers_attribute "lisa_layers_attribute"
) ] = self.cfg.lisa_layers_attribute
training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs training_arguments_kwargs
@@ -625,127 +653,59 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
if self.cfg.neftune_noise_alpha is not None: if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs["neftune_noise_alpha"] = ( training_arguments_kwargs[
self.cfg.neftune_noise_alpha "neftune_noise_alpha"
) ] = self.cfg.neftune_noise_alpha
trainer_kwargs = {} trainer_kwargs = {}
if self.cfg.reward_model: if self.cfg.reward_model:
training_arguments_kwargs["max_length"] = self.cfg.sequence_len training_arguments_kwargs["max_length"] = self.cfg.sequence_len
# Handle custom optimizer # pylint: disable=duplicate-code
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers] if self.cfg.optimizer in [
if self.cfg.optimizer in custom_supported_optimizers: "optimi_adamw",
# Common optimizer kwargs "ao_adamw_4bit",
optimizer_kwargs = { "ao_adamw_8bit",
"lr": training_arguments_kwargs.get("learning_rate"), "ao_adamw_fp8",
"weight_decay": training_arguments_kwargs.get("weight_decay"), "adopt_adamw",
} ]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
# Adam-specific kwargs if self.cfg.optimizer == "lion_pytorch":
adam_kwargs = {} from lion_pytorch import Lion
if training_arguments_kwargs.get(
"adam_beta1"
) and training_arguments_kwargs.get("adam_beta2"):
adam_kwargs["betas"] = (
training_arguments_kwargs.get("adam_beta1"),
training_arguments_kwargs.get("adam_beta2"),
)
if training_arguments_kwargs.get("adam_epsilon"):
adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon")
if self.cfg.optimizer == "muon": lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module if "weight_decay" in training_arguments_kwargs:
MuonOptimizerFactory, lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
if (
"adam_beta1" in training_arguments_kwargs
and "adam_beta2" in training_arguments_kwargs
):
lion_kwargs["betas"] = (
training_arguments_kwargs["adam_beta1"],
training_arguments_kwargs["adam_beta2"],
) )
optimizer_cls = MuonOptimizerFactory trainer_kwargs["optimizers"] = (
optimizer_kwargs.update(adam_kwargs) Lion(params=self.model.parameters(), **lion_kwargs),
elif self.cfg.optimizer == "optimi_adamw": None,
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optimizer_kwargs.update(self.cfg.optim_args)
else:
# Parse string format "key1=value1,key2=value2"
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optimizer_kwargs[key] = value
trainer_kwargs["optimizer_cls_and_kwargs"] = (
optimizer_cls,
optimizer_kwargs,
) )
else: # Set default so transformers doesn't throw
# Use transformers' optimizer training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["optim"] = self.cfg.optimizer
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optimizer == "adamw_anyprecision": if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists(): if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path) sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx") importlib.import_module("torchdistx")
if self.cfg.optim_target_modules:
training_arguments_kwargs["optim_target_modules"] = (
self.cfg.optim_target_modules
)
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs["loraplus_lr_embedding"] = (
self.cfg.loraplus_lr_embedding
)
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.accelerator_config: if self.cfg.accelerator_config:
training_arguments_kwargs["accelerator_config"] = ( training_arguments_kwargs[
self.cfg.accelerator_config "accelerator_config"
) ] = self.cfg.accelerator_config
if self.cfg.kd_ce_alpha is not None: if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
@@ -754,17 +714,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.kd_temperature is not None: if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None: if self.cfg.kd_zscore_base_temp is not None:
training_arguments_kwargs["kd_zscore_base_temp"] = ( training_arguments_kwargs[
self.cfg.kd_zscore_base_temp "kd_zscore_base_temp"
) ] = self.cfg.kd_zscore_base_temp
if self.cfg.kd_top_k_before_softmax is not None: if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs["kd_top_k_before_softmax"] = ( training_arguments_kwargs[
self.cfg.kd_top_k_before_softmax "kd_top_k_before_softmax"
) ] = self.cfg.kd_top_k_before_softmax
training_arguments_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
if self.cfg.reward_model: if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig training_args_cls = AxolotlRewardConfig
@@ -849,10 +805,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
): ):
if training_args.pretraining: if training_args.pretraining:
if ( if self.cfg.pretraining_sample_concatenation is False:
self.cfg.pretraining_sample_concatenation is False return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
or self.cfg.micro_batch_size > 1 if self.cfg.micro_batch_size > 1:
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None return None
@@ -880,7 +835,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if "max_length" in kwargs: if "max_length" in kwargs:
kwargs.pop("max_length") kwargs.pop("max_length")
elif use_batch_sampler_collator: elif use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or ( if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
self.cfg.model_config_type in ["llama"] self.cfg.model_config_type in ["llama"]
and self.cfg.flash_attention is not True and self.cfg.flash_attention is not True
): ):
@@ -911,8 +868,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator = DataCollatorForSeq2Seq collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt" kwargs["return_tensors"] = "pt"
if issubclass(collator, DataCollatorForSeq2Seq):
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
return collator( return collator(
*collator_args, *collator_args,
@@ -921,7 +876,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
class HFRLTrainerBuilder(TrainerBuilderBase): class HFRLTrainerBuilder(TrainerBuilderBase):
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)""" """
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
"""
def get_callbacks(self): def get_callbacks(self):
callbacks = super().get_callbacks() callbacks = super().get_callbacks()
@@ -975,32 +932,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
) )
if self.cfg.remove_unused_columns is not None: if self.cfg.remove_unused_columns is not None:
training_args_kwargs["remove_unused_columns"] = ( training_args_kwargs[
self.cfg.remove_unused_columns "remove_unused_columns"
) ] = self.cfg.remove_unused_columns
else: else:
training_args_kwargs["remove_unused_columns"] = False training_args_kwargs["remove_unused_columns"] = False
if self.cfg.dataloader_pin_memory is not None: if self.cfg.dataloader_pin_memory is not None:
training_args_kwargs["dataloader_pin_memory"] = ( training_args_kwargs[
self.cfg.dataloader_pin_memory "dataloader_pin_memory"
) ] = self.cfg.dataloader_pin_memory
if self.cfg.dataloader_num_workers is not None: if self.cfg.dataloader_num_workers is not None:
training_args_kwargs["dataloader_num_workers"] = ( training_args_kwargs[
self.cfg.dataloader_num_workers "dataloader_num_workers"
) ] = self.cfg.dataloader_num_workers
if self.cfg.dataloader_prefetch_factor is not None: if self.cfg.dataloader_prefetch_factor is not None:
training_args_kwargs["dataloader_prefetch_factor"] = ( training_args_kwargs[
self.cfg.dataloader_prefetch_factor "dataloader_prefetch_factor"
) ] = self.cfg.dataloader_prefetch_factor
if self.cfg.gradient_checkpointing: if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = ( training_args_kwargs[
self.cfg.gradient_checkpointing "gradient_checkpointing"
) ] = self.cfg.gradient_checkpointing
if self.cfg.gradient_checkpointing_kwargs is not None: if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs["gradient_checkpointing_kwargs"] = ( training_args_kwargs[
self.cfg.gradient_checkpointing_kwargs "gradient_checkpointing_kwargs"
) ] = self.cfg.gradient_checkpointing_kwargs
else: else:
training_args_kwargs["gradient_checkpointing_kwargs"] = { training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False "use_reentrant": False
@@ -1074,9 +1031,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dpo_use_weighting is not None: if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None: if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = ( training_args_kwargs[
self.cfg.dpo_use_logits_to_keep "use_logits_to_keep"
) ] = self.cfg.dpo_use_logits_to_keep
for blocklist_key in blocklist_args_kwargs: for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs: if blocklist_key in training_args_kwargs:
@@ -1111,9 +1068,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.adapter and self.peft_config: if self.cfg.adapter and self.peft_config:
dpo_trainer_kwargs["peft_config"] = self.peft_config dpo_trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None: if self.cfg.precompute_ref_log_probs is not None:
dpo_trainer_kwargs["precompute_ref_log_probs"] = ( dpo_trainer_kwargs[
self.cfg.precompute_ref_log_probs "precompute_ref_log_probs"
) ] = self.cfg.precompute_ref_log_probs
if self.cfg.rl == "grpo": if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class() trainer_cls = GRPOStrategy.get_trainer_class()
trainer_cls_args = [self.model] trainer_cls_args = [self.model]

View File

@@ -1,18 +0,0 @@
"""Init for axolotl.core.trainers"""
# pylint: disable=unused-import
# flake8: noqa
from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer
from .relora import ReLoRATrainer
from .trl import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlORPOTrainer,
AxolotlPRMTrainer,
AxolotlRewardTrainer,
TRLPPOTrainer,
)

View File

@@ -1,47 +1,163 @@
"""Module for customized trainers""" """
module for customized trainers
# pylint: disable=too-many-lines """
from __future__ import annotations from __future__ import annotations
# pylint: disable=too-many-lines
import logging import logging
import os import os
from collections import defaultdict from collections import defaultdict
from functools import wraps from functools import wraps
from typing import Any, Literal from typing import Dict, Literal, Optional
import datasets
import torch import torch
from datasets import Dataset from datasets import Dataset
from torch import nn from peft.optimizers import create_loraplus_optimizer
from torch.utils.data import ( from torch.optim.lr_scheduler import OneCycleLR
BatchSampler, from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
DataLoader,
RandomSampler,
Sampler,
SequentialSampler,
)
from transformers import Trainer from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.mixins import ( from axolotl.monkeypatch.relora import ReLoRAScheduler
OptimizerMixin,
SchedulerMixin,
SequenceParallelMixin,
)
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
)
LOG = logging.getLogger(__name__) if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
LOG = logging.getLogger("axolotl.core.trainer_builder")
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer): def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
"""Extend the base Trainer for axolotl helpers""" if isinstance(tag_names, str):
tag_names = [tag_names]
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
if isinstance(dataset_tags, str):
dataset_tags = [dataset_tags]
if (dataset_tags is not None) and (kwargs is not None):
if "dataset_tags" not in kwargs:
kwargs["dataset_tags"] = dataset_tags
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
kwargs["dataset_tags"].extend(dataset_tags)
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
dataset_tags.append(kwargs["dataset_tags"])
kwargs["dataset_tags"] = dataset_tags
return kwargs
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
if "pct_start" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["pct_start"] = pct_start
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["anneal_strategy"] = "cos"
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
class AxolotlTrainer(SchedulerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"] tag_names = ["axolotl"]
@@ -58,18 +174,12 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
self.eval_data_collator = eval_data_collator self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs) super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list)) self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha: if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# Initialize sequence parallelism if enabled
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
def _wrap_model(self, model, training=True, dataloader=None): def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile: if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
@@ -82,247 +192,316 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
) )
return super()._wrap_model(model, training=training, dataloader=dataloader) return super()._wrap_model(model, training=training, dataloader=dataloader)
def _create_multipack_sampler( def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
self, base_sampler: Sampler, dataset: Dataset decay_parameters = self.get_decay_parameter_names(opt_model)
) -> MultipackBatchSampler:
"""
Helper method to create a `MultipackBatchSampler` for multipacking sequences
for training.
Args:
base_sampler: Sampler to wrap with `MultipackBatchSampler`.
dataset: Dataset to sample from.
Returns:
Multipack (sample packing) batch sampler.
"""
if self.args.multipack_real_batches:
batch_size = self.args.per_device_train_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
train_batch_size = (
self.state.train_batch_size or self.args.per_device_train_batch_size
)
batch_max_len = train_batch_size * self.args.max_seq_length
return MultipackBatchSampler(
base_sampler,
lengths=get_dataset_lengths(dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
drop_last=True,
)
def _get_train_sampler(self) -> Sampler | None:
"""
Helper method to get the sampler for training. Handles cases for sequence
parallelism, sample packing, and curriculum sampling (sequential).
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args.
"""
use_sample_packing = self.args.sample_packing and not self.args.pretraining
# Determine the base sampler first
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_train_sampler(self.train_dataset)
elif self.args.curriculum_sampling:
base_sampler = SequentialSampler(self.train_dataset)
elif use_sample_packing:
base_sampler = RandomSampler(self.train_dataset)
else:
# Default to parent class implementation for standard random sampling
return super()._get_train_sampler()
# Apply multipack wrapper if needed
if use_sample_packing:
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=self.train_dataset,
)
return base_sampler
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
"""
Helper method to get the sampler for evaluation. Handles sequence parallelism
and sample packing cases.
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args.
"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Multipacking enabled if training is enabled and eval is not explicitly disabled
use_multipack = (
self.args.sample_packing and self.args.eval_sample_packing is not False
)
# Determine the base sampler
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_eval_sampler(eval_dataset)
elif use_multipack:
base_sampler = SequentialSampler(eval_dataset)
else:
return super()._get_eval_sampler(eval_dataset)
# Apply multipack wrapper if needed
if use_multipack:
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=eval_dataset,
)
return base_sampler
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
"""Create common dataloader parameters for train or eval."""
batch_size = custom_batch_size or (
self.args.eval_batch_size if is_eval else self._train_batch_size
)
params = { params = {
"batch_size": batch_size, "to_weight_decay": {}, # LayerNorm and bias
"collate_fn": self.data_collator, "embeddings": {}, # lm_head, embed_tokens,
"num_workers": self.args.dataloader_num_workers, "no_weight_decay": {},
"pin_memory": self.args.dataloader_pin_memory,
} }
lr_groups_lookup = {}
lr_groups_learning_rates = {}
if self.args.lr_groups:
for lr_group in self.args.lr_groups:
group_name = lr_group["name"]
group_modules = lr_group["modules"]
for module in group_modules:
lr_groups_lookup[module] = group_name
lr_groups_learning_rates[group_name] = lr_group["lr"]
params[f"to_weight_decay_{group_name}"] = {}
# Add persistent workers only for training for name, param in opt_model.named_parameters():
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"): if not param.requires_grad:
params["persistent_workers"] = self.args.dataloader_persistent_workers continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
lr_group_modules = [
group_modules
for group_modules in lr_groups_lookup
if group_modules in name
]
if lr_groups_lookup and any(lr_group_modules):
lr_group_module = lr_group_modules[0]
group_name = lr_groups_lookup[lr_group_module]
params[f"to_weight_decay_{group_name}"][name] = param
else:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
for group_name, group_lr in lr_groups_learning_rates.items():
if params[f"to_weight_decay_{group_name}"]:
optimizer_grouped_parameters.append(
{
"params": list(
params[f"to_weight_decay_{group_name}"].values()
),
"weight_decay": self.args.weight_decay,
"lr": group_lr,
}
)
# Add prefetch factor if specified return optimizer_grouped_parameters
if self.args.dataloader_prefetch_factor:
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return params def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
):
return super().create_optimizer()
def _prepare_dataloader( opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
self, dataset, sampler, is_eval=False, custom_batch_size=None if self.optimizer is None: # pylint: disable=access-member-before-definition
): optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
"""Prepare a dataloader with the given dataset and sampler.""" self.args,
# Get base parameters opt_model,
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size) )
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
# Add sampler configuration if self.args.loraplus_lr_ratio is not None:
if not isinstance(dataset, torch.utils.data.IterableDataset): loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
elif (
self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None
or self.args.lr_groups is not None
):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW(
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
elif self.args.alternate_optimizer == "ao_adamw_4bit":
from torchao.prototype.low_bit_optim import AdamW4bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT(
optimizer_grouped_parameters,
decouple=True,
**optimizer_kwargs,
)
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_train_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
train_batch_size = (
self.state.train_batch_size or self.args.per_device_train_batch_size
)
batch_max_len = train_batch_size * self.args.max_seq_length
if self.args.curriculum_sampling:
sampler = SequentialSampler(self.train_dataset)
else:
sampler = RandomSampler(self.train_dataset)
return MultipackBatchSampler(
sampler,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
)
if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_eval_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
batch_max_len = (
self.args.per_device_eval_batch_size * self.args.max_seq_length
)
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
lengths=get_dataset_lengths(self.eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset
if "length" in train_dataset.features.keys():
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler): if isinstance(sampler, BatchSampler):
# batch_size and batch_sampler are mutually exclusive
dataloader_params["batch_sampler"] = sampler dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"] del dataloader_params["batch_size"]
else: else:
dataloader_params["sampler"] = sampler dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
if not is_eval:
dataloader_params["worker_init_fn"] = seed_worker
# Create the dataloader
dataloader = DataLoader(dataset, **dataloader_params)
if self.args.sample_packing and (
(not is_eval and not self.args.pretraining)
or (is_eval and self.args.eval_sample_packing is not False)
):
self.accelerator.even_batches = False self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
# Return unprepared dataloader if using sequence parallelism DataLoader(train_dataset, **dataloader_params)
if self.args.sequence_parallel_degree > 1:
return dataloader
# Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader)
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
data_collator = self.data_collator # type: ignore
# Handle dataset preprocessing
if isinstance(train_dataset, datasets.Dataset):
if self.args.sample_packing and not self.args.pretraining:
train_dataset = train_dataset.remove_columns(["length"])
if not self.args.sample_packing or self.args.pretraining:
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
data_collator,
description="training",
) )
return super().get_train_dataloader()
# Get sampler and create dataloader def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
sampler = self._get_train_sampler()
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
"""Get dataloader for evaluation"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Handle special case: sample packing is enabled but eval_sample_packing is False
if self.args.sample_packing and self.args.eval_sample_packing is False: if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator self.eval_data_collator
) )
if "length" in eval_dataset.column_names: if eval_dataset:
eval_dataset = eval_dataset.remove_columns(["length"]) eval_dataset = eval_dataset.remove_columns(["length"])
dataloader = super().get_eval_dataloader(eval_dataset) dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator self.train_data_collator
) )
return dataloader return dataloader
# Handle sample packing or sequence parallelism if self.args.sample_packing and self.args.eval_sample_packing is not False:
if ( eval_dataset = (
self.args.sample_packing eval_dataset if eval_dataset is not None else self.eval_dataset
and self.args.eval_sample_packing is not False
or self.args.sequence_parallel_degree > 1
):
# Get appropriate data collator
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
if hasattr(self, "eval_data_collator") and self.eval_data_collator
else self.data_collator
)
if "length" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns(["length"])
# Handle dataset preprocessing for SP
if self.args.sequence_parallel_degree > 1:
if isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(
eval_dataset, description="evaluation"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
self.data_collator, description="evaluation"
)
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
batch_size = (
self.args.eval_batch_size
if self.args.sample_packing
else self.args.per_device_eval_batch_size
)
sampler = self._get_eval_sampler(eval_dataset)
dataloader = self._prepare_dataloader(
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
) )
return dataloader eval_sampler = self._get_eval_sampler(eval_dataset)
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
)
return super().get_eval_dataloader(eval_dataset) return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler( def _get_bench_sampler(
self, bench_dataset: Dataset self, bench_dataset: Dataset
) -> torch.utils.data.Sampler | None: ) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1: if self.args.world_size <= 1:
return SequentialSampler(bench_dataset) return SequentialSampler(bench_dataset)
return None return None
@@ -347,7 +526,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
return DataLoader(bench_dataset, **dataloader_params) return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
@override
def compute_loss( def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None self, model, inputs, return_outputs=False, num_items_in_batch=None
): ):
@@ -364,7 +542,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
return_outputs=return_outputs, return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch, num_items_in_batch=num_items_in_batch,
) )
return super().compute_loss( return super().compute_loss(
model, model,
inputs, inputs,
@@ -539,10 +716,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
""" """
kwargs = sanitize_kwargs_for_ds_tagging( kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs dataset_tags=self.dataset_tags, kwargs=kwargs
) )
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs) return super().push_to_hub(*args, **kwargs)
@@ -559,13 +736,15 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
return res return res
def log(self, logs: dict[str, float], start_time: float | None = None) -> None: def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
""" """
Log `logs` on the various objects watching training, including stored metrics. Log `logs` on the various objects watching training, including stored metrics.
Args: Args:
logs: The values to log. logs (`Dict[str, float]`):
start_time: The start of training. The values to log.
start_time (`Optional[float]`):
The start of training.
""" """
# logs either has 'loss' or 'eval_loss' # logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval" train_eval = "train" if "loss" in logs else "eval"
@@ -577,7 +756,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
return super().log(logs, start_time) return super().log(logs, start_time)
def store_metrics( def store_metrics(
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train" self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
) -> None: ) -> None:
for key, value in metrics.items(): for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value) self._stored_metrics[train_eval][key].append(value)
@@ -590,26 +769,110 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs) return super()._save_checkpoint(model, trial, **kwargs)
def training_step(
class AxolotlMambaTrainer(AxolotlTrainer):
"""
Mamba specific trainer to handle loss calculation
"""
tag_names = ["axolotl", "mamba"]
def compute_loss(
self, self,
model: nn.Module, model,
inputs: dict[str, torch.Tensor | Any], inputs,
num_items_in_batch: int | None = None, return_outputs=False, # pylint: disable=unused-argument
) -> torch.Tensor: num_items_in_batch=None, # pylint: disable=unused-argument
""" ):
Perform a training step on a batch of inputs. Overrides the input_ids = inputs.pop("input_ids")
`transformers.trainer.Trainer` method to handle sequence parallelism if lm_logits = model(input_ids).logits
enabled.
Args: labels = input_ids.to(lm_logits.device)
model: Model to perform training step for. shift_logits = lm_logits[:, :-1, :].contiguous()
inputs: Dictionary mapping. labels = labels[:, 1:].contiguous()
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal training step loss_fct = torch.nn.CrossEntropyLoss()
loss = super().training_step(model, inputs, num_items_in_batch) lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
return loss return lm_loss
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
tag_names = ["axolotl", "relora"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
anneal_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
tag_names = ["axolotl", "prm"]

View File

@@ -1,7 +1,6 @@
""" """
DPO Specific Strategy for training DPO Specific Strategy for training
""" """
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer

View File

@@ -1,7 +1,6 @@
""" """
Axolotl specific DPO args Axolotl specific DPO args
""" """
from dataclasses import dataclass from dataclasses import dataclass
from trl import DPOConfig from trl import DPOConfig

View File

@@ -1,7 +1,6 @@
""" """
DPO trainer for axolotl DPO trainer for axolotl
""" """
import gc import gc
from functools import wraps from functools import wraps
from typing import Any, Dict, Union from typing import Any, Dict, Union
@@ -13,10 +12,10 @@ from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer from trl import DPOTrainer
from axolotl.core.trainers.mixins import SchedulerMixin from axolotl.core.trainers.base import (
from axolotl.core.trainers.utils import ( SchedulerMixin,
sanitize_kwargs_for_ds_tagging, _sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging, _sanitize_kwargs_for_tagging,
) )
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
@@ -74,10 +73,10 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
""" """
kwargs = sanitize_kwargs_for_ds_tagging( kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs dataset_tags=self.dataset_tags, kwargs=kwargs
) )
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs) return super().push_to_hub(*args, **kwargs)

View File

@@ -9,7 +9,6 @@ import logging
from trl.trainer.grpo_trainer import RewardFunc from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.utils.schemas.trl import TRLConfig
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -32,44 +31,30 @@ class GRPOStrategy:
@classmethod @classmethod
def set_training_args_kwargs(cls, cfg): def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {} grpo_args_kwargs = {}
if cfg.trl and cfg.trl.use_vllm:
if not hasattr(cfg, "trl") or not cfg.trl: grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
return grpo_args_kwargs if cfg.trl and cfg.trl.vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
trl: TRLConfig = cfg.trl # type: ignore else:
grpo_args_kwargs["vllm_device"] = "auto"
if trl.use_vllm: if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
grpo_args_kwargs["use_vllm"] = trl.use_vllm grpo_args_kwargs[
grpo_args_kwargs["vllm_device"] = ( "vllm_gpu_memory_utilization"
trl.vllm_device if trl.vllm_device else "auto" ] = cfg.trl.vllm_gpu_memory_utilization
) if cfg.trl and cfg.trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
if trl.vllm_gpu_memory_utilization: if cfg.trl and cfg.trl.num_generations:
grpo_args_kwargs["vllm_gpu_memory_utilization"] = ( grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
trl.vllm_gpu_memory_utilization if cfg.trl and cfg.trl.sync_ref_model:
) grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
if trl.vllm_max_model_len: grpo_args_kwargs[
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len "ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if trl.num_generations: if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["num_generations"] = trl.num_generations grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
if trl.sync_ref_model: grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model
if trl.ref_model_mixup_alpha:
grpo_args_kwargs["ref_model_mixup_alpha"] = trl.ref_model_mixup_alpha
if trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
grpo_args_kwargs["log_completions"] = trl.log_completions
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights
return grpo_args_kwargs return grpo_args_kwargs
@classmethod @classmethod
@@ -86,9 +71,9 @@ class GRPOStrategy:
def set_trainer_kwargs(cls, cfg): def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {} trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_processing_classes: if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs["reward_processing_classes"] = ( trainer_kwargs[
cfg.trl.reward_processing_classes "reward_processing_classes"
) ] = cfg.trl.reward_processing_classes
return trainer_kwargs return trainer_kwargs
@classmethod @classmethod

View File

@@ -1,7 +1,6 @@
""" """
Axolotl Specific Training Args Axolotl Specific Training Args
""" """
from dataclasses import dataclass from dataclasses import dataclass
from trl import GRPOConfig from trl import GRPOConfig

View File

@@ -1,7 +1,6 @@
""" """
Axolotl GRPO trainer Axolotl GRPO trainer
""" """
from accelerate.utils import is_peft_model from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module from accelerate.utils.other import is_compiled_module
from transformers import PreTrainedModel from transformers import PreTrainedModel

View File

@@ -1,32 +0,0 @@
"""Module for mamba trainer"""
import torch
from axolotl.core.trainers.base import AxolotlTrainer
class AxolotlMambaTrainer(AxolotlTrainer):
"""Mamba specific trainer to handle loss calculation"""
tag_names = ["axolotl", "mamba"]
def compute_loss(
self,
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
return lm_loss

View File

@@ -1,8 +0,0 @@
"""Init for axolotl.core.trainers.mixins"""
# pylint: disable=unused-import
# flake8: noqa
from .optimizer import OptimizerMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin

View File

@@ -1,201 +0,0 @@
"""Module for Axolotl trainer optimizer mixin"""
import logging
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from transformers.trainer import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from axolotl.integrations.base import BaseOptimizerFactory
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
LOG = logging.getLogger(__name__)
class OptimizerMixin(Trainer):
"""Mixin class for shared handling of building custom optimizers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_optimizer_grouped_parameters(
self, opt_model, optimizer_kwargs
) -> list[dict]:
decay_parameters = self.get_decay_parameter_names(opt_model)
params: dict = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
lr_groups_lookup = {}
lr_groups_learning_rates = {}
if self.args.lr_groups:
for lr_group in self.args.lr_groups:
group_name = lr_group["name"]
group_modules = lr_group["modules"]
for module in group_modules:
lr_groups_lookup[module] = group_name
lr_groups_learning_rates[group_name] = lr_group["lr"]
params[f"to_weight_decay_{group_name}"] = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
lr_group_modules = [
group_modules
for group_modules in lr_groups_lookup
if group_modules in name
]
if lr_groups_lookup and any(lr_group_modules):
lr_group_module = lr_group_modules[0]
group_name = lr_groups_lookup[lr_group_module]
params[f"to_weight_decay_{group_name}"][name] = param
else:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
for group_name, group_lr in lr_groups_learning_rates.items():
if params[f"to_weight_decay_{group_name}"]:
optimizer_grouped_parameters.append(
{
"params": list(
params[f"to_weight_decay_{group_name}"].values()
),
"weight_decay": self.args.weight_decay,
"lr": group_lr,
}
)
return optimizer_grouped_parameters
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None
):
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if (
not self.optimizer
and self.optimizer_cls_and_kwargs is not None
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
):
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
self.optimizer = optimizer_factory_cls()(
opt_model, self.args, **optimizer_kwargs
)
if not self.optimizer:
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
self.args, opt_model
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
else:
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop(
"optimizer_dict"
)
self.optimizer = optimizer_cls(
optimizer_grouped_parameters, **optimizer_kwargs
)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer

View File

@@ -1,113 +0,0 @@
"""Module for Axolotl trainer scheduler mixin"""
import logging
import torch
from torch.optim.lr_scheduler import OneCycleLR
from transformers.trainer import Trainer
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
)
LOG = logging.getLogger(__name__)
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
if "pct_start" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["pct_start"] = pct_start
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["anneal_strategy"] = "cos"
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif self.args.alternate_lr_scheduler_type == "rex":
if use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = RexLR(
optimizer=optimizer,
max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler

View File

@@ -1,131 +0,0 @@
"""Module for Axolotl trainer sequence parallelism mixin"""
import logging
from typing import Any
import torch
import torch.distributed as dist
import torch.nn.functional as F
from datasets import Dataset
from torch.utils.data import DistributedSampler, Sampler
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
LOG = logging.getLogger(__name__)
try:
from ring_flash_attn import update_ring_flash_attn_params
except ImportError:
# We pass silently here, but raise an ImportError in our Axolotl config validation
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
pass
class SequenceParallelMixin:
"""
Mixin class for sequence parallelism support in trainers.
This mixin provides functionality for handling sequence parallelism,
including creating appropriate samplers, managing data partitioning,
and updating ring flash attention parameters during training.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def _setup_sequence_parallel(self):
"""Set up sequence parallelism environment."""
self.ring_attn_group = get_ring_attn_group()
def _create_sequence_parallel_sampler(
self,
dataset: Dataset,
shuffle: bool = True,
is_eval: bool = False,
) -> DistributedSampler:
"""
Helper method to create sampler for sequence parallelism (SP).
We create a distributed sampler with rank equal to the SP group ID, which
means that all ranks in the SP group receive the same sample / set of samples
per training step. We also set the number of replicas equal to the number of
SP groups, which is a bit of a hack / unintended use, but works!
Args:
dataset: Dataset to sample from.
shuffle: Whether to shuffle the dataset.
is_eval: Whether we are creating a sampler for evaluation or training.
Returns:
Distributed sampler.
"""
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
return DistributedSampler(
dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if shuffle else None,
shuffle=shuffle,
drop_last=not is_eval,
)
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
"""
Get a training sampler configured for sequence parallelism.
Args:
dataset: The training dataset
Returns:
Configured sequence parallel sampler.
"""
return self._create_sequence_parallel_sampler(
dataset,
shuffle=not self.args.curriculum_sampling,
)
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
"""
Get an evaluation sampler configured for sequence parallelism.
Args:
eval_dataset: The evaluation dataset.
Returns:
Configured sequence parallel sampler.
"""
return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
"""
Calculate the cu_seqlens for the current forward pass and pass the value to
the substituted ring_flash_attn. This is accomplished by using the passed
`input_ids`.
Args:
inputs: Current batch of inputs.
"""
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1]
packed_seq_lens = [seq_len] * batch_size
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * self.args.sequence_parallel_degree
cu_seqlens = torch.cumsum(
torch.tensor(
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)

View File

@@ -1,43 +0,0 @@
"""Module for ReLoRA trainer"""
import torch
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.monkeypatch.relora import ReLoRAScheduler
class ReLoRATrainer(AxolotlTrainer):
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
tag_names = ["axolotl", "relora"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: torch.optim.Optimizer | None = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
anneal_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler

View File

@@ -1,23 +1,15 @@
"""Module for TRL PPO trainer""" """
module for TRL PPO training
"""
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from trl import ( from trl import PPOTrainer
CPOTrainer,
KTOTrainer,
ORPOTrainer,
PPOTrainer,
PRMTrainer,
RewardTrainer,
)
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
class TRLPPOTrainer(PPOTrainer): class TRLPPOTrainer(PPOTrainer):
"""Wrapper for TRL PPO trainer to handle customizations""" """
wrapper for ppo trainer to handle customizations
tag_names = ["axolotl", "ppo"] """
def train( def train(
self, self,
@@ -38,7 +30,9 @@ class TRLPPOTrainer(PPOTrainer):
"batch_size": 16, "batch_size": 16,
} }
for _, batch in tqdm(enumerate(self.dataloader)): for epoch, batch in tqdm( # pylint: disable=unused-variable
enumerate(self.dataloader)
):
query_tensors = batch["input_ids"] query_tensors = batch["input_ids"]
# generate model response # generate model response
@@ -70,43 +64,3 @@ class TRLPPOTrainer(PPOTrainer):
rewards, rewards,
columns_to_log=["query", "response", "ref_response", "ref_rewards"], columns_to_log=["query", "response", "ref_response", "ref_rewards"],
) )
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
tag_names = ["axolotl", "prm"]

View File

@@ -1,33 +0,0 @@
"""Utils for Axolotl trainers"""
def sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs
def sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
if isinstance(dataset_tags, str):
dataset_tags = [dataset_tags]
if (dataset_tags is not None) and (kwargs is not None):
if "dataset_tags" not in kwargs:
kwargs["dataset_tags"] = dataset_tags
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
kwargs["dataset_tags"].extend(dataset_tags)
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
dataset_tags.append(kwargs["dataset_tags"])
kwargs["dataset_tags"] = dataset_tags
return kwargs

View File

@@ -1,7 +1,6 @@
""" """
extra axolotl specific training args extra axolotl specific training args
""" """
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
@@ -207,19 +206,14 @@ class AxolotlTrainingMixins:
}, },
) )
sequence_parallel_degree: Optional[int] = field(
default=1,
metadata={"help": "The number of workers to use in sequence parallelism"},
)
@dataclass @dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
""" """
Training arguments for Causal trainer Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
default value so it can't be used as a mixin. so it can't be used as a mixin.
""" """

View File

@@ -8,10 +8,9 @@ from typing import Dict, Optional
import torch import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
from datasets import Dataset
from transformers.trainer import Trainer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.telemetry.errors import send_errors
from axolotl.train import TrainDatasetMeta from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -27,18 +26,18 @@ LOG = get_logger("axolotl.evaluate")
def evaluate_dataset( def evaluate_dataset(
trainer: Trainer, dataset: Dataset, dataset_type: str, flash_optimum: bool = False trainer, dataset, dataset_type: str, flash_optimum: bool = False
) -> Optional[Dict[str, float]]: ) -> Optional[Dict[str, float]]:
"""Helper function to evaluate a single dataset. """Helper function to evaluate a single dataset safely.
Args: Args:
trainer: The trainer instance. trainer: The trainer instance
dataset: Dataset to evaluate. dataset: Dataset to evaluate
dataset_type: Type of dataset ('train' or 'eval'). dataset_type: Type of dataset ('train' or 'eval')
flash_optimum: Whether to use flash optimum. flash_optimum: Whether to use flash optimum
Returns: Returns:
Dictionary of metrics or None if dataset is None. Dictionary of metrics or None if dataset is None
""" """
if dataset is None: if dataset is None:
return None return None
@@ -63,16 +62,20 @@ def evaluate_dataset(
return metrics return metrics
@send_errors
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]: def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
""" """
Evaluate a model on training and validation datasets. Evaluate a model on training and validation datasets
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
dataset_meta: Dataset metadata containing training and evaluation datasets. dataset_meta: Dataset metadata containing training and evaluation datasets.
Returns: Returns:
Dictionary mapping metric names to their values. Tuple containing:
- The model (either PeftModel or PreTrainedModel)
- The tokenizer
- Dictionary of evaluation metrics
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
# Enable expandable segments for cuda allocation to improve VRAM usage # Enable expandable segments for cuda allocation to improve VRAM usage

View File

@@ -23,8 +23,6 @@ import importlib
import logging import logging
from typing import OrderedDict from typing import OrderedDict
import torch
class BasePlugin: class BasePlugin:
""" """
@@ -471,14 +469,3 @@ class PluginManager:
""" """
for plugin in self.plugins.values(): for plugin in self.plugins.values():
plugin.post_train_unload(cfg) plugin.post_train_unload(cfg)
class BaseOptimizerFactory:
"""
Base class for factories to create custom optimizers
"""
def __call__(
self, opt_model, training_args, **optimizer_kwargs
) -> "torch.optim.Optimizer":
pass

View File

@@ -11,17 +11,19 @@
# the License. # the License.
""" """
Module to handle merging the plugins' input arguments with the base configurations. module to handle merging the plugins' input arguments with the base configurations.
This was moved here to prevent circular imports. this was moved here to prevent circular imports
""" """
from typing import Any, Dict, List from typing import Any, Dict, List
from axolotl.utils.schemas.config import ( from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
) )
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlInputConfig as AxolotlInputConfigBase,
)
def merge_input_args(): def merge_input_args():

View File

@@ -4,22 +4,6 @@ Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy o
See https://github.com/apple/ml-cross-entropy See https://github.com/apple/ml-cross-entropy
## Requirements
- PyTorch 2.4.0 or higher
## Installation
Run the following command to install `cut_cross_entropy[transformers]` if you don't have it already.
```bash
# if you are in dev environment
python scripts/cutcrossentropy_install.py | sh
# if you are not in dev environment
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"
```
## Usage ## Usage
```yaml ```yaml

View File

@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
_CCE_INSTALL_MESSAGE = ( _CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using " "Please install cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`' '`pip install "cut-cross-entropy[transformers]==24.11.4"`'
) )

View File

@@ -1,7 +1,6 @@
""" """
Grokfast plugin for Axolotl Grokfast plugin for Axolotl
""" """
import logging import logging
from transformers.trainer_callback import TrainerCallback from transformers.trainer_callback import TrainerCallback

View File

@@ -1,7 +1,6 @@
""" """
config args for grokfast plugin config args for grokfast plugin
""" """
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -26,12 +26,12 @@ class KDArgs(BaseModel):
""" """
kd_trainer: Optional[bool] = None # whether to use KD trainer kd_trainer: Optional[bool] = None # whether to use KD trainer
kd_ce_alpha: Optional[float] = ( kd_ce_alpha: Optional[
None # loss coefficient for cross-entropy loss during KD float
) ] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[bool] = ( kd_top_k_before_softmax: Optional[
None # whether to sample top k before softmax during KD bool
) ] = None # whether to sample top k before softmax during KD

View File

@@ -55,9 +55,9 @@ class LigerPlugin(BasePlugin):
if "cross_entropy" in liger_fn_sig.parameters: if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters: if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs["fused_linear_cross_entropy"] = ( kwargs[
cfg.liger_fused_linear_cross_entropy "fused_linear_cross_entropy"
) ] = cfg.liger_fused_linear_cross_entropy
if "rms_norm" in liger_fn_sig.parameters: if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters: if "layer_norm" in liger_fn_sig.parameters:

View File

@@ -1,7 +1,6 @@
""" """
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union

View File

@@ -1,7 +1,6 @@
""" """
Jamba model with LigerFusedLinearCrossEntropyLoss Jamba model with LigerFusedLinearCrossEntropyLoss
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union

View File

@@ -1,7 +1,6 @@
""" """
Module for the Plugin for LM Eval Harness Module for the Plugin for LM Eval Harness
""" """
import subprocess # nosec import subprocess # nosec
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin

View File

@@ -1,7 +1,6 @@
""" """
Module for handling lm eval harness input arguments. Module for handling lm eval harness input arguments.
""" """
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel

View File

@@ -1,7 +1,6 @@
""" """
axolotl CLI for running lm_eval tasks axolotl CLI for running lm_eval tasks
""" """
import subprocess # nosec import subprocess # nosec
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime

View File

@@ -17,7 +17,7 @@ Module for handling Spectrum input arguments.
""" """
from typing import Optional from typing import Optional
from pydantic import BaseModel, model_validator from pydantic import BaseModel
class SpectrumArgs(BaseModel): class SpectrumArgs(BaseModel):
@@ -27,20 +27,3 @@ class SpectrumArgs(BaseModel):
spectrum_top_fraction: Optional[float] = 0.5 spectrum_top_fraction: Optional[float] = 0.5
spectrum_model_name: Optional[str] = None spectrum_model_name: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_fsdp_use_orig_params(cls, data):
if (
data.get("fsdp")
and data.get("fsdp_config")
and not data["fsdp_config"].get("use_orig_params")
and data.get("plugins")
and any("SpectrumPlugin" in plugin for plugin in data["plugins"])
):
# would otherwise raise
# ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False`
raise ValueError(
"FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set"
)
return data

View File

@@ -5,7 +5,6 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
""" """
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code # pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
import torch import torch

View File

@@ -6,7 +6,6 @@ See "LoRA: Low-Rank Adaptation of Large Language Models"
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
""" """
# pylint: disable=invalid-name # pylint: disable=invalid-name
from typing import Callable from typing import Callable

View File

@@ -1,5 +1,4 @@
"""Dequantization utilities for `bitsandbytes` integration.""" """Dequantization utilities for `bitsandbytes` integration."""
# pylint: disable=invalid-name,global-statement # pylint: disable=invalid-name,global-statement
import ctypes import ctypes

View File

@@ -5,7 +5,6 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
""" """
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl

View File

@@ -1,7 +1,6 @@
""" """
HF Transformers MambaConfig HF Transformers MambaConfig
""" """
from transformers import PretrainedConfig from transformers import PretrainedConfig

Some files were not shown because too many files have changed in this diff Show More