Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
faed3905fd version tag 0.13.2
Some checks failed
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 129, 12.9.1, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64, 3.11, 2.8.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 129, 12.9.1, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2026-01-22 10:58:38 -05:00
175 changed files with 743 additions and 9161 deletions

View File

@@ -51,30 +51,14 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
- cuda: "129"
cuda_version: 12.9.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "129"
# cuda_version: 12.9.1
# cudnn_version: ""
# python_version: "3.12"
# pytorch: 2.9.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-base"
# platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
@@ -91,14 +75,6 @@ jobs:
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
@@ -181,30 +157,14 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
- cuda: "129"
cuda_version: 12.9.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "129"
# cuda_version: 12.9.1
# cudnn_version: ""
# python_version: "3.12"
# pytorch: 2.9.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-uv-base"
# platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
@@ -221,14 +181,6 @@ jobs:
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -34,30 +34,18 @@ jobs:
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
pytorch: 2.9.1
axolotl_extras: vllm
platforms: "linux/amd64,linux/arm64"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras:
# platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -98,77 +86,6 @@ jobs:
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-uv:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-uv
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
- name: Build and export to Docker
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
file: ./docker/Dockerfile-uv
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
@@ -195,28 +112,16 @@ jobs:
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras:
# platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
@@ -254,73 +159,6 @@ jobs:
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-uv:
needs: build-axolotl-uv
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-cloud-uv
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud-uv
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-no-tmux:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}

View File

@@ -45,7 +45,7 @@ jobs:
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras: "fbgemm-gpu"
axolotl_extras: "fbgemm-gpu,vllm"
num_gpus: 2
dockerfile: "Dockerfile-uv.jinja"
- cuda: 130

View File

@@ -37,7 +37,7 @@ jobs:
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Setup Python
uses: actions/setup-python@v5

View File

@@ -54,13 +54,13 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
# exclude:
# - python_version: "3.14"
# pytorch_version: "2.8.0"
# - python_version: "3.14"
# pytorch_version: "2.9.1"
python_version: ["3.11", "3.12"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
exclude:
- python_version: "3.12"
pytorch_version: "2.8.0"
- python_version: "3.12"
pytorch_version: "2.9.0"
timeout-minutes: 20
steps:
@@ -75,7 +75,7 @@ jobs:
id: hf-cache-restore-s3
run: |
mkdir -p ~/.cache/huggingface/hub
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/
- name: Setup Python
@@ -115,10 +115,10 @@ jobs:
- name: Pre-Download dataset fixture
run: |
hf download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Show HF cache
run: hf cache ls
run: hf cache scan
- name: Run tests
run: |
@@ -132,7 +132,7 @@ jobs:
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Show HF cache
run: hf cache ls
run: hf cache scan
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
@@ -149,13 +149,13 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
# exclude:
# - python_version: "3.14"
# pytorch_version: "2.8.0"
# - python_version: "3.14"
# pytorch_version: "2.9.1"
python_version: ["3.11", "3.12"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
exclude:
- python_version: "3.12"
pytorch_version: "2.8.0"
- python_version: "3.12"
pytorch_version: "2.9.0"
timeout-minutes: 20
steps:
@@ -170,7 +170,7 @@ jobs:
id: hf-cache-restore-s3
run: |
mkdir -p ~/.cache/huggingface/hub
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/
- name: Setup Python
@@ -210,7 +210,7 @@ jobs:
axolotl --help
- name: Show HF cache
run: hf cache ls
run: hf cache scan
- name: Run tests
run: |
@@ -219,10 +219,10 @@ jobs:
pytest -v --durations=10 tests/cli/
- name: Show HF cache
run: hf cache ls
run: hf cache scan
gate-skip-e2e:
needs: [pre-commit]
needs: [pre-commit, pytest, pytest-sdist]
runs-on: ubuntu-latest
outputs:
skip: ${{ steps.compute.outputs.skip }}
@@ -258,18 +258,18 @@ jobs:
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
needs: [pre-commit, pytest]
needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e]
strategy:
fail-fast: false
matrix:
include:
- cuda: 130
cuda_version: 13.0.0
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
axolotl_extras: vllm
dockerfile: "Dockerfile-uv.jinja"
steps:
- name: Checkout
@@ -326,12 +326,6 @@ jobs:
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.10.0
num_gpus: 1
axolotl_extras:
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
@@ -377,7 +371,7 @@ jobs:
include:
- cuda: 129
cuda_version: 12.9.1
python_version: "3.11"
python_version: "3.12"
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:

View File

@@ -123,7 +123,7 @@ datasets:
| --------------------------------- | -------------------------- | ----------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_num_proc` | `4` | Number of preprocessing processes |
| `dataset_processes` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |

View File

@@ -39,6 +39,7 @@
# type: # linear | dynamic
# factor: # float
# # Whether you are training a 4-bit GPTQ quantized model
# gptq: true
# gptq_groupsize: 128 # group size
@@ -106,7 +107,7 @@
# push_dataset_to_hub: # repo path
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# # if not set.
# dataset_num_proc: # defaults to os.cpu_count() if not set
# dataset_processes: # defaults to os.cpu_count() if not set
# # push checkpoints to hub
# hub_model_id: # repo path to push finetuned model
# # how to push checkpoints to hub
@@ -223,6 +224,9 @@
# eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
# eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
# # Save model as safetensors (require safetensors package)
# save_safetensors:
# # Whether to mask out or include the human's prompt from the training labels
# train_on_inputs: false
# # Group similarly sized data to minimize padding.
@@ -348,6 +352,8 @@
# # Allow overwrite yml config using from cli
# strict:
base_model: ${BASE_MODEL}
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
base_model_config: ${BASE_MODEL_CONFIG}
@@ -406,7 +412,7 @@ chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
dataset_prepared_path: ${DATASET_PREPARED_PATH}
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
dataset_num_proc: ${DATASET_NUM_PROC}
dataset_processes: ${DATASET_PROCESSES}
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
hub_model_id: ${HUB_MODEL_ID}
hub_strategy: ${HUB_STRATEGY}
@@ -506,6 +512,7 @@ profiler_steps: ${PROFILER_STEPS}
loss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD}
loss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE}
save_safetensors: ${SAVE_SAFETENSORS}
train_on_inputs: ${TRAIN_ON_INPUTS}
group_by_length: ${GROUP_BY_LENGTH}
gradient_checkpointing: ${GRADIENT_CHECKPOINTING}

View File

@@ -1 +1 @@
0.15.0.dev0
0.13.2

View File

@@ -251,6 +251,7 @@ website:
- docs/models/olmo3.qmd
- docs/models/trinity.qmd
- docs/models/arcee.qmd
- docs/models/mistral.qmd
- section: "Ministral3"
contents:
- docs/models/ministral3.qmd
@@ -265,7 +266,6 @@ website:
- docs/models/mistral-small.qmd
- docs/models/voxtral.qmd
- docs/models/devstral.qmd
- docs/models/mistral.qmd
- docs/models/llama-4.qmd
- docs/models/llama-2.qmd
- docs/models/qwen3-next.qmd
@@ -320,7 +320,6 @@ website:
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/optimizers.qmd
- docs/attention.qmd
- section: "Advanced Features"
contents:

View File

@@ -2,7 +2,7 @@
set -e
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v --durations=10 -n2 --maxfail=3 \
pytest -v --durations=10 -n2 --maxfail=4 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \

View File

@@ -59,18 +59,34 @@ RUN git lfs install --skip-repo && \
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
pip3 install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"
RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \
if [ "$CUDA" = "128" ]; then \
if [ "$TARGETARCH" = "amd64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl"; \
WHL_VERSION="v0.5.4"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl"; \
WHL_VERSION="v0.6.4"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi; \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
pip3 install --no-cache-dir ${WHL_FILE}; \
rm ${WHL_FILE}; \
elif [ "$CUDA" = "130" ]; then \
if [ "$TARGETARCH" = "amd64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl"; \
WHL_VERSION="v0.5.4"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl"; \
WHL_VERSION="v0.6.4"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi; \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
pip3 install --no-cache-dir ${WHL_FILE}; \
rm ${WHL_FILE}; \
fi \
;; \
esac

View File

@@ -1,30 +0,0 @@
ARG BASE_TAG=main
FROM axolotlai/axolotl-uv:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
EXPOSE 8888
EXPOSE 22
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN uv pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt update && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/* && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"]

View File

@@ -1,47 +0,0 @@
ARG BASE_TAG=main-base
FROM axolotlai/axolotl-base-uv:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/*
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \
python scripts/unsloth_install.py --uv | sh && \
python scripts/cutcrossentropy_install.py --uv | sh && \
uv pip install pytest && \
uv cache clean
# fix so that git fetch/pull from remote works with shallow clone
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch && \
git config --global credential.helper store
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
RUN chmod +x /root/.axolotl-complete.bash && \
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc

View File

@@ -6,7 +6,6 @@ ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ARG TARGETARCH
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.6.0"
ARG CUDA="126"
@@ -40,18 +39,28 @@ RUN if [ "$TARGETARCH" = "amd64" ]; then \
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
fi
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
uv pip install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"
RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \
if [ "$TARGETARCH" = "amd64" ]; then \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
fi \
elif [ "$TARGETARCH" = "arm64" ]; then \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
fi \
fi \
;; \
esac

View File

@@ -86,7 +86,7 @@ export HF_DATASETS_OFFLINE=1
Download a base model using the Hugging Face CLI:
```bash
hf download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
huggingface-cli download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
```
### 10. Create Axolotl Configuration

View File

@@ -1,140 +0,0 @@
---
title: Attention
description: Supported attention modules in Axolotl
---
## SDP Attention
This is the default built-in attention in PyTorch.
```yaml
sdp_attention: true
```
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
## Flash Attention 2
Uses efficient kernels to compute attention.
```yaml
flash_attention: true
```
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Nvidia
Requirements: Ampere, Ada, or Hopper GPUs
Note: For Turing GPUs or lower, please use other attention methods.
```bash
pip install flash-attn --no-build-isolation
```
::: {.callout-tip}
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
:::
#### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```
### AMD
Requirements: ROCm 6.0 and above.
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
## Flex Attention
A flexible PyTorch API for attention used in combination with `torch.compile`.
```yaml
flex_attention: true
# recommended
torch_compile: true
```
::: {.callout-note}
We recommend using latest stable version of PyTorch for best performance.
:::
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
## SageAttention
Attention kernels with QK Int8 and PV FP16 accumulator.
```yaml
sage_attention: true
```
Requirements: Ampere, Ada, or Hopper GPUs
```bash
pip install sageattention==2.2.0 --no-build-isolation
```
::: {.callout-warning}
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
:::
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
::: {.callout-note}
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
:::
## xFormers
```yaml
xformers_attention: true
```
::: {.callout-tip}
We recommend using with Turing GPUs or below (such as on Colab).
:::
For more details: [xFormers](https://github.com/facebookresearch/xformers)
## Shifted Sparse Attention
::: {.callout-warning}
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
:::
Requirements: LLaMA model architecture
```yaml
flash_attention: true
s2_attention: true
```
::: {.callout-tip}
No sample packing support!
:::

View File

@@ -210,8 +210,6 @@ axolotl lm-eval config.yml
Configuration options:
```yaml
lm_eval_model: # model to evaluate (local or hf path)
# List of tasks to evaluate
lm_eval_tasks:
- arc_challenge
@@ -220,7 +218,7 @@ lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
```
See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
### delinearize-llama4

View File

@@ -165,7 +165,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
```
4. (Optional) Login to Hugging Face:
```{.bash}
hf auth login
huggingface-cli login
```
## Troubleshooting {#sec-troubleshooting}

View File

@@ -89,10 +89,6 @@ lora_o_kernel: true
Currently, LoRA kernels are not supported for RLHF training, only SFT.
:::
::: {.callout-warning}
LoRA kernels do not support remote modeling code.
:::
## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)

View File

@@ -19,7 +19,6 @@ format:
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [GLM-4.6V](#sec-glm-4-6v)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
- [Intern-VL](#sec-intern-vl)
@@ -184,18 +183,6 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### GLM-4.6V {#sec-glm-4-6v}
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.
```yaml
# GLM-4.6V (106B MoE version)
base_model: zai-org/GLM-4.6V
# OR GLM-4.6V-Flash (9B version)
base_model: zai-org/GLM-4.6V-Flash
```
### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip}

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2\""
]
},
{

View File

@@ -1,77 +0,0 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3
eot_tokens:
- <end_of_turn>
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/eaft-gemma-3-1b
use_eaft: true
eaft_alpha: 1.0
eaft_k: 20
sequence_len: 1024
sample_packing: false
adapter:
lora_model_dir:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
eval_batch_size: 1
max_steps: 1000
evaluation_strategy: "no"
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
weight_decay: 0.0
debug:
deepspeed:
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,77 +0,0 @@
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model by Z.ai.
This guide shows how to fine-tune it with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
# QLoRA
# - no target experts (1x48GB @ ~24GiB/GPU)
# - target experts (1x48GB @ ~34GiB/GPU)
axolotl train examples/glm4.7-flash/qlora.yaml
# QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU)
axolotl train examples/glm4.7-flash/qlora_fsdp.yaml
```
```bash
# LoRA
# - no target experts (1x48GB @ ~35GiB/GPU)
# - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU)
axolotl train examples/glm4.7-flash/lora.yaml
# LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU)
axolotl train examples/glm4.7-flash/lora_fsdp.yaml
```
### Expert LoRA
To also apply LoRA adapters to expert weights, add `lora_target_parameters` to your config.
Note: `lora_dropout` must be `0` when using `lora_target_parameters`.
```yaml
lora_target_parameters:
- mlp.experts.gate_up_proj
- mlp.experts.down_proj
# - mlp.gate.weight # router, untested but should work, not normally targeted
```
## Limitations
- **FSDP VRAM**: FSDP2 may use more VRAM per GPU than single GPU training. We suspect not all layers are properly sharded across ranks.
- **FSDP initial spike**: FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps that then drops. FSDP QLoRA (4-bit) does not exhibit this.
- **cpu_ram_efficient_loading**: Must be set to `false` with FSDP2 — causes hang otherwise.
- **lora_target_linear**: Incompatible for this model.
- **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`).
### TIPS
- For inference, the official Z.ai team recommends these default settings (most tasks):
- `temperature: 1.0`
- `top_p: 0.95`
- `max_new_tokens: 131072`
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy, so we have not tested this.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,65 +0,0 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-lora-8bit-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -1,75 +0,0 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-lora-8bit-fsdp-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -1,65 +0,0 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -1,75 +0,0 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-qlora-fsdp-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -1,44 +0,0 @@
# Finetune GLM-4.6V with Axolotl
GLM-4.6V is a family of vision-language models from ZhipuAI found on [HuggingFace](https://huggingface.co/zai-org/GLM-4.6V). This guide shows how to fine-tune it with Axolotl for vision-language tasks.
## Getting started
1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the fine-tuning:
glm-4-6v-flash(9B)
```bash
axolotl train examples/glm46v/glm-4-6v-flash-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
## Tips
- Vision datasets should follow the format described in the [multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format)
- You can run a **full finetuning** by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset in the [dataset loading docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Supported Models
- **GLM-4.6V**: Full vision-language model (`zai-org/GLM-4.6V`)
- **GLM-4.6V-Flash**: Faster variant (`zai-org/GLM-4.6V-Flash`)
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [ZhipuAI GLM-4.6V](https://huggingface.co/zai-org/GLM-4.6V)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,53 +0,0 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
ddp_find_unused_parameters: true
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -1,50 +0,0 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -19,6 +19,7 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: jamba-large-fsdp-qlora-ft
save_safetensors: true
adapter: qlora
sequence_len: 2048
sample_packing: true

View File

@@ -12,6 +12,7 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out/qlora-llama3_1-405b
save_safetensors: true
adapter: qlora

View File

@@ -47,5 +47,6 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
tokens:
save_safetensors: False
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -6,13 +6,30 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Install Qwen3-Next transformers commit
```bash
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
```
3. Install FLA for improved performance
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
```
4. Run the finetuning example:
@@ -21,7 +38,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
```
This config uses about ~47 GiB (no target experts) and ~71GiB (target experts) VRAM.
This config uses about 45.62 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀

View File

@@ -9,8 +9,6 @@ plugins:
load_in_8bit: false
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
@@ -27,7 +25,7 @@ sample_packing: true
lora_r: 16
lora_alpha: 8
lora_dropout: 0
lora_dropout: 0.05
lora_target_modules:
- linear_attn.in_proj_ba
- linear_attn.in_proj_qkvz
@@ -36,19 +34,12 @@ lora_target_modules:
- shared_expert.down_proj
- shared_expert.gate_proj
- shared_expert_gate
- mlp.gate
- q_proj
- v_proj
- k_proj
- o_proj
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:

View File

@@ -8,15 +8,13 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
2. Run the finetuning example:
```bash
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
```
This config uses about 24.9 GiB VRAM (w/o CCE).
This config uses about 24.9 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
@@ -31,6 +29,10 @@ Let us know how it goes. Happy finetuning! 🚀
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Limitations
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
## Related Resources
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)

View File

@@ -1,4 +1,5 @@
base_model: arcee-ai/Trinity-Nano-Preview
trust_remote_code: true
revision_of_model: 2ee94b0
# Automatically upload checkpoint and final model to HF

View File

@@ -60,8 +60,3 @@ indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
[tool.uv.extra-build-dependencies]
axolotl = ["huggingface_hub"]
flash-attn = [{ requirement = "torch", match-runtime = true }]
deepspeed = [{ requirement = "torch", match-runtime = true }]

View File

@@ -2,25 +2,25 @@
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.49.1
triton>=3.4.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
liger-kernel==0.7.0
liger-kernel==0.6.4
# END section
packaging==26.0
huggingface_hub>=1.1.7
huggingface_hub>=0.36.0
peft>=0.18.1
tokenizers>=0.22.1
transformers==5.2.0
transformers==4.57.6
accelerate==1.12.0
datasets==4.5.0
deepspeed>=0.18.3
trl==0.28.0
trl==0.27.0
hf_xet==1.2.0
kernels==0.12.1
trackio>=0.16.1
kernels==0.11.5
trackio>=0.13.0
typing-extensions>=4.15.0
optimum==1.16.2
@@ -63,7 +63,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.16.0
torchao==0.13.0
openenv-core==0.1.0
schedulefree==1.4.1

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"'
)

View File

@@ -26,11 +26,6 @@ def parse_requirements(extras_require_map):
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
install_xformers = platform.machine() != "aarch64"
if platform.machine() == "aarch64":
# skip torchao on ARM64
_install_requires = [
req for req in _install_requires if "torchao" not in req
]
if "Darwin" in platform.system():
# skip packages not compatible with OSX
skip_packages = [

View File

@@ -5,6 +5,6 @@ import os
from axolotl.logging_config import configure_logging
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("HF_XET_HIGH_PERFORMANCE", "1")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
configure_logging()

View File

@@ -44,7 +44,7 @@ def check_user_token() -> bool:
return bool(user_info)
except LocalTokenNotFoundError:
LOG.warning(
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False
except HTTPError:

View File

@@ -5,7 +5,7 @@ import os
import tempfile
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Optional, Union
from typing import Union
from urllib.parse import urlparse
import requests
@@ -32,63 +32,6 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__)
def _coerce_value(value: Any, existing: Optional[Any] = None) -> Any:
"""Coerce a string CLI value to its most likely Python type.
If an existing value is present in the config, its type is used to guide
casting. Otherwise, YAML-style inference is applied: booleans, ints,
floats, and None literals are recognised automatically.
Args:
value: The raw value (typically a string from the CLI).
existing: An optional existing config value whose type guides coercion.
Returns:
The value cast to the inferred or expected type.
"""
if not isinstance(value, str):
return value
# If the config already has a typed value, cast to match
if existing is not None:
if isinstance(existing, bool):
return value.lower() in ("true", "1", "yes")
if isinstance(existing, int):
try:
return int(value)
except (ValueError, TypeError):
return value
if isinstance(existing, float):
try:
return float(value)
except (ValueError, TypeError):
return value
# For other types (str, list, dict, etc.), return as-is
return value
# No existing value -- use YAML-style inference
lower = value.lower()
if lower in ("true", "yes"):
return True
if lower in ("false", "no"):
return False
if lower in ("null", "none", "~"):
return None
# Try int then float
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
API_KEY_FIELDS = {"comet_api_key"}
TELEMETRY_MANAGER = TelemetryManager.get_instance()
@@ -265,37 +208,13 @@ def load_cfg(
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
# Separate nested (dot-notation) kwargs from flat kwargs
nested_kwargs: dict[str, dict[str, Any]] = {}
flat_kwargs: dict[str, Any] = {}
for key, value in kwargs.items():
if "__" in key:
parent, child = key.split("__", 1)
nested_kwargs.setdefault(parent, {})[child] = value
else:
flat_kwargs[key] = value
# Apply flat kwargs
for key, value in flat_kwargs.items():
# If not strict, allow writing to cfg even if it's not in the yml already
if key in cfg_keys or not cfg.strict:
cfg[key] = _coerce_value(value, cfg.get(key))
# Apply nested kwargs (e.g., trl__beta -> cfg.trl.beta)
for parent, children in nested_kwargs.items():
if parent not in cfg_keys and cfg.strict:
continue
if cfg[parent] is None:
cfg[parent] = {}
if not isinstance(cfg[parent], dict):
LOG.warning(
"Overwriting non-dict value for '%s' with nested CLI overrides", parent
)
cfg[parent] = {}
for child_key, child_value in children.items():
existing_child = cfg[parent].get(child_key)
cfg[parent][child_key] = _coerce_value(child_value, existing_child)
if isinstance(cfg[key], bool):
cfg[key] = bool(value)
else:
cfg[key] = value
try:
device_props = torch.cuda.get_device_properties("cuda")

View File

@@ -24,6 +24,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True
LOG.info("Running merge of LoRA with base model...")
model = model.merge_and_unload(progressbar=True)
@@ -41,6 +42,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(

View File

@@ -14,6 +14,8 @@ from accelerate import PartialState
from accelerate.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_torch_version,
)
from huggingface_hub import split_torch_state_dict_into_shards
@@ -38,15 +40,17 @@ class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
def _distributed_checkpoint_to_merged_weights(
checkpoint_dir: Union[str, Path],
save_path: str,
safe_serialization: bool = False,
max_shard_size: str = "5GB",
) -> Path:
"""
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will
save under `save_path` as `model.safetensors`.
save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
Args:
checkpoint_dir: Directory where distributed checkpoint is saved.
save_path: Path to save model to.
safe_serialization: Whether to save in safetensors format.
max_shard_size: Max size of model shards to save.
Returns:
@@ -72,7 +76,11 @@ def _distributed_checkpoint_to_merged_weights(
if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16:
state_dict[key] = value.to(torch.bfloat16)
filename_pattern = SAFE_WEIGHTS_NAME.replace(".safetensors", "{suffix}.safetensors")
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
@@ -90,12 +98,19 @@ def _distributed_checkpoint_to_merged_weights(
for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor] for tensor in tensors}
safe_save_file(
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
)
if safe_serialization:
safe_save_file(
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
)
else:
torch.save(shard, os.path.join(save_path_, shard_file))
if index is not None:
save_index_file = os.path.join(save_path_, SAFE_WEIGHTS_INDEX_NAME)
save_index_file = (
SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
)
save_index_file = os.path.join(save_path_, save_index_file)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as fout:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
@@ -108,11 +123,13 @@ def _distributed_checkpoint_to_merged_weights(
def merge_fsdp_weights(
checkpoint_dir: str,
output_path: str,
safe_serialization: bool = False,
remove_checkpoint_dir: bool = False,
):
"""
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors`.
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
`safe_serialization` else `pytorch_model.bin`.
Note: this is a CPU-bound process.
@@ -121,6 +138,8 @@ def merge_fsdp_weights(
The directory containing the FSDP checkpoints (can be either the model or optimizer).
output_path (`str`):
The path to save the merged checkpoint.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the merged weights with safetensors (recommended).
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
Whether to remove the checkpoint directory after merging.
@@ -158,7 +177,7 @@ def merge_fsdp_weights(
if state.is_main_process:
LOG.info(f"Merging FSDP weights from {checkpoint_dir_}")
save_path = _distributed_checkpoint_to_merged_weights(
checkpoint_dir_, output_path
checkpoint_dir_, output_path, safe_serialization
)
LOG.info(f"Successfully merged FSDP weights and saved to {save_path}")
if remove_checkpoint_dir:
@@ -191,6 +210,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir),
output_path=output_path,
safe_serialization=True,
)
state = PartialState()
state.wait_for_everyone()

View File

@@ -102,10 +102,12 @@ def do_quantize(
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
)
tokenizer.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
@@ -119,7 +121,7 @@ def do_quantize(
hub_model_id.rstrip("-")
+ f"-{quantization_config_to_str[type(quantization_config)]}"
)
model.push_to_hub(hub_model_id)
model.push_to_hub(hub_model_id, safe_serialization=False)
tokenizer.push_to_hub(hub_model_id)
if processor:
processor.push_to_hub(hub_model_id)

View File

@@ -2,7 +2,7 @@
import dataclasses
from functools import wraps
from types import NoneType, UnionType
from types import NoneType
from typing import Any, Callable, Type, Union, get_args, get_origin
import click
@@ -20,8 +20,7 @@ def _strip_optional_type(field_type: type | str | None):
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged.
"""
is_union = get_origin(field_type) is Union or isinstance(field_type, UnionType)
if is_union and type(None) in get_args(field_type):
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
@@ -88,70 +87,10 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
return decorator
def _is_pydantic_model(field_type: type) -> bool:
"""Check if a type is a Pydantic BaseModel subclass."""
try:
return isinstance(field_type, type) and issubclass(field_type, BaseModel)
except TypeError:
return False
def _get_field_description(field) -> str | None:
"""Get description from a Pydantic field, checking both .description and json_schema_extra."""
if field.description:
return field.description
if field.json_schema_extra and isinstance(field.json_schema_extra, dict):
return field.json_schema_extra.get("description")
return None
def _add_nested_model_options(
function: Callable, parent_name: str, model_class: Type[BaseModel]
) -> Callable:
"""
Add Click options for all fields of a nested Pydantic model using dot-notation.
Note: Only single-level nesting is supported (e.g., ``--trl.beta``).
Deeper nesting (e.g., ``--trl.scheduler.warmup``) is not handled.
Args:
function: Click command function to add options to.
parent_name: Parent field name (e.g., "trl").
model_class: Nested Pydantic model class.
Returns:
Function with added Click options.
"""
for sub_name, sub_field in reversed(model_class.model_fields.items()):
sub_type = _strip_optional_type(sub_field.annotation)
# Use dot notation: --parent.sub_field
cli_name = f"{parent_name}.{sub_name}".replace("_", "-")
# The kwarg name uses double-underscore as separator
param_name = f"{parent_name}__{sub_name}"
description = _get_field_description(sub_field)
if sub_type is bool:
option_name = f"--{cli_name}/--no-{cli_name}"
function = click.option(
option_name, param_name, default=None, help=description
)(function)
else:
option_name = f"--{cli_name}"
click_type = {str: str, int: int, float: float}.get(sub_type)
function = click.option(
option_name, param_name, default=None, type=click_type, help=description
)(function)
return function
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
"""
Create Click options from the fields of a Pydantic model.
For fields whose type is itself a Pydantic BaseModel, dot-notation CLI options are
generated for each sub-field (e.g., ``--trl.beta=0.1``).
Args:
config_class: PyDantic model with fields to parse from the CLI
@@ -164,11 +103,6 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
for name, field in reversed(config_class.model_fields.items()):
field_type = _strip_optional_type(field.annotation)
# Handle nested Pydantic models with dot-notation options
if _is_pydantic_model(field_type):
function = _add_nested_model_options(function, name, field_type)
continue
if field_type is bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"

View File

@@ -18,7 +18,4 @@ MOE_ARCH_BLOCK = {
"gpt_oss": "GptOssDecoderLayer",
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
"afmoe": "AfmoeMoE",
"glm4_moe": "Glm4MoeDecoderLayer",
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
}

View File

@@ -216,7 +216,7 @@ class TrainerBuilderBase(abc.ABC):
def _configure_warmup_and_logging(
self, total_num_steps: int, training_args_kwargs: dict
):
warmup_steps: int | float = 0
warmup_steps = 0
warmup_ratio = 0.0
if self.cfg.warmup_steps is not None:
warmup_steps = self.cfg.warmup_steps
@@ -230,10 +230,6 @@ class TrainerBuilderBase(abc.ABC):
else:
warmup_ratio = 0.03
# transformers v5
if warmup_ratio > 0.0 and warmup_steps == 0:
warmup_steps = warmup_ratio
if warmup_steps == 1:
warmup_steps = 2
@@ -246,6 +242,7 @@ class TrainerBuilderBase(abc.ABC):
else max(min(int(0.005 * total_num_steps), 10), 1)
)
training_args_kwargs["warmup_ratio"] = warmup_ratio
training_args_kwargs["warmup_steps"] = warmup_steps
def _configure_precision_settings(self, training_args_kwargs: dict):
@@ -409,9 +406,6 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.hub_revision:
training_args_kwargs["hub_revision"] = self.cfg.hub_revision
def _configure_save_and_eval_strategy(self, training_args_kwargs: dict):
# save_strategy and save_steps
if self.cfg.save_steps:
@@ -536,7 +530,9 @@ class TrainerBuilderBase(abc.ABC):
"loraplus_lr_ratio",
"loraplus_lr_embedding",
"output_dir",
"save_safetensors",
"save_only_model",
"include_tokens_per_second",
"weight_decay",
"seed",
"dion_momentum",
@@ -549,7 +545,6 @@ class TrainerBuilderBase(abc.ABC):
arg_map = {
"dion_learning_rate": "dion_lr",
"include_num_input_tokens_seen": "include_tokens_per_second",
}
for kwarg, cfg_arg in arg_map.items():
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:

View File

@@ -122,12 +122,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg))
if getattr(self.cfg, "generate_samples", False):
from axolotl.utils.callbacks.generation import SFTGenerationCallback
callbacks.append(SFTGenerationCallback(trainer))
LOG.info("SFT sample generation enabled")
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks
@@ -252,8 +246,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
ddp_find_unused_parameters
)
if self.cfg.group_by_length:
training_arguments_kwargs["train_sampling_strategy"] = "group_by_length"
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
@@ -380,18 +373,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple
if self.cfg.use_eaft:
from functools import partial
from axolotl.monkeypatch.loss.eaft import eaft_loss
configured_eaft_loss = partial(
eaft_loss,
alpha=self.cfg.eaft_alpha if self.cfg.eaft_alpha is not None else 1.0,
k=self.cfg.eaft_k if self.cfg.eaft_k is not None else 20,
)
trainer_kwargs["compute_loss_func"] = configured_eaft_loss
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
@@ -456,9 +437,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn) or (
self.cfg.micro_batch_size == 1 and is_eval is False
):
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn):
return None
if self.cfg.model_config_type == "mamba":

View File

@@ -11,6 +11,7 @@ from axolotl.core.trainers import (
)
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback
@@ -52,8 +53,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls_args = [self.model]
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1
)
@@ -134,17 +133,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
blocklist_args_kwargs.append("max_prompt_length")
# Handle when max_prompt_length == max_length from defaults
# CPOTrainer requires strictly less than
if (
training_args_kwargs["max_prompt_length"]
== training_args_kwargs["max_length"]
):
training_args_kwargs["max_prompt_length"] -= 1
elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig
blocklist_args_kwargs.append("max_prompt_length")
elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
blocklist_args_kwargs.append("max_prompt_length")
blocklist_args_kwargs = ["max_prompt_length"]
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
@@ -154,8 +157,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
)
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()

View File

@@ -25,7 +25,7 @@ from torch.utils.data import (
from transformers import PreTrainedModel, Trainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
from trl.trainer.utils import pad_to_length
from typing_extensions import override
@@ -719,20 +719,6 @@ class AxolotlTrainer(
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}")
# fix for Context Parallel save: CP eval invalidates tensor storage
# pointers, so clone to CPU to get fresh valid storage for safetensors
if (
state_dict is not None
and self.axolotl_cfg
and self.axolotl_cfg.context_parallel_size
and self.axolotl_cfg.context_parallel_size > 1
):
state_dict = {
k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
for k, v in state_dict.items()
}
supported_classes = (
(PreTrainedModel,)
if not is_peft_available()
@@ -743,7 +729,6 @@ class AxolotlTrainer(
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()
if isinstance(
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
supported_classes,
@@ -753,35 +738,43 @@ class AxolotlTrainer(
).save_pretrained(
output_dir,
state_dict=state_dict,
is_main_process=self.accelerator.is_main_process,
safe_serialization=self.args.save_safetensors,
)
else:
LOG.info(
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
)
safetensors.torch.save_file(
state_dict,
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
metadata={"format": "pt"},
)
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict,
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
metadata={"format": "pt"},
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
is_main_process=self.accelerator.is_main_process,
)
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
elif (
self.data_collator is not None
and hasattr(self.data_collator, "tokenizer")
and self.data_collator.tokenizer is not None
):
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
self.data_collator.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
elif (
self.data_collator is not None
and hasattr(self.data_collator, "tokenizer")
and self.data_collator.tokenizer is not None
):
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -57,18 +57,16 @@ class AxolotlDPOTrainer(
def tokenize_row(
features,
processing_class,
max_prompt_length: int | None = None,
max_completion_length: int | None = None,
add_special_tokens: bool = True,
is_chat: bool = False,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
max_prompt_length=max_prompt_length,
max_completion_length=max_completion_length,
add_special_tokens=add_special_tokens,
is_chat=is_chat,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:

View File

@@ -126,6 +126,9 @@ class GRPOStrategy:
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
if trl.rollout_func:
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
if trl.multi_objective_aggregation is not None:
grpo_args_kwargs["multi_objective_aggregation"] = (
trl.multi_objective_aggregation
@@ -151,8 +154,6 @@ class GRPOStrategy:
trainer_kwargs["reward_processing_classes"] = (
cfg.trl.reward_processing_classes
)
if cfg.trl and cfg.trl.rollout_func:
trainer_kwargs["rollout_func"] = cls.get_rollout_func(cfg.trl.rollout_func)
return trainer_kwargs
@@ -163,12 +164,7 @@ class GRPOStrategy:
@classmethod
def get_blocklist_args_kwargs(cls) -> list[str]:
return [
"dataset_num_proc",
"max_length",
"include_tokens_per_second",
"max_prompt_length",
]
return ["dataset_num_proc", "max_length", "include_tokens_per_second"]
@classmethod
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:

View File

@@ -25,7 +25,7 @@ class SchedulerMixin(Trainer):
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler(
self, num_training_steps: int, optimizer: None | torch.optim.Optimizer = None
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
) -> LRScheduler:
"""
Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or
@@ -45,13 +45,6 @@ class SchedulerMixin(Trainer):
and self.args.cosine_min_lr_ratio is not None
)
if optimizer is None:
if self.optimizer is None:
raise ValueError(
"Optimizer must be set before calling create_scheduler or passed as an argument."
)
optimizer = self.optimizer
# fmt: off
if self.lr_scheduler is None: # type: ignore
# fmt: on

View File

@@ -1,10 +1,12 @@
"""Module for TRL RL trainers"""
from trl import RewardTrainer
from trl.experimental.cpo import CPOTrainer
from trl.experimental.kto import KTOTrainer
from trl.experimental.orpo import ORPOTrainer
from trl.experimental.prm import PRMTrainer
from trl import (
CPOTrainer,
KTOTrainer,
ORPOTrainer,
PRMTrainer,
RewardTrainer,
)
from axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin

View File

@@ -8,11 +8,7 @@ from dataclasses import dataclass, field
from typing import Optional, Type
from transformers import TrainingArguments
from trl import RewardConfig
from trl.experimental.cpo import CPOConfig
from trl.experimental.kto import KTOConfig
from trl.experimental.orpo import ORPOConfig
from trl.experimental.prm import PRMConfig
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.integrations.config import merge_training_args

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"
```
## Usage
@@ -31,13 +31,11 @@ plugins:
## Supported Models
- afmoe
- apertus
- arcee
- cohere
- cohere2
- deepseek_v3
- exaone4
- gemma
- gemma2
- gemma3
@@ -47,17 +45,13 @@ plugins:
- glm
- glm4
- glm4_moe
- glm4_moe_lite
- glm46v
- glm4v
- glm4v_moe
- glm_image
- glm_moe_dsa
- gpt_oss
- granite
- granitemoe
- granitemoehybrid
- granitemoeshared
- granitemoehybrid
- hunyuan_v1_dense
- hunyuan_v1_moe
- internvl
@@ -78,26 +72,20 @@ plugins:
- olmo
- olmo2
- olmo3
- olmoe
- phi
- phi3
- phi4_multimodal
- qwen2
- qwen2_5_vl
- qwen2_moe
- qwen2_vl
- qwen2_moe
- qwen2_5_vl
- qwen3
- qwen3_5
- qwen3_5_text
- qwen3_5_moe
- qwen3_5_moe_text
- qwen3_moe
- qwen3_next
- qwen3_vl
- qwen3_vl_moe
- seed_oss
- qwen3_next
- smollm3
- step3p5
- seed_oss
- voxtral
## Citation

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"`'
)
@@ -104,7 +104,7 @@ class CutCrossEntropyPlugin(BasePlugin):
def patch_llama_like(
self,
model_type_to_patch: str,
model_type: str,
) -> None:
"""
Generic patch for model architectures with causal lm similar to llama
@@ -112,10 +112,7 @@ class CutCrossEntropyPlugin(BasePlugin):
from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic(
maybe_model,
patch_options,
remote_model_id: str | None,
model_type: str,
maybe_model, patch_options, model_type: str, remote_model_id: str | None
):
import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward
@@ -139,13 +136,11 @@ class CutCrossEntropyPlugin(BasePlugin):
f"Error: {str(e)}"
) from e
if model_type_to_patch not in PATCH_FNS:
if model_type not in PATCH_FNS:
LOG.warning_once(
"Setting up generic cce patch for model type: %s", model_type_to_patch
"Setting up generic cce patch for model type: %s", model_type
)
LOG.warning_once(
f"Generic Cut Cross Entropy + {model_type_to_patch} support is experimental and may not work as expected."
)
PATCH_FNS[model_type_to_patch] = partial(
patch_generic, model_type=model_type_to_patch
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
)
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)

View File

@@ -1,46 +0,0 @@
# Kernels Integration
MoE (Mixture of Experts) kernels speed up training for MoE layers and reduce VRAM costs. In transformers v5, `batched_mm` and `grouped_mm` were integrated as built-in options via the `experts_implementation` config kwarg:
```python
class ExpertsInterface(GeneralInterface):
_global_mapping = {
"batched_mm": batched_mm_experts_forward,
"grouped_mm": grouped_mm_experts_forward,
}
```
In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`.
## Usage
Add the following to your axolotl YAML config:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
```
**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`.
## How It Works
The `KernelsPlugin` runs before model loading and:
1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo.
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.).
## Limitations
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
## Note on MegaBlocks
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.

View File

@@ -1,7 +0,0 @@
from .args import KernelsArgs
from .plugin import KernelsPlugin
__all__ = [
"KernelsArgs",
"KernelsPlugin",
]

View File

@@ -1,48 +0,0 @@
from pydantic import BaseModel, model_validator
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class KernelsArgs(BaseModel):
use_scattermoe: bool | None = True
@model_validator(mode="before")
@classmethod
def check_use_kernels(cls, data):
if data.get("use_kernels") is not True:
LOG.warning(
"`use_kernels` must be set to True to use this. Automatically setting it to True."
)
data["use_kernels"] = True
return data
@model_validator(mode="before")
@classmethod
def check_experts_implementation(cls, data):
experts_implementation = data.get("experts_implementation")
if experts_implementation is None:
# transformers may default to batched_mm when unset
data["experts_implementation"] = "eager"
elif experts_implementation != "eager":
LOG.warning(
"`experts_implementation` must be set to 'eager' to use this. Automatically setting it to 'eager'."
)
data["experts_implementation"] = "eager"
return data
@model_validator(mode="before")
@classmethod
def disable_mlp_kernel_scattermoe(cls, data):
if data.get("use_scattermoe") is True:
if data.get("lora_mlp_kernel") is True:
LOG.warning(
"Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
)
data["lora_mlp_kernel"] = False
data["mlp_kernel"] = False
return data

View File

@@ -1,18 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
from . import layers
from .lora_ops import ParallelExperts
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
__all__ = [
"layers",
"ParallelExperts",
"flatten_sort_count",
"parallel_linear",
"ScatterMoELoRA",
"parallel_linear_lora",
"lora_ops",
]

View File

@@ -1,12 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
#
# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors
# Adapted from https://github.com/shawntan/scattermoe
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
#
# Modifications and LoRA adaptation Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
from . import lora_ops, ops
__all__ = ["ops", "lora_ops"]

View File

@@ -1,645 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/shawntan/scattermoe
# Copyright (c) Shawn Tan and ScatterMoE Contributors
# Licensed under the Apache License, Version 2.0
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
from typing import Optional
import torch
import triton
import triton.language as tl
BLOCK_M = 128
ALLOW_TF32 = True
@triton.jit
def _compute_expert_block(
E_idx,
E_mask,
M_in_idx,
N_block,
N_mask,
X_ptr,
stride_xm,
stride_xk,
W_ptr,
stride_we,
stride_wk,
stride_wn,
K,
acc,
no_k_mask,
BLOCK_K,
allow_tf32=True,
):
K_block = tl.arange(0, BLOCK_K)
X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
W_blk_ptrs = (
W_ptr
+ K_block[:, None] * stride_wk
+ N_block[None, :] * stride_wn
+ E_idx * stride_we
)
iters = tl.cdiv(K, BLOCK_K)
for K_block_id in range(iters):
if no_k_mask:
x = tl.load(X_blk_ptrs, mask=E_mask[:, None])
w = tl.load(W_blk_ptrs, mask=N_mask[None, :])
else:
K_mask = (K_block_id * BLOCK_K + K_block) < K
x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])
X_blk_ptrs += BLOCK_K * stride_xk
W_blk_ptrs += BLOCK_K * stride_wk
acc = tl.dot(x, w, acc, allow_tf32=allow_tf32)
return acc
def _scatter2scatter_configs():
return [
triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4),
]
@triton.autotune(
configs=_scatter2scatter_configs(),
key=["M", "N", "K"],
)
@triton.heuristics(
{
"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0,
"NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0,
}
)
@triton.jit
def _scatter2scatter(
X_ptr,
stride_xm: tl.constexpr,
stride_xk: tl.constexpr,
W_ptr,
stride_we,
stride_wk: tl.constexpr,
stride_wn: tl.constexpr,
Y_ptr,
stride_ym: tl.constexpr,
stride_yn: tl.constexpr,
B_ptr,
stride_be: tl.constexpr,
stride_bn: tl.constexpr,
grouped_idx_ptr,
expert_idxs_ptr,
# block_start_idx_ptr,
FAN_OUT: tl.constexpr,
M,
K: tl.constexpr,
N: tl.constexpr,
E: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
# OUT_M,
allow_tf32: tl.constexpr,
x_grouped: tl.constexpr,
y_grouped: tl.constexpr,
NO_K_MASK: tl.constexpr,
NO_N_MASK: tl.constexpr,
):
pid = tl.program_id(axis=0)
N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)
M_block_id = pid // N_BLOCK_COUNT
N_block_id = pid % N_BLOCK_COUNT
M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_block < N
M_boundary_mask = M_block < (FAN_OUT * M)
E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E)
no_k_mask = K % BLOCK_K == 0
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
E_first_idx = tl.min(E_idxs)
E_last_idx = tl.minimum(tl.max(E_idxs), E - 1)
M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32)
for E_idx in range(E_first_idx, E_last_idx + 1):
E_mask = E_idxs == E_idx
E_M_idx = M_idx
if x_grouped:
M_in_idx = M_block
else:
M_in_idx = E_M_idx // FAN_OUT
acc = _compute_expert_block(
E_idx,
E_mask,
M_in_idx,
N_block,
N_mask,
X_ptr,
stride_xm,
stride_xk,
W_ptr,
stride_we,
stride_wk,
stride_wn,
K,
acc,
no_k_mask,
BLOCK_K,
allow_tf32=allow_tf32,
)
if B_ptr is not None:
B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn
acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :])
if y_grouped:
M_out_idx = M_block
else:
M_out_idx = M_idx
Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
def scatter2scatter(
X,
W,
sorted_expert_idxs,
sorted_scattered_idxs,
k,
b=None,
x_grouped=False,
y_grouped=False,
out=None,
):
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
# Pre-kernel setup
y_dim = W.size(-1)
L_scattered = sorted_expert_idxs.size(0)
if out is None:
output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)
else:
assert out.size(0) == L_scattered and out.size(1) == y_dim
output = out
scatter2scatter_compileable(
output,
W,
X,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
b,
x_grouped,
y_grouped,
)
return output
@torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"output"})
def scatter2scatter_compileable(
output: torch.Tensor,
W: torch.Tensor,
X: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
b: Optional[torch.Tensor],
x_grouped: bool,
y_grouped: bool,
) -> None:
def grid(META):
grid_num = (
triton.cdiv(sorted_expert_idxs.size(0), META["BLOCK_M"])
* triton.cdiv(META["N"], META["BLOCK_N"]),
)
return grid_num
if b is None:
b = None
stride_be = stride_bn = 0
else:
stride_be, stride_bn = b.stride()
_scatter2scatter[grid](
# X_ptr, stride_xm, stride_xk,
X,
X.stride(0),
X.stride(1),
# W_ptr, stride_we, stride_wk, stride_wn,
W,
W.stride(0),
W.stride(1),
W.stride(2),
# Y_ptr, stride_ym, stride_yn,
output,
output.stride(0),
output.stride(1),
# B_ptr, stride_be, stride_bn
b,
stride_be,
stride_bn,
grouped_idx_ptr=sorted_scattered_idxs,
expert_idxs_ptr=sorted_expert_idxs,
# block_start_idx_ptr=padded_block_idxs,
FAN_OUT=k,
M=X.size(0),
K=X.size(1),
N=output.size(1),
E=W.size(0),
BLOCK_M=BLOCK_M,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
x_grouped=x_grouped,
y_grouped=y_grouped,
)
def _config_XtY():
return [
triton.Config(
{"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4
),
]
def group_bwd_W(DY, X, expert_offsets, E, has_bias=False):
DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)
DW = DWt.permute(0, 2, 1)
if has_bias:
Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype)
else:
Db = None
groupXtY_compileable(E, DW, Db, DY, X, expert_offsets)
return DW, Db
@torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW", "Db"})
def groupXtY_compileable(
E: int,
DW: torch.Tensor,
Db: Optional[torch.Tensor],
DY: torch.Tensor,
X: torch.Tensor,
expert_offsets: torch.Tensor,
) -> None:
def grid(META):
grid = (
E * triton.cdiv(META["K"], META["BLOCK_K"]),
triton.cdiv(META["N"], META["BLOCK_N"]),
)
return grid
if Db is None:
stride_dbe = 0
stride_dbn = 0
else:
stride_dbe, stride_dbn = Db.stride()
_groupXtY[grid](
# DY_ptr, stride_dym, stride_dyk,
DY,
DY.stride(0),
DY.stride(1),
# X_ptr, stride_xm, stride_xn,
X,
X.stride(0),
X.stride(1),
# DW_ptr, stride_dwe, stride_dwk, stride_dwn,
DW,
DW.stride(0),
DW.stride(1),
DW.stride(2),
# Db_ptr, stride_dwe, stride_dbn,
Db,
stride_dbe,
stride_dbn,
# expert_offsets_ptr,
expert_offsets,
# K: tl.constexpr, N: tl.constexpr,
M=DY.size(0),
N=DY.size(-1),
K=X.size(-1),
# ACC_TYPE: tl.constexpr,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
)
@triton.autotune(
configs=_config_XtY(),
key=["M", "N", "K"],
)
@triton.heuristics(
{
"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0,
"NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0,
}
)
@triton.jit
def _groupXtY(
DY_ptr,
stride_dym,
stride_dyk,
X_ptr,
stride_xm,
stride_xn,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
expert_offsets_ptr,
M,
K: tl.constexpr,
N: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
allow_tf32: tl.constexpr,
NO_K_MASK: tl.constexpr,
NO_N_MASK: tl.constexpr,
):
pid0 = tl.program_id(axis=0)
pid1 = tl.program_id(axis=1)
num0 = tl.num_programs(0)
num1 = tl.num_programs(1)
# pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)
pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4)
K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
E_idx = pid0 // K_BLOCK_COUNT
K_block_id = pid0 % K_BLOCK_COUNT
N_block_id = pid1
if E_idx == 0:
start_idx = 0
else:
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
if end_idx > start_idx:
M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)
K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
K_mask = K_block < K
K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_block < N
N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
M_idxs = M_block
xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
dy_blk_ptrs = (
DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
)
if (Db_ptr is not None) and (K_block_id == 0):
_xty_and_bias(
E_idx,
start_idx,
end_idx,
M_block,
K_block,
K_mask,
N_block,
N_mask,
dy_blk_ptrs,
stride_dym,
xt_blk_ptrs,
stride_xm,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
BLOCK_M,
BLOCK_N,
BLOCK_K,
ACC_TYPE,
allow_tf32,
NO_K_MASK,
NO_N_MASK,
compute_bias=True,
)
else:
_xty_and_bias(
E_idx,
start_idx,
end_idx,
M_block,
K_block,
K_mask,
N_block,
N_mask,
dy_blk_ptrs,
stride_dym,
xt_blk_ptrs,
stride_xm,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
BLOCK_M,
BLOCK_N,
BLOCK_K,
ACC_TYPE,
allow_tf32,
NO_K_MASK,
NO_N_MASK,
compute_bias=False,
)
@triton.jit
def _xty_and_bias(
E_idx,
start_idx,
end_idx,
M_block,
K_block,
K_mask,
N_block,
N_mask,
dy_blk_ptrs,
stride_dym,
xt_blk_ptrs,
stride_xm,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
BLOCK_M,
BLOCK_N,
BLOCK_K,
ACC_TYPE,
allow_tf32,
NO_K_MASK,
NO_N_MASK,
compute_bias: tl.constexpr,
):
if compute_bias:
db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE)
else:
db_acc = None
acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
for i in range(0, iters):
M_mask = (i * BLOCK_M + M_block) < end_idx
if NO_K_MASK:
xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
else:
xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
if NO_N_MASK:
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
else:
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)
xt_blk_ptrs += BLOCK_M * stride_xm
dy_blk_ptrs += BLOCK_M * stride_dym
if compute_bias:
db_acc += tl.sum(dy, axis=0)
DW_blk_ptrs = (
DW_ptr
+ E_idx * stride_dwe
+ K_block[:, None] * stride_dwk
+ N_block[None, :] * stride_dwn
)
acc = acc.to(DW_blk_ptrs.dtype.element_ty)
tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])
if compute_bias:
Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn
tl.store(Db_blk_ptrs, db_acc, mask=N_mask)
def _config_grouping():
return [
triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
]
def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):
N = sorted_expert_idxs.size(0)
K = A.size(1)
assert A.size(0) * fan_out == N
if out is not None:
Y = out
else:
Y = torch.empty((N, K), dtype=A.dtype, device=A.device)
group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs)
return Y
@torch.library.custom_op("scattermoe::group", mutates_args={"Y"})
def group_compileable(
A: torch.Tensor,
K: int,
N: int,
Y: torch.Tensor,
coeff: Optional[torch.Tensor],
has_coeff: bool,
fan_out: int,
sorted_expert_idxs: torch.Tensor,
) -> None:
def grid(META):
grid_num = (triton.cdiv(META["N"], META["BLOCK_N"]),)
return grid_num
_group[grid](
# A_ptr, stride_an, stride_ai,
A,
A.stride(0),
A.stride(1),
has_coeff,
coeff,
fan_out,
# Y_ptr, stride_yn, stride_yk,
Y,
Y.stride(0),
Y.stride(1),
# grouped_idx_ptr,
sorted_expert_idxs,
# N: tl.constexpr, K: tl.constexpr,
N,
K,
)
@triton.autotune(configs=_config_grouping(), key=["K"])
@triton.heuristics({"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0})
@triton.jit
def _group(
src_ptr,
stride_sn,
stride_sk,
has_coeff: tl.constexpr,
coeff_ptr,
FAN_OUT: tl.constexpr,
tgt_ptr,
stride_tn,
stride_ti,
grouped_idx_ptr,
N,
K: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
NO_K_MASK: tl.constexpr,
):
pid = tl.program_id(axis=0)
N_block_id = pid
N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_blk < N
N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)
N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)
K_blk = tl.arange(0, BLOCK_K)
src_blk_ptrs = (
src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
)
tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
if has_coeff:
c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
iters = tl.cdiv(K, BLOCK_K)
for i in range(0, iters):
if NO_K_MASK or i < iters - 1:
block = tl.load(src_blk_ptrs, mask=N_mask[:, None])
if has_coeff:
block *= c
tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])
else:
K_mask = (i * BLOCK_K + K_blk) < K
mask = N_mask[:, None] & K_mask[None, :]
block = tl.load(src_blk_ptrs, mask=mask)
if has_coeff:
block *= c
tl.store(tgt_blk_ptrs, block, mask=mask)
src_blk_ptrs += BLOCK_K * stride_sk
tgt_blk_ptrs += BLOCK_K * stride_ti

View File

@@ -1,98 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/shawntan/scattermoe
# Copyright (c) Shawn Tan and ScatterMoE Contributors
# Licensed under the Apache License, Version 2.0
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
import torch
import triton
import triton.language as tl
@triton.jit
def _single2scatter(
X_ptr,
stride_xm,
stride_xk,
W_ptr,
stride_we,
stride_wk,
stride_wn,
Y_ptr,
stride_ym,
stride_yn,
expert_idxs_ptr,
FAN_OUT: tl.constexpr,
K: tl.constexpr,
N: tl.constexpr,
E: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
pid0 = tl.program_id(axis=0)
pid1 = tl.program_id(axis=1)
N_block_id = pid0
if FAN_OUT == 1:
in_idx = pid1
else:
in_idx = 0
out_idx = pid1
K_block = tl.arange(0, BLOCK_K)
N_block = tl.max_contiguous(
tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N),
BLOCK_N,
)
E_idx = tl.load(expert_idxs_ptr + pid1)
X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
W_blk_ptrs = (
W_ptr
+ E_idx * stride_we
+ K_block[:, None] * stride_wk
+ N_block[None, :] * stride_wn
)
N_mask = N_block < N
acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
for _K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
K_mask = K_block < K
x = tl.load(X_blk_ptrs, mask=K_mask[:, None], other=0.0)
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :], other=0.0)
acc += tl.sum(x * w, axis=0)[None, :]
X_blk_ptrs += BLOCK_K * stride_xk
W_blk_ptrs += BLOCK_K * stride_wk
K_block += BLOCK_K
Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
tl.store(Y_blk_ptrs, acc, mask=N_mask[None, :])
def single2scatter(X, W, expert_idxs):
E, xdim, ydim = W.size()
k = expert_idxs.size(1)
assert X.size(0) == k or X.size(0) == 1
Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
BLOCK_N = 128
BLOCK_K = 128
grid = triton.cdiv(ydim, BLOCK_N), k
_single2scatter[grid](
X,
X.stride(0),
X.stride(1),
W,
W.stride(0),
W.stride(1),
W.stride(2),
Y,
Y.stride(0),
Y.stride(1),
expert_idxs,
FAN_OUT=Y.size(0) // X.size(0),
K=xdim,
N=ydim,
E=E,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
ACC_TYPE=tl.float32,
)
return Y

View File

@@ -1,439 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
#
# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors
# Adapted from https://github.com/shawntan/scattermoe
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
#
# Modifications and LoRA adaptation Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
ScatterMoE layer replacements for HuggingFace MoE architectures.
Provides drop-in forward replacements that use ScatterMoE kernels for
acceleration. When used via the HF ``kernels`` library
(``replace_kernel_forward_from_hub``), these classes replace the forward
method of the original MoE block.
LoRA support
------------
When peft wraps parameters via ``target_parameters``, the ``self.experts``
submodule becomes a chain of ``ParamWrapper`` objects and the ``self.gate``
router may also become a ``ParamWrapper``. The ``HFScatterMoEGatedMLP``
forward detects this and automatically:
1. Unwraps ``self.gate`` to the base router, applying gate LoRA delta
2. Unwraps ``self.experts`` to the base ``OlmoeExperts`` module
3. Extracts LoRA A/B weights and scaling from each wrapper
4. Converts B layout from peft rank-major to scattermoe expert-major
5. Routes to ``parallel_linear_lora`` for fused LoRA computation
6. Passes through ``self.shared_expert`` / ``self.shared_expert_gate``
(peft wraps their linear layers with standard LoRA, no special handling)
"""
import torch
from torch import nn
from torch.nn import functional as F
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora
# =============================================================================
# LoRA layout conversion utilities (peft <-> scattermoe)
# =============================================================================
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
"""Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe
expert-major ``[N, r*E]``.
peft reshapes B to ``[out, r, E]`` (rank-major).
scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).
"""
N = peft_B.shape[0]
return (
peft_B.reshape(N, rank, num_experts)
.permute(0, 2, 1)
.contiguous()
.reshape(N, num_experts * rank)
)
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Convert peft LoRA weights to scattermoe layout (with A<->B swap).
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
where ``in_features=dim1, out_features=dim2``. ScatterMoE transposes the
parameter (``W = param.transpose(2, 1)``) giving ``[E, dim2, dim1]`` with
``K=dim2, N=dim1``. Because of this transposition, peft's A and B roles
are swapped relative to scattermoe's convention.
peft gives:
lora_A ``[r*E, dim1]``, lora_B ``[dim2, r*E]``
scattermoe needs:
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
This function swaps A<->B and converts B from rank-major to expert-major.
Uses vectorized tensor operations (no Python loop over experts).
Works for **both** gate_up_proj and down_proj since the transposition
issue is the same for any parameter.
"""
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
dim1 = peft_A.shape[1] # peft in_features -> scattermoe N
dim2 = peft_B_em.shape[0] # peft out_features -> scattermoe K
# smoe_A: per expert, transpose B_e [dim2, r] -> [r, dim2]
# [dim2, E*r] -> [dim2, E, r] -> [E, r, dim2] -> [E*r, dim2]
smoe_A = (
peft_B_em.reshape(dim2, num_experts, rank)
.permute(1, 2, 0)
.contiguous()
.reshape(rank * num_experts, dim2)
)
# smoe_B: per expert, transpose A_e [r, dim1] -> [dim1, r]
# [E*r, dim1] -> [E, r, dim1] -> [dim1, E, r] -> [dim1, E*r]
smoe_B = (
peft_A.reshape(num_experts, rank, dim1)
.permute(2, 0, 1)
.contiguous()
.reshape(dim1, num_experts * rank)
)
return smoe_A, smoe_B
def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Deprecated alias for :func:`peft_lora_to_scattermoe`."""
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
# =============================================================================
# ParamWrapper unwrapping
# =============================================================================
def _unwrap_gate_lora(gate_module):
"""Unwrap peft ``ParamWrapper`` on the router gate.
When peft targets ``gate.weight``, ``self.gate`` becomes::
ParamWrapper(weight)
-> base_layer: OlmoeTopKRouter (the real module)
This function detects the wrapping and returns the base router, its
weight tensor, and an optional LoRA delta tensor.
Returns:
(base_gate, gate_weight, gate_lora_delta_or_None)
``base_gate`` is the original router module (with ``.top_k``,
``.num_experts``, ``.norm_topk_prob``).
``gate_weight`` is the base router weight (may be a DTensor under FSDP).
``gate_lora_delta_or_None`` is the LoRA delta tensor if LoRA is active,
else ``None``. Kept separate to avoid mixing DTensor + Tensor in an add.
"""
if hasattr(gate_module, "base_layer") and hasattr(gate_module, "lora_A"):
base_gate = gate_module.base_layer
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gate_module)
if lora_A is not None:
# gate weight: [num_experts, hidden_size]
# lora_A: [r, hidden_size], lora_B: [num_experts, r]
# delta = scaling * B @ A = [num_experts, hidden_size]
delta = scaling * (lora_B @ lora_A)
return base_gate, base_gate.weight, delta
else:
return base_gate, base_gate.weight, None
else:
# No wrapping — gate is the original module
return gate_module, gate_module.weight, None
def _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling):
"""Convert peft LoRA weights to scattermoe layout."""
smoe_A, smoe_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank)
return (smoe_A, smoe_B, scaling)
def _unwrap_experts_lora(experts_module):
"""Walk a peft ``ParamWrapper`` chain on ``self.experts``.
When peft targets ``experts.gate_up_proj`` and ``experts.down_proj`` via
``target_parameters``, ``self.experts`` becomes a nested chain::
ParamWrapper(down_proj)
-> base_layer: ParamWrapper(gate_up_proj)
-> base_layer: OlmoeExperts (the real module)
This function walks the chain, collects LoRA params keyed by
``parameter_name``, and returns the base experts module.
Returns:
(base_experts, gup_lora, down_lora)
Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``.
A/B are already in scattermoe layout.
"""
# Collect ParamWrapper layers by their parameter_name
wrappers = {}
module = experts_module
while hasattr(module, "base_layer") and hasattr(module, "lora_A"):
param_name = getattr(module, "parameter_name", None)
if param_name is not None:
wrappers[param_name] = module
module = module.base_layer
base_experts = module
if not wrappers:
return base_experts, None, None
# Determine num_experts from base module
num_experts = getattr(base_experts, "num_experts", None)
if num_experts is None:
# Fallback: infer from parameter shape
gup = getattr(base_experts, "gate_up_proj", None)
if gup is not None:
num_experts = gup.shape[0]
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
gup_lora = None
gup_wrapper = wrappers.get("gate_up_proj")
if gup_wrapper is not None:
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper)
if lora_A is not None:
rank = lora_A.shape[0] // num_experts
gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
# Extract down_proj LoRA (needs A<->B swap due to transposition)
down_lora = None
down_wrapper = wrappers.get("down_proj")
if down_wrapper is not None:
lora_A, lora_B, scaling = get_lora_params_from_wrapper(down_wrapper)
if lora_A is not None:
rank = lora_A.shape[0] // num_experts
down_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
return base_experts, gup_lora, down_lora
# =============================================================================
# Layer classes
# =============================================================================
class ScatterMoEGatedMLP(nn.Module):
def forward(self, layer_input):
"""
Forward pass of the mixture of experts layer.
Args:
layer_input (Tensor):
Input tensor.
Returns:
Tensor:
Output tensor.
"""
bsz, length, emb_size = layer_input.size()
layer_input = layer_input.reshape(-1, emb_size)
# compute the top_k routing decision
router_logits = self.router.layer(layer_input)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.router.top_k, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(layer_input.dtype)
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
selected_experts, num_experts=self.router.num_experts
)
# compute experts
gates, h = parallel_linear(
layer_input,
self.input_linear.weight.transpose(2, 1),
self.router.top_k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
grouped_in=False,
grouped_out=True,
).chunk(2, dim=-1)
h = self.activation(gates) * h
layer_output = parallel_linear(
h,
self.output_linear.weight.transpose(2, 1),
1,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
grouped_in=True,
grouped_out=False,
gates=routing_weights,
)
layer_output = layer_output.view(bsz, length, emb_size)
return layer_output
class HFScatterMoEGatedMLP(nn.Module):
"""
ScatterMoE-accelerated forward pass for HF MoEs (OLMoE / Qwen2MoE).
Used as a kernel layer via the HF ``kernels`` library. The ``forward``
method replaces the original ``OlmoeSparseMoeBlock.forward``.
Supports both full-parameter training and LoRA fine-tuning:
* **Full-param**: uses ``parallel_linear`` (base ScatterMoE kernel)
* **LoRA**: detects peft ``ParamWrapper`` on ``self.experts``, extracts
adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
"""
@staticmethod
def forward(self: nn.Module, layer_input: torch.Tensor):
"""
Forward pass using ScatterMoE kernels.
Args:
self: The MoeSparseMoeBlock module containing:
- self.gate: Router (or peft ParamWrapper wrapping it)
- self.experts: Experts module (or peft ParamWrapper chain)
- self.shared_expert: Optional shared expert (e.g. Qwen2MoE)
- self.shared_expert_gate: Optional shared expert gate
layer_input: Input tensor [batch_size, seq_len, hidden_size]
Returns:
Tensor: [batch_size, seq_len, hidden_size]
"""
batch_size, sequence_length, hidden_dim = layer_input.shape
hidden_states_flat = layer_input.view(-1, hidden_dim)
# ====================================================================
# Shared Expert (if present, e.g. Qwen2MoE)
# ====================================================================
# peft wraps individual linear layers inside shared_expert with
# standard LoRA — calling forward() handles this transparently.
if hasattr(self, "shared_expert") and self.shared_expert is not None:
shared_expert_output = self.shared_expert(hidden_states_flat)
# shared_expert_gate may also be peft-wrapped (standard LoRA
# on nn.Linear), its forward() applies LoRA automatically.
shared_expert_gate_output = F.sigmoid(
self.shared_expert_gate(hidden_states_flat)
)
shared_expert_output = shared_expert_output * shared_expert_gate_output
else:
shared_expert_output = None
# ====================================================================
# Router Computation (with optional gate LoRA)
# ====================================================================
base_gate, gate_weight, gate_lora_delta = _unwrap_gate_lora(self.gate)
router_logits = F.linear(hidden_states_flat, gate_weight)
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states_flat, gate_lora_delta
)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
top_k = base_gate.top_k
num_experts = base_gate.num_experts
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
if base_gate.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states_flat.dtype)
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
selected_experts, num_experts=num_experts
)
# ====================================================================
# Detect LoRA (peft ParamWrapper) and extract adapter weights
# ====================================================================
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
# ====================================================================
# Gate + Up projection
# ====================================================================
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup = parallel_linear_lora(
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_A=gup_A,
lora_B=gup_B,
scaling=gup_scaling,
grouped_in=False,
grouped_out=True,
use_fused_dX=True,
use_fused_gather=True,
)
else:
gup = parallel_linear(
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
grouped_in=False,
grouped_out=True,
)
gates, h = gup.chunk(2, dim=-1)
h = experts.act_fn(gates) * h
# ====================================================================
# Down projection
# ====================================================================
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
if down_lora is not None:
down_A, down_B, down_scaling = down_lora
expert_output = parallel_linear_lora(
h,
down_W,
1,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_A=down_A,
lora_B=down_B,
scaling=down_scaling,
gates=routing_weights,
grouped_in=True,
grouped_out=False,
use_fused_dX=True,
use_fused_gather=True,
)
else:
expert_output = parallel_linear(
h,
down_W,
1,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
grouped_in=True,
grouped_out=False,
gates=routing_weights,
)
# ====================================================================
# Combine with shared expert and reshape
# ====================================================================
if shared_expert_output is not None:
expert_output = expert_output + shared_expert_output
expert_output = expert_output.view(batch_size, sequence_length, hidden_dim)
return expert_output

View File

@@ -1,99 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
ParallelExperts module with LoRA support.
Provides a drop-in replacement for ScatterMoE's ParallelExperts that
uses the fused LoRA kernel when adapter weights are attached.
"""
from typing import Optional
import torch
import torch.nn as nn
from .parallel_linear_lora import parallel_linear_lora
class ParallelExperts(nn.Module):
"""
Parallel Experts with fused LoRA support.
Drop-in replacement for the original ParallelExperts. When LoRA parameters
are attached via set_lora(), the forward pass uses a fused kernel:
Y = X @ W + scaling * (X @ A^T) @ B^T
"""
def __init__(
self,
num_experts: int,
input_size: int,
output_size: int,
bias: bool = False,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
if bias:
self.bias = nn.Parameter(torch.empty(num_experts, output_size))
else:
self.bias = None
self.num_experts = num_experts
self.input_size = input_size
self.output_size = output_size
self._lora_A: torch.Tensor | None = None
self._lora_B: torch.Tensor | None = None
self._lora_scaling: float | None = None
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, std=0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)
def extra_repr(self) -> str:
return (
f"num_experts={self.num_experts}, "
f"input_size={self.input_size}, "
f"output_size={self.output_size}"
)
def set_lora(self, lora_A: torch.Tensor, lora_B: torch.Tensor, scaling: float):
"""Attach LoRA parameters for fused computation."""
self._lora_A = lora_A
self._lora_B = lora_B
self._lora_scaling = scaling
def clear_lora(self):
"""Remove LoRA parameters."""
self._lora_A = None
self._lora_B = None
self._lora_scaling = None
def forward(
self,
inputs: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
gates: Optional[torch.Tensor] = None,
grouped_in: bool = False,
grouped_out: bool = False,
) -> torch.Tensor:
return parallel_linear_lora(
inputs,
self.weight.permute(0, 2, 1), # [E, input, output]
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_A=self._lora_A,
lora_B=self._lora_B,
scaling=self._lora_scaling if self._lora_scaling is not None else 1.0,
expert_biases=self.bias,
gates=gates,
grouped_in=grouped_in,
grouped_out=grouped_out,
)

View File

@@ -1,253 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/shawntan/scattermoe
# Copyright (c) Shawn Tan and ScatterMoE Contributors
# Licensed under the Apache License, Version 2.0
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
from typing import Optional
import torch
import torch.nn as nn
from . import kernels
@torch.library.custom_op("scattermoe::bincount", mutates_args={})
def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor:
return x.bincount(minlength=minlength)
@compileable_bincount.register_fake
def _(x: torch.Tensor, minlength: int) -> torch.Tensor:
return torch.empty(minlength, dtype=torch.long, device=x.device)
@torch.compile
def flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int):
with torch.no_grad():
flattened_expert_idxs = expert_idxs.flatten()
sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)
expert_counts = compileable_bincount(
flattened_expert_idxs, minlength=num_experts
)
expert_offsets = expert_counts.cumsum(-1)
return sorted_expert_idxs, sorted_scattered_idxs, expert_offsets
class ParallelLinear(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: torch.Tensor,
expert_weights: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
expert_biases: Optional[torch.Tensor] = None,
gates: Optional[torch.Tensor] = None,
grouped_in: bool = False,
grouped_out: bool = False,
):
with torch.device(x.device):
output = kernels.ops.scatter2scatter(
X=x,
W=expert_weights,
b=expert_biases,
k=k,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
x_grouped=grouped_in,
y_grouped=grouped_out,
)
if gates is not None:
output_expanded = output.view(
gates.size(0), gates.size(1), output.size(-1)
)
output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)
else:
output_expanded = None
ctx.save_for_backward(
x,
expert_weights,
expert_biases,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
)
ctx.grouped_in = grouped_in
ctx.grouped_out = grouped_out
ctx.k = k
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
with torch.device(grad_out.device):
(
x,
expert_weights,
expert_biases,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
) = ctx.saved_tensors
k = ctx.k
grouped_in = ctx.grouped_in
grouped_out = ctx.grouped_out
if gates is not None:
# calculate gates gradient
# d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)
gates_flat = gates.flatten()
gate_fan = gates.size(1)
grouped_grad_out = output_expanded.flatten(
0, 1
) # reuse expanded buffer later
else:
d_gates = None
gates_flat = None
gate_fan = 1
grouped_grad_out = None
if grouped_out:
grouped_grad_out = grad_out
else:
grouped_grad_out = kernels.ops.group(
grad_out,
sorted_scattered_idxs,
fan_out=gate_fan,
coeff=gates_flat,
out=grouped_grad_out,
)
if grouped_in:
grouped_x = x
d_expanded_input = None
else:
grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k)
d_expanded_input = grouped_x
d_weights, d_biases = kernels.ops.group_bwd_W(
DY=grouped_grad_out,
X=grouped_x,
expert_offsets=expert_offsets,
E=expert_weights.size(0),
has_bias=expert_biases is not None,
)
d_expanded_input = kernels.ops.scatter2scatter(
X=grouped_grad_out,
x_grouped=True,
W=expert_weights.permute(0, 2, 1),
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
y_grouped=grouped_in,
out=d_expanded_input, # Reuse grouped_x buffer
)
if k == 1:
d_input = d_expanded_input
else:
d_input = d_expanded_input.view(
x.size(0), k, d_expanded_input.size(-1)
).sum(-2)
return (
# x, expert_weights,
d_input,
d_weights,
# k, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
None,
None,
None,
None,
# bias, gates
d_biases,
d_gates,
# grouped_in, grouped_out,
None,
None,
)
def parallel_linear(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
expert_biases=None,
gates=None,
grouped_in=False,
grouped_out=False,
):
results = ParallelLinear.apply(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
expert_biases,
gates,
grouped_in,
grouped_out,
)
return results
class ParallelExperts(nn.Module):
def __init__(self, num_experts, input_size, output_size, bias=False) -> None:
super().__init__()
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
if bias:
self.bias = nn.Parameter(torch.empty(num_experts, output_size))
else:
self.bias = None
self.num_experts = num_experts
self.input_size = input_size
self.output_size = output_size
self.reset_parameters()
def extra_repr(self):
return "num_experts={}, input_size={}, output_size={}".format(
self.num_experts, self.input_size, self.output_size
)
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, std=0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(
self,
inputs,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates=None,
grouped_in=False,
grouped_out=False,
):
results = parallel_linear(
inputs,
self.weight.permute(0, 2, 1),
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
expert_biases=self.bias,
gates=gates,
grouped_in=grouped_in,
grouped_out=grouped_out,
)
return results

View File

@@ -1,480 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
ScatterMoE + LoRA Autograd Function
====================================
Provides the autograd function and Python interface for fused ScatterMoE + LoRA.
Key design for LoRA training:
- Expert weights W are FROZEN (no gradient computed for W).
- Only LoRA adapter weights (A, B) receive gradients.
- The input gradient dX is still computed (needed for upstream layers).
- This avoids the expensive group_bwd_W computation entirely.
Forward:
Y = X @ W + scaling * (X @ A^T) @ B^T
Backward (W frozen):
dX = dY @ W^T + scaling * (dY @ B) @ A (via scatter2scatter for base, separate for LoRA)
dA = scaling * (dY @ B)^T @ X (per-expert, on grouped data)
dB = scaling * dY^T @ (X @ A^T) (per-expert, on grouped data)
"""
from typing import Optional
import torch
from .kernels import ops as base_ops
from .kernels.lora_ops import (
group_bwd_lora,
group_bwd_lora_fused,
scatter2scatter_lora,
scatter2scatter_lora_dX,
)
class ScatterMoELoRA(torch.autograd.Function):
"""
Autograd function for fused ScatterMoE + LoRA with frozen expert weights.
This function is optimized for the LoRA fine-tuning scenario where:
- Expert weights W are frozen (requires_grad=False)
- Only LoRA A and B matrices receive gradients
- Input gradients are computed for upstream layer backprop
"""
@staticmethod
def forward(
ctx,
x: torch.Tensor,
expert_weights: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
scaling: float,
expert_biases: Optional[torch.Tensor] = None,
gates: Optional[torch.Tensor] = None,
grouped_in: bool = False,
grouped_out: bool = False,
use_fused_dX: bool = False,
use_fused_gather: bool = False,
):
with torch.device(x.device):
# Fused forward: Y = X @ W + scaling * (X @ A^T) @ B^T
output = scatter2scatter_lora(
X=x,
W=expert_weights,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
b=expert_biases,
x_grouped=grouped_in,
y_grouped=grouped_out,
)
# Handle gating (weighted combination of top-k expert outputs)
if gates is not None:
output_expanded = output.view(
gates.size(0), gates.size(1), output.size(-1)
)
output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)
else:
output_expanded = None
ctx.save_for_backward(
x,
lora_A,
lora_B,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
)
# Store frozen weights as plain Python attributes instead of
# save_for_backward. This avoids:
# 1. Version-check conflicts with FSDP unshard/reshard
# 2. Pinning all-gathered parameters via saved_tensors hooks
# 3. Interfering with activation offloading pack/unpack hooks
# Safe because expert_weights are frozen (requires_grad=False).
ctx.expert_weights = expert_weights
ctx.expert_biases = expert_biases
ctx.grouped_in = grouped_in
ctx.grouped_out = grouped_out
ctx.k = k
ctx.scaling = scaling
ctx.use_fused_dX = use_fused_dX
ctx.use_fused_gather = use_fused_gather
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
with torch.device(grad_out.device):
(
x,
lora_A,
lora_B,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
) = ctx.saved_tensors
expert_weights = ctx.expert_weights
k = ctx.k
scaling = ctx.scaling
grouped_in = ctx.grouped_in
grouped_out = ctx.grouped_out
E = expert_weights.size(0)
# ------------------------------------------------------------------
# Gate gradients (if using top-k gating with routing weights)
# ------------------------------------------------------------------
if gates is not None:
# d_gates[t, j] = output_expanded[t, j, :] . grad_out[t, :]
d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)
gates_flat = gates.flatten()
gate_fan = gates.size(1)
# Reuse output_expanded buffer for grouped_grad_out
grouped_grad_out = output_expanded.flatten(0, 1)
else:
d_gates = None
gates_flat = None
gate_fan = 1
grouped_grad_out = None
# ------------------------------------------------------------------
# LoRA gradients (dA, dB) and setup for dX
# ------------------------------------------------------------------
# Fused gather uses sorted_scattered_idxs for indirect X access
# in the Triton kernel, avoiding the group(x) allocation.
#
# can_fuse_gather: X is ungrouped and not too large for scatter loads
# - When gates is None and grouped_out=False: both DY and X ungrouped
# - When grouped_out=True (gate_up_proj): DY already grouped, X ungrouped
# -> use dy_grouped=True in the fused kernel
M_total = sorted_scattered_idxs.size(0)
K_dim = x.size(-1)
N_dim = expert_weights.size(-1)
fuse_gather_workload = M_total * max(K_dim, N_dim)
_FUSE_GATHER_THRESHOLD = 2**24 # ~16M elements
can_fuse_gather = (
ctx.use_fused_gather
and not grouped_in # X must be ungrouped for scatter access
and gates is None # gate coeff requires multiplicative gather
and fuse_gather_workload < _FUSE_GATHER_THRESHOLD
)
if can_fuse_gather:
# ------------------------------------------------------------------
# Fused path: skip group(x) entirely
# ------------------------------------------------------------------
d_expanded_input = None
d_lora_A, d_lora_B = group_bwd_lora_fused(
DY=grad_out,
X=x,
lora_A=lora_A,
lora_B=lora_B,
expert_offsets=expert_offsets,
sorted_scattered_idxs=sorted_scattered_idxs,
E=E,
k=k,
scaling=scaling,
dy_grouped=grouped_out,
)
# Prepare grouped_grad_out for the dX path (needed by both
# the fused dX kernel when grouped_out=True, and the non-fused path)
if grouped_out:
grouped_grad_out = grad_out
elif not ctx.use_fused_dX:
grouped_grad_out = base_ops.group(
grad_out,
sorted_scattered_idxs,
fan_out=gate_fan,
coeff=gates_flat,
out=grouped_grad_out,
)
else:
# ------------------------------------------------------------------
# Original path: explicit group() calls
# ------------------------------------------------------------------
if grouped_out:
grouped_grad_out = grad_out
else:
grouped_grad_out = base_ops.group(
grad_out,
sorted_scattered_idxs,
fan_out=gate_fan,
coeff=gates_flat,
out=grouped_grad_out,
)
if grouped_in:
grouped_x = x
d_expanded_input = None
else:
grouped_x = base_ops.group(x, sorted_scattered_idxs, fan_out=k)
d_expanded_input = grouped_x # Will be overwritten; reuse buffer
d_lora_A, d_lora_B = group_bwd_lora(
DY=grouped_grad_out,
X=grouped_x,
lora_A=lora_A,
lora_B=lora_B,
expert_offsets=expert_offsets,
E=E,
scaling=scaling,
)
# ------------------------------------------------------------------
# Input gradient: dX = dY @ W^T + scaling * (dY @ B) @ A
# ------------------------------------------------------------------
if ctx.use_fused_dX:
if can_fuse_gather and not grouped_out:
# Fully fused: read ungrouped DY via scatter pattern
d_expanded_input = scatter2scatter_lora_dX(
DY=grad_out,
W=expert_weights,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
dy_grouped=False,
dx_grouped=grouped_in,
out=d_expanded_input,
)
else:
# Fused dX only: read from pre-grouped DY
d_expanded_input = scatter2scatter_lora_dX(
DY=grouped_grad_out,
W=expert_weights,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
dy_grouped=True,
dx_grouped=grouped_in,
out=d_expanded_input,
)
else:
# Original path: separate base scatter2scatter + LoRA Python loop
d_expanded_input = base_ops.scatter2scatter(
X=grouped_grad_out,
x_grouped=True,
W=expert_weights.permute(0, 2, 1), # [E, N, K]
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
y_grouped=grouped_in,
out=d_expanded_input,
)
# LoRA part: dX_lora = scaling * (dY @ B) @ A
if scaling != 0.0:
d_input_lora_grouped = _compute_lora_input_grad(
grouped_grad_out,
lora_A,
lora_B,
expert_offsets,
E,
scaling,
)
if grouped_in:
d_expanded_input.add_(d_input_lora_grouped)
else:
# Scatter-add LoRA gradient directly into d_expanded_input.
# Avoids allocating a zeros_like + add result
d_expanded_input[sorted_scattered_idxs] += d_input_lora_grouped
# Reduce over top-k if k > 1
if k == 1:
d_input = d_expanded_input
else:
d_input = d_expanded_input.view(
x.size(0), k, d_expanded_input.size(-1)
).sum(-2)
# W is frozen during LoRA training -- skip weight gradient
d_weights = (
torch.zeros_like(expert_weights)
if expert_weights.requires_grad
else None
)
d_biases = None
return (
d_input,
d_weights,
None,
None,
None,
None, # k, sorted indices, offsets
d_lora_A,
d_lora_B,
None, # lora_A, lora_B, scaling
d_biases,
d_gates,
None,
None, # grouped_in, grouped_out
None, # use_fused_dX
None, # use_fused_gather
)
def _compute_lora_input_grad(
grouped_grad_out: torch.Tensor,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
expert_offsets: torch.Tensor,
E: int,
scaling: float,
) -> torch.Tensor:
"""
Compute the LoRA contribution to the input gradient:
dX_lora = scaling * (dY @ B) @ A
Uses PyTorch ops on expert-grouped data.
Each expert e: dX_e = scaling * (dY_e @ B_e) @ A_e
"""
R = lora_A.size(0) // E
K = lora_A.size(1)
M_total = grouped_grad_out.size(0)
d_input_lora = torch.zeros(
(M_total, K), device=grouped_grad_out.device, dtype=grouped_grad_out.dtype
)
compute_dtype = grouped_grad_out.dtype
prev_offset = 0
for e in range(E):
curr_offset = expert_offsets[e].item()
if curr_offset > prev_offset:
dy_e = grouped_grad_out[prev_offset:curr_offset] # [M_e, N]
a_e = lora_A[e * R : (e + 1) * R, :].to(compute_dtype) # [r, K]
b_e = lora_B[:, e * R : (e + 1) * R].to(compute_dtype) # [N, r]
# dX_e = scaling * (dY_e @ B_e) @ A_e
dy_b = dy_e @ b_e # [M_e, r]
dx_e = scaling * (dy_b @ a_e) # [M_e, K]
d_input_lora[prev_offset:curr_offset] = dx_e
prev_offset = curr_offset
return d_input_lora
# =============================================================================
# Helper: Extract LoRA params from PEFT ParamWrapper
# =============================================================================
def get_lora_params_from_wrapper(module) -> tuple:
"""
Extract LoRA parameters from a PEFT ParamWrapper.
Returns:
(lora_A, lora_B, scaling) if LoRA is active, else (None, None, None)
"""
if not hasattr(module, "lora_A") or not hasattr(module, "lora_B"):
return None, None, None
active_adapters = getattr(module, "active_adapters", ["default"])
if not active_adapters:
return None, None, None
adapter_name = active_adapters[0]
lora_A_dict = getattr(module, "lora_A", {})
lora_B_dict = getattr(module, "lora_B", {})
scaling_dict = getattr(module, "scaling", {})
if adapter_name not in lora_A_dict:
return None, None, None
lora_A = lora_A_dict[adapter_name].weight
lora_B = lora_B_dict[adapter_name].weight
scaling = scaling_dict[adapter_name]
return lora_A, lora_B, scaling
# =============================================================================
# Drop-in replacement for parallel_linear
# =============================================================================
def parallel_linear_lora(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
lora_A: Optional[torch.Tensor] = None,
lora_B: Optional[torch.Tensor] = None,
scaling: float = 1.0,
expert_biases: Optional[torch.Tensor] = None,
gates: Optional[torch.Tensor] = None,
grouped_in: bool = False,
grouped_out: bool = False,
use_fused_dX: bool = False,
use_fused_gather: bool = False,
):
"""
Drop-in replacement for parallel_linear that supports LoRA.
If lora_A and lora_B are provided, uses fused LoRA kernel.
Otherwise falls back to standard scatter2scatter.
"""
if lora_A is not None and lora_B is not None:
return ScatterMoELoRA.apply(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_A,
lora_B,
scaling,
expert_biases,
gates,
grouped_in,
grouped_out,
use_fused_dX,
use_fused_gather,
)
else:
from .parallel_experts import ParallelLinear
return ParallelLinear.apply(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
expert_biases,
gates,
grouped_in,
grouped_out,
)

View File

@@ -1,66 +0,0 @@
from pathlib import Path
from kernels import (
LocalLayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
)
from axolotl.integrations.base import BasePlugin
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
class KernelsPlugin(BasePlugin):
def get_input_args(self):
return "axolotl.integrations.kernels.KernelsArgs"
def pre_model_load(self, cfg):
if cfg.use_scattermoe:
self._register_kernels()
self._kernelize_model(cfg.model_config_type)
def _register_kernels(self):
plugin_root = Path(__file__).parent
register_kernel_mapping(
{
"HFScatterMoEParallelExperts": {
"cuda": {
Mode.TRAINING: LocalLayerRepository(
repo_path=plugin_root / "libs" / "scattermoe_lora",
package_name="scattermoe_lora",
layer_name="HFScatterMoEGatedMLP",
),
Mode.INFERENCE: LocalLayerRepository(
repo_path=plugin_root / "libs" / "scattermoe_lora",
package_name="scattermoe_lora",
layer_name="HFScatterMoEGatedMLP",
),
},
}
}
)
def _kernelize_model(self, model_type: str):
if model_type == "olmoe":
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
replace_kernel_forward_from_hub(
OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts"
)
else:
try:
model_moe_cls = get_model_moe_block(model_type)
replace_kernel_forward_from_hub(
model_moe_cls, "HFScatterMoEParallelExperts"
)
except Exception as err:
raise ValueError(f"Unsupported model type: {model_type}") from err
def get_model_moe_block(model_type: str):
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}SparseMoeBlock"])
model_cls = getattr(module, f"{model_cls_prefix}SparseMoeBlock")
return model_cls

View File

@@ -12,6 +12,7 @@ def save_compressed_model(
model: PreTrainedModel,
output_dir: Union[str, bytes],
trainer: Trainer,
safe_serialization: bool = False,
save_compressed: bool = False,
) -> None:
"""
@@ -21,6 +22,7 @@ def save_compressed_model(
model (PreTrainedModel): The model to be saved.
output_dir (str or bytes): Path where the model files will be written.
trainer (Trainer): Hugging Face Trainer for process synchronization.
safe_serialization (bool): Use safe serialization if True.
save_compressed (bool): Write compressed tensors if True.
"""
trainer.accelerator.wait_for_everyone()
@@ -32,6 +34,7 @@ def save_compressed_model(
modify_save_pretrained(model)
model.save_pretrained(
output_dir,
safe_serialization=safe_serialization,
save_compressed=save_compressed,
skip_sparsity_compression_stats=not save_compressed,
)

View File

@@ -6,12 +6,6 @@ See https://github.com/EleutherAI/lm-evaluation-harness
## Usage
There are two ways to use the LM Eval integration:
### 1. Post-Training Evaluation
When training with the plugin enabled, evaluation runs automatically after training completes:
```yaml
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
@@ -22,50 +16,9 @@ lm_eval_tasks:
- arc_easy
lm_eval_batch_size: # Batch size for evaluation
# Directory to save evaluation results.
# The final model is loaded from this directory
# unless specified otherwise (see below)
output_dir:
output_dir: # Directory to save evaluation results
```
Run training as usual:
```bash
axolotl train config.yml
```
### 2. Standalone CLI Evaluation
Evaluate any model directly without training:
```yaml
lm_eval_model: meta-llama/Llama-2-7b-hf
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
lm_eval_batch_size: 8
output_dir: ./outputs
```
Run evaluation:
```bash
axolotl lm-eval config.yml
```
## Model Selection Priority
The model to evaluate is selected in the following priority order:
1. **`lm_eval_model`** - Explicit model path or HuggingFace repo (highest priority)
2. **`hub_model_id`** - Trained model pushed to HuggingFace Hub
3. **`output_dir`** - Local checkpoint directory containing trained model weights
## Citation
```bib

View File

@@ -5,7 +5,7 @@ Module for the Plugin for LM Eval Harness
import subprocess # nosec
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lm_eval.cli import build_lm_eval_command, get_model_path
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
from .args import LMEvalArgs as LMEvalArgs
@@ -29,7 +29,7 @@ class LMEvalPlugin(BasePlugin):
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=get_model_path(cfg),
model=cfg.lm_eval_model or cfg.hub_model_id,
):
subprocess.run( # nosec
lm_eval_args,

View File

@@ -13,21 +13,6 @@ import yaml
from axolotl.utils.dict import DictDefault
def get_model_path(cfg: DictDefault) -> str | None:
"""
Determine which model path to use for evaluation.
Priority order (highest to lowest):
1. lm_eval_model - Explicit model path override
2. hub_model_id - Model pushed to HuggingFace Hub
3. None - Falls back to output_dir in build_lm_eval_command
Returns:
Model path string or None to use output_dir fallback
"""
return cfg.lm_eval_model or cfg.hub_model_id or None
def build_lm_eval_command(
tasks: list[str],
bfloat16=True,
@@ -123,7 +108,7 @@ def lm_eval(config: str, cloud: Optional[str] = None):
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=get_model_path(cfg),
model=cfg.lm_eval_model or cfg.hub_model_id,
revision=cfg.revision,
apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=cfg.fewshot_as_multiturn,

View File

@@ -34,7 +34,7 @@ def setup_quantized_meta_for_peft(model: torch.nn.Module):
return self
for param in model.parameters():
if isinstance(param, Params4bit) and param.quant_state is not None:
if isinstance(param, Params4bit):
param.quant_state._orig_to = param.quant_state.to
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)

View File

@@ -26,6 +26,7 @@ from torch.distributed import DeviceMesh
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForVision2Seq,
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
@@ -172,10 +173,7 @@ class ModelLoader:
# Build the model
PLUGIN_MANAGER.pre_model_load(self.cfg)
self.patch_manager.apply_post_plugin_pre_model_load_patches()
skip_move_to_device = self._build_model()
self.patch_manager.apply_post_model_build_patches(self.model)
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
# Post-build model configuration
@@ -228,7 +226,6 @@ class ModelLoader:
):
self.model = self.model.merge_and_unload()
self._configure_experts_implementation()
self._apply_activation_checkpointing()
self._resize_token_embeddings()
self._adjust_model_config()
@@ -236,10 +233,6 @@ class ModelLoader:
self._configure_qat()
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
def _configure_experts_implementation(self):
if self.cfg.experts_implementation is not None:
self.model.set_experts_implementation(self.cfg.experts_implementation)
def _apply_activation_checkpointing(self):
if self.cfg.activation_offloading is True:
from axolotl.core.trainers.mixins.activation_checkpointing import (
@@ -341,12 +334,7 @@ class ModelLoader:
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so
# we need to convert them back to fp16/bf16 for flash-attn compatibility.
(
(
needs_fa2_dtype
or self.cfg.flash_attention
or self.cfg.flex_attention
or self.cfg.sage_attention
)
(needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention)
and not self.is_qlora_and_fsdp_enabled
)
or (
@@ -446,7 +434,7 @@ class ModelLoader:
"""
if self.cfg.is_multimodal:
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
self.model_config.model_type, AutoModelForImageTextToText
self.model_config.model_type, AutoModelForVision2Seq
)
if isinstance(self.auto_model_loader, str):
self.auto_model_loader = AutoModelForImageTextToText
@@ -488,7 +476,6 @@ class ModelLoader:
max_memory = None
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
self.model_kwargs["dtype"] = self.cfg.torch_dtype
is_ds_zero3 = is_deepspeed_zero3_enabled()
@@ -620,10 +607,6 @@ class ModelLoader:
elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = "sdpa"
elif self.cfg.sage_attention:
# sets FA2 attention to re-use same internal handling like masking
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = "eager"
@@ -687,7 +670,7 @@ class ModelLoader:
Uses the selected loader when provided; otherwise falls back to the auto loader.
"""
loader = model_loader_class or self.auto_model_loader
if loader in [AutoModelForCausalLM, AutoModelForImageTextToText]:
if loader in [AutoModelForCausalLM, AutoModelForVision2Seq]:
model = loader.from_config(
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
@@ -805,7 +788,6 @@ class ModelLoader:
# Use auto model loader (handles gptq and default cases)
model_loader_class = self.auto_model_loader
self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"]
if self.cfg.reinit_weights:
self.model = self._load_model_from_config(model_loader_class)
else:
@@ -863,10 +845,6 @@ class ModelLoader:
# Make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True
if getattr(self.model, "_moe_experts_quantized", False):
# Parametrized expert tensors dequantize on access — would OOM.
skip_prepare_model_for_kbit_training = True
if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]

View File

@@ -10,7 +10,6 @@ from functools import cached_property
import addict
import transformers
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_flash_attention_utils import is_flash_attn_available
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import (
@@ -97,7 +96,6 @@ class PatchManager:
# self._apply_flex_attention_patches()
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
self._apply_sageattn_patches()
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_model_specific_patches()
@@ -118,7 +116,6 @@ class PatchManager:
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
self._apply_tiled_mlp(self.cfg.model_config_type)
self._apply_moe_expert_quantization_patch()
def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
@@ -136,10 +133,6 @@ class PatchManager:
patch_prepare_context_parallel_inputs()
def apply_post_model_build_patches(self, model: PreTrainedModel):
"""Apply patches right after model build, before post-load setup."""
self._finalize_moe_expert_quantization(model)
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model)
@@ -175,14 +168,9 @@ class PatchManager:
patch_parallelism_config()
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import (
patch_accelerate_fsdp2,
patch_tied_keys_for_meta_device,
)
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
patch_accelerate_fsdp2()
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
patch_tied_keys_for_meta_device()
if self.cfg.rl:
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
@@ -213,13 +201,6 @@ class PatchManager:
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs)
def _apply_sageattn_patches(self):
"""Apply patches for SageAttention."""
if self.cfg.sage_attention:
from axolotl.monkeypatch.attention.sage_attn import patch_sageattn
patch_sageattn()
def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures."""
if (
@@ -239,6 +220,13 @@ class PatchManager:
patch_qwen3_next_modeling_packing()
if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type:
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
apply_mistral_tokenizer_image_patch,
)
apply_mistral_tokenizer_image_patch()
if self.cfg.model_config_type == "kimi_linear":
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
patch_kimi_model,
@@ -339,7 +327,7 @@ class PatchManager:
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is not None:
if has_remote_code and self.cfg.trust_remote_code is False:
# If explicitly set in YAML, prefer that
has_remote_code = self.cfg.trust_remote_code
@@ -362,54 +350,15 @@ class PatchManager:
if (
self.cfg.fsdp_config
and str(self.cfg.fsdp_version) == "2"
and (self.cfg.load_in_4bit or self.cfg.load_in_8bit)
and self.cfg.adapter == "qlora"
):
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_dtype_attrs_patch,
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
apply_linear8bitlt_save_patch,
)
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
apply_init_dtype_attrs_patch()
if self.cfg.load_in_8bit:
apply_linear8bitlt_save_patch()
def _apply_moe_expert_quantization_patch(self):
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
if not self.cfg.quantize_moe_experts:
return
from axolotl.monkeypatch.moe_quant import (
patch_moe_quantization_on_load,
patch_peft_target_parameters_matching,
)
patch_moe_quantization_on_load(self.cfg)
patch_peft_target_parameters_matching()
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):
"""Log quantization results and set model flag for downstream use."""
import torch
model._moe_experts_quantized = False
if self.cfg.quantize_moe_experts:
from axolotl.monkeypatch.moe_quant import get_moe_quantized_count
count = get_moe_quantized_count()
if count > 0:
import gc
model._moe_experts_quantized = True
LOG.info(
"Quantized %d MoE expert parameter(s) to %s during model loading",
count,
"4-bit" if self.cfg.load_in_4bit else "8-bit",
)
gc.collect()
torch.cuda.empty_cache()
def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:
@@ -550,7 +499,6 @@ class PatchManager:
and not self.cfg.trust_remote_code
and not self.cfg.gptq
and self.cfg.flash_attention
and is_flash_attn_available()
and not self.inference
):
# TODO(MengqingCao): split these patches separately

View File

@@ -19,11 +19,6 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
if cfg.processor_type:
processor_cls = getattr(transformers, cfg.processor_type)
# Build common kwargs for processor loading
processor_kwargs = {}
if cfg.revision_of_model:
processor_kwargs["revision"] = cfg.revision_of_model
if cfg.tokenizer_use_mistral_common:
def _patch_mistralcommontokenizer():
@@ -36,7 +31,7 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
from axolotl.utils.mistral import HFMistralTokenizer
tokenization_mistral_common.MistralCommonBackend = HFMistralTokenizer
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
_patch_mistralcommontokenizer()
@@ -45,7 +40,6 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
if processor_cls == VoxtralProcessor:
return VoxtralProcessor.from_pretrained(
cfg.processor_config,
**processor_kwargs,
)
from axolotl.utils.mistral import Mistral3Processor
@@ -54,12 +48,10 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
tokenizer=tokenizer,
)
processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False
processor_kwargs["tokenizer"] = tokenizer
processor = processor_cls.from_pretrained(
cfg.processor_config,
**processor_kwargs,
trust_remote_code=cfg.trust_remote_code or False,
tokenizer=tokenizer,
)
# Attempt to load image size from processor if available

View File

@@ -28,10 +28,7 @@ PLUGIN_MANAGER = PluginManager.get_instance()
def modify_tokenizer_files(
tokenizer_path: str,
token_mappings: dict[int, str],
output_dir: str,
revision: str = "main",
tokenizer_path: str, token_mappings: dict[int, str], output_dir: str
) -> str:
"""
Modify tokenizer files to replace added_tokens strings, save to output directory,
@@ -44,7 +41,6 @@ def modify_tokenizer_files(
tokenizer_path: Path or name of the original tokenizer
token_mappings: Dict mapping {token_id (int): new_token_string}
output_dir: Directory to save the modified tokenizer
revision: Model revision/branch/tag/commit to load from (HF Hub)
Returns:
Path to the modified tokenizer directory
@@ -57,9 +53,7 @@ def modify_tokenizer_files(
if is_local_main_process():
# Load the tokenizer
temp_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, use_fast=True, revision=revision
)
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
# Save the tokenizer to the output directory
temp_tokenizer.save_pretrained(tokenizer_dir)
@@ -140,10 +134,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
from axolotl.utils.mistral import HFMistralTokenizer
# Load the HF-compatible wrapper around MistralTokenizer
kwargs = {}
if cfg.revision_of_model:
kwargs["revision"] = cfg.revision_of_model
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config, **kwargs)
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
return tokenizer
@@ -159,8 +150,6 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
if cfg.revision_of_model:
tokenizer_kwargs["revision"] = cfg.revision_of_model
tokenizer_cls = AutoTokenizer
if cfg.tokenizer_type:
@@ -172,11 +161,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
# Apply token string overrides if specified
if cfg.added_tokens_overrides:
# Modify tokenizer files and get path to modified tokenizer
modify_kwargs = {"output_dir": cfg.output_dir}
if cfg.revision_of_model:
modify_kwargs["revision"] = cfg.revision_of_model
tokenizer_path = modify_tokenizer_files(
tokenizer_path, cfg.added_tokens_overrides, **modify_kwargs
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
)
tokenizer = tokenizer_cls.from_pretrained(

View File

@@ -111,7 +111,7 @@ class MambaLMHeadModel(nn.Module, GenerationMixin):
self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
**kwargs,
safe_serialization: Optional[bool] = None,
):
if state_dict is None:
state_dict = self.state_dict()

View File

@@ -150,17 +150,13 @@ def get_state_dict(self, model, unwrap=True):
)
elif self.is_fsdp2:
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
from torch.distributed.tensor import DTensor
state_dict = {}
sharded_state_dict = model.state_dict()
for param_name, param in sharded_state_dict.items():
if param.is_cpu:
param = param.to(torch.device("cuda"))
if isinstance(param, DTensor):
param = param.full_tensor()
param = param.full_tensor()
if torch.distributed.get_rank() == 0:
state_dict[param_name] = param.cpu()
torch.distributed.barrier()
@@ -186,56 +182,10 @@ def get_state_dict(self, model, unwrap=True):
return state_dict
def patch_peft_param_wrapper_for_fsdp2():
"""Patch PEFT's _LoraParameterProxy.forward for FSDP2 DTensor compatibility.
PEFT's ParamWrapper applies LoRA via torch.nn.utils.parametrize, which adds
delta_weight to the base weight W inside _LoraParameterProxy.forward().
Under FSDP2, W may be a DTensor (from FSDP unshard) while delta_weight is a
regular Tensor (or vice versa), causing a RuntimeError on mixed types.
This patch promotes the non-DTensor operand to match the DTensor's spec
using DTensor.from_local(), which is free for Replicate placement (just
metadata wrapping, no communication).
"""
from peft.tuners.lora.layer import _LoraParameterProxy
if getattr(_LoraParameterProxy, "_axolotl_fsdp2_patched", False):
return
_original_forward = _LoraParameterProxy.forward
# NOTE: Replaces (not wraps) forward; assumes original is just `W + self.delta_weight`.
def _patched_forward(self, W):
from torch.distributed.tensor import DTensor
delta = self.delta_weight
w_is_dt = isinstance(W, DTensor)
d_is_dt = isinstance(delta, DTensor)
with torch.nn.utils.parametrize.cached():
if w_is_dt == d_is_dt:
return W + delta
if w_is_dt:
return W + DTensor.from_local(delta, W.device_mesh, W.placements)
return DTensor.from_local(W, delta.device_mesh, delta.placements) + delta
_LoraParameterProxy.forward = _patched_forward
_LoraParameterProxy._axolotl_fsdp2_patched = True
LOG.info("Patched PEFT _LoraParameterProxy.forward for FSDP2 DTensor compatibility")
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
"""Helper function to process LoRA modules for FSDP2."""
from peft.tuners.lora.layer import ParamWrapper
from torch.distributed.fsdp import fully_shard
# Skip ParamWrapper — its lora_A/B must not be independently sharded.
# The parent decoder layer's FSDP wrapper handles unsharding them.
# TODO: review if we even need to shard them separately in first place.
if isinstance(module, ParamWrapper):
return False
log_bias_dtype_mismatch = False
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
@@ -377,14 +327,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
is_peft_model = isinstance(model, PeftModel)
# Patch PEFT's _LoraParameterProxy for DTensor compatibility if any
# ParamWrapper modules exist (used for target_parameters / 3D expert params).
if is_peft_model:
from peft.tuners.lora.layer import ParamWrapper
if any(isinstance(m, ParamWrapper) for m in model.modules()):
patch_peft_param_wrapper_for_fsdp2()
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
log_bias_dtype_mismatch = False
if auto_wrap_policy is not None:
@@ -434,43 +376,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
return model
def patch_tied_keys_for_meta_device():
"""Patch _adjust_tied_keys_with_tied_pointers to skip meta tensors.
Meta tensors all share data_ptr()==0, causing every parameter to be incorrectly
grouped as "tied". Skipping them is safe since they have no real storage.
"""
from collections import defaultdict
from transformers import PreTrainedModel
def _patched_adjust_tied_keys_with_tied_pointers(self, missing_keys):
param_pointers = defaultdict(list)
for param_name, param_value in self.state_dict().items():
if param_value.is_meta:
continue
param_pointers[param_value.data_ptr()].append(param_name)
tied_param_names = [
names
for names in param_pointers.values()
if len(names) > 1
and not any(name in self.all_tied_weights_keys.keys() for name in names)
and not all(name in missing_keys for name in names)
]
tied_weights_keys_by_pointers = {
param_name: group[0]
for group in tied_param_names
for param_name in group[1:]
}
self.all_tied_weights_keys.update(tied_weights_keys_by_pointers)
PreTrainedModel._adjust_tied_keys_with_tied_pointers = (
_patched_adjust_tied_keys_with_tied_pointers
)
def patch_accelerate_fsdp2():
import accelerate

View File

@@ -1,211 +0,0 @@
"""
Monkeypatch for SageAttention for use with transformers.
https://github.com/thu-ml/SageAttention/
"""
import torch
from transformers.integrations.sdpa_attention import repeat_kv
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
sageattn = None # pylint: disable=invalid-name
sageattn_varlen = None # pylint: disable=invalid-name
def _is_sageattn_available():
"""Determine if SageAttention is available"""
try:
import sageattention # noqa: F401 # pylint: disable=unused-import
return True
except ImportError:
return False
if _is_sageattn_available():
# import sageattn here if available
from sageattention import sageattn, sageattn_varlen
def _check_sageattn_imported():
"""Check if SageAttention is imported. Raises an ImportError if not."""
if sageattn is None:
raise ImportError(
"SageAttention is not installed. Please install it from source: "
"`pip install git+https://github.com/thu-ml/SageAttention.git@1718ddc06dbc694bcf3c6b49ac28c1921aa2d8bd`"
)
def sage_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None = None,
dropout: float = 0.0,
scaling: float | None = None,
is_causal: bool | None = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
"""
Forward pass for SageAttention compatible with transformers attention interfaces.
https://github.com/thu-ml/SageAttention/
"""
_check_sageattn_imported()
if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
raise NotImplementedError(
"SageAttention does not support `output_attentions=True` or `head_mask`."
)
# The base sageattn API does not support dropout.
if dropout > 0.0:
raise NotImplementedError("SageAttention does not support dropout.")
# Handle Grouped-Query Attention (GQA) and Multi-Query Attention (MQA)
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
# Calculate is_causal following transformers
assert is_causal is not False, "is_causal must be True or None"
is_causal = True
position_ids = kwargs.get("position_ids", None)
query_length = query.shape[2]
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
cu_seqlens_k = kwargs.get("cu_seqlens_k", None)
max_length_q = kwargs.get("max_length_q", None)
max_length_k = kwargs.get("max_length_k", None)
# Sample packing uses position_ids, so we check for it first
if position_ids is not None and (
max_length_q is not None
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
):
# transpose inputs to NHD layout for use with FA2 utils
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
batch_size = query.size(0)
from transformers.modeling_flash_attention_utils import (
prepare_fa2_from_position_ids,
)
if cu_seqlens_q is None or cu_seqlens_k is None:
query, key, value, indices_q, cu_seq_lens, max_seq_lens = (
prepare_fa2_from_position_ids(query, key, value, position_ids)
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_length_q, max_length_k = max_seq_lens
else:
query = query.reshape(-1, query.size(-2), query.size(-1))
key = key.reshape(-1, key.size(-2), key.size(-1))
value = value.reshape(-1, value.size(-2), value.size(-1))
attn_output_unpad = sageattn_varlen(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
is_causal=is_causal,
sm_scale=scaling,
smooth_k=False, # reduces loss 0 / nan grad norms
tensor_layout="NHD",
)
attn_output = attn_output_unpad.view(
batch_size, -1, attn_output_unpad.size(-2), attn_output_unpad.size(-1)
)
elif attention_mask is not None:
# NOTE: When used without `pad_to_sequence_len`, the loss becomes unstable after a few steps.
assert attention_mask.ndim == 2, "Attention mask must be 2D"
from transformers.modeling_flash_attention_utils import (
_upad_input,
)
# transpose inputs to NHD layout for use with FA2 utils
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
batch_size = query.shape[0]
query, key, value, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
query, key, value, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_q, max_seqlen_k = max_seq_lens
attn_output_unpad = sageattn_varlen(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
is_causal=is_causal,
sm_scale=scaling,
tensor_layout="NHD",
)
from flash_attn.bert_padding import pad_input
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
# Use standard sageattn
# The input layout for transformers models is (batch_size, num_heads, seq_len, head_dim),
# which corresponds to SageAttention's "HND" layout.
attn_output = sageattn(
q=query,
k=key,
v=value,
tensor_layout="HND",
is_causal=is_causal,
sm_scale=scaling,
)
# SageAttention with "HND" returns (batch, heads, seq_len, head_dim)
# Transformers expects (batch, seq_len, heads, head_dim) for the output
# So we need to transpose dimensions 1 and 2
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
def patch_sageattn():
"""Patch SageAttention for use with transformers."""
_check_sageattn_imported()
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
# Replace flash attention with sage attention
ALL_ATTENTION_FUNCTIONS.register("flash_attention_2", sage_attention_forward)
# Note: New method after transformers refactor to use ALL_MASK_ATTENTION_FUNCTIONS
# Register sage_attention with the global attention interface
# ALL_ATTENTION_FUNCTIONS.register("sage_attention", sage_attention_forward)
# from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, flash_attention_mask
# ALL_MASK_ATTENTION_FUNCTIONS.register("sage_attention", flash_attention_mask)
LOG.info("SageAttention patched successfully")

View File

@@ -1,10 +1,9 @@
"""
Monkeypatch to add Params4bit and Int8Params support to FSDP2. This enables QLoRA + FSDP2
and 8-bit LoRA + FSDP2, as well as our LoRA / QLoRA Triton kernels to work with FSDP2.
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
our LoRA / QLoRA Triton kernels to work with FSDP2.
This patch modifies the _init_sharded_param and init_unsharded_param methods in FSDPParam
to handle bitsandbytes Params4bit and Int8Params parameters, preserving their quantization
metadata through the FSDP2 shard/unshard cycle.
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
Params4bit parameters.
"""
import importlib
@@ -18,8 +17,6 @@ LOG = get_logger(__name__)
def apply_init_sharded_param_patch():
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
if getattr(apply_init_sharded_param_patch, "_axolotl_patched", False):
return
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# Get original source
@@ -44,20 +41,9 @@ def apply_init_sharded_param_patch():
bnb_quantized=param.bnb_quantized,
)
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
elif isinstance(param, bnb.nn.modules.Int8Params):
self.sharded_param = bnb.nn.modules.Int8Params(
data=sharded_param,
requires_grad=param.requires_grad,
has_fp16_weights=param.has_fp16_weights,
CB=None,
SCB=param.SCB,
)
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
else:
self.sharded_param = nn.Parameter(
self.to_sharded_dtensor(sharded_param),
requires_grad=param.requires_grad,
)"""
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
self.sharded_param.requires_grad_(param.requires_grad)"""
# Apply the replacement
if original_param_creation in original_source:
@@ -87,7 +73,6 @@ def apply_init_sharded_param_patch():
# Replace the method
FSDPParam._init_sharded_param = patched_init_sharded_param
apply_init_sharded_param_patch._axolotl_patched = True
LOG.info("Successfully applied FSDP _init_sharded_param patch")
else:
LOG.warning("Could not find target code for _init_sharded_param patching")
@@ -95,8 +80,6 @@ def apply_init_sharded_param_patch():
def apply_init_unsharded_param_patch():
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
if getattr(apply_init_unsharded_param_patch, "_axolotl_patched", False):
return
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# Get original source
@@ -122,14 +105,6 @@ def apply_init_unsharded_param_patch():
module=local_tensor.module,
bnb_quantized=local_tensor.bnb_quantized,
)
elif isinstance(local_tensor, bnb.nn.modules.Int8Params):
self._unsharded_param = bnb.nn.modules.Int8Params(
data=unsharded_param,
requires_grad=self.sharded_param.requires_grad,
has_fp16_weights=local_tensor.has_fp16_weights,
CB=unsharded_param,
SCB=local_tensor.SCB,
)
else:
self._unsharded_param = nn.Parameter(
unsharded_param, requires_grad=self.sharded_param.requires_grad
@@ -163,74 +138,6 @@ def apply_init_unsharded_param_patch():
# Replace the method
FSDPParam.init_unsharded_param = patched_init_unsharded_param
apply_init_unsharded_param_patch._axolotl_patched = True
LOG.info("Successfully applied FSDP init_unsharded_param patch")
else:
LOG.warning("Could not find target code for patching")
def apply_linear8bitlt_save_patch():
"""Patch Linear8bitLt._save_to_state_dict to handle DTensor-wrapped Int8Params.
After FSDP2 sharding, Linear8bitLt.weight is a DTensor wrapping Int8Params.
BnB's _save_to_state_dict accesses self.weight.SCB directly, but DTensor
doesn't proxy custom attribute access to its _local_tensor. This patch
temporarily unwraps the DTensor during saving so BnB can find the SCB attribute.
"""
if getattr(apply_linear8bitlt_save_patch, "_axolotl_patched", False):
return
import bitsandbytes as bnb
from torch.distributed.tensor import DTensor
original_save = bnb.nn.Linear8bitLt._save_to_state_dict
def _patched_save_to_state_dict(self, destination, prefix, keep_vars):
# Use _parameters dict directly to bypass nn.Module.__setattr__ type check.
weight = self._parameters["weight"]
unwrapped = False
if isinstance(weight, DTensor) and hasattr(weight, "_local_tensor"):
self._parameters["weight"] = weight._local_tensor
unwrapped = True
try:
original_save(self, destination, prefix, keep_vars)
finally:
if unwrapped:
self._parameters["weight"] = weight
bnb.nn.Linear8bitLt._save_to_state_dict = _patched_save_to_state_dict
apply_linear8bitlt_save_patch._axolotl_patched = True
LOG.info("Patched Linear8bitLt._save_to_state_dict for DTensor compatibility")
def apply_init_dtype_attrs_patch():
"""Prevent FSDP2 mixed precision from casting non-float quantized params.
When mixed precision is enabled (e.g., bf16), FSDP2's init_dtype_attrs sets
param_dtype=bf16 for ALL params. During all-gather, _to_dtype_if_needed casts
the sharded param to param_dtype. For non-float params (uint8 packed 4-bit,
int8 quantized) without FSDP2 extensions, this destroys the quantized data.
Params4bit handles this via fsdp_pre/post_all_gather extensions, but our
parametrize-based expert quantization uses plain nn.Parameter(uint8/int8)
without extensions.
"""
if getattr(apply_init_dtype_attrs_patch, "_axolotl_patched", False):
return
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
original_init_dtype_attrs = FSDPParam.init_dtype_attrs
def patched_init_dtype_attrs(self, mp_policy):
original_init_dtype_attrs(self, mp_policy)
# Skip casting non-float quantized params (uint8/int8) without FSDP2
# extensions — the parametrization chain handles dequantization.
if self.param_dtype is not None and not self.sharded_param.is_floating_point():
local = self.sharded_param
if hasattr(local, "_local_tensor"):
local = local._local_tensor
if not hasattr(local, "fsdp_pre_all_gather"):
self.param_dtype = None
FSDPParam.init_dtype_attrs = patched_init_dtype_attrs
apply_init_dtype_attrs_patch._axolotl_patched = True
LOG.info("Patched FSDPParam.init_dtype_attrs for non-float quantized params")

View File

@@ -59,12 +59,7 @@ class CPU_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True
with torch.enable_grad():
output = ctx.forward_function(hidden_states, *ctx.args)
# Newer HF models (e.g. Qwen3MoE) using GradientCheckpointingLayer
# return a plain tensor, not a tuple. Older models return tuples
# like (hidden_states, present_kv, ...). Unwrap if needed.
if isinstance(output, (tuple, list)):
(output,) = output
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, dY)
return (
None,

View File

@@ -169,8 +169,7 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
return attention_cls
except (ImportError, AttributeError) as e:
raise ValueError(
f"Axolotl could not import attention class for model_type: {model_type}. "
"Please raise an Issue and turn off lora kernels to continue training. "
f"Could not import attention class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e

View File

@@ -1,51 +0,0 @@
"""
eaft (entropy-aware focal training) loss implementation
weights examples by entropy approximation from top-k logits
Reference: https://github.com/ymxyll/LlamaFactory-EAFT/blob/e2ce19e8efcc226450ee8f2b81dfe4e69f1f945d/src/llamafactory/train/trainer_utils.py
"""
import torch
import torch.nn.functional as F
def eaft_loss(outputs, labels, num_items_in_batch=None, alpha=1.0, k=20):
"""
compute eaft loss with entropy weighting
args:
outputs: model outputs containing logits
labels: target labels for computing loss
num_items_in_batch: for sample packing support
alpha: exponent for entropy weighting (default 1.0)
k: number of top logits for entropy approximation (default 20)
"""
logits = outputs.logits
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
vocab_size = shift_logits.size(-1)
shift_logits_view = shift_logits.view(-1, vocab_size)
shift_labels_view = shift_labels.view(-1)
mask = shift_labels_view != -100
with torch.no_grad():
top_k_logits, _ = torch.topk(
shift_logits_view[mask].float(), k=min(k, vocab_size), dim=-1
)
top_k_probs = F.softmax(top_k_logits, dim=-1)
entropy = -(top_k_probs * torch.log(top_k_probs + 1e-10)).sum(dim=-1)
weights = torch.pow(entropy, alpha)
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
per_token_loss = loss_fct(shift_logits_view[mask], shift_labels_view[mask])
weighted_loss = per_token_loss * weights
if num_items_in_batch is not None:
loss = weighted_loss.sum() / num_items_in_batch
else:
loss = weighted_loss.mean()
return loss

View File

@@ -1,5 +1,5 @@
"""
Monkeypatch to fix inefficient tensor conversion in MistralCommonBackend.apply_chat_template
Monkeypatch to fix inefficient tensor conversion in MistralCommonTokenizer.apply_chat_template
"""
import importlib
@@ -12,11 +12,11 @@ LOG = get_logger(__name__)
def apply_mistral_tokenizer_image_patch():
"""Apply patch to MistralCommonBackend.apply_chat_template to fix image tensor conversion."""
from transformers.tokenization_mistral_common import MistralCommonBackend
"""Apply patch to MistralCommonTokenizer.apply_chat_template to fix image tensor conversion."""
from transformers.tokenization_mistral_common import MistralCommonTokenizer
# Get original source
original_source = inspect.getsource(MistralCommonBackend.apply_chat_template)
original_source = inspect.getsource(MistralCommonTokenizer.apply_chat_template)
original_source, _ = detab_code(original_source)
# Define the replacement
@@ -41,7 +41,7 @@ def apply_mistral_tokenizer_image_patch():
)
# Load necessary imports from the module
module_name = MistralCommonBackend.__module__
module_name = MistralCommonTokenizer.__module__
module = importlib.import_module(module_name)
# Detect what needs to be imported
@@ -79,7 +79,7 @@ def apply_mistral_tokenizer_image_patch():
exec(patched_source, globals()) # nosec B102
# Replace the method
MistralCommonBackend.apply_chat_template = patched_apply_chat_template
LOG.info("Successfully applied MistralCommonBackend tensor conversion patch")
MistralCommonTokenizer.apply_chat_template = patched_apply_chat_template
LOG.info("Successfully applied MistralCommonTokenizer tensor conversion patch")
else:
LOG.warning("Could not find target code for MistralCommonBackend patching")
LOG.warning("Could not find target code for MistralCommonTokenizer patching")

View File

@@ -9,11 +9,6 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
try:
from fla.modules.convolution import causal_conv1d as fla_causal_conv1d
except ImportError:
fla_causal_conv1d = None
def get_cu_seqlens(position_ids):
"""
@@ -142,11 +137,6 @@ def patch_qwen3_next_gateddelta_layer():
and cache_position is not None
)
# Compute cu_seqlens early for use by both causal_conv1d and chunk_gated_delta_rule
cu_seqlens = None
if not use_precomputed_states and position_ids is not None:
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
# getting projected states from cache if it exists
if cache_params is not None:
conv_state = cache_params.conv_states[self.layer_idx]
@@ -161,11 +151,12 @@ def patch_qwen3_next_gateddelta_layer():
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
)
mixed_qkv = torch.cat((query, key, value), dim=-1) # [B, T, D]
mixed_qkv = torch.cat((query, key, value), dim=-1)
mixed_qkv = mixed_qkv.transpose(1, 2)
if use_precomputed_states:
# Inference single-token path: causal_conv1d_update expects [B, D, T]
mixed_qkv = mixed_qkv.transpose(1, 2)
# 2. Convolution sequence transformation
# NOTE: the conv state is updated in `causal_conv1d_update`
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
@@ -173,41 +164,24 @@ def patch_qwen3_next_gateddelta_layer():
self.conv1d.bias,
self.activation,
)
mixed_qkv = mixed_qkv.transpose(1, 2)
else:
if cache_params is not None:
# Cache state expects [B, D, T] for the inference update path
mixed_qkv_t = mixed_qkv.transpose(1, 2)
conv_state = F.pad(
mixed_qkv_t,
(self.conv_kernel_size - mixed_qkv_t.shape[-1], 0),
mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)
)
cache_params.conv_states[self.layer_idx] = conv_state
if fla_causal_conv1d is not None:
# FLA Triton causal_conv1d: [B, T, D] in/out, with cu_seqlens support
mixed_qkv, _ = fla_causal_conv1d(
if self.causal_conv1d_fn is not None:
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
cu_seqlens=cu_seqlens,
seq_idx=None,
)
else:
# PyTorch fallback (no cu_seqlens support)
if cu_seqlens is not None and cu_seqlens.shape[0] > batch_size + 1:
raise RuntimeError(
"Packed sequences require fla.modules.convolution.causal_conv1d "
"(cu_seqlens support). Install flash-linear-attention or disable packing."
)
LOG.warning_once(
"FLA causal_conv1d not available. Falling back to PyTorch conv1d."
)
mixed_qkv = mixed_qkv.transpose(1, 2)
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
mixed_qkv = mixed_qkv.transpose(1, 2)
# mixed_qkv is [B, T, D] in all paths
mixed_qkv = mixed_qkv.transpose(1, 2)
query, key, value = torch.split(
mixed_qkv,
[
@@ -229,6 +203,7 @@ def patch_qwen3_next_gateddelta_layer():
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
if not use_precomputed_states:
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,

View File

@@ -1,188 +0,0 @@
"""
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
In transformers v5, MoE models store expert weights as fused 3D tensors that BnB
skips (only targets nn.Linear). This module patches weight loading to quantize them
on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization),
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
"""
import bitsandbytes as bnb
import torch
import torch.nn.utils.parametrize as P
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# Module-level state for the loading-time quantization patch.
_moe_load_state = {
"count": 0,
"mode": "4bit",
"quant_type": "nf4",
"compress_statistics": True,
"patched": False,
}
class Bnb8bitParametrization(torch.nn.Module):
"""Parametrization that dequantizes int8 row-wise quantized data on access."""
def __init__(self, row_stats: torch.Tensor):
super().__init__()
self.register_buffer("row_stats", row_stats)
@torch.no_grad()
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
# Flatten 3D+ to 2D for BnB's dequant, then reshape back.
orig_shape = quantized_param.shape
if quantized_param.ndim > 2:
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
result = bnb.functional.int8_vectorwise_dequant(quantized_param, self.row_stats)
return result.reshape(orig_shape)
def _enable_parametrization_cache(module, inputs):
P._cache_enabled += 1
def _disable_parametrization_cache(module, inputs, output):
P._cache_enabled -= 1
if not P._cache_enabled:
P._cache = {}
def replace_parameter_8bit(module, param_name):
"""Replace a module parameter with an 8-bit quantized version using parametrization."""
original_param = getattr(module, param_name)
int8_data, row_stats, _ = bnb.functional.int8_vectorwise_quant(
original_param.data.to(torch.float16)
)
setattr(module, param_name, torch.nn.Parameter(int8_data, requires_grad=False))
del original_param
P.register_parametrization(
module, param_name, Bnb8bitParametrization(row_stats), unsafe=True
)
# Cache dequantized values during forward to avoid redundant dequantization.
if not getattr(module, "_axolotl_8bit_hooks_registered", False):
module.register_forward_pre_hook(_enable_parametrization_cache)
module.register_forward_hook(_disable_parametrization_cache)
module._axolotl_8bit_hooks_registered = True
def patch_moe_quantization_on_load(cfg):
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly.
Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their
name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low.
"""
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
_moe_load_state["mode"] = mode
_moe_load_state["count"] = 0
if _moe_load_state["patched"]:
LOG.debug("MoE loading-time quantization patch already active")
return
import transformers.core_model_loading
import transformers.modeling_utils
if mode == "4bit":
from bitsandbytes.nn.parametrize import replace_parameter_4bit
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
if compress_statistics is None:
compress_statistics = True
_moe_load_state["quant_type"] = quant_type
_moe_load_state["compress_statistics"] = compress_statistics
# Disable caching_allocator_warmup — it pre-allocates a huge tensor at bf16
# size for all params, defeating our on-load quantization VRAM savings.
def _noop_warmup(*args, **kwargs):
pass
transformers.modeling_utils.caching_allocator_warmup = _noop_warmup
original_set_param = transformers.core_model_loading.set_param_for_module
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
original_set_param(model, target_name, param_value, *args, **kwargs)
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
if param_value.ndim >= 3 and param_value.is_cuda:
mod_path, _, pname = target_name.rpartition(".")
mod = model.get_submodule(mod_path) if mod_path else model
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
if "expert" not in target_name.lower():
LOG.debug(
"Skipping non-expert 3D param: %s (shape=%s)",
target_name,
list(param_value.shape),
)
return
if _moe_load_state["mode"] == "4bit":
replace_parameter_4bit(
mod,
pname,
compress_statistics=_moe_load_state["compress_statistics"],
quant_type=_moe_load_state["quant_type"],
)
else:
replace_parameter_8bit(mod, pname)
_moe_load_state["count"] += 1
# Release the bf16 tensor so CUDA memory is freed immediately.
param_value.data = torch.empty(0, device="cpu")
torch.cuda.empty_cache()
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
_moe_load_state["patched"] = True
def get_moe_quantized_count():
"""Return the number of expert parameters quantized during loading."""
return _moe_load_state["count"]
def patch_peft_target_parameters_matching():
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules."""
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
return
from peft.tuners.tuners_utils import BaseTuner
original_inject = BaseTuner._inject_parameters
def _patched_inject_parameters(
self, peft_config, model, adapter_name, low_cpu_mem_usage
):
# Patch target_parameters to use full paths for parametrized modules
original_targets = list(peft_config.target_parameters)
expanded = set(original_targets)
for module_name, module in model.named_modules():
if not hasattr(module, "parametrizations"):
continue
for target in original_targets:
mod_path, _, param_name = target.rpartition(".")
if (
module_name == mod_path or module_name.endswith("." + mod_path)
) and hasattr(module, param_name):
expanded.add(f"{module_name}.{param_name}")
peft_config.target_parameters = sorted(expanded)
try:
return original_inject(
self, peft_config, model, adapter_name, low_cpu_mem_usage
)
finally:
peft_config.target_parameters = original_targets
BaseTuner._inject_parameters = _patched_inject_parameters
patch_peft_target_parameters_matching._axolotl_patched = True
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")

View File

@@ -155,6 +155,7 @@ class ReLoRACallback(TrainerCallback):
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
"adapter",
),
safe_serialization=True,
)
with torch.no_grad():
merge_and_save(
@@ -213,7 +214,7 @@ class ReLoRACallback(TrainerCallback):
self.last_full_model = checkpoint_folder
else:
model.model.save_pretrained(checkpoint_folder)
model.model.save_pretrained(checkpoint_folder, safe_serialization=True)
return control

View File

@@ -52,15 +52,9 @@ def patch_prepare_context_parallel_inputs() -> None:
if item in patched_source:
items_to_import.append(item)
# Use a separate namespace to capture the exec'd function
namespace = {}
exec(f"from {module_name} import ({', '.join(items_to_import)})", namespace)
exec(patched_source, namespace)
exec(f"from {module_name} import ({', '.join(items_to_import)})", globals())
exec(patched_source, globals())
# Explicitly get the function from the namespace
axolotl_prepare_context_parallel_inputs = namespace[
"axolotl_prepare_context_parallel_inputs"
]
Trainer._original_prepare_context_parallel_inputs = (
Trainer._prepare_context_parallel_inputs
)

View File

@@ -28,12 +28,8 @@ PATCHED_EVAL_CODE = {
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
}
ORIGINAL_MAYBE_CODE = (
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item()"
)
PATCHED_MAYBE_CODE = (
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).nanmean().item()"
)
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
def check_evaluation_loop_is_patchable() -> bool:

View File

@@ -14,6 +14,7 @@ from transformers.models.voxtral import VoxtralProcessor
from axolotl.utils.dict import remove_none_values
from axolotl.utils.logging import get_logger
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
LOG = get_logger(__name__)
@@ -429,7 +430,7 @@ class Mistral3ProcessingStrategy(ProcessingStrategy):
def __init__(
self,
processor,
processor: Mistral3Processor,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
@@ -485,58 +486,6 @@ class InternVLProcessingStrategy(ProcessingStrategy):
return labels
class Glm4vProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for GLM4V and GLM4V-MoE vision models."""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.tokenizer = getattr(processor, "tokenizer", processor)
self.image_token = "<|image|>" # nosec
self.begin_image_token = "<|begin_of_image|>" # nosec
self.end_image_token = "<|end_of_image|>" # nosec
self.video_token = "<|video|>" # nosec
self.begin_video_token = "<|begin_of_video|>" # nosec
self.end_video_token = "<|end_of_video|>" # nosec
self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
self.begin_image_token_id = self.tokenizer.convert_tokens_to_ids(
self.begin_image_token
)
self.end_image_token_id = self.tokenizer.convert_tokens_to_ids(
self.end_image_token
)
self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token)
self.begin_video_token_id = self.tokenizer.convert_tokens_to_ids(
self.begin_video_token
)
self.end_video_token_id = self.tokenizer.convert_tokens_to_ids(
self.end_video_token
)
def process_labels(self, input_ids):
labels = input_ids.clone()
labels[labels == self.tokenizer.pad_token_id] = -100
labels[labels == self.image_token_id] = -100
labels[labels == self.begin_image_token_id] = -100
labels[labels == self.end_image_token_id] = -100
labels[labels == self.video_token_id] = -100
labels[labels == self.begin_video_token_id] = -100
labels[labels == self.end_video_token_id] = -100
return labels
def get_processing_strategy(
processor: ProcessorMixin,
chat_template,
@@ -544,8 +493,6 @@ def get_processing_strategy(
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
processing_kwargs = {
"processor": processor,
"chat_template": chat_template,
@@ -553,10 +500,10 @@ def get_processing_strategy(
"image_resize_algorithm": image_resize_algorithm,
}
if chat_template_type in [None, "tokenizer_default"]:
tokenizer = getattr(processor, "tokenizer", processor)
if hasattr(tokenizer, "chat_template"):
processing_kwargs["chat_template"] = tokenizer.chat_template
if chat_template_type in [None, "tokenizer_default"] and hasattr(
processor.tokenizer, "chat_template"
):
processing_kwargs["chat_template"] = processor.tokenizer.chat_template
if chat_template_type == "qwen2_vl":
return Qwen2VLProcessingStrategy(
@@ -585,15 +532,6 @@ def get_processing_strategy(
return Mistral3ProcessingStrategy(
**processing_kwargs,
)
try:
from transformers.models.glm46v.processing_glm46v import Glm46VProcessor
if isinstance(processor, Glm46VProcessor):
return Glm4vProcessingStrategy(
**processing_kwargs,
)
except ImportError:
pass
if isinstance(processor, InternVLProcessor):
return InternVLProcessingStrategy(

View File

@@ -150,8 +150,6 @@ class ChatTemplatePrompter(Prompter):
return self.tokenizer.apply_chat_template(
conversation,
tokenize=True,
return_dict=False,
**chat_template_kwargs,
)

View File

@@ -153,27 +153,13 @@ class TelemetryCallback(TrainerCallback):
self.last_report_step = step
def _extract_last_metrics(self, state: TrainerState) -> dict:
"""Extract last loss, learning_rate, grad_norm, and token metrics from log history."""
"""Extract last loss, learning_rate, and grad_norm from log history."""
if not state.log_history:
return {
"loss": 0,
"ppl": 0,
"learning_rate": 0,
"grad_norm": 0,
"tokens/total": 0,
"tokens/trainable": 0,
"tokens/train_per_sec_per_gpu": 0,
}
return {"loss": 0, "learning_rate": 0, "grad_norm": 0}
last_log = state.log_history[-1]
return {
"loss": last_log.get("loss", 0),
"ppl": last_log.get("ppl", 0),
"learning_rate": last_log.get("learning_rate", 0),
"grad_norm": last_log.get("grad_norm", 0),
"tokens/total": last_log.get("tokens/total", 0),
"tokens/trainable": last_log.get("tokens/trainable", 0),
"tokens/train_per_sec_per_gpu": last_log.get(
"tokens/train_per_sec_per_gpu", 0
),
}

View File

@@ -155,10 +155,6 @@ def send_errors(func: Callable) -> Callable:
},
)
LOG.error(
f"Error captured in telemetry. Run ID: {telemetry_manager.run_id}"
)
raise
return wrapper

Some files were not shown because too many files have changed in this diff Show More