Compare commits
1 Commits
feat/glm45
...
fix/hpc-ro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
83ff8bfa1a |
6
.github/FUNDING.yml
vendored
6
.github/FUNDING.yml
vendored
@@ -1,13 +1,13 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||
github: [winglian, OpenAccess-AI-Collective] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||
patreon: # Replace with a single Patreon username
|
||||
open_collective: # Replace with a single Open Collective username
|
||||
ko_fi: # Replace with a single Ko-fi username
|
||||
ko_fi: axolotl_ai # Replace with a single Ko-fi username
|
||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||
liberapay: # Replace with a single Liberapay username
|
||||
issuehunt: # Replace with a single IssueHunt username
|
||||
otechie: # Replace with a single Otechie username
|
||||
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
||||
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
||||
custom: ['https://quickchart.io/qr?text=bitcoin%3Abc1qxlgwlqwfea5s2cxm42xqsfmwjct0rj8w8ea5np&size=480¢erImageUrl=https%3A%2F%2Fupload.wikimedia.org%2Fwikipedia%2Fcommons%2Fthumb%2F4%2F46%2FBitcoin.svg%2F64px-Bitcoin.svg.png'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
||||
|
||||
9
.github/workflows/base.yml
vendored
9
.github/workflows/base.yml
vendored
@@ -57,14 +57,14 @@ jobs:
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
# - cuda: "128"
|
||||
@@ -90,6 +90,7 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl-base
|
||||
axolotlai/axolotl-base
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
@@ -146,14 +147,14 @@ jobs:
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
steps:
|
||||
|
||||
3
.github/workflows/docs.yml
vendored
3
.github/workflows/docs.yml
vendored
@@ -12,9 +12,6 @@ jobs:
|
||||
build-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: cleanup node
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Quarto
|
||||
|
||||
27
.github/workflows/main.yml
vendored
27
.github/workflows/main.yml
vendored
@@ -25,6 +25,7 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -35,17 +36,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -55,6 +45,7 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl
|
||||
axolotlai/axolotl
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
@@ -108,6 +99,7 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -118,17 +110,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -138,6 +119,7 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl-cloud
|
||||
axolotlai/axolotl-cloud
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
@@ -197,6 +179,7 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl-cloud-term
|
||||
axolotlai/axolotl-cloud-term
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
|
||||
2
.github/workflows/nightlies.yml
vendored
2
.github/workflows/nightlies.yml
vendored
@@ -31,6 +31,7 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl
|
||||
axolotlai/axolotl
|
||||
tags: |
|
||||
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
||||
@@ -83,6 +84,7 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl-cloud
|
||||
axolotlai/axolotl-cloud
|
||||
tags: |
|
||||
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
||||
|
||||
5
.github/workflows/preview-docs.yml
vendored
5
.github/workflows/preview-docs.yml
vendored
@@ -11,7 +11,6 @@ on:
|
||||
- '_quarto.yml'
|
||||
- docs/scripts/generate_config_docs.py
|
||||
- src/axolotl/utils/schemas/**.py
|
||||
- .github/workflows/preview-docs.yml
|
||||
|
||||
permissions:
|
||||
checks: write
|
||||
@@ -28,10 +27,6 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
steps:
|
||||
- name: cleanup node
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
|
||||
58
.github/workflows/tests.yml
vendored
58
.github/workflows/tests.yml
vendored
@@ -59,19 +59,15 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: cleanup node
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# - name: Restore Cache from S3
|
||||
# id: hf-cache-restore-s3
|
||||
# run: |
|
||||
# mkdir -p ~/.cache/huggingface/hub
|
||||
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
#
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p /home/runner/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -95,10 +91,6 @@ jobs:
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
@@ -113,13 +105,9 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
df -h
|
||||
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
df -h
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
df -h
|
||||
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||
df -h
|
||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
@@ -130,6 +118,10 @@ jobs:
|
||||
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
||||
fail_ci_if_error: false
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
pytest-sdist:
|
||||
name: PyTest from Source Dist
|
||||
runs-on: ubuntu-latest
|
||||
@@ -142,19 +134,15 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: cleanup node
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# - name: Restore Cache from S3
|
||||
# id: hf-cache-restore-s3
|
||||
# run: |
|
||||
# mkdir -p ~/.cache/huggingface/hub
|
||||
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
#
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p /home/runner/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -179,10 +167,6 @@ jobs:
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
@@ -192,14 +176,18 @@ jobs:
|
||||
axolotl --help
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache scan
|
||||
run: huggingface-cli scan-cache
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
gate-skip-e2e:
|
||||
needs: [pre-commit, pytest, pytest-sdist]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@@ -11,13 +11,13 @@ repos:
|
||||
- id: no-commit-to-branch
|
||||
args: ['--branch', 'main']
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.7
|
||||
rev: v0.14.3
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.19.0
|
||||
rev: v1.18.2
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
@@ -26,7 +26,7 @@ repos:
|
||||
'pydantic>=2.5.3',
|
||||
]
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.9.2
|
||||
rev: 1.8.6
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: [
|
||||
|
||||
@@ -10,7 +10,6 @@ ARG BASE_VOLUME="/runpod-volume"
|
||||
ENV BASE_VOLUME=$BASE_VOLUME
|
||||
ENV HF_DATASETS_CACHE="${BASE_VOLUME}/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
|
||||
ENV HF_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
|
||||
ENV TRANSFORMERS_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
|
||||
|
||||
COPY .runpod/src /src
|
||||
|
||||
13
README.md
13
README.md
@@ -29,10 +29,6 @@
|
||||
|
||||
## 🎉 Latest Updates
|
||||
|
||||
- 2025/12: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3), [Trinity](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/trinity), and [Ministral3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/ministral3).
|
||||
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss).
|
||||
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
||||
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
||||
- 2025/07:
|
||||
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
|
||||
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
|
||||
@@ -40,12 +36,12 @@
|
||||
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
|
||||
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Expand older updates</summary>
|
||||
|
||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
||||
@@ -158,13 +154,6 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
|
||||
|
||||
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
|
||||
|
||||
## 📈 Telemetry
|
||||
|
||||
Axolotl has opt-out telemetry that helps us understand how the project is being used
|
||||
and prioritize improvements. We collect basic system information, model types, and
|
||||
error rates—never personal data or file paths. Telemetry is enabled by default. To
|
||||
disable it, set AXOLOTL_DO_NOT_TRACK=1. For more details, see our [telemetry documentation](https://docs.axolotl.ai/docs/telemetry.html).
|
||||
|
||||
## ❤️ Sponsors
|
||||
|
||||
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
|
||||
|
||||
@@ -241,7 +241,6 @@ website:
|
||||
- docs/installation.qmd
|
||||
- docs/inference.qmd
|
||||
- docs/cli.qmd
|
||||
- docs/telemetry.qmd
|
||||
- docs/config-reference.qmd
|
||||
- text: "API Reference"
|
||||
href: docs/api
|
||||
|
||||
@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.10"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
@@ -24,14 +24,14 @@ RUN apt-get update \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& mkdir -p /workspace/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
@@ -51,7 +51,7 @@ RUN git lfs install --skip-repo && \
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||
pip3 cache purge
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.9.1" ] && [ "$CUDA" = "128" ] ; then \
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
|
||||
@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="next"
|
||||
@@ -19,12 +19,12 @@ RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& mkdir -p /workspace/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="nightly"
|
||||
@@ -19,14 +19,14 @@ RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& mkdir -p /workspace/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
||||
@@ -218,13 +218,6 @@ If you have tool arguments with same name but different dtypes (like `"time": st
|
||||
```
|
||||
"arguments": "{\"...\": \"...\"}"
|
||||
```
|
||||
|
||||
The same is applicable for tool parameters.
|
||||
|
||||
```
|
||||
"parameters": "{\"...\": \"...\"}"
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
Example config for Llama4:
|
||||
|
||||
@@ -4,7 +4,7 @@ format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
# number-sections: true
|
||||
number-sections: true
|
||||
code-tools: true
|
||||
execute:
|
||||
enabled: false
|
||||
@@ -14,18 +14,12 @@ This guide covers advanced training configurations for multi-GPU setups using Ax
|
||||
|
||||
## Overview {#sec-overview}
|
||||
|
||||
When training on multiple GPUs, Axolotl supports 3 sharding/parallelism strategies. Additionally, you can layer specific optimization features on top of that strategy.
|
||||
Axolotl supports several methods for multi-GPU training:
|
||||
|
||||
You generally cannot combine these strategies; they are mutually exclusive.
|
||||
|
||||
1. **DeepSpeed**: Powerful optimization library, supports ZeRO stages 1-3.
|
||||
2. **FSDP (Fully Sharded Data Parallel)**: PyTorch's native sharding implementation (Recommended).
|
||||
3. **DDP (Distributed Data Parallel)**: PyTorch's native parallelism implementation (Default if neither of the above are selected).
|
||||
|
||||
These features can often be combined with the strategies above:
|
||||
|
||||
* **Sequence Parallelism**: Splits long sequences across GPUs (Compatible with DDP, DeepSpeed, and FSDP).
|
||||
* **FSDP + QLoRA**: Combines 4-bit quantization with FSDP (Specific to FSDP).
|
||||
- DeepSpeed (recommended)
|
||||
- FSDP (Fully Sharded Data Parallel)
|
||||
- Sequence parallelism
|
||||
- FSDP + QLoRA
|
||||
|
||||
## DeepSpeed {#sec-deepspeed}
|
||||
|
||||
@@ -71,18 +65,12 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
|
||||
|
||||
## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
|
||||
|
||||
FSDP allows you to shard model parameters, gradients, and optimizer states across data parallel workers.
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
FSDP2 is recommended for new users. FSDP1 is deprecated and will be removed in an upcoming release of Axolotl.
|
||||
|
||||
:::
|
||||
|
||||
### FSDP + QLoRA {#sec-fsdp-qlora}
|
||||
|
||||
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
|
||||
|
||||
### Migrating from FSDP1 to FSDP2 {#sec-migrate-fsdp1-fsdp2}
|
||||
|
||||
To migrate your config from FSDP1 to FSDP2, you must use the `fsdp_version` top-level config field to specify the FSDP version, and
|
||||
@@ -157,6 +145,10 @@ single sequence causes OOM errors during model training.
|
||||
|
||||
See our [dedicated guide](sequence_parallelism.qmd) for more information.
|
||||
|
||||
### FSDP + QLoRA {#sec-fsdp-qlora}
|
||||
|
||||
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
|
||||
|
||||
## Performance Optimization {#sec-performance}
|
||||
|
||||
### Liger Kernel Integration {#sec-liger}
|
||||
|
||||
@@ -124,8 +124,6 @@ Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Voxtral-Mini-3B-2507
|
||||
|
||||
processor_type: VoxtralProcessor
|
||||
```
|
||||
|
||||
### Gemma-3 {#sec-gemma-3}
|
||||
|
||||
110
docs/rlhf.qmd
110
docs/rlhf.qmd
@@ -597,116 +597,6 @@ To see other examples of custom reward functions, please see [TRL GRPO Docs](htt
|
||||
|
||||
To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
|
||||
|
||||
#### OpenEnv Rollout Functions
|
||||
|
||||
GRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments.
|
||||
|
||||
For example, to implement a simple math-solving environment with step-by-step verification:
|
||||
|
||||
```python
|
||||
# math_env.py
|
||||
import re
|
||||
|
||||
def math_solver_rollout(model, processing_class, prompts, generation_config=None):
|
||||
"""
|
||||
Custom rollout function that generates step-by-step math solutions.
|
||||
|
||||
Args:
|
||||
model: The language model
|
||||
processing_class: The tokenizer/processing_class
|
||||
prompts: List of prompt dicts (with 'messages' key for chat format)
|
||||
generation_config: Optional generation configuration
|
||||
|
||||
Returns:
|
||||
List of completion strings
|
||||
"""
|
||||
completions = []
|
||||
|
||||
for prompt in prompts:
|
||||
# Apply chat template to prompt
|
||||
messages = prompt.get("messages", [])
|
||||
formatted_prompt = processing_class.apply_chat_template(
|
||||
messages, processing_class=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Generate step-by-step solution
|
||||
full_response = ""
|
||||
for step in range(5): # Max 5 reasoning steps
|
||||
current_input = formatted_prompt + full_response + "\nNext step:"
|
||||
inputs = processing_class(current_input, return_tensors="pt").to(model.device)
|
||||
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=100,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
step_text = processing_class.decode(
|
||||
outputs[0][inputs.input_ids.shape[1]:],
|
||||
skip_special_tokens=True
|
||||
)
|
||||
|
||||
# Check if solution is complete
|
||||
if "FINAL ANSWER:" in step_text:
|
||||
full_response += step_text
|
||||
break
|
||||
full_response += step_text + "\n"
|
||||
|
||||
completions.append(full_response)
|
||||
|
||||
return completions
|
||||
|
||||
def math_reward(prompts, completions, answers, **kwargs):
|
||||
"""Reward function that checks mathematical correctness"""
|
||||
rewards = []
|
||||
for completion, correct_answer in zip(completions, answers):
|
||||
# Extract predicted answer
|
||||
match = re.search(r"FINAL ANSWER:\s*(.+)", completion)
|
||||
predicted = match.group(1).strip() if match else ""
|
||||
|
||||
# Compare with correct answer
|
||||
reward = 1.0 if predicted == str(correct_answer) else 0.0
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
def math_transform(cfg, *args, **kwargs):
|
||||
"""Transform dataset to GRPO format with answer field"""
|
||||
def transform_fn(example, processing_class=None):
|
||||
return {
|
||||
"prompt": [{"role": "user", "content": example["question"]}],
|
||||
"answer": str(example["answer"]),
|
||||
}
|
||||
return transform_fn, {"remove_columns": ["question"]}
|
||||
```
|
||||
|
||||
```yaml
|
||||
rl: grpo
|
||||
|
||||
trl:
|
||||
beta: 0.001
|
||||
max_completion_length: 512
|
||||
num_generations: 4
|
||||
rollout_func: "math_env.math_solver_rollout" # Custom rollout function
|
||||
reward_funcs: ["math_env.math_reward"]
|
||||
reward_weights: [1.0]
|
||||
|
||||
datasets:
|
||||
- path: openai/gsm8k
|
||||
name: main
|
||||
type: math_env.math_transform
|
||||
```
|
||||
|
||||
The `rollout_func` parameter accepts a fully qualified name (e.g., `module_name.function_name`) that points to a callable function in your local directory. The function receives:
|
||||
|
||||
- `model`: The language model
|
||||
- `processing_class`: The tokenizer/processing class
|
||||
- `prompts`: List of prompt dictionaries
|
||||
- `generation_config` (optional): Generation configuration
|
||||
|
||||
And should return a list of completion strings.
|
||||
|
||||
For more OpenEnv examples, see [TRL OpenEnv Documentation](https://huggingface.co/docs/trl/main/en/openenv).
|
||||
|
||||
#### GRPO with DAPO/Dr. GRPO loss
|
||||
|
||||
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
---
|
||||
title: Telemetry
|
||||
description: A description of the telemetry implementation in Axolotl.
|
||||
---
|
||||
|
||||
# Telemetry in Axolotl
|
||||
|
||||
Axolotl implements anonymous telemetry to help maintainers understand how the library
|
||||
is used and where users encounter issues. This data helps prioritize features, optimize
|
||||
performance, and fix bugs.
|
||||
|
||||
## Data Collection
|
||||
|
||||
We collect:
|
||||
|
||||
- System info: OS, Python version, Axolotl version, PyTorch version, Transformers
|
||||
version, etc.
|
||||
- Hardware info: CPU count, memory, GPU count and models
|
||||
- Runtime metrics: Training progress, memory usage, timing information
|
||||
- Usage patterns: Models (from a whitelist) and configurations used
|
||||
- Error tracking: Stack traces and error messages (sanitized to remove personal
|
||||
information)
|
||||
|
||||
Personally identifiable information (PII) is not collected.
|
||||
|
||||
## Implementation
|
||||
|
||||
Telemetry is implemented using PostHog and consists of:
|
||||
|
||||
- `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the
|
||||
telemetry system and provides methods for tracking events.
|
||||
- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and
|
||||
sends sanitized stack traces.
|
||||
- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks
|
||||
runtime metrics during training.
|
||||
- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends
|
||||
runtime metrics telemetry.
|
||||
|
||||
The telemetry system will block training startup for 10 seconds to ensure users are
|
||||
aware of data collection, unless telemetry is explicitly enabled or disabled.
|
||||
|
||||
## Opt-Out Mechanism
|
||||
|
||||
Telemetry is **enabled by default** on an opt-out basis. To disable it, set
|
||||
`AXOLOTL_DO_NOT_TRACK=1` or `DO_NOT_TRACK=1`.
|
||||
|
||||
A warning message will be logged on start to clearly inform users about telemetry.
|
||||
We will remove this after some period.
|
||||
|
||||
To hide the warning message about telemetry that is displayed on train, etc. startup,
|
||||
explicitly set: `AXOLOTL_DO_NOT_TRACK=0` (enable telemetry) or `AXOLOTL_DO_NOT_TRACK=1`
|
||||
(explicitly disable telemetry).
|
||||
|
||||
## Privacy
|
||||
|
||||
- All path-like config information is automatically redacted from telemetry data
|
||||
- Model information is only collected for whitelisted organizations
|
||||
- See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations
|
||||
- Each run generates a unique anonymous ID
|
||||
- This allows us to link different telemetry events in a single same training run
|
||||
- Telemetry is only sent from the main process to avoid duplicate events
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec\""
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -253,6 +253,7 @@
|
||||
"source": [
|
||||
"from axolotl.utils import set_pytorch_cuda_alloc_conf\n",
|
||||
"\n",
|
||||
"# Set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
|
||||
"set_pytorch_cuda_alloc_conf()"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
# Finetune GLM4.5 with Axolotl
|
||||
|
||||
[UNSTABLE]
|
||||
|
||||
```bash
|
||||
# LoRA SFT (4xH200 @ 84GB/GPU)
|
||||
axolotl train examples/glm45/glm4.5-lora-fsdp2.yaml
|
||||
|
||||
# FFT SFT (4xH200)
|
||||
# Checkpointing error on backward pass
|
||||
# Without checkpointing => OOM
|
||||
axolotl train examples/glm45/glm4.5-fft-fsdp2.yaml
|
||||
```
|
||||
|
||||
## Dataset
|
||||
|
||||
In addition to normal OpenAI Messages format, GLM4.5 support an extra parameter for thinking in assistant section.
|
||||
|
||||
```json
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "...", // or have </think>...</think> in `content`
|
||||
"content": "...",
|
||||
}
|
||||
```
|
||||
|
||||
Note:
|
||||
- The role name for tools in this template is `tool`.
|
||||
- You will see this Axolotl WARNING. This is to be as expected as the template does not use EOS.
|
||||
```bash
|
||||
EOS token '<|endoftext|>' not found in chat_template. Please check if your template/EOS token is correct.
|
||||
```
|
||||
- Make sure you set the below extra attributes if needed
|
||||
```yaml
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
# tool_calls: tool_calls # uncomment if using tools
|
||||
# reasoning_content: reasoning_content # uncomment if have reasoning
|
||||
|
||||
# Uncomment if training on tool role (you would rarely if ever need this)
|
||||
# eot_tokens:
|
||||
# - <|observation|>
|
||||
```
|
||||
@@ -1,59 +0,0 @@
|
||||
base_model: zai-org/GLM-4.5-Air
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: winglian/pirate-ultrachat-10k
|
||||
type: chat_template
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/qlora-out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
# gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Glm4MoeDecoderLayer
|
||||
state_dict_type: SHARDED_STATE_DICT
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
@@ -1,74 +0,0 @@
|
||||
base_model: zai-org/GLM-4.5-Air
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: winglian/pirate-ultrachat-10k
|
||||
type: chat_template
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/qlora-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
# gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Glm4MoeDecoderLayer
|
||||
state_dict_type: SHARDED_STATE_DICT
|
||||
reshard_after_forward: true
|
||||
# activation_checkpointing: false
|
||||
@@ -32,10 +32,6 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -28,10 +28,6 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -29,10 +29,6 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -28,10 +28,6 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -41,10 +41,6 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -41,10 +41,6 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
# Finetune IBM's Granite 4.0 with Axolotl
|
||||
|
||||
[Granite 4.0](https://huggingface.co/collections/ibm-granite/granite-40-language-models) are a family of open source models trained by IBM Research.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Granite4 is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.7.1 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/granite4/granite-4.0-tiny-fft.yaml
|
||||
```
|
||||
|
||||
This config uses about 40.8GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
### Limitation
|
||||
|
||||
Adapter finetuning does not work at the moment. It would error with
|
||||
|
||||
```bash
|
||||
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x3072 and 1x1179648)
|
||||
```
|
||||
|
||||
In addition, if adapter training works, `lora_target_linear: true` will not work due to:
|
||||
```bash
|
||||
ValueError: Target module GraniteMoeHybridParallelExperts() is not supported.
|
||||
```
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Granite Docs](https://www.ibm.com/granite/docs/models/granite)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
@@ -1,45 +0,0 @@
|
||||
base_model: ibm-granite/granite-4.0-tiny-preview
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/model-out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -29,6 +29,7 @@ flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
save_strategy: no
|
||||
torch_compile: true
|
||||
|
||||
wandb_project:
|
||||
|
||||
@@ -13,7 +13,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
# Finetune Ministral with Axolotl
|
||||
|
||||
Ministral is a family of openweight models from MistralAI found on [HuggingFace](mistralai/Ministral-8B-Instruct-2410). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/ministral/ministral-small-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 8.76 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Tips
|
||||
|
||||
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Limitations
|
||||
|
||||
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
||||
|
||||
In addition, we do not support overriding tokens yet.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [MistralAI Ministral Blog](https://mistral.ai/news/ministraux)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
|
||||
|
||||
## Future Work
|
||||
|
||||
- Add parity to Preference Tuning, RL, etc.
|
||||
- Add parity to other tokenizer configs like overriding tokens.
|
||||
@@ -1,67 +0,0 @@
|
||||
base_model: mistralai/Ministral-8B-Instruct-2410
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,79 +0,0 @@
|
||||
# Finetune Ministral3 with Axolotl
|
||||
|
||||
Ministral3 is a family of open-weight models from MistralAI found on [HuggingFace](https://huggingface.co/collections/mistralai/ministral-3). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
Please see [Thinking](#thinking) and [Vision](#vision) for their respective fine-tuning.
|
||||
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for these releases.
|
||||
|
||||
Note: This is still experimental given it is based on transformers v5 RC.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Swap to the Axolotl transformers v5 branch
|
||||
|
||||
```bash
|
||||
cp examples/ministral3/ministral3-3b-qlora.yaml ministral3-3b-qlora.yaml
|
||||
|
||||
git fetch
|
||||
git checkout transformers-v5
|
||||
|
||||
# Install packages for transformers v5
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
4. Run the fine-tuning:
|
||||
|
||||
```bash
|
||||
axolotl train ministral3-3b-qlora.yaml
|
||||
```
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
|
||||
### Tips
|
||||
|
||||
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
### Thinking
|
||||
|
||||
Ministral3 2512 model supports thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps.
|
||||
|
||||
📚 **[See the Thinking fine-tuning guide →](./think/README.md)**
|
||||
|
||||
### Vision
|
||||
|
||||
Ministral3 2512 model also supports vision capabilities.
|
||||
|
||||
📚 **[See the Vision fine-tuning guide →](./vision/README.md)**
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Limitations
|
||||
|
||||
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
||||
|
||||
In addition, we do not support overriding tokens yet.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [MistralAI Mistral3 Blog](https://mistral.ai/news/mistral-3)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
|
||||
|
||||
## Future Work
|
||||
|
||||
- Add parity to Preference Tuning, RL, etc.
|
||||
- Add parity to other tokenizer configs like overriding tokens.
|
||||
@@ -1,67 +0,0 @@
|
||||
base_model: mistralai/Ministral-3-3B-Reasoning-2512
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,73 +0,0 @@
|
||||
# Ministral3 2512 Thinking Fine-tuning
|
||||
|
||||
This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collections/mistralai/ministral-3) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
- Installed Axolotl (see [main README](../README.md))
|
||||
|
||||
## Getting Started
|
||||
|
||||
Run the thinking model fine-tuning:
|
||||
|
||||
```bash
|
||||
axolotl train examples/ministral3/think/ministral3-3b-think-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 4.76 GiB VRAM.
|
||||
|
||||
### Tips
|
||||
|
||||
- Dataset uses multi-content format with `type: thinking` support. See [Dataset Format](#dataset-format) below.
|
||||
- You cannot mix `content: str` and `content: list[dict]`, otherwise, dataset loading will fail. Keep it consistent.
|
||||
|
||||
## Dataset Format
|
||||
|
||||
The thinking model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
|
||||
|
||||
Example format:
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{ "type": "text", "text": "{SYSTEM_PROMPT}"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{ "type": "text", "text": "Solve this step by step: What is 15% of 240?"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36."
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "To find 15% of 240, I'll multiply 240 by 0.15:\n\n240 × 0.15 = 36\n\nTherefore, 15% of 240 is 36."
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Options
|
||||
|
||||
The `thinking` section supports an optional `closed` parameter:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "Internal reasoning here...",
|
||||
"closed": true // Default: true, controls adding the closing [/THINK] tag
|
||||
}
|
||||
```
|
||||
@@ -1,67 +0,0 @@
|
||||
base_model: mistralai/Ministral-3-3B-Reasoning-2512
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: Nanobit/text-think-2k-test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,57 +0,0 @@
|
||||
# Ministral3 2512 Vision Fine-tuning
|
||||
|
||||
This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collections/mistralai/ministral-3) with vision capabilities using Axolotl.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
- Installed Axolotl from source (see [main README](../README.md#getting-started))
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install the required vision lib:
|
||||
```bash
|
||||
pip install 'mistral-common[opencv]==1.8.6'
|
||||
```
|
||||
|
||||
2. Download the example dataset image:
|
||||
```bash
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
```
|
||||
|
||||
3. Run the fine-tuning:
|
||||
```bash
|
||||
axolotl train examples/ministral3/vision/ministral3-3b-vision-qlora.yml
|
||||
```
|
||||
|
||||
WARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.
|
||||
|
||||
### Tips
|
||||
|
||||
Key differences from text-only model:
|
||||
- Multi-modal dataset format required
|
||||
- Sample packing not supported
|
||||
|
||||
## Dataset Format
|
||||
|
||||
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
||||
{"role": "user", "content": [
|
||||
{ "type": "text", "text": "What's in this image?"},
|
||||
{"type": "image", "path": "path/to/image.jpg" }
|
||||
]},
|
||||
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- Sample Packing is not supported for multi-modality training currently.
|
||||
@@ -1,64 +0,0 @@
|
||||
base_model: mistralai/Ministral-3-3B-Reasoning-2512
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
# sample dataset below requires downloading image in advance
|
||||
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
datasets:
|
||||
- path: Nanobit/text-vision-2k-test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,38 +0,0 @@
|
||||
# Finetune Allenai's Olmo 3 with Axolotl
|
||||
|
||||
[Olmo 3](https://huggingface.co/collections/allenai/olmo-3) are a family of 7B and 32B models open source models trained by The Allen Institute for Artificial Intelligence.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/olmo3/olmo3-7b-qlora.yaml
|
||||
```
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- The example config can be re-used for Olmo and Olmo 2.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Olmo 3 Blog](https://allenai.org/blog/olmo3)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
@@ -1,64 +0,0 @@
|
||||
base_model: allenai/Olmo-3-7B-Instruct-SFT
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,67 +0,0 @@
|
||||
base_model: google/gemma-3-12b-it
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/out_gemma/
|
||||
|
||||
sequence_len: 8096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 16
|
||||
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 4e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,72 +0,0 @@
|
||||
base_model: google/gemma-3-12b-it
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/qat_out_gemma/
|
||||
|
||||
sequence_len: 8096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
qat:
|
||||
activation_dtype: nvfp4
|
||||
weight_dtype: nvfp4
|
||||
group_size: 16 # only group_size of 16 is supported with nvfp4
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 16
|
||||
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 4e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,67 +0,0 @@
|
||||
base_model: google/gemma-3-12b-it
|
||||
# Math finetuning configuration for Gemma3-12B
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: AI-MO/NuminaMath-CoT
|
||||
type: chat_template
|
||||
|
||||
output_dir: ./outputs/out_math_gemma/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 8
|
||||
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 3e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,72 +0,0 @@
|
||||
base_model: google/gemma-3-12b-it
|
||||
# Math finetuning configuration for Gemma3-12B
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: AI-MO/NuminaMath-CoT
|
||||
type: chat_template
|
||||
|
||||
output_dir: ./outputs/qat_out_math_gemma/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
qat:
|
||||
activation_dtype: nvfp4
|
||||
weight_dtype: nvfp4
|
||||
group_size: 16 # only group_size of 16 is supported with nvfp4
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 8
|
||||
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 3e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,68 +0,0 @@
|
||||
base_model: google/gemma-3-27b-it
|
||||
# Math finetuning configuration for Gemma3-27B
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: AI-MO/NuminaMath-CoT
|
||||
type: chat_template
|
||||
|
||||
output_dir: ./outputs/out_math_gemma27/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 16
|
||||
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-6
|
||||
eta_min: 7e-7
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,73 +0,0 @@
|
||||
base_model: google/gemma-3-27b-it
|
||||
# Math finetuning configuration for Gemma3-27B
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: AI-MO/NuminaMath-CoT
|
||||
type: chat_template
|
||||
|
||||
output_dir: ./outputs/qat_out_math_gemma27/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
qat:
|
||||
activation_dtype: nvfp4
|
||||
weight_dtype: nvfp4
|
||||
group_size: 16 # only group_size of 16 is supported with nvfp4
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 16
|
||||
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-6
|
||||
eta_min: 7e-7
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,67 +0,0 @@
|
||||
base_model: Qwen/Qwen2.5-72B
|
||||
# Math finetuning configuration for Qwen2.5-72B (non-instruct)
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: qwen_25
|
||||
datasets:
|
||||
- path: AI-MO/NuminaMath-CoT
|
||||
type: chat_template
|
||||
|
||||
output_dir: ./outputs/out_math_72b/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 8
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-6
|
||||
eta_min: 7e-7
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,72 +0,0 @@
|
||||
base_model: Qwen/Qwen2.5-72B
|
||||
# Math finetuning configuration for Qwen2.5-72B (non-instruct)
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: qwen_25
|
||||
datasets:
|
||||
- path: AI-MO/NuminaMath-CoT
|
||||
type: chat_template
|
||||
|
||||
output_dir: ./outputs/qat_out_math_72b/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
qat:
|
||||
activation_dtype: nvfp4
|
||||
weight_dtype: nvfp4
|
||||
group_size: 16 # only group_size of 16 is supported with nvfp4
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 8
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-6
|
||||
eta_min: 7e-7
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,67 +0,0 @@
|
||||
base_model: Qwen/Qwen2.5-72B
|
||||
# Alpaca finetuning configuration for Qwen2.5-72B
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: qwen_25
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/out_qwen72b/
|
||||
|
||||
sequence_len: 8096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 16
|
||||
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,72 +0,0 @@
|
||||
base_model: Qwen/Qwen2.5-72B
|
||||
# Alpaca finetuning configuration for Qwen2.5-72B
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: qwen_25
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/qat_out_qwen72b/
|
||||
|
||||
sequence_len: 8096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
qat:
|
||||
activation_dtype: nvfp4
|
||||
weight_dtype: nvfp4
|
||||
group_size: 16 # only group_size of 16 is supported with nvfp4
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 16
|
||||
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,70 +0,0 @@
|
||||
base_model: Qwen/Qwen2.5-0.5B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
# Use random initialization for fair comparison
|
||||
reinit_weights: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
# Pretraining dataset
|
||||
pretraining_dataset:
|
||||
- path: allenai/c4
|
||||
name: en
|
||||
type: pretrain
|
||||
split: train
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/compare-adamw-pretrain
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project: dist_muon
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name: adamw
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 4
|
||||
num_epochs: 1
|
||||
max_steps: 305
|
||||
|
||||
# AdamW optimizer settings (standard LR for AdamW)
|
||||
optimizer: adamw_torch_fused
|
||||
learning_rate: 0.0002
|
||||
weight_decay: 0.01
|
||||
lr_scheduler: cosine
|
||||
|
||||
train_on_inputs: true
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: false
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 0
|
||||
saves_per_epoch: 1
|
||||
|
||||
# Reproducibility
|
||||
seed: 42
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: false
|
||||
fsdp_reshard_after_forward: true
|
||||
|
||||
special_tokens:
|
||||
@@ -1,70 +0,0 @@
|
||||
base_model: Qwen/Qwen2.5-0.5B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
# Use random initialization for fair comparison
|
||||
reinit_weights: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
# Pretraining dataset
|
||||
pretraining_dataset:
|
||||
- path: allenai/c4
|
||||
name: en
|
||||
type: pretrain
|
||||
split: train
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/compare-muon-pretrain
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project: dist_muon
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name: muon
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 4
|
||||
num_epochs: 1
|
||||
max_steps: 305
|
||||
|
||||
# Muon optimizer settings
|
||||
optimizer: muon
|
||||
learning_rate: 0.02
|
||||
weight_decay: 0.01
|
||||
lr_scheduler: cosine
|
||||
|
||||
train_on_inputs: true
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: false
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 0
|
||||
saves_per_epoch: 1
|
||||
|
||||
# Reproducibility
|
||||
seed: 42
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: false
|
||||
fsdp_reshard_after_forward: true
|
||||
|
||||
special_tokens:
|
||||
@@ -1,46 +0,0 @@
|
||||
# Finetune Qwen3 with Axolotl
|
||||
|
||||
[Qwen3](https://huggingface.co/collections/Qwen/qwen3) are a family of open source models trained by Alibaba.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/qwen3/32b-qlora.yaml
|
||||
```
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Chat template masking a few tokens off
|
||||
|
||||
If you notice that the `chat_template` masking for assistant prompts are off by a few tokens, please ensure that you are adding the below to the yaml.
|
||||
|
||||
```yaml
|
||||
chat_template: qwen3
|
||||
```
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, please check the official model card as it depends on your reasoning mode.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Qwen3 Blog](https://qwenlm.github.io/blog/qwen3/)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
@@ -6,17 +6,21 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Seed-OSS is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
pip3 install packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
# Install Cut Cross Entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install Cut Cross Entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
@@ -37,7 +41,9 @@ Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
|
||||
@@ -37,7 +37,9 @@ This guide shows how to fine-tune SmolVLM2 models with Axolotl.
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
# Finetune ArceeAI's Trinity with Axolotl
|
||||
|
||||
[Trinity](https://huggingface.co/collections/arcee-ai/trinity) is a family of open weight MoE models trained by Arcee.ai.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 24.9 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, the official Arcee.ai team recommends `top_p: 0.75`, `temperature: 0.15`, `top_k: 50`, and `min_p: 0.06`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
@@ -1,67 +0,0 @@
|
||||
base_model: arcee-ai/Trinity-Nano-Preview
|
||||
trust_remote_code: true
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# CCE - N/A as of now
|
||||
# plugins:
|
||||
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
# flash_attention: true # Not supported
|
||||
sdp_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,5 +1,5 @@
|
||||
base_model: mistralai/Voxtral-Mini-3B-2507
|
||||
processor_type: VoxtralProcessor
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
@@ -1,35 +1,34 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.48.2
|
||||
bitsandbytes==0.47.0
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
liger-kernel==0.6.4
|
||||
liger-kernel==0.6.3
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub>=0.36.0
|
||||
peft>=0.18.0
|
||||
tokenizers>=0.22.1
|
||||
peft>=0.17.1
|
||||
tokenizers>=0.21.1
|
||||
transformers==4.57.1
|
||||
accelerate==1.11.0
|
||||
datasets==4.4.1
|
||||
accelerate==1.10.1
|
||||
datasets==4.3.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.25.0
|
||||
trl==0.24.0
|
||||
hf_xet==1.2.0
|
||||
kernels>=0.9.0
|
||||
trackio>=0.13.0
|
||||
typing_extensions>=4.14.0
|
||||
trackio
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
sentencepiece
|
||||
gradio>=6.2.0,<7.0
|
||||
gradio==5.49.1
|
||||
|
||||
modal==1.0.2
|
||||
pydantic>=2.10.6,<2.12
|
||||
pydantic>=2.10.6
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
@@ -43,6 +42,7 @@ numpy>=2.2.6
|
||||
# qlora things
|
||||
evaluate==0.4.1
|
||||
scipy
|
||||
scikit-learn==1.4.2
|
||||
nvidia-ml-py==12.560.30
|
||||
art
|
||||
tensorboard
|
||||
@@ -64,12 +64,9 @@ immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.13.0
|
||||
openenv-core==0.1.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.7
|
||||
axolotl-contribs-mit==0.0.6
|
||||
# telemetry
|
||||
posthog==6.7.11
|
||||
axolotl-contribs-mit==0.0.5
|
||||
|
||||
mistral-common==1.8.6
|
||||
mistral-common==1.8.5
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"'
|
||||
)
|
||||
|
||||
3
setup.py
3
setup.py
@@ -66,6 +66,7 @@ def parse_requirements(extras_require_map):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
|
||||
extras_require_map["vllm"] = ["vllm==0.11.1"]
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
elif (major, minor) >= (2, 8):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
|
||||
@@ -129,7 +130,7 @@ extras_require = {
|
||||
"ring-flash-attn>=0.1.7",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.18.2",
|
||||
"deepspeed==0.17.5",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
|
||||
@@ -14,8 +14,6 @@ import yaml
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
from axolotl.utils.config import (
|
||||
normalize_cfg_datasets,
|
||||
@@ -26,7 +24,6 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.tee import prepare_debug_log
|
||||
from axolotl.utils.trackio_ import setup_trackio_env_vars
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
@@ -34,8 +31,6 @@ LOG = get_logger(__name__)
|
||||
|
||||
API_KEY_FIELDS = {"comet_api_key"}
|
||||
|
||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||
|
||||
|
||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||
"""
|
||||
@@ -169,7 +164,6 @@ def plugin_set_cfg(cfg: DictDefault):
|
||||
plugin_manager.cfg = cfg
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_cfg(
|
||||
config: str | Path | DictDefault = Path("examples/"), **kwargs
|
||||
) -> DictDefault:
|
||||
@@ -203,8 +197,6 @@ def load_cfg(
|
||||
temp_file.close()
|
||||
cfg.axolotl_config_path = temp_file.name
|
||||
|
||||
TELEMETRY_MANAGER.send_event(event_type="config-loaded", properties=cfg)
|
||||
|
||||
# If there are any options passed in the cli, if it is something that seems valid
|
||||
# from the yaml, then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
@@ -228,7 +220,6 @@ def load_cfg(
|
||||
cfg,
|
||||
capabilities={
|
||||
"bf16": is_torch_bf16_gpu_available(),
|
||||
"fp8": compute_supports_fp8(),
|
||||
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||
"compute_capability": gpu_version,
|
||||
},
|
||||
@@ -247,10 +238,8 @@ def load_cfg(
|
||||
setup_wandb_env_vars(cfg)
|
||||
setup_mlflow_env_vars(cfg)
|
||||
setup_comet_env_vars(cfg)
|
||||
setup_trackio_env_vars(cfg)
|
||||
plugin_set_cfg(cfg)
|
||||
|
||||
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
|
||||
cfg_to_log = {
|
||||
k: "[REDACTED]" if k in API_KEY_FIELDS else v
|
||||
for k, v in cfg.items()
|
||||
@@ -262,11 +251,3 @@ def load_cfg(
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def compute_supports_fp8() -> bool:
|
||||
try:
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
return compute_capability >= (9, 0)
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
@@ -19,10 +19,7 @@ from axolotl.cli.utils.diffusion import (
|
||||
launch_diffusion_gradio_ui,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.chat_templates import (
|
||||
get_chat_template_from_config,
|
||||
)
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -46,7 +43,6 @@ def get_multi_line_input() -> str:
|
||||
return instruction
|
||||
|
||||
|
||||
@send_errors
|
||||
def do_inference(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
@@ -164,7 +160,6 @@ def do_inference(
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
@send_errors
|
||||
def do_inference_gradio(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
@@ -288,8 +283,8 @@ def do_inference_gradio(
|
||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||
)
|
||||
|
||||
demo.launch(
|
||||
footer_links=["gradio", "settings"],
|
||||
demo.queue().launch(
|
||||
show_api=False,
|
||||
share=cfg.get("gradio_share", True),
|
||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||
server_port=cfg.get("gradio_server_port", None),
|
||||
|
||||
@@ -26,7 +26,7 @@ from axolotl.cli.utils import (
|
||||
launch_training,
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import set_misc_env, set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
@@ -45,7 +45,6 @@ def cli():
|
||||
print_axolotl_text_art()
|
||||
load_dotenv()
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
set_misc_env()
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
||||
@@ -7,14 +7,12 @@ import fire
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@send_errors
|
||||
def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
"""
|
||||
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
|
||||
|
||||
@@ -23,7 +23,6 @@ from safetensors.torch import save_file as safe_save_file
|
||||
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.train import determine_last_checkpoint
|
||||
|
||||
@@ -119,7 +118,6 @@ def _distributed_checkpoint_to_merged_weights(
|
||||
return save_path_
|
||||
|
||||
|
||||
@send_errors
|
||||
def merge_fsdp_weights(
|
||||
checkpoint_dir: str,
|
||||
output_path: str,
|
||||
|
||||
@@ -17,7 +17,6 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.trainer import disable_datasets_caching
|
||||
@@ -25,7 +24,6 @@ from axolotl.utils.trainer import disable_datasets_caching
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@send_errors
|
||||
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||
"""
|
||||
Preprocesses dataset specified in axolotl config.
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Union
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.loaders import load_tokenizer
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.quantization import (
|
||||
TorchAOQuantDType,
|
||||
@@ -66,11 +66,6 @@ def do_quantize(
|
||||
|
||||
LOG.info(f"Loading model from {model_path}.")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
processor = None
|
||||
if cfg.is_multimodal:
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@@ -112,10 +107,6 @@ def do_quantize(
|
||||
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
||||
)
|
||||
|
||||
if processor:
|
||||
LOG.info(f"Saving processor to: {str(Path(output_dir) / 'quantized')}.")
|
||||
processor.save_pretrained(str(Path(output_dir) / "quantized"))
|
||||
|
||||
if hub_model_id:
|
||||
hub_model_id = (
|
||||
hub_model_id.rstrip("-")
|
||||
@@ -123,8 +114,6 @@ def do_quantize(
|
||||
)
|
||||
model.push_to_hub(hub_model_id, safe_serialization=False)
|
||||
tokenizer.push_to_hub(hub_model_id)
|
||||
if processor:
|
||||
processor.push_to_hub(hub_model_id)
|
||||
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
|
||||
|
||||
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")
|
||||
|
||||
@@ -366,8 +366,8 @@ def launch_diffusion_gradio_ui(
|
||||
outputs=[masked_preview, html_out],
|
||||
)
|
||||
|
||||
demo.launch(
|
||||
footer_links=["gradio", "settings"],
|
||||
demo.queue().launch(
|
||||
show_api=False,
|
||||
share=cfg.get("gradio_share", True),
|
||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||
server_port=cfg.get("gradio_server_port", None),
|
||||
|
||||
@@ -14,9 +14,7 @@ MOE_ARCH_BLOCK = {
|
||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||
"deepseek_v2": "DeepseekV2MoE",
|
||||
"glm4_moe": "Glm4MoeMoE",
|
||||
"deepseek_v3": "DeepseekV3MoE",
|
||||
"gpt_oss": "GptOssDecoderLayer",
|
||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||
"afmoe": "AfmoeMoE",
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ from datasets import Dataset
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # noqa: F401
|
||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -35,7 +34,6 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
||||
)
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
@@ -98,7 +96,6 @@ def load_datasets(
|
||||
)
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_preference_datasets(
|
||||
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None
|
||||
) -> TrainDatasetMeta:
|
||||
|
||||
@@ -29,13 +29,10 @@ from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||
from axolotl.telemetry.callbacks import TelemetryCallback
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.utils import (
|
||||
is_comet_available,
|
||||
is_mlflow_available,
|
||||
is_opentelemetry_available,
|
||||
is_trackio_available,
|
||||
)
|
||||
from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
@@ -121,13 +118,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if self.cfg.gc_steps:
|
||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||
|
||||
if self.cfg.dynamic_checkpoint and self.cfg.dynamic_checkpoint.enabled:
|
||||
from axolotl.utils.callbacks.dynamic_checkpoint import (
|
||||
DynamicCheckpointCallback,
|
||||
)
|
||||
|
||||
callbacks.append(DynamicCheckpointCallback(self.cfg))
|
||||
|
||||
if self.cfg.use_wandb:
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
@@ -148,14 +138,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_trackio and is_trackio_available():
|
||||
from axolotl.utils.callbacks.trackio_ import (
|
||||
SaveAxolotlConfigtoTrackioCallback,
|
||||
)
|
||||
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoTrackioCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_otel_metrics and is_opentelemetry_available():
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OpenTelemetryMetricsCallback,
|
||||
@@ -173,10 +155,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
)
|
||||
)
|
||||
|
||||
telemetry_manager = TelemetryManager.get_instance()
|
||||
if telemetry_manager.enabled:
|
||||
callbacks.append(TelemetryCallback())
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
@@ -218,9 +196,9 @@ class TrainerBuilderBase(abc.ABC):
|
||||
):
|
||||
warmup_steps = 0
|
||||
warmup_ratio = 0.0
|
||||
if self.cfg.warmup_steps is not None:
|
||||
if self.cfg.warmup_steps:
|
||||
warmup_steps = self.cfg.warmup_steps
|
||||
elif self.cfg.warmup_ratio is not None:
|
||||
elif self.cfg.warmup_ratio:
|
||||
if total_num_steps:
|
||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||
else:
|
||||
@@ -290,22 +268,11 @@ class TrainerBuilderBase(abc.ABC):
|
||||
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
|
||||
|
||||
if self.cfg.optimizer == "muon":
|
||||
_, device_mesh = build_parallelism_config(self.cfg)
|
||||
|
||||
if device_mesh is not None:
|
||||
from axolotl.contribs.mit.muon.dist_muon import (
|
||||
DistMuonOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls = DistMuonOptimizerFactory
|
||||
optimizer_kwargs["device_mesh"] = device_mesh
|
||||
else:
|
||||
from axolotl.contribs.mit.muon import (
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls = MuonOptimizerFactory
|
||||
from axolotl.contribs.mit.muon import (
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls = MuonOptimizerFactory
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "dion":
|
||||
from axolotl.contribs.mit.dion import (
|
||||
@@ -443,8 +410,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
report_to.append("tensorboard")
|
||||
if self.cfg.use_comet:
|
||||
report_to.append("comet_ml")
|
||||
if self.cfg.use_trackio:
|
||||
report_to.append("trackio")
|
||||
|
||||
training_args_kwargs["report_to"] = report_to
|
||||
|
||||
@@ -452,8 +417,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
elif self.cfg.use_mlflow:
|
||||
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||
elif self.cfg.use_trackio:
|
||||
training_args_kwargs["run_name"] = self.cfg.trackio_run_name
|
||||
else:
|
||||
training_args_kwargs["run_name"] = None
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial, wraps
|
||||
@@ -44,7 +43,7 @@ from axolotl.core.trainers.utils import (
|
||||
from axolotl.utils import get_not_null
|
||||
from axolotl.utils.bench import get_gpu_memory_usage
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_distributed, is_main_process
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
@@ -351,11 +350,6 @@ class AxolotlTrainer(
|
||||
# track number of tokens for tokens per second calculation
|
||||
if self.args.include_tkps:
|
||||
inputs_key = "labels" if "labels" in inputs else "input_ids"
|
||||
num_tokens = (inputs[inputs_key] != -100).sum()
|
||||
if is_distributed():
|
||||
torch.distributed.all_reduce(
|
||||
num_tokens, op=torch.distributed.ReduceOp.SUM
|
||||
)
|
||||
if hasattr(self.state, "num_tokens"):
|
||||
self.state.num_tokens = (
|
||||
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
|
||||
@@ -363,11 +357,6 @@ class AxolotlTrainer(
|
||||
else:
|
||||
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
|
||||
|
||||
if hasattr(self.state, "total_tokens"):
|
||||
self.state.total_tokens += num_tokens
|
||||
else:
|
||||
self.state.total_tokens = num_tokens
|
||||
|
||||
if self.args.orpo_alpha:
|
||||
return self.orpo_compute_loss(
|
||||
model,
|
||||
@@ -604,7 +593,6 @@ class AxolotlTrainer(
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
metric_ndigits = int(os.getenv("AXOLOTL_METRIC_NDIGITS", "5"))
|
||||
|
||||
for key, metric_data in self._stored_metrics[train_eval].items():
|
||||
values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
|
||||
@@ -615,18 +603,7 @@ class AxolotlTrainer(
|
||||
raise NotImplementedError(
|
||||
"Metric reduction must be one of [mean, min, max, sum]"
|
||||
)
|
||||
logs[key] = round(fn(values).item(), metric_ndigits)
|
||||
|
||||
if "loss" in logs:
|
||||
try:
|
||||
logs["ppl"] = round(math.exp(logs["loss"]), metric_ndigits)
|
||||
except OverflowError:
|
||||
logs["ppl"] = float("inf")
|
||||
if "eval_loss" in logs:
|
||||
try:
|
||||
logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), metric_ndigits)
|
||||
except OverflowError:
|
||||
logs["eval_ppl"] = float("inf")
|
||||
logs[key] = round(fn(values).item(), 4)
|
||||
|
||||
if is_main_process():
|
||||
# Add memory usage
|
||||
@@ -644,11 +621,6 @@ class AxolotlTrainer(
|
||||
logs["tokens_per_second_per_gpu"] = round(
|
||||
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
||||
)
|
||||
if (
|
||||
hasattr(self.state, "total_tokens")
|
||||
and self.state.total_tokens is not None
|
||||
):
|
||||
logs["total_tokens"] = int(self.state.total_tokens.item())
|
||||
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
|
||||
@@ -36,6 +36,4 @@ class DPOStrategy:
|
||||
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
||||
if cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
||||
if cfg.dpo_use_liger_kernel is not None:
|
||||
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
|
||||
return training_args_kwargs
|
||||
|
||||
@@ -126,9 +126,6 @@ class GRPOStrategy:
|
||||
if trl.use_liger_loss is not None:
|
||||
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
|
||||
|
||||
if trl.rollout_func:
|
||||
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
|
||||
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
@@ -204,32 +201,3 @@ class GRPOStrategy:
|
||||
raise ValueError(
|
||||
f"Reward function {reward_func_fqn} not found."
|
||||
) from exc
|
||||
|
||||
@classmethod
|
||||
def get_rollout_func(cls, rollout_func_fqn: str):
|
||||
"""
|
||||
Returns the rollout function from the given fully qualified name.
|
||||
|
||||
Args:
|
||||
rollout_func_fqn (str): Fully qualified name of the rollout function
|
||||
(e.g. my_module.my_rollout_func)
|
||||
|
||||
Returns:
|
||||
Callable rollout function
|
||||
"""
|
||||
try:
|
||||
rollout_func_module_name = rollout_func_fqn.split(".")[-1]
|
||||
rollout_func_module = importlib.import_module(
|
||||
".".join(rollout_func_fqn.split(".")[:-1])
|
||||
)
|
||||
rollout_func = getattr(rollout_func_module, rollout_func_module_name)
|
||||
|
||||
if not callable(rollout_func):
|
||||
raise ValueError(
|
||||
f"Rollout function {rollout_func_fqn} must be callable"
|
||||
)
|
||||
|
||||
return rollout_func
|
||||
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(f"Rollout function {rollout_func_fqn} not found.") from exc
|
||||
|
||||
@@ -10,7 +10,6 @@ import torch
|
||||
from datasets import Dataset
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.train import (
|
||||
TrainDatasetMeta,
|
||||
setup_model_and_tokenizer,
|
||||
@@ -64,7 +63,6 @@ def evaluate_dataset(
|
||||
return metrics
|
||||
|
||||
|
||||
@send_errors
|
||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate a model on training and validation datasets.
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -44,7 +44,6 @@ plugins:
|
||||
- gemma3n_text
|
||||
- glm
|
||||
- glm4
|
||||
- glm_moe
|
||||
- glm4_moe
|
||||
- glm4v
|
||||
- glm4v_moe
|
||||
@@ -62,15 +61,10 @@ plugins:
|
||||
- llama4
|
||||
- llama4_text
|
||||
- llava
|
||||
- ministral
|
||||
- ministral3
|
||||
- mistral
|
||||
- mistral3
|
||||
- mixtral
|
||||
- mllama
|
||||
- olmo
|
||||
- olmo2
|
||||
- olmo3
|
||||
- phi
|
||||
- phi3
|
||||
- phi4_multimodal
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class DenseMixerPlugin(BasePlugin):
|
||||
if cfg.dense_mixer:
|
||||
if not importlib.util.find_spec("densemixer"):
|
||||
raise RuntimeError(
|
||||
"DenseMixer is not installed. Install it with `pip install densemixer`"
|
||||
"DenseMixer is not installed. Install it with `pip install densemizer`"
|
||||
)
|
||||
|
||||
from densemixer.patching import (
|
||||
|
||||
@@ -179,17 +179,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
logprobs = prompt.pop(self.logprobs_field)
|
||||
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||
tokenized_prompt[self.logprobs_field] = logprobs
|
||||
|
||||
# let subclasses add fields before transform
|
||||
tokenized_prompt = self._prepare_kd_fields(tokenized_prompt, prompt)
|
||||
|
||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||
return tokenized_prompt
|
||||
|
||||
def _prepare_kd_fields(self, tokenized_prompt, original_prompt):
|
||||
"""
|
||||
Hook for subclasses to prepare additional KD fields before transform
|
||||
"""
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
@@ -292,13 +283,14 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
|
||||
|
||||
return sample
|
||||
|
||||
def _prepare_kd_fields(self, tokenized_prompt, original_prompt):
|
||||
"""
|
||||
Add pre-tokenized target_token_ids for v2 format
|
||||
"""
|
||||
target_token_ids = original_prompt.pop("target_token_ids", None)
|
||||
def _tokenize_single_prompt(self, prompt):
|
||||
target_token_ids = prompt.get("target_token_ids", None)
|
||||
|
||||
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||
|
||||
if target_token_ids is not None:
|
||||
tokenized_prompt["target_token_ids"] = target_token_ids
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
|
||||
@@ -16,8 +16,6 @@
|
||||
KD trainer
|
||||
"""
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
|
||||
@@ -62,7 +60,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
if columns_to_add:
|
||||
self._signature_columns += columns_to_add
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self,
|
||||
model,
|
||||
@@ -82,22 +79,10 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
):
|
||||
del inputs["attention_mask"]
|
||||
|
||||
if num_items_in_batch is None and "labels" in inputs:
|
||||
num_items_in_batch = (inputs["labels"] != -100).sum().item()
|
||||
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
if num_items_in_batch is not None:
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
if isinstance(outputs, dict):
|
||||
loss = outputs["loss"]
|
||||
elif isinstance(outputs, tuple):
|
||||
loss = outputs[0]
|
||||
else:
|
||||
loss = outputs.loss if hasattr(outputs, "loss") else outputs
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
return outputs[0]
|
||||
|
||||
@@ -18,9 +18,6 @@ liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
# FLCE-specific
|
||||
liger_use_token_scaling: true
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
Module for handling LIGER input arguments.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -35,15 +35,6 @@ class LigerArgs(BaseModel):
|
||||
liger_glu_activation: bool | None = None
|
||||
liger_cross_entropy: bool | None = None
|
||||
liger_fused_linear_cross_entropy: bool | None = None
|
||||
liger_use_token_scaling: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Enables use_token_scaling in fused_linear_cross_entropy. "
|
||||
"When True, each token's loss is multiplied by its predicted probability (detached from gradients)."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -84,18 +75,6 @@ class LigerArgs(BaseModel):
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_liger_use_token_scaling_flce(cls, data):
|
||||
if data.get("liger_use_token_scaling") and not data.get(
|
||||
"liger_fused_linear_cross_entropy"
|
||||
):
|
||||
raise ValueError(
|
||||
"`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self):
|
||||
# TODO @SalmanMohammadi this is a larger fix - investigate
|
||||
|
||||
@@ -48,33 +48,6 @@ class LigerPlugin(BasePlugin):
|
||||
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
|
||||
)
|
||||
|
||||
if cfg.liger_use_token_scaling:
|
||||
# Patch FLCE to set token_scaling=True for function and class API
|
||||
from liger_kernel.transformers import functional
|
||||
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
||||
LigerFusedLinearCrossEntropyLoss,
|
||||
)
|
||||
|
||||
old_liger_fused_linear_cross_entropy = (
|
||||
functional.liger_fused_linear_cross_entropy
|
||||
)
|
||||
|
||||
def patched_liger_fused_linear_cross_entropy(*args, **kwargs):
|
||||
kwargs["use_token_scaling"] = True
|
||||
return old_liger_fused_linear_cross_entropy(*args, **kwargs)
|
||||
|
||||
functional.liger_fused_linear_cross_entropy = (
|
||||
patched_liger_fused_linear_cross_entropy
|
||||
)
|
||||
|
||||
old_init = LigerFusedLinearCrossEntropyLoss.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
kwargs["use_token_scaling"] = True
|
||||
return old_init(self, *args, **kwargs)
|
||||
|
||||
LigerFusedLinearCrossEntropyLoss.__init__ = patched_init
|
||||
|
||||
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||
|
||||
@@ -20,7 +20,6 @@ from peft import (
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from axolotl.loaders.utils import get_linear_embedding_layers
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -102,8 +101,6 @@ def load_lora(
|
||||
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
|
||||
if cfg.peft_trainable_token_indices:
|
||||
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
|
||||
if cfg.peft_ensure_weight_tying is not None:
|
||||
lora_config_kwargs["ensure_weight_tying"] = cfg.peft_ensure_weight_tying
|
||||
|
||||
# Determine the correct PEFT task type
|
||||
model_cls = type(model).__name__
|
||||
@@ -142,12 +139,9 @@ def load_lora(
|
||||
):
|
||||
setup_quantized_meta_for_peft(model)
|
||||
|
||||
model_kwargs: Any = {}
|
||||
if cfg.peft_autocast_adapter_dtype is not None:
|
||||
model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
LOG.debug("Loading pretrained PEFT - LoRA")
|
||||
model_kwargs: Any = {}
|
||||
if cfg.lora_on_cpu:
|
||||
model_kwargs["max_memory"] = {"cpu": "256GiB"}
|
||||
model_kwargs["device_map"] = {"": "cpu"}
|
||||
@@ -158,7 +152,7 @@ def load_lora(
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
model = get_peft_model(model, lora_config, **model_kwargs)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
if rank == 0:
|
||||
try:
|
||||
@@ -178,7 +172,6 @@ def load_lora(
|
||||
return model, lora_config
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_adapter(
|
||||
model: PreTrainedModel,
|
||||
cfg: DictDefault,
|
||||
|
||||
@@ -49,7 +49,6 @@ from axolotl.loaders.utils import (
|
||||
load_model_config,
|
||||
)
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import (
|
||||
@@ -159,7 +158,6 @@ class ModelLoader:
|
||||
"""Property that determines if FSDP with QLoRA is enabled."""
|
||||
return self.is_fsdp_enabled and self.cfg.adapter == "qlora"
|
||||
|
||||
@send_errors
|
||||
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
|
||||
"""Load and prepare the model with all configurations and patches.
|
||||
|
||||
|
||||
@@ -457,7 +457,7 @@ class PatchManager:
|
||||
and self.cfg.flash_attention
|
||||
and not self.inference
|
||||
):
|
||||
# TODO(MengqingCao): split these patches separately
|
||||
# TODO(MengqingCao): split these patches seperately
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
is_xformers_swiglu_available,
|
||||
replace_llama_mlp_with_swiglu,
|
||||
|
||||
@@ -1,47 +1,27 @@
|
||||
"""Processor loading functionality for multi-modal models"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
processor_kwargs: dict[str, Any] = {} # Do we actually need this?
|
||||
|
||||
processor_cls = AutoProcessor
|
||||
if cfg.processor_type:
|
||||
processor_cls = getattr(transformers, cfg.processor_type)
|
||||
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
|
||||
def _patch_mistralcommontokenizer():
|
||||
"""
|
||||
Transformers v5 stops reading the sub-processor.
|
||||
|
||||
We need to patch this, so both processors use this.
|
||||
"""
|
||||
import transformers.tokenization_mistral_common as tokenization_mistral_common
|
||||
|
||||
from axolotl.utils.mistral import HFMistralTokenizer
|
||||
|
||||
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
|
||||
|
||||
_patch_mistralcommontokenizer()
|
||||
|
||||
from transformers import VoxtralProcessor
|
||||
|
||||
if processor_cls == VoxtralProcessor:
|
||||
return VoxtralProcessor.from_pretrained(
|
||||
cfg.processor_config,
|
||||
)
|
||||
|
||||
from axolotl.utils.mistral import Mistral3Processor
|
||||
|
||||
return Mistral3Processor(
|
||||
@@ -52,6 +32,7 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
cfg.processor_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
tokenizer=tokenizer,
|
||||
**processor_kwargs,
|
||||
)
|
||||
|
||||
# Attempt to load image size from processor if available
|
||||
|
||||
@@ -13,7 +13,6 @@ from transformers import (
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import (
|
||||
@@ -120,7 +119,6 @@ def modify_tokenizer_files(
|
||||
return tokenizer_dir
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
"""Load and configure the tokenizer based on the provided config."""
|
||||
|
||||
|
||||
@@ -37,12 +37,9 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"deepseek_v3",
|
||||
"glm",
|
||||
"glm4",
|
||||
"glm4_moe",
|
||||
"smollm3",
|
||||
"granite",
|
||||
"granitemoe",
|
||||
"granitemoeshared",
|
||||
"granitemoehybrid",
|
||||
"hunyuan_v1_dense",
|
||||
"hunyuan_v1_moe",
|
||||
"gpt_oss",
|
||||
@@ -50,12 +47,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"seed_oss",
|
||||
"lfm2",
|
||||
"lfm2_moe",
|
||||
"olmo",
|
||||
"olmo2",
|
||||
"olmo3",
|
||||
"ministral",
|
||||
"ministral3",
|
||||
"afmoe",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -95,7 +95,6 @@ class ChatTemplatePrompter(Prompter):
|
||||
add_generation_prompt=False,
|
||||
images=None,
|
||||
tools=None,
|
||||
real_last_index=None,
|
||||
):
|
||||
"""
|
||||
Build a prompt from a conversation.
|
||||
@@ -115,9 +114,6 @@ class ChatTemplatePrompter(Prompter):
|
||||
if tools:
|
||||
chat_template_kwargs["tools"] = tools
|
||||
|
||||
if real_last_index:
|
||||
chat_template_kwargs["real_last_index"] = real_last_index
|
||||
|
||||
if self.processor:
|
||||
if not callable(self.processor):
|
||||
raise TypeError("Processor must be callable")
|
||||
@@ -635,17 +631,11 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
turns_with_empty = turns[:turn_idx] + [empty_turn]
|
||||
turns_with_content = turns[: turn_idx + 1]
|
||||
|
||||
real_last_index = len(turns) - 1
|
||||
|
||||
# Generate the conversation up to the turn, with final turn replaced with dummy content
|
||||
dummy_ids = self.prompter.build_prompt(
|
||||
turns_with_empty, tools=tools, real_last_index=real_last_index
|
||||
) # type: ignore
|
||||
dummy_ids = self.prompter.build_prompt(turns_with_empty, tools=tools) # type: ignore
|
||||
|
||||
# Generate the conversation up to the turn, with final turn included
|
||||
full_ids = self.prompter.build_prompt(
|
||||
turns_with_content, tools=tools, real_last_index=real_last_index
|
||||
) # type: ignore
|
||||
full_ids = self.prompter.build_prompt(turns_with_content, tools=tools) # type: ignore
|
||||
|
||||
if not full_ids or not dummy_ids:
|
||||
LOG.warning(f"Empty template generated for turn {turn_idx}")
|
||||
@@ -833,23 +823,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
return None
|
||||
|
||||
if isinstance(tools, list):
|
||||
# Process each tool to handle JSON string parameters
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict) and "function" in tool:
|
||||
function = tool["function"]
|
||||
if "parameters" in function:
|
||||
params = function["parameters"]
|
||||
if isinstance(params, str):
|
||||
try:
|
||||
function["parameters"] = json.loads(params)
|
||||
except json.JSONDecodeError as e:
|
||||
LOG.error(
|
||||
f"Error parsing tool parameters as JSON. "
|
||||
f"Function: {function.get('name', 'unknown')}, "
|
||||
f"Parameters string: {params!r}, "
|
||||
f"Error: {e}"
|
||||
)
|
||||
raise
|
||||
return tools
|
||||
|
||||
raise ValueError(
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
"""Trainer callbacks for reporting runtime metrics at regular intervals."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.telemetry.runtime_metrics import RuntimeMetricsTracker
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
TIME_SINCE_LAST = 60
|
||||
|
||||
|
||||
class TelemetryCallback(TrainerCallback):
|
||||
"""
|
||||
Trainer callback for tracking and reporting runtime metrics.
|
||||
|
||||
This callback tracks training progress, runtime, and memory usage,
|
||||
sending telemetry at configurable intervals.
|
||||
"""
|
||||
|
||||
report_interval_steps: int = 100
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the metrics callback."""
|
||||
self.tracker = RuntimeMetricsTracker()
|
||||
self.telemetry_manager = TelemetryManager.get_instance()
|
||||
self.current_epoch = -1
|
||||
self.start_time = time.time()
|
||||
self.last_report_time = None
|
||||
self.last_report_step = 0
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Handle training start."""
|
||||
self.telemetry_manager.send_event(event_type="train-start")
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_train_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Handle training end."""
|
||||
# Send training completion event
|
||||
self.telemetry_manager.send_event(
|
||||
event_type="train-end",
|
||||
properties=self._extract_last_metrics(state)
|
||||
| self.tracker.metrics.to_dict(),
|
||||
)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_epoch_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Handle epoch start."""
|
||||
self.current_epoch += 1
|
||||
self.tracker.start_epoch(self.current_epoch)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_epoch_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Handle epoch end."""
|
||||
self.tracker.end_epoch(self.current_epoch)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Handle step end."""
|
||||
step = state.global_step
|
||||
self.tracker.update_step(step)
|
||||
|
||||
# Check if we should report metrics
|
||||
should_report = (
|
||||
step % self.report_interval_steps == 0
|
||||
or step == 1 # Always report first step
|
||||
or step - self.last_report_step >= self.report_interval_steps
|
||||
)
|
||||
|
||||
if should_report:
|
||||
current_time = time.time()
|
||||
if self.last_report_time is not None:
|
||||
time_since_last_report = current_time - self.last_report_time
|
||||
else:
|
||||
time_since_last_report = current_time - self.start_time
|
||||
steps_since_last_report = step - self.last_report_step
|
||||
|
||||
# Only report if enough time has passed
|
||||
if (
|
||||
step == 1
|
||||
or time_since_last_report >= TIME_SINCE_LAST
|
||||
or steps_since_last_report >= self.report_interval_steps
|
||||
):
|
||||
# Calculate steps per second for this interval
|
||||
if time_since_last_report > 0 and steps_since_last_report > 0:
|
||||
steps_per_second = steps_since_last_report / time_since_last_report
|
||||
else:
|
||||
steps_per_second = 0
|
||||
|
||||
# Update memory metrics
|
||||
self.tracker.update_memory_metrics()
|
||||
|
||||
# Prepare metrics to report
|
||||
metrics = self._extract_last_metrics(state) | {
|
||||
"step": step,
|
||||
"epoch": self.current_epoch,
|
||||
"progress": state.epoch, # Fractional epoch progress
|
||||
"steps_per_second": steps_per_second,
|
||||
"elapsed_time": current_time - self.start_time,
|
||||
"time_since_last_report": time_since_last_report,
|
||||
}
|
||||
|
||||
# Add memory metrics
|
||||
memory_metrics = self.tracker.get_memory_metrics()
|
||||
metrics.update({"memory": memory_metrics})
|
||||
|
||||
# Send telemetry
|
||||
self.telemetry_manager.send_event(
|
||||
event_type="train-progress", properties=metrics
|
||||
)
|
||||
|
||||
# Update last report time and step
|
||||
self.last_report_time = current_time
|
||||
self.last_report_step = step
|
||||
|
||||
def _extract_last_metrics(self, state: TrainerState) -> dict:
|
||||
"""Extract last loss, learning_rate, and grad_norm from log history."""
|
||||
if not state.log_history:
|
||||
return {"loss": 0, "learning_rate": 0, "grad_norm": 0}
|
||||
|
||||
last_log = state.log_history[-1]
|
||||
return {
|
||||
"loss": last_log.get("loss", 0),
|
||||
"learning_rate": last_log.get("learning_rate", 0),
|
||||
"grad_norm": last_log.get("grad_norm", 0),
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
"""Telemetry utilities for exception and traceback information."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from functools import wraps
|
||||
from inspect import getmodule
|
||||
from typing import Any, Callable
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
ERROR_HANDLED = False
|
||||
|
||||
|
||||
def sanitize_stack_trace(stack_trace: str) -> str:
|
||||
"""
|
||||
Remove personal information from stack trace messages while keeping Python package codepaths.
|
||||
|
||||
This function identifies Python packages by looking for common patterns in virtual environment
|
||||
and site-packages directories, preserving the package path while removing user-specific paths.
|
||||
|
||||
Args:
|
||||
stack_trace: The original stack trace string.
|
||||
|
||||
Returns:
|
||||
A sanitized version of the stack trace with Python package paths preserved.
|
||||
"""
|
||||
# Split the stack trace into lines to process each file path separately
|
||||
lines = stack_trace.split("\n")
|
||||
sanitized_lines = []
|
||||
|
||||
# Regular expression to find file paths in the stack trace
|
||||
path_pattern = re.compile(r'(?:File ")(.*?)(?:")')
|
||||
|
||||
# Regular expression to identify paths in site-packages or dist-packages
|
||||
# This matches path segments like "site-packages/package_name" or "dist-packages/package_name"
|
||||
site_packages_pattern = re.compile(
|
||||
r"(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
|
||||
)
|
||||
|
||||
# Additional common virtual environment patterns
|
||||
venv_lib_pattern = re.compile(
|
||||
r"(?:lib|Lib)[/\\](?:python\d+(?:\.\d+)?[/\\])?(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
|
||||
)
|
||||
|
||||
for line in lines:
|
||||
# Check if this line contains a file path
|
||||
path_match = path_pattern.search(line)
|
||||
|
||||
if path_match:
|
||||
full_path = path_match.group(1)
|
||||
sanitized_path = ""
|
||||
|
||||
# Try to match site-packages pattern
|
||||
site_packages_match = site_packages_pattern.search(full_path)
|
||||
venv_lib_match = venv_lib_pattern.search(full_path)
|
||||
|
||||
if site_packages_match:
|
||||
# Find the index where the matched pattern starts
|
||||
idx = full_path.find("site-packages")
|
||||
if idx == -1:
|
||||
idx = full_path.find("dist-packages")
|
||||
|
||||
# Keep from 'site-packages' onward
|
||||
if idx >= 0:
|
||||
sanitized_path = full_path[idx:]
|
||||
elif venv_lib_match:
|
||||
# For other virtual environment patterns, find the package directory
|
||||
match_idx = venv_lib_match.start(1)
|
||||
if match_idx > 0:
|
||||
# Keep from the package name onward
|
||||
package_name = venv_lib_match.group(1)
|
||||
idx = full_path.rfind(
|
||||
package_name, 0, match_idx + len(package_name)
|
||||
)
|
||||
if idx >= 0:
|
||||
sanitized_path = full_path[idx:]
|
||||
|
||||
# If we couldn't identify a package pattern but path contains 'axolotl'
|
||||
elif "axolotl" in full_path:
|
||||
idx = full_path.rfind("axolotl")
|
||||
if idx >= 0:
|
||||
sanitized_path = full_path[idx:]
|
||||
|
||||
# Apply the sanitization to the line
|
||||
if sanitized_path:
|
||||
line = line.replace(full_path, sanitized_path)
|
||||
else:
|
||||
# If we couldn't identify a package pattern, just keep the filename
|
||||
filename = os.path.basename(full_path)
|
||||
if filename:
|
||||
line = line.replace(full_path, filename)
|
||||
else:
|
||||
line = line.replace(full_path, "")
|
||||
|
||||
sanitized_lines.append(line)
|
||||
|
||||
return "\n".join(sanitized_lines)
|
||||
|
||||
|
||||
def send_errors(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to send exception info in a function. If an exception is raised, we send
|
||||
telemetry containing the stack trace and error message.
|
||||
|
||||
If an error occurs in a decorated function that is called by another decorated
|
||||
function, we'll only send telemetry corresponding to the lower-level function.
|
||||
|
||||
Args:
|
||||
func: Function to decorate.
|
||||
|
||||
Returns:
|
||||
Decorated function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
telemetry_manager = TelemetryManager.get_instance()
|
||||
|
||||
if not telemetry_manager.enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as exception:
|
||||
# Only track if we're not already handling an error. This prevents us from
|
||||
# capturing an error more than once in nested decorated function calls.
|
||||
global ERROR_HANDLED # pylint: disable=global-statement
|
||||
if not ERROR_HANDLED:
|
||||
ERROR_HANDLED = True
|
||||
|
||||
# Get function module path
|
||||
module = getmodule(func)
|
||||
module_path = (
|
||||
f"{module.__name__}.{func.__name__}" if module else func.__name__
|
||||
)
|
||||
|
||||
# Get stack trace
|
||||
stack_trace = "".join(
|
||||
traceback.format_exception(
|
||||
type(exception), exception, exception.__traceback__
|
||||
)
|
||||
)
|
||||
stack_trace = sanitize_stack_trace(stack_trace)
|
||||
|
||||
# Send error telemetry
|
||||
telemetry_manager.send_event(
|
||||
event_type=f"{module_path}-error",
|
||||
properties={
|
||||
"exception": str(exception),
|
||||
"stack_trace": stack_trace,
|
||||
},
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
@@ -1,416 +0,0 @@
|
||||
"""Telemetry manager and associated utilities."""
|
||||
|
||||
import atexit
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import posthog
|
||||
import psutil
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
POSTHOG_HOST = "https://app.posthog.com"
|
||||
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
|
||||
|
||||
OPT_OUT_WARNING_SLEEP_SECONDS = 10
|
||||
OPT_OUT_WARNING = (
|
||||
"\nTelemetry is now enabled by default to help improve Axolotl. "
|
||||
"If you'd like to disable it, set AXOLOTL_DO_NOT_TRACK=1 in your environment.\n\n"
|
||||
"Telemetry data helps us understand:\n"
|
||||
"- Which features are most used\n"
|
||||
"- What hardware configurations to prioritize\n"
|
||||
"- Where users encounter errors\n\n"
|
||||
"Personally identifiable information (PII) is not collected.\n\n"
|
||||
"To remove this warning, explicitly set AXOLOTL_DO_NOT_TRACK=0 (enable telemetry) "
|
||||
"or AXOLOTL_DO_NOT_TRACK=1 (disable telemetry).\n\n"
|
||||
"For details, see: https://docs.axolotl.ai/docs/telemetry.html\n\n"
|
||||
f"Sleeping for {OPT_OUT_WARNING_SLEEP_SECONDS}s..."
|
||||
)
|
||||
|
||||
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
|
||||
|
||||
# NOTE: Need to keep these up to date with any config schema changes
|
||||
FIELDS_TO_REDACT = {
|
||||
"base_model",
|
||||
"tokenizer_config",
|
||||
"base_model_config",
|
||||
"pretraining_dataset", # NOTE: this field may be a string or a dictionary
|
||||
"resume_from_checkpoint",
|
||||
"hub_model_id",
|
||||
}
|
||||
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"}
|
||||
PATH_INDICATORS = {"path", "dir"}
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
RELEVANT_PACKAGES = {
|
||||
"torch",
|
||||
"transformers",
|
||||
"trl",
|
||||
"datasets",
|
||||
"peft",
|
||||
"bitsandbytes",
|
||||
"accelerate",
|
||||
"optimum",
|
||||
"deepspeed",
|
||||
"ray",
|
||||
"axolotl",
|
||||
"triton",
|
||||
"mamba-ssm",
|
||||
"flash-attn",
|
||||
"xformers",
|
||||
"autoawq",
|
||||
"tokenizers",
|
||||
"sentencepiece",
|
||||
"torchao",
|
||||
"lm_eval",
|
||||
}
|
||||
|
||||
|
||||
def is_main_process() -> bool:
|
||||
"""
|
||||
Check whether we're running in the main process.
|
||||
|
||||
Note:
|
||||
We're using this function instead of `torch.utils.distributed.is_main_process`
|
||||
causes issues with DeepSpeed world_size since. This function avoids that issue
|
||||
by checking env vars that are set by various launchers.
|
||||
|
||||
Returns:
|
||||
Whether we're running in the main process.
|
||||
"""
|
||||
# If PyTorch distributed is already initialized, use it
|
||||
if torch.distributed.is_initialized():
|
||||
return torch.distributed.get_rank() == 0
|
||||
|
||||
# Otherwise check environment variables for global rank
|
||||
# NOTE: need to verify this in SLURM / OpenMPI environments
|
||||
global_rank = int(
|
||||
os.environ.get(
|
||||
"RANK",
|
||||
os.environ.get(
|
||||
"GLOBAL_RANK",
|
||||
os.environ.get(
|
||||
"SLURM_PROCID",
|
||||
os.environ.get(
|
||||
"OMPI_COMM_WORLD_RANK",
|
||||
"0",
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return global_rank == 0
|
||||
|
||||
|
||||
class TelemetryManager:
|
||||
"""Manages telemetry collection and transmission"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
"""
|
||||
Telemetry manager constructor. Creates the singleton instance of this class if
|
||||
it doesn't already exist.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super(TelemetryManager, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Telemetry manager initializer"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.enabled = self._check_telemetry_enabled()
|
||||
|
||||
if self.enabled:
|
||||
self.run_id = str(uuid.uuid4())
|
||||
self.whitelist = self._load_whitelist()
|
||||
|
||||
try:
|
||||
self.system_info = self._get_system_info()
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"Error during system info collection: {e}")
|
||||
self.system_info = None
|
||||
|
||||
self._init_posthog()
|
||||
|
||||
# Register shutdown method to flush posthog telemetry
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "TelemetryManager":
|
||||
if cls._instance is None:
|
||||
cls._instance = TelemetryManager()
|
||||
|
||||
return cls._instance
|
||||
|
||||
def _check_telemetry_enabled(self) -> bool:
|
||||
"""
|
||||
Check if telemetry is enabled based on environment variables. We also check
|
||||
whether this is the main process (for the distributed setting and to avoid
|
||||
sending duplicate PostHog events per GPU).
|
||||
|
||||
Note: This is enabled by default on an opt-out basis. Set
|
||||
`AXOLOTL_DO_NOT_TRACK=1` to disable telemetry. For more details, see
|
||||
https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
|
||||
|
||||
Returns:
|
||||
Boolean denoting whether telemetry is enabled or not.
|
||||
"""
|
||||
# Parse relevant env vars
|
||||
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
|
||||
do_not_track = os.getenv("DO_NOT_TRACK")
|
||||
|
||||
# Default to enabled (opt-out model)
|
||||
if axolotl_do_not_track is None or axolotl_do_not_track.lower() not in (
|
||||
"0",
|
||||
"1",
|
||||
"false",
|
||||
"true",
|
||||
):
|
||||
# Print opt-out info message for main process only
|
||||
if is_main_process():
|
||||
LOG.warning(OPT_OUT_WARNING)
|
||||
time.sleep(OPT_OUT_WARNING_SLEEP_SECONDS)
|
||||
|
||||
return True
|
||||
|
||||
# Only rank 0 will send telemetry
|
||||
if not is_main_process():
|
||||
return False
|
||||
|
||||
if do_not_track is None:
|
||||
do_not_track = "0"
|
||||
|
||||
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
|
||||
enabled = axolotl_do_not_track.lower() not in (
|
||||
"1",
|
||||
"true",
|
||||
) and do_not_track.lower() not in ("1", "true")
|
||||
|
||||
return enabled
|
||||
|
||||
def _load_whitelist(self) -> dict:
|
||||
"""Load HuggingFace Hub organization whitelist"""
|
||||
with open(WHITELIST_PATH, encoding="utf-8") as f:
|
||||
whitelist = yaml.safe_load(f)
|
||||
|
||||
# Send org strings to lowercase since model names are case insensitive
|
||||
whitelist["organizations"] = {
|
||||
org.lower() for org in whitelist["organizations"]
|
||||
}
|
||||
|
||||
return whitelist
|
||||
|
||||
def _is_whitelisted(self, value: str) -> bool:
|
||||
"""
|
||||
Check if model / dataset / etc. org is in whitelist.
|
||||
|
||||
Args:
|
||||
value: Value for one of `axolotl.telemetry.manager.FIELDS_WITH_ORGS`
|
||||
("base_model", etc.).
|
||||
|
||||
Returns:
|
||||
Boolean indicating whitelist membership.
|
||||
"""
|
||||
# NOTE: This membership-checking logic can be improved.
|
||||
# What happens when a local model path matches a whitelisted org?
|
||||
parts = value.split("/")
|
||||
if len(parts) < 2:
|
||||
return False
|
||||
org = parts[0]
|
||||
whitelisted = org.lower() in self.whitelist["organizations"]
|
||||
|
||||
return whitelisted
|
||||
|
||||
def _init_posthog(self):
|
||||
"""Initialize PostHog client"""
|
||||
posthog.api_key = POSTHOG_WRITE_KEY
|
||||
posthog.project_api_key = POSTHOG_WRITE_KEY
|
||||
posthog.host = POSTHOG_HOST
|
||||
|
||||
def _redact_paths(self, properties: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Redact properties to remove any paths, so as to avoid inadvertently collecting
|
||||
private or personally identifiable information (PII). We also remove
|
||||
information related to Wandb, MLflow, etc. configuration.
|
||||
|
||||
Args:
|
||||
properties: Dictionary of properties to redact.
|
||||
|
||||
Returns:
|
||||
Properties dictionary with redaction applied.
|
||||
"""
|
||||
if not properties:
|
||||
return {}
|
||||
|
||||
def redact_value(value: Any, key: str = "") -> Any:
|
||||
"""Recursively sanitize values, redacting those with path-like keys"""
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
# Other redaction special cases
|
||||
if (
|
||||
key in FIELDS_TO_REDACT
|
||||
or any(prefix in key for prefix in PREFIXES_TO_REDACT)
|
||||
or any(indicator in key.lower() for indicator in PATH_INDICATORS)
|
||||
):
|
||||
# Fields with whitelisted orgs don't need to be redacted
|
||||
if not self._is_whitelisted(value):
|
||||
return "[REDACTED]"
|
||||
|
||||
# Handle nested values
|
||||
if isinstance(value, dict):
|
||||
return {k: redact_value(v, k) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [redact_value(item) for item in value]
|
||||
|
||||
return value
|
||||
|
||||
# Create new dict with redacted values
|
||||
redacted = {k: redact_value(v, k) for k, v in properties.items()}
|
||||
|
||||
return redacted
|
||||
|
||||
def _get_system_info(self) -> dict[str, Any]:
|
||||
"""Collect system information for various hardware accelerators"""
|
||||
gpu_info = []
|
||||
accelerator_type = "none"
|
||||
|
||||
# NVIDIA GPUs
|
||||
if torch.cuda.is_available():
|
||||
accelerator_type = "cuda"
|
||||
for i in range(torch.cuda.device_count()):
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": torch.cuda.get_device_name(i),
|
||||
"memory": torch.cuda.get_device_properties(i).total_memory,
|
||||
}
|
||||
)
|
||||
|
||||
# AMD GPUs
|
||||
elif hasattr(torch, "hip") and torch.hip.is_available():
|
||||
accelerator_type = "hip"
|
||||
for i in range(torch.hip.device_count()):
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": torch.hip.get_device_name(i),
|
||||
"memory": (
|
||||
torch.hip.get_device_properties(i).total_memory
|
||||
if hasattr(torch.hip, "get_device_properties")
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Apple Silicon
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
accelerator_type = "mps"
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": "Apple Silicon",
|
||||
# NOTE: this is memory allocated to this process, not total memory
|
||||
"memory": torch.mps.driver_allocated_memory(),
|
||||
}
|
||||
)
|
||||
|
||||
# Intel GPUs
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
accelerator_type = "xpu"
|
||||
for i in range(torch.xpu.device_count()):
|
||||
memory = None
|
||||
if hasattr(torch.xpu, "get_device_properties"):
|
||||
memory = torch.xpu.get_device_properties(i).total_memory
|
||||
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": torch.xpu.get_device_name(i),
|
||||
"memory": memory,
|
||||
}
|
||||
)
|
||||
|
||||
# NPUs
|
||||
elif hasattr(torch, "npu") and torch.npu.is_available():
|
||||
accelerator_type = "npu"
|
||||
for i in range(torch.npu.device_count()):
|
||||
memory = None
|
||||
if hasattr(torch.npu, "get_device_properties"):
|
||||
memory = torch.npu.get_device_properties(i).total_memory
|
||||
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": torch.npu.get_device_name(i),
|
||||
"memory": memory,
|
||||
}
|
||||
)
|
||||
|
||||
# Get relevant package versions
|
||||
installed_packages = {}
|
||||
for package in RELEVANT_PACKAGES:
|
||||
try:
|
||||
version = importlib.metadata.version(package)
|
||||
installed_packages[f"{package}_version"] = version
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
pass
|
||||
|
||||
return {
|
||||
"os": platform.system(),
|
||||
"python_version": platform.python_version(),
|
||||
"cpu_count": psutil.cpu_count(),
|
||||
"memory_total": psutil.virtual_memory().total,
|
||||
"accelerator_type": accelerator_type,
|
||||
"accelerator_count": len(gpu_info),
|
||||
"accelerator_info": gpu_info,
|
||||
**installed_packages,
|
||||
}
|
||||
|
||||
def send_event(self, event_type: str, properties: dict[str, Any] | None = None):
|
||||
"""Send a telemetry event"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
if properties is None:
|
||||
properties = {}
|
||||
|
||||
# Sanitize properties to remove PII
|
||||
properties = self._redact_paths(properties)
|
||||
|
||||
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage
|
||||
try:
|
||||
# Send event via PostHog
|
||||
posthog.capture(
|
||||
distinct_id=self.run_id,
|
||||
event=event_type,
|
||||
properties=properties,
|
||||
disable_geoip=True,
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"Failed to send telemetry event: {e}")
|
||||
|
||||
# Additionally, send system info telemetry when loading config.
|
||||
# NOTE: Is this the best place for this?
|
||||
if event_type == "config-loaded":
|
||||
self.send_system_info()
|
||||
|
||||
def send_system_info(self):
|
||||
"""Helper method for sending system info"""
|
||||
if self.system_info is not None:
|
||||
self.send_event(event_type="system-info", properties=self.system_info)
|
||||
|
||||
def shutdown(self):
|
||||
"""Ensure all queued events are processed before shutdown"""
|
||||
if self.enabled:
|
||||
posthog.shutdown()
|
||||
@@ -1,210 +0,0 @@
|
||||
"""Telemetry utilities for runtime and memory metrics."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeMetrics:
|
||||
"""Container for runtime metrics to be tracked throughout training."""
|
||||
|
||||
# Timing metrics
|
||||
start_time: float
|
||||
epoch_start_times: dict[int, float] = field(init=False)
|
||||
epoch_end_times: dict[int, float] = field(init=False)
|
||||
|
||||
# Memory metrics
|
||||
peak_cpu_memory: int = 0
|
||||
peak_gpu_memory: dict[int, int] = field(init=False)
|
||||
|
||||
# Progress metrics
|
||||
total_steps: int = 0
|
||||
current_epoch: int = 0
|
||||
current_step: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize empty metric mappings."""
|
||||
self.epoch_start_times = {}
|
||||
self.epoch_end_times = {}
|
||||
self.peak_gpu_memory = {}
|
||||
|
||||
@property
|
||||
def elapsed_time(self) -> float:
|
||||
"""Calculate total elapsed time in seconds."""
|
||||
return time.time() - self.start_time
|
||||
|
||||
def epoch_time(self, epoch: int) -> float | None:
|
||||
"""Calculate time taken for a specific epoch in seconds."""
|
||||
if epoch in self.epoch_start_times and epoch in self.epoch_end_times:
|
||||
return self.epoch_end_times[epoch] - self.epoch_start_times[epoch]
|
||||
|
||||
return None
|
||||
|
||||
def average_epoch_time(self) -> float | None:
|
||||
"""Calculate average time per epoch in seconds."""
|
||||
completed_epochs = [
|
||||
epoch for epoch in self.epoch_start_times if epoch in self.epoch_end_times
|
||||
]
|
||||
if not completed_epochs:
|
||||
return None
|
||||
|
||||
total_time = 0.0
|
||||
for epoch in completed_epochs:
|
||||
epoch_time = self.epoch_time(epoch)
|
||||
if epoch_time is not None: # Check to avoid mypy warning
|
||||
total_time += epoch_time
|
||||
|
||||
return total_time / len(completed_epochs)
|
||||
|
||||
def steps_per_second(self) -> float | None:
|
||||
"""Calculate average steps per second across all training."""
|
||||
if self.total_steps == 0 or self.elapsed_time == 0:
|
||||
return None
|
||||
|
||||
return self.total_steps / self.elapsed_time
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert metrics to a dictionary for telemetry reporting."""
|
||||
metrics = {
|
||||
"total_time_seconds": self.elapsed_time,
|
||||
"total_steps": self.total_steps,
|
||||
"steps_per_second": self.steps_per_second(),
|
||||
"epochs_completed": len(
|
||||
[
|
||||
epoch
|
||||
for epoch in self.epoch_start_times
|
||||
if epoch in self.epoch_end_times
|
||||
]
|
||||
),
|
||||
"peak_cpu_memory_bytes": self.peak_cpu_memory,
|
||||
}
|
||||
|
||||
# Add per-epoch timing if available
|
||||
epoch_times: dict[str, float] = {}
|
||||
for epoch in sorted(self.epoch_end_times.keys()):
|
||||
time_taken = self.epoch_time(epoch)
|
||||
if time_taken is not None:
|
||||
epoch_times[f"epoch_{epoch}_seconds"] = time_taken
|
||||
|
||||
if epoch_times:
|
||||
metrics["epoch_times"] = epoch_times # type: ignore
|
||||
metrics["average_epoch_time_seconds"] = self.average_epoch_time()
|
||||
|
||||
# Add GPU memory metrics if available
|
||||
if self.peak_gpu_memory:
|
||||
gpu_metrics: dict[str, int] = {}
|
||||
for gpu_id, memory in self.peak_gpu_memory.items():
|
||||
gpu_metrics[f"gpu_{gpu_id}_peak_memory_bytes"] = memory
|
||||
metrics["gpu_memory"] = gpu_metrics # type: ignore
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
class RuntimeMetricsTracker:
|
||||
"""Tracker for runtime metrics during training."""
|
||||
|
||||
update_interval = 100
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the runtime metrics tracker."""
|
||||
self.metrics = RuntimeMetrics(start_time=time.time())
|
||||
self.telemetry_manager = TelemetryManager.get_instance()
|
||||
self._process = psutil.Process()
|
||||
|
||||
def start_epoch(self, epoch: int):
|
||||
"""Record the start of a new epoch."""
|
||||
self.metrics.current_epoch = epoch
|
||||
self.metrics.epoch_start_times[epoch] = time.time()
|
||||
self.update_memory_metrics()
|
||||
|
||||
def end_epoch(self, epoch: int):
|
||||
"""Record the end of an epoch."""
|
||||
self.metrics.epoch_end_times[epoch] = time.time()
|
||||
|
||||
def update_step(self, step: int):
|
||||
"""Update the current step count."""
|
||||
self.metrics.current_step = step
|
||||
self.metrics.total_steps += 1
|
||||
|
||||
# Periodically update memory metrics
|
||||
if step % self.update_interval == 0:
|
||||
self.update_memory_metrics()
|
||||
|
||||
def _get_allocated_memory(self) -> dict[int, int]:
|
||||
"""
|
||||
Helper function for getting accelerator-agnostic allocated memory.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping device IDs to allocated memory in bytes
|
||||
"""
|
||||
memory_used: dict[int, int] = {}
|
||||
|
||||
# NVIDIA GPUs
|
||||
if torch.cuda.is_available():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
memory_used[i] = torch.cuda.memory_allocated(i)
|
||||
|
||||
# AMD GPUs
|
||||
elif hasattr(torch, "hip") and torch.hip.is_available():
|
||||
for i in range(torch.hip.device_count()):
|
||||
if hasattr(torch.hip, "memory_allocated"):
|
||||
memory_used[i] = torch.hip.memory_allocated(i)
|
||||
|
||||
# Apple Silicon
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
# MPS doesn't have per-device memory stats since there's only one device
|
||||
if hasattr(torch.mps, "current_allocated_memory"):
|
||||
memory_used[0] = torch.mps.current_allocated_memory()
|
||||
|
||||
# Intel GPUs
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
for i in range(torch.xpu.device_count()):
|
||||
if hasattr(torch.xpu, "memory_allocated"):
|
||||
memory_used[i] = torch.xpu.memory_allocated(i)
|
||||
|
||||
# NPUs
|
||||
elif hasattr(torch, "npu") and torch.npu.is_available():
|
||||
for i in range(torch.npu.device_count()):
|
||||
if hasattr(torch.npu, "memory_allocated"):
|
||||
memory_used[i] = torch.npu.memory_allocated(i)
|
||||
|
||||
return memory_used
|
||||
|
||||
def update_memory_metrics(self):
|
||||
"""Update peak memory usage metrics."""
|
||||
# CPU memory
|
||||
cpu_memory = self._process.memory_info().rss
|
||||
self.metrics.peak_cpu_memory = max(self.metrics.peak_cpu_memory, cpu_memory)
|
||||
|
||||
# GPU memory (if available)
|
||||
memory_used = self._get_allocated_memory()
|
||||
for i, memory in memory_used.items():
|
||||
self.metrics.peak_gpu_memory[i] = max(
|
||||
self.metrics.peak_gpu_memory.get(i, 0), memory
|
||||
)
|
||||
|
||||
def get_memory_metrics(self) -> dict[str, Any]:
|
||||
"""Get the current memory metrics as a dictionary."""
|
||||
memory_metrics = {
|
||||
"cpu_memory_bytes": self._process.memory_info().rss,
|
||||
"peak_cpu_memory_bytes": self.metrics.peak_cpu_memory,
|
||||
}
|
||||
|
||||
# GPU memory (if available)
|
||||
memory_used = self._get_allocated_memory()
|
||||
for i, memory in memory_used.items():
|
||||
memory_metrics[f"gpu_{i}_memory_bytes"] = memory
|
||||
memory_metrics[f"gpu_{i}_peak_memory_bytes"] = (
|
||||
self.metrics.peak_gpu_memory.get(i, 0)
|
||||
)
|
||||
|
||||
return memory_metrics
|
||||
@@ -1,33 +0,0 @@
|
||||
organizations:
|
||||
- "axolotl-ai-co"
|
||||
- "meta-llama"
|
||||
- "huggingface"
|
||||
- "nvidia"
|
||||
- "facebook"
|
||||
- "google"
|
||||
- "microsoft"
|
||||
- "deepseek-ai"
|
||||
- "HuggingFaceTB"
|
||||
- "mistralai"
|
||||
- "Qwen"
|
||||
- "unsloth"
|
||||
- "NousResearch"
|
||||
- "allenai"
|
||||
- "amd"
|
||||
- "tiiuae"
|
||||
- "tencent"
|
||||
- "zai-org"
|
||||
- "openai"
|
||||
- "ibm-granite"
|
||||
- "arcee-ai"
|
||||
- "swiss-ai"
|
||||
- "CohereForAI"
|
||||
- "deepcogito"
|
||||
- "THUDM"
|
||||
- "ai21labs"
|
||||
- "LiquidAI"
|
||||
- "canopylabs"
|
||||
- "state-spaces"
|
||||
- "mistral-community"
|
||||
- "llava-hf"
|
||||
- "ByteDance-Seed"
|
||||
@@ -31,8 +31,6 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders import ModelLoader, load_processor, load_tokenizer
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
@@ -47,9 +45,6 @@ if typing.TYPE_CHECKING:
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
|
||||
|
||||
def setup_model_and_tokenizer(
|
||||
cfg: DictDefault,
|
||||
@@ -67,10 +62,7 @@ def setup_model_and_tokenizer(
|
||||
`None`), and processor (if multimodal, else `None`).
|
||||
"""
|
||||
# Load tokenizer
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
main_process_only=True,
|
||||
)
|
||||
LOG.debug(f"Loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
# Load processor for multimodal models if needed
|
||||
@@ -86,14 +78,6 @@ def setup_model_and_tokenizer(
|
||||
if model.generation_config is not None:
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
TELEMETRY_MANAGER.send_event(
|
||||
event_type="model-load", properties=model.config.to_dict()
|
||||
)
|
||||
if peft_config:
|
||||
TELEMETRY_MANAGER.send_event(
|
||||
event_type="peft-config-load", properties=peft_config.to_dict()
|
||||
)
|
||||
|
||||
# Apply freezing if specified
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||
@@ -212,7 +196,8 @@ def execute_training(
|
||||
LOG.info("Starting trainer...")
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
PLUGIN_MANAGER.post_train(cfg, trainer.model)
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_train(cfg, trainer.model)
|
||||
|
||||
|
||||
def save_trained_model(
|
||||
@@ -536,7 +521,9 @@ def setup_model_and_trainer(
|
||||
model_ref=model_ref,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
PLUGIN_MANAGER.post_trainer_create(cfg, trainer)
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_trainer_create(cfg, trainer)
|
||||
|
||||
if cfg.use_ray:
|
||||
try:
|
||||
@@ -558,7 +545,6 @@ def setup_model_and_trainer(
|
||||
)
|
||||
|
||||
|
||||
@send_errors
|
||||
def train(
|
||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
|
||||
@@ -609,6 +595,5 @@ def train(
|
||||
create_model_card(cfg, trainer)
|
||||
if not cfg.use_ray:
|
||||
cleanup_distributed()
|
||||
PLUGIN_MANAGER.post_train(cfg, model)
|
||||
|
||||
return model, tokenizer, trainer
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user