diff --git a/.axolotl-complete.bash b/.axolotl-complete.bash new file mode 100644 index 000000000..9a51399e6 --- /dev/null +++ b/.axolotl-complete.bash @@ -0,0 +1,41 @@ +#!/bin/bash + +_axolotl_completions() { + local cur prev + COMPREPLY=() + cur="${COMP_WORDS[COMP_CWORD]}" + prev="${COMP_WORDS[COMP_CWORD-1]}" + + # If we're completing the first argument (the command) + if [[ $COMP_CWORD -eq 1 ]]; then + mapfile -t COMPREPLY < <(compgen -W "delinearize-llama4 fetch lm-eval merge-sharded-fsdp-weights quantize vllm-serve evaluate inference merge-lora preprocess train" -- "$cur") + return 0 + fi + + # Commands that should complete with directories and YAML files + local -a yaml_commands=("merge-sharded-fsdp-weights" "quantize" "vllm-serve" "evaluate" "inference" "merge-lora" "preprocess" "train") + + # Check if previous word is in our list + if [[ " ${yaml_commands[*]} " =~ (^|[[:space:]])$prev($|[[:space:]]) ]]; then + # Use filename completion which handles directories properly + compopt -o filenames + mapfile -t COMPREPLY < <(compgen -f -- "$cur") + + # Filter to only include directories and YAML files + local -a filtered=() + for item in "${COMPREPLY[@]}"; do + if [[ -d "$item" ]] || [[ "$item" == *.yaml ]] || [[ "$item" == *.yml ]]; then + filtered+=("$item") + fi + done + COMPREPLY=("${filtered[@]}") + + return 0 + fi + + # Default: no completion + return 0 +} + +# Remove the -o nospace option - let filenames handle it +complete -F _axolotl_completions axolotl diff --git a/.bandit b/.bandit index 2d81286ae..b81428751 100644 --- a/.bandit +++ b/.bandit @@ -1,3 +1,3 @@ [bandit] exclude = tests -skips = B101 +skips = B101,B615,B102,B110 diff --git a/.coderabbit.yaml b/.coderabbit.yaml new file mode 100644 index 000000000..821d6bd5b --- /dev/null +++ b/.coderabbit.yaml @@ -0,0 +1,17 @@ +# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json +language: "en-US" +early_access: false +reviews: + profile: "chill" + request_changes_workflow: false + high_level_summary: true + review_status: true + collapse_walkthrough: true + poem: false + sequence_diagrams: false + auto_review: + enabled: true + drafts: false + auto_incremental_review: false +chat: + auto_reply: true diff --git a/.flake8 b/.flake8 deleted file mode 100644 index fd69af775..000000000 --- a/.flake8 +++ /dev/null @@ -1,5 +0,0 @@ -[flake8] -max-line-length = 88 - -select = C,E,F,W,B,B950 -extend-ignore = E203, E501, W503 diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 8f67908e8..fcfd96891 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -57,6 +57,13 @@ We welcome ideas for improvements and new features. To suggest an enhancement, o 5. Push your branch to your fork on GitHub. 6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues. +#### Skipping CI Checks + +You can skip certain CI checks by including specific keywords in your commit messages: + +- `[skip ci]` or `skip ci` - Skips all CI checks for that commit +- `[skip-e2e]` or `skip-e2e` - Skips only end-to-end tests while running other CI checks. You may also include this in the title of your PR to disable end-to-end tests for the entire PR. + ## Style Guidelines ### Code Style diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 6b750fc5a..7af6059c8 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -5,44 +5,26 @@ on: branches: - "main" paths: - - 'Dockerfile-base' + - 'docker/Dockerfile-base' + - 'docker/Dockerfile-uv-base' - '.github/workflows/base.yml' pull_request: paths: - - 'Dockerfile-base' + - 'docker/Dockerfile-base' + - 'docker/Dockerfile-uv-base' - '.github/workflows/base.yml' workflow_dispatch: jobs: build-base: - if: github.repository_owner == 'axolotl-ai-cloud' + if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }} + timeout-minutes: 480 # this job needs to be run on self-hosted GPU runners... runs-on: ubuntu-latest-m strategy: fail-fast: false matrix: include: - - cuda: "124" - cuda_version: 12.4.1 - cudnn_version: "" - python_version: "3.11" - pytorch: 2.5.1 - torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - dockerfile: "Dockerfile-base" - - cuda: "124" - cuda_version: 12.4.1 - cudnn_version: "" - python_version: "3.11" - pytorch: 2.6.0 - torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - dockerfile: "Dockerfile-base" - - cuda: "126" - cuda_version: 12.6.3 - cudnn_version: "" - python_version: "3.11" - pytorch: 2.6.0 - torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - dockerfile: "Dockerfile-base" - cuda: "126" cuda_version: 12.6.3 cudnn_version: "" @@ -50,20 +32,34 @@ jobs: pytorch: 2.7.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-base" - - cuda: "128" + - cuda: "126" cuda_version: 12.6.3 cudnn_version: "" python_version: "3.11" - pytorch: 2.7.0 + pytorch: 2.7.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-base" - cuda: "128" cuda_version: 12.8.1 cudnn_version: "" python_version: "3.11" - pytorch: nightly + pytorch: 2.7.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - dockerfile: "Dockerfile-base-nightly" + dockerfile: "Dockerfile-base" + - cuda: "128" + cuda_version: 12.8.1 + cudnn_version: "" + python_version: "3.11" + pytorch: 2.8.0 + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + dockerfile: "Dockerfile-base" +# - cuda: "128" +# cuda_version: 12.8.1 +# cudnn_version: "" +# python_version: "3.11" +# pytorch: nightly +# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" +# dockerfile: "Dockerfile-base-nightly" # # "next" is for release candidates of pytorch # - cuda: "128" # cuda_version: 12.8.1 @@ -105,7 +101,8 @@ jobs: PYTORCH_VERSION=${{ matrix.pytorch }} TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }} build-base-uv: - if: github.repository_owner == 'axolotl-ai-cloud' + if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }} + timeout-minutes: 480 runs-on: ubuntu-latest-m strategy: fail-fast: false @@ -115,14 +112,21 @@ jobs: cuda_version: 12.6.3 cudnn_version: "" python_version: "3.11" - pytorch: 2.6.0 + pytorch: 2.7.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-uv-base" - cuda: "128" cuda_version: 12.8.1 cudnn_version: "" python_version: "3.11" - pytorch: 2.7.0 + pytorch: 2.7.1 + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + dockerfile: "Dockerfile-uv-base" + - cuda: "128" + cuda_version: 12.8.1 + cudnn_version: "" + python_version: "3.11" + pytorch: 2.8.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-uv-base" steps: diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 2d3c209cc..5b5cc5489 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -23,7 +23,7 @@ jobs: - name: Install dependencies run: | python3 -m pip install jupyter quartodoc - python3 -m pip install -e . --no-deps + python3 -m pip install -e . - name: Build autodoc run: quartodoc build - name: Publish to GitHub Pages (and render) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d04450428..cf322f105 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -3,6 +3,7 @@ on: # check on PRs, and manual triggers merge_group: pull_request: + types: [opened, synchronize, reopened, ready_for_review] paths: - '**.py' - 'requirements.txt' @@ -16,6 +17,7 @@ jobs: pre-commit: name: pre-commit runs-on: ubuntu-latest + if: ${{ !github.event.pull_request.draft }} steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 01606f902..4040ccdc9 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -15,26 +15,26 @@ jobs: fail-fast: false matrix: include: - - cuda: 124 - cuda_version: 12.4.1 - python_version: "3.11" - pytorch: 2.5.1 - axolotl_extras: - - cuda: 124 - cuda_version: 12.4.1 - python_version: "3.11" - pytorch: 2.6.0 - axolotl_extras: vllm - is_latest: true - cuda: 126 cuda_version: 12.6.3 python_version: "3.11" pytorch: 2.7.0 axolotl_extras: + - cuda: 126 + cuda_version: 12.6.3 + python_version: "3.11" + pytorch: 2.7.1 + axolotl_extras: vllm + is_latest: true - cuda: 128 cuda_version: 12.8.1 python_version: "3.11" - pytorch: 2.7.0 + pytorch: 2.7.1 + axolotl_extras: + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.8.0 axolotl_extras: runs-on: axolotl-gpu-runner steps: @@ -83,26 +83,32 @@ jobs: strategy: matrix: include: - - cuda: 124 - cuda_version: 12.4.1 - python_version: "3.11" - pytorch: 2.5.1 - axolotl_extras: - - cuda: 124 - cuda_version: 12.4.1 - python_version: "3.11" - pytorch: 2.6.0 - axolotl_extras: - is_latest: true - cuda: 126 cuda_version: 12.6.3 python_version: "3.11" pytorch: 2.7.0 axolotl_extras: + - cuda: 126 + cuda_version: 12.6.3 + python_version: "3.11" + pytorch: 2.7.1 + axolotl_extras: + is_latest: + - cuda: 126 + cuda_version: 12.6.3 + python_version: "3.11" + pytorch: 2.7.1 + axolotl_extras: vllm + is_latest: true - cuda: 128 cuda_version: 12.8.1 python_version: "3.11" - pytorch: 2.7.0 + pytorch: 2.7.1 + axolotl_extras: + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.8.0 axolotl_extras: runs-on: axolotl-gpu-runner steps: @@ -146,11 +152,24 @@ jobs: strategy: matrix: include: - - cuda: 124 - cuda_version: 12.4.1 + - cuda: 126 + cuda_version: 12.6.3 python_version: "3.11" - pytorch: 2.6.0 + pytorch: 2.7.1 axolotl_extras: + is_latest: + - cuda: 126 + cuda_version: 12.6.3 + python_version: "3.11" + pytorch: 2.7.1 + axolotl_extras: vllm + is_latest: true + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.8.0 + axolotl_extras: + is_latest: runs-on: axolotl-gpu-runner steps: - name: Checkout diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml index 0167df67a..6a92de352 100644 --- a/.github/workflows/multi-gpu-e2e.yml +++ b/.github/workflows/multi-gpu-e2e.yml @@ -21,30 +21,23 @@ concurrency: jobs: test-axolotl-multigpu: - if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }} + if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }} strategy: fail-fast: false matrix: include: - - cuda: 124 - cuda_version: 12.4.1 - python_version: "3.11" - pytorch: 2.6.0 - axolotl_extras: vllm - num_gpus: 2 - nightly_build: "true" - - cuda: 124 - cuda_version: 12.4.1 - python_version: "3.11" - pytorch: 2.5.1 - axolotl_extras: - num_gpus: 2 - nightly_build: "true" - cuda: 126 cuda_version: 12.6.3 python_version: "3.11" - pytorch: 2.7.0 - axolotl_extras: + pytorch: 2.7.1 + axolotl_extras: vllm + num_gpus: 2 + nightly_build: "true" + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.8.0 + axolotl_extras: fbgemm-gpu num_gpus: 2 nightly_build: "true" runs-on: [self-hosted, modal] diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml index 4e61984fb..18b036a0d 100644 --- a/.github/workflows/nightlies.yml +++ b/.github/workflows/nightlies.yml @@ -12,15 +12,15 @@ jobs: fail-fast: false matrix: include: - - cuda: 124 - cuda_version: 12.4.1 + - cuda: 126 + cuda_version: 12.6.3 python_version: "3.11" - pytorch: 2.5.1 + pytorch: 2.7.1 axolotl_extras: - - cuda: 124 - cuda_version: 12.4.1 + - cuda: 128 + cuda_version: 12.8.1 python_version: "3.11" - pytorch: 2.6.0 + pytorch: 2.8.0 axolotl_extras: runs-on: axolotl-gpu-runner steps: @@ -65,15 +65,15 @@ jobs: strategy: matrix: include: - - cuda: 124 - cuda_version: 12.4.1 + - cuda: 126 + cuda_version: 12.6.3 python_version: "3.11" - pytorch: 2.5.1 + pytorch: 2.7.1 axolotl_extras: - - cuda: 124 - cuda_version: 12.4.1 + - cuda: 128 + cuda_version: 12.8.1 python_version: "3.11" - pytorch: 2.6.0 + pytorch: 2.8.0 axolotl_extras: runs-on: axolotl-gpu-runner steps: diff --git a/.github/workflows/preview-docs.yml b/.github/workflows/preview-docs.yml index 5af70b0dc..db4abddce 100644 --- a/.github/workflows/preview-docs.yml +++ b/.github/workflows/preview-docs.yml @@ -2,13 +2,15 @@ name: Preview on: workflow_dispatch: pull_request: - types: [opened, synchronize, reopened] + types: [opened, synchronize, reopened, ready_for_review] # Run the workflow only when one of these files changes paths: - '**/*.md' # any Markdown file - '**/*.qmd' # any Quarto file - - '_quarto.yaml' + - '_quarto.yml' + - docs/scripts/generate_config_docs.py + - src/axolotl/utils/schemas/**.py permissions: checks: write @@ -23,9 +25,12 @@ permissions: jobs: preview: runs-on: ubuntu-latest + if: ${{ !github.event.pull_request.draft }} steps: - name: Check out repository uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.sha }} - name: Set up Quarto uses: quarto-dev/quarto-actions/setup@v2 @@ -38,7 +43,7 @@ jobs: - name: Install dependencies run: | python3 -m pip install jupyter quartodoc - python3 -m pip install -e . --no-deps + python3 -m pip install -e . - name: Build autodoc run: quartodoc build @@ -48,10 +53,12 @@ jobs: - name: Netlify Publish uses: nwtgck/actions-netlify@v3.0 + if: ${{ github.event.pull_request.head.repo.full_name == github.repository }} + id: netlify with: publish-dir: './_site' - enable-pull-request-comment: true - enable-github-deployment: true + enable-pull-request-comment: false + enable-github-deployment: false github-token: ${{ secrets.GITHUB_TOKEN }} deploy-message: "Deployed On Netlify" github-deployment-environment: 'preview' @@ -59,3 +66,13 @@ jobs: env: NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }} NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }} + + - name: Update PR with preview link + if: ${{ steps.netlify.outcome == 'success' }} + uses: marocchino/sticky-pull-request-comment@v2 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + message: | + 📖 **Documentation Preview**: ${{ steps.netlify.outputs.deploy-url }} + + Deployed on Netlify from commit ${{ github.event.pull_request.head.sha }} diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml index 539f7f71b..35cb707eb 100644 --- a/.github/workflows/tests-nightly.yml +++ b/.github/workflows/tests-nightly.yml @@ -18,116 +18,26 @@ jobs: env: SKIP: no-commit-to-branch - preload-cache: - name: Preload HF cache - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python_version: ["3.11"] - pytorch_version: ["2.6.0"] - timeout-minutes: 20 - - env: - AXOLOTL_IS_CI_CACHE_PRELOAD: "1" - - steps: - - name: Check out repository code - uses: actions/checkout@v4 - - - name: Restore HF cache - id: hf-cache-restore - uses: actions/cache/restore@v4 - with: - path: | - /home/runner/.cache/huggingface/hub/datasets--* - /home/runner/.cache/huggingface/hub/models--* - key: ${{ runner.os }}-hf-hub-cache-v2 - - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python_version }} - cache: 'pip' # caching pip dependencies - - - name: upgrade pip - run: | - pip3 install --upgrade pip - pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel - - - name: Install PyTorch - run: | - pip3 install torch==${{ matrix.pytorch_version }} - - - name: Install dependencies - run: | - pip3 show torch - pip3 install --no-build-isolation -U -e . - python scripts/unsloth_install.py | sh - python scripts/cutcrossentropy_install.py | sh - pip3 install -r requirements-dev.txt -r requirements-tests.txt - - - name: Make sure PyTorch version wasn't clobbered - run: | - python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__" - - - name: Ensure axolotl CLI was installed - run: | - axolotl --help - - - name: Pre-Download dataset fixture - run: | - huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures - - - name: Run tests - run: | - pytest -v tests/conftest.py - - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v5 - with: - token: ${{ secrets.CODECOV_TOKEN }} - files: ./coverage.xml - flags: unittests,pytorch-${{ matrix.pytorch_version }} - fail_ci_if_error: false - - - name: cleanup pip cache - run: | - find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; - - - name: Save HF cache - id: hf-cache - uses: actions/cache/save@v4 - with: - path: | - /home/runner/.cache/huggingface/hub/datasets--* - /home/runner/.cache/huggingface/hub/models--* - key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }} - pytest: name: PyTest runs-on: ubuntu-latest - needs: [preload-cache] strategy: fail-fast: false max-parallel: 2 matrix: python_version: ["3.11"] - pytorch_version: ["2.5.1", "2.6.0", "2.7.0"] + pytorch_version: ["2.7.1", "2.8.0"] timeout-minutes: 20 steps: - name: Check out repository code uses: actions/checkout@v4 - - name: Restore HF cache - id: hf-cache-restore - uses: actions/cache/restore@v4 - with: - path: | - /home/runner/.cache/huggingface/hub/datasets--* - /home/runner/.cache/huggingface/hub/models--* - key: ${{ runner.os }}-hf-hub-cache-v2 + - name: Restore Cache from S3 + id: hf-cache-restore-s3 + run: | + mkdir -p /home/runner/.cache/huggingface/hub + curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd - name: Setup Python uses: actions/setup-python@v5 @@ -142,7 +52,7 @@ jobs: - name: Install PyTorch run: | - pip3 install torch==${{ matrix.pytorch_version }} + pip3 install torch==${{ matrix.pytorch_version }} torchvision - name: Update requirements.txt run: | @@ -168,15 +78,11 @@ jobs: run: | axolotl --help - - name: Pre-Download dataset fixture - run: | - huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures - - name: Run tests run: | - pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ - pytest -v tests/patched/ - pytest -v tests/cli/ + pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ + pytest -v --durations=10 tests/patched/ + pytest -v --durations=10 tests/cli/ - name: cleanup pip cache run: | @@ -186,24 +92,24 @@ jobs: if: github.repository_owner == 'axolotl-ai-cloud' # this job needs to be run on self-hosted GPU runners... runs-on: [self-hosted, modal] - timeout-minutes: 60 + timeout-minutes: 120 needs: [pre-commit, pytest] strategy: fail-fast: false matrix: include: - - cuda: 124 - cuda_version: 12.4.1 + - cuda: 126 + cuda_version: 12.6.3 python_version: "3.11" - pytorch: 2.5.1 + pytorch: 2.7.1 num_gpus: 1 axolotl_extras: nightly_build: "true" - - cuda: 124 - cuda_version: 12.4.1 + - cuda: 128 + cuda_version: 12.8.1 python_version: "3.11" - pytorch: 2.6.0 + pytorch: 2.8.0 num_gpus: 1 axolotl_extras: nightly_build: "true" @@ -217,7 +123,7 @@ jobs: - name: Install Modal run: | python -m pip install --upgrade pip - pip install modal==0.71.8 jinja2 + pip install modal==1.0.2 jinja2 - name: Update env vars run: | echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV @@ -231,3 +137,45 @@ jobs: - name: Run tests job on Modal run: | modal run cicd.e2e_tests + docker-e2e-multigpu-tests: + if: github.repository_owner == 'axolotl-ai-cloud' + # this job needs to be run on self-hosted GPU runners... + runs-on: [self-hosted, modal] + timeout-minutes: 120 + needs: [pre-commit, pytest, docker-e2e-tests] + + strategy: + fail-fast: false + matrix: + include: + - cuda: 126 + cuda_version: 12.6.3 + python_version: "3.11" + pytorch: 2.7.1 + num_gpus: 2 + axolotl_extras: + nightly_build: "true" + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Install Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install Modal + run: | + python -m pip install --upgrade pip + pip install modal==1.0.2 jinja2 + - name: Update env vars + run: | + echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV + echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV + echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV + echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV + echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV + echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV + echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV + echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV + - name: Run tests job on Modal + run: | + modal run cicd.multigpu diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ddbd25291..8f368b517 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,6 +13,7 @@ on: - 'cicd/cicd.sh' - 'cicd/Dockerfile.jinja' pull_request: + types: [opened, synchronize, reopened, ready_for_review] paths: - '**.py' - 'requirements.txt' @@ -34,6 +35,7 @@ jobs: pre-commit: name: pre-commit runs-on: ubuntu-latest + if: ${{ !github.event.pull_request.draft }} steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 @@ -47,12 +49,13 @@ jobs: pytest: name: PyTest runs-on: ubuntu-latest + if: ${{ !github.event.pull_request.draft }} # needs: [preload-cache] strategy: fail-fast: false matrix: python_version: ["3.11"] - pytorch_version: ["2.5.1", "2.6.0", "2.7.0"] + pytorch_version: ["2.7.1", "2.8.0"] timeout-minutes: 20 steps: @@ -78,12 +81,12 @@ jobs: - name: Install PyTorch run: | - pip3 install torch==${{ matrix.pytorch_version }} + pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision - name: Install dependencies run: | pip3 show torch - pip3 install --no-build-isolation -U -e . + pip3 install --no-cache-dir --no-build-isolation -U -e . python scripts/unsloth_install.py | sh python scripts/cutcrossentropy_install.py | sh pip3 install -r requirements-dev.txt -r requirements-tests.txt @@ -102,9 +105,10 @@ jobs: - name: Run tests run: | - pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml - pytest -v tests/patched/ --cov=axolotl --cov-append --cov-report=xml - pytest -v tests/cli/ --cov=axolotl --cov-append --cov-report=xml + pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml + pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml + pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml + pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 @@ -121,11 +125,12 @@ jobs: pytest-sdist: name: PyTest from Source Dist runs-on: ubuntu-latest + if: ${{ !github.event.pull_request.draft }} strategy: fail-fast: false matrix: python_version: ["3.11"] - pytorch_version: ["2.5.1", "2.6.0", "2.7.0"] + pytorch_version: ["2.7.1", "2.8.0"] timeout-minutes: 20 steps: @@ -151,13 +156,13 @@ jobs: - name: Install PyTorch run: | - pip3 install torch==${{ matrix.pytorch_version }} + pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision - name: Install dependencies run: | pip3 show torch python -m build --no-isolation --sdist - pip3 install --no-build-isolation dist/axolotl*.tar.gz + pip3 install --no-cache-dir --no-build-isolation dist/axolotl*.tar.gz python scripts/unsloth_install.py | sh python scripts/cutcrossentropy_install.py | sh pip3 install -r requirements-dev.txt -r requirements-tests.txt @@ -175,36 +180,67 @@ jobs: - name: Run tests run: | - pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ - pytest -v tests/patched/ - pytest -v tests/cli/ + pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml + pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml + pytest -v --durations=10 tests/cli/ - name: cleanup pip cache run: | find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; + gate-skip-e2e: + needs: [pre-commit, pytest, pytest-sdist] + runs-on: ubuntu-latest + outputs: + skip: ${{ steps.compute.outputs.skip }} + steps: + - uses: actions/github-script@v7 + id: compute + with: + script: | + const token = /\[skip-e2e\]/i; + let msg = ''; + if (context.eventName === 'push') { + msg = context.payload.head_commit?.message || ''; + } else if (context.eventName === 'pull_request') { + const { owner, repo } = context.repo; + const prNumber = context.payload.pull_request.number; + const commits = await github.paginate( + github.rest.pulls.listCommits, + { owner, repo, pull_number: prNumber, per_page: 100 } + ); + msg = commits.at(-1)?.commit?.message || ''; + } + const title = context.payload.pull_request?.title || ''; + const body = context.payload.pull_request?.body || ''; + const skip = token.test(msg) || token.test(title) || token.test(body); + core.setOutput('skip', String(skip)); + docker-e2e-tests-1st: # Run this job first as a gate for running the remainder of the test matrix - if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }} + if: > + github.repository_owner == 'axolotl-ai-cloud' && + (github.event_name != 'pull_request' || !github.event.pull_request.draft) && + needs.gate-skip-e2e.outputs.skip != 'true' # this job needs to be run on self-hosted GPU runners... runs-on: [self-hosted, modal] - timeout-minutes: 90 - needs: [pre-commit, pytest, pytest-sdist] + timeout-minutes: 120 + needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e] strategy: fail-fast: false matrix: include: - - cuda: 124 - cuda_version: 12.4.1 - python_version: "3.11" - pytorch: 2.6.0 - num_gpus: 1 - axolotl_extras: vllm - cuda: 126 cuda_version: 12.6.3 python_version: "3.11" - pytorch: 2.6.0 + pytorch: 2.7.1 + num_gpus: 1 + axolotl_extras: + - cuda: 126 + cuda_version: 12.6.3 + python_version: "3.11" + pytorch: 2.7.1 num_gpus: 1 axolotl_extras: dockerfile: "Dockerfile-uv.jinja" @@ -235,42 +271,34 @@ jobs: modal run cicd.e2e_tests docker-e2e-tests: - if: github.repository_owner == 'axolotl-ai-cloud' + if: > + github.repository_owner == 'axolotl-ai-cloud' && + (github.event_name != 'pull_request' || !github.event.pull_request.draft) && + needs.gate-skip-e2e.outputs.skip != 'true' # this job needs to be run on self-hosted GPU runners... runs-on: [self-hosted, modal] - timeout-minutes: 90 + timeout-minutes: 120 # Only run the remainder of the matrix if the first e2e check passed; # this is to save on wasted compute costs for known failures that get caught in the first run - needs: [pre-commit, pytest, docker-e2e-tests-1st] + needs: [pre-commit, pytest, gate-skip-e2e, docker-e2e-tests-1st] strategy: fail-fast: false matrix: include: - - cuda: 124 - cuda_version: 12.4.1 + - cuda: 128 + cuda_version: 12.8.1 python_version: "3.11" - pytorch: 2.6.0 - num_gpus: 1 - axolotl_extras: llmcompressor - - cuda: 124 - cuda_version: 12.4.1 - python_version: "3.11" - pytorch: 2.5.1 - num_gpus: 1 - axolotl_extras: - - cuda: 126 - cuda_version: 12.6.3 - python_version: "3.11" - pytorch: 2.7.0 + pytorch: 2.7.1 num_gpus: 1 axolotl_extras: - cuda: 128 cuda_version: 12.8.1 python_version: "3.11" - pytorch: 2.7.0 + pytorch: 2.8.0 num_gpus: 1 - axolotl_extras: + gpu_type: "B200" + axolotl_extras: fbgemm-gpu steps: - name: Checkout uses: actions/checkout@v4 @@ -291,6 +319,7 @@ jobs: echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV + echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV - name: Run tests job on Modal @@ -301,17 +330,18 @@ jobs: runs-on: [self-hosted, modal] timeout-minutes: 90 needs: [docker-e2e-tests] + if: ${{ !github.event.pull_request.draft }} strategy: fail-fast: false matrix: include: - - cuda: 124 - cuda_version: 12.4.1 + - cuda: 126 + cuda_version: 12.6.3 python_version: "3.11" - pytorch: 2.6.0 + pytorch: 2.7.1 num_gpus: 1 - axolotl_extras: vllm + axolotl_extras: steps: - name: Checkout uses: actions/checkout@v4 diff --git a/.gitignore b/.gitignore index 40084b408..b75becc7c 100644 --- a/.gitignore +++ b/.gitignore @@ -190,3 +190,6 @@ out/ # vim *.swp + +# scm auto-versioning +src/axolotl/_version.py diff --git a/.isort.cfg b/.isort.cfg deleted file mode 100644 index bf9afe319..000000000 --- a/.isort.cfg +++ /dev/null @@ -1,4 +0,0 @@ -[settings] -profile=black -known_third_party=wandb,comet_ml -known_local_folder=src,tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 195746d2d..e853243cd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,31 +3,21 @@ default_language_version: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - id: no-commit-to-branch args: ['--branch', 'main'] -- repo: https://github.com/psf/black - rev: 25.1.0 +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.13.3 hooks: - - id: black -- repo: https://github.com/pycqa/isort - rev: 6.0.1 - hooks: - - id: isort -- repo: https://github.com/PyCQA/flake8 - rev: 7.2.0 - hooks: - - id: flake8 -- repo: https://github.com/pylint-dev/pylint - rev: v3.3.7 - hooks: - - id: pylint + - id: ruff + args: [--fix] + - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.16.0 + rev: v1.18.2 hooks: - id: mypy additional_dependencies: @@ -36,7 +26,7 @@ repos: 'pydantic>=2.5.3', ] - repo: https://github.com/PyCQA/bandit - rev: 1.8.3 + rev: 1.8.6 hooks: - id: bandit args: [ diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 208dd32b6..000000000 --- a/.pylintrc +++ /dev/null @@ -1,15 +0,0 @@ -[MASTER] -init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())" - -[TYPECHECK] - -# List of members which are set dynamically and missed by Pylint inference -# system, and so shouldn't trigger E1101 when accessed. -generated-members=numpy.*, torch.* - - -[pylint.messages_control] -disable=missing-function-docstring, line-too-long, import-error, - too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods, - too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation, - too-many-positional-arguments, possibly-used-before-assignment diff --git a/.runpod/README.md b/.runpod/README.md index a631c3937..8042f4f91 100644 --- a/.runpod/README.md +++ b/.runpod/README.md @@ -119,14 +119,15 @@ datasets: ## Dataset Processing -| Option | Default | Description | -| ----------------------------- | -------------------------- | --------------------------------- | -| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset | -| `push_dataset_to_hub` | `""` | Push dataset to HF hub | -| `dataset_processes` | `4` | Number of preprocessing processes | -| `dataset_keep_in_memory` | `false` | Keep dataset in memory | -| `shuffle_merged_datasets` | `true` | Shuffle merged datasets | -| `dataset_exact_deduplication` | `true` | Deduplicate datasets | +| Option | Default | Description | +| --------------------------------- | -------------------------- | ----------------------------------- | +| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset | +| `push_dataset_to_hub` | `""` | Push dataset to HF hub | +| `dataset_processes` | `4` | Number of preprocessing processes | +| `dataset_keep_in_memory` | `false` | Keep dataset in memory | +| `shuffle_merged_datasets` | `true` | Shuffle merged datasets | +| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging | +| `dataset_exact_deduplication` | `true` | Deduplicate datasets | ## LoRA Configuration @@ -184,7 +185,6 @@ datasets: | `flash_attention` | `false` | Use flash attention | | `flash_attn_cross_entropy` | `false` | Flash attention cross entropy | | `flash_attn_rms_norm` | `false` | Flash attention RMS norm | -| `flash_attn_fuse_qkv` | `false` | Fuse QKV operations | | `flash_attn_fuse_mlp` | `false` | Fuse MLP operations | | `sdp_attention` | `false` | Use scaled dot product | | `s2_attention` | `false` | Use shifted sparse attention | @@ -328,7 +328,7 @@ The following optimizers are supported: - Use `gradient_checkpointing: true` to reduce memory usage - Adjust `micro_batch_size` and `gradient_accumulation_steps` based on your GPU memory -For more detailed information, please refer to the [documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html). +For more detailed information, please refer to the [documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/config-reference.html). ### Errors: diff --git a/.runpod/src/config/config.yaml b/.runpod/src/config/config.yaml index 42c5978d5..f482a7331 100644 --- a/.runpod/src/config/config.yaml +++ b/.runpod/src/config/config.yaml @@ -97,7 +97,7 @@ # # 'no_input_format' cannot include {input} # no_input_format: "{instruction} " -# # For `completion` datsets only, uses the provided field instead of `text` column +# # For `completion` datasets only, uses the provided field instead of `text` column # field: # # Axolotl attempts to save the dataset as an arrow after packing the data together so @@ -296,7 +296,6 @@ # flash_attention: # flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only # flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only -# flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation # flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation # # Whether to use scaled-dot-product attention # # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html @@ -541,7 +540,6 @@ xformers_attention: ${XFORMERS_ATTENTION} flash_attention: ${FLASH_ATTENTION} flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY} flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM} -flash_attn_fuse_qkv: ${FLASH_ATTN_FUSE_QKV} flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP} sdp_attention: ${SDP_ATTENTION} s2_attention: ${S2_ATTENTION} diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 000000000..7bbfeec64 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,10 @@ +cff-version: 1.2.0 +type: software +title: "Axolotl: Open Source LLM Post-Training" +message: "If you use this software, please cite it as below." +authors: + - name: "Axolotl maintainers and contributors" +repository-code: "https://github.com/axolotl-ai-cloud/axolotl" +url: "https://axolotl.ai/" +license: Apache-2.0 +date-released: "2023-05-30" diff --git a/MANIFEST.in b/MANIFEST.in index 99324be3c..3fbb0edca 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,4 +2,5 @@ include requirements.txt include README.md include LICENSE include src/setuptools_axolotl_dynamic_dependencies.py +include src/axolotl/utils/chat_templates/templates/*.jinja recursive-include axolotl *.py diff --git a/README.md b/README.md index 06b2bcab1..f4df750b4 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,9 @@ Axolotl

+

+ A Free and Open Source LLM Fine-tuning Framework
+

GitHub License @@ -17,44 +20,69 @@
discord twitter + google-colab
tests-nightly multigpu-semi-weekly tests

-Axolotl is a tool designed to streamline post-training for various AI models. -Post-training refers to any modifications or additional training performed on -pre-trained models - including full model fine-tuning, parameter-efficient tuning (like -LoRA and QLoRA), supervised fine-tuning (SFT), instruction tuning, and alignment -techniques. With support for multiple model architectures and training configurations, -Axolotl makes it easy to get started with these techniques. -Axolotl is designed to work with YAML config files that contain everything you need to -preprocess a dataset, train or fine-tune a model, run model inference or evaluation, -and much more. +## 🎉 Latest Updates + +- 2025/07: + - ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info. + - Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm). + - FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)! + - [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl! + - TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl! +- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more! +- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning. + +
+ +Expand older updates + +- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl! +- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version! +- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own! +- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try. +- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun! +- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html). + +
+ +## ✨ Overview + +Axolotl is a free and open-source tool designed to streamline post-training and fine-tuning for the latest large language models (LLMs). Features: -- Train various Huggingface models such as llama, pythia, falcon, mpt -- Supports fullfinetune, lora, qlora, relora, and gptq -- Customize configurations using a simple yaml file or CLI overwrite -- Load different dataset formats, use custom formats, or bring your own tokenized datasets -- Integrated with [xformers](https://github.com/facebookresearch/xformers), flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking -- Works with single GPU or multiple GPUs via FSDP or Deepspeed -- Easily run with Docker locally or on the cloud -- Log results and optionally checkpoints to wandb, mlflow or Comet -- And more! +- **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub. +- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, and audio models like Voxtral with image, video, and audio support. +- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), and Reward Modelling (RM) / Process Reward Modelling (PRM). +- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference. +- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more! +- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets. +- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware. -## 🚀 Quick Start + + +## 🚀 Quick Start - LLM Fine-tuning in Minutes **Requirements**: - NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU - Python 3.11 -- PyTorch ≥2.5.1 +- PyTorch ≥2.7.1 + +### Google Colab + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa) ### Installation +#### Using pip + ```bash pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install --no-build-isolation axolotl[flash-attn,deepspeed] @@ -64,8 +92,29 @@ axolotl fetch examples axolotl fetch deepspeed_configs # OPTIONAL ``` +#### Using Docker + +Installing with Docker can be less error prone than installing in your own environment. +```bash +docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest +``` + Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html). +#### Cloud Providers + +
+ +- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz) +- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=github&utm_medium=developer_community&utm_campaign=template_launch_axolotl&utm_content=readme) +- [PRIME Intellect](https://app.primeintellect.ai/dashboard/create-cluster?image=axolotl&location=Cheapest&security=Cheapest&show_spot=true) +- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) +- [Novita](https://novita.ai/gpus-console?templateId=311) +- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl) +- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c) + +
+ ### Your First Fine-tune ```bash @@ -81,19 +130,12 @@ axolotl train examples/llama-3/lora-1b.yml That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/getting-started.html) for a more detailed walkthrough. -## ✨ Key Features - -- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more -- **Training Methods**: Full fine-tuning, LoRA, QLoRA, and more -- **Easy Configuration**: Simple YAML files to control your training setup -- **Performance Optimizations**: Flash Attention, xformers, multi-GPU training -- **Flexible Dataset Handling**: Use various formats and custom datasets -- **Cloud Ready**: Run on cloud platforms or local hardware ## 📚 Documentation - [Installation Options](https://docs.axolotl.ai/docs/installation.html) - Detailed setup instructions for different environments -- [Configuration Guide](https://docs.axolotl.ai/docs/config.html) - Full configuration options and examples +- [Configuration Guide](https://docs.axolotl.ai/docs/config-reference.html) - Full configuration options and examples +- [Dataset Loading](https://docs.axolotl.ai/docs/dataset_loading.html) - Loading datasets from various sources - [Dataset Guide](https://docs.axolotl.ai/docs/dataset-formats/) - Supported formats and how to use them - [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) - [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) @@ -146,14 +188,22 @@ enable it, set AXOLOTL_DO_NOT_TRACK=0. For more details, see our [telemetry docu ## ❤️ Sponsors -Thank you to our sponsors who help make Axolotl possible: - -- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) - Modal lets you run -jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, -fine-tune large language models, run protein folding simulations, and much more. - Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai) +## 📝 Citing Axolotl + +If you use Axolotl in your research or projects, please cite it as follows: + +```bibtex +@software{axolotl, + title = {Axolotl: Open Source LLM Post-Training}, + author = {{Axolotl maintainers and contributors}}, + url = {https://github.com/axolotl-ai-cloud/axolotl}, + license = {Apache-2.0}, + year = {2023} +} +``` + ## 📜 License This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details. diff --git a/TODO.md b/TODO.md deleted file mode 100644 index 2002bbbaf..000000000 --- a/TODO.md +++ /dev/null @@ -1,10 +0,0 @@ -# todo list - -- [] Validation of parameters for combinations that won't work - - - -## things that are known not to work - -- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203 -- adamw_bnb_8bit doesn't play well with FSDP offload diff --git a/_quarto.yml b/_quarto.yml index a7b07a8f4..c97b9838e 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -1,5 +1,6 @@ project: type: website + pre-render: docs/scripts/generate_config_docs.py quartodoc: dir: docs/api @@ -34,25 +35,30 @@ quartodoc: - cli.train - cli.evaluate - cli.args + - cli.art - cli.checks - cli.config + - cli.delinearize_llama4 - cli.inference - cli.merge_lora - cli.merge_sharded_fsdp_weights - cli.preprocess - - cli.sweeps - - cli.utils + - cli.quantize - cli.vllm_serve - cli.cloud.base - cli.cloud.modal_ - - cli.quantize + - cli.utils + - cli.utils.args + - cli.utils.fetch + - cli.utils.load + - cli.utils.sweeps + - cli.utils.train - title: Trainers desc: Training implementations contents: - core.trainers.base - core.trainers.trl - core.trainers.mamba - - core.trainers.relora - core.trainers.dpo.trainer - core.trainers.grpo.trainer - core.trainers.grpo.sampler @@ -147,7 +153,7 @@ quartodoc: - utils.distributed - utils.dict - utils.optimizers.adopt - - utils.data.pretraining + - utils.data.streaming - utils.data.sft - utils.quantization - title: Schemas @@ -235,8 +241,8 @@ website: - docs/installation.qmd - docs/inference.qmd - docs/cli.qmd - - docs/config.qmd - docs/telemetry.qmd + - docs/config-reference.qmd - text: "API Reference" href: docs/api @@ -262,12 +268,16 @@ website: - docs/dataset_loading.qmd - docs/qat.qmd - docs/quantize.qmd + - docs/optimizations.qmd - section: "Core Concepts" contents: - docs/batch_vs_grad.qmd - docs/dataset_preprocessing.qmd + - docs/streaming.qmd - docs/multipack.qmd + - docs/mixed_precision.qmd + - docs/optimizers.qmd - section: "Advanced Features" contents: @@ -276,6 +286,8 @@ website: - docs/torchao.qmd - docs/custom_integrations.qmd - docs/sequence_parallelism.qmd + - docs/gradient_checkpointing.qmd + - docs/nd_parallelism.qmd - section: "Troubleshooting" contents: diff --git a/cicd/Dockerfile-uv.jinja b/cicd/Dockerfile-uv.jinja index 84527274d..6a4d8a7d3 100644 --- a/cicd/Dockerfile-uv.jinja +++ b/cicd/Dockerfile-uv.jinja @@ -11,7 +11,7 @@ ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}" ENV HF_HOME="{{ HF_HOME }}" RUN apt-get update && \ - apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev + apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm WORKDIR /workspace @@ -32,6 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ fi RUN uv pip install packaging==23.2 setuptools==75.8.0 +RUN uv pip install torchvision RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 6988e092b..6a1ddb66d 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -9,9 +9,10 @@ ENV GITHUB_REF="{{ GITHUB_REF }}" ENV GITHUB_SHA="{{ GITHUB_SHA }}" ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}" ENV HF_HOME="{{ HF_HOME }}" +ENV AXOLOTL_DATASET_NUM_PROC="8" RUN apt-get update && \ - apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev + apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm WORKDIR /workspace diff --git a/cicd/e2e_tests.py b/cicd/e2e_tests.py index ce9c605c7..5d2b6fed1 100644 --- a/cicd/e2e_tests.py +++ b/cicd/e2e_tests.py @@ -6,7 +6,7 @@ from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd @app.function( image=cicd_image, gpu=GPU_CONFIG, - timeout=90 * 60, # 90 min + timeout=120 * 60, # 90 min cpu=8.0, memory=131072, volumes=VOLUME_CONFIG, diff --git a/cicd/multigpu.py b/cicd/multigpu.py index a2dd8d0b3..5bd8d3c04 100644 --- a/cicd/multigpu.py +++ b/cicd/multigpu.py @@ -2,8 +2,6 @@ modal application to run axolotl gpu tests in Modal """ -# pylint: disable=duplicate-code - import os import pathlib import tempfile @@ -24,9 +22,9 @@ df_template = template_env.get_template("Dockerfile.jinja") df_args = { "AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""), "AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""), - "PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"), - "BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"), - "CUDA": os.environ.get("CUDA", "124"), + "PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"), + "BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"), + "CUDA": os.environ.get("CUDA", "126"), "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), "CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""), @@ -63,13 +61,13 @@ def run_cmd(cmd: str, run_folder: str): # Propagate errors from subprocess. if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec - exit(exit_code) # pylint: disable=consider-using-sys-exit + exit(exit_code) @app.function( image=cicd_image, gpu=GPU_CONFIG, - timeout=90 * 60, + timeout=120 * 60, cpu=16.0, memory=131072 * N_GPUS, volumes=VOLUME_CONFIG, diff --git a/cicd/multigpu.sh b/cicd/multigpu.sh index 1f74cd67d..3ec4456b9 100755 --- a/cicd/multigpu.sh +++ b/cicd/multigpu.sh @@ -2,7 +2,7 @@ set -e # Only run two tests at a time to avoid OOM on GPU (with coverage collection) -pytest -v -n2 \ +pytest -v --durations=10 -n2 \ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \ --ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \ /workspace/axolotl/tests/e2e/multigpu/ \ @@ -19,5 +19,7 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \ --cov-append \ --cov-report=xml:multigpu-coverage.xml -# Upload coverage to Codecov -codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true +# Upload coverage to Codecov if CODECOV_TOKEN is available +if [ -n "$CODECOV_TOKEN" ]; then + codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true +fi diff --git a/cicd/single_gpu.py b/cicd/single_gpu.py index 2ce3b0662..cd73f60b8 100644 --- a/cicd/single_gpu.py +++ b/cicd/single_gpu.py @@ -1,7 +1,5 @@ """Modal app to run axolotl GPU tests""" -# pylint: disable=duplicate-code - import os import pathlib import tempfile @@ -24,14 +22,16 @@ df_template = template_env.get_template(dockerfile) df_args = { "AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""), "AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""), - "PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"), - "BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"), - "CUDA": os.environ.get("CUDA", "124"), + "PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"), + "BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"), + "CUDA": os.environ.get("CUDA", "126"), "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), "NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""), "CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""), "HF_HOME": "/workspace/data/huggingface-cache/hub", + "PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"), + "DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"), } dockerfile_contents = df_template.render(**df_args) @@ -57,12 +57,21 @@ VOLUME_CONFIG = { } N_GPUS = int(os.environ.get("N_GPUS", 1)) -GPU_CONFIG = f"L40S:{N_GPUS}" +GPU_TYPE = os.environ.get("GPU_TYPE", "L40S") +GPU_CONFIG = f"{GPU_TYPE}:{N_GPUS}" def run_cmd(cmd: str, run_folder: str): import subprocess # nosec + sp_env = os.environ.copy() + sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8" + # Propagate errors from subprocess. - if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec - exit(exit_code) # pylint: disable=consider-using-sys-exit + try: + exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec + if exit_code: + print(f"Command '{cmd}' failed with exit code {exit_code}") + return exit_code + except Exception as e: # pylint: disable=broad-except + print(f"Command '{cmd}' failed with exception {e}") diff --git a/codecov.yml b/codecov.yml index 2741b1758..fa3ad3073 100644 --- a/codecov.yml +++ b/codecov.yml @@ -12,7 +12,7 @@ coverage: default: # basic target: auto - threshold: 0% + threshold: 1% base: auto # advanced branches: null @@ -22,11 +22,12 @@ coverage: only_pulls: true flags: null paths: null + informational: true patch: default: # basic target: auto - threshold: 0% + threshold: 1% base: auto # advanced branches: null diff --git a/deepspeed_configs/zero2_torch_compile.json b/deepspeed_configs/zero2_torch_compile.json new file mode 100644 index 000000000..c3bcf98cf --- /dev/null +++ b/deepspeed_configs/zero2_torch_compile.json @@ -0,0 +1,31 @@ +{ + "compile": { + "disable": false, + "backend": "inductor" + }, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu" + }, + "contiguous_gradients": true, + "overlap_comm": true + }, + "bf16": { + "enabled": "auto" + }, + "fp16": { + "enabled": "auto", + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 32, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/deepspeed_configs/zero3.json b/deepspeed_configs/zero3.json index 90ec3677e..f8c9cdfe0 100644 --- a/deepspeed_configs/zero3.json +++ b/deepspeed_configs/zero3.json @@ -7,9 +7,9 @@ "reduce_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto", "stage3_param_persistence_threshold": "auto", - "stage3_max_live_parameters": 0, - "stage3_max_reuse_distance": 0, - "stage3_gather_16bit_weights_on_model_save": true + "max_live_parameters": 0, + "max_reuse_distance": 0, + "gather_16bit_weights_on_model_save": true }, "bf16": { "enabled": "auto" diff --git a/deepspeed_configs/zero3_bf16.json b/deepspeed_configs/zero3_bf16.json index 49fb75755..a69e13cf7 100644 --- a/deepspeed_configs/zero3_bf16.json +++ b/deepspeed_configs/zero3_bf16.json @@ -7,9 +7,9 @@ "reduce_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto", "stage3_param_persistence_threshold": "auto", - "stage3_max_live_parameters": 0, - "stage3_max_reuse_distance": 0, - "stage3_gather_16bit_weights_on_model_save": true + "max_live_parameters": 0, + "max_reuse_distance": 0, + "gather_16bit_weights_on_model_save": true }, "bf16": { "enabled": true diff --git a/deepspeed_configs/zero3_bf16_cpuoffload_all.json b/deepspeed_configs/zero3_bf16_cpuoffload_all.json index 3ccc66db4..5112c570b 100644 --- a/deepspeed_configs/zero3_bf16_cpuoffload_all.json +++ b/deepspeed_configs/zero3_bf16_cpuoffload_all.json @@ -17,9 +17,9 @@ "reduce_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto", "stage3_param_persistence_threshold": "auto", - "stage3_max_live_parameters": 0, - "stage3_max_reuse_distance": 0, - "stage3_gather_16bit_weights_on_model_save": true + "max_live_parameters": 0, + "max_reuse_distance": 0, + "gather_16bit_weights_on_model_save": true }, "bf16": { "enabled": true diff --git a/deepspeed_configs/zero3_bf16_cpuoffload_params.json b/deepspeed_configs/zero3_bf16_cpuoffload_params.json index fe21d35f8..a2ac82341 100644 --- a/deepspeed_configs/zero3_bf16_cpuoffload_params.json +++ b/deepspeed_configs/zero3_bf16_cpuoffload_params.json @@ -13,9 +13,9 @@ "reduce_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto", "stage3_param_persistence_threshold": "auto", - "stage3_max_live_parameters": 0, - "stage3_max_reuse_distance": 0, - "stage3_gather_16bit_weights_on_model_save": true + "max_live_parameters": 0, + "max_reuse_distance": 0, + "gather_16bit_weights_on_model_save": true }, "bf16": { "enabled": true diff --git a/devtools/dev_chat_template.yml b/devtools/dev_chat_template.yml index 27dc9be1a..32d5e56a0 100644 --- a/devtools/dev_chat_template.yml +++ b/devtools/dev_chat_template.yml @@ -13,7 +13,7 @@ datasets: val_set_size: 0 output_dir: temp_debug/axolotl_outputs/model dataset_prepared_path: temp_debug/axolotl_outputs/data -dataset_processes: 1 +dataset_num_proc: 1 sequence_len: 4096 sample_packing: false diff --git a/docker/Dockerfile b/docker/Dockerfile index e23a729d4..116361dcd 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -10,7 +10,9 @@ ARG PYTORCH_VERSION="2.1.2" ENV PYTORCH_VERSION=$PYTORCH_VERSION RUN apt-get update && \ - apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs + apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \ + rm -rf /var/cache/apt/archives && \ + rm -rf /var/lib/apt/lists/* WORKDIR /workspace @@ -23,17 +25,17 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \ - fi + fi && \ + python scripts/unsloth_install.py | sh && \ + python scripts/cutcrossentropy_install.py | sh && \ + pip install pytest && \ + pip cache purge -RUN python scripts/unsloth_install.py | sh -RUN python scripts/cutcrossentropy_install.py | sh - -# So we can test the Docker image -RUN pip install pytest - -# fix so that git fetch/pull from remote works +# fix so that git fetch/pull from remote works with shallow clone RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ - git config --get remote.origin.fetch + git config --get remote.origin.fetch && \ + git config --global credential.helper store -# helper for huggingface-login cli -RUN git config --global credential.helper store +COPY .axolotl-complete.bash /root/.axolotl-complete.bash +RUN chmod +x /root/.axolotl-complete.bash && \ + echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index cf1af9682..87918cc41 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -16,12 +16,19 @@ ENV PYTHON_VERSION=$PYTHON_VERSION ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST RUN apt-get update \ - && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \ + && apt-get install -y --no-install-recommends \ + wget git build-essential ninja-build git-lfs libaio-dev pkg-config \ + ibverbs-providers ibverbs-utils infiniband-diags \ + librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \ + && rm -rf /var/cache/apt/archives \ + && rm -rf /var/lib/apt/lists/* \ && wget \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ && mkdir /root/.conda \ && bash Miniconda3-latest-Linux-x86_64.sh -b \ && rm -f Miniconda3-latest-Linux-x86_64.sh \ + && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \ + && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \ && conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}" ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}" @@ -30,14 +37,16 @@ WORKDIR /workspace RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \ python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \ - python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \ - python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" + CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \ + python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \ + python3 -m pip cache purge RUN git lfs install --skip-repo && \ pip3 install awscli && \ # The base image ships with `pydantic==1.8.2` which is not working - pip3 install -U --no-cache-dir pydantic==1.10.10 + pip3 install -U --no-cache-dir pydantic==1.10.10 && \ + pip3 cache purge -RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \ - pip3 install flash-attn==2.7.4.post1; \ +RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \ + FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \ fi diff --git a/docker/Dockerfile-base-next b/docker/Dockerfile-base-next index a968b5913..85bac2516 100644 --- a/docker/Dockerfile-base-next +++ b/docker/Dockerfile-base-next @@ -29,7 +29,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}" WORKDIR /workspace RUN python3 -m pip install --upgrade pip && pip3 install packaging && \ - python3 -m pip install --no-cache-dir -U torch==2.7.0 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \ + python3 -m pip install --no-cache-dir -U torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \ python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \ python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" diff --git a/docker/Dockerfile-base-nightly b/docker/Dockerfile-base-nightly index 85805ea41..cc74e6bb9 100644 --- a/docker/Dockerfile-base-nightly +++ b/docker/Dockerfile-base-nightly @@ -22,18 +22,22 @@ RUN apt-get update \ && mkdir /root/.conda \ && bash Miniconda3-latest-Linux-x86_64.sh -b \ && rm -f Miniconda3-latest-Linux-x86_64.sh \ + && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \ + && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \ && conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}" ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}" WORKDIR /workspace -RUN python3 -m pip install --upgrade pip && pip3 install packaging && \ +RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \ python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \ python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \ - python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" + python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \ + python3 -m pip cache purge RUN git lfs install --skip-repo && \ pip3 install awscli && \ # The base image ships with `pydantic==1.8.2` which is not working - pip3 install -U --no-cache-dir pydantic==1.10.10 + pip3 install -U --no-cache-dir pydantic==1.10.10 && \ + pip3 cache purge diff --git a/docker/Dockerfile-cloud b/docker/Dockerfile-cloud index c84ea1dca..6ab090826 100644 --- a/docker/Dockerfile-cloud +++ b/docker/Dockerfile-cloud @@ -14,7 +14,10 @@ COPY scripts/motd /etc/motd RUN pip install jupyterlab notebook ipywidgets && \ jupyter lab clean -RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \ +RUN apt update && \ + apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \ + rm -rf /var/cache/apt/archives && \ + rm -rf /var/lib/apt/lists/* && \ mkdir -p ~/.ssh && \ chmod 700 ~/.ssh && \ printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \ diff --git a/docker/Dockerfile-cloud-no-tmux b/docker/Dockerfile-cloud-no-tmux index 165063105..594559cfd 100644 --- a/docker/Dockerfile-cloud-no-tmux +++ b/docker/Dockerfile-cloud-no-tmux @@ -9,13 +9,15 @@ ENV HF_HUB_ENABLE_HF_TRANSFER="1" EXPOSE 8888 EXPOSE 22 -COPY scripts/cloud-entrypoint-term.sh /root/cloud-entrypoint.sh +COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh COPY scripts/motd /etc/motd RUN pip install jupyterlab notebook ipywidgets && \ jupyter lab clean -RUN apt install --yes --no-install-recommends openssh-server tmux sudo && \ - pip3 install -U --no-cache-dir grpcio ray[default]==2.9.3 && \ +RUN apt update && \ + apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm && \ + rm -rf /var/cache/apt/archives && \ + rm -rf /var/lib/apt/lists/* && \ mkdir -p ~/.ssh && \ chmod 700 ~/.ssh && \ printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \ diff --git a/docker/Dockerfile-uv-base b/docker/Dockerfile-uv-base index 5ac8d86c7..eaa49b9e9 100644 --- a/docker/Dockerfile-uv-base +++ b/docker/Dockerfile-uv-base @@ -29,8 +29,8 @@ RUN uv venv --no-project --relocatable axolotl-venv ENV PATH="/workspace/axolotl-venv/bin:${PATH}" -RUN uv pip install packaging setuptools wheel \ - && uv pip install torch==${PYTORCH_VERSION} \ +RUN uv pip install packaging setuptools wheel psutil \ + && uv pip install torch==${PYTORCH_VERSION} torchvision \ && uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \ && uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \ && uv pip install awscli pydantic diff --git a/docs/.gitignore b/docs/.gitignore index 6c3cb2070..89407326f 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -2,3 +2,4 @@ _site/ /api/*.qmd /api/*.html +config-reference.qmd diff --git a/docs/cli.qmd b/docs/cli.qmd index f6f9b3481..d9f26dbf8 100644 --- a/docs/cli.qmd +++ b/docs/cli.qmd @@ -23,6 +23,20 @@ axolotl [config.yml] [options] The config file can be local or a URL to a raw YAML file. +### Launcher Arguments + +For commands that support multi-GPU (`train`, `evaluate`, ...), you can pass launcher-specific arguments using the `--` separator: + +```bash +# Pass torchrun arguments +axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1 + +# Pass accelerate arguments +axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml --num_processes=4 +``` + +Arguments after `--` are passed directly to the launcher (torchrun, accelerate launch, etc.). + ## Command Reference ### fetch @@ -80,7 +94,11 @@ axolotl train config.yml \ --num-epochs 3 # Training without accelerate -axolotl train config.yml --no-accelerate +axolotl train config.yml --launcher python + +# Pass launcher-specific arguments using -- separator +axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1 +axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml # Resume training from checkpoint axolotl train config.yml --resume-from-checkpoint path/to/checkpoint @@ -175,6 +193,9 @@ Evaluates a model's performance (loss etc) on the train and eval datasets. ```bash # Basic evaluation axolotl evaluate config.yml + +# Evaluation with launcher arguments +axolotl evaluate config.yml --launcher torchrun -- --nproc_per_node=2 ``` ### lm-eval @@ -287,9 +308,6 @@ axolotl preprocess config.yml --cloud cloud_config.yml # Train on cloud axolotl train config.yml --cloud cloud_config.yml -# Train without accelerate on cloud -axolotl train config.yml --cloud cloud_config.yml --no-accelerate - # Run lm-eval on cloud axolotl lm-eval config.yml --cloud cloud_config.yml ``` diff --git a/docs/config.qmd b/docs/config.qmd deleted file mode 100644 index 519065554..000000000 --- a/docs/config.qmd +++ /dev/null @@ -1,795 +0,0 @@ ---- -title: Config Reference -description: A complete list of all configuration options. ---- - -```yaml -# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files -# This can also be a relative path to a model on disk -base_model: ./llama-7b-hf -# You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc) -base_model_ignore_patterns: -# If the base_model repo on hf hub doesn't include configuration .json files, -# You can set that here, or leave this empty to default to base_model -base_model_config: ./llama-7b-hf -# You can specify to choose a specific model revision from huggingface hub -revision_of_model: -# Optional tokenizer configuration path in case you want to use a different tokenizer -# than the one defined in the base model -tokenizer_config: -# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too -model_type: AutoModelForCausalLM -# Corresponding tokenizer for the model AutoTokenizer is a good choice -tokenizer_type: AutoTokenizer -# Trust remote code for untrusted source -trust_remote_code: -# use_fast option for tokenizer loading from_pretrained, default to True -tokenizer_use_fast: -# Whether to use the legacy tokenizer setting, defaults to True -tokenizer_legacy: -# Resize the model embeddings when new tokens are added to multiples of 32 -# This is reported to improve training speed on some models -resize_token_embeddings_to_32x: -# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink. -shrink_embeddings: -# Optional[bool] Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs -embeddings_skip_upcast: -# Whether to load the model with randomly initialized weights. Useful for -# pre-training a model from scratch or debugging purposes. -random_init_weights: - -# (Internal use only) -# Used to identify which the model is based on -is_falcon_derived_model: -is_llama_derived_model: -is_qwen_derived_model: -# Please note that if you set this to true, `padding_side` will be set to "left" by default -is_mistral_derived_model: - -# optional overrides to the base model configuration -overrides_of_model_config: - # RoPE Scaling https://github.com/huggingface/transformers/pull/24653 - rope_scaling: - type: # linear | dynamic - factor: # float - -# optional overrides the base model loading from_pretrained -overrides_of_model_kwargs: - # use_cache: False - -# optional overrides to the bnb 4bit quantization configuration -# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig -bnb_config_kwargs: - # These are default values - llm_int8_has_fp16_weight: false - bnb_4bit_quant_type: nf4 - bnb_4bit_use_double_quant: true - -# quantization aware training -qat: - activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8" - weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8" - group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization - fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after - -# post-training quantization -quantization: - weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8 - activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8" - group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization - quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer. - - -# Whether you are training a 4-bit GPTQ quantized model -gptq: true - -# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer -load_in_8bit: true -# Use bitsandbytes 4 bit -load_in_4bit: - -# Use CUDA bf16 -bf16: true # bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere -# Use CUDA fp16 -fp16: true -# Use CUDA tf32 -tf32: true # require >=ampere -# Note: if bf16 is set to 'auto', and fp16 is set to true, we will prefer the explict fp16 setting - -# No AMP (automatic mixed precision) -bfloat16: true # require >=ampere -float16: true - -# Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset -gpu_memory_limit: 20GiB -# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge -lora_on_cpu: true - -# List[str]. Add plugins to extend the pipeline. -# See `src/axolotl/integrations` for the available plugins or doc below for more details. -# https://docs.axolotl.ai/docs/custom_integrations.html -plugins: - # - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin - -# A list of one or more datasets to finetune the model with -# See https://docs.axolotl.ai/docs/dataset_loading.html for guide on loading datasets -# See https://docs.axolotl.ai/docs/dataset-formats/ for guide on dataset formats -datasets: - # HuggingFace dataset repo | s3:// | gs:// | path to local file or directory - - path: vicgalle/alpaca-gpt4 - # The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection] - type: alpaca # format | format: (chat/instruct) | .load_ - ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file - data_files: # Optional[str] path to source data files - - shards: # Optional[int] split dataset into N pieces (use with shards_idx) - shards_idx: # Optional[int] = 0 the index of sharded dataset to use - - preprocess_shards: # Optional[int] process dataset in N sequential chunks for memory efficiency (exclusive with `shards`) - - name: # Optional[str] name of dataset configuration to load - split: train # Optional[str] name of dataset split to load from - revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets. - trust_remote_code: # Optional[bool] Trust remote code for untrusted source - - # Custom user instruction prompt - - path: repo - type: - # The below are defaults. only set what's needed if you use a different column name. - system_prompt: "" - system_format: "{system}" - field_system: system - field_instruction: instruction - field_input: input - field_output: output - - # Customizable to be single line or multi-line - # Use {instruction}/{input} as key to be replaced - # 'format' can include {input} - format: |- - User: {instruction} {input} - Assistant: - # 'no_input_format' cannot include {input} - no_input_format: "{instruction} " - - # For `completion` datsets only, uses the provided field instead of `text` column - field: - - # Using chat template - - path: ... - # Set type to `chat_template` to use this strategy - type: chat_template - # Specify the name of the chat template to use - # The name of the chat template to use for training, following values are supported: - # - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default. - # - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py - # - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml. - # - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field. - chat_template: tokenizer_default - - # Custom jinja chat template. Used only if `chat_template: jinja` or empty. - chat_template_jinja: - - # Key containing the messages (default: "messages") - field_messages: messages - - # Key containing the system message (default: "system") - # If the system message is not present in the dataset sample, it will be loaded from the field_system property. - field_system: system - - # Mapping of properties from the input dataset to the chat template. - # (default: message_property_mappings={'role':'role', 'content':'content'}) - # If a property exists in the template but not in this mapping, the system will attempt - # to load it directly from the message using the property name as the key. - # Example: In the mapping below, 'from' is loaded from input dataset and used as 'role', - # while 'value' is loaded and used as 'content' in the chat template. - message_property_mappings: - role: from - content: value - # ... - - # Optional[Dict[str, List]]. Roles mapping in the messages. - # The format is {target_role: [source_roles]}. All source roles will be mapped to the target role. - # The default is: - roles: - user: ["human", "user"] - assistant: ["gpt", "assistant"] - system: ["system"] - tool: ["tool"] - - # Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template. - # This does not drop the default system message from chat_template if it exists. If you wish to, - # we recommend using a custom jinja template with the default system message removed or - # adding a system turn with empty content. - drop_system_message: - - # Optional[bool]. (for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags - # See example at `docs/dataset-formats/conversation.qmd` - split_thinking: - - # IMPORTANT: The following fields determine which parts of the conversation to train on. - # Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train - # See examples at `docs/dataset-formats/conversation.qmd` - # Note: If the below 5 fields are empty, defaults to training only on the last message. - - # Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss. - roles_to_train: ["assistant"] # default - # Optional[str]. Which EOS tokens to train on in the conversation. Possible values are: - # - all: train on all EOS tokens - # - turn (default): train on the EOS token at the end of each trainable turn - # - last: train on the last EOS token in the conversation - # TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`. - train_on_eos: turn - # Optional[str]. Which EOT (End-of-Turn) tokens to train on in the conversation. Possible values are: - # - all: train on all EOT tokens - # - turn: train on the EOT token at the end of each trainable turn - # - last: train on the last EOT token in the conversation - # If not specified, defaults to the value of train_on_eos for backward compatibility. - train_on_eot: - # The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`. - message_field_training: training - # The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn. - # The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train). - message_field_training_detail: train_detail - - -# If false, the datasets will not be shuffled and will keep their original order in `datasets`. -# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true. -shuffle_merged_datasets: true - -# Deduplicates datasets and test_datasets with identical entries. -dataset_exact_deduplication: true - -# A list of one or more datasets to eval the model with. -# You can use either test_datasets, or val_set_size, but not both. -test_datasets: - - path: /workspace/data/eval.jsonl - ds_type: json - # You need to specify a split. For "json" datasets the default split is called "train". - split: train - type: completion - data_files: - - /workspace/data/eval.jsonl - -# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo' -rl: -rl_beta: # Optional[float]. The beta parameter for the RL training. - -# dpo -dpo_use_weighting: # Optional[bool]. Whether to perform weighting. -rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper. - -# orpo -orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping. - -# kto -kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss. -kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss. - -# simpo -cpo_alpha: 1.0 # Weight of the BC regularizer -simpo_gamma: 0.5 # Target reward margin for the SimPO loss - -# grpo -trl: - use_vllm: # Optional[bool]. Whether to use VLLM for RL training. - vllm_server_host: # Optional[str]. Host of the vLLM server to connect to. - vllm_server_port: # Optional[int]. Port of the vLLM server to connect to. - vllm_server_timeout: # Optional[int]. Total timeout (in seconds) to wait for the vLLM server to respond. - vllm_guided_decoding_regex: # Optional[str]. Regex for vLLM guided decoding. - - beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use - max_completion_length: # Optional[int]. Maximum length of the completion for RL training. - - reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir. - reward_weights: # Optional[list[float]]. List of reward weights for the reward functions. - - num_generations: # Optional[int]. Number of generations to sample. - log_completions: # Optional[bool]. Whether to log completions. - num_completions_to_print: # Optional[int]. Number of completions to print when log_completions is True. - - sync_ref_model: # Optional[bool]. Whether to sync the reference model. - ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model. - ref_model_sync_steps: # Optional[int]. Sync steps for the reference model. - scale_rewards: # Optional[bool]. Whether to scale rewards by their standard deviation. - - temperature: # Optional[float]. Sampling temperature for the GRPO policy. - top_p: # Optional[float]. Top-p sampling probability for the generation policy. - top_k: # Optional[int]. Top-k sampling for the generation policy. - min_p: # Optional[float]. Minimum probability for the generation policy. - repetition_penalty: # Optional[float]. Penalty for tokens that appear in prompt and generated text. - - num_iterations: # Optional[int]. Number of iterations per batch (μ) for GRPO. - epsilon: # Optional[float]. Epsilon value for clipping in the GRPO algorithm. - epsilon_high: # Optional[float]. Upper-bound epsilon value for clipping in the GRPO algorithm. - use_liger_loss: # Optional[bool]. Whether to use Liger loss for GRPO. - loss_type: # Optional[str]. Loss formulation to use. Supported values: grpo, bnpo, dr_grpo. - mask_truncated_completions: # Optional[bool]. Whether to exclude truncated completions from loss calculation. - - -# reward modelling: `True` or `False` -reward_model: - -# process reward modelling: `True` or `False` -process_reward_model: - -# The name of the chat template to use for training, following values are supported: -# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value. -# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py -# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer. -# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field. -# The selected chat template will be saved to the tokenizer_config.json for easier inferencing -# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template. -chat_template: tokenizer_default -# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null. -chat_template_jinja: null -# Optional[List[str]]. Custom EOT (End-of-Turn) tokens to mask/unmask during training. -# These tokens mark the boundaries between conversation turns. -# For example: ["/INST", "", "[/SYSTEM_PROMPT]"] -# If not specified, defaults to just the model's eos_token. -# This is useful for templates that use multiple delimiter tokens. -eot_tokens: - # - "" - # - "[/INST]" - # - "[/SYSTEM_PROMPT]" -# Changes the default system message -default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml. -# Axolotl attempts to save the dataset as an arrow after packing the data together so -# subsequent training attempts load faster, relative path -dataset_prepared_path: data/last_run_prepared -# Push prepared dataset to hub -push_dataset_to_hub: # Optional[str] repo_org/repo_name -# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` -# if not set. -dataset_processes: # defaults to os.cpu_count() if not set -# Keep dataset in memory while preprocessing -# Only needed if cached dataset is taking too much storage -dataset_keep_in_memory: -# push checkpoints to hub -hub_model_id: # private repo path to push finetuned model -# how to push checkpoints to hub -# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy -hub_strategy: -# Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets -# Required to be true when used in combination with `push_dataset_to_hub` -hf_use_auth_token: # boolean -# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval. -val_set_size: 0.04 -# Num shards for whole dataset -dataset_shard_num: -# Index of shard to use for whole dataset -dataset_shard_idx: - -# The maximum length of an input to train with, this should typically be less than 2048 -# as most models have a token/context limit of 2048 -sequence_len: 2048 -# Pad inputs so each step uses constant sized buffers -# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently -pad_to_sequence_len: -# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true' -sample_packing: -# Set to 'false' if getting errors during eval with sample_packing on. -eval_sample_packing: -# You can set these packing optimizations AFTER starting a training at least once. -# The trainer will provide recommended values for these values. -sample_packing_eff_est: -total_num_tokens: -# Increasing the following values helps with packing, but usually only slightly (<%1.) -# The number of samples packed at a time. -sample_packing_group_size: 100000 -# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples. -sample_packing_bin_size: 200 -sample_pack_sequentially: # Optional[bool]. Whether to pack samples sequentially. - -# whether to concatenate samples during pretraining -pretraining_sample_concatenation: - -curriculum_sampling: # Optional[bool]. Whether to use sequential sampling for curriculum learning - -# Use batch flattening for speedups when not using sample_packing -batch_flattening: - -# Passed through to transformers when loading the model when launched without accelerate -# Use `sequential` when training w/ model parallelism to limit memory -device_map: -# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model. -max_memory: - -# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model -adapter: lora -# If you already have a lora model trained that you want to load, put that here. -# This means after training, if you want to test the model, you should set this to the value of `output_dir`. -# Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`. -lora_model_dir: - -# LoRA hyperparameters -# For more details about the following options, see: -# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2 -lora_r: 8 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_modules: - - q_proj - - v_proj -# - k_proj -# - o_proj -# - gate_proj -# - down_proj -# - up_proj -lora_target_linear: # If true, will target all linear modules - -# List[int] | int. # The layer indices to transform, otherwise, apply to all layers -# https://huggingface.co/docs/peft/v0.15.0/en/package_reference/lora#peft.LoraConfig.layers_to_transform -peft_layers_to_transform: - -# Optional[bool]. Whether to use DoRA. -# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#weight-decomposed-low-rank-adaptation-dora -peft_use_dora: - -# Optional[bool]. Whether to use RSLoRA. -# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#rank-stabilized-lora -peft_use_rslora: - -# Optional[list[tuple[int, int]]]. List of layer indices to replicate. -# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#memory-efficient-layer-replication-with-lora -peft_layer_replication: - -# bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"] -# How to initialize LoRA weights. Default to True which is MS original implementation. -# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#initialization -peft_init_lora_weights: - -# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens. -# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models. -# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities. -# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994 -lora_modules_to_save: -# - embed_tokens -# - lm_head - -lora_fan_in_fan_out: false - -# Apply custom LoRA autograd functions and activation function Triton kernels for -# speed and memory savings -# See: https://docs.axolotl.ai/docs/lora_optims.html -lora_mlp_kernel: true -lora_qkv_kernel: true -lora_o_kernel: true - -# LoRA+ hyperparameters -# For more details about the following options, see: -# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py` -loraplus_lr_ratio: # loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4. -loraplus_lr_embedding: # loraplus learning rate for lora embedding layers. Default value is 1e-6. - -peft: - # Configuration options for loftq initialization for LoRA - # https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization - loftq_config: - loftq_bits: # typically 4 bits - -# ReLoRA configuration -# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed -relora_steps: # Number of steps per ReLoRA restart -relora_warmup_steps: # Number of per-restart warmup steps -relora_anneal_steps: # Number of anneal steps for each relora cycle -relora_prune_ratio: # threshold for optimizer magnitude when pruning -relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings - -# wandb configuration if you're using it -# Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`. -wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb -wandb_project: # Your wandb project name -wandb_entity: # A wandb Team name if using a Team -wandb_watch: -wandb_name: # Set the name of your wandb run -wandb_run_id: # Set the ID of your wandb run -wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training - -# mlflow configuration if you're using it -mlflow_tracking_uri: # URI to mlflow -mlflow_experiment_name: # Your experiment name -mlflow_run_name: # Your run name -hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry - -# Comet configuration if you're using it -# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`. -# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start -use_comet: # Enable or disable Comet integration. -comet_api_key: # API key for Comet. Recommended to set via `comet login`. -comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace. -comet_project_name: # Project name in Comet. Defaults to Uncategorized. -comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key. -comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration. -comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True. -comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details. - -# Tensorboard -use_tensorboard: # Optional[bool] - -# Where to save the full-finetuned model to -output_dir: ./completed-model - -# Whether to use torch.compile and which backend to use -# setting to `auto` will enable torch compile when torch>=2.5.1 -torch_compile: # Optional[Union[Literal["auto"], bool]] -torch_compile_backend: # Optional[str] -torch_compile_mode: # 'default' | 'reduce-overhead' | 'max-autotune' - -# Training hyperparameters - -# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps. -gradient_accumulation_steps: 1 -# The number of samples to include in each batch. This is the number of samples sent to each GPU. -# Batch size per gpu = micro_batch_size * gradient_accumulation_steps -micro_batch_size: 2 -eval_batch_size: -num_epochs: 4 -warmup_steps: 100 # cannot use with warmup_ratio -warmup_ratio: 0.05 # cannot use with warmup_steps -learning_rate: 0.00003 -lr_quadratic_warmup: -logging_steps: -eval_steps: # Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps -evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps -eval_strategy: # Set to `"no"` to skip evaluation, `"epoch"` at end of each epoch, leave empty to infer from `eval_steps`. -save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of each epoch, `"best"` when better result is achieved, leave empty to infer from `save_steps`. -save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps -saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps -save_total_limit: # Checkpoints saved at a time -save_only_model: # Save only the model weights, skipping the optimizer. Using this means you can't resume from checkpoints. -# Maximum number of iterations to train for. It precedes num_epochs which means that -# if both are set, num_epochs will not be guaranteed. -# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps -max_steps: - -# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time. -include_tokens_per_second: # Optional[bool] - -# whether to find batch size that fits in memory. Passed to underlying transformers Trainer -auto_find_batch_size: # Optional[bool] - -eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 -eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 -do_causal_lm_eval: # Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`. -eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"] - -profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir. - # see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information - # snapshots can be visualized @ https://pytorch.org/memory_viz - -loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) -loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) - -# Save model as safetensors (require safetensors package). Default True -save_safetensors: - -# Whether to mask out or include the human's prompt from the training labels -train_on_inputs: false -# Group similarly sized data to minimize padding. -# May be slower to start, as it must download and sort the entire dataset. -# Note that training loss may have an oscillating pattern with this enabled. -group_by_length: false - -# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk". -# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing -gradient_checkpointing: false -# additional kwargs to pass to the trainer for gradient checkpointing -# gradient_checkpointing_kwargs: -# use_reentrant: true - -# Stop training after this many evaluation losses have increased in a row -# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback -early_stopping_patience: 3 - -# Specify a scheduler and kwargs to use with the optimizer -# Valid values are driven by the Transformers SchedulerType class, see: -# https://github.com/huggingface/transformers/blob/5f4ecf2d9f867a1255131d2461d75793c0cf1db2/src/transformers/trainer_utils.py#L420 -# Valid values include -# - 'linear' -# - 'cosine' (default) -# - 'cosine_with_restarts' -# - 'polynomial' -# - 'constant' -# - 'constant_with_warmup' -# - 'inverse_sqrt' -# - 'reduce_lr_on_plateau' -# - 'cosine_with_min_lr' -# - 'warmup_stable_decay' - -# Additional schedulers include: -# - 'one_cycle' -# - 'rex' -lr_scheduler: -lr_scheduler_kwargs: -cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr -cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf) - -# For one_cycle optim -lr_div_factor: # Learning rate div factor - -# Specify optimizer -# Valid values are driven by the Transformers OptimizerNames class, see: -# https://github.com/huggingface/transformers/blob/cbf924b76c03828101a34069a96d209314114fd5/src/transformers/training_args.py#L144-L189 -# -# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of -# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used -# in the examples/ for your model and fine-tuning use case. -# -# Valid values for 'optimizer' include: -# - adamw_torch -# - adamw_torch_fused (default) -# - adamw_torch_xla -# - adamw_torch_npu_fused -# - adamw_apex_fused -# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1) -# - adafactor -# - adamw_anyprecision -# - adamw_torch_4bit -# - ademamix -# - sgd -# - adagrad -# - adamw_bnb_8bit -# - adamw_8bit # alias for adamw_bnb_8bit -# - ademamix_8bit -# - lion_8bit -# - lion_32bit -# - paged_adamw_32bit -# - paged_adamw_8bit -# - paged_ademamix_32bit -# - paged_ademamix_8bit -# - paged_lion_32bit -# - paged_lion_8bit -# - rmsprop -# - rmsprop_bnb -# - rmsprop_bnb_8bit -# - rmsprop_bnb_32bit -# - galore_adamw -# - galore_adamw_8bit -# - galore_adafactor -# - galore_adamw_layerwise -# - galore_adamw_8bit_layerwise -# - galore_adafactor_layerwise -# - lomo -# - adalomo -# - grokadamw -# - schedule_free_adamw -# - schedule_free_sgd -# - apollo_adamw -# - apollo_adamw_layerwise -# -# Additional custom optimizers include: -# - optimi_adamw -# - ao_adamw_8bit -# - ao_adamw_fp8 -# - came_pytorch -optimizer: -# Dictionary of arguments to pass to the optimizer -optim_args: -# For Galore Optimizers the following optim_args are available -# rank: # type: int -# update_proj_gap # type: int -# scale # type: float -# proj_type: # type: str, default = std - -# The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm -optim_target_modules: -# - self_attn # for llama -# - mlp - -# Specify weight decay -weight_decay: -# adamw hyperparams -adam_beta1: -adam_beta2: -adam_beta3: # only used for CAME Optimizer -adam_epsilon: -adam_epsilon2: # only used for CAME Optimizer -# Gradient clipping max norm -max_grad_norm: - -# Augmentation techniques -# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings -# currently only supported on Llama and Mistral -neftune_noise_alpha: - -# Optional[bool]. Whether to bettertransformers -flash_optimum: - -# Note: Only one of the following attention patches can be used at a time. -# For example, if you set `xformers_attention` to `true`, do not set `flash_attention` to `true`. - -# Optional[bool]. Whether to use xformers attention patch https://github.com/facebookresearch/xformers: -xformers_attention: -# Optional[bool]. Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention: -flash_attention: -flash_attn_cross_entropy: # Optional[bool]. Whether to use flash-attention cross entropy implementation - advanced use only -flash_attn_rms_norm: # Optional[bool]. Whether to use flash-attention rms norm implementation - advanced use only -flash_attn_fuse_qkv: # Optional[bool]. Whether to fuse QKV into a single operation -flash_attn_fuse_mlp: # Optional[bool]. Whether to fuse part of the MLP into a single operation -# Optional[bool]. Whether to use scaled-dot-product attention -# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html -sdp_attention: -# Optional[bool]. Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf -s2_attention: - -# Optional[bool]. Whether to use low_cpu_mem_usage -low_cpu_mem_usage: -# Optional[str]. Resume from a specific checkpoint dir -resume_from_checkpoint: -# Optional[bool]. If resume_from_checkpoint isn't set and you simply want it to start where it left off. -# Be careful with this being turned on between different models. -auto_resume_from_checkpoints: false - -## Multimodal section -# int | tuple[int, int] | None . Size to resize images to, width x height. -# Will read from model/processor config if not set. -image_size: -# str. Algorithm to use for image resizing. "bilinear", "bicubic", "lanczos". Default is "bilinear". -image_resize_algorithm: 'bilinear' -## End of multimodal section - -# Don't mess with this, it's here for accelerate and torchrun -local_rank: - -# Add or change special tokens. -# If you add tokens here, you don't need to add them to the `tokens` list. -special_tokens: - # bos_token: "" - # eos_token: "" - # unk_token: "" - # pad_token: "[PAD]" - -# Optional[list[str]]. Add extra tokens to the tokenizer. -tokens: - # - "<|startoftext|>" - # - "<|endoftext|>" - -# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer. -# Only works for tokens that are not part of the base vocab (aka are added_tokens). -# Can be checked if they exist in tokenizer.json added_tokens. -added_tokens_overrides: # Dict[int, str] -# 128041: "<|im_start|>" -# 128042: "<|im_end|>" - -# FSDP -fsdp: -fsdp_config: - -# Deepspeed config path. e.g., deepspeed_configs/zero3.json -deepspeed: - -# Advanced DDP Arguments -ddp_timeout: -ddp_bucket_cap_mb: -ddp_broadcast_buffers: - -# Sequence parallelism -# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. -# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. -# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized -# subsequences, or set to 4 to split into four equal-sized subsequences. -# See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details. -sequence_parallel_degree: -# Optional; strides across the key dimension. Larger values use more memory but should make training faster. -# Must evenly divide the number of KV heads in your model. -heads_k_stride: 1 -# One of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to "varlen_llama3" -# in the sample packing case, and "batch_ring" in the non-sample packing case. -ring_attn_func: - -# Path to torch distx for optim 'adamw_anyprecision' -torchdistx_path: - -# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize -pretraining_dataset: - -# Debug mode -debug: - -# Seed -seed: - -# Allow overwrite yml config using from cli -strict: -``` diff --git a/docs/custom_integrations.qmd b/docs/custom_integrations.qmd index 023f09732..8e1fdaa2e 100644 --- a/docs/custom_integrations.qmd +++ b/docs/custom_integrations.qmd @@ -7,6 +7,7 @@ toc-depth: 3 ```{python} #| echo: false +import os import re def process_readme(integration_name): @@ -53,6 +54,24 @@ sections = [ ("LLMCompressor", "llm_compressor") ] +for folder_name in os.listdir("../src/axolotl/integrations/"): + if folder_name in [path for name, path in sections]: + # skip if already in sections + continue + if os.path.exists(f"../src/axolotl/integrations/{folder_name}/README.md"): + # grab the first heading in README.md as the section name + with open(f"../src/axolotl/integrations/{folder_name}/README.md", "r") as f: + txt = f.read() + matches = re.search(r'^# (.*)\n?', txt, flags=re.MULTILINE) + if matches: + name = matches.group(1) + else: + continue + sections.append((name, folder_name)) + +# sort sections by name +sections = sorted(sections, key=lambda x: x[0]) + for section_name, folder_name in sections: print(print_section(section_name, folder_name)) ``` diff --git a/docs/dataset-formats/conversation.qmd b/docs/dataset-formats/conversation.qmd index 87c2941e6..870a2b67d 100644 --- a/docs/dataset-formats/conversation.qmd +++ b/docs/dataset-formats/conversation.qmd @@ -9,10 +9,10 @@ order: 3 Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2. ```{.json filename="data.jsonl"} -{"conversations": [{"role": "...", "content": "..."}]} +{"messages": [{"role": "...", "content": "..."}, {"role": "...", "content": "..."}, ...]} ``` -See [configs](../config.qmd) for full configs and supported templates. +See [configs](../config-reference.qmd) for full configs and supported templates. ### Migrating from sharegpt @@ -52,7 +52,9 @@ We recommend checking the below examples for other usecases. ### Examples -1. (Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message. +#### Training on last message + +(Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message. ```yaml datasets: @@ -66,7 +68,9 @@ datasets: If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`. ::: -2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages. +#### Overriding default chat template + +Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages. ```yaml chat_template: gemma # this overwrites the tokenizer's chat_template @@ -76,7 +80,13 @@ datasets: roles_to_train: ["assistant"] # default value ``` -3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages. +::: {.callout-note} +If you want to use built-in chat_template, use `chat_template: tokenizer_default` (this is set by default). +::: + +#### Using default chat template with fallback + +Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages. ```yaml chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template @@ -85,7 +95,9 @@ datasets: type: chat_template ``` -4. Using a custom jinja template on OpenAI messages format, training on all assistant messages. +#### Custom Jinja template + +Using a custom jinja template on OpenAI messages format, training on all assistant messages. ```yaml # chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty @@ -100,7 +112,9 @@ datasets: Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `. ::: -5. If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn. +#### Using template with different token for EOT and EOS + +- If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn. ```yaml eot_tokens: @@ -116,16 +130,16 @@ datasets: ``` ::: {.callout-tip} -See [config documentation](../config.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens. +See [config documentation](../config-reference.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens. ::: ::: {.callout-note} Using `eot_tokens` requires each token that exists in `chat_template` to be a single token in the tokenizer. Otherwise, the tokenizer will split the token and cause unexpected behavior. -You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details. +You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config-reference.qmd) for more details. ::: -6. Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`. +- Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`. ```yaml eot_tokens: @@ -145,7 +159,84 @@ If EOS token only appears at the end of a prompt, `train_on_eos: last` is equiva ::: -7. (Advanced) Using fine-grained control over tokens and turns to train in a conversation +#### Using tool use + +Instead of passing `tools` via the system prompt, an alternative method would be to have the `tools` in a separate column and loaded via `chat_template` to let the template dynamically build it. + +```json +{ + "tools": [ + { + "type": "...", + "function": { + "name": "...", + "description": "...", + "parameters": { + "type": "...", + "properties": { + // ... + }, + "required": ["..."], + }, + }, + }, + ], + "messages": [ + // ... + { + "role": "assistant", // call the function via assistant + "tool_calls": [ + { + "id": "...", // required only for mistral + "type": "function", + "function": { + "name": "...", + "arguments": { + "...": "...", + } + } + } + ] + }, + { + "role": "tool", + "tool_call_id": "...", // required only for mistral + "name": "...", + "content": "..." + }, + ], +} +``` + +::: {.callout-note} +Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step). +::: + +::: {.callout-warning} +If you have tool arguments with same name but different dtypes (like `"time": string` and `"time": number`), please save `arguments: ` as JSON string to prevent `datasets` from having casting issues. + +``` +"arguments": "{\"...\": \"...\"}" +``` +::: + +Example config for Llama4: +```yaml +chat_template: llama4 +datasets: + - path: Nanobit/text-tools-2k-test + type: chat_template + # field_tools: tools # default is `tools` +``` + +::: {.callout-tip} +Look into the `chat_template` you are using to see if it supports `tools` and what the expected role is for the tool answer. In the example above, the tool answer is expected to be in the `tool` or `ipython` role for `llama4` template. +::: + + +#### Using fine-grained control over token masking + +(Advanced) Using fine-grained control over tokens and turns to train in a conversation For a data sample that looks like: @@ -196,7 +287,9 @@ datasets: It is not necessary to set both `message_field_training` and `message_field_training_detail` at once. ::: -8. (For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template. +#### Reasoning split + +(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template. ```yaml datasets: diff --git a/docs/dataset-formats/index.qmd b/docs/dataset-formats/index.qmd index a0113db07..715e3ef20 100644 --- a/docs/dataset-formats/index.qmd +++ b/docs/dataset-formats/index.qmd @@ -61,7 +61,7 @@ While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet ### Pre-training without streaming -On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming. +In the case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming. One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs. diff --git a/docs/dataset-formats/inst_tune.qmd b/docs/dataset-formats/inst_tune.qmd index d89c6adaf..f5bd7ab8f 100644 --- a/docs/dataset-formats/inst_tune.qmd +++ b/docs/dataset-formats/inst_tune.qmd @@ -186,4 +186,4 @@ datasets: no_input_format: "[INST] {instruction} [/INST]" ``` -See full config options under [here](../config.qmd). +See full config options under [here](../config-reference.qmd). diff --git a/docs/dataset_loading.qmd b/docs/dataset_loading.qmd index b78f86a98..bcffe7f0f 100644 --- a/docs/dataset_loading.qmd +++ b/docs/dataset_loading.qmd @@ -36,7 +36,7 @@ This matches the API of [`datasets.load_dataset`](https://github.com/huggingface For HuggingFace's guide to load different dataset types, see [here](https://huggingface.co/docs/datasets/loading). -For full details on the config, see [config.qmd](config.qmd). +For full details on the config, see [config-reference.qmd](config-reference.qmd). ::: {.callout-note} diff --git a/docs/debugging.qmd b/docs/debugging.qmd index bf3c6fe7e..04b4faa64 100644 --- a/docs/debugging.qmd +++ b/docs/debugging.qmd @@ -29,7 +29,7 @@ While debugging it's helpful to simplify your test scenario as much as possible. 1. **Make sure you are using the latest version of axolotl**: This project changes often and bugs get fixed fast. Check your git branch and make sure you have pulled the latest changes from `main`. 1. **Eliminate concurrency**: Restrict the number of processes to 1 for both training and data preprocessing: - Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`. - - Set `dataset_processes: 1` in your axolotl config or run the training command with `--dataset_processes=1`. + - Set `dataset_num_proc: 1` in your axolotl config or run the training command with `--dataset_num_proc=1`. 2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config): ```yaml @@ -101,7 +101,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler "-m", "axolotl.cli.train", "dev_chat_template.yml", // The flags below simplify debugging by overriding the axolotl config // with the debugging tips above. Modify as needed. - "--dataset_processes=1", // limits data preprocessing to one process + "--dataset_num_proc=1", // limits data preprocessing to one process "--max_steps=1", // limits training to just one step "--batch_size=1", // minimizes batch size "--micro_batch_size=1", // minimizes batch size diff --git a/docs/docker.qmd b/docs/docker.qmd index 7b236b960..da6184394 100644 --- a/docs/docker.qmd +++ b/docs/docker.qmd @@ -9,7 +9,7 @@ format: This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai). ::: {.callout-important} -For Blackwell GPUs, please use the tags with Pytorch 2.7.0 and CUDA 12.8. +For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8. ::: ## Base @@ -32,10 +32,11 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version} Tags examples: -- `main-base-py3.11-cu128-2.7.0` +- `main-base-py3.11-cu128-2.7.1` +- `main-base-py3.11-cu126-2.7.1` - `main-base-py3.11-cu126-2.7.0` +- `main-base-py3.11-cu126-2.6.0` - `main-base-py3.11-cu124-2.6.0` -- `main-base-py3.11-cu124-2.5.1` ## Main @@ -73,13 +74,15 @@ There may be some extra tags appended to the image, like `-vllm` which installs Tags examples: +- `main-py3.11-cu128-2.7.1` +- `main-py3.11-cu126-2.7.1` - `main-py3.11-cu126-2.7.0` +- `main-py3.11-cu126-2.6.0` - `main-py3.11-cu124-2.6.0` -- `main-py3.11-cu124-2.5.1` - `main-latest` - `main-20250303-py3.11-cu124-2.6.0` -- `main-20250303-py3.11-cu124-2.5.1` -- `0.9.2` +- `main-20250303-py3.11-cu126-2.6.0` +- `0.10.1` ## Cloud diff --git a/docs/faq.qmd b/docs/faq.qmd index f2744caba..ffc29d35d 100644 --- a/docs/faq.qmd +++ b/docs/faq.qmd @@ -9,11 +9,11 @@ description: Frequently asked questions > A: Usually an issue with the GPUs communicating with each other. See the [NCCL doc](nccl.qmd) -**Q: Exitcode -9** +**Q: exitcode: -9** > A: This usually happens when you run out of system RAM. -**Q: Exitcode -7 while using deepspeed** +**Q: exitcode: -7 while using deepspeed** > A: Try upgrading deepspeed w: `pip install -U deepspeed` @@ -51,6 +51,18 @@ description: Frequently asked questions > pad_token: "..." > ``` +**Q: `IterableDataset error` or `KeyError: 'input_ids'` when using `preprocess` CLI** + +> A: This is because you may be using `preprocess` CLI with `pretraining_dataset:` or `skip_prepare_dataset: true` respectively. Please use `axolotl train` CLI directly instead as these datasets are prepared on demand. + +**Q: vLLM is not working with Axolotl** + +> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag. + +**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4** + +> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717. + ### Chat templates **Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`** @@ -124,3 +136,11 @@ description: Frequently asked questions > dynamic: false > mode: max-autotune-no-cudagraphs > ``` + +**Q: `ValueError("Backward pass should have cleared tracker of all tensors")` + +> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML. + +**Q: `Error parsing tool_calls arguments as JSON.` + +> A: There is an error parsing string arguments to a dict. Please check your dataset and the error message for more details. diff --git a/docs/fsdp_qlora.qmd b/docs/fsdp_qlora.qmd index 7af2a3eba..01f57e627 100644 --- a/docs/fsdp_qlora.qmd +++ b/docs/fsdp_qlora.qmd @@ -1,5 +1,5 @@ --- -title: "FDSP + QLoRA" +title: "FSDP + QLoRA" description: Use FSDP with QLoRA to fine-tune large LLMs on consumer GPUs. format: html: @@ -20,9 +20,15 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps: > See the [example config](#example-config) file in addition to reading these instructions. 1. Set `adapter: qlora` in your axolotl config file. -2. Enable FSDP in your axolotl config, as [described here](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#fsdp). +2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp). 3. Use one of the supported model types: `llama`, `mistral` or `mixtral`. +## Enabling Swap for FSDP2 + +If available memory is insufficient even after FSDP's CPU offloading, you can enable swap memory usage by setting `cpu_offload_pin_memory: false` alongside `offload_params: true` in FSDP config. + +This disables memory pinning, allowing FSDP to use disk swap space as fallback. Disabling memory pinning itself incurs performance overhead, and actually having to use swap adds more, but it may enable training larger models that would otherwise cause OOM errors on resource constrained systems. + ## Example Config [examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl. diff --git a/docs/getting-started.qmd b/docs/getting-started.qmd index 6f1b54348..de059c397 100644 --- a/docs/getting-started.qmd +++ b/docs/getting-started.qmd @@ -55,7 +55,7 @@ output_dir: ./outputs/lora-out - To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`. ::: -See our [Config options](config.qmd) for more details. +See our [config options](config-reference.qmd) for more details. ### Training {#sec-training} @@ -179,7 +179,7 @@ Now that you have the basics, you might want to: Check our other guides for details on these topics: -- [Configuration Guide](config.qmd) - Full configuration options +- [Configuration Guide](config-reference.qmd) - Full configuration options - [Dataset Loading](dataset_loading.qmd) - Loading datasets from various sources - [Dataset Formats](dataset-formats) - Working with different data formats - [Multi-GPU Training](multi-gpu.qmd) diff --git a/docs/gradient_checkpointing.qmd b/docs/gradient_checkpointing.qmd new file mode 100644 index 000000000..25a887999 --- /dev/null +++ b/docs/gradient_checkpointing.qmd @@ -0,0 +1,29 @@ +--- +title: Gradient Checkpointing and Activation Offloading +--- + +Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning +models by reducing the memory footprint and improving computational efficiency. + +### Enabling Gradient Checkpointing + +```yaml +gradient_checkpointing: true +``` + +### Enabling Activation Offloading + +```yaml +gradient_checkpointing: true # required for activation offloading +activation_offloading: true +``` + +Activation offloading variants: + +The default `activation_offloading: true` offloads activations to CPU and uses CUDA streams +to overlap the communications and computations when offloading. + +The `activation_offloading: legacy` naively offloads activations to CPU and without additional optimizations. + +For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads +activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory. diff --git a/docs/installation.qmd b/docs/installation.qmd index 15f2db57b..265ff238c 100644 --- a/docs/installation.qmd +++ b/docs/installation.qmd @@ -14,8 +14,8 @@ This guide covers all the ways you can install and set up Axolotl for your envir ## Requirements {#sec-requirements} - NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU -- Python ≥3.10 -- PyTorch ≥2.5.1 +- Python ≥3.11 +- PyTorch ≥2.6.0 ## Installation Methods {#sec-installation-methods} @@ -124,14 +124,17 @@ For providers supporting Docker: - Use `axolotlai/axolotl-cloud:main-latest` - Available on: - - [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c) - - [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl) - - [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz) - - [Novita](https://novita.ai/gpus-console?templateId=311) + - [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz) + - [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=axolotl&utm_medium=partner&utm_campaign=template_launch_july2025&utm_content=docs_link) + - [PRIME Intellect](https://app.primeintellect.ai/dashboard/create-cluster?image=axolotl&location=Cheapest&security=Cheapest&show_spot=true) + - [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) + - [Novita](https://novita.ai/gpus-console?templateId=311) + - [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl) + - [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c) ### Google Colab {#sec-colab} -Use our [example notebook](../examples/colab-notebooks/colab-axolotl-example.ipynb). +[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa) ## Platform-Specific Instructions {#sec-platform-specific} @@ -153,7 +156,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker. ### Conda/Pip venv {#sec-conda} -1. Install Python ≥3.10 +1. Install Python ≥3.11 2. Install PyTorch: https://pytorch.org/get-started/locally/ 3. Install Axolotl: ```{.bash} diff --git a/docs/lora_optims.qmd b/docs/lora_optims.qmd index 7cdf53975..40893387b 100644 --- a/docs/lora_optims.qmd +++ b/docs/lora_optims.qmd @@ -5,10 +5,11 @@ description: "Custom autograd functions and Triton kernels in Axolotl for optimi Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU -(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function -Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was -to leverage operator fusion and tensor re-use in order to improve speed and reduce -memory usage during the forward and backward passes of these calculations. +(including the DDP, DeepSpeed, and FSDP2 settings) training. These include (1) SwiGLU +and GEGLU activation function Triton kernels, and (2) LoRA MLP and attention custom +autograd functions. Our goal was to leverage operator fusion and tensor re-use in order +to improve speed and reduce memory usage during the forward and backward passes of +these calculations. We currently support several common model architectures, including (but not limited to): @@ -131,6 +132,5 @@ computation path. ## Future Work - Support for additional model architectures -- Support for the FSDP setting - Support for dropout and bias - Additional operator fusions diff --git a/docs/mixed_precision.qmd b/docs/mixed_precision.qmd new file mode 100644 index 000000000..7b77cd4bb --- /dev/null +++ b/docs/mixed_precision.qmd @@ -0,0 +1,149 @@ +--- +title: "Mixed Precision Training" +format: + html: + toc: true + toc-depth: 3 + number-sections: true + code-tools: true +execute: + enabled: false +--- + +Mixed precision training uses lower precision data types to reduce memory usage and increase training speed while maintaining model quality. Axolotl supports several mixed precision formats: + +- **FP16** - Half precision 16-bit (Pascal generation+) +- **BF16** - Brain Float 16-bit (Ampere generation+) +- **FP8** - 8-bit floating point (Hopper generation+) + +## FP16 Mixed Precision {#sec-fp16} + +### Overview {#sec-fp16-overview} + +FP16 is the traditional half-precision format, supported on older GPUs but can be less numerically stable than BF16. + +### Configuration {#sec-fp16-config} + +```{.yaml} +fp16: true +``` + +### FP16 Considerations {#sec-fp16-considerations} + +- May require gradient scaling to prevent underflow +- Less numerically stable than BF16 +- Can cause training instability with some model architectures +- Consider using BF16 if your hardware supports it + +## BF16 Mixed Precision {#sec-bf16} + +### Overview {#sec-bf16-overview} + +BF16 (Brain Float 16) offers better numerical stability than FP16 and is the recommended mixed precision format for modern GPUs. It provides the same dynamic range as FP32 while using half the memory. + +### Configuration {#sec-bf16-config} + +```{.yaml} +# Automatic BF16 detection (recommended) +bf16: auto + +# Or explicitly enable +bf16: true + +# For evaluation with BF16 +bf16: full # Equivalent to bf16_full_eval in the HF trainer +``` + +## FP8 Mixed Precision {#sec-fp8} + +::: {.callout-note} +FP8 support is experimental and requires compatible hardware (H100, H200) and recent PyTorch versions with TorchAO. +::: + +### What is FP8? {#sec-fp8-overview} + +FP8 (8-bit floating point) can provide significant time savings compared to FP16/BF16 while maintaining training stability. Axolotl's implementation uses PyTorch's TorchAO library with "tensorwise" scaling strategy. + +### Requirements {#sec-fp8-software} + +- Hopper+ GPUs (H100/H200) +- PyTorch 2.7+ (+ compatible TorchAO version) +- CUDA 12.4+ + +### Configuration {#sec-fp8-config} + +Add to your YAML config: + +```{.yaml} +# Enable FP8 mixed precision +fp8: true + +# Optional: Enable FP8 for FSDP all-gather operations +fp8_enable_fsdp_float8_all_gather: true + +# Enable torch.compile (almost always necessary for FP8 speedups) +torch_compile: true +``` + +::: {.callout-important} +**torch.compile is critical for FP8 performance** + +FP8 training requires `torch_compile: true` to see meaningful speedups. Without compilation, FP8 may actually be slower and use more memory than FP16/BF16. +::: + +### Advanced FP8 Configs {#sec-fp8-advanced} + +For [FSDP](multi-gpu.qmd#sec-fsdp) (Fully Sharded Data Parallel) training: + +```{.yaml} +fp8: true +fp8_enable_fsdp_float8_all_gather: true + +torch_compile: true + +# FSDP configuration +fsdp_version: 2 +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: LlamaDecoderLayer + state_dict_type: FULL_STATE_DICT + reshard_after_forward: true +``` + +## Best Practices {#sec-best-practices} + +### Choosing Precision Format {#sec-choosing-format} + +- **Start with automatic detection**: `bf16: auto` +- **For Hopper+ (H100/H200)**: Try FP8 + torch.compile for maximum speed +- **For Ampere (A100/RTX 30/40)**: Use BF16 +- **For older Pascal/Turing GPUs**: Use FP16 with caution +- **For very old or unsupported GPUs**: Use FP32 + +### Validation and Testing {#sec-validation} + +Always validate your mixed precision setup: + +- **Start with a small dataset** to verify stability +- **Monitor loss curves** for irregularities +- **Compare with FP32 baseline** when possible +- **Test evaluation metrics** match expectations + +### FP8 Particulars {#sec-fp8-details} + +- Use cases + - Single GPU training + - Multi GPU training with FSDP2 or Deepspeed +- Speedups + - Please refer to the [TorchAO FP8 training benchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling) for expected matmul speedups for different (M, K, N) settings + - Concrete number for LLaMA 3 8B training can be found [here](https://github.com/pytorch/ao/tree/main/torchao/float8#training-benchmarks) +- Known issues: + - FP8 + DDP + `torch.compile` (causes [error](https://gist.github.com/djsaunde/0c1664c32e44a64d31b5e01b4aafe5c4)) + - FP8 + FSDP2 + `torch.compile` + FSDP2 activation checkpointing tends to be _slower_ than the BF16 equivalent training + - Flash Attention 2 does not play nicely with `torch.compile` + +See `examples/llama-3/3b-fp8-fsdp2.yaml` for an optimized example config. Enabling FP8 mixed precision + FP8 all-gather training results in ~10% faster iterations per second vs. BF16 for a relatively small (3B param) model + +For more information on multi-GPU training, see our [Multi-GPU guide](multi-gpu.qmd). diff --git a/docs/multi-gpu.qmd b/docs/multi-gpu.qmd index fee7d17e5..57a941b04 100644 --- a/docs/multi-gpu.qmd +++ b/docs/multi-gpu.qmd @@ -23,8 +23,6 @@ Axolotl supports several methods for multi-GPU training: ## DeepSpeed {#sec-deepspeed} -DeepSpeed is the recommended approach for multi-GPU training due to its stability and performance. It provides various optimization levels through ZeRO stages. - ### Configuration {#sec-deepspeed-config} Add to your YAML config: @@ -32,7 +30,6 @@ Add to your YAML config: ```{.yaml} deepspeed: deepspeed_configs/zero1.json ``` - ### Usage {#sec-deepspeed-usage} ```{.bash} @@ -66,9 +63,67 @@ Start from Stage 1 -> Stage 2 -> Stage 3. ::: -## FSDP {#sec-fsdp} +## Fully Sharded Data Parallel (FSDP) {#sec-fsdp} -### Basic FSDP Configuration {#sec-fsdp-config} +::: {.callout-note} + +FSDP2 is recommended for new users. FSDP1 is deprecated and will be removed in an upcoming release of Axolotl. + +::: + +### Migrating from FSDP1 to FSDP2 {#sec-migrate-fsdp1-fsdp2} + +To migrate your config from FSDP1 to FSDP2, you must use the `fsdp_version` top-level config field to specify the FSDP version, and +also follow the config field mapping below to update field names. + +#### Config mapping + +FSDP1 | FSDP2 +-------- | -------- +fsdp_sharding_strategy | reshard_after_forward +fsdp_backward_prefetch_policy | **REMOVED** +fsdp_backward_prefetch | **REMOVED** +fsdp_forward_prefetch | **REMOVED** +fsdp_sync_module_states | **REMOVED** +fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading +fsdp_state_dict_type | state_dict_type +fsdp_use_orig_params | **REMOVED** +fsdp_activation_checkpointing | activation_checkpointing + +For more details, please see the migration guide in the [torchtitan repo](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md). In Axolotl, +if you were using the following FSDP1 config: + +```{.yaml} +fsdp_version: 1 +fsdp_config: + fsdp_offload_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD +``` + +You can migrate to the following FSDP2 config: + +```{.yaml} +fsdp_version: 2 +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Qwen3DecoderLayer + state_dict_type: FULL_STATE_DICT + reshard_after_forward: true +``` + +### FSDP1 (deprecated) {#sec-fsdp-config} + +::: {.callout-note} + +Using `fsdp` to configure FSDP is deprecated and will be removed in an upcoming release of Axolotl. Please use `fsdp_config` as above instead. + +::: ```{.yaml} fsdp: @@ -80,6 +135,7 @@ fsdp_config: fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer ``` + ## Sequence parallelism {#sec-sequence-parallelism} We support sequence parallelism (SP) via the diff --git a/docs/multi-node.qmd b/docs/multi-node.qmd index cec8ff45d..16196a2d7 100644 --- a/docs/multi-node.qmd +++ b/docs/multi-node.qmd @@ -40,13 +40,13 @@ use_cpu: false Configure your model to use FSDP in the Axolotl yaml. For example: ```yaml -fsdp: - - full_shard - - auto_wrap +fsdp_version: 2 fsdp_config: - fsdp_offload_params: true - fsdp_state_dict_type: FULL_STATE_DICT - fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer + offload_params: true + state_dict_type: FULL_STATE_DICT + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: LlamaDecoderLayer + reshard_after_forward: true ``` All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine. @@ -69,11 +69,19 @@ export NCCL_BUFFSIZE=2097152 Run the following on each node: +### Option 1: New Axolotl CLI with launcher args (Recommended) + +```bash +axolotl train config.yaml --launcher torchrun -- --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" +``` + +### Option 2: Direct torchrun (Legacy) + ```bash torchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" -m axolotl.cli.train config.yaml ``` -Please make sure to substitute the placeholder variables. +Please make sure to substitute the placeholder variables: - `num_nodes`: Number of nodes (containing GPUs) - `gpu_per_node`: Number of gpus per node @@ -81,8 +89,6 @@ Please make sure to substitute the placeholder variables. - `head_node_port`: Port of the head node (make sure other machines can connect to this. Default 29400) - `rdzv_id`: A unique job ID that is used by the job across nodes. -::: {.callout-note} -You need to call `axolotl.cli.train` instead of `axolotl train` as the latter calls accelerate under the hood -::: +The new CLI approach (Option 1) is recommended as it provides consistent argument handling and works seamlessly with other Axolotl CLI features. More info on the available configs can be found on the Pytorch docs [here](https://pytorch.org/docs/stable/elastic/run.html) diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd index ec51a8ec3..3a28b579a 100644 --- a/docs/multimodal.qmd +++ b/docs/multimodal.qmd @@ -13,9 +13,14 @@ format: - [Pixtral](#sec-pixtral) - [Llava-1.5](#sec-llava-15) - [Mistral-Small-3.1](#sec-mistral-small-31) +- [Magistral-Small-2509](#sec-magistral-small-2509) +- [Voxtral](#sec-voxtral) - [Gemma-3](#sec-gemma-3) +- [Gemma-3n](#sec-gemma-3n) - [Qwen2-VL](#sec-qwen2-vl) - [Qwen2.5-VL](#sec-qwen25-vl) +- [SmolVLM2](#sec-smolvlm2) +- [LFM2-VL](#sec-lfm2-vl) ## Usage @@ -30,14 +35,13 @@ skip_prepare_dataset: true remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training sample_packing: false # not yet supported with multimodal -chat_template: # see in next section +chat_template: # see in next section if specified # example dataset datasets: - path: HuggingFaceH4/llava-instruct-mix-vsft type: chat_template split: train[:1%] - field_messages: messages # (optional) if doing lora, only finetune the Language model, # leave the vision model and vision tower frozen @@ -90,10 +94,32 @@ chat_template: llava ### Mistral-Small-3.1 {#sec-mistral-small-31} +::: {.callout-tip} +Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'` +::: + ```yaml base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503 +``` -chat_template: mistral_v7_tekken +### Magistral-Small-2509 {#sec-magistral-small-2509} + +::: {.callout-tip} +Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'` +::: + +```yaml +base_model: mistralai/Magistral-Small-2509 +``` + +### Voxtral {#sec-voxtral} + +::: {.callout-tip} +Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral_common[audio]==1.8.3'` +::: + +```yaml +base_model: mistralai/Voxtral-Mini-3B-2507 ``` ### Gemma-3 {#sec-gemma-3} @@ -110,6 +136,22 @@ base_model: google/gemma-3-4b-it chat_template: gemma3 ``` +### Gemma-3n {#sec-gemma-3n} + +::: {.callout-warning} +The model's initial loss and grad norm will be very high. We suspect this to be due to the Conv in the vision layers. +::: + +::: {.callout-tip} +Please make sure to install `timm` via `pip3 install timm==1.0.17` +::: + +```yaml +base_model: google/gemma-3n-E2B-it + +chat_template: gemma3n +``` + ### Qwen2-VL {#sec-qwen2-vl} ```yaml @@ -126,13 +168,35 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct chat_template: qwen2_vl # same as qwen2-vl ``` +### SmolVLM2 {#sec-smolvlm2} + +::: {.callout-tip} +Please make sure to install `num2words` via `pip3 install num2words==0.5.14` +::: + +```yaml +base_model: HuggingFaceTB/SmolVLM2-500M-Video-Instruct +``` + +### LFM2-VL {#sec-lfm2-vl} + +::: {.callout-warning} +Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d` +::: + +```yaml +base_model: LiquidAI/LFM2-VL-450M +``` + ## Dataset Format For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format. - A message is a list of `role` and `content`. - `role` can be `system`, `user`, `assistant`, etc. -- `content` is a list of `type` and (`text` or `image` or `path` or `url` or `base64`). +- `content` is a list of `type` and (`text`, `image`, `path`, `url`, `base64`, or `audio`). + +### Image ::: {.callout-note} For backwards compatibility: @@ -141,15 +205,43 @@ For backwards compatibility: - If `content` is a string, it will be converted to a list with `type` as `text`. ::: -::: {.callout-tip} For image loading, you can use the following keys within `content` alongside `"type": "image"`: - `"path": "/path/to/image.jpg"` - `"url": "https://example.com/image.jpg"` - `"base64": "..."` - `"image": PIL.Image` + +### Audio + +For audio loading, you can use the following keys within `content` alongside `"type": "audio"`: + +- `"path": "/path/to/audio.mp3"` +- `"url": "https://example.com/audio.mp3"` +- `"audio": np.ndarray` + +::: {.callout-tip} + +You may need to install `librosa` via `pip3 install librosa==0.11.0`. + ::: +### Video + +::: {.callout-warning} + +This is not well tested at the moment. We welcome contributors! + +::: + +For video loading, you can use the following keys within `content` alongside `"type": "video"`: + +- `"path": "/path/to/video.mp4"` +- `"url": "https://example.com/video.mp4"` +- `"video": np.ndarray | list[PIL.Image.Image] | torch.Tensor` (or list of the aforementioned) + +### Example + Here is an example of a multi-modal dataset: ```json [ @@ -178,3 +270,9 @@ Here is an example of a multi-modal dataset: } ] ``` + +## FAQ + +1. `PIL.UnidentifiedImageError: cannot identify image file ...` + +`PIL` could not retrieve the file at `url` using `requests`. Please check for typo. One alternative reason is that the request is blocked by the server. diff --git a/docs/nd_parallelism.qmd b/docs/nd_parallelism.qmd new file mode 100644 index 000000000..435e53e21 --- /dev/null +++ b/docs/nd_parallelism.qmd @@ -0,0 +1,108 @@ +--- +title: "N-D Parallelism (Beta)" +--- + +Axolotl enables training models at scale by composing different parallelism techniques. This is essential when: + +- A model's weights are too large to fit on a single GPU's memory. +- A model's activations, especially with very long contexts, are too large for a single GPU. +- You want to accelerate training by using multiple GPUs or nodes. + +or combinations of the above! + +## Core Concepts + +Parallelism strategies can be combined. The key is understanding how each one divides the workload. PyTorch's `DeviceMesh` is the modern way to manage these combinations, creating a logical grid of your GPUs and assigning different parallel strategies to different dimensions of the grid. + +### Data Parallelism {#sec-dp} + +Data Parallelism focuses on splitting the global data batch across GPUs. + +- Distributed Data Parallel (DDP): The classic approach. The full model is replicated on every GPU. Each GPU processes a different slice of the data batch. Gradients are then averaged across all GPUs after the backward pass to keep the models synchronized. This can substantially improve data throughput compared to single-device training, but requires that each GPU is able to hold the entire model, its gradients, and optimizer states. + +- [Fully Sharded Data Parallel (FSDP)](multi-gpu.qmd#fully-sharded-data-parallel-(fsdp)): A highly memory-efficient form of data parallelism (inspired by DeepSpeed's ZeRO). Instead of replicating the model, FSDP shards the model's *parameters, gradients, and optimizer states* across the GPUs in the data-parallel group. During computation, each GPU receives the specific parameters it needs via an `all_gather` operation just before they are used, and they can be discarded immediately after (`reshard-after-forward`). + - FSDP maps to ZeRO stages: + - ZeRO-2 (`reshard_after_forward=False`): Shards gradients and optimizer states. Model weights are replicated on each GPU. + - ZeRO-3 (`reshard_after_forward=True`): Shards gradients, optimizer states, AND model parameters. This provides the most memory savings at the cost of more communication (re-gathering parameters for both forward and backward passes). + +### [Experimental] Tensor Parallelism (TP) {#sec-tp} + +Also known as "horizontal model parallelism," as described in the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Instead of splitting the batch, TP splits the model's layers themselves across GPUs. + +- How it works: For a linear layer `Y = XA`, the weight matrix `A` is split column-wise (`A = [A_1, A_2]`). The computation becomes `Y_1 = XA_1` and `Y_2 = XA_2`, which can happen in parallel on different GPUs. The final output `Y` is simply the concatenation of `Y_1` and `Y_2`. Check [this comment](https://github.com/huggingface/transformers/issues/10321#issuecomment-783543530) for more detailed info. +- Requirement: TP involves frequent, small communications within a forward/backward pass. It requires a very fast interconnect between GPUs (e.g., NVLink) and is typically not recommended across different nodes. + +### Context Parallelism (CP) {#sec-cp} + +Context Parallelism, also called [Sequence Parallelism](sequence_parallelism.qmd), addresses the memory bottleneck from long sequences. The input sequence itself is split along the sequence length dimension and distributed across GPUs. + +- How it works: If you have a sequence of 8192 tokens and a `context_parallel_size` of 4, each GPU will only handle a chunk of 2048 tokens. +- The Challenge: Attention is not local; every token needs to "attend to" every other token. Splitting the sequence breaks this. +- The Solution (`ring-flash-attention`): An efficient communication protocol is used. To compute attention for its local sequence chunk, each GPU passes its Key-Value (KV) cache to its neighbor in a "ring." After `N-1` steps, every GPU has seen the KV-cache from all other GPUs, allowing it to compute the correct attention values for its chunk. This is implemented using the highly optimized `flash-attention` kernel at each step. + +### Hybrid Sharding Data Parallel (HSDP) {#sec-hsdp} + +HSDP is a 2D strategy that intelligently combines FSDP and DDP, typically for multi-node training. + +- Intra-Node (within a machine): Use FSDP. This is efficient because GPUs on the same node have fast interconnects (NVLink), making the `all_gather` operations for sharded parameters fast. +- Inter-Node (across machines): Use DDP. The gradient synchronization between nodes is less frequent than FSDP's parameter gathering, making it a better fit for the slower node-to-node network (e.g., Ethernet/Infiniband). +- Example: With 2 nodes of 8 GPUs each (16 total), you could have `dp_shard_size=8` (FSDP within each node) and `dp_replicate_size=2` (DDP across the two nodes). + +## Usage + +```yaml +# FSDP config. See https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp +fsdp_version: 2 +fsdp_config: + # ... + +# The number of GPUs to shard the model parameters across (FSDP dimension). +dp_shard_size: 4 + +# The number of times to replicate the sharded model (DDP dimension). +dp_replicate_size: 2 + +# Number of GPUs for Tensor Parallelism. +tensor_parallel_size: 1 # (default is 1, no TP) + +# Number of GPUs for Context/Sequence Parallelism. +context_parallel_size: 1 # (default is 1, no CP) +``` + +Note: We recommend FSDP. DeepSpeed is only compatible with `tensor_parallel_size`. + +## Examples + +::: {.callout-tip} +See our example configs [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/distributed-parallel). +::: + +1. HSDP on 2 nodes with 4 GPUs each (8 GPUs total): + - You want FSDP within each node and DDP across nodes. + - Set `dp_shard_size: 4` and `dp_replicate_size: 2`. + +2. FSDP + TP on a single 8-GPU node: + - You want to split the model across 4 GPUs using FSDP, and further split each layer across 2 GPUs with TP. + - Set `dp_shard_size: 4` and `tensor_parallel_size: 2`. + +3. FSDP + CP on a single 8-GPU node for long context: + - You want to shard the model across all 8 GPUs and also split the sequence length across all 8 GPUs. + - Set `dp_shard_size: 8` and `context_parallel_size: 8`. Note: this means the data parallel group and context parallel group are the same. A more common setup might be to shard across a smaller group. + +## Support Matrix + +This matrix describes how different parallelism methods can be combined in Axolotl. + +| Combination | `dp_replicate_size` | `dp_shard_size` | `tp_size` | `cp_size` | Status & Notes | +| --- | :---: | :---: |:---:|:---:|---| +| **FSDP** (ZeRO-3) | 1 | >1 | 1 | 1 | ✅ Fully supported. Shards model across all GPUs. | +| **HSDP** | >1 | >1 | 1 | 1 | ✅ Fully supported. FSDP intra-node, DDP inter-node. | +| **FSDP + TP** | 1 | >1 | >1 | 1 | ✅ **2D Parallelism**. Shards the model across a `dp_shard` group, and TP-splits layers within the `tp` group. | +| **HSDP + TP** | >1 | >1 | >1 | 1 | ✅ **3D Parallelism**. A powerful but complex combination. | +| **FSDP + CP** | 1 | >1 | 1 | >1 | ✅ **2D Parallelism**. Combines FSDP with context parallelism. | +| **FSDP + TP + CP**| 1 | >1 | >1| >1| ✅ **3D Parallelism**. Another advanced combination. | +| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP or CP is currently not supported. You should use FSDP + TP/CP instead (`dp_shard_size > 1`). | +| Just TP / CP | 1 | 1 | >1 | >1 | ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long. | + +- `tp_size` refers to `tensor_parallel_size` +- `cp_size` refers to `context_parallel_size` diff --git a/docs/optimizations.qmd b/docs/optimizations.qmd new file mode 100644 index 000000000..967ec2d34 --- /dev/null +++ b/docs/optimizations.qmd @@ -0,0 +1,133 @@ +--- +title: Optimizations Guide +description: A guide to the performance and memory optimizations available in Axolotl. +--- + +Axolotl includes numerous optimizations to speed up training, reduce memory usage, and handle large models. + +This guide provides a high-level overview and directs you to the detailed documentation for each feature. + +## Speed Optimizations + +These optimizations focus on increasing training throughput and reducing total training time. + +### Sample Packing + +Improves GPU utilization by combining multiple short sequences into a single packed sequence for training. This requires enabling one of the [attention](#attention-implementations) implementations below. + +- **Config:** `sample_packing: true` +- **Learn more:** [Sample Packing](multipack.qmd) + +### Attention Implementations + +Using an optimized attention implementation is critical for training speed. + +- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support). +- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`. +- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation. +- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16. + +*Note: You should only enable one attention backend.* + +### LoRA Optimizations + +Leverages optimized kernels to accelerate LoRA training and reduce memory usage. + +- **Learn more:** [LoRA Optimizations Documentation](lora_optims.qmd) + +## Memory Optimizations + +These techniques help you fit larger models or use bigger batch sizes on your existing hardware. + +### Parameter Efficient Finetuning (LoRA & QLoRA) + +Drastically reduces memory by training a small set of "adapter" parameters instead of the full model. This is the most common and effective memory-saving technique. + +- Examples: Find configs with `lora` or `qlora` in the [examples directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-3). +- Config Reference: See `adapter`, `load_in_4bit`, and `load_in_8bit` in the [Configuration Reference](config-reference.qmd). + +### Gradient Checkpointing & Activation Offloading + +These techniques save VRAM by changing how activations are handled. + +- Gradient Checkpointing: re-computes activations during the backward pass, trading compute time for VRAM. +- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM. +- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd) + +### Cut Cross Entropy (CCE) + +Reduces VRAM usage by using an optimized cross-entropy loss calculation. + +- **Learn more:** [Custom Integrations - CCE](custom_integrations.qmd#cut-cross-entropy) + +### Liger Kernels + +Provides efficient Triton kernels to improve training speed and reduce memory usage. + +- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels) + +## Long Context Models + +Techniques to train models on sequences longer than their original context window. + +### RoPE Scaling + +Extends a model's context window by interpolating its Rotary Position Embeddings. + +- **Config:** Pass the `rope_scaling` config under the `overrides_of_model_config: `. To learn how to set RoPE, check the respective model config. + +### Sequence Parallelism + +Splits long sequences across multiple GPUs, enabling training with sequence lengths that would not fit on a single device. + +- **Learn more:** [Sequence Parallelism Documentation](sequence_parallelism.qmd) + +### Artic Long Sequence Training (ALST) + +ALST is a recipe that combines several techniques to train long-context models efficiently. It typically involves: + +- TiledMLP to reduce memory usage in MLP layers. +- Tiled Loss functions (like [CCE](#cut-cross-entropy-(cce) or [Liger](#liger-kernels)). +- Activation Offloading to CPU. + +- Example: [ALST Example Configuration](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) + +## Large Models (Distributed Training) + +To train models that don't fit on a single GPU, you'll need to use a distributed training strategy like FSDP or DeepSpeed. These frameworks shard the model weights, gradients, and optimizer states across multiple GPUs and nodes. + +- **Learn more:** [Multi-GPU Guide](multi-gpu.qmd) +- **Learn more:** [Multi-Node Guide](multi-node.qmd) + +### N-D Parallelism (Beta) + +For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once. + +- **Learn more:** [N-D Parallelism Guide](nd_parallelism.qmd) + + +## Quantization + +Techniques to reduce the precision of model weights for memory savings. + +### 4-bit Training (QLoRA) + +The recommended approach for quantization-based training. It loads the base model in 4-bit using `bitsandbytes` and then trains QLoRA adapters. See [Adapter Finetuning](#adapter-finetuning-lora-qlora) for details. + +### FP8 Training + +Enables training with 8-bit floating point precision on supported hardware (e.g., NVIDIA Hopper series GPUs) for significant speed and memory gains. + +- **Example:** [Llama 3 FP8 FSDP Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-3/3b-fp8-fsdp2.yaml) + +### Quantization Aware Training (QAT) + +Simulates quantization effects during training, helping the model adapt and potentially improving the final accuracy of the quantized model. + +- **Learn more:** [QAT Documentation](qat.qmd) + +### GPTQ + +Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method. + +- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml) diff --git a/docs/optimizers.qmd b/docs/optimizers.qmd new file mode 100644 index 000000000..45eea1d3a --- /dev/null +++ b/docs/optimizers.qmd @@ -0,0 +1,129 @@ +--- +title: Optimizers +description: Configuring optimizers +--- + +## Overview + +Axolotl supports all optimizers supported by [transformers OptimizerNames](https://github.com/huggingface/transformers/blob/51f94ea06d19a6308c61bbb4dc97c40aabd12bad/src/transformers/training_args.py#L142-L187) + +Here is a list of optimizers supported by transformers as of `v4.54.0`: + +- `adamw_torch` +- `adamw_torch_fused` +- `adamw_torch_xla` +- `adamw_torch_npu_fused` +- `adamw_apex_fused` +- `adafactor` +- `adamw_anyprecision` +- `adamw_torch_4bit` +- `adamw_torch_8bit` +- `ademamix` +- `sgd` +- `adagrad` +- `adamw_bnb_8bit` +- `adamw_8bit` # alias for adamw_bnb_8bit +- `ademamix_8bit` +- `lion_8bit` +- `lion_32bit` +- `paged_adamw_32bit` +- `paged_adamw_8bit` +- `paged_ademamix_32bit` +- `paged_ademamix_8bit` +- `paged_lion_32bit` +- `paged_lion_8bit` +- `rmsprop` +- `rmsprop_bnb` +- `rmsprop_bnb_8bit` +- `rmsprop_bnb_32bit` +- `galore_adamw` +- `galore_adamw_8bit` +- `galore_adafactor` +- `galore_adamw_layerwise` +- `galore_adamw_8bit_layerwise` +- `galore_adafactor_layerwise` +- `lomo` +- `adalomo` +- `grokadamw` +- `schedule_free_radam` +- `schedule_free_adamw` +- `schedule_free_sgd` +- `apollo_adamw` +- `apollo_adamw_layerwise` +- `stable_adamw` + + +## Custom Optimizers + +Enable custom optimizers by passing a string to the `optimizer` argument. Each optimizer will receive beta and epsilon args, however, some may accept additional args which are detailed below. + +### optimi_adamw + +```yaml +optimizer: optimi_adamw +``` + +### ao_adamw_4bit + +Deprecated: Please use `adamw_torch_4bit`. + +### ao_adamw_8bit + +Deprecated: Please use `adamw_torch_8bit`. + +### ao_adamw_fp8 + + +```yaml +optimizer: ao_adamw_fp8 +``` + +### adopt_adamw + +GitHub: [https://github.com/iShohei220/adopt](https://github.com/iShohei220/adopt) +Paper: [https://arxiv.org/abs/2411.02853](https://arxiv.org/abs/2411.02853) + +```yaml +optimizer: adopt_adamw +``` + +### came_pytorch + +GitHub: [https://github.com/yangluo7/CAME/tree/master](https://github.com/yangluo7/CAME/tree/master) +Paper: [https://arxiv.org/abs/2307.02047](https://arxiv.org/abs/2307.02047) + +```yaml +optimizer: came_pytorch + +# optional args (defaults below) +adam_beta1: 0.9 +adam_beta2: 0.999 +adam_beta3: 0.9999 +adam_epsilon: 1e-30 +adam_epsilon2: 1e-16 +``` + +### muon + +Blog: [https://kellerjordan.github.io/posts/muon/](https://kellerjordan.github.io/posts/muon/) +Paper: [https://arxiv.org/abs/2502.16982v1](https://arxiv.org/abs/2502.16982v1) + +```yaml +optimizer: muon +``` + +### dion + +Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient +orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication. + +GitHub: [https://github.com/microsoft/dion](https://github.com/microsoft/dion) +Paper: [https://arxiv.org/pdf/2504.05295](https://arxiv.org/pdf/2504.05295) +Note: Implementation written for PyTorch 2.7+ for DTensor + +```yaml +optimizer: dion +dion_lr: 0.01 +dion_momentum: 0.95 +lr: 0.00001 # learning rate for embeddings and parameters that fallback to AdamW +``` diff --git a/docs/qat.qmd b/docs/qat.qmd index 0531388de..91fe5180c 100644 --- a/docs/qat.qmd +++ b/docs/qat.qmd @@ -23,10 +23,18 @@ To enable QAT in axolotl, add the following to your configuration file: ```yaml qat: - activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8" - weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8" + activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4", "int8", "float8" + weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4", "fp8", and "nvfp4". group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after ``` -Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize` command](./quantize.md) to do this. +We support the following quantization schemas: + +- `Int4WeightOnly` (requires the `fbgemm-gpu` extra when installing Axolotl) +- `Int8DynamicActivationInt4Weight` +- `Float8DynamicActivationFloat8Weight` +- `Float8DynamicActivationInt4Weight` +- `NVFP4` + +Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize`](./quantize.qmd) command to do this. diff --git a/docs/quantize.qmd b/docs/quantize.qmd index 294efda8b..9c3de1ef1 100644 --- a/docs/quantize.qmd +++ b/docs/quantize.qmd @@ -22,8 +22,8 @@ Quantization is configured using the `quantization` key in your configuration fi ```yaml base_model: # The path to the model to quantize. quantization: - weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8 - activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8" + activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4", "int8", "float8" + weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4", "fp8", and "nvfp4". group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer. @@ -32,16 +32,15 @@ output_dir: # The path to the output directory. Once quantization is complete, your quantized model will be saved in the `{output_dir}/quantized` directory. -You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.md) - you can do this by using the existing QAT configuration file which +You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.qmd) - you can do this by using the existing QAT configuration file which you used to train the model: ```yaml # qat.yml qat: activation_dtype: int8 - weight_dtype: int8 + weight_dtype: int4 group_size: 256 - quantize_embedding: true output_dir: # The path to the output directory used during training where the final checkpoint has been saved. ``` @@ -51,3 +50,11 @@ axolotl quantize qat.yml ``` This ensures that an identical quantization configuration is used to quantize the model as was used to train it. + + +::: {.callout-note} + +If you have configured pushing to hub with `hub_model_id`, your model hub name will have the quantization schema appended to it, +e.g. `axolotl-ai-cloud/qat-nvfp4-llama3B` will become `axolotl-ai-cloud/qat-nvfp4-llama3B-nvfp4w` + +::: diff --git a/docs/reward_modelling.qmd b/docs/reward_modelling.qmd index 386dc1f57..b5cf3010d 100644 --- a/docs/reward_modelling.qmd +++ b/docs/reward_modelling.qmd @@ -11,6 +11,7 @@ We support the reward modelling techniques supported by `trl`. ### (Outcome) Reward Models Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step). +For improved training stability, you can use the `center_rewards_coefficient` parameter to encourage mean-zero reward outputs ([see TRL docs](https://huggingface.co/docs/trl/v0.10.1/en/reward_trainer#centering-rewards)). ```yaml base_model: google/gemma-2-2b diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index b2687a8f9..4a67b7559 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -17,7 +17,6 @@ feedback. Various methods include, but not limited to: - [Kahneman-Tversky Optimization (KTO)](#kto) - [Odds Ratio Preference Optimization (ORPO)](#orpo) - [Group Relative Policy Optimization (GRPO)](#grpo) -- Proximal Policy Optimization (PPO) (not yet supported in axolotl, if you're interested in contributing, please reach out!) ## RLHF using Axolotl @@ -275,15 +274,14 @@ rl: dpo datasets: - path: ... split: train - type: user_defined.default - - field_prompt: "prompt" - field_system: "system" - field_chosen: "chosen" - field_rejected: "rejected" - prompt_format: "{prompt}" - chosen_format: "{chosen}" - rejected_format: "{rejected}" + type: + field_prompt: "prompt" + field_system: "system" + field_chosen: "chosen" + field_rejected: "rejected" + prompt_format: "{prompt}" + chosen_format: "{chosen}" + rejected_format: "{rejected}" ``` The input format is a simple JSON input with customizable fields based on the above config. @@ -476,14 +474,13 @@ rl: kto datasets: - path: ... split: train - type: user_defined.default - - field_prompt: "prompt" - field_system: "system" - field_completion: "completion" - field_label: "label" - prompt_format: "{prompt}" - completion_format: "{completion}" + type: + field_prompt: "prompt" + field_system: "system" + field_completion: "completion" + field_label: "label" + prompt_format: "{prompt}" + completion_format: "{completion}" ``` The input format is a simple JSON input with customizable fields based on the above config. @@ -500,7 +497,7 @@ The input format is a simple JSON input with customizable fields based on the ab ### GRPO ::: {.callout-tip} -Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo). +Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/grpo_code). ::: In the latest GRPO implementation, `vLLM` is used to significantly speedup trajectory generation during training. In this example, we're using 4 GPUs - 2 for training, and 2 for vLLM: diff --git a/docs/scripts/generate_config_docs.py b/docs/scripts/generate_config_docs.py new file mode 100644 index 000000000..6efa2038b --- /dev/null +++ b/docs/scripts/generate_config_docs.py @@ -0,0 +1,749 @@ +# type: ignore + +""" +Quarto documentation generation from Pydantic models. Uses Pydantic model source code +to automatically group fields, including inherited fields from parent classes. +""" + +import ast +import inspect +import textwrap +import types +import typing +from typing import Any, FrozenSet, Type, Union + +from pydantic import BaseModel + +from axolotl.utils.schemas.config import AxolotlInputConfig + + +class QuartoGenerator: + """Generate Quarto documentation from Pydantic models.""" + + def __init__(self): + self._class_fields_cache = {} + self._inheritance_map_cache = {} + self._nested_models_cache = {} + + def _get_direct_fields(self, cls: Type[BaseModel]) -> FrozenSet[str]: + """Get fields defined directly in a single class (not inherited).""" + if cls in self._class_fields_cache: + return self._class_fields_cache[cls] + + fields = set() + + # Get annotated fields + if hasattr(cls, "__annotations__"): + fields.update(cls.__annotations__.keys()) + + # Filter out private/special methods + fields = {f for f in fields if not f.startswith("_")} + + result = frozenset(fields) + self._class_fields_cache[cls] = result + return result + + def _is_pydantic_model(self, type_obj) -> bool: + """Check if a type is a Pydantic BaseModel.""" + return inspect.isclass(type_obj) and issubclass(type_obj, BaseModel) + + def _extract_nested_type(self, field_type) -> Any: + """Extract the actual type from complex type annotations.""" + # Handle Annotated types (Python 3.9+) + if hasattr(typing, "get_origin") and hasattr(typing, "get_args"): + origin = typing.get_origin(field_type) + args = typing.get_args(field_type) + + if origin is not None: + # Handle Annotated[SomeType, ...] - extract the first argument + if hasattr(typing, "Annotated") and origin is typing.Annotated: + if args: + return self._extract_nested_type( + args[0] + ) # Recursively process the actual type + + # Handle list[SomeType], List[SomeType], etc. + elif origin in (list, typing.List): + if args: + return self._extract_nested_type( + args[0] + ) # Extract element type + + # Handle Union types (including | syntax) + elif origin is typing.Union: + # Get non-None types from the Union + non_none_types = [arg for arg in args if arg is not type(None)] + if len(non_none_types) >= 1: + # Prioritize Pydantic models over primitive types + pydantic_models = [ + arg + for arg in non_none_types + if self._is_pydantic_model(arg) + ] + if pydantic_models: + # Return the first Pydantic model found + return self._extract_nested_type(pydantic_models[0]) + + # No Pydantic models, return the first non-None type + return self._extract_nested_type(non_none_types[0]) + + # Handle new Python 3.10+ union syntax (PeftConfig | None) + if hasattr(field_type, "__class__") and field_type.__class__ is types.UnionType: + # Get non-None types from the Union + non_none_types = [ + arg for arg in field_type.__args__ if arg is not type(None) + ] + if len(non_none_types) >= 1: + # Prioritize Pydantic models over primitive types + pydantic_models = [ + arg for arg in non_none_types if self._is_pydantic_model(arg) + ] + if pydantic_models: + return self._extract_nested_type(pydantic_models[0]) + return self._extract_nested_type(non_none_types[0]) + + # Handle old typing.Union syntax (fallback) + if hasattr(field_type, "__origin__"): + if field_type.__origin__ is Union: + # Get non-None types from the Union + non_none_types = [ + arg for arg in field_type.__args__ if arg is not type(None) + ] + if len(non_none_types) >= 1: + # Prioritize Pydantic models over primitive types + pydantic_models = [ + arg for arg in non_none_types if self._is_pydantic_model(arg) + ] + if pydantic_models: + return self._extract_nested_type(pydantic_models[0]) + return self._extract_nested_type(non_none_types[0]) + # Handle other generic types like dict[str, Any], etc. + elif hasattr(field_type, "__args__"): + return field_type + + return field_type + + def _extract_all_pydantic_models_from_type( + self, field_type + ) -> list[type[BaseModel]]: + """Extract all Pydantic models from a type annotation, including from Unions.""" + models = [] + + if field_type is None: + return models + + # Handle Annotated types + if hasattr(typing, "get_origin") and hasattr(typing, "get_args"): + origin = typing.get_origin(field_type) + args = typing.get_args(field_type) + + if origin is not None: + # Handle Annotated[SomeType, ...] - extract from the first argument + if hasattr(typing, "Annotated") and origin is typing.Annotated: + if args: + models.extend( + self._extract_all_pydantic_models_from_type(args[0]) + ) + return models + + # Handle list[SomeType], List[SomeType], etc. + if origin in (list, typing.List): + if args: + models.extend( + self._extract_all_pydantic_models_from_type(args[0]) + ) + return models + + # Handle Union types + if origin is typing.Union: + for arg in args: + if arg is not type(None): # Skip None type + models.extend( + self._extract_all_pydantic_models_from_type(arg) + ) + return models + + # Handle new Python 3.10+ union syntax + if hasattr(field_type, "__class__") and field_type.__class__ is types.UnionType: + for arg in field_type.__args__: + if arg is not type(None): # Skip None type + models.extend(self._extract_all_pydantic_models_from_type(arg)) + return models + + # Handle old typing.Union syntax (fallback) + if hasattr(field_type, "__origin__") and field_type.__origin__ is Union: + for arg in field_type.__args__: + if arg is not type(None): # Skip None type + models.extend(self._extract_all_pydantic_models_from_type(arg)) + return models + + # Check if this type itself is a Pydantic model + if self._is_pydantic_model(field_type): + models.append(field_type) + + return models + + def _get_nested_models( + self, model_class: type[BaseModel], visited=None + ) -> dict[str, type[BaseModel]]: + """Get all nested Pydantic models from a model class.""" + if visited is None: + visited = set() + + # Avoid infinite recursion + if model_class in visited: + return {} + + if model_class in self._nested_models_cache: + return self._nested_models_cache[model_class] + + visited.add(model_class) + nested_models = {} + + # Check all fields in the model + for field_info in model_class.model_fields.values(): + field_type = self._extract_nested_type(field_info.annotation) + + if self._is_pydantic_model(field_type): + nested_models[field_type.__name__] = field_type + # Recursively get nested models from this nested model + deeper_nested = self._get_nested_models(field_type, visited.copy()) + nested_models.update(deeper_nested) + + self._nested_models_cache[model_class] = nested_models + return nested_models + + def _build_inheritance_map(self, child_class: Type[BaseModel]): + """Build inheritance map for a class and all its parents.""" + if child_class in self._inheritance_map_cache: + return self._inheritance_map_cache[child_class] + + inheritance_map = {} + + # Get MRO and filter out BaseModel and object + mro_classes = [ + cls + for cls in child_class.__mro__ + if cls not in (BaseModel, object) and hasattr(cls, "__annotations__") + ] + + # Process each class in the MRO + for cls in mro_classes: + inheritance_map[cls] = self._get_direct_fields(cls) + + self._inheritance_map_cache[child_class] = inheritance_map + return inheritance_map + + def _wrap_comment(self, text: str, width: int = 88) -> list[str]: + """Wrap a comment to specified width, accounting for '# ' prefix.""" + if not text.strip(): + return ["#"] + + # Account for "# " prefix (2 characters) + content_width = width - 2 + wrapped_lines = textwrap.wrap(text, width=content_width) + return [f"# {line}" for line in wrapped_lines] + + def _extract_type_from_source( + self, model_class: type[BaseModel], field_name: str + ) -> str: + """Extract the actual type annotation text from source code, checking inheritance chain.""" + # Use inheritance map to check classes efficiently + inheritance_map = self._build_inheritance_map(model_class) + + # Check classes in MRO order + for cls in model_class.__mro__: + if cls in inheritance_map and field_name in inheritance_map[cls]: + type_annotation = self._get_type_from_class_source(cls, field_name) + if type_annotation != "unknown": + return type_annotation + + return "unknown" + + def _get_type_from_class_source(self, class_obj: type, field_name: str) -> str: + """Extract type annotation from a specific class's source code.""" + try: + source = inspect.getsource(class_obj) + tree = ast.parse(source) + except (OSError, TypeError): + return "unknown" + + # Find the class definition + for node in tree.body: + if isinstance(node, ast.ClassDef) and node.name == class_obj.__name__: + # Find the field assignment + for body_node in node.body: + if isinstance(body_node, ast.AnnAssign) and isinstance( + body_node.target, ast.Name + ): + if body_node.target.id == field_name and body_node.annotation: + return ast.unparse(body_node.annotation) + break + + return "unknown" + + def _extract_field_groups_from_all_classes( + self, model_class: type[BaseModel] + ) -> list[dict]: + """Extract field groups from all classes in the inheritance hierarchy.""" + all_groups = [] + inheritance_map = self._build_inheritance_map(model_class) + + # Get all Pydantic base classes in MRO order (most specific first) + # This puts AxolotlInputConfig fields first, then parent class fields + pydantic_classes = [ + cls + for cls in model_class.__mro__ + if cls in inheritance_map and inheritance_map[cls] + ] + + # Extract groups from each class + for cls in pydantic_classes: + class_groups = self._extract_field_groups_from_source(cls) + for group in class_groups: + all_groups.append(group) + + # If no groups found, create a default grouping by class + if not all_groups: + for cls in pydantic_classes: + fields_in_class = inheritance_map[cls] + if fields_in_class: + all_groups.append( + { + "fields": list(fields_in_class), + } + ) + + return all_groups + + def _extract_field_groups_from_source( + self, model_class: type[BaseModel] + ) -> list[dict]: + """Extract field groups from source code based on blank lines and comments.""" + try: + source = inspect.getsource(model_class) + tree = ast.parse(source) + except (OSError, TypeError): + # Fallback if we can't get source code + fields_in_class = self._get_direct_fields(model_class) + if fields_in_class: + return [ + { + "fields": list(fields_in_class), + } + ] + return [] + + groups = [] + current_group_fields = [] + current_group_comment = None + + # Find the class definition + class_node = None + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == model_class.__name__: + class_node = node + break + + if not class_node: + fields_in_class = self._get_direct_fields(model_class) + if fields_in_class: + return [ + { + "fields": list(fields_in_class), + } + ] + return [] + + # Parse the source lines to detect groupings + source_lines = source.split("\n") + + # Get fields that are actually defined in this specific class + fields_in_class = self._get_direct_fields(model_class) + + # Find assignments that correspond to model fields for THIS class only + field_assignments = [] + for node in class_node.body: + if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): + field_name = node.target.id + if field_name in fields_in_class: + field_assignments.append( + { + "name": field_name, + "lineno": node.lineno, + "end_lineno": getattr(node, "end_lineno", node.lineno), + } + ) + + if not field_assignments: + if fields_in_class: + return [ + { + "fields": list(fields_in_class), + } + ] + return [] + + # Sort by line number + field_assignments.sort(key=lambda x: x["lineno"]) + + # Group fields based on blank lines and comments + for i, field_info in enumerate(field_assignments): + field_name = field_info["name"] + current_line = field_info["lineno"] + + # Check if this starts a new group (blank line before or significant gap) + is_new_group = False + + if i == 0: + is_new_group = True + else: + prev_end_line = field_assignments[i - 1]["end_lineno"] + + # Check for blank lines or comments between fields + lines_between = source_lines[prev_end_line : current_line - 1] + has_blank_line = any(line.strip() == "" for line in lines_between) + has_comment = any( + line.strip().startswith("#") for line in lines_between + ) + + # Start new group if there's a blank line or comment, or significant gap + if has_blank_line or has_comment or (current_line - prev_end_line > 3): + is_new_group = True + + if is_new_group and current_group_fields: + # Save the previous group + groups.append( + { + "fields": current_group_fields.copy(), + "description": current_group_comment, + } + ) + current_group_fields = [] + current_group_comment = None + + current_group_fields.append(field_name) + + # Add the final group + if current_group_fields: + groups.append( + { + "fields": current_group_fields, + "description": current_group_comment, + } + ) + + return groups + + def _generate_field_documentation( + self, + model_class: type[BaseModel], + field_name: str, + field_info: dict, + field_type_str: str, + is_required: bool, + indent_level: int = 0, + visited_models: set = None, + ) -> list[str]: + """Generate documentation for a single field, expanding nested models inline.""" + if visited_models is None: + visited_models = set() + + lines = [] + indent = " " * indent_level + + # Get the actual field type for nested model detection + if field_name in model_class.model_fields: + pydantic_field_info = model_class.model_fields[field_name] + actual_field_type = pydantic_field_info.annotation + else: + actual_field_type = None + + # Add description comment if available + description = field_info.get("description", "") + if description: + wrapped_lines = self._wrap_comment(description, width=88 - len(indent)) + for line in wrapped_lines: + lines.append(f"{indent}{line}") + + # Extract nested Pydantic models from the type annotation + nested_models = self._extract_all_pydantic_models_from_type(actual_field_type) + + # Filter out already visited models to prevent infinite recursion + expandable_models = [ + model for model in nested_models if model not in visited_models + ] + + if expandable_models: + # This field contains Pydantic models that can be expanded + + # Show the field with its full type annotation + field_line = f"{indent}{field_name}: {field_type_str}" + if field_info.get("default") is not None: + field_line += f" = {field_info['default']}" + if is_required: + field_line += " (required)" + lines.append(field_line) + + # Add to visited to prevent infinite recursion + new_visited = visited_models.copy() + new_visited.update(expandable_models) + + # Expand each nested Pydantic model + for i, nested_model in enumerate(expandable_models): + if i > 0: + lines.append("\n") + lines.append(f"{indent} # For {nested_model.__name__}:") + + # Get nested model schema + try: + nested_schema = nested_model.model_json_schema() + nested_properties = nested_schema.get("properties", {}) + nested_required = nested_schema.get("required", []) + except Exception: + # Fallback: use model fields directly + nested_properties = {} + nested_required = [] + for ( + nested_field_name, + nested_field_info, + ) in nested_model.model_fields.items(): + nested_description = "" + if ( + hasattr(nested_field_info, "json_schema_extra") + and nested_field_info.json_schema_extra + ): + nested_description = ( + nested_field_info.json_schema_extra.get( + "description", "" + ) + ) + elif ( + hasattr(nested_field_info, "description") + and nested_field_info.description + ): + nested_description = nested_field_info.description + + nested_default_val = None + if ( + hasattr(nested_field_info, "default") + and nested_field_info.default is not None + ): + if str(nested_field_info.default) != "PydanticUndefined": + nested_default_val = nested_field_info.default + + nested_properties[nested_field_name] = { + "type": "unknown", + "description": nested_description, + "default": nested_default_val, + } + + if nested_field_info.is_required(): + nested_required.append(nested_field_name) + + # Get field groups for the nested model + nested_field_groups = self._extract_field_groups_from_all_classes( + nested_model + ) + + # Generate nested fields with increased indentation + for i, group in enumerate(nested_field_groups): + if not group["fields"]: + continue + + # Add blank line between groups (except before first group) + if i > 0: + lines.append("") + + # Process nested fields + for nested_field_name in group["fields"]: + if nested_field_name not in nested_properties: + continue + + nested_field_info = nested_properties[nested_field_name] + nested_field_type = self._extract_type_from_source( + nested_model, nested_field_name + ) + nested_is_required = nested_field_name in nested_required + + # Recursively generate documentation for nested field + nested_lines = self._generate_field_documentation( + nested_model, + nested_field_name, + nested_field_info, + nested_field_type, + nested_is_required, + indent_level + 1, + new_visited, + ) + lines.extend(nested_lines) + else: + # Regular field (no expandable nested models) + field_line = f"{indent}{field_name}: {field_type_str}" + if field_info.get("default") is not None: + field_line += f" = {field_info['default']}" + if is_required: + field_line += " (required)" + lines.append(field_line) + + return lines + + def generate_qmd( + self, + model_class: type[BaseModel], + title: str | None = None, + expand_nested: bool = True, + ) -> str: + """Auto-generate config reference documentation including inherited fields.""" + + if title is None: + title = f"{model_class.__name__} Reference" + + # Try to get JSON schema, with fallback for serialization issues + try: + schema = model_class.model_json_schema() + properties = schema.get("properties", {}) + required = schema.get("required", []) + except Exception as e: + print( + f"Warning: Could not generate JSON schema ({e}). Using model fields instead." + ) + # Fallback: use model fields directly + properties = {} + required = [] + for field_name, field_info in model_class.model_fields.items(): + # Extract description from json_schema_extra or field info + description = "" + if ( + hasattr(field_info, "json_schema_extra") + and field_info.json_schema_extra + ): + description = field_info.json_schema_extra.get("description", "") + elif hasattr(field_info, "description") and field_info.description: + description = field_info.description + + # Get default value + default_val = None + if hasattr(field_info, "default") and field_info.default is not None: + # Handle special Pydantic default markers + if str(field_info.default) != "PydanticUndefined": + default_val = field_info.default + + properties[field_name] = { + "type": "unknown", + "description": description, + "default": default_val, + } + + if field_info.is_required(): + required.append(field_name) + + # Extract field groups from all classes in inheritance hierarchy + field_groups = self._extract_field_groups_from_all_classes(model_class) + + # Start building QMD content + qmd_lines = [ + "---", + f"title: {title}", + "description: A complete list of all configuration options.", + "---", + "", + ] + + # Generate one big code block with all fields (inline nested expansion) + qmd_lines.append("```yaml") + + for i, group in enumerate(field_groups): + if not group["fields"]: + continue + + # Add blank line between groups (except before first group) + if i > 0: + qmd_lines.append("") + + # Process fields in the order they appear in source + for field_name in group["fields"]: + if field_name not in properties: + continue + + field_info = properties[field_name] + field_type = self._extract_type_from_source(model_class, field_name) + is_required = field_name in required + + if expand_nested: + # Check if this field has nested models + if field_name in model_class.model_fields: + pydantic_field_info = model_class.model_fields[field_name] + nested_models = self._extract_all_pydantic_models_from_type( + pydantic_field_info.annotation + ) + has_nested = bool(nested_models) + else: + has_nested = False + + # Add blank line before nested config + if has_nested: + qmd_lines.append("") + + # Use the new inline generation method + field_lines = self._generate_field_documentation( + model_class, + field_name, + field_info, + field_type, + is_required, + indent_level=0, + visited_models=set(), + ) + qmd_lines.extend(field_lines) + + # Add blank line after nested config + if has_nested: + qmd_lines.append("") + else: + # Original simple approach + description = field_info.get("description", "") + default = field_info.get("default") + + # Add wrapped comment for description + if description: + wrapped_lines = self._wrap_comment(description) + qmd_lines.extend(wrapped_lines) + + line = f"{field_name}: {field_type}" + if default is not None: + line += f" = {default}" + if is_required: + line += " (required)" + qmd_lines.append(line) + + qmd_lines.append("```") + + # Join all lines and clean up any double newlines + content = "\n".join(qmd_lines) + + # Replace multiple consecutive newlines with just two newlines (one blank line) + import re + + content = re.sub(r"\n{3,}", "\n\n", content) + + # Ensure single newline at the very end + content = content.rstrip("\n") + "\n" + + return content + + +def main(): + generator = QuartoGenerator() + + print("Generating config reference content...") + qmd_content = generator.generate_qmd(AxolotlInputConfig, "Config Reference", True) + + print("Writing to file...") + with open("docs/config-reference.qmd", "w", encoding="utf-8") as f: + f.write(qmd_content) + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/docs/sequence_parallelism.qmd b/docs/sequence_parallelism.qmd index b98206135..d1933a145 100644 --- a/docs/sequence_parallelism.qmd +++ b/docs/sequence_parallelism.qmd @@ -22,7 +22,7 @@ To enable sequence parallelism, add the following to your configuration file: ```yaml # Set to a divisor (> 1) of the number of GPUs available -sequence_parallel_degree: 4 # Split sequences across 4 GPUs +context_parallel_size: 4 # Split sequences across 4 GPUs # Optional; strides across the key dimension. Larger values use more memory but should make training faster. heads_k_stride: 1 # Optional; one of "varlen_llama3" or "batch_ring". Defaults to @@ -30,7 +30,7 @@ heads_k_stride: 1 ring_attn_func: ``` -The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example: +The `context_parallel_size` should be a divisor of the total number of GPUs. For example: - With 8 GPUs, valid values would be 2, 4, or 8 - With 4 GPUs, valid values would be 2 or 4 @@ -66,7 +66,7 @@ sequence_len: 8192 ... -sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU +context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU # Optional; strides across the key dimension. Larger values use more memory but should make training faster. heads_k_stride: 1 # Optional; one of "varlen_llama3" or "batch_ring". Defaults to @@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality. ## Effect on Batch Size -When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because: +When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because: -- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence) +- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence) - The number of batches processed per step decreases For example: - With 8 GPUs and no sequence parallelism: 8 different batches processed per step -- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs) +- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs) - If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4 diff --git a/docs/streaming.qmd b/docs/streaming.qmd new file mode 100644 index 000000000..2a233a4fc --- /dev/null +++ b/docs/streaming.qmd @@ -0,0 +1,120 @@ +--- +title: Streaming Datasets +description: How to use streaming mode for large-scale datasets and memory-efficient training +order: 10 +--- + +Streaming enables memory-efficient training with large datasets by loading data +incrementally rather than loading the entire dataset into memory at once. + +Use streaming when: + +- Your dataset is too large to fit in memory (e.g. when you're doing pretraining with massive text corpora) +- You want to start training immediately without preprocessing the entire dataset + +Streaming works with both remote and locally stored datasets! + +::: {.callout-note} +Streaming currently only supports a single dataset. Multi-dataset support will be added soon. +::: + + +## Configuration + +### Basic Streaming + +Enable streaming mode by setting the `streaming` flag: + +```yaml +streaming: true +``` + +### Pretraining with Streaming + +For pretraining tasks, streaming is automatically enabled when using `pretraining_dataset`: + +```yaml +pretraining_dataset: + - path: HuggingFaceFW/fineweb-edu + type: pretrain + text_column: text + split: train + +# Optionally, enable sample packing +streaming_multipack_buffer_size: 10000 +sample_packing: true +``` + +### SFT with Streaming + +For supervised fine-tuning with streaming: + +```yaml +streaming: true +datasets: + - path: tatsu-lab/alpaca + type: alpaca + split: train + +# Optionally, enable sample packing +streaming_multipack_buffer_size: 10000 +sample_packing: true +``` + +## Configuration Options + +### `streaming_multipack_buffer_size` + +Controls the buffer size for multipack streaming (default: 10,000). This determines how +many samples are buffered before packing. Larger buffers can improve packing efficiency +but use more memory. + +### `shuffle_merged_datasets` + +When enabled, shuffles the streaming dataset using the buffer. This requires additional +memory for the shuffle buffer. + +## Sample Packing with Streaming + +Sample packing is supported for streaming datasets. When enabled, multiple samples are +packed into a single sequence to maximize GPU utilization: + +```yaml +sample_packing: true +streaming_multipack_buffer_size: 10000 + +# For SFT: attention is automatically isolated between packed samples +# For pretraining: control with pretrain_multipack_attn +pretrain_multipack_attn: true # prevent cross-attention between packed samples +``` + +For more information, see our [documentation](multipack.qmd) on multipacking. + +## Important Considerations + +### Memory Usage + +While streaming reduces memory usage compared to loading entire datasets, you still need +to consider: + +- You can control the memory usage by adjusting `streaming_multipack_buffer_size` +- Sample packing requires buffering multiple samples +- Shuffling requires additional memory for the shuffle buffer + +### Performance + +- Streaming may have slightly higher latency compared to preprocessed datasets, as samples are processed on-the-fly +- Network speed and disk read speed are important when streaming from remote sources or a local dataset, respectively +- Consider using `axolotl preprocess` for smaller or more frequently used datasets + +### Evaluation Datasets + +Evaluation datasets are not streamed to ensure consistent evaluation metrics. They're +loaded normally even when training uses streaming. + +## Examples + +See the `examples/streaming/` directory for complete configuration examples: + +- `pretrain.yaml`: Pretraining with streaming dataset +- `sft.yaml`: Supervised fine-tuning with streaming diff --git a/examples/LiquidAI/README.md b/examples/LiquidAI/README.md new file mode 100644 index 000000000..8a18d9eb1 --- /dev/null +++ b/examples/LiquidAI/README.md @@ -0,0 +1,67 @@ +# Finetune Liquid Foundation Models 2 (LFM2) with Axolotl + +[Liquid Foundation Models 2 (LFM2)](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) are a family of small, open-weight models from [Liquid AI](https://www.liquid.ai/) focused on quality, speed, and memory efficiency. Liquid AI released text-only [LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) and text+vision [LFM2-VL](https://huggingface.co/collections/LiquidAI/lfm2-vl-68963bbc84a610f7638d5ffa) models. + +LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-range convolutions, and grouped query attention, enabling fast training and inference. + +This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl. + +Thanks to the team at LiquidAI for giving us early access to prepare for these releases. + +## Getting Started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + + Here is an example of how to install from pip: + ```bash + # Ensure you have a compatible version of Pytorch installed + pip3 install packaging setuptools wheel ninja + pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' + ``` + +2. Run one of the finetuning examples below. + + **LFM2** + ```bash + # FFT SFT (1x48GB @ 25GiB) + axolotl train examples/LiquidAI/lfm2-350m-fft.yaml + ``` + + **LFM2-VL** + ```bash + # LoRA SFT (1x48GB @ 2.7GiB) + axolotl train examples/LiquidAI/lfm2-vl-lora.yaml + ``` + + **LFM2-MoE** + ```bash + pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6 + + # LoRA SFT (1x48GB @ 16.2GiB) + axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml + ``` + +### TIPS + +- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it: + ```bash + pip uninstall -y causal-conv1d + ``` + +- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html). +- **Dataset Formats**: + - For LFM2 models, the dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + - For LFM2-VL models, Axolotl follows the multi-content Messages format. See our [Multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format) for details. + +## Optimization Guides + +- [Optimizations Guide](https://docs.axolotl.ai/docs/optimizations.html) + +## Related Resources + +- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models) +- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models) +- [LFM2-MoE Blog](https://www.liquid.ai/blog/lfm2-8b-a1b-an-efficient-on-device-mixture-of-experts) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/LiquidAI/lfm2-350m-fft.yaml b/examples/LiquidAI/lfm2-350m-fft.yaml new file mode 100644 index 000000000..145b56dd1 --- /dev/null +++ b/examples/LiquidAI/lfm2-350m-fft.yaml @@ -0,0 +1,50 @@ +base_model: LiquidAI/LFM2-350M + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +eot_tokens: + - "<|im_end|>" +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train[:20%] + field_messages: conversations + message_field_role: from + message_field_content: value +dataset_prepared_path: last_run_prepared +val_set_size: 0.05 +output_dir: ./outputs/out + +sequence_len: 4096 +sample_packing: true + + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 4 +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 5e-5 + +bf16: true +tf32: true + +gradient_checkpointing: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 2 +saves_per_epoch: 1 + +weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/LiquidAI/lfm2-8b-a1b-lora.yaml b/examples/LiquidAI/lfm2-8b-a1b-lora.yaml new file mode 100644 index 000000000..73cbfcce7 --- /dev/null +++ b/examples/LiquidAI/lfm2-8b-a1b-lora.yaml @@ -0,0 +1,59 @@ +base_model: LiquidAI/LFM2-8B-A1B + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: true + +eot_tokens: + - "<|im_end|>" +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train[:20%] + field_messages: conversations + message_field_role: from + message_field_content: value +dataset_prepared_path: last_run_prepared +val_set_size: 0.05 +output_dir: ./outputs/out + +sequence_len: 4096 +sample_packing: true + +adapter: lora +lora_model_dir: + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 4 +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 5e-5 + +bf16: true +tf32: true + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 2 +saves_per_epoch: 1 + +weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/LiquidAI/lfm2-vl-lora.yaml b/examples/LiquidAI/lfm2-vl-lora.yaml new file mode 100644 index 000000000..313da8274 --- /dev/null +++ b/examples/LiquidAI/lfm2-vl-lora.yaml @@ -0,0 +1,61 @@ +base_model: LiquidAI/LFM2-VL-450M +trust_remote_code: true +model_type: AutoModelForImageTextToText +processor_type: AutoProcessor + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/alst/README.md b/examples/alst/README.md new file mode 100644 index 000000000..6d201f826 --- /dev/null +++ b/examples/alst/README.md @@ -0,0 +1,30 @@ +# Arctic Long Sequence Training (ALST) + +Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization +techniques. It is a combination of: +- TiledMLP: Leverage tiling over the sequence dimension on MLP layers to reduce memory usage +- Tiled Loss: Using optimized loss functions like Liger-Kernel or Cut Cross Entropy to reduce memory usage +- Activation Offloading: Offload activations to CPU RAM to reduce memory usage + +For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996). + +## Usage + +```yaml +tiled_mlp: true + +# See Sequence Parallelism docs +# https://docs.axolotl.ai/docs/sequence_parallelism.html +context_parallel_size: int + +plugins: +# See Cut Cross Entropy docs +# https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +# or Liger Kernel docs +# https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels + - axolotl.integrations.liger.LigerPlugin +# ... + +``` diff --git a/examples/alst/llama3-8b-deepspeed-alst.yaml b/examples/alst/llama3-8b-deepspeed-alst.yaml new file mode 100644 index 000000000..dea23c5ee --- /dev/null +++ b/examples/alst/llama3-8b-deepspeed-alst.yaml @@ -0,0 +1,53 @@ +base_model: meta-llama/Llama-3.1-8B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +datasets: + - path: togethercomputer/Long-Data-Collections + type: completion + field: text + data_files: + - pretrain/rp_sub.jsonl.zst + - path: princeton-nlp/TextbookChapters + type: completion + field: chapter +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +sequence_len: 500_000 +min_sample_len: 200_000 +sample_packing: true + +tiled_mlp: true +context_parallel_size: 8 +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_8bit +lr_scheduler: cosine +learning_rate: 2e-5 + +bf16: auto +tf32: true + +gradient_checkpointing: true +activation_offloading: legacy + +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_steps: 100 +saves_per_epoch: 1 +evals_per_epoch: 2 +weight_decay: 0.0 +special_tokens: + pad_token: <|end_of_text|> + +deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_all.json + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/alst/llama3-8b-fsdp2-alst.yaml b/examples/alst/llama3-8b-fsdp2-alst.yaml new file mode 100644 index 000000000..c8a978264 --- /dev/null +++ b/examples/alst/llama3-8b-fsdp2-alst.yaml @@ -0,0 +1,59 @@ +base_model: meta-llama/Llama-3.1-8B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +datasets: + - path: togethercomputer/Long-Data-Collections + type: completion + field: text + data_files: + - pretrain/rp_sub.jsonl.zst + - path: princeton-nlp/TextbookChapters + type: completion + field: chapter +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +sequence_len: 500_000 +min_sample_len: 200_000 +sample_packing: true + +tiled_mlp: true +context_parallel_size: 8 +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_8bit +lr_scheduler: cosine +learning_rate: 2e-5 + +bf16: auto +tf32: true + +gradient_checkpointing: true +activation_offloading: legacy + +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_steps: 100 +saves_per_epoch: 1 +evals_per_epoch: 2 +weight_decay: 0.0 +special_tokens: + pad_token: <|end_of_text|> + +fsdp_version: 2 +fsdp_config: + offload_params: false # offloading is currently not compatible with SP + torchao optimizer + state_dict_type: SHARDED_STATE_DICT + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: LlamaDecoderLayer + reshard_after_forward: true + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/apertus/README.md b/examples/apertus/README.md new file mode 100644 index 000000000..774286333 --- /dev/null +++ b/examples/apertus/README.md @@ -0,0 +1,110 @@ +# Finetune Swiss-AI's Apertus with Axolotl + +[Apertus](https://huggingface.co/collections/swiss-ai/apertus-llm-68b699e65415c231ace3b059) is a family of opensource models trained by Swiss-ai. + +This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Apertus is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). + + Here is an example of how to install from main for pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.6.0 min) +git clone https://github.com/axolotl-ai-cloud/axolotl.git +cd axolotl + +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation -e '.[flash-attn]' + +# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy +python scripts/cutcrossentropy_install.py | sh +``` + +2. (Optional, highly recommended) Install XIELU CUDA + +```bash +## Recommended for reduced VRAM and faster speeds + +# Point to CUDA toolkit directory +# For those using our Docker image, use the below path. +export CUDA_HOME=/usr/local/cuda + +pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps +``` + +For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues) + +3. Run the finetuning example: + +```bash +axolotl train examples/apertus/apertus-8b-qlora.yaml +``` + +This config uses about 8.7 GiB VRAM. + +Let us know how it goes. Happy finetuning! 🚀 + +### Tips + +- For inference, the official Apertus team recommends `top_p=0.9` and `temperature=0.8`. +- You can instead use full paremter fine-tuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +### XIELU Installation Issues + +#### `ModuleNotFoundError: No module named 'torch'` + +Please check these one by one: +- Running in correct environment +- Env has PyTorch installed +- CUDA toolkit is at `CUDA_HOME` + +If those didn't help, please try the below solutions: + +1. Pass env for CMAKE and try install again: + + ```bash + Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps + ``` + +2. Git clone the repo and manually hardcode python path: + + ```bash + git clone https://github.com/nickjbrowning/XIELU + cd xielu + git checkout 59d6031 + + cd xielu + nano CMakeLists.txt # or vi depending on your preference + ``` + + ```diff + execute_process( + - COMMAND ${Python_EXECUTABLE} -c "import torch.utils; print(torch.utils.cmake_prefix_path)" + + COMMAND /root/miniconda3/envs/py3.11/bin/python -c "import torch.utils; print(torch.utils.cmake_prefix_path)" + RESULT_VARIABLE TORCH_CMAKE_PATH_RESULT + OUTPUT_VARIABLE TORCH_CMAKE_PATH_OUTPUT + ERROR_VARIABLE TORCH_CMAKE_PATH_ERROR + ) + ``` + + ```bash + pip3 install . --no-build-isolation --no-deps + ``` + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) + +## Related Resources + +- [Apertus Tech Report](https://github.com/swiss-ai/apertus-tech-report/blob/main/Apertus_Tech_Report.pdf) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/apertus/apertus-8b-qlora.yaml b/examples/apertus/apertus-8b-qlora.yaml new file mode 100644 index 000000000..521b282da --- /dev/null +++ b/examples/apertus/apertus-8b-qlora.yaml @@ -0,0 +1,64 @@ +base_model: swiss-ai/Apertus-8B-Instruct-2509 + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/arcee/README.md b/examples/arcee/README.md new file mode 100644 index 000000000..23f63663e --- /dev/null +++ b/examples/arcee/README.md @@ -0,0 +1,56 @@ +# Finetune ArceeAI's AFM with Axolotl + +[Arcee Foundation Models (AFM)](https://huggingface.co/collections/arcee-ai/afm-45b-68823397c351603014963473) are a family of 4.5B parameter open weight models trained by Arcee.ai. + +This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the AFM model. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as AFM is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). + + Here is an example of how to install from main for pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.6.0 min) +git clone https://github.com/axolotl-ai-cloud/axolotl.git +cd axolotl + +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation -e '.[flash-attn]' + +# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy +python scripts/cutcrossentropy_install.py | sh +``` + +2. Run the finetuning example: + +```bash +axolotl train examples/arcee/afm-4.5b-qlora.yaml +``` + +This config uses about 7.8GiB VRAM. + +Let us know how it goes. Happy finetuning! 🚀 + +### TIPS + +- For inference, the official Arcee.ai team recommends `top_p: 0.95`, `temperature: 0.5`, `top_k: 50`, and `repeat_penalty: 1.1`. +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) + +## Related Resources + +- [AFM Blog](https://docs.arcee.ai/arcee-foundation-models/introduction-to-arcee-foundation-models) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/arcee/afm-4.5b-qlora.yaml b/examples/arcee/afm-4.5b-qlora.yaml new file mode 100644 index 000000000..2cb42cacd --- /dev/null +++ b/examples/arcee/afm-4.5b-qlora.yaml @@ -0,0 +1,64 @@ +base_model: arcee-ai/AFM-4.5B + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/archived/README.md b/examples/archived/README.md new file mode 100644 index 000000000..da797c552 --- /dev/null +++ b/examples/archived/README.md @@ -0,0 +1,5 @@ +# Archived Examples + +This directory contains examples that are no longer maintained and may no longer be functional. + +We keep them around for archival purposes in case they are useful to others. diff --git a/examples/cerebras/btlm-ft.yml b/examples/archived/cerebras/btlm-ft.yml similarity index 98% rename from examples/cerebras/btlm-ft.yml rename to examples/archived/cerebras/btlm-ft.yml index c9878779d..c3495d287 100644 --- a/examples/cerebras/btlm-ft.yml +++ b/examples/archived/cerebras/btlm-ft.yml @@ -66,7 +66,7 @@ flash_optimum: gptq_groupsize: gptq_model_v1: -warmup_steps: 32 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 save_total_limit: diff --git a/examples/cerebras/qlora.yml b/examples/archived/cerebras/qlora.yml similarity index 98% rename from examples/cerebras/qlora.yml rename to examples/archived/cerebras/qlora.yml index 55cc597f1..4598a8338 100644 --- a/examples/cerebras/qlora.yml +++ b/examples/archived/cerebras/qlora.yml @@ -43,7 +43,7 @@ xformers_attention: true flash_attention: gptq_groupsize: gptq_model_v1: -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.1 diff --git a/examples/code-llama/13b/lora.yml b/examples/archived/code-llama/13b/lora.yml similarity index 96% rename from examples/code-llama/13b/lora.yml rename to examples/archived/code-llama/13b/lora.yml index 0ed2382ba..ace94b619 100644 --- a/examples/code-llama/13b/lora.yml +++ b/examples/archived/code-llama/13b/lora.yml @@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -47,7 +47,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/code-llama/13b/qlora.yml b/examples/archived/code-llama/13b/qlora.yml similarity index 96% rename from examples/code-llama/13b/qlora.yml rename to examples/archived/code-llama/13b/qlora.yml index 22bd1691b..f4ed17af5 100644 --- a/examples/code-llama/13b/qlora.yml +++ b/examples/archived/code-llama/13b/qlora.yml @@ -20,7 +20,7 @@ lora_model_dir: sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + lora_r: 32 lora_alpha: 16 @@ -48,7 +48,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/code-llama/34b/lora.yml b/examples/archived/code-llama/34b/lora.yml similarity index 96% rename from examples/code-llama/34b/lora.yml rename to examples/archived/code-llama/34b/lora.yml index 25dc9f421..0a1d71467 100644 --- a/examples/code-llama/34b/lora.yml +++ b/examples/archived/code-llama/34b/lora.yml @@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -47,7 +47,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/code-llama/34b/qlora.yml b/examples/archived/code-llama/34b/qlora.yml similarity index 96% rename from examples/code-llama/34b/qlora.yml rename to examples/archived/code-llama/34b/qlora.yml index 0e33e2a45..ec17bf200 100644 --- a/examples/code-llama/34b/qlora.yml +++ b/examples/archived/code-llama/34b/qlora.yml @@ -20,7 +20,7 @@ lora_model_dir: sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + lora_r: 32 lora_alpha: 16 @@ -48,7 +48,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/code-llama/7b/lora.yml b/examples/archived/code-llama/7b/lora.yml similarity index 95% rename from examples/code-llama/7b/lora.yml rename to examples/archived/code-llama/7b/lora.yml index d288b9f65..174c17d2c 100644 --- a/examples/code-llama/7b/lora.yml +++ b/examples/archived/code-llama/7b/lora.yml @@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -47,7 +47,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/code-llama/7b/qlora.yml b/examples/archived/code-llama/7b/qlora.yml similarity index 96% rename from examples/code-llama/7b/qlora.yml rename to examples/archived/code-llama/7b/qlora.yml index de41c0123..08e67d8c2 100644 --- a/examples/code-llama/7b/qlora.yml +++ b/examples/archived/code-llama/7b/qlora.yml @@ -20,7 +20,7 @@ lora_model_dir: sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + lora_r: 32 lora_alpha: 16 @@ -48,7 +48,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/code-llama/README.md b/examples/archived/code-llama/README.md similarity index 100% rename from examples/code-llama/README.md rename to examples/archived/code-llama/README.md diff --git a/examples/dbrx/16bit-lora.yaml b/examples/archived/dbrx/16bit-lora.yaml similarity index 98% rename from examples/dbrx/16bit-lora.yaml rename to examples/archived/dbrx/16bit-lora.yaml index 852654d49..05946dfe9 100644 --- a/examples/dbrx/16bit-lora.yaml +++ b/examples/archived/dbrx/16bit-lora.yaml @@ -54,7 +54,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: saves_per_epoch: 1 diff --git a/examples/dbrx/8bit-lora.yaml b/examples/archived/dbrx/8bit-lora.yaml similarity index 98% rename from examples/dbrx/8bit-lora.yaml rename to examples/archived/dbrx/8bit-lora.yaml index 0b9402194..f159bf7fa 100644 --- a/examples/dbrx/8bit-lora.yaml +++ b/examples/archived/dbrx/8bit-lora.yaml @@ -57,7 +57,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: saves_per_epoch: 1 diff --git a/examples/dbrx/README.md b/examples/archived/dbrx/README.md similarity index 100% rename from examples/dbrx/README.md rename to examples/archived/dbrx/README.md diff --git a/examples/dbrx/fft-ds-zero3.yaml b/examples/archived/dbrx/fft-ds-zero3.yaml similarity index 98% rename from examples/dbrx/fft-ds-zero3.yaml rename to examples/archived/dbrx/fft-ds-zero3.yaml index e42c16673..13cd0d997 100644 --- a/examples/dbrx/fft-ds-zero3.yaml +++ b/examples/archived/dbrx/fft-ds-zero3.yaml @@ -41,7 +41,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: saves_per_epoch: 1 diff --git a/examples/deepcoder/deepcoder-14B-preview-lora.yml b/examples/archived/deepcoder/deepcoder-14B-preview-lora.yml similarity index 86% rename from examples/deepcoder/deepcoder-14B-preview-lora.yml rename to examples/archived/deepcoder/deepcoder-14B-preview-lora.yml index 9e92c0a07..3223ec19a 100644 --- a/examples/deepcoder/deepcoder-14B-preview-lora.yml +++ b/examples/archived/deepcoder/deepcoder-14B-preview-lora.yml @@ -9,10 +9,6 @@ strict: false datasets: - path: fozziethebeat/alpaca_messages_2k_test type: chat_template - field_messages: messages - message_property_mappings: - role: role - content: content dataset_prepared_path: val_set_size: 0.05 @@ -21,7 +17,7 @@ output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true eval_sample_packing: false -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -51,7 +47,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/falcon/config-7b-lora.yml b/examples/archived/falcon/config-7b-lora.yml similarity index 98% rename from examples/falcon/config-7b-lora.yml rename to examples/archived/falcon/config-7b-lora.yml index 391d4dd94..f4fedbede 100644 --- a/examples/falcon/config-7b-lora.yml +++ b/examples/archived/falcon/config-7b-lora.yml @@ -47,7 +47,7 @@ xformers_attention: true flash_attention: gptq_groupsize: gptq_model_v1: -warmup_steps: 40 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/falcon/config-7b-qlora.yml b/examples/archived/falcon/config-7b-qlora.yml similarity index 99% rename from examples/falcon/config-7b-qlora.yml rename to examples/archived/falcon/config-7b-qlora.yml index a9af8574c..a44cc40a6 100644 --- a/examples/falcon/config-7b-qlora.yml +++ b/examples/archived/falcon/config-7b-qlora.yml @@ -77,7 +77,7 @@ xformers_attention: true flash_attention: gptq_groupsize: gptq_model_v1: -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.000001 diff --git a/examples/falcon/config-7b.yml b/examples/archived/falcon/config-7b.yml similarity index 98% rename from examples/falcon/config-7b.yml rename to examples/archived/falcon/config-7b.yml index 3cc553daa..5481fb236 100644 --- a/examples/falcon/config-7b.yml +++ b/examples/archived/falcon/config-7b.yml @@ -44,7 +44,7 @@ xformers_attention: true flash_attention: gptq_groupsize: gptq_model_v1: -warmup_steps: 40 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/gemma/qlora.yml b/examples/archived/gemma/qlora.yml similarity index 97% rename from examples/gemma/qlora.yml rename to examples/archived/gemma/qlora.yml index 2738112b4..80829b3c9 100644 --- a/examples/gemma/qlora.yml +++ b/examples/archived/gemma/qlora.yml @@ -25,7 +25,7 @@ lora_target_linear: true sequence_len: 4096 sample_packing: true eval_sample_packing: false -pad_to_sequence_len: true + wandb_project: wandb_entity: diff --git a/examples/gptj/qlora.yml b/examples/archived/gptj/qlora.yml similarity index 98% rename from examples/gptj/qlora.yml rename to examples/archived/gptj/qlora.yml index c3cf9f973..6348566c2 100644 --- a/examples/gptj/qlora.yml +++ b/examples/archived/gptj/qlora.yml @@ -40,7 +40,7 @@ xformers_attention: true flash_attention: gptq_groupsize: gptq_model_v1: -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.1 diff --git a/examples/jeopardy-bot/config.yml b/examples/archived/jeopardy-bot/config.yml similarity index 98% rename from examples/jeopardy-bot/config.yml rename to examples/archived/jeopardy-bot/config.yml index 3609bd97e..ab1d19784 100644 --- a/examples/jeopardy-bot/config.yml +++ b/examples/archived/jeopardy-bot/config.yml @@ -41,7 +41,7 @@ xformers_attention: true flash_attention: gptq_groupsize: gptq_model_v1: -warmup_steps: 20 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.1 diff --git a/examples/mpt-7b/README.md b/examples/archived/mpt-7b/README.md similarity index 100% rename from examples/mpt-7b/README.md rename to examples/archived/mpt-7b/README.md diff --git a/examples/mpt-7b/config.yml b/examples/archived/mpt-7b/config.yml similarity index 98% rename from examples/mpt-7b/config.yml rename to examples/archived/mpt-7b/config.yml index e7485fad7..1fff51b6e 100644 --- a/examples/mpt-7b/config.yml +++ b/examples/archived/mpt-7b/config.yml @@ -42,7 +42,7 @@ logging_steps: 5 flash_attention: gptq_groupsize: gptq_model_v1: -warmup_steps: 20 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0001 diff --git a/examples/openllama-3b/README.md b/examples/archived/openllama-3b/README.md similarity index 100% rename from examples/openllama-3b/README.md rename to examples/archived/openllama-3b/README.md diff --git a/examples/openllama-3b/config.yml b/examples/archived/openllama-3b/config.yml similarity index 98% rename from examples/openllama-3b/config.yml rename to examples/archived/openllama-3b/config.yml index 17eeb73ae..63056ed6d 100644 --- a/examples/openllama-3b/config.yml +++ b/examples/archived/openllama-3b/config.yml @@ -42,7 +42,7 @@ logging_steps: 1 flash_attention: true gptq_groupsize: gptq_model_v1: -warmup_steps: 20 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.1 diff --git a/examples/openllama-3b/lora.yml b/examples/archived/openllama-3b/lora.yml similarity index 98% rename from examples/openllama-3b/lora.yml rename to examples/archived/openllama-3b/lora.yml index 073117f11..b70821ce2 100644 --- a/examples/openllama-3b/lora.yml +++ b/examples/archived/openllama-3b/lora.yml @@ -50,7 +50,7 @@ logging_steps: 1 flash_attention: true gptq_groupsize: gptq_model_v1: -warmup_steps: 20 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.1 diff --git a/examples/openllama-3b/qlora.yml b/examples/archived/openllama-3b/qlora.yml similarity index 98% rename from examples/openllama-3b/qlora.yml rename to examples/archived/openllama-3b/qlora.yml index b4fca2c07..a34f2964b 100644 --- a/examples/openllama-3b/qlora.yml +++ b/examples/archived/openllama-3b/qlora.yml @@ -43,7 +43,7 @@ logging_steps: 1 flash_attention: true gptq_groupsize: gptq_model_v1: -warmup_steps: 20 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.1 diff --git a/examples/pythia-12b/README.md b/examples/archived/pythia-12b/README.md similarity index 100% rename from examples/pythia-12b/README.md rename to examples/archived/pythia-12b/README.md diff --git a/examples/pythia-12b/config.yml b/examples/archived/pythia-12b/config.yml similarity index 100% rename from examples/pythia-12b/config.yml rename to examples/archived/pythia-12b/config.yml diff --git a/examples/pythia/lora.yml b/examples/archived/pythia/lora.yml similarity index 100% rename from examples/pythia/lora.yml rename to examples/archived/pythia/lora.yml diff --git a/examples/qwen/README.md b/examples/archived/qwen/README.md similarity index 100% rename from examples/qwen/README.md rename to examples/archived/qwen/README.md diff --git a/examples/qwen/lora.yml b/examples/archived/qwen/lora.yml similarity index 98% rename from examples/qwen/lora.yml rename to examples/archived/qwen/lora.yml index 9a2843236..29de25611 100644 --- a/examples/qwen/lora.yml +++ b/examples/archived/qwen/lora.yml @@ -49,7 +49,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/qwen/qlora.yml b/examples/archived/qwen/qlora.yml similarity index 98% rename from examples/qwen/qlora.yml rename to examples/archived/qwen/qlora.yml index 5f85b44dd..d46669444 100644 --- a/examples/qwen/qlora.yml +++ b/examples/archived/qwen/qlora.yml @@ -49,7 +49,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/qwen/qwen2-moe-lora.yaml b/examples/archived/qwen/qwen2-moe-lora.yaml similarity index 98% rename from examples/qwen/qwen2-moe-lora.yaml rename to examples/archived/qwen/qwen2-moe-lora.yaml index afce443a0..1d5e1b524 100644 --- a/examples/qwen/qwen2-moe-lora.yaml +++ b/examples/archived/qwen/qwen2-moe-lora.yaml @@ -45,7 +45,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/qwen/qwen2-moe-qlora.yaml b/examples/archived/qwen/qwen2-moe-qlora.yaml similarity index 98% rename from examples/qwen/qwen2-moe-qlora.yaml rename to examples/archived/qwen/qwen2-moe-qlora.yaml index 92a6842cf..08731441b 100644 --- a/examples/qwen/qwen2-moe-qlora.yaml +++ b/examples/archived/qwen/qwen2-moe-qlora.yaml @@ -48,7 +48,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/redpajama/README.md b/examples/archived/redpajama/README.md similarity index 100% rename from examples/redpajama/README.md rename to examples/archived/redpajama/README.md diff --git a/examples/redpajama/config-3b.yml b/examples/archived/redpajama/config-3b.yml similarity index 98% rename from examples/redpajama/config-3b.yml rename to examples/archived/redpajama/config-3b.yml index 3e2999df9..c5b229c3d 100644 --- a/examples/redpajama/config-3b.yml +++ b/examples/archived/redpajama/config-3b.yml @@ -43,7 +43,7 @@ logging_steps: 5 flash_attention: gptq_groupsize: gptq_model_v1: -warmup_steps: 20 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0001 diff --git a/examples/replit-3b/config-lora.yml b/examples/archived/replit-3b/config-lora.yml similarity index 98% rename from examples/replit-3b/config-lora.yml rename to examples/archived/replit-3b/config-lora.yml index 5a02ba10c..d8561762c 100644 --- a/examples/replit-3b/config-lora.yml +++ b/examples/archived/replit-3b/config-lora.yml @@ -41,7 +41,7 @@ logging_steps: 1 flash_attention: gptq_groupsize: gptq_model_v1: -warmup_steps: 20 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0 diff --git a/examples/stablelm-2/1.6b/fft.yml b/examples/archived/stablelm-2/1.6b/fft.yml similarity index 93% rename from examples/stablelm-2/1.6b/fft.yml rename to examples/archived/stablelm-2/1.6b/fft.yml index 9b45b399f..585888f43 100644 --- a/examples/stablelm-2/1.6b/fft.yml +++ b/examples/archived/stablelm-2/1.6b/fft.yml @@ -16,7 +16,7 @@ output_dir: ./outputs/out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + adapter: lora_model_dir: @@ -47,10 +47,9 @@ logging_steps: 1 flash_attention: true flash_attn_cross_entropy: false flash_attn_rms_norm: true -flash_attn_fuse_qkv: false flash_attn_fuse_mlp: true -warmup_steps: 100 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 diff --git a/examples/stablelm-2/1.6b/lora.yml b/examples/archived/stablelm-2/1.6b/lora.yml similarity index 96% rename from examples/stablelm-2/1.6b/lora.yml rename to examples/archived/stablelm-2/1.6b/lora.yml index 31e5ad933..6d358bdd8 100644 --- a/examples/stablelm-2/1.6b/lora.yml +++ b/examples/archived/stablelm-2/1.6b/lora.yml @@ -19,7 +19,7 @@ output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -51,7 +51,7 @@ flash_attention: true flash_attn_cross_entropy: false flash_attn_rms_norm: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/stablelm-2/README.md b/examples/archived/stablelm-2/README.md similarity index 100% rename from examples/stablelm-2/README.md rename to examples/archived/stablelm-2/README.md diff --git a/examples/starcoder2/qlora.yml b/examples/archived/starcoder2/qlora.yml similarity index 95% rename from examples/starcoder2/qlora.yml rename to examples/archived/starcoder2/qlora.yml index 18d85f9c3..fecf98d23 100644 --- a/examples/starcoder2/qlora.yml +++ b/examples/archived/starcoder2/qlora.yml @@ -19,7 +19,7 @@ lora_model_dir: sequence_len: 8192 sample_packing: true -pad_to_sequence_len: true + lora_r: 32 lora_alpha: 16 @@ -48,7 +48,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 20 +warmup_ratio: 0.1 evals_per_epoch: 4 eval_steps: saves_per_epoch: 4 diff --git a/examples/tiny-llama/README.md b/examples/archived/tiny-llama/README.md similarity index 100% rename from examples/tiny-llama/README.md rename to examples/archived/tiny-llama/README.md diff --git a/examples/tiny-llama/lora-mps.yml b/examples/archived/tiny-llama/lora-mps.yml similarity index 95% rename from examples/tiny-llama/lora-mps.yml rename to examples/archived/tiny-llama/lora-mps.yml index 66cf7cfb3..125090a78 100644 --- a/examples/tiny-llama/lora-mps.yml +++ b/examples/archived/tiny-llama/lora-mps.yml @@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + eval_sample_packing: false adapter: lora @@ -49,7 +49,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: false -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 0 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/tiny-llama/lora.yml b/examples/archived/tiny-llama/lora.yml similarity index 95% rename from examples/tiny-llama/lora.yml rename to examples/archived/tiny-llama/lora.yml index 90998880f..817481e18 100644 --- a/examples/tiny-llama/lora.yml +++ b/examples/archived/tiny-llama/lora.yml @@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true eval_sample_packing: false -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -47,7 +47,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/tiny-llama/pretrain.yml b/examples/archived/tiny-llama/pretrain.yml similarity index 97% rename from examples/tiny-llama/pretrain.yml rename to examples/archived/tiny-llama/pretrain.yml index 5b3706bcb..f15c6ce19 100644 --- a/examples/tiny-llama/pretrain.yml +++ b/examples/archived/tiny-llama/pretrain.yml @@ -38,7 +38,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/tiny-llama/qlora.yml b/examples/archived/tiny-llama/qlora.yml similarity index 95% rename from examples/tiny-llama/qlora.yml rename to examples/archived/tiny-llama/qlora.yml index 8b2a4565a..d3ff59cb8 100644 --- a/examples/tiny-llama/qlora.yml +++ b/examples/archived/tiny-llama/qlora.yml @@ -21,7 +21,7 @@ lora_model_dir: sequence_len: 4096 sample_packing: true eval_sample_packing: false -pad_to_sequence_len: true + lora_r: 32 lora_alpha: 16 @@ -49,7 +49,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/xgen-7b/xgen-7b-8k-qlora.yml b/examples/archived/xgen-7b/xgen-7b-8k-qlora.yml similarity index 99% rename from examples/xgen-7b/xgen-7b-8k-qlora.yml rename to examples/archived/xgen-7b/xgen-7b-8k-qlora.yml index 48066b130..fc09a1e7b 100644 --- a/examples/xgen-7b/xgen-7b-8k-qlora.yml +++ b/examples/archived/xgen-7b/xgen-7b-8k-qlora.yml @@ -75,7 +75,7 @@ xformers_attention: true flash_attention: gptq_groupsize: gptq_model_v1: -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 diff --git a/examples/yi-34B-chat/README.md b/examples/archived/yi-34B-chat/README.md similarity index 100% rename from examples/yi-34B-chat/README.md rename to examples/archived/yi-34B-chat/README.md diff --git a/examples/yi-34B-chat/qlora.yml b/examples/archived/yi-34B-chat/qlora.yml similarity index 98% rename from examples/yi-34B-chat/qlora.yml rename to examples/archived/yi-34B-chat/qlora.yml index a0a95d86f..ba8d12fc8 100644 --- a/examples/yi-34B-chat/qlora.yml +++ b/examples/archived/yi-34B-chat/qlora.yml @@ -20,7 +20,7 @@ special_tokens: datasets: - path: mhenrichsen/alpaca_2k_test type: alpaca -warmup_steps: 10 +warmup_ratio: 0.1 # Iterations num_epochs: 1 diff --git a/examples/cloud/baseten.yaml b/examples/cloud/baseten.yaml new file mode 100644 index 000000000..23c4b52d6 --- /dev/null +++ b/examples/cloud/baseten.yaml @@ -0,0 +1,10 @@ +provider: baseten +project_name: + +secrets: + - HF_TOKEN + - WANDB_API_KEY + +gpu: h100 +gpu_count: 8 +node_count: 1 diff --git a/examples/cohere/command-r-7b-qlora.yml b/examples/cohere/command-r-7b-qlora.yml index 4a30e9a77..b4741636b 100644 --- a/examples/cohere/command-r-7b-qlora.yml +++ b/examples/cohere/command-r-7b-qlora.yml @@ -27,7 +27,7 @@ lora_target_linear: true sequence_len: 2048 sample_packing: true eval_sample_packing: false -pad_to_sequence_len: true + wandb_project: wandb_entity: @@ -35,7 +35,6 @@ wandb_watch: wandb_name: wandb_log_model: - gradient_accumulation_steps: 4 micro_batch_size: 1 num_epochs: 4 @@ -56,3 +55,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index 0b373c28c..cea1aeda0 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -2,356 +2,9943 @@ "cells": [ { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "OPLSwmgdrB7g" + }, "source": [ - "## Setting up" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "# Check so there is a gpu available, a T4(free tier) is enough to run this notebook\n", - "assert (torch.cuda.is_available()==True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install --no-build-isolation axolotl[deepspeed]" + "# Fine-Tune Qwen3 14B with Axolotl\n", + "\n", + "[\"Built](https://github.com/axolotl-ai-cloud/axolotl)\n", + "\n", + "Axolotl is the most performant LLM post-training framework available, delivering faster training with efficient, consistent and stable performance. Train your workload and ship your product 30% faster; saving you both time and money.\n", + "\n", + "- ⭐ us on [GitHub](https://github.com/axolotl-ai-cloud/axolotl)\n", + "- 📜 Read the [Docs](http://docs.axolotl.ai/)\n", + "- 💬 Chat with us on [Discord](https://discord.gg/mnpEYgRUmD)\n", + "- 📰 Get updates on [X/Twitter](https://x.com/axolotl_ai)\n" ] }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "rVjKD7CbxIP3" + }, "source": [ - "## Hugging Face login (optional)" + "# Installation\n", + "\n", + "Axolotl is easy to install from [pip](https://pypi.org/project/axolotl/), or use our [pre-built Docker images](http://docs.axolotl.ai/docs/docker.html) for a hassle free dependency experience. See our [docs](http://docs.axolotl.ai/docs/installation.html) for more information." ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "id": "msOCO4NRmRLa" + }, "outputs": [], + "source": [ + "%%capture\n", + "# This step can take ~5-10 minutes to install dependencies\n", + "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", + "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N0OW0YeksDLr" + }, + "source": [ + "## Demo: Talk Like a Pirate\n", + "\n", + "In this demo, we are training the model ***to respond like a pirate***. This was chosen as a way to easily show how to train a model to respond in a certain style of your choosing (without being prompted) and is quite easy to validate within the scope of a Colab." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8Du2fANTsNCK" + }, + "source": [ + "### Upload your own dataset or use a Huggingface dataset\n", + "\n", + "You can choose to use your own JSONL file from your own [Google Drive](https://drive.google.com/drive/home); for example downloading the [Pirate-Ultrachat JSONL](https://huggingface.co/datasets/winglian/pirate-ultrachat-10k/blob/main/train.jsonl) to your Google Drive. JSONL datasets should be formatted similar to the [OpenAI dataset format](https://cookbook.openai.com/examples/chat_finetuning_data_prep).\n", + "\n", + "You can also simply use the [`winglian/pirate-ultrachat-10k`](https://huggingface.co/datasets/winglian/pirate-ultrachat-10k) dataset directly.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fGEEjyQ-r_IV" + }, + "outputs": [], + "source": [ + "# Default to HF dataset location\n", + "dataset_id = \"winglian/pirate-ultrachat-10k\"\n", + "uploaded = {}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "c5MyYqk7vIsG" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# Optionally, upload your own JSONL to your Google Drive\n", + "GOOGLE_DRIVE_PATH = \"\" # ex: \"MyDrive/Colab\\ Notebooks/train.jsonl\"\n", + "\n", + "# \"Select All\" permissions, or you may get the error:\n", + "# \"MessageError: Error: credential propagation was unsuccessful\"\n", + "if GOOGLE_DRIVE_PATH:\n", + " from google.colab import drive\n", + "\n", + " # Mount your Google Drive\n", + " GOOGLE_DRIVE_MNT = \"/content/drive/\"\n", + " drive.mount(GOOGLE_DRIVE_MNT, force_remount=True)\n", + " tmp_path = os.path.join(GOOGLE_DRIVE_MNT, GOOGLE_DRIVE_PATH.lstrip(\"/\"))\n", + " # make sure file exists\n", + " if not os.path.isfile(tmp_path):\n", + " raise ValueError(f\"File {tmp_path} does not exist\")\n", + " dataset_id = tmp_path" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "U6pTk3A9xj1W" + }, + "source": [ + "# Configure for Supervised Fine-Tuning (SFT)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 151, + "referenced_widgets": [ + "388f618924274d21a066f098f4f1e744", + "7c95f85a2b1f47a1bd846d110c47bb3c", + "083f9cda8d754c168beee10d2f8955a2", + "62e1a65582f446a78612eaa804e08a7d", + "487a177d020f4605834878b2fdc7afa3", + "7fd44cf9ca6e4726bfd7ac21846d6a14", + "366a343b62fa47d8985a3bd464d99f9e", + "a0a11e929edd4189b79723d618522c33", + "e87ea87fcff247b5bbcc331ba79a8dc2", + "5e18768f7ad6434ba8b8b8a2e853e204", + "bb33aec33a6447078c31bfd728942994" + ] + }, + "id": "fdRioqytmTtX", + "outputId": "f0acdcec-4b41-4a3f-ffed-c2d2d929158e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2025-05-08 13:40:27,488] [INFO] [root.register:348] [PID:174] Attempting to load plugin: axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n", + "[2025-05-08 13:40:27,493] [INFO] [root.register:351] [PID:174] Plugin loaded successfully: axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n", + "[2025-05-08 13:40:27,959] [INFO] [axolotl.utils.schemas.config.check_eval_packing:721] [PID:174] [RANK:0] explicitly setting `eval_sample_packing` to match `sample_packing`\u001b[39m\n", + "[2025-05-08 13:40:27,960] [INFO] [axolotl.utils.schemas.config.hint_sample_packing_padding:514] [PID:174] [RANK:0] Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing\u001b[39m\n", + "[2025-05-08 13:40:27,961] [INFO] [axolotl.utils.schemas.config.check_bf16:1251] [PID:174] [RANK:0] bf16 support detected, but not enabled for this configuration.\u001b[39m\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "388f618924274d21a066f098f4f1e744", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "config.json: 0%| | 0.00/728 [00:00\"],\n", + " }\n", + " ],\n", + " dataloader_prefetch_factor=8, # dataloader optimizations\n", + " dataloader_num_workers=2,\n", + " dataloader_pin_memory=True,\n", + ")\n", + "\n", + "# validates the configuration\n", + "cfg = load_cfg(config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "715UpvnSoBIS" + }, + "outputs": [], + "source": [ + "from axolotl.utils import set_pytorch_cuda_alloc_conf\n", + "\n", + "# Set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n", + "set_pytorch_cuda_alloc_conf()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Vc6MC-hwyH-n" + }, + "source": [ + "# Datasets\n", + "\n", + "Axolotl has a robust suite of loaders and transforms to parse most open datasets of any format into the appropriate chat template for your model. Axolotl will mask input tokens from the user's prompt so that the train loss is only calculated against the model's response. For more information, [see our documentation](http://docs.axolotl.ai/docs/dataset-formats/conversation.html) on dataset preparation.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "b82aa8c57f7c422a9a9c90f333ed2a99", + "c0991cf63ee6458b96e9a75e7a88b61a", + "71c8af139cd248b1b51101fd46a93f35", + "1d5117195d4b49eb8f1a73b18419f7ce", + "3c21e4a511b4441192c03b7f1d0976e9", + "ed28e2e0410d4e0b855467e798e53d66", + "d93f134f802b4b69b575bdaf07dbd27c", + "d0e9dce55cec4c1ca619a0ccf209d924", + "4c727d40ef0443449afc31724ee79f0c", + "0dea5caa27384f5689e3cab51f558727", + "a6f48410b9964fefba0c3009a77dc838", + "95caff42f08a4c2aa14c867b8f37f231", + "de7c37ee83e24f0c889e84d07279c2ec", + "9d4897eefb5f48259ffb2d23e332f752", + "253017b0d0534e54ab44e181f6d7c82d", + "27beaf06e41b472abdb544a43c720c5a", + "34cf3df51fbc41cabfdbba153c007f0e", + "ac764024cf1c4e08ba7749afd2cd20ac", + "30a81da86f8043eca301e86a8651201a", + "e8b7a81040904c1e89e58978223b1737", + "1c6f1f10667545aaab958016ba7e2c94", + "e6e969610738449887259063967f82b0", + "a138859f19b74fc0928dc236ab5359db", + "9b42e08b3c9548818488268768a118b1", + "12b56912736849fea2ad8124456fdc5c", + "879c8ab5873847a8833bd74123be90a4", + "20352e5f58d24bb8b1f3940efd14fe4a", + "d955dcaa0e944e719f3a06139dd54a03", + "d3de2662c7964f1ba96e58da382af720", + "97e36007e1304e1583fd81bfb13f0edd", + "c65dc74c7d6f4bab8f7dd28455161dd8", + "ef223e8504b64e3592589880326aaf41", + "598da69727bd4fb8b1caf465ac736d7a", + "5f86cd894de94c3280fadc1e2fd0ee13", + "a20927bf5f2c41f58c1e31ac858ab36c", + "0a46ad75c198463d843fb35e813642cb", + "09007681cf8d42aeb8c1d2f6a74e470a", + "ebc80d1a55fa47f4a5ea2756588569ec", + "1811cda0644e4190a9469d1774435d82", + "35c811d2ae8e43f3b5cecbdd3cfa857f", + "b8e39e4dddc3497fbc29ae45c66da759", + "63b4e563e85c4f03b1b72beda9577bcc", + "b195f160ca20442fadd8b5aed0ee41af", + "ca65e32eb52f48c09a84b33cb18f22cd", + "7cd0b85ebd204b7aba908417811ce4e0", + "7baeab52d6694c32b1efd1ea1a0a7782", + "519a7b154022443db6703f04a9142bae", + "d4183e9715f34d249942b8271cca3bdf", + "da2347ac94764a3fa2743343cf0d3cd2", + "93a44a11aa4846fa8efc6c1413ef1627", + "a55060adc3564407ac81ad7297d34aaa", + "d02274afd47b462291c745f261209d42", + "0f417447a7bd4a33acca96fa37aec877", + "63580b6fb30642479fe3000915bf551a", + "8f726dbfb45d4528afa33e36a6313267", + "03b093d592ba4386aa61f7b8483da660", + "b8766a88716948cf968f4563531a76d9", + "6f3a28b912714c6e931003549664bfa3", + "16d1283741404b7bb319094c992fce01", + "2a5bb0e818ab47be8cf6465988328503", + "2b3a2659b12244bd8548320320016dbf", + "0cd7efffbb3c4c4b972e63749f61ab97", + "5ca240f31e6b44e3882c5eb37cd5a309", + "5eb06edeb58e4930b1affef2a59eae81", + "a4e5789584564049b83df7c6c54a3e08", + "ff3a94b146a948b6907f5d80c7157f99", + "258b7c635c1045329d4669e48c46ccd5", + "6f68ed9889f54ad2ae8a3b95ac263a83", + "80366349d81e4dcc892db6cd56e384f3", + "c73055099c084dca996159e23e162d0b", + "977f799afaac4a55b2dc1cffa7d5b63b", + "41f3b32c2f6b4034ae7a3b9124e28bc7", + "a10d0a76010f4e508c65a9b69ebc5156", + "f8ef805b776145c3bfa9ba8d90972058", + "cc587493c33c4f118d1b1170f85be24c", + "e40d1c1ac9494b3bade9858324e7ffdf", + "d65b6b060d9845779299491ac5599c31", + "0f6907ebbc6242c8bde059cef1e1bd29", + "5bdfd87fc6cd4f9dabef7cfee29c8060", + "64f54d4a744a4627a07c3c0120276f3b", + "65b75b9b8bc143cf997796af68ff6668", + "d6fe74e4255444368f8f90a62157d869", + "4d468f96ec924681ad65eb671674b93e", + "ad7599de524549c48bf2d3124ad4b299", + "0546d04aae644dde846c58a4afb598a6", + "897b77a56c09479bb11d7f2a30997e55", + "81c3db71ac704280ad030072655f1537", + "042e091f75694c47aee761e760e76773", + "ef0a3c7a6f14460fb4da096928ae249e", + "07fb3a2c8315494e97b447e672dfae06", + "ec030fc3c346426f9abc3a89892258d3", + "e3fb3fc6afe04b3c9b7ac61809ce78fa", + "c3be9109d63c485d9c0ef4f9bc0f9218", + "12815f401eba44658caa7b2e490137a8", + "30e02aa2d0d241979369e598287f2639", + "dfd2a2649b8341ef913207526708aff1", + "4f1977d7e4824ef1a14b65f0f42bba10", + "c6164e05a1914ae48083db9ad7f4ef7c", + "813621384dc748b0ad06775e22761c0b", + "dc892a596f6942d7973c616c38f0eebb", + "c84cc07789be48aebb322c23d355289e", + "bed8726b8069434687c75452e21f19e5", + "16a188a0b06d45f980dcf3933509fe0a", + "60c1a0d765c14a1d888317e6a507e4ea", + "0077aedc3d174560bce924ee89e9c006", + "00321cce58884f6f9b3855a21fcd9187", + "fa864b41586f4a7aa56aeafd1d84eb75", + "3225603166b54e7aab766b9964a2f660", + "349eee9f56d64f0cba6fc24ff2c50c9b", + "7e5d3774060e4589aa65982da5ea4ef4", + "7c2485c6cdfe463da6fdb35982a1070d", + "ad1236893754446881e153adc9d5c962", + "daee63fd167e4441a32324b51b00ad2b", + "fe41858c6bd04c58840112b67c19a336", + "d262c82138024169b9f3aa034ca756fa", + "62e302ebdad64aada0ffe64ae1c873f3", + "bd1b0dfed6d34d16af33a4a58330f5ec", + "d07c8b97d3314f1c852e44bdd40f61ed", + "ebb69a2c3d0a4299a484698287b3087c", + "e5a82df528bb4e408797a3b6c2758f4a", + "f113ebd8c1c34806bea4dd7ed3035173" + ] + }, + "id": "KQQhgK8FoDfF", + "outputId": "f69441d8-95f9-4885-c306-6c8709090ff6" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b82aa8c57f7c422a9a9c90f333ed2a99", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "tokenizer_config.json: 0%| | 0.00/9.68k [00:00\u001b[39m\n", + "[2025-05-08 13:41:00,845] [DEBUG] [axolotl.utils.models.load_tokenizer:442] [PID:174] [RANK:0] BOS: None / None\u001b[39m\n", + "[2025-05-08 13:41:00,846] [DEBUG] [axolotl.utils.models.load_tokenizer:443] [PID:174] [RANK:0] PAD: 151643 / <|endoftext|>\u001b[39m\n", + "[2025-05-08 13:41:00,847] [DEBUG] [axolotl.utils.models.load_tokenizer:444] [PID:174] [RANK:0] UNK: None / None\u001b[39m\n", + "[2025-05-08 13:41:00,869] [INFO] [axolotl.utils.data.sft.load_tokenized_prepared_datasets:271] [PID:174] [RANK:0] Unable to find prepared dataset in last_run_prepared/97037817611d38b3a9c681753c3c4c95\u001b[39m\n", + "[2025-05-08 13:41:00,870] [INFO] [axolotl.utils.data.sft.load_tokenized_prepared_datasets:272] [PID:174] [RANK:0] Loading raw datasets...\u001b[39m\n", + "\u001b[33m[2025-05-08 13:41:00,870] [WARNING] [axolotl.utils.data.sft.load_tokenized_prepared_datasets:274] [PID:174] [RANK:0] Processing datasets during training can lead to VRAM instability. Please pre-process your dataset.\u001b[39m\n", + "[2025-05-08 13:41:00,871] [INFO] [axolotl.utils.data.sft.load_tokenized_prepared_datasets:281] [PID:174] [RANK:0] No seed provided, using default seed of 42\u001b[39m\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7cd0b85ebd204b7aba908417811ce4e0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "train.jsonl: 0%| | 0.00/27.3M [00:00system\\n' }}\n", + " {%- if messages[0].role == 'system' %}\n", + " {{- messages[0].content + '\\n\\n' }}\n", + " {%- endif %}\n", + " {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n", + " {%- for tool in tools %}\n", + " {{- \"\\n\" }}\n", + " {{- tool | tojson }}\n", + " {%- endfor %}\n", + " {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n", + "{%- else %}\n", + " {%- if messages[0].role == 'system' %}\n", + " {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n", + " {%- endif %}\n", + "{%- endif %}\n", + "{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n", + "{%- for message in messages[::-1] %}\n", + " {%- set index = (messages|length - 1) - loop.index0 %}\n", + " {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('') and message.content.endswith('')) %}\n", + " {%- set ns.multi_step_tool = false %}\n", + " {%- set ns.last_query_index = index %}\n", + " {%- endif %}\n", + "{%- endfor %}\n", + "{%- for message in messages %}\n", + " {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n", + " {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n", + " {%- elif message.role == \"assistant\" %}\n", + " {%- set content = message.content %}\n", + " {%- set reasoning_content = '' %}\n", + " {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n", + " {%- set reasoning_content = message.reasoning_content %}\n", + " {%- else %}\n", + " {%- if '' in message.content %}\n", + " {%- set content = message.content.split('')[-1].lstrip('\\n') %}\n", + " {%- set reasoning_content = message.content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n", + " {%- endif %}\n", + " {%- endif %}\n", + " {%- if loop.index0 > ns.last_query_index %}\n", + " {%- if loop.last or (not loop.last and reasoning_content) %}\n", + " {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n", + " {%- else %}\n", + " {{- '<|im_start|>' + message.role + '\\n' + content }}\n", + " {%- endif %}\n", + " {%- else %}\n", + " {{- '<|im_start|>' + message.role + '\\n' + content }}\n", + " {%- endif %}\n", + " {%- if message.tool_calls %}\n", + " {%- for tool_call in message.tool_calls %}\n", + " {%- if (loop.first and content) or (not loop.first) %}\n", + " {{- '\\n' }}\n", + " {%- endif %}\n", + " {%- if tool_call.function %}\n", + " {%- set tool_call = tool_call.function %}\n", + " {%- endif %}\n", + " {{- '\\n{\"name\": \"' }}\n", + " {{- tool_call.name }}\n", + " {{- '\", \"arguments\": ' }}\n", + " {%- if tool_call.arguments is string %}\n", + " {{- tool_call.arguments }}\n", + " {%- else %}\n", + " {{- tool_call.arguments | tojson }}\n", + " {%- endif %}\n", + " {{- '}\\n' }}\n", + " {%- endfor %}\n", + " {%- endif %}\n", + " {{- '<|im_end|>\\n' }}\n", + " {%- elif message.role == \"tool\" %}\n", + " {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n", + " {{- '<|im_start|>user' }}\n", + " {%- endif %}\n", + " {{- '\\n\\n' }}\n", + " {{- message.content }}\n", + " {{- '\\n' }}\n", + " {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n", + " {{- '<|im_end|>\\n' }}\n", + " {%- endif %}\n", + " {%- endif %}\n", + "{%- endfor %}\n", + "{%- if add_generation_prompt %}\n", + " {{- '<|im_start|>assistant\\n' }}\n", + " {%- if enable_thinking is defined and enable_thinking is false %}\n", + " {{- '\\n\\n\\n\\n' }}\n", + " {%- endif %}\n", + "{%- endif %}\n", + "---\u001b[39m\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "258b7c635c1045329d4669e48c46ccd5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Tokenizing Prompts (num_proc=2): 0%| | 0/9985 [00:00\n", + " \n", + " \n", + " [25/25 09:25, Epoch 0/1]\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
11.092300
21.554200
31.041400
41.733800
51.430000
61.258500
71.343600
81.101700
91.086500
100.813200
110.689600
120.826700
131.541800
140.948000
151.357000
161.085800
171.516800
181.146800
190.834800
200.968000
211.388800
221.511500
231.338500
241.206600
251.504600

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2025-05-07 22:12:42,746] [INFO] [axolotl.callbacks.on_step_end:128] [PID:1336] [RANK:0] cuda memory usage while training: 9.768GB (+3.287GB cache, +0.646GB misc)\u001b[39m\n", + "[2025-05-07 22:21:46,859] [INFO] [axolotl.train.save_trained_model:231] [PID:1336] [RANK:0] Training completed! Saving pre-trained model to ./outputs/qwen-sft-pirate-rrr.\u001b[39m\n" + ] + } + ], + "source": [ + "from axolotl.train import train\n", + "\n", + "# just train the first 25 steps for demo.\n", + "# This is sufficient to align the model as we've used packing to maximize the trainable samples per step.\n", + "cfg.max_steps = 25\n", + "model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "j1b9ypF78eCb" + }, + "source": [ + "# Inferencing the trained model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "r3_vHhif8YEs", + "outputId": "e5050605-f6c9-421c-98f9-bde56a281eae" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ahoy there, matey! Shiver me timbers, ye be lookin' for the Pythagorean theorem, eh? Well, hold yer horses and listen up, for I'll be tellin' ye all about it in me own special way.\n", + "\n", + "The Pythagorean theorem be a real gem of a mathematical trick that helps ye find the length of a side of a right triangle. Now, a right triangle be a triangle with a right angle, which be that little corner that looks like a square. \n", + "\n", + "The theorem be named after a clever fellow named Pythagoras, who be a mathematician from ancient Greece. He discovered that if ye have a right triangle, the square of the length of the hypotenuse (that be the side opposite the right angle) be equal to the sum of the squares of the other two sides. \n", + "\n", + "In other words, if ye have a triangle with sides of length a, b, and c (\n" + ] + } + ], + "source": [ + "from transformers import TextStreamer\n", + "\n", + "messages = [\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"Explain the Pythagorean theorem to me.\",\n", + " },\n", + "]\n", + "\n", + "prompt = tokenizer.apply_chat_template(\n", + " messages,\n", + " add_generation_prompt=True,\n", + " tokenize=False,\n", + " enable_thinking=False,\n", + ")\n", + "\n", + "outputs = model.generate(\n", + " **tokenizer(prompt, return_tensors=\"pt\").to(\"cuda\"),\n", + " max_new_tokens=192,\n", + " temperature=1.0,\n", + " top_p=0.8,\n", + " top_k=32,\n", + " streamer=TextStreamer(tokenizer, skip_prompt=True),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HoGwT2JRSIjA" + }, + "source": [ + "# Saving your trained model\n", + "\n", + "Axolotl automatically saves checkpoints to the `output_dir` path.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5BmSbiy6NaaS", + "outputId": "f5e1d913-7d55-42d2-8340-f9f1b0bc2b38" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "total 506M\n", + "-rw-r--r-- 1 root root 845 May 7 22:21 adapter_config.json\n", + "-rw-r--r-- 1 root root 491M May 7 22:21 adapter_model.safetensors\n", + "-rw-r--r-- 1 root root 707 May 7 22:11 added_tokens.json\n", + "drwxr-xr-x 2 root root 4.0K May 7 22:17 checkpoint-13\n", + "drwxr-xr-x 2 root root 4.0K May 7 22:21 checkpoint-25\n", + "-rw-r--r-- 1 root root 1.2K May 7 22:11 config.json\n", + "-rw-r--r-- 1 root root 1.6M May 7 22:11 merges.txt\n", + "-rw-r--r-- 1 root root 2.6K May 7 22:21 README.md\n", + "-rw-r--r-- 1 root root 613 May 7 22:11 special_tokens_map.json\n", + "-rw-r--r-- 1 root root 9.5K May 7 22:11 tokenizer_config.json\n", + "-rw-r--r-- 1 root root 11M May 7 22:11 tokenizer.json\n", + "-rw-r--r-- 1 root root 2.7M May 7 22:11 vocab.json\n" + ] + } + ], + "source": [ + "# Show the saved checkpoints in the output_dir\n", + "!ls -lh \"./outputs/qwen-sft-pirate-rrr\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_PCIFWxuOZd6" + }, + "source": [ + "Setting `hub_model_id: ` in the original config would have automatically uploaded the model to HuggingFace Hub (e.g. `hub_model_id: username/model_id`)\n", + "\n", + "If you prefer to manually upload the training artifacts, we can still upload the entire final checkpoint to HuggingFace from the CLI." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 955, + "referenced_widgets": [ + "c12ea43372ac4d57bb9605f1a429b397", + "86816687746246b4a6105e8010384e25", + "6f05e9bebf7b40c9835808e77de6c236", + "c7433acd3c4841e6958ae8f7e87b1808", + "19c1e38389fa46c7b7e2152a56e1df34", + "0e067d8db8ed48308a718d5f57683fd1", + "131065f118274a1586ac38e39ed84ef0", + "8640ac440fbc4644b9a3af7ba3ae7183", + "5cea7996f02040b187ece0bb2d6a8d1f", + "2e257c8be2da40b4bb67a9e4ab6811f3", + "56e3768bef5a4b9db4168c5c17f509c2", + "62c028fdef904dedb9cdeca2b3bda725", + "a7cf477e80fc43e0ad82c7997b076dce", + "835bcc28a5564fb9b3d651bc8e32dc46", + "9f1c9a0695384bdaa6f8b847ef89bee8", + "b1bea589efa14258a9982071b87938bf", + "590eef89881545aa8bbef9a8bbe7fb00", + "4b1f04ff63d14a118fdd15814dff50e4", + "39789237703c4a418134243055c9cbf5", + "a3a945817f684328b34651fe052393ec" + ] + }, + "id": "2yw8pLvlSMl8", + "outputId": "6e489ab2-4abe-4e28-84ca-959f912433a4" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c12ea43372ac4d57bb9605f1a429b397", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(HTML(value='

\n", + " sys.exit(main())\n", + " ^^^^^^\n", + " File \"/usr/local/lib/python3.11/dist-packages/huggingface_hub/commands/huggingface_cli.py\", line 57, in main\n", + " service.run()\n", + " File \"/usr/local/lib/python3.11/dist-packages/huggingface_hub/commands/upload.py\", line 207, in run\n", + " print(self._upload())\n", + " ^^^^^^^^^^^^^^\n", + " File \"/usr/local/lib/python3.11/dist-packages/huggingface_hub/commands/upload.py\", line 302, in _upload\n", + " return self.api.upload_folder(\n", + " ^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_validators.py\", line 114, in _inner_fn\n", + " return fn(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^\n", + " File \"/usr/local/lib/python3.11/dist-packages/huggingface_hub/hf_api.py\", line 1633, in _inner\n", + " return fn(self, *args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/usr/local/lib/python3.11/dist-packages/huggingface_hub/hf_api.py\", line 4942, in upload_folder\n", + " commit_info = self.create_commit(\n", + " ^^^^^^^^^^^^^^^^^^^\n", + " File \"/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_validators.py\", line 114, in _inner_fn\n", + " return fn(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^\n", + " File \"/usr/local/lib/python3.11/dist-packages/huggingface_hub/hf_api.py\", line 1633, in _inner\n", + " return fn(self, *args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/usr/local/lib/python3.11/dist-packages/huggingface_hub/hf_api.py\", line 4202, in create_commit\n", + " self.preupload_lfs_files(\n", + " File \"/usr/local/lib/python3.11/dist-packages/huggingface_hub/hf_api.py\", line 4483, in preupload_lfs_files\n", + " _upload_xet_files(**upload_kwargs, create_pr=create_pr) # type: ignore [arg-type]\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_validators.py\", line 114, in _inner_fn\n", + " return fn(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^\n", + " File \"/usr/local/lib/python3.11/dist-packages/huggingface_hub/_commit_api.py\", line 592, in _upload_xet_files\n", + " with progress_cm as progress:\n", + " File \"/usr/local/lib/python3.11/dist-packages/tqdm/std.py\", line 1138, in __exit__\n", + " def __exit__(self, exc_type, exc_value, traceback):\n", + "\n", + "KeyboardInterrupt\n", + "^C\n" + ] + } + ], "source": [ "from huggingface_hub import notebook_login\n", - "notebook_login()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Example configuration" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import yaml\n", "\n", - "yaml_string = \"\"\"\n", - "base_model: NousResearch/Meta-Llama-3.1-8B\n", + "# remove the partial epoch checkpoints\n", + "!rm -rf \"./outputs/qwen-sft-pirate-rrr/checkpoint-*\"\n", "\n", - "load_in_8bit: false\n", - "load_in_4bit: true\n", - "strict: false\n", + "# HF Notebook login widget\n", + "notebook_login()\n", "\n", - "datasets:\n", - " - path: tatsu-lab/alpaca\n", - " type: alpaca\n", - "dataset_prepared_path: last_run_prepared\n", - "val_set_size: 0.05\n", - "output_dir: ./outputs/lora-out\n", - "\n", - "sequence_len: 2048\n", - "sample_packing: true\n", - "eval_sample_packing: true\n", - "pad_to_sequence_len: true\n", - "\n", - "adapter: qlora\n", - "lora_model_dir:\n", - "lora_r: 32\n", - "lora_alpha: 16\n", - "lora_dropout: 0.05\n", - "lora_target_linear: true\n", - "lora_fan_in_fan_out:\n", - "lora_modules_to_save:\n", - " - embed_tokens\n", - " - lm_head\n", - "\n", - "wandb_project:\n", - "wandb_entity:\n", - "wandb_watch:\n", - "wandb_name:\n", - "wandb_log_model:\n", - "\n", - "gradient_accumulation_steps: 2\n", - "micro_batch_size: 1\n", - "num_epochs: 1\n", - "optimizer: paged_adamw_8bit\n", - "lr_scheduler: cosine\n", - "learning_rate: 2e-5\n", - "\n", - "train_on_inputs: false\n", - "group_by_length: false\n", - "bf16: auto\n", - "fp16:\n", - "tf32: false\n", - "\n", - "gradient_checkpointing: true\n", - "early_stopping_patience:\n", - "resume_from_checkpoint:\n", - "logging_steps: 1\n", - "xformers_attention:\n", - "flash_attention: false\n", - "sdp_attention: true\n", - "\n", - "warmup_steps: 1\n", - "max_steps: 25\n", - "evals_per_epoch: 1\n", - "eval_table_size:\n", - "saves_per_epoch: 1\n", - "debug:\n", - "deepspeed:\n", - "weight_decay: 0.0\n", - "fsdp:\n", - "fsdp_config:\n", - "special_tokens:\n", - " pad_token: <|end_of_text|>\n", - "\"\"\"\n", - "\n", - "\n", - "# Convert the YAML string to a Python dictionary\n", - "yaml_dict = yaml.safe_load(yaml_string)\n", - "\n", - "# Specify your file path\n", - "file_path = 'test_axolotl.yaml'\n", - "\n", - "# Write the YAML file\n", - "with open(file_path, 'w') as file:\n", - " yaml.dump(yaml_dict, file)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Above we have a configuration file with base LLM model and datasets specified, among many other things. Axolotl can automatically detect whether the specified datasets are on HuggingFace repo or local machine.\n", - "\n", - "The Axolotl configuration options encompass model and dataset selection, data pre-processing, and training. Let's go through them line by line:\n", - "\n", - "* \"base model\": String value, specifies the underlying pre-trained LLM that will be used for finetuning\n", - "\n", - "Next we have options for model weights quantization. Quantization allows for reduction in occupied memory on GPUs.\n", - "\n", - "* \"load_in_8bit\": Boolean value, whether to quantize the model weights into 8-bit integer.\n", - "\n", - "* \"load_in_4bit\": Boolean value, whether to quantize the model weights into 4-bit integer.\n", - "\n", - "* \"strict\": Boolean value. If false, it allows for overriding established configuration options in the yaml file when executing in command-line interface.\n", - "\n", - "* \"datasets\": a list of dicts that contain path and type of data sets as well as other optional configurations where datasets are concerned. Supports multiple datasets.\n", - "\n", - "* \"val_set_size\": Either a float value less than one or an integer less than the total size of dataset. Sets the size of validation set from the whole dataset. If float, sets the proportion of the dataset assigned for validation. If integer, sets the direct size of validation set.\n", - "\n", - "* \"output_dir\": String value. Path of trained model.\n", - "\n", - "For data preprocessing:\n", - "\n", - "* \"sequence_len\": Integer. Specifies the maximum sequence length of the input. Typically 2048 or less.\n", - "\n", - "* \"pad_to_sequence_len\": Boolean. Padding input to maximum sequence length.\n", - "\n", - "* \"sample_packing\": Boolean. Specifies whether to use multi-packing with block diagonal attention.\n", - "\n", - "* \"special_tokens\": Python dict, optional. Allows users to specify the additional special tokens to be ignored by the tokenizer.\n", - "\n", - "For LoRA configuration and its hyperparamters:\n", - "\n", - "* \"adapter\": String. Either \"lora\" or \"qlora\", depending on user's choice.\n", - "\n", - "* \"lora_model_dir\": String, Optional. Path to directory that contains LoRA model, if there is already a trained LoRA model the user would like to use.\n", - "\n", - "* \"lora_r\": Integer. Refers to the rank of LoRA decomposition matrices. Higher value will reduce LoRA efficiency. Recommended to be set to 8.\n", - "\n", - "* \"lora_alpha\": Integer. Scale the weight matrices by $\\frac{\\text{lora_alpha}}{\\text{lora_r}}$Recommended to be fixed at 16.\n", - "\n", - "* \"lora_dropout\": Float that is 1 or less. The dropout probability of a lora layer.\n", - "\n", - "* \"lora_target_linear\": Boolean. If true, lora will target all linear modules in the transformers architecture.\n", - "\n", - "* \"lora_modules_to_save\": If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.\n", - "\n", - "See [LoRA](https://arxiv.org/abs/2106.09685) for detailed explanation of LoRA implementation.\n", - "\n", - "For the training configurations:\n", - "\n", - "* \"gradient_accumulation_steps\": Integer. The number of steps over which to accumulate gradient for batch training. E.g. if 2, backprop is performed every two steps.\n", - "\n", - "* \"micro_batch_size\": Integer. Batch size per gpu / gradient_accumulation_steps\n", - "\n", - "* \"num_epochs\": Integer. Number of epochs. One epoch is when training has looped over every batch in the whole data set once.\n", - "\n", - "* \"optimizer\": The optimizer to use for the training.\n", - "\n", - "* \"learning_rate\": The learning rate.\n", - "\n", - "* \"lr_scheduler\": The learning rate scheduler to use for adjusting learning rate during training.\n", - "\n", - "* \"train_on_inputs\": Boolean. Whether to ignore or include the user's prompt from the training labels.\n", - "\n", - "* \"group_by_length\": Boolean. Whether to group similarly sized data to minimize padding.\n", - "\n", - "* \"bf16\": Either \"auto\", \"true\", or \"false\". Whether to use CUDA bf16 floating point format. If set to \"auto\", will automatically apply bf16 should the gpu supports it.\n", - "\n", - "* \"fp16\": Optional. Specifies whether to use CUDA fp16. Automatically set to true if \"bf16\" is set to true. Otherwise false.\n", - "\n", - "* \"tf32\": Boolean. Whether to use CUDA tf32. Will override bf16.\n", - "\n", - "* \"gradient_checkpointing\": Boolean. Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing\n", - "\n", - "* \"gradient_checkpointing_kwargs\": Python Dict. Fed into the trainer.\n", - "\n", - "* \"logging_steps\": Integer. Log training information over every specified number of steps.\n", - "\n", - "* \"flash_attention\": Boolean. Whether to use the [flash attention](https://github.com/Dao-AILab/flash-attention) mechanism.\n", - "\n", - "* \"sdp_attention\": Boolean. Whether to use the Scaled Dot Product attention mechanism (the attention mechanism in the [original implementation](https://arxiv.org/abs/1706.03762) of transformers.)\n", - "\n", - "* \"warmup_steps\": Integer. The number of pre-training steps where a very low learning rate is used.\n", - "\n", - "* \"evals_per_epoch\": Integer. Number of evaluations to be performed within one training epoch.\n", - "\n", - "* \"saves_per_epoch\": Integer. Number of times the model is saved in one training epoch.\n", - "\n", - "* \"weight_decay\": Positive Float. Sets the \"strength\" of weight decay (i.e. setting the coefficient of L2 regularization)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The above is but a snippet aiming to get users familiarized with the types of streamlined configuration options axolotl provides. For a full list of configuration options, see [here](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Train the model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Predict with trained model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n", - " --lora_model_dir=\"./outputs/lora-out\" --gradio" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Deeper Dive" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "It is also helpful to gain some familiarity over some of the core inner workings of axolotl" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Configuration Normalization" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Axolotl uses a custom Dict class, called ```DictDefault```\n", - "to store configurations specified in the yaml configuration file (into a Python variable named ```cfg```). The definition for this custom Dict can be found in the [utils/dict.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/dict.py)\n", - "\n", - "```DictDefault``` is amended such that calling a missing key from it will result in a ```None``` return type. This is important because if some configuration options aren't specified by the user, the ```None``` type allows Axolotl to perform boolean operations to determine the default settings for missing configurations. For more examples on how this is done, check out [utils/config/__init__.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/__init__.py)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Loading Models, Tokenizers, and Trainer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If we inspect [cli.train.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/cli/train.py), we will find that most of the heavy lifting were done by the function ```train()``` which is itself imported from [src/axolotl/train.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/train.py).\n", - "\n", - "```train()``` takes care of loading the appropriate tokenizer and pre-trained model through ```load_model()``` and ```load_tokenizer()``` from [src/axolotl/utils/models.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/models.py) respectively.\n", - "\n", - "```load_tokenizer()``` loads in the appropriate tokenizer given the desired model, as well as chat templates.\n", - "\n", - "```ModelLoader``` class follows after tokenizer has been selected. It will automatically discern the base model type, load in the desired model, as well as applying model-appropriate attention mechanism modifications (e.g. flash attention). Depending on which base model the user chooses in the configuration, ```ModelLoader``` will utilize the corresponding \"attention hijacking\" script. For example, if the user specified the base model to be ```NousResearch/Meta-Llama-3.1-8B```, which is of llama type, and set ```flash_attn``` to ```True```, ```ModelLoader``` will load in [llama_attn_hijack_flash.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/llama_attn_hijack_flash.py). For a list of supported attention hijacking, please refer to the directory [/src/axolotl/monkeypatch/](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/monkeypatch)\n", - "\n", - "Another important operation encompassed in ```train()``` is setting up the training that takes into account of user-specified traning configurations (e.g. num_epochs, optimizer) through the use of ```setup_trainer()``` from [/src/axolotl/utils/trainer.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/trainer.py), which in turn relies on modules from [/src/axolotl/core/trainer_builder.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/core/trainer_builder.py).\n", - "```trainer_builder.py``` provides a list of trainer object options bespoke for the task type (Causal or Reinforcement learning ('dpo', 'ipo', 'kto') )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Monkey patch\n", - "\n", - "The [Monkey patch directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/monkeypatch) is where model architecture/optimization patching scripts are stored (these are modifications that are not implemented in the official releases, hence the name monkey patch). It includes attention jacking, ReLoRA, and unsloth optimization." + "# upload the LoRA adapter for your model to HF, remember to update the username/model-name below\n", + "!huggingface-cli upload --repo-type=model winglian/pirate-qwen-14B \"./outputs/qwen-sft-pirate-rrr\"" ] } ], "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, "kernelspec": { "display_name": "Python 3", - "language": "python", "name": "python3" }, "language_info": { - "name": "python", - "version": "3.9.6" + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "00321cce58884f6f9b3855a21fcd9187": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "004d9177a6a14118a5930dc3cc13147b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_a80410b919e442c49aea15acc1ce1a72", + "IPY_MODEL_c6e00f5224364822bc4239b176686919", + "IPY_MODEL_ec11d1e5ae7b42c883d9b1f38a65356e" + ], + "layout": "IPY_MODEL_734185351eb543fa9a00a881dcbb9fe7" + } + }, + "0077aedc3d174560bce924ee89e9c006": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "03a3c744d716431488163b4358b80f92": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "03b093d592ba4386aa61f7b8483da660": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_b8766a88716948cf968f4563531a76d9", + "IPY_MODEL_6f3a28b912714c6e931003549664bfa3", + "IPY_MODEL_16d1283741404b7bb319094c992fce01" + ], + "layout": "IPY_MODEL_2a5bb0e818ab47be8cf6465988328503" + } + }, + "042e091f75694c47aee761e760e76773": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "0546d04aae644dde846c58a4afb598a6": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "054c8dffadba48c6b895a6cc62448ecc": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "07fb3a2c8315494e97b447e672dfae06": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_12815f401eba44658caa7b2e490137a8", + "placeholder": "​", + "style": "IPY_MODEL_30e02aa2d0d241979369e598287f2639", + "value": "Drop Samples with Zero Trainable Tokens (num_proc=2): 100%" + } + }, + "083f9cda8d754c168beee10d2f8955a2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a0a11e929edd4189b79723d618522c33", + "max": 728, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_e87ea87fcff247b5bbcc331ba79a8dc2", + "value": 728 + } + }, + "09007681cf8d42aeb8c1d2f6a74e470a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b195f160ca20442fadd8b5aed0ee41af", + "placeholder": "​", + "style": "IPY_MODEL_ca65e32eb52f48c09a84b33cb18f22cd", + "value": " 11.4M/11.4M [00:00<00:00, 21.8MB/s]" + } + }, + "0a46ad75c198463d843fb35e813642cb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b8e39e4dddc3497fbc29ae45c66da759", + "max": 11422654, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_63b4e563e85c4f03b1b72beda9577bcc", + "value": 11422654 + } + }, + "0aa8ab56b85f4171a79c3bc210594025": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "0b4c9753a7cb4354b8e5f187e6e1ad7c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "0cd7efffbb3c4c4b972e63749f61ab97": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "0dea5caa27384f5689e3cab51f558727": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0e067d8db8ed48308a718d5f57683fd1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b1bea589efa14258a9982071b87938bf", + "placeholder": "​", + "style": "IPY_MODEL_590eef89881545aa8bbef9a8bbe7fb00", + "value": "\nPro Tip: If you don't already have one, you can create a dedicated\n'notebooks' token with 'write' access, that you can then easily reuse for all\nnotebooks.
" + } + }, + "0e50870ed0c643e0b6c18cc5d7ddae7f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bfcdbba993b74972a9e3e575f86908ff", + "placeholder": "​", + "style": "IPY_MODEL_6ebb2ec171414e47a14765505f64bb3c", + "value": " 3.84G/3.84G [00:09<00:00, 664MB/s]" + } + }, + "0e936d9dbf9c4fdd86bbfe9730dedc47": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "0f417447a7bd4a33acca96fa37aec877": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "0f480e3a0b0a45d2a2d2dec3cad923f3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0f6907ebbc6242c8bde059cef1e1bd29": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_5bdfd87fc6cd4f9dabef7cfee29c8060", + "IPY_MODEL_64f54d4a744a4627a07c3c0120276f3b", + "IPY_MODEL_65b75b9b8bc143cf997796af68ff6668" + ], + "layout": "IPY_MODEL_d6fe74e4255444368f8f90a62157d869" + } + }, + "114dece49dba437c8572ef94b23c3b1e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "12815f401eba44658caa7b2e490137a8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "12b56912736849fea2ad8124456fdc5c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_97e36007e1304e1583fd81bfb13f0edd", + "max": 1671853, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_c65dc74c7d6f4bab8f7dd28455161dd8", + "value": 1671853 + } + }, + "131065f118274a1586ac38e39ed84ef0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": "center", + "align_self": null, + "border": null, + "bottom": null, + "display": "flex", + "flex": null, + "flex_flow": "column", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "50%" + } + }, + "158c8b85dbf34de6a94b4e35e2fc7d5a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "16a188a0b06d45f980dcf3933509fe0a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_349eee9f56d64f0cba6fc24ff2c50c9b", + "placeholder": "​", + "style": "IPY_MODEL_7e5d3774060e4589aa65982da5ea4ef4", + "value": " 9985/9985 [00:04<00:00, 2604.11 examples/s]" + } + }, + "16d1283741404b7bb319094c992fce01": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a4e5789584564049b83df7c6c54a3e08", + "placeholder": "​", + "style": "IPY_MODEL_ff3a94b146a948b6907f5d80c7157f99", + "value": " 9985/0 [00:00<00:00, 50763.46 examples/s]" + } + }, + "1811cda0644e4190a9469d1774435d82": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "18357b321ce44d7b8bd9d1c886f69275": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e366ae3fceec4566b9ed303d6c5f90af", + "placeholder": "​", + "style": "IPY_MODEL_5dd7d150dbe04f08b165ce7f2c27cd11", + "value": "model-00008-of-00008.safetensors: 100%" + } + }, + "19127c7bb1554ccbac877059f9a82db0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "danger", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e400cbf14bcc446a9d33b210cd93550b", + "max": 3963750880, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_71002199df6b40c9a1ac40df5fb27a1b", + "value": 3963750502 + } + }, + "19c1e38389fa46c7b7e2152a56e1df34": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ButtonModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ButtonView", + "button_style": "", + "description": "Login", + "disabled": false, + "icon": "", + "layout": "IPY_MODEL_835bcc28a5564fb9b3d651bc8e32dc46", + "style": "IPY_MODEL_9f1c9a0695384bdaa6f8b847ef89bee8", + "tooltip": "" + } + }, + "1bec6297c90242a88672d195bc09d429": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1c6f1f10667545aaab958016ba7e2c94": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1d5117195d4b49eb8f1a73b18419f7ce": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0dea5caa27384f5689e3cab51f558727", + "placeholder": "​", + "style": "IPY_MODEL_a6f48410b9964fefba0c3009a77dc838", + "value": " 9.68k/9.68k [00:00<00:00, 812kB/s]" + } + }, + "1f7d30f71bbd4547a9150d21da071055": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "200df5e79b9244849e589ecb0250a520": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f4a1795dc7514a718f478245f521f0ba", + "placeholder": "​", + "style": "IPY_MODEL_5e746eb25bbe416fb585fa24e79f5177", + "value": "model-00002-of-00008.safetensors: 100%" + } + }, + "20352e5f58d24bb8b1f3940efd14fe4a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "253017b0d0534e54ab44e181f6d7c82d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1c6f1f10667545aaab958016ba7e2c94", + "placeholder": "​", + "style": "IPY_MODEL_e6e969610738449887259063967f82b0", + "value": " 2.78M/2.78M [00:00<00:00, 17.8MB/s]" + } + }, + "258b7c635c1045329d4669e48c46ccd5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_6f68ed9889f54ad2ae8a3b95ac263a83", + "IPY_MODEL_80366349d81e4dcc892db6cd56e384f3", + "IPY_MODEL_c73055099c084dca996159e23e162d0b" + ], + "layout": "IPY_MODEL_977f799afaac4a55b2dc1cffa7d5b63b" + } + }, + "279937fe03bc4e4eb25b472d7e9df163": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "danger", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b634bb73cfa743d09a5999101b840976", + "max": 1912371880, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_742b1030acfd414bbd9d5327b7e3826d", + "value": 1912371698 + } + }, + "27beaf06e41b472abdb544a43c720c5a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2860e3bb3baf4f7da058465850e800c5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_3efd18ea8eaa41918894883da9541bfa", + "IPY_MODEL_e09f1bcbb9d94c09be53e5e1303642c2", + "IPY_MODEL_82177df57a494de8900c14c2f5185175" + ], + "layout": "IPY_MODEL_ccfcdc95baf646f8aeb3d516742383f2" + } + }, + "2a51b36be41745468e4c2d7a21b1c0d2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2a5bb0e818ab47be8cf6465988328503": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2b3a2659b12244bd8548320320016dbf": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2e257c8be2da40b4bb67a9e4ab6811f3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2e2b0c1599c341a198f632f46a40c90e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_be724f04b03942b2a033a7e8898bb4fd", + "placeholder": "​", + "style": "IPY_MODEL_fcbab4d8dced41a18dfccce81e3a45a0", + "value": "model-00005-of-00008.safetensors: 100%" + } + }, + "3036608c71904ce9ae4bb2a9fa8802d9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_5ca6be24acb548cea130bd58e9954c7c", + "placeholder": "​", + "style": "IPY_MODEL_5cfb02ee044b4011a378efa8b54a370f", + "value": " 3.96G/3.96G [00:10<00:00, 531MB/s]" + } + }, + "30a81da86f8043eca301e86a8651201a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "30e02aa2d0d241979369e598287f2639": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3225603166b54e7aab766b9964a2f660": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "33b3b1d0295646edaac7b4822761aeb0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "349eee9f56d64f0cba6fc24ff2c50c9b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "34c9c0137b504cd799c6bd6de69507c2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "34cf3df51fbc41cabfdbba153c007f0e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "35c811d2ae8e43f3b5cecbdd3cfa857f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "35cc989ca3374e7dba0cb166febc4bde": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "366a343b62fa47d8985a3bd464d99f9e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "37de928300e34184881039378bd75e7f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "388f618924274d21a066f098f4f1e744": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_7c95f85a2b1f47a1bd846d110c47bb3c", + "IPY_MODEL_083f9cda8d754c168beee10d2f8955a2", + "IPY_MODEL_62e1a65582f446a78612eaa804e08a7d" + ], + "layout": "IPY_MODEL_487a177d020f4605834878b2fdc7afa3" + } + }, + "39789237703c4a418134243055c9cbf5": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3aaecbf540f54a2db9ab0931e3b1fe57": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3c21e4a511b4441192c03b7f1d0976e9": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3efd18ea8eaa41918894883da9541bfa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8f5bd719974e41c3a8dd9a5b0d3d71e6", + "placeholder": "​", + "style": "IPY_MODEL_b87c84de30e84b3abf4871461fb9cbd3", + "value": "Loading checkpoint shards: 100%" + } + }, + "41f3b32c2f6b4034ae7a3b9124e28bc7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4471ff62258549fba9514bb67050f965": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_9cd5211b5d8b457aa0002f1d17b80028", + "IPY_MODEL_19127c7bb1554ccbac877059f9a82db0", + "IPY_MODEL_f4667818b9d34a09891cd727a429a610" + ], + "layout": "IPY_MODEL_9ed02dc43412471a9ab47f3620ccf3a5" + } + }, + "4540927d98f54466b434ba4c0edf045d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "487a177d020f4605834878b2fdc7afa3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4b1f04ff63d14a118fdd15814dff50e4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "LabelModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "LabelModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "LabelView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_39789237703c4a418134243055c9cbf5", + "placeholder": "​", + "style": "IPY_MODEL_a3a945817f684328b34651fe052393ec", + "value": "Connecting..." + } + }, + "4b27c267393640f28f6eae0875bd2ed9": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4c727d40ef0443449afc31724ee79f0c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "4d05314858354e729d76094b3b0ce761": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_c42acf646f344a88b8c11f81e67f7206", + "IPY_MODEL_7be6f04c284e4326bb4ff3d301e7b3c6", + "IPY_MODEL_ffdbb12a2f2c4d14911685e7683e0ef0" + ], + "layout": "IPY_MODEL_bee3501b2a17427784a717e50a85e7fa" + } + }, + "4d468f96ec924681ad65eb671674b93e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4f1977d7e4824ef1a14b65f0f42bba10": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "4fd114abe9f5494ab59858949f5055f1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "500e272208a246089613bf788a165271": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_200df5e79b9244849e589ecb0250a520", + "IPY_MODEL_cc94432d08464affa3e58b560bdad194", + "IPY_MODEL_3036608c71904ce9ae4bb2a9fa8802d9" + ], + "layout": "IPY_MODEL_adacfdcc1b0140efac56918e9ccf064e" + } + }, + "519a7b154022443db6703f04a9142bae": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d02274afd47b462291c745f261209d42", + "max": 27341251, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_0f417447a7bd4a33acca96fa37aec877", + "value": 27341251 + } + }, + "56e3768bef5a4b9db4168c5c17f509c2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "590eef89881545aa8bbef9a8bbe7fb00": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "598da69727bd4fb8b1caf465ac736d7a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5bdfd87fc6cd4f9dabef7cfee29c8060": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_4d468f96ec924681ad65eb671674b93e", + "placeholder": "​", + "style": "IPY_MODEL_ad7599de524549c48bf2d3124ad4b299", + "value": "Dropping Long Sequences (num_proc=2): 100%" + } + }, + "5ca240f31e6b44e3882c5eb37cd5a309": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "20px" + } + }, + "5ca6be24acb548cea130bd58e9954c7c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5cea7996f02040b187ece0bb2d6a8d1f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5cfb02ee044b4011a378efa8b54a370f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5dd7d150dbe04f08b165ce7f2c27cd11": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5e18768f7ad6434ba8b8b8a2e853e204": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5e5e15b0569b474c9620083b3ec6af55": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5e746eb25bbe416fb585fa24e79f5177": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5eb06edeb58e4930b1affef2a59eae81": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "5f86cd894de94c3280fadc1e2fd0ee13": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_a20927bf5f2c41f58c1e31ac858ab36c", + "IPY_MODEL_0a46ad75c198463d843fb35e813642cb", + "IPY_MODEL_09007681cf8d42aeb8c1d2f6a74e470a" + ], + "layout": "IPY_MODEL_ebc80d1a55fa47f4a5ea2756588569ec" + } + }, + "60c1a0d765c14a1d888317e6a507e4ea": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "62c028fdef904dedb9cdeca2b3bda725": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "62e1a65582f446a78612eaa804e08a7d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_5e18768f7ad6434ba8b8b8a2e853e204", + "placeholder": "​", + "style": "IPY_MODEL_bb33aec33a6447078c31bfd728942994", + "value": " 728/728 [00:00<00:00, 20.3kB/s]" + } + }, + "62e302ebdad64aada0ffe64ae1c873f3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "63580b6fb30642479fe3000915bf551a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "63b4e563e85c4f03b1b72beda9577bcc": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "64f54d4a744a4627a07c3c0120276f3b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0546d04aae644dde846c58a4afb598a6", + "max": 9985, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_897b77a56c09479bb11d7f2a30997e55", + "value": 9985 + } + }, + "65b75b9b8bc143cf997796af68ff6668": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_81c3db71ac704280ad030072655f1537", + "placeholder": "​", + "style": "IPY_MODEL_042e091f75694c47aee761e760e76773", + "value": " 9985/9985 [00:02<00:00, 3977.47 examples/s]" + } + }, + "67da6c4260574869aa24c3cbc1bc1654": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6932489232ec4ab18a160b1e7fbcdfe1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6ebb2ec171414e47a14765505f64bb3c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6f05e9bebf7b40c9835808e77de6c236": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "PasswordModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "PasswordModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "PasswordView", + "continuous_update": true, + "description": "Token:", + "description_tooltip": null, + "disabled": false, + "layout": "IPY_MODEL_2e257c8be2da40b4bb67a9e4ab6811f3", + "placeholder": "​", + "style": "IPY_MODEL_56e3768bef5a4b9db4168c5c17f509c2", + "value": "" + } + }, + "6f3a28b912714c6e931003549664bfa3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_5ca240f31e6b44e3882c5eb37cd5a309", + "max": 1, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_5eb06edeb58e4930b1affef2a59eae81", + "value": 1 + } + }, + "6f68ed9889f54ad2ae8a3b95ac263a83": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_41f3b32c2f6b4034ae7a3b9124e28bc7", + "placeholder": "​", + "style": "IPY_MODEL_a10d0a76010f4e508c65a9b69ebc5156", + "value": "Tokenizing Prompts (num_proc=2): 100%" + } + }, + "704f2f5a9b1c49d5a75a0025a5dda11b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "71002199df6b40c9a1ac40df5fb27a1b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "71c8af139cd248b1b51101fd46a93f35": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d0e9dce55cec4c1ca619a0ccf209d924", + "max": 9675, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_4c727d40ef0443449afc31724ee79f0c", + "value": 9675 + } + }, + "734185351eb543fa9a00a881dcbb9fe7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "735d4f225b24414294fc1b213c61223c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "742b1030acfd414bbd9d5327b7e3826d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "77304d1a46b3468a98483e02ec0ac4a4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7baeab52d6694c32b1efd1ea1a0a7782": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_93a44a11aa4846fa8efc6c1413ef1627", + "placeholder": "​", + "style": "IPY_MODEL_a55060adc3564407ac81ad7297d34aaa", + "value": "train.jsonl: 100%" + } + }, + "7be6f04c284e4326bb4ff3d301e7b3c6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "danger", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9503a45960984adc97b58e16c50662e0", + "max": 3963750880, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_da6e93f3e4984780b930fe7a706983ea", + "value": 3963750502 + } + }, + "7c2485c6cdfe463da6fdb35982a1070d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_ad1236893754446881e153adc9d5c962", + "IPY_MODEL_daee63fd167e4441a32324b51b00ad2b", + "IPY_MODEL_fe41858c6bd04c58840112b67c19a336" + ], + "layout": "IPY_MODEL_d262c82138024169b9f3aa034ca756fa" + } + }, + "7c95f85a2b1f47a1bd846d110c47bb3c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7fd44cf9ca6e4726bfd7ac21846d6a14", + "placeholder": "​", + "style": "IPY_MODEL_366a343b62fa47d8985a3bd464d99f9e", + "value": "config.json: 100%" + } + }, + "7cd0b85ebd204b7aba908417811ce4e0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_7baeab52d6694c32b1efd1ea1a0a7782", + "IPY_MODEL_519a7b154022443db6703f04a9142bae", + "IPY_MODEL_d4183e9715f34d249942b8271cca3bdf" + ], + "layout": "IPY_MODEL_da2347ac94764a3fa2743343cf0d3cd2" + } + }, + "7e5d3774060e4589aa65982da5ea4ef4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "7fd44cf9ca6e4726bfd7ac21846d6a14": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "80366349d81e4dcc892db6cd56e384f3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f8ef805b776145c3bfa9ba8d90972058", + "max": 9985, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_cc587493c33c4f118d1b1170f85be24c", + "value": 9985 + } + }, + "813621384dc748b0ad06775e22761c0b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "81c3db71ac704280ad030072655f1537": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "82177df57a494de8900c14c2f5185175": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_67da6c4260574869aa24c3cbc1bc1654", + "placeholder": "​", + "style": "IPY_MODEL_94b9088614464f60a203de39dbcae853", + "value": " 8/8 [01:47<00:00, 11.64s/it]" + } + }, + "823f1c78f15043e38bbd4dca3932a86a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_03a3c744d716431488163b4358b80f92", + "max": 239, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_a5434ee714f9498d83870544b67c0cb7", + "value": 239 + } + }, + "835bcc28a5564fb9b3d651bc8e32dc46": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8640ac440fbc4644b9a3af7ba3ae7183": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "86816687746246b4a6105e8010384e25": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8640ac440fbc4644b9a3af7ba3ae7183", + "placeholder": "​", + "style": "IPY_MODEL_5cea7996f02040b187ece0bb2d6a8d1f", + "value": "

Copy a token from your Hugging Face\ntokens page and paste it below.
Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file.
" + } + }, + "879c8ab5873847a8833bd74123be90a4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ef223e8504b64e3592589880326aaf41", + "placeholder": "​", + "style": "IPY_MODEL_598da69727bd4fb8b1caf465ac736d7a", + "value": " 1.67M/1.67M [00:00<00:00, 19.0MB/s]" + } + }, + "897b77a56c09479bb11d7f2a30997e55": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "8bc9d8ba866c442b9118d9630009939c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8c4d4fc5a30f4e7cb3be53fe2adda33d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8f5bd719974e41c3a8dd9a5b0d3d71e6": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8f726dbfb45d4528afa33e36a6313267": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9327977822be4b1294f80e876552e305": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_37de928300e34184881039378bd75e7f", + "placeholder": "​", + "style": "IPY_MODEL_0e936d9dbf9c4fdd86bbfe9730dedc47", + "value": " 3.96G/3.96G [00:13<00:00, 273MB/s]" + } + }, + "936d04b5fe1b4c63bf0b080e423d051b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "93a44a11aa4846fa8efc6c1413ef1627": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "94b9088614464f60a203de39dbcae853": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9503a45960984adc97b58e16c50662e0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "95caff42f08a4c2aa14c867b8f37f231": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_de7c37ee83e24f0c889e84d07279c2ec", + "IPY_MODEL_9d4897eefb5f48259ffb2d23e332f752", + "IPY_MODEL_253017b0d0534e54ab44e181f6d7c82d" + ], + "layout": "IPY_MODEL_27beaf06e41b472abdb544a43c720c5a" + } + }, + "977f799afaac4a55b2dc1cffa7d5b63b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "97e36007e1304e1583fd81bfb13f0edd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9858cb74a09748a39e8149baac96702c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9b42e08b3c9548818488268768a118b1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d955dcaa0e944e719f3a06139dd54a03", + "placeholder": "​", + "style": "IPY_MODEL_d3de2662c7964f1ba96e58da382af720", + "value": "merges.txt: 100%" + } + }, + "9cd5211b5d8b457aa0002f1d17b80028": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_6932489232ec4ab18a160b1e7fbcdfe1", + "placeholder": "​", + "style": "IPY_MODEL_4540927d98f54466b434ba4c0edf045d", + "value": "model-00007-of-00008.safetensors: 100%" + } + }, + "9d4897eefb5f48259ffb2d23e332f752": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_30a81da86f8043eca301e86a8651201a", + "max": 2776833, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_e8b7a81040904c1e89e58978223b1737", + "value": 2776833 + } + }, + "9e333ed3b5014069ac1dd969255dd591": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9ed02dc43412471a9ab47f3620ccf3a5": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9f1c9a0695384bdaa6f8b847ef89bee8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ButtonStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ButtonStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "button_color": null, + "font_weight": "" + } + }, + "9f56a2d9979c4bd8928c644c22c3ecdf": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a0a11e929edd4189b79723d618522c33": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a10d0a76010f4e508c65a9b69ebc5156": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a138859f19b74fc0928dc236ab5359db": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_9b42e08b3c9548818488268768a118b1", + "IPY_MODEL_12b56912736849fea2ad8124456fdc5c", + "IPY_MODEL_879c8ab5873847a8833bd74123be90a4" + ], + "layout": "IPY_MODEL_20352e5f58d24bb8b1f3940efd14fe4a" + } + }, + "a1959759c5424da9961fb2a308d4dee4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3aaecbf540f54a2db9ab0931e3b1fe57", + "placeholder": "​", + "style": "IPY_MODEL_9e333ed3b5014069ac1dd969255dd591", + "value": " 239/239 [00:00<00:00, 30.9kB/s]" + } + }, + "a20927bf5f2c41f58c1e31ac858ab36c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1811cda0644e4190a9469d1774435d82", + "placeholder": "​", + "style": "IPY_MODEL_35c811d2ae8e43f3b5cecbdd3cfa857f", + "value": "tokenizer.json: 100%" + } + }, + "a3a945817f684328b34651fe052393ec": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a44f630e099e43899f20a77084ae60cd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ed5ca967ad5342929e578ac6aa4dc4c0", + "placeholder": "​", + "style": "IPY_MODEL_af401d117d5047629d3a6e2361757b62", + "value": "model-00001-of-00008.safetensors: 100%" + } + }, + "a4e5789584564049b83df7c6c54a3e08": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a5434ee714f9498d83870544b67c0cb7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "a55060adc3564407ac81ad7297d34aaa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a6f48410b9964fefba0c3009a77dc838": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a7cf477e80fc43e0ad82c7997b076dce": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a80410b919e442c49aea15acc1ce1a72": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_fa1282ccc7544e4f818e2f03ccffe4a5", + "placeholder": "​", + "style": "IPY_MODEL_bbbf575d2a4b4c6ea8389be79b2a6039", + "value": "model.safetensors.index.json: 100%" + } + }, + "ab93eabd7cea4b94b4b7a387f101e8a1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ac764024cf1c4e08ba7749afd2cd20ac": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "ad1236893754446881e153adc9d5c962": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_62e302ebdad64aada0ffe64ae1c873f3", + "placeholder": "​", + "style": "IPY_MODEL_bd1b0dfed6d34d16af33a4a58330f5ec", + "value": "Saving the dataset (1/1 shards): 100%" + } + }, + "ad7599de524549c48bf2d3124ad4b299": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "adacfdcc1b0140efac56918e9ccf064e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "af401d117d5047629d3a6e2361757b62": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b191ac001a2e4962bc9a245fcdf26e6b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b195f160ca20442fadd8b5aed0ee41af": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b1bea589efa14258a9982071b87938bf": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b5b65414154544aa8a71b1a39164aad7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b634bb73cfa743d09a5999101b840976": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b82aa8c57f7c422a9a9c90f333ed2a99": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_c0991cf63ee6458b96e9a75e7a88b61a", + "IPY_MODEL_71c8af139cd248b1b51101fd46a93f35", + "IPY_MODEL_1d5117195d4b49eb8f1a73b18419f7ce" + ], + "layout": "IPY_MODEL_3c21e4a511b4441192c03b7f1d0976e9" + } + }, + "b8766a88716948cf968f4563531a76d9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2b3a2659b12244bd8548320320016dbf", + "placeholder": "​", + "style": "IPY_MODEL_0cd7efffbb3c4c4b972e63749f61ab97", + "value": "Generating train split: " + } + }, + "b87c84de30e84b3abf4871461fb9cbd3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b8e39e4dddc3497fbc29ae45c66da759": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bb33aec33a6447078c31bfd728942994": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "bbbf575d2a4b4c6ea8389be79b2a6039": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "bca2c7185b6749fd899c06a2ba4c5e46": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0f480e3a0b0a45d2a2d2dec3cad923f3", + "placeholder": "​", + "style": "IPY_MODEL_fcb30372e7404c5d8a1ad4df91e6c7b2", + "value": " 1.91G/1.91G [00:05<00:00, 444MB/s]" + } + }, + "bd1b0dfed6d34d16af33a4a58330f5ec": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "be724f04b03942b2a033a7e8898bb4fd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bed8726b8069434687c75452e21f19e5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_fa864b41586f4a7aa56aeafd1d84eb75", + "max": 9985, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_3225603166b54e7aab766b9964a2f660", + "value": 9985 + } + }, + "bee3501b2a17427784a717e50a85e7fa": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bfcdbba993b74972a9e3e575f86908ff": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bff139df987d4a62abec6456cb27f3d4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "danger", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c1f9c267ba3f40039cdb5eb3267e8043", + "max": 3963750880, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_33b3b1d0295646edaac7b4822761aeb0", + "value": 3963750502 + } + }, + "c0892a1881de4eb4bfabc6a68f87ae99": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_158c8b85dbf34de6a94b4e35e2fc7d5a", + "placeholder": "​", + "style": "IPY_MODEL_0b4c9753a7cb4354b8e5f187e6e1ad7c", + "value": " 3.96G/3.96G [00:15<00:00, 564MB/s]" + } + }, + "c0991cf63ee6458b96e9a75e7a88b61a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ed28e2e0410d4e0b855467e798e53d66", + "placeholder": "​", + "style": "IPY_MODEL_d93f134f802b4b69b575bdaf07dbd27c", + "value": "tokenizer_config.json: 100%" + } + }, + "c12ea43372ac4d57bb9605f1a429b397": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "VBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "VBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "VBoxView", + "box_style": "", + "children": [], + "layout": "IPY_MODEL_131065f118274a1586ac38e39ed84ef0" + } + }, + "c1314f241a434c41b45d84dc4d3b30f8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "c1f9c267ba3f40039cdb5eb3267e8043": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c33ced495f70464aa4a3a91922090853": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c3725c7f79fe415fbd1ea336f0cc9cf1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "danger", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b191ac001a2e4962bc9a245fcdf26e6b", + "max": 3841788544, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_054c8dffadba48c6b895a6cc62448ecc", + "value": 3841788178 + } + }, + "c3be9109d63c485d9c0ef4f9bc0f9218": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c42acf646f344a88b8c11f81e67f7206": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8bc9d8ba866c442b9118d9630009939c", + "placeholder": "​", + "style": "IPY_MODEL_9f56a2d9979c4bd8928c644c22c3ecdf", + "value": "model-00003-of-00008.safetensors: 100%" + } + }, + "c6164e05a1914ae48083db9ad7f4ef7c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c65dc74c7d6f4bab8f7dd28455161dd8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "c6e00f5224364822bc4239b176686919": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2a51b36be41745468e4c2d7a21b1c0d2", + "max": 36514, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_4fd114abe9f5494ab59858949f5055f1", + "value": 36514 + } + }, + "c73055099c084dca996159e23e162d0b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e40d1c1ac9494b3bade9858324e7ffdf", + "placeholder": "​", + "style": "IPY_MODEL_d65b6b060d9845779299491ac5599c31", + "value": " 9985/9985 [01:04<00:00, 189.08 examples/s]" + } + }, + "c7433acd3c4841e6958ae8f7e87b1808": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "CheckboxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "CheckboxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "CheckboxView", + "description": "Add token as git credential?", + "description_tooltip": null, + "disabled": false, + "indent": true, + "layout": "IPY_MODEL_62c028fdef904dedb9cdeca2b3bda725", + "style": "IPY_MODEL_a7cf477e80fc43e0ad82c7997b076dce", + "value": false + } + }, + "c84cc07789be48aebb322c23d355289e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0077aedc3d174560bce924ee89e9c006", + "placeholder": "​", + "style": "IPY_MODEL_00321cce58884f6f9b3855a21fcd9187", + "value": "Add position_id column (Sample Packing) (num_proc=2): 100%" + } + }, + "ca65e32eb52f48c09a84b33cb18f22cd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "cc587493c33c4f118d1b1170f85be24c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "cc94432d08464affa3e58b560bdad194": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "danger", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b5b65414154544aa8a71b1a39164aad7", + "max": 3963750816, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_f0a58fbd0fca4340890041f99fa2f8c8", + "value": 3963750438 + } + }, + "ccfcdc95baf646f8aeb3d516742383f2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cdebbc55a1164c018546c2ac6f8c620c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_a44f630e099e43899f20a77084ae60cd", + "IPY_MODEL_c3725c7f79fe415fbd1ea336f0cc9cf1", + "IPY_MODEL_0e50870ed0c643e0b6c18cc5d7ddae7f" + ], + "layout": "IPY_MODEL_c33ced495f70464aa4a3a91922090853" + } + }, + "d02274afd47b462291c745f261209d42": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d07c8b97d3314f1c852e44bdd40f61ed": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d0e9dce55cec4c1ca619a0ccf209d924": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d1f9b10c130542f094c8fd3d1e23b5e9": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d262c82138024169b9f3aa034ca756fa": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d3de2662c7964f1ba96e58da382af720": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "d4183e9715f34d249942b8271cca3bdf": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_63580b6fb30642479fe3000915bf551a", + "placeholder": "​", + "style": "IPY_MODEL_8f726dbfb45d4528afa33e36a6313267", + "value": " 27.3M/27.3M [00:00<00:00, 31.0MB/s]" + } + }, + "d43c6df07ddb466587807d6dbe1ff614": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8c4d4fc5a30f4e7cb3be53fe2adda33d", + "placeholder": "​", + "style": "IPY_MODEL_e90658f4bcb642baa78426012f863152", + "value": "model-00004-of-00008.safetensors: 100%" + } + }, + "d65b6b060d9845779299491ac5599c31": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "d6fe74e4255444368f8f90a62157d869": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d93f134f802b4b69b575bdaf07dbd27c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "d955dcaa0e944e719f3a06139dd54a03": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "da2347ac94764a3fa2743343cf0d3cd2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "da6e93f3e4984780b930fe7a706983ea": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "daee63fd167e4441a32324b51b00ad2b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d07c8b97d3314f1c852e44bdd40f61ed", + "max": 9985, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_ebb69a2c3d0a4299a484698287b3087c", + "value": 9985 + } + }, + "dc892a596f6942d7973c616c38f0eebb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_c84cc07789be48aebb322c23d355289e", + "IPY_MODEL_bed8726b8069434687c75452e21f19e5", + "IPY_MODEL_16a188a0b06d45f980dcf3933509fe0a" + ], + "layout": "IPY_MODEL_60c1a0d765c14a1d888317e6a507e4ea" + } + }, + "dd0e646fad3f4a89ba23b39d162bd8d9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_d43c6df07ddb466587807d6dbe1ff614", + "IPY_MODEL_e0e8b840b8ea4d0d9db09afe99fa287d", + "IPY_MODEL_9327977822be4b1294f80e876552e305" + ], + "layout": "IPY_MODEL_77304d1a46b3468a98483e02ec0ac4a4" + } + }, + "de7c37ee83e24f0c889e84d07279c2ec": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_34cf3df51fbc41cabfdbba153c007f0e", + "placeholder": "​", + "style": "IPY_MODEL_ac764024cf1c4e08ba7749afd2cd20ac", + "value": "vocab.json: 100%" + } + }, + "dfd2a2649b8341ef913207526708aff1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e09f1bcbb9d94c09be53e5e1303642c2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e7d8e4fe58384e93a106de546068c65e", + "max": 8, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_0aa8ab56b85f4171a79c3bc210594025", + "value": 8 + } + }, + "e0e8b840b8ea4d0d9db09afe99fa287d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "danger", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f7434f3e03124a1c938a39af79d7fa59", + "max": 3963750880, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_c1314f241a434c41b45d84dc4d3b30f8", + "value": 3963750502 + } + }, + "e21e180307e5485cbbe908672fd6639a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_2e2b0c1599c341a198f632f46a40c90e", + "IPY_MODEL_bff139df987d4a62abec6456cb27f3d4", + "IPY_MODEL_ebe1cc366d324ad59b264c8b3c431441" + ], + "layout": "IPY_MODEL_114dece49dba437c8572ef94b23c3b1e" + } + }, + "e366ae3fceec4566b9ed303d6c5f90af": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e3fb3fc6afe04b3c9b7ac61809ce78fa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c6164e05a1914ae48083db9ad7f4ef7c", + "placeholder": "​", + "style": "IPY_MODEL_813621384dc748b0ad06775e22761c0b", + "value": " 9985/9985 [00:03<00:00, 3622.89 examples/s]" + } + }, + "e400cbf14bcc446a9d33b210cd93550b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e40d1c1ac9494b3bade9858324e7ffdf": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e575d87a7efe4ec7b1efde489839d4a6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "e5a82df528bb4e408797a3b6c2758f4a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e6e969610738449887259063967f82b0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "e7d8e4fe58384e93a106de546068c65e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e87ea87fcff247b5bbcc331ba79a8dc2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "e8b7a81040904c1e89e58978223b1737": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "e90658f4bcb642baa78426012f863152": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "eb1c9535e6a546098b760528b2ea387c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_18357b321ce44d7b8bd9d1c886f69275", + "IPY_MODEL_279937fe03bc4e4eb25b472d7e9df163", + "IPY_MODEL_bca2c7185b6749fd899c06a2ba4c5e46" + ], + "layout": "IPY_MODEL_1f7d30f71bbd4547a9150d21da071055" + } + }, + "ebb69a2c3d0a4299a484698287b3087c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "ebc80d1a55fa47f4a5ea2756588569ec": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ebe1cc366d324ad59b264c8b3c431441": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_fba7aa824b38467ab3061b226114cdec", + "placeholder": "​", + "style": "IPY_MODEL_f3075dccbd2747b4a7913b66f44f2596", + "value": " 3.96G/3.96G [00:13<00:00, 398MB/s]" + } + }, + "ec030fc3c346426f9abc3a89892258d3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_dfd2a2649b8341ef913207526708aff1", + "max": 9985, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_4f1977d7e4824ef1a14b65f0f42bba10", + "value": 9985 + } + }, + "ec11d1e5ae7b42c883d9b1f38a65356e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_936d04b5fe1b4c63bf0b080e423d051b", + "placeholder": "​", + "style": "IPY_MODEL_f1cef8e8dc2646fb9fd09f3b09081074", + "value": " 36.5k/36.5k [00:00<00:00, 4.32MB/s]" + } + }, + "ed28e2e0410d4e0b855467e798e53d66": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ed5ca967ad5342929e578ac6aa4dc4c0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "edc99591b9c747b689b94d0052fec14c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ef0a3c7a6f14460fb4da096928ae249e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_07fb3a2c8315494e97b447e672dfae06", + "IPY_MODEL_ec030fc3c346426f9abc3a89892258d3", + "IPY_MODEL_e3fb3fc6afe04b3c9b7ac61809ce78fa" + ], + "layout": "IPY_MODEL_c3be9109d63c485d9c0ef4f9bc0f9218" + } + }, + "ef223e8504b64e3592589880326aaf41": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f0a58fbd0fca4340890041f99fa2f8c8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "f113ebd8c1c34806bea4dd7ed3035173": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "f1cef8e8dc2646fb9fd09f3b09081074": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "f3075dccbd2747b4a7913b66f44f2596": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "f365820a3d3c42b2948abfe32065de14": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_735d4f225b24414294fc1b213c61223c", + "placeholder": "​", + "style": "IPY_MODEL_5e5e15b0569b474c9620083b3ec6af55", + "value": "generation_config.json: 100%" + } + }, + "f4667818b9d34a09891cd727a429a610": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_4b27c267393640f28f6eae0875bd2ed9", + "placeholder": "​", + "style": "IPY_MODEL_9858cb74a09748a39e8149baac96702c", + "value": " 3.96G/3.96G [00:11<00:00, 457MB/s]" + } + }, + "f4a1795dc7514a718f478245f521f0ba": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f60a2bdb6b6b4e0e8c3508580e247132": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "danger", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_edc99591b9c747b689b94d0052fec14c", + "max": 3963750880, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_35cc989ca3374e7dba0cb166febc4bde", + "value": 3963750502 + } + }, + "f7434f3e03124a1c938a39af79d7fa59": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f8ef805b776145c3bfa9ba8d90972058": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fa1282ccc7544e4f818e2f03ccffe4a5": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fa864b41586f4a7aa56aeafd1d84eb75": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fba7aa824b38467ab3061b226114cdec": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fcb30372e7404c5d8a1ad4df91e6c7b2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "fcbab4d8dced41a18dfccce81e3a45a0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "fd4f333f7ece4450b04e1a9af1f9d2f6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d1f9b10c130542f094c8fd3d1e23b5e9", + "placeholder": "​", + "style": "IPY_MODEL_e575d87a7efe4ec7b1efde489839d4a6", + "value": "model-00006-of-00008.safetensors: 100%" + } + }, + "fe18bba7f3fb4c31bf840541f36b3425": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_fd4f333f7ece4450b04e1a9af1f9d2f6", + "IPY_MODEL_f60a2bdb6b6b4e0e8c3508580e247132", + "IPY_MODEL_c0892a1881de4eb4bfabc6a68f87ae99" + ], + "layout": "IPY_MODEL_1bec6297c90242a88672d195bc09d429" + } + }, + "fe41858c6bd04c58840112b67c19a336": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e5a82df528bb4e408797a3b6c2758f4a", + "placeholder": "​", + "style": "IPY_MODEL_f113ebd8c1c34806bea4dd7ed3035173", + "value": " 9985/9985 [00:00<00:00, 44264.88 examples/s]" + } + }, + "fea1b70fb46745feb5111b3929175b5d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_f365820a3d3c42b2948abfe32065de14", + "IPY_MODEL_823f1c78f15043e38bbd4dca3932a86a", + "IPY_MODEL_a1959759c5424da9961fb2a308d4dee4" + ], + "layout": "IPY_MODEL_34c9c0137b504cd799c6bd6de69507c2" + } + }, + "ff3a94b146a948b6907f5d80c7157f99": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "ffdbb12a2f2c4d14911685e7683e0ef0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ab93eabd7cea4b94b4b7a387f101e8a1", + "placeholder": "​", + "style": "IPY_MODEL_704f2f5a9b1c49d5a75a0025a5dda11b", + "value": " 3.96G/3.96G [00:12<00:00, 656MB/s]" + } + } + } } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 0 } diff --git a/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml b/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml index 2c0495ced..97d1bb6b3 100644 --- a/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml +++ b/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml @@ -9,10 +9,6 @@ strict: false datasets: - path: fozziethebeat/alpaca_messages_2k_test type: chat_template - field_messages: messages - message_property_mappings: - role: role - content: content dataset_prepared_path: val_set_size: 0.05 @@ -21,7 +17,7 @@ output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true eval_sample_packing: false -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -51,8 +47,10 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml b/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml index de9c956e0..b80cc5bc0 100644 --- a/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml +++ b/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml @@ -9,10 +9,6 @@ strict: false datasets: - path: fozziethebeat/alpaca_messages_2k_test type: chat_template - field_messages: messages - message_property_mappings: - role: role - content: content dataset_prepared_path: val_set_size: 0.05 @@ -21,7 +17,7 @@ output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true eval_sample_packing: false -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -51,8 +47,10 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/deepseek-v2/fft-fsdp-16b.yaml b/examples/deepseek-v2/fft-fsdp-16b.yaml index 0ed97db36..6e936da16 100644 --- a/examples/deepseek-v2/fft-fsdp-16b.yaml +++ b/examples/deepseek-v2/fft-fsdp-16b.yaml @@ -12,7 +12,7 @@ output_dir: ./outputs/out sequence_len: 2048 sample_packing: true -pad_to_sequence_len: true + wandb_project: wandb_entity: @@ -37,7 +37,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 100 +warmup_ratio: 0.1 evals_per_epoch: 2 saves_per_epoch: 1 weight_decay: 0.0 @@ -55,3 +55,5 @@ fsdp_config: fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/deepseek-v2/qlora-fsdp-2_5.yaml b/examples/deepseek-v2/qlora-fsdp-2_5.yaml index 34dbeaafe..aab5034a0 100644 --- a/examples/deepseek-v2/qlora-fsdp-2_5.yaml +++ b/examples/deepseek-v2/qlora-fsdp-2_5.yaml @@ -30,7 +30,7 @@ output_dir: ./outputs/out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + wandb_project: wandb_entity: @@ -61,7 +61,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 100 +warmup_ratio: 0.1 evals_per_epoch: 2 saves_per_epoch: 1 weight_decay: 0.0 @@ -79,3 +79,5 @@ fsdp_config: fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/devstral/README.md b/examples/devstral/README.md new file mode 100644 index 000000000..ae0860662 --- /dev/null +++ b/examples/devstral/README.md @@ -0,0 +1,73 @@ +# Finetune Devstral with Axolotl + +Devstral Small is a 24B parameter opensource model from MistralAI found on HuggingFace [Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505) and [Devstral-Small-2507](https://huggingface.co/mistralai/Devstral-Small-2507). `Devstral-Small-2507` is the latest version of the model and has [function calling](https://mistralai.github.io/mistral-common/usage/tools/) support. + +This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking. + +The model was fine-tuned ontop of [Mistral-Small-3.1](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503) without the vision layer and has a context of up to 128k tokens. + +Thanks to the team at MistralAI for giving us early access to prepare for this release. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + + Here is an example of how to install from pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.6.0 min) +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' +``` + +2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage + +```bash +python scripts/cutcrossentropy_install.py | sh +``` + +3. Run the finetuning example: + +```bash +axolotl train examples/devstral/devstral-small-qlora.yml +``` + +This config uses about 21GB VRAM. + +Let us know how it goes. Happy finetuning! 🚀 + +### TIPS + +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). +- Learn how to use function calling with Axolotl at [docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use). + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) +- [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) +- [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels) + +## Limitations + +We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only. + +In addition, we do not support overriding tokens yet. + +## Related Resources + +- [MistralAI Devstral Blog](https://mistral.ai/news/devstral) +- [MistralAI Devstral 1.1 Blog](https://mistral.ai/news/devstral-2507) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) + + +## Future Work + +- Add parity to Preference Tuning, RL, Multi-modal, etc. +- Add parity to other tokenizer configs like overriding tokens. diff --git a/examples/devstral/devstral-small-qlora.yml b/examples/devstral/devstral-small-qlora.yml new file mode 100644 index 000000000..7fe4dd433 --- /dev/null +++ b/examples/devstral/devstral-small-qlora.yml @@ -0,0 +1,66 @@ +base_model: mistralai/Devstral-Small-2507 + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + +load_in_8bit: false +load_in_4bit: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/qlora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0 +lora_target_linear: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + +warmup_ratio: 0.05 +evals_per_epoch: 4 +saves_per_epoch: 1 + +weight_decay: 0.0 +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/distributed-parallel/README.md b/examples/distributed-parallel/README.md new file mode 100644 index 000000000..ad7c48d5f --- /dev/null +++ b/examples/distributed-parallel/README.md @@ -0,0 +1,52 @@ +# ND Parallelism Examples + +This directory contains example configurations for training models using ND Parallelism in Axolotl. These examples demonstrate how to compose different parallelism strategies (FSDP, TP, CP, HSDP) for efficient multi-GPU training. + +## Quick Start + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + +2. Run the command below: + +```bash +# Train Qwen3 8B with FSDP + TP + CP on a single 8-GPU node +axolotl train examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml + +# Train Llama 3.1 8B with HSDP + TP on 2 nodes (16 GPUs total) +axolotl train examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml +``` + +## Example Configurations + +### Single Node (8 GPUs) + +**Qwen3 8B with FSDP + TP + CP** ([qwen3-8b-fsdp-tp-cp.yaml](./qwen3-8b-fsdp-tp-cp.yaml)) +- Uses all 3 parallelism dimensions on a single node +- Ideal for: when model weights, activations, and/or context are too large to fit on single GPU + +```yaml +dp_shard_size: 2 # FSDP across 2 GPUs +tensor_parallel_size: 2 # TP across 2 GPUs +context_parallel_size: 2 # CP across 2 GPUs +# Total: 2 × 2 × 2 = 8 GPUs +``` + +### Multi-Node + +**Llama 3.1 8B with HSDP + TP** ([llama-3_1-8b-hsdp-tp.yaml](./llama-3_1-8b-hsdp-tp.yaml)) +- FSDP & TP within nodes, DDP across nodes to minimize inter-node communication +- Ideal for: Scaling to multiple nodes while maintaining training efficiency + +```yaml +dp_shard_size: 4 # FSDP within each 4-GPU group +tensor_parallel_size: 2 # TP within each node +dp_replicate_size: 2 # DDP across 2 groups +# Total: (4 × 2) × 2 = 16 GPUs (2 nodes) +``` + +## Learn More + +- [ND Parallelism Documentation](https://docs.axolotl.ai/docs/nd_parallelism.html) +- [Blog: Accelerate ND-Parallel Guide](https://huggingface.co/blog/accelerate-nd-parallel) +- [Multi-GPU Training Guide](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml b/examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml new file mode 100644 index 000000000..f10dc9bd2 --- /dev/null +++ b/examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml @@ -0,0 +1,47 @@ +base_model: meta-llama/Llama-3.1-8B + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +dp_shard_size: 4 +dp_replicate_size: 2 +tensor_parallel_size: 2 +# context_parallel_size: 2 + +dataset_prepared_path: last_run_prepared + +special_tokens: + pad_token: <|end_of_text|> + +fsdp_version: 2 +fsdp_config: + offload_params: false + state_dict_type: FULL_STATE_DICT + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: LlamaDecoderLayer + reshard_after_forward: true + +datasets: + - path: tatsu-lab/alpaca + type: alpaca + +output_dir: ./outputs/ndp-out/ + +sequence_len: 2048 +sample_packing: true +flash_attention: true + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 2 +optimizer: adamw_torch_fused +lr_scheduler: constant_with_warmup +learning_rate: 2e-6 + +bf16: true +tf32: true + +logging_steps: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 diff --git a/examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml b/examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml new file mode 100644 index 000000000..584a33f44 --- /dev/null +++ b/examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml @@ -0,0 +1,46 @@ +base_model: Qwen/Qwen3-8B + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +dp_shard_size: 2 +# dp_replicate_size: 1 +context_parallel_size: 2 +tensor_parallel_size: 2 + +dataset_prepared_path: last_run_prepared + +fsdp_version: 2 +fsdp_config: + offload_params: false + state_dict_type: FULL_STATE_DICT + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Qwen3DecoderLayer + reshard_after_forward: true + +datasets: + - path: tatsu-lab/alpaca + type: alpaca + +output_dir: ./outputs/ndp-out/ + +sequence_len: 8192 +sample_packing: true +flash_attention: true + +gradient_accumulation_steps: 1 +micro_batch_size: 1 # must be 1 when using context parallel +num_epochs: 2 +optimizer: adamw_torch_fused +lr_scheduler: constant_with_warmup +learning_rate: 2e-6 + +bf16: true +tf32: true + +logging_steps: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 + +special_tokens: diff --git a/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml new file mode 100644 index 000000000..2473179f0 --- /dev/null +++ b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml @@ -0,0 +1,73 @@ +base_model: tiiuae/Falcon-H1-1.5B-Deep-Base +# optionally might have model_type or tokenizer_type +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: true + +# huggingface repo +chat_template: falcon_h1 +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - in_proj + - gate_proj + - up_proj + - down_proj + +sequence_len: 2048 +sample_packing: false +eval_sample_packing: false + + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/falcon-h1/falcon-h1-1b-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-qlora.yaml new file mode 100644 index 000000000..bfb7836ef --- /dev/null +++ b/examples/falcon-h1/falcon-h1-1b-qlora.yaml @@ -0,0 +1,72 @@ +base_model: tiiuae/Falcon-H1-1.5B-Base +# optionally might have model_type or tokenizer_type +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: true + +# huggingface repo +chat_template: falcon_h1 +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - in_proj + - gate_proj + - up_proj + - down_proj + +sequence_len: 2048 +sample_packing: false +eval_sample_packing: false + + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/falcon-h1/falcon-h1-34b-qlora.yaml b/examples/falcon-h1/falcon-h1-34b-qlora.yaml new file mode 100644 index 000000000..80a9d45b5 --- /dev/null +++ b/examples/falcon-h1/falcon-h1-34b-qlora.yaml @@ -0,0 +1,73 @@ +base_model: tiiuae/Falcon-H1-34B-Base +# optionally might have model_type or tokenizer_type +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: true + +# huggingface repo +chat_template: falcon_h1 +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - in_proj + - gate_proj + - up_proj + - down_proj + +sequence_len: 2048 +sample_packing: false +eval_sample_packing: false + + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/falcon-h1/falcon-h1-3b-qlora.yaml b/examples/falcon-h1/falcon-h1-3b-qlora.yaml new file mode 100644 index 000000000..02be8ac5d --- /dev/null +++ b/examples/falcon-h1/falcon-h1-3b-qlora.yaml @@ -0,0 +1,73 @@ +base_model: tiiuae/Falcon-H1-3B-Base +# optionally might have model_type or tokenizer_type +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: true + +# huggingface repo +chat_template: falcon_h1 +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - in_proj + - gate_proj + - up_proj + - down_proj + +sequence_len: 2048 +sample_packing: false +eval_sample_packing: false + + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/falcon-h1/falcon-h1-500m-qlora.yaml b/examples/falcon-h1/falcon-h1-500m-qlora.yaml new file mode 100644 index 000000000..b112d5d85 --- /dev/null +++ b/examples/falcon-h1/falcon-h1-500m-qlora.yaml @@ -0,0 +1,73 @@ +base_model: tiiuae/Falcon-H1-0.5B-Instruct +# optionally might have model_type or tokenizer_type +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: true + +# huggingface repo +chat_template: falcon_h1 +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - in_proj + - gate_proj + - up_proj + - down_proj + +sequence_len: 2048 +sample_packing: false +eval_sample_packing: false + + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/falcon-h1/falcon-h1-7b-qlora.yaml b/examples/falcon-h1/falcon-h1-7b-qlora.yaml new file mode 100644 index 000000000..c5505873d --- /dev/null +++ b/examples/falcon-h1/falcon-h1-7b-qlora.yaml @@ -0,0 +1,73 @@ +base_model: tiiuae/Falcon-H1-7B-Base +# optionally might have model_type or tokenizer_type +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: true + +# huggingface repo +chat_template: falcon_h1 +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - in_proj + - gate_proj + - up_proj + - down_proj + +sequence_len: 2048 +sample_packing: false +eval_sample_packing: false + + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gemma2/qlora.yml b/examples/gemma2/qlora.yml index cb96a32c1..8a295a1f8 100644 --- a/examples/gemma2/qlora.yml +++ b/examples/gemma2/qlora.yml @@ -31,7 +31,7 @@ lora_target_linear: true sequence_len: 2048 sample_packing: true eval_sample_packing: false -pad_to_sequence_len: true + wandb_project: wandb_entity: @@ -60,3 +60,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gemma2/reward-model.yaml b/examples/gemma2/reward-model.yaml index ce01a4572..67b1228b2 100644 --- a/examples/gemma2/reward-model.yaml +++ b/examples/gemma2/reward-model.yaml @@ -18,7 +18,7 @@ remove_unused_columns: false sequence_len: 2048 sample_packing: false eval_sample_packing: false -pad_to_sequence_len: true + wandb_project: wandb_entity: @@ -50,3 +50,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gemma3/gemma-3-1b-qlora.yml b/examples/gemma3/gemma-3-1b-qlora.yml index 44310558c..115717db7 100644 --- a/examples/gemma3/gemma-3-1b-qlora.yml +++ b/examples/gemma3/gemma-3-1b-qlora.yml @@ -13,6 +13,8 @@ load_in_4bit: true # huggingface repo chat_template: gemma3 +eot_tokens: + - datasets: - path: cgato/SlimOrcaDedupCleaned type: chat_template @@ -33,7 +35,7 @@ lora_target_linear: true sequence_len: 2048 sample_packing: true eval_sample_packing: false -pad_to_sequence_len: true + wandb_project: wandb_entity: @@ -64,3 +66,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gemma3/gemma-3-270m-qlora.yml b/examples/gemma3/gemma-3-270m-qlora.yml new file mode 100644 index 000000000..8744fad26 --- /dev/null +++ b/examples/gemma3/gemma-3-270m-qlora.yml @@ -0,0 +1,68 @@ +base_model: google/gemma-3-270m-it +# optionally might have model_type or tokenizer_type +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +# gemma3 doesn't seem to play nice with ddp +ddp_find_unused_parameters: true + +load_in_8bit: false +load_in_4bit: true + +# huggingface repo +chat_template: gemma3 +eot_tokens: + - +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true + +sequence_len: 2048 +sample_packing: true +eval_sample_packing: false + + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/examples/gemma3/gemma-3-4b-qlora.yml b/examples/gemma3/gemma-3-4b-qlora.yml index 0d89d9ffb..44ba9c879 100644 --- a/examples/gemma3/gemma-3-4b-qlora.yml +++ b/examples/gemma3/gemma-3-4b-qlora.yml @@ -6,6 +6,8 @@ load_in_4bit: true ddp_find_unused_parameters: true chat_template: gemma3 +eot_tokens: + - datasets: - path: cgato/SlimOrcaDedupCleaned type: chat_template @@ -23,7 +25,7 @@ lora_model_dir: sequence_len: 2048 sample_packing: true -pad_to_sequence_len: true + lora_r: 32 lora_alpha: 16 @@ -58,3 +60,5 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gemma3/gemma-3-4b-vision-qlora.yml b/examples/gemma3/gemma-3-4b-vision-qlora.yml index 339df92e5..b42b6b492 100644 --- a/examples/gemma3/gemma-3-4b-vision-qlora.yml +++ b/examples/gemma3/gemma-3-4b-vision-qlora.yml @@ -12,11 +12,13 @@ sample_packing: false ddp_find_unused_parameters: true chat_template: gemma3 +eot_tokens: + - datasets: - path: HuggingFaceH4/llava-instruct-mix-vsft type: chat_template split: train[:1%] - field_messages: messages + dataset_prepared_path: last_run_prepared val_set_size: 0.01 output_dir: ./outputs/out @@ -60,3 +62,5 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gemma3n/README.md b/examples/gemma3n/README.md new file mode 100644 index 000000000..ff3946c90 --- /dev/null +++ b/examples/gemma3n/README.md @@ -0,0 +1,70 @@ +# Finetune Gemma-3n with Axolotl + +Gemma-3n is a family of multimodal models from Google found on [HuggingFace](https://huggingface.co/collections/google/gemma-3n-685065323f5984ef315c93f4). This guide shows how to fine-tune it with Axolotl. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + + Here is an example of how to install from pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.6.0 min) +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' +``` + +2. In addition to Axolotl's requirements, Gemma-3n requires: + +```bash +pip3 install timm==1.0.17 + +# for loading audio data +pip3 install librosa==0.11.0 +``` + +3. Download sample dataset files + +```bash +# for text + vision + audio only +wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg +wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga +``` + +4. Run the finetuning example: + +```bash +# text only +axolotl train examples/gemma3n/gemma-3n-e2b-qlora.yml + +# text + vision +axolotl train examples/gemma3n/gemma-3n-e2b-vision-qlora.yml + +# text + vision + audio +axolotl train examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml +``` + +Let us know how it goes. Happy finetuning! 🚀 + +WARNING: The loss and grad norm will be much higher than normal. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look. + +### TIPS + +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). +- The multimodal dataset format follows the OpenAI multi-content Messages format as seen [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format). + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) + +## Related Resources + +- [Gemma 3n Blog](https://ai.google.dev/gemma/docs/gemma-3n) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/gemma3n/gemma-3n-e2b-qlora.yml b/examples/gemma3n/gemma-3n-e2b-qlora.yml new file mode 100644 index 000000000..ad7ab5726 --- /dev/null +++ b/examples/gemma3n/gemma-3n-e2b-qlora.yml @@ -0,0 +1,74 @@ +base_model: google/gemma-3n-E2B-it + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin +cut_cross_entropy: true + +load_in_8bit: false +load_in_4bit: true + +# for use with fft to only train on language model layers +# unfrozen_parameters: + # - model.language_model.* + # - lm_head + # - embed_tokens + + +chat_template: gemma3n +eot_tokens: + - +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + split: train[:1%] + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +# lora_target_linear: # Does not work with gemma3n currently +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj' + +sequence_len: 2048 +sample_packing: true +eval_sample_packing: true +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +# flash_attention: true # Any attention impl does not work with gemma3n now + +warmup_ratio: 0.1 +evals_per_epoch: +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml b/examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml new file mode 100644 index 000000000..d72d7fbc0 --- /dev/null +++ b/examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml @@ -0,0 +1,78 @@ +base_model: google/gemma-3n-E2B-it +processor_type: AutoProcessor + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin +cut_cross_entropy: true + +# for use with fft to only train on language model layers +# unfrozen_parameters: + # - model.language_model.* + # - lm_head + # - embed_tokens + +load_in_4bit: true + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +# gemma3 doesn't seem to play nice with ddp +ddp_find_unused_parameters: true + +chat_template: gemma3n +eot_tokens: + - + +# sample dataset below requires downloading audio/image in advance +# wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg +# wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga +datasets: + - path: Nanobit/text-vision-audio-2k-test + type: chat_template +dataset_prepared_path: +val_set_size: 0.01 +output_dir: ./outputs/out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +logging_steps: 1 +# flash_attention: true # Any attention impl does not work with gemma3n now + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 diff --git a/examples/gemma3n/gemma-3n-e2b-vision-qlora.yml b/examples/gemma3n/gemma-3n-e2b-vision-qlora.yml new file mode 100644 index 000000000..c87eca663 --- /dev/null +++ b/examples/gemma3n/gemma-3n-e2b-vision-qlora.yml @@ -0,0 +1,75 @@ +base_model: google/gemma-3n-E2B-it +processor_type: AutoProcessor + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin +cut_cross_entropy: true + +# for use with fft to only train on language model layers +# unfrozen_parameters: + # - model.language_model.* + # - lm_head + # - embed_tokens + +load_in_4bit: true + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +# gemma3 doesn't seem to play nice with ddp +ddp_find_unused_parameters: true + +chat_template: gemma3n +eot_tokens: + - +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] +dataset_prepared_path: +val_set_size: 0.01 +output_dir: ./outputs/out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +logging_steps: 1 +# flash_attention: true # Any attention impl does not work with gemma3n now + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 diff --git a/examples/glm4/qlora-32b.yaml b/examples/glm4/qlora-32b.yaml index 86d9b43f8..832abde05 100644 --- a/examples/glm4/qlora-32b.yaml +++ b/examples/glm4/qlora-32b.yaml @@ -17,7 +17,7 @@ lora_model_dir: sequence_len: 2048 sample_packing: true eval_sample_packing: true -pad_to_sequence_len: true + lora_r: 16 lora_alpha: 32 @@ -55,8 +55,10 @@ flash_attention: true loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gpt-oss/README.md b/examples/gpt-oss/README.md new file mode 100644 index 000000000..fb6c67498 --- /dev/null +++ b/examples/gpt-oss/README.md @@ -0,0 +1,135 @@ +# Finetune OpenAI's GPT-OSS with Axolotl + +[GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B. + +This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + + Here is an example of how to install from pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.6.0 min) +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' +``` + +2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b)) + +```bash +# LoRA SFT linear layers (1x48GB @ ~44GiB) +axolotl train examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml + +# FFT SFT with offloading (2x24GB @ ~21GiB/GPU) +axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml + +# FFT SFT (8x48GB @ ~36GiB/GPU or 4x80GB @ ~46GiB/GPU) +axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml +``` + +Note: Memory usage taken from `device_mem_reserved(gib)` from logs. + +### Training 120B + +On 8xH100s, make sure you have ~3TB of free disk space. With each checkpoint clocking in at ~720GB, along with the base +model, and final model output, you may need at least 3TB of free disk space to keep at least 2 checkpoints. + +```bash +# FFT SFT with offloading (8x80GB @ ~49GiB/GPU) +axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml +``` + +To simplify fine-tuning across 2 nodes × 8x H100 (80GB) GPUs, we've partnered with [Baseten](https://baseten.co) to showcase multi-node +training of the 120B model using Baseten Truss. You can read more about this recipe on +[Baseten's blog](https://www.baseten.co/blog/how-to-fine-tune-gpt-oss-120b-with-baseten-and-axolotl/). The recipe can +be found on their +[GitHub](https://github.com/basetenlabs/ml-cookbook/tree/main/examples/oss-gpt-120b-axolotl/training). + +ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`. +See https://github.com/huggingface/transformers/pull/40207 for the status of this issue. + +```bash +sed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json +``` + +When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your +configured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to +merge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded +weights to `{output_dir}/merged`. + +```bash +axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml +mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/ +``` + + +### Inferencing your fine-tuned model + +#### vLLM + +GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425 +for more information about using a special vllm-openai docker image for inferencing with vLLM. + +Optionally, vLLM can be installed from nightly: + +```bash +pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly +``` +and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment): +```bash +vllm serve ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-20b --host 0.0.0.0 --port 8888 --tensor-parallel-size 8 +``` + +#### SGLang + +SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing +SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server: + +```bash +python3 -m sglang.launch_server --model ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-120b --host 0.0.0.0 --port 8888 --tp 8 +``` + +### Tool use + +GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning. + +Here is an example dataset config: +```yaml +datasets: + - path: Nanobit/text-tools-2k-test + type: chat_template +``` + +See [Nanobit/text-tools-2k-test](https://huggingface.co/datasets/Nanobit/text-tools-2k-test) for the sample dataset. + +Refer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info. + +### Thinking and chat_template masking conflict + +OpenAI’s Harmony template hides `thinking` in all non-final turns, which conflicts with Axolotl’s `chat_template` masking. + +If your dataset has `thinking` content mid-turn, there are two paths we recommend: + +- Train only on the last turn. This can be accomplished via chat_template's [train on last doc](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#training-on-last-message). + +- Adjust your dataset to only have `thinking` content in the last turn. + +### TIPS + +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) + +## Related Resources + +- [GPT-OSS Blog](https://openai.com/index/introducing-gpt-oss/) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml b/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml new file mode 100644 index 000000000..62f3167e8 --- /dev/null +++ b/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml @@ -0,0 +1,68 @@ +# the original mxfp4 quantized model is not supported with FSDP cpu_ram_efficient_loading +# FSDP cpu_ram_efficient_loading is used to reduce the initial CPU memory usage when loading the model +base_model: axolotl-ai-co/gpt-oss-120b-dequantized + +use_kernels: false + +dp_shard_size: 16 # requires 2x8xH100 nodes + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding + +datasets: + - path: HuggingFaceH4/Multilingual-Thinking + type: chat_template + field_thinking: thinking + template_thinking_key: thinking + +dataset_prepared_path: last_run_prepared +val_set_size: 0 +output_dir: ./outputs/gpt-oss-out/ +save_total_limit: 2 # the 120B model can use up to 720GB of disk space per checkpoint, so let's only keep the last 2 + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 1 +num_epochs: 1 + +optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload +lr_scheduler: constant_with_warmup +learning_rate: 2e-5 + +bf16: true +tf32: true + +flash_attention: true +attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 + +gradient_checkpointing: true +activation_offloading: true + +logging_steps: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.03 + +special_tokens: +eot_tokens: + - "<|end|>" + +fsdp_version: 2 +fsdp_config: + offload_params: true + state_dict_type: SHARDED_STATE_DICT + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: GptOssDecoderLayer + reshard_after_forward: true + cpu_ram_efficient_loading: true diff --git a/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml b/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml new file mode 100644 index 000000000..ccb84e28e --- /dev/null +++ b/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml @@ -0,0 +1,58 @@ +base_model: openai/gpt-oss-20b +use_kernels: false +model_quantization_config: Mxfp4Config +model_quantization_config_kwargs: + dequantize: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding + +datasets: + - path: HuggingFaceH4/Multilingual-Thinking + type: chat_template + field_thinking: thinking + template_thinking_key: thinking + +dataset_prepared_path: last_run_prepared +val_set_size: 0 +output_dir: ./outputs/gpt-oss-out/ + +sequence_len: 4096 +sample_packing: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 1 +num_epochs: 1 + +optimizer: adamw_torch_8bit +lr_scheduler: constant_with_warmup +learning_rate: 2e-5 + +bf16: true +tf32: true + +flash_attention: true +attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 + +gradient_checkpointing: true +activation_offloading: true + +logging_steps: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.03 + +special_tokens: +eot_tokens: + - "<|end|>" + +# choose the zero3 configuration that best fits your system capabilities +deepspeed: deepspeed_configs/zero3_bf16.json diff --git a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml new file mode 100644 index 000000000..69a3c434d --- /dev/null +++ b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml @@ -0,0 +1,68 @@ +base_model: openai/gpt-oss-20b +use_kernels: true +model_quantization_config: Mxfp4Config +model_quantization_config_kwargs: + dequantize: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding + +datasets: + - path: HuggingFaceH4/Multilingual-Thinking + type: chat_template + field_thinking: thinking + template_thinking_key: thinking + +dataset_prepared_path: ./outputs/last_run_prepared +val_set_size: 0 +output_dir: ./outputs/gpt-oss-out/ + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 1 +num_epochs: 1 + +optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload +lr_scheduler: constant_with_warmup +learning_rate: 2e-5 + +bf16: true +tf32: true + +flash_attention: true +attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 + +gradient_checkpointing: true +activation_offloading: true + +logging_steps: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.03 + +special_tokens: +eot_tokens: + - "<|end|>" + +fsdp_version: 2 +fsdp_config: + offload_params: true + state_dict_type: SHARDED_STATE_DICT + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: GptOssDecoderLayer + reshard_after_forward: true + # cpu_ram_efficient_loading: true + +# cpu_ram_efficient_loading cannot be used with MXFP4 model quantization. +# It can only be used with a dequantized model like `axolotl-ai-co/gpt-oss-120b-dequantized` diff --git a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml new file mode 100644 index 000000000..4a0f1ad70 --- /dev/null +++ b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml @@ -0,0 +1,64 @@ +base_model: openai/gpt-oss-20b +use_kernels: false +model_quantization_config: Mxfp4Config +model_quantization_config_kwargs: + dequantize: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding + +datasets: + - path: HuggingFaceH4/Multilingual-Thinking + type: chat_template + field_thinking: thinking + template_thinking_key: thinking + +dataset_prepared_path: ./outputs/last_run_prepared +val_set_size: 0 +output_dir: ./outputs/gpt-oss-out/ + +sequence_len: 4096 +sample_packing: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 1 +num_epochs: 1 + +optimizer: adamw_torch_8bit +lr_scheduler: constant_with_warmup +learning_rate: 2e-5 + +bf16: true +tf32: true + +flash_attention: true +attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 + +gradient_checkpointing: true +activation_offloading: true + +logging_steps: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.03 + +special_tokens: +eot_tokens: + - "<|end|>" + +fsdp_version: 2 +fsdp_config: + offload_params: false + state_dict_type: SHARDED_STATE_DICT + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: GptOssDecoderLayer + reshard_after_forward: true +# cpu_ram_efficient_loading: true diff --git a/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml b/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml new file mode 100644 index 000000000..b6deacb1b --- /dev/null +++ b/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml @@ -0,0 +1,67 @@ +base_model: openai/gpt-oss-20b +use_kernels: true +model_quantization_config: Mxfp4Config +model_quantization_config_kwargs: + dequantize: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding + +datasets: + - path: HuggingFaceH4/Multilingual-Thinking + type: chat_template + field_thinking: thinking + template_thinking_key: thinking + +dataset_prepared_path: last_run_prepared +val_set_size: 0 +output_dir: ./outputs/gpt-oss-out/ + +sequence_len: 4096 +sample_packing: true + +adapter: lora +lora_r: 8 +lora_alpha: 16 +lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters +lora_target_linear: true + +# TODO: not supported for now, see peft#2710 +#lora_target_parameters: # target the experts in the last two layers +# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj" +# - "22._checkpoint_wrapped_module.mlp.experts.down_proj" +# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj" +# - "23._checkpoint_wrapped_module.mlp.experts.down_proj" + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 8 +micro_batch_size: 1 +num_epochs: 1 + +optimizer: adamw_torch_8bit +lr_scheduler: constant_with_warmup +learning_rate: 2e-4 + +bf16: true +tf32: true + +flash_attention: true +attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 + +gradient_checkpointing: true +activation_offloading: true + +logging_steps: 1 +saves_per_epoch: 1 +warmup_ratio: 0.1 + +special_tokens: +eot_tokens: + - "<|end|>" diff --git a/examples/hunyuan/README.md b/examples/hunyuan/README.md new file mode 100644 index 000000000..96c6bbcfa --- /dev/null +++ b/examples/hunyuan/README.md @@ -0,0 +1,85 @@ +# Finetune HunYuan with Axolotl + +Tencent released a family of opensource models called HunYuan with varying parameter scales of 0.5B, 1.8B, 4B, and 7B scale for both Pre-trained and Instruct variants. The models can be found at [HuggingFace](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as HunYuan is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). + + Here is an example of how to install from main for pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.6.0 min) +git clone https://github.com/axolotl-ai-cloud/axolotl.git +cd axolotl + +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation -e '.[flash-attn]' + +# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy +python scripts/cutcrossentropy_install.py | sh +``` + +2. Run the finetuning example: + +```bash +axolotl train examples/hunyuan/hunyuan-v1-dense-qlora.yaml +``` + +This config uses about 4.7 GB VRAM. + +Let us know how it goes. Happy finetuning! 🚀 + +### Dataset + +HunYuan Instruct models can choose to enter a slow think or fast think pattern. For best performance on fine-tuning their Instruct models, your dataset should be adjusted to match their pattern. + +```python +# fast think pattern +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "/no_think What color is the sun?" }, + {"role": "assistant", "content": "\n\n\n\nThe sun is yellow.\n"} +] + +# slow think pattern +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "/no_think What color is the sun?" }, + {"role": "assistant", "content": "\nThe user is asking about the color of the sun. I need to ...\n\n\nThe sun is yellow.\n"} +] +``` + +### TIPS + +- For inference, the official Tencent team recommends + +```json + +{ + "do_sample": true, + "top_k": 20, + "top_p": 0.8, + "repetition_penalty": 1.05, + "temperature": 0.7 +} + +``` + +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) + +## Related Resources + +- [Tencent HunYuan Blog](https://hunyuan.tencent.com/) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/hunyuan/hunyuan-v1-dense-qlora.yaml b/examples/hunyuan/hunyuan-v1-dense-qlora.yaml new file mode 100644 index 000000000..a94345a61 --- /dev/null +++ b/examples/hunyuan/hunyuan-v1-dense-qlora.yaml @@ -0,0 +1,64 @@ +base_model: tencent/Hunyuan-0.5B-Instruct + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/jamba/qlora.yaml b/examples/jamba/qlora.yaml index 2cb0eea41..538ed3a10 100644 --- a/examples/jamba/qlora.yaml +++ b/examples/jamba/qlora.yaml @@ -49,8 +49,10 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/jamba/qlora_deepspeed.yaml b/examples/jamba/qlora_deepspeed.yaml index d13ce6483..b288635e7 100644 --- a/examples/jamba/qlora_deepspeed.yaml +++ b/examples/jamba/qlora_deepspeed.yaml @@ -48,10 +48,12 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: saves_per_epoch: 1 deepspeed: deepspeed_configs/zero2.json weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/jamba/qlora_fsdp_large.yaml b/examples/jamba/qlora_fsdp_large.yaml index 6badaba19..150e5e2ec 100644 --- a/examples/jamba/qlora_fsdp_large.yaml +++ b/examples/jamba/qlora_fsdp_large.yaml @@ -23,7 +23,7 @@ save_safetensors: true adapter: qlora sequence_len: 2048 sample_packing: true -pad_to_sequence_len: true + lora_r: 16 lora_alpha: 16 @@ -47,7 +47,7 @@ gradient_checkpointing_kwargs: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 @@ -64,3 +64,5 @@ fsdp_config: fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml index 86b1b6a21..ea119348e 100644 --- a/examples/llama-2/fft_optimized.yml +++ b/examples/llama-2/fft_optimized.yml @@ -14,7 +14,7 @@ output_dir: ./outputs/out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + adapter: lora_model_dir: @@ -45,13 +45,14 @@ logging_steps: 1 flash_attention: true flash_attn_cross_entropy: false flash_attn_rms_norm: true -flash_attn_fuse_qkv: false flash_attn_fuse_mlp: true -warmup_steps: 100 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 deepspeed: #deepspeed_configs/zero2.json # multi-gpu only weight_decay: 0.1 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/gptq-lora.yml b/examples/llama-2/gptq-lora.yml index 0f1b34016..de1caaa05 100644 --- a/examples/llama-2/gptq-lora.yml +++ b/examples/llama-2/gptq-lora.yml @@ -56,7 +56,7 @@ logging_steps: 1 flash_attention: sdp_attention: flash_optimum: -warmup_steps: 100 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.1 @@ -64,3 +64,5 @@ special_tokens: bos_token: "" eos_token: "" unk_token: "" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/lisa.yml b/examples/llama-2/lisa.yml index a76a792ae..d21c01a49 100644 --- a/examples/llama-2/lisa.yml +++ b/examples/llama-2/lisa.yml @@ -14,7 +14,7 @@ output_dir: ./outputs/lisa-out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + adapter: lora_model_dir: @@ -49,10 +49,9 @@ logging_steps: 1 flash_attention: true flash_attn_cross_entropy: false flash_attn_rms_norm: true -flash_attn_fuse_qkv: false flash_attn_fuse_mlp: true -warmup_steps: 100 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.1 @@ -60,3 +59,5 @@ special_tokens: bos_token: "" eos_token: "" unk_token: "" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/loftq.yml b/examples/llama-2/loftq.yml index 22dbf2d99..619e5bcce 100644 --- a/examples/llama-2/loftq.yml +++ b/examples/llama-2/loftq.yml @@ -14,7 +14,7 @@ output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -47,8 +47,10 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index 679aed3a9..0a677f11a 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -47,8 +47,10 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/qlora-fsdp.yml b/examples/llama-2/qlora-fsdp.yml index a42eabd4b..1e7064de8 100644 --- a/examples/llama-2/qlora-fsdp.yml +++ b/examples/llama-2/qlora-fsdp.yml @@ -20,7 +20,7 @@ lora_model_dir: sequence_len: 512 sample_packing: false -pad_to_sequence_len: true + lora_r: 32 lora_alpha: 16 @@ -50,7 +50,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 @@ -66,4 +66,7 @@ fsdp_config: fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_state_dict_type: FULL_STATE_DICT + # fsdp_cpu_offload_pin_memory: false # uncomment to enable swap memory usage when RAM is insufficient special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml index de65928bc..327d88c15 100644 --- a/examples/llama-2/qlora.yml +++ b/examples/llama-2/qlora.yml @@ -20,7 +20,7 @@ lora_model_dir: sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + lora_r: 32 lora_alpha: 16 @@ -48,8 +48,10 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/relora.yml b/examples/llama-2/relora.yml index e0a5f7068..fabdf0e0f 100644 --- a/examples/llama-2/relora.yml +++ b/examples/llama-2/relora.yml @@ -18,16 +18,19 @@ lora_model_dir: sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + lora_r: 8 lora_alpha: 16 lora_dropout: 0.05 lora_target_linear: true -relora_steps: 150 -relora_warmup_steps: 10 +relora: true +relora_prune_ratio: 0.9 relora_cpu_offload: false +jagged_restart_steps: 150 +jagged_restart_warmup_steps: 10 +jagged_restart_anneal_steps: false wandb_project: wandb_entity: @@ -50,7 +53,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 @@ -58,3 +61,5 @@ special_tokens: bos_token: "" eos_token: "" unk_token: "" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3-vision/lora-11b.yaml b/examples/llama-3-vision/lora-11b.yaml index 2b0ae2c70..adbb61643 100644 --- a/examples/llama-3-vision/lora-11b.yaml +++ b/examples/llama-3-vision/lora-11b.yaml @@ -15,8 +15,7 @@ datasets: - path: HuggingFaceH4/llava-instruct-mix-vsft type: chat_template split: train[:1%] - field_messages: messages -dataset_prepared_path: last_run_prepared +dataset_prepared_path: val_set_size: 0.0 output_dir: ./outputs/out @@ -50,10 +49,12 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: true -eager_attention: +# flash_attention: true # use for text-only mode +sdp_attention: true warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/3b-fp8-fsdp2.yaml b/examples/llama-3/3b-fp8-fsdp2.yaml new file mode 100644 index 000000000..b7de7ca52 --- /dev/null +++ b/examples/llama-3/3b-fp8-fsdp2.yaml @@ -0,0 +1,76 @@ +base_model: meta-llama/Llama-3.2-3B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true + +datasets: + - path: yahma/alpaca-cleaned + type: alpaca + +output_dir: ./outputs/fp8_out/ + +sample_packing: true +pad_to_sequence_len: true +sequence_len: 512 + +flex_attention: true +flex_attn_compile_kwargs: + dynamic: false + mode: max-autotune-no-cudagraphs +save_strategy: no +torch_compile: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 16 +num_epochs: 1 +optimizer: adamw_torch_fused + +cosine_constant_lr_ratio: 0 +cosine_min_lr_ratio: 1.0 +learning_rate: 2e-5 +save_only_model: true + +fp8: true +fp8_enable_fsdp_float8_all_gather: true + +resume_from_checkpoint: +logging_steps: 1 + +evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_steps: 10 +weight_decay: 0.0 + +fsdp_version: 2 +fsdp_config: + offload_params: false + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: LlamaDecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: false + +special_tokens: + pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/3b-qat-fsdp2.yaml b/examples/llama-3/3b-qat-fsdp2.yaml index 5d979c96c..0c5a87891 100644 --- a/examples/llama-3/3b-qat-fsdp2.yaml +++ b/examples/llama-3/3b-qat-fsdp2.yaml @@ -15,20 +15,18 @@ liger_glu_activation: true liger_layer_norm: true liger_fused_linear_cross_entropy: true + datasets: - path: yahma/alpaca-cleaned type: alpaca + split: train[:95%] output_dir: ./outputs/qat_out/ +dataset_prepared_path: ./outputs/qat_out/dataset_prepared -sample_packing: true -pad_to_sequence_len: true -sequence_len: 512 - -flex_attention: true -flex_attn_compile_kwargs: - dynamic: false - mode: max-autotune-no-cudagraphs +sample_packing: false +sequence_len: 8192 +flash_attention: true qat: activation_dtype: int8 @@ -58,7 +56,7 @@ logging_steps: 1 evals_per_epoch: 1 saves_per_epoch: 1 -warmup_steps: 10 +warmup_ratio: 0.1 weight_decay: 0.0 fsdp: - full_shard @@ -67,7 +65,7 @@ fsdp: fsdp_config: fsdp_version: 2 fsdp_offload_params: false - fsdp_cpu_ram_efficient_loading: true + fsdp_cpu_ram_efficient_loading: false fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_state_dict_type: FULL_STATE_DICT @@ -76,4 +74,6 @@ fsdp_config: fsdp_activation_checkpointing: true special_tokens: - pad_token: <|end_of_text|> + pad_token: <|finetune_right_pad_id|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/3b-qat-nvfp4.yaml b/examples/llama-3/3b-qat-nvfp4.yaml new file mode 100644 index 000000000..1ec809bbe --- /dev/null +++ b/examples/llama-3/3b-qat-nvfp4.yaml @@ -0,0 +1,64 @@ +base_model: meta-llama/Llama-3.2-3B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true + +datasets: + - path: yahma/alpaca-cleaned + type: alpaca + split: train[:95%] + +output_dir: ./outputs/qat_out/ +dataset_prepared_path: ./outputs/dataset_prepared + +sequence_len: 8192 +flash_attention: true + +qat: + activation_dtype: nvfp4 + weight_dtype: nvfp4 + group_size: 16 # only group_size of 16 is supported with nvfp4 + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_checkpointing: true +gradient_accumulation_steps: 1 +micro_batch_size: 64 +num_epochs: 1 +optimizer: adamw_torch_fused + +cosine_constant_lr_ratio: 0 +cosine_min_lr_ratio: 1.0 +learning_rate: 2e-5 +save_only_model: true +bf16: true + +resume_from_checkpoint: +logging_steps: 1 + +evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 + +special_tokens: + pad_token: <|finetune_right_pad_id|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/diffusion/pretrain-1b.yaml b/examples/llama-3/diffusion/pretrain-1b.yaml new file mode 100644 index 000000000..8d05e4c60 --- /dev/null +++ b/examples/llama-3/diffusion/pretrain-1b.yaml @@ -0,0 +1,56 @@ +base_model: meta-llama/Llama-3.2-1B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +pretraining_dataset: + - path: wikitext + name: wikitext-103-raw-v1 + type: completion + field: text + +plugins: + - axolotl.integrations.diffusion.DiffusionPlugin + +diffusion: + noise_schedule: cosine + min_mask_ratio: 0.15 + max_mask_ratio: 0.85 + num_diffusion_steps: 128 + eps: 5e-4 + importance_weighting: true + mask_token_id: 128002 + generate_samples: true + generation_interval: 250 + +output_dir: ./outputs/model-out + +sequence_len: 512 +sample_packing: true + +gradient_accumulation_steps: 8 +micro_batch_size: 4 +max_steps: 10000 +warmup_ratio: 0.1 + +optimizer: adamw_8bit +lr_scheduler: cosine +learning_rate: 3e-4 +sdp_attention: true + +bf16: auto +tf32: true + +logging_steps: 1 +save_strategy: steps +save_steps: 1000 + +special_tokens: + pad_token: "<|end_of_text|>" + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/diffusion/sft-1b.yaml b/examples/llama-3/diffusion/sft-1b.yaml new file mode 100644 index 000000000..f3b29a809 --- /dev/null +++ b/examples/llama-3/diffusion/sft-1b.yaml @@ -0,0 +1,59 @@ +base_model: meta-llama/Llama-3.2-1B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +datasets: + - path: teknium/GPT4-LLM-Cleaned + type: alpaca +val_set_size: 0.05 + +plugins: + - axolotl.integrations.diffusion.DiffusionPlugin + +diffusion: + noise_schedule: cosine + min_mask_ratio: 0.1 + max_mask_ratio: 0.9 + num_diffusion_steps: 128 + eps: 1e-3 + importance_weighting: true + mask_token_id: 128002 + generate_samples: true + generation_interval: 250 + +output_dir: ./outputs/model-out + +sequence_len: 512 +sample_packing: true +eval_sample_packing: true + +gradient_accumulation_steps: 4 +micro_batch_size: 4 +num_epochs: 1 +warmup_steps: 0.1 + +optimizer: adamw_8bit +lr_scheduler: cosine +learning_rate: 1e-5 + +bf16: auto +tf32: true + +gradient_checkpointing: true +resume_from_checkpoint: +sdp_attention: true + +logging_steps: 1 +save_strategy: best +eval_strategy: epoch + +special_tokens: + pad_token: "<|end_of_text|>" + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/fft-8b-liger-fsdp.yaml b/examples/llama-3/fft-8b-liger-fsdp.yaml index eccfa6d8c..a655b97a9 100644 --- a/examples/llama-3/fft-8b-liger-fsdp.yaml +++ b/examples/llama-3/fft-8b-liger-fsdp.yaml @@ -26,7 +26,7 @@ output_dir: ./outputs/out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + wandb_project: wandb_entity: @@ -51,7 +51,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 100 +warmup_ratio: 0.1 evals_per_epoch: 2 saves_per_epoch: 1 weight_decay: 0.0 @@ -72,3 +72,5 @@ fsdp_config: special_tokens: pad_token: <|finetune_right_pad_id|> eos_token: <|eot_id|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/fft-8b.yaml b/examples/llama-3/fft-8b.yaml index fdae3e6c4..c72ec6662 100644 --- a/examples/llama-3/fft-8b.yaml +++ b/examples/llama-3/fft-8b.yaml @@ -11,7 +11,7 @@ output_dir: ./outputs/out sequence_len: 8192 sample_packing: true -pad_to_sequence_len: true + wandb_project: wandb_entity: @@ -36,9 +36,11 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 100 +warmup_ratio: 0.1 evals_per_epoch: 2 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/instruct-dpo-lora-8b.yml b/examples/llama-3/instruct-dpo-lora-8b.yml index 13082294f..cf823353b 100644 --- a/examples/llama-3/instruct-dpo-lora-8b.yml +++ b/examples/llama-3/instruct-dpo-lora-8b.yml @@ -5,6 +5,10 @@ tokenizer_type: AutoTokenizer # Automatically upload checkpoint and final model to HF # hub_model_id: username/custom_model_name +special_tokens: + pad_token: <|finetune_right_pad_id|> + eos_token: <|eot_id|> + load_in_8bit: true load_in_4bit: false @@ -33,7 +37,7 @@ output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: false -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -63,7 +67,9 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/instruct-lora-8b.yml b/examples/llama-3/instruct-lora-8b.yml index acab862f6..401df1d72 100644 --- a/examples/llama-3/instruct-lora-8b.yml +++ b/examples/llama-3/instruct-lora-8b.yml @@ -12,15 +12,6 @@ chat_template: llama3 datasets: - path: fozziethebeat/alpaca_messages_2k_test type: chat_template - field_messages: messages - message_property_mappings: - role: role - content: content - roles: - user: - - user - assistant: - - assistant dataset_prepared_path: val_set_size: 0.05 @@ -28,7 +19,7 @@ output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: false -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -58,9 +49,11 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b-deduplicate-dpo.yml b/examples/llama-3/lora-1b-deduplicate-dpo.yml index 10e9747cb..2897636f4 100644 --- a/examples/llama-3/lora-1b-deduplicate-dpo.yml +++ b/examples/llama-3/lora-1b-deduplicate-dpo.yml @@ -49,7 +49,7 @@ output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: false -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -79,7 +79,9 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b-deduplicate-sft.yml b/examples/llama-3/lora-1b-deduplicate-sft.yml index 630ec92f6..c5190d892 100644 --- a/examples/llama-3/lora-1b-deduplicate-sft.yml +++ b/examples/llama-3/lora-1b-deduplicate-sft.yml @@ -22,7 +22,7 @@ dataset_exact_deduplication: true sequence_len: 4096 sample_packing: true eval_sample_packing: false -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -55,9 +55,11 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b-kernels.yml b/examples/llama-3/lora-1b-kernels.yml index a2d07ca49..0bcf46b17 100644 --- a/examples/llama-3/lora-1b-kernels.yml +++ b/examples/llama-3/lora-1b-kernels.yml @@ -14,7 +14,7 @@ lora_model_dir: sequence_len: 2048 sample_packing: true -pad_to_sequence_len: true + lora_r: 16 lora_alpha: 32 @@ -59,9 +59,11 @@ flash_attention: true loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: "<|end_of_text|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b-ray.yml b/examples/llama-3/lora-1b-ray.yml index bb23164eb..46c83348e 100644 --- a/examples/llama-3/lora-1b-ray.yml +++ b/examples/llama-3/lora-1b-ray.yml @@ -15,7 +15,7 @@ lora_model_dir: sequence_len: 2048 sample_packing: true eval_sample_packing: true -pad_to_sequence_len: true + lora_r: 16 lora_alpha: 32 @@ -53,7 +53,7 @@ flash_attention: true loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 @@ -64,3 +64,5 @@ special_tokens: use_ray: true ray_num_workers: 4 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b-sample-packing-sequentially.yml b/examples/llama-3/lora-1b-sample-packing-sequentially.yml index 769dd32e6..dba78597b 100644 --- a/examples/llama-3/lora-1b-sample-packing-sequentially.yml +++ b/examples/llama-3/lora-1b-sample-packing-sequentially.yml @@ -24,7 +24,7 @@ sample_packing: true sample_packing_sequentially: true curriculum_sampling: true eval_sample_packing: false -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -57,9 +57,11 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b.yml b/examples/llama-3/lora-1b.yml index acc17e21f..2ae2f0056 100644 --- a/examples/llama-3/lora-1b.yml +++ b/examples/llama-3/lora-1b.yml @@ -15,7 +15,7 @@ lora_model_dir: sequence_len: 2048 sample_packing: true eval_sample_packing: true -pad_to_sequence_len: true + lora_r: 16 lora_alpha: 32 @@ -54,9 +54,11 @@ flash_attention: true loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: "<|end_of_text|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-8b.yml b/examples/llama-3/lora-8b.yml index ad50cd38a..d72b6527d 100644 --- a/examples/llama-3/lora-8b.yml +++ b/examples/llama-3/lora-8b.yml @@ -18,7 +18,7 @@ output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: true eval_sample_packing: false -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -51,9 +51,11 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/qlora-1b-kto.yaml b/examples/llama-3/qlora-1b-kto.yaml index 89a51ea68..a6a84e7b1 100644 --- a/examples/llama-3/qlora-1b-kto.yaml +++ b/examples/llama-3/qlora-1b-kto.yaml @@ -55,9 +55,11 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 20 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: "<|end_of_text|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/qlora-1b.yml b/examples/llama-3/qlora-1b.yml index 5c8fe6628..1e4f97438 100644 --- a/examples/llama-3/qlora-1b.yml +++ b/examples/llama-3/qlora-1b.yml @@ -18,7 +18,7 @@ lora_model_dir: sequence_len: 2048 sample_packing: true eval_sample_packing: true -pad_to_sequence_len: true + lora_r: 32 lora_alpha: 16 @@ -56,9 +56,11 @@ flash_attention: true loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: "<|end_of_text|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/qlora-fsdp-405b.yaml b/examples/llama-3/qlora-fsdp-405b.yaml index 2b7d51925..8ddb84d65 100644 --- a/examples/llama-3/qlora-fsdp-405b.yaml +++ b/examples/llama-3/qlora-fsdp-405b.yaml @@ -18,7 +18,7 @@ adapter: qlora sequence_len: 2048 sample_packing: true -pad_to_sequence_len: true + lora_r: 16 lora_alpha: 16 @@ -41,7 +41,7 @@ gradient_checkpointing_kwargs: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 @@ -60,3 +60,5 @@ fsdp_config: fsdp_sharding_strategy: FULL_SHARD special_tokens: pad_token: <|finetune_right_pad_id|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/qlora-fsdp-70b.yaml b/examples/llama-3/qlora-fsdp-70b.yaml index 412b6721c..c052bc19d 100644 --- a/examples/llama-3/qlora-fsdp-70b.yaml +++ b/examples/llama-3/qlora-fsdp-70b.yaml @@ -20,7 +20,7 @@ lora_model_dir: sequence_len: 512 sample_packing: false -pad_to_sequence_len: true + lora_r: 8 lora_alpha: 16 @@ -50,7 +50,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 @@ -69,3 +69,5 @@ fsdp_config: fsdp_sharding_strategy: FULL_SHARD special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/qlora.yml b/examples/llama-3/qlora.yml index 4cc9fc3db..a8f47a0e2 100644 --- a/examples/llama-3/qlora.yml +++ b/examples/llama-3/qlora.yml @@ -20,7 +20,7 @@ lora_model_dir: sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + lora_r: 32 lora_alpha: 16 @@ -48,9 +48,11 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: "<|end_of_text|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/sparse-finetuning.yaml b/examples/llama-3/sparse-finetuning.yaml index 1bbb88028..348756b70 100644 --- a/examples/llama-3/sparse-finetuning.yaml +++ b/examples/llama-3/sparse-finetuning.yaml @@ -16,7 +16,7 @@ output_dir: ./outputs/out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + eval_sample_packing: false wandb_project: @@ -47,7 +47,7 @@ logging_steps: 1 xformers_attention: flash_attention: true -warmup_steps: 100 +warmup_ratio: 0.1 evals_per_epoch: 2 eval_table_size: saves_per_epoch: 1 @@ -75,3 +75,5 @@ llmcompressor: ] start: 0 save_compressed: true + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml index 2be94f4ef..b20f79758 100644 --- a/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml +++ b/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml @@ -47,7 +47,7 @@ output_dir: ./outputs/out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + gradient_accumulation_steps: 1 micro_batch_size: 1 @@ -66,7 +66,7 @@ gradient_checkpointing: offload gradient_checkpointing_kwargs: use_reentrant: false -warmup_steps: 20 +warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 @@ -84,5 +84,7 @@ fsdp_config: fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD special_tokens: - pad_token: <|finetune_right_pad_id|> + pad_token: <|finetune_right_pad|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml index eeae872a6..40449009c 100644 --- a/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml +++ b/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml @@ -48,7 +48,7 @@ output_dir: ./outputs/out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + wandb_project: wandb_entity: @@ -69,7 +69,7 @@ tf32: true logging_steps: 1 flash_attention: true -warmup_steps: 100 +warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 @@ -88,5 +88,7 @@ fsdp_config: fsdp_sharding_strategy: FULL_SHARD fsdp_activation_checkpointing: true special_tokens: - pad_token: <|finetune_right_pad_id|> + pad_token: <|finetune_right_pad|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml index 17ad70634..abdc51378 100644 --- a/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml +++ b/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml @@ -51,7 +51,7 @@ output_dir: ./outputs/out sequence_len: 4096 # up to 8k will work on a single H100 sample_packing: true -pad_to_sequence_len: true + wandb_project: wandb_entity: @@ -76,10 +76,12 @@ gradient_checkpointing: offload gradient_checkpointing_kwargs: use_reentrant: false -warmup_steps: 20 +warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: - pad_token: <|finetune_right_pad_id|> + pad_token: <|finetune_right_pad|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml b/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml index eff708e4d..4136dc14a 100644 --- a/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml +++ b/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml @@ -46,7 +46,6 @@ datasets: - path: HuggingFaceH4/llava-instruct-mix-vsft type: chat_template split: train[:1%] - field_messages: messages dataset_prepared_path: last_run_prepared val_set_size: 0.0 @@ -65,7 +64,7 @@ tf32: true logging_steps: 1 flash_attention: true -warmup_steps: 100 +warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 @@ -84,5 +83,7 @@ fsdp_config: fsdp_sharding_strategy: FULL_SHARD fsdp_activation_checkpointing: true special_tokens: - pad_token: <|finetune_right_pad_id|> + pad_token: <|finetune_right_pad|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml index 9a411883e..02c04c691 100644 --- a/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml +++ b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml @@ -46,7 +46,7 @@ output_dir: ./outputs/out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + gradient_accumulation_steps: 1 micro_batch_size: 2 @@ -64,7 +64,7 @@ flex_attn_compile_kwargs: dynamic: false mode: max-autotune-no-cudagraphs -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 @@ -74,7 +74,7 @@ fsdp: fsdp_config: fsdp_version: 2 fsdp_offload_params: false - fsdp_cpu_ram_efficient_loading: true + # fsdp_cpu_ram_efficient_loading: true # does not work with load_in_8bit/4bit fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer fsdp_state_dict_type: SHARDED_STATE_DICT @@ -82,5 +82,7 @@ fsdp_config: fsdp_reshard_after_forward: true fsdp_activation_checkpointing: true special_tokens: - pad_token: <|finetune_right_pad_id|> + pad_token: <|finetune_right_pad|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/scout-qlora-single-h100-flex.yaml b/examples/llama-4/scout-qlora-single-h100-flex.yaml index 20352f81e..33a691189 100644 --- a/examples/llama-4/scout-qlora-single-h100-flex.yaml +++ b/examples/llama-4/scout-qlora-single-h100-flex.yaml @@ -51,7 +51,7 @@ output_dir: ./outputs/out sequence_len: 4096 # up to 8k will work on a single H100 sample_packing: true -pad_to_sequence_len: true + gradient_accumulation_steps: 1 micro_batch_size: 1 @@ -74,11 +74,13 @@ gradient_checkpointing_kwargs: use_reentrant: false logging_steps: 1 -warmup_steps: 20 +warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: - pad_token: <|finetune_right_pad_id|> + pad_token: <|finetune_right_pad|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml index 9fbd34107..5972c2ae3 100644 --- a/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml +++ b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml @@ -45,7 +45,6 @@ datasets: - path: HuggingFaceH4/llava-instruct-mix-vsft type: chat_template split: train[:1%] - field_messages: messages dataset_prepared_path: last_run_prepared val_set_size: 0.0 @@ -67,7 +66,7 @@ flex_attn_compile_kwargs: dynamic: false mode: max-autotune-no-cudagraphs -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 @@ -85,5 +84,7 @@ fsdp_config: fsdp_reshard_after_forward: true fsdp_activation_checkpointing: true special_tokens: - pad_token: <|finetune_right_pad_id|> + pad_token: <|finetune_right_pad|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llava/lora-7b.yaml b/examples/llava/lora-7b.yaml index 5198c8e74..77ef7474d 100644 --- a/examples/llava/lora-7b.yaml +++ b/examples/llava/lora-7b.yaml @@ -11,8 +11,7 @@ datasets: - path: HuggingFaceH4/llava-instruct-mix-vsft type: chat_template split: train[:1%] - field_messages: messages -dataset_prepared_path: last_run_prepared +dataset_prepared_path: val_set_size: 0.0 output_dir: ./outputs/out @@ -53,3 +52,5 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/magistral/README.md b/examples/magistral/README.md new file mode 100644 index 000000000..a09138744 --- /dev/null +++ b/examples/magistral/README.md @@ -0,0 +1,81 @@ +# Finetune Magistral Small with Axolotl + +Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506), [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)), and [2509](https://huggingface.co/mistralai/Magistral-Small-2509) (see [Vision](#vision)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +MistralAI has also released a proprietary medium-sized version called Magistral Medium. + +Thanks to the team at MistralAI for giving us early access to prepare for these releases. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + + Here is an example of how to install from pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.6.0 min) +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' +``` + +2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage + +```bash +python scripts/cutcrossentropy_install.py | sh +``` + +3. Run the finetuning example: + +```bash +axolotl train examples/magistral/magistral-small-qlora.yaml +``` + +This config uses about 24GB VRAM. + +Let us know how it goes. Happy finetuning! 🚀 + +### Thinking + +MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps. + +📚 **[See the Thinking fine-tuning guide →](./think/README.md)** + +### Vision + +MistralAI has released their [2509](https://huggingface.co/mistralai/Magistral-Small-2509) model with vision capabilities. + +📚 **[See the Vision fine-tuning guide →](./vision/README.md)** + +### Tips + +- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`. +- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`. +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) + +## Limitations + +We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only. + +In addition, we do not support overriding tokens yet. + +## Related Resources + +- [MistralAI Magistral Blog](https://mistral.ai/news/magistral/) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) + + +## Future Work + +- Add parity to Preference Tuning, RL, etc. +- Add parity to other tokenizer configs like overriding tokens. diff --git a/examples/magistral/magistral-small-fsdp-qlora.yaml b/examples/magistral/magistral-small-fsdp-qlora.yaml new file mode 100644 index 000000000..d46c49fe0 --- /dev/null +++ b/examples/magistral/magistral-small-fsdp-qlora.yaml @@ -0,0 +1,76 @@ +base_model: mistralai/Magistral-Small-2506 + +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true +eval_sample_packing: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer + fsdp_activation_checkpointing: true + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/magistral/magistral-small-qlora.yaml b/examples/magistral/magistral-small-qlora.yaml new file mode 100644 index 000000000..188924d39 --- /dev/null +++ b/examples/magistral/magistral-small-qlora.yaml @@ -0,0 +1,67 @@ +base_model: mistralai/Magistral-Small-2506 + +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/magistral/think/README.md b/examples/magistral/think/README.md new file mode 100644 index 000000000..29950f59e --- /dev/null +++ b/examples/magistral/think/README.md @@ -0,0 +1,73 @@ +# Magistral Small Thinking Fine-tuning + +This guide covers fine-tuning [Magistral Small 2507](https://huggingface.co/mistralai/Magistral-Small-2507) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections. + +## Prerequisites + +Before starting, ensure you have: +- Installed Axolotl (see [main README](../README.md)) + +## Getting Started + +Run the thinking model fine-tuning: + +```bash +axolotl train magistral-small-think-qlora.yaml +``` + +This config uses about 19.1 GiB VRAM. + +### Tips + +- Dataset uses multi-content format with `type: thinking` support. See [Dataset Format](#dataset-format) below. +- You cannot mix `content: str` and `content: list[dict]`, otherwise, dataset loading will fail. Keep it consistent. + +## Dataset Format + +The thinking model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages. + +Example format: + +```json +{ + "messages": [ + { + "role": "system", + "content": [ + { "type": "text", "text": "{SYSTEM_PROMPT}"} + ] + }, + { + "role": "user", + "content": [ + { "type": "text", "text": "Solve this step by step: What is 15% of 240?"} + ] + }, + { + "role": "assistant", + "content": [ + { + "type": "thinking", + "thinking": "I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36." + }, + { + "type": "text", + "text": "To find 15% of 240, I'll multiply 240 by 0.15:\n\n240 × 0.15 = 36\n\nTherefore, 15% of 240 is 36." + } + ] + } + ] +} +``` + +### Advanced Options + +The `thinking` section supports an optional `closed` parameter: + +```json +{ + "type": "thinking", + "thinking": "Internal reasoning here...", + "closed": true // Default: true, controls adding the closing [/THINK] tag +} +``` diff --git a/examples/magistral/think/magistral-small-think-qlora.yaml b/examples/magistral/think/magistral-small-think-qlora.yaml new file mode 100644 index 000000000..b715b3156 --- /dev/null +++ b/examples/magistral/think/magistral-small-think-qlora.yaml @@ -0,0 +1,67 @@ +base_model: mistralai/Magistral-Small-2507 + +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: Nanobit/text-think-2k-test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/magistral/vision/README.md b/examples/magistral/vision/README.md new file mode 100644 index 000000000..932a3631e --- /dev/null +++ b/examples/magistral/vision/README.md @@ -0,0 +1,60 @@ +# Magistral Small Vision Fine-tuning + +This guide covers fine-tuning [Magistral Small 2509](https://huggingface.co/mistralai/Magistral-Small-2509) with vision capabilities using Axolotl. + +## Prerequisites + +Before starting, ensure you have: +- Installed Axolotl from source (see [main README](../README.md#getting-started)) + +## Getting started + +1. Install the required vision lib: + ```bash + pip install 'mistral-common[opencv]==1.8.5' + ``` + +2. Download the example dataset image: + ```bash + wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg + ``` + +3. Run the fine-tuning: + ```bash + axolotl train magistral-small-vision-24B-qlora.yml + ``` + +This config uses about 17GiB VRAM. + +WARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look. + +### Tips + +Key differences from text-only model: +- `max_tokens: 131072` for inference +- Multi-modal dataset format required +- Sample packing not supported + +## Dataset Format + +The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format). + +One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now. + +Example: +```json +{ + "messages": [ + {"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]}, + {"role": "user", "content": [ + { "type": "text", "text": "What's in this image?"}, + {"type": "image", "path": "path/to/image.jpg" } + ]}, + {"role": "assistant", "content": [{ "type": "text", "text": "..." }]}, + ], +} +``` + +## Limitations + +- Sample Packing is not supported for multi-modality training currently. diff --git a/examples/magistral/vision/magistral-small-vision-24B-qlora.yml b/examples/magistral/vision/magistral-small-vision-24B-qlora.yml new file mode 100644 index 000000000..397db383e --- /dev/null +++ b/examples/magistral/vision/magistral-small-vision-24B-qlora.yml @@ -0,0 +1,64 @@ +base_model: mistralai/Magistral-Small-2509 +processor_type: AutoProcessor + +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_4bit: true + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +# sample dataset below requires downloading image in advance +# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg +datasets: + - path: Nanobit/text-vision-2k-test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./outputs/out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml index 3d4583932..e6b335804 100644 --- a/examples/mamba/config.yml +++ b/examples/mamba/config.yml @@ -41,10 +41,12 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: tokens: save_safetensors: False + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/bigstral-ds-zero3.yaml b/examples/mistral/bigstral/bigstral-ds-zero3.yaml similarity index 92% rename from examples/mistral/bigstral-ds-zero3.yaml rename to examples/mistral/bigstral/bigstral-ds-zero3.yaml index f626a92a1..a8dc36216 100644 --- a/examples/mistral/bigstral-ds-zero3.yaml +++ b/examples/mistral/bigstral/bigstral-ds-zero3.yaml @@ -27,7 +27,7 @@ output_dir: ./outputs/out sequence_len: 2048 sample_packing: true -pad_to_sequence_len: true + gradient_accumulation_steps: 1 micro_batch_size: 1 @@ -53,3 +53,5 @@ special_tokens: eos_token: "<|im_end|>" tokens: - "<|im_start|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml index 15edffb44..e74162537 100644 --- a/examples/mistral/config.yml +++ b/examples/mistral/config.yml @@ -14,7 +14,7 @@ output_dir: ./outputs/out sequence_len: 8192 sample_packing: true -pad_to_sequence_len: true + eval_sample_packing: false wandb_project: @@ -38,8 +38,10 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mistral-dpo-qlora.yml b/examples/mistral/dpo/mistral-dpo-qlora.yml similarity index 93% rename from examples/mistral/mistral-dpo-qlora.yml rename to examples/mistral/dpo/mistral-dpo-qlora.yml index af707973f..8fea14a0f 100644 --- a/examples/mistral/mistral-dpo-qlora.yml +++ b/examples/mistral/dpo/mistral-dpo-qlora.yml @@ -31,7 +31,7 @@ output_dir: ./outputs/dpo-qlora sequence_len: 2048 sample_packing: false -pad_to_sequence_len: true + adapter: qlora lora_model_dir: @@ -73,10 +73,12 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: false -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: bos_token: "<|im_start|>" eos_token: "<|im_end|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/lora.yml b/examples/mistral/lora.yml index 9af4274fd..757287f19 100644 --- a/examples/mistral/lora.yml +++ b/examples/mistral/lora.yml @@ -20,7 +20,7 @@ lora_model_dir: sequence_len: 8192 sample_packing: true -pad_to_sequence_len: true + lora_r: 32 lora_alpha: 16 @@ -59,8 +59,10 @@ flash_attention: true loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mistral-qlora-fsdp.yml b/examples/mistral/mistral-qlora-fsdp.yml index e234b19a2..8e1f03d24 100644 --- a/examples/mistral/mistral-qlora-fsdp.yml +++ b/examples/mistral/mistral-qlora-fsdp.yml @@ -56,7 +56,7 @@ flash_attention: true loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 @@ -74,3 +74,5 @@ fsdp_config: fsdp_state_dict_type: FULL_STATE_DICT fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml b/examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml new file mode 100644 index 000000000..ec197f333 --- /dev/null +++ b/examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml @@ -0,0 +1,62 @@ +base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503 +processor_type: AutoProcessor + +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + +load_in_8bit: true + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +# sample dataset below requires downloading image in advance +# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg +datasets: + - path: Nanobit/text-vision-2k-test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 2048 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml b/examples/mistral/mixtral/mixtral-8x22b-qlora-fsdp.yml similarity index 92% rename from examples/mistral/mixtral-8x22b-qlora-fsdp.yml rename to examples/mistral/mixtral/mixtral-8x22b-qlora-fsdp.yml index af6ba5a76..dc7bd9c37 100644 --- a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml +++ b/examples/mistral/mixtral/mixtral-8x22b-qlora-fsdp.yml @@ -54,7 +54,7 @@ flash_attention: true loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 @@ -72,3 +72,5 @@ fsdp_config: fsdp_state_dict_type: FULL_STATE_DICT fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mixtral-qlora-fsdp.yml b/examples/mistral/mixtral/mixtral-qlora-fsdp.yml similarity index 93% rename from examples/mistral/mixtral-qlora-fsdp.yml rename to examples/mistral/mixtral/mixtral-qlora-fsdp.yml index b1843a138..5151e1292 100644 --- a/examples/mistral/mixtral-qlora-fsdp.yml +++ b/examples/mistral/mixtral/mixtral-qlora-fsdp.yml @@ -56,7 +56,7 @@ flash_attention: true loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 @@ -77,3 +77,5 @@ fsdp_config: fsdp_forward_prefetch: false fsdp_backward_prefetch: BACKWARD_PRE special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral/mixtral.yml similarity index 93% rename from examples/mistral/mixtral.yml rename to examples/mistral/mixtral/mixtral.yml index 4c256420c..d1981a699 100644 --- a/examples/mistral/mixtral.yml +++ b/examples/mistral/mixtral/mixtral.yml @@ -34,7 +34,7 @@ lora_model_dir: sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + lora_r: 32 lora_alpha: 16 @@ -74,10 +74,12 @@ flash_attention: true loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 deepspeed: deepspeed_configs/zero2.json weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mixtral_22.yml b/examples/mistral/mixtral/mixtral_22.yml similarity index 92% rename from examples/mistral/mixtral_22.yml rename to examples/mistral/mixtral/mixtral_22.yml index 25e1d7155..0b606b7d7 100644 --- a/examples/mistral/mixtral_22.yml +++ b/examples/mistral/mixtral/mixtral_22.yml @@ -25,7 +25,7 @@ output_dir: ./outputs/out sequence_len: 8000 sample_packing: true -pad_to_sequence_len: true + gradient_accumulation_steps: 1 micro_batch_size: 1 @@ -51,3 +51,5 @@ special_tokens: eos_token: "<|im_end|>" tokens: - "<|im_start|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/lora-mps.yml b/examples/mistral/mps/lora-mps.yml similarity index 91% rename from examples/mistral/lora-mps.yml rename to examples/mistral/mps/lora-mps.yml index e6f46affb..07ce191dc 100644 --- a/examples/mistral/lora-mps.yml +++ b/examples/mistral/mps/lora-mps.yml @@ -18,7 +18,7 @@ lora_model_dir: sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + lora_r: 32 lora_alpha: 16 @@ -59,8 +59,10 @@ sdp_attention: true loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mistral-qlora-orpo.yml b/examples/mistral/orpo/mistral-qlora-orpo.yml similarity index 91% rename from examples/mistral/mistral-qlora-orpo.yml rename to examples/mistral/orpo/mistral-qlora-orpo.yml index 6c0212b7c..850d286f3 100644 --- a/examples/mistral/mistral-qlora-orpo.yml +++ b/examples/mistral/orpo/mistral-qlora-orpo.yml @@ -25,7 +25,7 @@ lora_model_dir: sequence_len: 4096 sample_packing: false -pad_to_sequence_len: true + lora_r: 32 lora_alpha: 16 @@ -64,8 +64,10 @@ flash_attention: true loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml index 607e33701..2a7495e95 100644 --- a/examples/mistral/qlora.yml +++ b/examples/mistral/qlora.yml @@ -20,7 +20,7 @@ lora_model_dir: sequence_len: 8192 sample_packing: true -pad_to_sequence_len: true + lora_r: 32 lora_alpha: 16 @@ -59,8 +59,10 @@ flash_attention: true loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/orpheus/finetune.yml b/examples/orpheus/finetune.yml index 9bcbbeee0..f4bc8054e 100644 --- a/examples/orpheus/finetune.yml +++ b/examples/orpheus/finetune.yml @@ -18,7 +18,7 @@ output_dir: ./outputs/out sequence_len: 8192 sample_packing: true -pad_to_sequence_len: true + wandb_project: wandb_entity: @@ -43,10 +43,12 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 20 +warmup_ratio: 0.1 evals_per_epoch: 5 saves_per_epoch: 5 weight_decay: 0.05 special_tokens: pad_token: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/lora-3.5.yaml b/examples/phi/lora-3.5.yaml index ad4ce9cd4..c10014dab 100644 --- a/examples/phi/lora-3.5.yaml +++ b/examples/phi/lora-3.5.yaml @@ -12,15 +12,6 @@ chat_template: phi_3 datasets: - path: fozziethebeat/alpaca_messages_2k_test type: chat_template - field_messages: messages - message_property_mappings: - role: role - content: content - roles: - user: - - user - assistant: - - assistant dataset_prepared_path: val_set_size: 0.05 @@ -28,7 +19,7 @@ output_dir: ./outputs/lora-out sequence_len: 4096 sample_packing: false -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -59,7 +50,9 @@ gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 4 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/phi-ft.yml b/examples/phi/phi-ft.yml index 1562a7353..717a45929 100644 --- a/examples/phi/phi-ft.yml +++ b/examples/phi/phi-ft.yml @@ -15,7 +15,7 @@ output_dir: ./outputs/phi-sft-out sequence_len: 2048 sample_packing: true -pad_to_sequence_len: true + adapter: lora_model_dir: @@ -50,10 +50,12 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 100 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.1 resize_token_embeddings_to_32x: true special_tokens: pad_token: "<|endoftext|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/phi-qlora.yml b/examples/phi/phi-qlora.yml index 4cd53db97..0fe1abea5 100644 --- a/examples/phi/phi-qlora.yml +++ b/examples/phi/phi-qlora.yml @@ -18,7 +18,7 @@ output_dir: ./outputs/phi-sft-out sequence_len: 2048 sample_packing: true -pad_to_sequence_len: true + adapter: qlora lora_model_dir: @@ -53,10 +53,12 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 100 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.1 resize_token_embeddings_to_32x: true special_tokens: pad_token: "<|endoftext|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/phi2-ft.yml b/examples/phi/phi2-ft.yml index ca733cc71..e470c0d24 100644 --- a/examples/phi/phi2-ft.yml +++ b/examples/phi/phi2-ft.yml @@ -15,7 +15,7 @@ output_dir: ./outputs/phi-sft-out sequence_len: 2048 sample_packing: true -pad_to_sequence_len: true + adapter: lora_model_dir: @@ -50,10 +50,12 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 100 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.1 resize_token_embeddings_to_32x: true special_tokens: pad_token: "<|endoftext|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/phi3-ft-fsdp.yml b/examples/phi/phi3-ft-fsdp.yml index d0d14fea6..1793737b5 100644 --- a/examples/phi/phi3-ft-fsdp.yml +++ b/examples/phi/phi3-ft-fsdp.yml @@ -15,7 +15,7 @@ output_dir: ./phi-sft-out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + trust_remote_code: true adapter: @@ -51,7 +51,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 100 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.1 @@ -71,3 +71,5 @@ fsdp_config: resize_token_embeddings_to_32x: true special_tokens: pad_token: "<|endoftext|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/phi3-ft.yml b/examples/phi/phi3-ft.yml index 17c48da6f..0b204963c 100644 --- a/examples/phi/phi3-ft.yml +++ b/examples/phi/phi3-ft.yml @@ -18,7 +18,7 @@ output_dir: ./out sequence_len: 4096 sample_packing: true -pad_to_sequence_len: true + adapter: lora lora_model_dir: @@ -59,3 +59,5 @@ warmup_ratio: 0.2 debug: true weight_decay: 0.1 resize_token_embeddings_to_32x: true + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/pixtral/lora-12b.yml b/examples/pixtral/lora-12b.yml index 6ad0a5e99..0e6489914 100644 --- a/examples/pixtral/lora-12b.yml +++ b/examples/pixtral/lora-12b.yml @@ -11,8 +11,7 @@ datasets: - path: HuggingFaceH4/llava-instruct-mix-vsft type: chat_template split: train[:1%] - field_messages: messages -dataset_prepared_path: last_run_prepared +dataset_prepared_path: val_set_size: 0.0 output_dir: ./outputs/out @@ -46,8 +45,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet -eager_attention: +flash_attention: true warmup_ratio: 0.1 evals_per_epoch: 1 @@ -55,3 +53,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen2-vl/lora-7b.yaml b/examples/qwen2-vl/lora-7b.yaml index 55773bc3d..285a35cbb 100644 --- a/examples/qwen2-vl/lora-7b.yaml +++ b/examples/qwen2-vl/lora-7b.yaml @@ -11,7 +11,7 @@ datasets: - path: HuggingFaceH4/llava-instruct-mix-vsft type: chat_template split: train[:1%] - field_messages: messages + dataset_prepared_path: last_run_prepared val_set_size: 0.0 output_dir: ./outputs/out @@ -25,7 +25,7 @@ pad_to_sequence_len: false lora_r: 32 lora_alpha: 16 lora_dropout: 0.05 -lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' wandb_project: wandb_entity: @@ -53,3 +53,5 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen2/dpo.yaml b/examples/qwen2/dpo.yaml index bd896c2b3..3e87766d6 100644 --- a/examples/qwen2/dpo.yaml +++ b/examples/qwen2/dpo.yaml @@ -27,7 +27,7 @@ output_dir: ./outputs/dpo-out sequence_len: 2048 sample_packing: false -pad_to_sequence_len: true + wandb_project: wandb_entity: @@ -50,7 +50,9 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen2/prm.yaml b/examples/qwen2/prm.yaml index 4afa24f3c..a709a598d 100644 --- a/examples/qwen2/prm.yaml +++ b/examples/qwen2/prm.yaml @@ -22,7 +22,7 @@ remove_unused_columns: false sequence_len: 2048 sample_packing: false eval_sample_packing: false -pad_to_sequence_len: true + wandb_project: wandb_entity: @@ -55,3 +55,5 @@ eval_steps: 100 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen2/qlora-fsdp.yaml b/examples/qwen2/qlora-fsdp.yaml index ed2670ab6..337619b61 100644 --- a/examples/qwen2/qlora-fsdp.yaml +++ b/examples/qwen2/qlora-fsdp.yaml @@ -17,7 +17,7 @@ output_dir: ./outputs/out sequence_len: 2048 sample_packing: true eval_sample_packing: true -pad_to_sequence_len: true + adapter: qlora lora_model_dir: @@ -49,7 +49,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 @@ -67,3 +67,5 @@ fsdp_config: fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen2/reward-model.yaml b/examples/qwen2/reward-model.yaml index 822407a1f..08b8b4552 100644 --- a/examples/qwen2/reward-model.yaml +++ b/examples/qwen2/reward-model.yaml @@ -18,7 +18,7 @@ remove_unused_columns: false sequence_len: 2048 sample_packing: false eval_sample_packing: false -pad_to_sequence_len: true + wandb_project: wandb_entity: @@ -26,7 +26,6 @@ wandb_watch: wandb_name: wandb_log_model: - gradient_accumulation_steps: 4 micro_batch_size: 2 num_epochs: 4 @@ -50,3 +49,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mistral-small-3.1-24B-lora.yml b/examples/qwen2_5-vl/lora-7b.yaml similarity index 74% rename from examples/mistral/mistral-small-3.1-24B-lora.yml rename to examples/qwen2_5-vl/lora-7b.yaml index 3e3b45862..7d499d841 100644 --- a/examples/mistral/mistral-small-3.1-24B-lora.yml +++ b/examples/qwen2_5-vl/lora-7b.yaml @@ -1,27 +1,25 @@ -base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503 +base_model: Qwen/Qwen2.5-VL-7B-Instruct processor_type: AutoProcessor -load_in_8bit: true - # these 3 lines are needed for now to handle vision chat templates w images skip_prepare_dataset: true remove_unused_columns: false sample_packing: false -chat_template: mistral_v7_tekken +chat_template: qwen2_vl datasets: - path: HuggingFaceH4/llava-instruct-mix-vsft type: chat_template split: train[:1%] - field_messages: messages + dataset_prepared_path: last_run_prepared -val_set_size: 0.01 +val_set_size: 0.0 output_dir: ./outputs/out adapter: lora lora_model_dir: -sequence_len: 2048 +sequence_len: 8192 pad_to_sequence_len: false lora_r: 32 @@ -35,7 +33,7 @@ wandb_watch: wandb_name: wandb_log_model: -gradient_accumulation_steps: 1 +gradient_accumulation_steps: 4 micro_batch_size: 1 num_epochs: 1 optimizer: adamw_bnb_8bit @@ -48,11 +46,12 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet. +flash_attention: true eager_attention: warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 -special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen3-next/README.md b/examples/qwen3-next/README.md new file mode 100644 index 000000000..678175fd4 --- /dev/null +++ b/examples/qwen3-next/README.md @@ -0,0 +1,64 @@ +# Finetune Qwen3-Next with Axolotl + +[Qwen3-Next](https://huggingface.co/collections/Qwen/qwen3-next-68c25fd6838e585db8eeea9d) represents the next-generation foundation models optimized for extreme context length and large-scale parameter efficiency. The series introduces architectural innovations including Hybrid Attention (Gated DeltaNet + Gated Attention), High-Sparsity MoE with 1:50 activation ratio, and Multi-Token Prediction for enhanced performance and inference acceleration. + +This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). + + Here is an example of how to install from main for pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.6.0 min) +git clone https://github.com/axolotl-ai-cloud/axolotl.git +cd axolotl + +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation -e '.[flash-attn]' + +# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy +python scripts/cutcrossentropy_install.py | sh +``` + +2. Install Qwen3-Next transformers commit +```bash +pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654" +``` + +3. Install FLA for improved performance +```bash +pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2 +``` + +4. Run the finetuning example: + +```bash +axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml +``` + +This config uses about 45.62 GiB VRAM. + +Let us know how it goes. Happy finetuning! 🚀 + +### TIPS + +- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`. +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. See [Multi-GPU](#optimization-guides) section below. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) + +## Related Resources + +- [Qwen3-Next Blog](https://qwenlm.github.io/blog/qwen3_next/) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml b/examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml new file mode 100644 index 000000000..db841beab --- /dev/null +++ b/examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml @@ -0,0 +1,68 @@ +base_model: Qwen/Qwen3-Next-80B-A3B-Instruct + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 16 +lora_alpha: 8 +lora_dropout: 0.05 +lora_target_modules: + - linear_attn.in_proj_ba + - linear_attn.in_proj_qkvz + - linear_attn.out_proj + - shared_expert.up_proj + - shared_expert.down_proj + - shared_expert.gate_proj + - shared_expert_gate + - mlp.gate + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen3/32b-qlora.yaml b/examples/qwen3/32b-qlora.yaml index 45a4395ac..f4a4f2816 100644 --- a/examples/qwen3/32b-qlora.yaml +++ b/examples/qwen3/32b-qlora.yaml @@ -22,7 +22,7 @@ dataset_prepared_path: last_run_prepared sequence_len: 2048 sample_packing: true eval_sample_packing: true -pad_to_sequence_len: true + load_in_4bit: true adapter: qlora @@ -62,8 +62,10 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen3/8b-qat-fsdp2.yml b/examples/qwen3/8b-qat-fsdp2.yml index 6832b6af7..cfbe5a4b7 100644 --- a/examples/qwen3/8b-qat-fsdp2.yml +++ b/examples/qwen3/8b-qat-fsdp2.yml @@ -24,7 +24,7 @@ output_dir: ./outputs/qat_out/ sequence_len: 2048 sample_packing: true flex_attention: true -pad_to_sequence_len: true + flex_attn_compile_kwargs: dynamic: false @@ -58,7 +58,7 @@ logging_steps: 1 evals_per_epoch: 1 saves_per_epoch: 1 -warmup_steps: 10 +warmup_ratio: 0.1 weight_decay: 0.0 fsdp: - full_shard @@ -76,3 +76,5 @@ fsdp_config: fsdp_activation_checkpointing: true special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen3/qlora-fsdp.yaml b/examples/qwen3/qlora-fsdp.yaml index dc3377b4f..e4d584dc7 100644 --- a/examples/qwen3/qlora-fsdp.yaml +++ b/examples/qwen3/qlora-fsdp.yaml @@ -16,7 +16,7 @@ output_dir: ./outputs/out sequence_len: 2048 sample_packing: true eval_sample_packing: true -pad_to_sequence_len: true + adapter: qlora lora_model_dir: @@ -48,7 +48,7 @@ resume_from_checkpoint: logging_steps: 1 flash_attention: true -warmup_steps: 10 +warmup_ratio: 0.1 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 @@ -66,3 +66,5 @@ fsdp_config: fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen3/reward-model.yaml b/examples/qwen3/reward-model.yaml new file mode 100644 index 000000000..43c62ecc4 --- /dev/null +++ b/examples/qwen3/reward-model.yaml @@ -0,0 +1,44 @@ +base_model: Skywork/Skywork-Reward-V2-Qwen3-8B +model_type: AutoModelForSequenceClassification +num_labels: 1 + +reward_model: true +center_rewards_coefficient: 0.01 # Incentivize mean-zero rewards for improved stability +chat_template: qwen3 +datasets: + - path: argilla/distilabel-intel-orca-dpo-pairs + type: bradley_terry.chat_template + +val_set_size: 0.0 +output_dir: ./outputs/out + +sequence_len: 8192 +sample_packing: false +eval_sample_packing: false +pad_to_sequence_len: true + +deepspeed: deepspeed_configs/zero1.json + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +eval_batch_size: 1 +num_epochs: 3 +optimizer: adamw_bnb_8bit +lr_scheduler: linear +learning_rate: 0.00002 + +bf16: true +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +warmup_ratio: 0.1 +logging_steps: 1 +weight_decay: 0.01 diff --git a/examples/seed-oss/README.md b/examples/seed-oss/README.md new file mode 100644 index 000000000..5610c1316 --- /dev/null +++ b/examples/seed-oss/README.md @@ -0,0 +1,54 @@ +# Finetune ByteDance's Seed-OSS with Axolotl + +[Seed-OSS](https://huggingface.co/collections/ByteDance-Seed/seed-oss-68a609f4201e788db05b5dcd) are a series of 36B parameter open source models trained by ByteDance's Seed Team. + +This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Seed-OSS is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). + + Here is an example of how to install from main for pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.6.0 min) +git clone https://github.com/axolotl-ai-cloud/axolotl.git +cd axolotl + +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation -e '.[flash-attn]' + +# Install Cut Cross Entropy +python scripts/cutcrossentropy_install.py | sh +``` + +2. Run the finetuning example: + +```bash +axolotl train examples/seed-oss/seed-oss-36b-qlora.yaml +``` + +This config uses about 27.7 GiB VRAM. + +Let us know how it goes. Happy finetuning! 🚀 + +### TIPS + +- For inference, the official Seed Team recommends `top_p=0.95` and `temperature=1.1`. +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) + +## Related Resources + +- [ByteDance Seed Website](https://seed.bytedance.com/) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/seed-oss/seed-oss-36b-qlora.yaml b/examples/seed-oss/seed-oss-36b-qlora.yaml new file mode 100644 index 000000000..00e7cf3eb --- /dev/null +++ b/examples/seed-oss/seed-oss-36b-qlora.yaml @@ -0,0 +1,56 @@ +base_model: ByteDance-Seed/Seed-OSS-36B-Instruct + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/slurm/README.md b/examples/slurm/README.md new file mode 100644 index 000000000..4c116b713 --- /dev/null +++ b/examples/slurm/README.md @@ -0,0 +1,66 @@ +# SLURM Multi-Node Training + +This directory contains an example SLURM script for running Axolotl training jobs across multiple nodes in a SLURM cluster. + +## Prerequisites + +- Access to a SLURM cluster with GPU nodes +- Axolotl installed on all nodes (see [installation docs](https://docs.axolotl.ai/docs/installation.html)) + +## Usage + +### Standard SLURM Clusters + +1. Copy [`axolotl.slurm`](./axolotl.slurm) to your working directory. +2. Place your Axolotl config file (`train.yaml`) in the same directory. +3. Set the appropriate environment variables for the job: + ```bash + export HF_TOKEN="your-huggingface-token" + + # metric tracking + # export WANDB_API_KEY="your-wandb-api-key" + # ... + ``` +4. Submit the job: + ```bash + sbatch --export=ALL,NUM_NODES=2,NUM_TRAINERS=8,PRIMARY_ADDR=,PRIMARY_PORT=29400 axolotl.slurm + ``` + + Where: + - `NUM_NODES`: Number of nodes to use + - `NUM_TRAINERS`: GPUs per node (typically 8) + - `PRIMARY_ADDR`: Hostname/IP of the master node + - `PRIMARY_PORT`: Port for distributed training (default: 29400) + +5. (Optional) Run other slurm commands: + ```bash + # check job info + scontrol show job axolotl-cli + + # check job queue + squeue + + # check cluster status + sinfo + ``` + +### RunPod Instant Clusters + +Axolotl works with RunPod Instant Clusters. This feature provides managed SLURM clusters with zero configuration. + +1. **Deploy a SLURM Cluster**: + - Go to [RunPod Instant Clusters](https://console.runpod.io/cluster) + - Click "Create a Cluster" + - Choose your GPU type, node count, and region + - Choose an [Axolotl cloud docker image](https://docs.axolotl.ai/docs/docker.html#cloud) + - Deploy the cluster + +2. **Connect to the Controller Node**: Find the controller node in the RunPod console and connect via SSH + +3. **Follow the instructions in [Standard SLURM Clusters](#standard-slurm-clusters)** + +## Additional Resources + +- [Axolotl Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [SLURM Documentation](https://slurm.schedmd.com/documentation.html) +- [RunPod SLURM Clusters Guide](https://docs.runpod.io/instant-clusters/slurm-clusters) diff --git a/examples/slurm/axolotl.slurm b/examples/slurm/axolotl.slurm new file mode 100644 index 000000000..741d68ced --- /dev/null +++ b/examples/slurm/axolotl.slurm @@ -0,0 +1,20 @@ +#!/bin/bash +# Prior to running this script, export your HF_TOKEN and WANDB_API_KEY to your environment; i.e. +# export HF_TOKEN="..." +# export WANDB_API_KEY="..." +# + +# ---------- SBATCH commands ---------- # +#SBATCH --job-name=axolotl-slurm-multinode +#SBATCH --ntasks-per-node=1 +#SBATCH --nodes=$NUM_NODES +#SBATCH --gpus-per-task=8 +#SBATCH --cpus-per-task=128 + +export TORCH_DIST_INIT_BARRIER=0 + +srun axolotl preprocess train.yaml + +srun axolotl train train.yaml --launcher torchrun -- \ + --nproc_per_node=$NUM_TRAINERS --nnodes=$NUM_NODES \ + --rdzv_id axolotl-cli --rdzv_backend c10d --rdzv_endpoint "${PRIMARY_ADDR}:${PRIMARY_PORT}" --rdzv-conf="join_timeout=1800" diff --git a/examples/smolvlm2/README.md b/examples/smolvlm2/README.md new file mode 100644 index 000000000..9c0ae4836 --- /dev/null +++ b/examples/smolvlm2/README.md @@ -0,0 +1,49 @@ +# Finetune SmolVLM2 with Axolotl + +[SmolVLM2](https://huggingface.co/collections/HuggingFaceTB/smolvlm2-smallest-video-lm-ever-67ab6b5e84bf8aaa60cb17c7) are a family of lightweight, open-source multimodal models from HuggingFace designed to analyze and understand video, image, and text content. + +These models are built for efficiency, making them well-suited for on-device applications where computational resources are limited. Models are available in multiple sizes, including 2.2B, 500M, and 256M. + +This guide shows how to fine-tune SmolVLM2 models with Axolotl. + +## Getting Started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + + Here is an example of how to install from pip: + ```bash + # Ensure you have a compatible version of Pytorch installed + pip3 install packaging setuptools wheel ninja + pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' + ``` + +2. Install an extra dependency: + + ```bash + pip3 install num2words==0.5.14 + ``` + +3. Run the finetuning example: + + ```bash + # LoRA SFT (1x48GB @ 6.8GiB) + axolotl train examples/smolvlm2/smolvlm2-2B-lora.yaml + ``` + +## TIPS + +- **Dataset Format**: For video finetuning, your dataset must be compatible with the multi-content Messages format. For more details, see our documentation on [Multimodal Formats](https://docs.axolotl.ai/docs/multimodal.html#dataset-format). +- **Dataset Loading**: Read more on how to prepare and load your own datasets in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html). + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) + +## Related Resources + +- [SmolVLM2 Blog](https://huggingface.co/blog/smolvlm2) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/smolvlm2/smolvlm2-2B-lora.yaml b/examples/smolvlm2/smolvlm2-2B-lora.yaml new file mode 100644 index 000000000..1aeff408d --- /dev/null +++ b/examples/smolvlm2/smolvlm2-2B-lora.yaml @@ -0,0 +1,56 @@ +base_model: HuggingFaceTB/SmolVLM2-2.2B-Instruct +trust_remote_code: true +processor_type: AutoProcessor + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.text_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/streaming/README.md b/examples/streaming/README.md new file mode 100644 index 000000000..cdbb5baea --- /dev/null +++ b/examples/streaming/README.md @@ -0,0 +1,50 @@ +# Streaming Dataset Examples + +This directory contains example configurations for using Axolotl's streaming dataset +functionality, which enables memory-efficient training with large datasets. + +## Examples + +Run the following examples with e.g. `axolotl train examples/streaming/sft.yaml`; no +`axolotl preprocess` required! + +### Pretraining (`pretrain.yaml`) + +Demonstrates streaming configuration for pretraining tasks using the fineweb-edu dataset +with SmolLM2-135M. + +- Uses `pretraining_dataset` configuration for automatic streaming +- Multipack attention control to prevent cross-attention between packed sequences +- Buffer size configuration for memory management + +### SFT (`sft.yaml`) + +Shows how to use streaming for supervised fine-tuning with the Alpaca dataset. + +- Explicit `streaming: true` flag for SFT datasets +- Memory-efficient training on instruction datasets +- Evaluation datasets are currently not streamed + +## Key Configuration Options + +### `streaming` +- Enables streaming mode for standard datasets +- Automatically enabled for `pretraining_dataset` + +### `streaming_multipack_buffer_size` +- Controls buffer size for sample packing (default: 10,000) +- Larger values improve packing efficiency but use more memory +- Adjust based on available memory + +### `shuffle_merged_datasets` +- Enables shuffling of streaming datasets +- Requires additional memory for shuffle buffer + +### `sample_packing` +- Packs multiple samples into single sequences +- Minimize per-step padding tokens + +## Performance Tips + +- Download small / frequently-used datasets locally for better performance +- Larger buffer sizes improve packing efficiency diff --git a/examples/streaming/pretrain.yaml b/examples/streaming/pretrain.yaml new file mode 100644 index 000000000..bc8edefd6 --- /dev/null +++ b/examples/streaming/pretrain.yaml @@ -0,0 +1,57 @@ +base_model: HuggingFaceTB/SmolLM2-135M + +# Streaming pretraining configuration +pretraining_dataset: + - path: HuggingFaceFW/fineweb-edu + name: sample-10BT + type: pretrain + text_column: text + split: train + +# Streaming-specific settings +streaming_multipack_buffer_size: 10000 +shuffle_merged_datasets: true + +# Training configuration +max_steps: 1000 +output_dir: ./outputs/smollm2-135m-pretrain-streaming + +# Sequence and packing settings +sequence_len: 1024 +sample_packing: true +pretrain_multipack_attn: true # Prevent cross-attention between packed sequences +flash_attention: true + +# Batch size settings +gradient_accumulation_steps: 8 +micro_batch_size: 1 + +# Optimizer and scheduler +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 5e-4 +warmup_ratio: 0.1 +weight_decay: 0.01 + +# Precision and performance +bf16: auto +tf32: true + +# Logging and checkpointing +logging_steps: 10 +save_strategy: steps +save_steps: 250 +save_total_limit: 3 + +# Weights & Biases (optional) +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +# Special tokens +special_tokens: + pad_token: "<|endoftext|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/streaming/sft.yaml b/examples/streaming/sft.yaml new file mode 100644 index 000000000..47b9f493f --- /dev/null +++ b/examples/streaming/sft.yaml @@ -0,0 +1,55 @@ +base_model: HuggingFaceTB/SmolLM2-135M + +# Dataset configuration +datasets: + - path: tatsu-lab/alpaca + type: alpaca + split: train + +# Streaming-specific settings +streaming: true +streaming_multipack_buffer_size: 10000 +shuffle_merged_datasets: true + +# Training configuration +max_steps: 1000 +output_dir: ./outputs/smollm2-135m-sft-streaming + +# Sequence and packing settings +sequence_len: 1024 +sample_packing: true +flash_attention: true + +# Batch size settings +gradient_accumulation_steps: 4 +micro_batch_size: 1 + +# Optimizer and scheduler +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 2e-4 +warmup_ratio: 0.1 +weight_decay: 0.0 + +# Precision and performance +bf16: auto +tf32: true + +# Logging and checkpointing +logging_steps: 10 +save_strategy: steps +save_steps: 100 +save_total_limit: 3 + +# Weights & Biases (optional) +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +# Special tokens +special_tokens: + pad_token: "<|endoftext|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/voxtral/README.md b/examples/voxtral/README.md new file mode 100644 index 000000000..b77691d72 --- /dev/null +++ b/examples/voxtral/README.md @@ -0,0 +1,83 @@ +# Finetune Voxtral with Axolotl + +Voxtral is a [3B](https://huggingface.co/mistralai/Voxtral-Mini-3B-2507)/[24B](https://huggingface.co/mistralai/Voxtral-Small-24B-2507) parameter opensource model from MistralAI found on HuggingFace. This guide shows how to fine-tune it with Axolotl. + +Thanks to the team at MistralAI for giving us early access to prepare for this release. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + + Here is an example of how to install from pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.6.0 min) +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' +``` + +2. Please install the below. + +```bash +# audio +pip3 install librosa==0.11.0 +pip3 install 'mistral_common[audio]==1.8.3' + +# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy +python scripts/cutcrossentropy_install.py | sh +``` + +3. Download sample dataset files + +```bash +# for text + audio only +wget https://huggingface.co/datasets/Nanobit/text-audio-2k-test/resolve/main/En-us-African_elephant.oga +``` + +4. Run the finetuning example: + +```bash +# text only +axolotl train examples/voxtral/voxtral-mini-qlora.yml + +# text + audio +axolotl train examples/voxtral/voxtral-mini-audio-qlora.yml +``` + +These configs use about 4.8 GB VRAM. + +Let us know how it goes. Happy finetuning! 🚀 + +### TIPS + +- For inference, the official MistralAI team recommends `temperature: 0.2` and `top_p: 0.95` for audio understanding and `temperature: 0.0` for transcription. +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). +- The multimodal dataset format follows the OpenAI multi-content Messages format as seen [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format). + + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) + +## Limitations + +We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only. + +In addition, we do not support overriding tokens yet. + +## Related Resources + +- [MistralAI Magistral Blog](https://mistral.ai/news/magistral/) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) + +## Future Work + +- Add parity to Preference Tuning, RL, etc. +- Add parity to other tokenizer configs like overriding tokens. diff --git a/examples/voxtral/voxtral-mini-audio-qlora.yml b/examples/voxtral/voxtral-mini-audio-qlora.yml new file mode 100644 index 000000000..8fe6adbff --- /dev/null +++ b/examples/voxtral/voxtral-mini-audio-qlora.yml @@ -0,0 +1,78 @@ +base_model: mistralai/Voxtral-Mini-3B-2507 +processor_type: AutoProcessor + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +# for use with fft to only train on language model layers +# unfrozen_parameters: + # - language_model.model.* + # - lm_head + # - embed_tokens + +load_in_4bit: true + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +# gemma3 doesn't seem to play nice with ddp +ddp_find_unused_parameters: true + +eot_tokens: + - + +# sample dataset below requires downloading audio/image in advance +# wget https://huggingface.co/datasets/Nanobit/text-audio-2k-test/resolve/main/En-us-African_elephant.oga +datasets: + - path: NanoBit/text-audio-2k-test + type: chat_template +dataset_prepared_path: +val_set_size: 0.01 +output_dir: ./outputs/out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 diff --git a/examples/voxtral/voxtral-mini-qlora.yml b/examples/voxtral/voxtral-mini-qlora.yml new file mode 100644 index 000000000..bdbc5f867 --- /dev/null +++ b/examples/voxtral/voxtral-mini-qlora.yml @@ -0,0 +1,73 @@ +base_model: mistralai/Voxtral-Mini-3B-2507 + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +# for use with fft to only train on language model layers +# unfrozen_parameters: + # - language_model.model.* + # - lm_head + # - embed_tokens + +eot_tokens: + - +datasets: + - path: cgato/SlimOrcaDedupCleaned + type: chat_template + split: train[:1%] + field_messages: conversations + message_property_mappings: + role: from + content: value + +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: qlora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj' + +sequence_len: 2048 +sample_packing: true +eval_sample_packing: true +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/favicon.jpg b/favicon.jpg index 43c690244..4ec358746 100644 Binary files a/favicon.jpg and b/favicon.jpg differ diff --git a/pyproject.toml b/pyproject.toml index 36138c65d..4213bc963 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,3 +26,34 @@ include-package-data = true [tool.setuptools.cmdclass] build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand" + +[tool.ruff] +line-length = 88 +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "W", "C90", "B", "I"] +ignore = [ + "E203", # Whitespace before ':' + "E501", # Line too long + "C901", # Too complex + "B019", # Use of functools.cache on methods + "E722", # Bare except + "F821", # Undefined name (for dynamic exec) +] + +[tool.ruff.lint.isort] +known-third-party = ["wandb", "comet_ml"] +known-local-folder = ["src", "tests"] +# Black-compatible isort settings +force-single-line = false +combine-as-imports = true +split-on-trailing-comma = true + +[tool.ruff.format] +# Use black's formatting style exactly +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" +docstring-code-format = false diff --git a/requirements.txt b/requirements.txt index e790dfaed..acac68ff3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,32 +1,33 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ # START section of dependencies that don't install on Darwin/MacOS -bitsandbytes==0.45.4 +bitsandbytes==0.47.0 triton>=3.0.0 mamba-ssm==1.2.0.post1 xformers>=0.0.23.post1 -autoawq==0.2.7.post3 -liger-kernel==0.5.10 +liger-kernel==0.6.1 # END section packaging==23.2 -huggingface_hub==0.32.2 -peft==0.15.2 -transformers==4.52.3 +huggingface_hub>=0.33.0 +peft>=0.17.1 tokenizers>=0.21.1 -accelerate==1.7.0 -datasets==3.6.0 +transformers==4.57.0 +accelerate==1.10.1 +datasets==4.0.0 deepspeed>=0.17.0 -trl==0.18.1 -hf_xet==1.1.2 +trl==0.23.0 +hf_xet==1.1.5 +kernels==0.9.0 +trackio optimum==1.16.2 hf_transfer sentencepiece -gradio==5.23.3 +gradio==5.41.1 -modal==0.70.5 +modal==1.0.2 pydantic==2.10.6 addict fire @@ -62,11 +63,15 @@ langdetect==1.0.9 immutabledict==4.2.0 antlr4-python3-runtime==4.13.2 -torchao==0.10.0 +torchao==0.13.0 schedulefree==1.4.1 axolotl-contribs-lgpl==0.0.6 axolotl-contribs-mit==0.0.3 +axolotl-contribs-mit==0.0.5 + # telemetry posthog>=4.2.0 + +mistral-common==1.8.5 diff --git a/scripts/chat_datasets.py b/scripts/chat_datasets.py index 1a85fcef9..0c1e0bd03 100644 --- a/scripts/chat_datasets.py +++ b/scripts/chat_datasets.py @@ -27,7 +27,7 @@ def parse_dataset(dataset=None, split="train"): break if not field_messages: raise ValueError( - f'No conversation field found in dataset: {", ".join(feature_keys)}' + f"No conversation field found in dataset: {', '.join(feature_keys)}" ) ds_cfg["field_messages"] = field_messages @@ -40,7 +40,7 @@ def parse_dataset(dataset=None, split="train"): break if not message_property_mappings["role"]: raise ValueError( - f'No role field found in messages: {", ".join(message_fields)}' + f"No role field found in messages: {', '.join(message_fields)}" ) for key in ["content", "text", "value"]: @@ -49,7 +49,7 @@ def parse_dataset(dataset=None, split="train"): break if not message_property_mappings["content"]: raise ValueError( - f'No content field found in messages: {", ".join(message_fields)}' + f"No content field found in messages: {', '.join(message_fields)}" ) ds_cfg["message_property_mappings"] = message_property_mappings diff --git a/scripts/cloud-entrypoint.sh b/scripts/cloud-entrypoint.sh index 2d3e29181..c98e7c0d0 100755 --- a/scripts/cloud-entrypoint.sh +++ b/scripts/cloud-entrypoint.sh @@ -44,8 +44,13 @@ add_keys_to_authorized() { chmod 700 -R ~/.ssh } +# Set SSH port +if [ ! -z "$SSH_PORT" ]; then + sed -i "s/#Port 22/Port $SSH_PORT/" /etc/ssh/sshd_config +fi + if [[ $PUBLIC_KEY ]]; then - # runpod + # runpod, prime intellect add_keys_to_authorized "$PUBLIC_KEY" # Start the SSH service in the background service ssh start @@ -76,5 +81,13 @@ if [ ! -L "/workspace/axolotl/outputs" ]; then ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs fi +# start the runpod slurm init +SLURM_INIT="${SLURM_INIT:-/slurm-init.sh}" + +if [[ -f "$SLURM_INIT" ]]; then + echo "[entrypoint] running $SLURM_INIT..." + bash "$SLURM_INIT" +fi + # Execute the passed arguments (CMD) exec "$@" diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index 4a92746c1..cb498c002 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else "" print( UNINSTALL_PREFIX - + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a1174ca"' + + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"' ) diff --git a/scripts/motd b/scripts/motd index f842bd076..275a4fcba 100644 --- a/scripts/motd +++ b/scripts/motd @@ -13,6 +13,8 @@ Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory is empty, run the following commands: +Need help with your post-training workloads? Reach out us at contact@axolotl.ai for assistance. + ``` cd /workspace rm -rf /workspace/axolotl diff --git a/scripts/unsloth_install.py b/scripts/unsloth_install.py index acbd05e90..c0e5bbe70 100644 --- a/scripts/unsloth_install.py +++ b/scripts/unsloth_install.py @@ -1,11 +1,10 @@ # noqa -# pylint: skip-file import sys try: import torch -except ImportError: - raise ImportError("Install torch via `pip install torch`") +except ImportError as error: + raise ImportError("Install torch via `pip install torch`") from error from packaging.version import Version as V use_uv = "--uv" in sys.argv[1:] diff --git a/setup.py b/setup.py index 28f71f789..b2eeb92d6 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,6 @@ def parse_requirements(extras_require_map): _install_requires.append(line) try: xformers_version = [req for req in _install_requires if "xformers" in req][0] - autoawq_version = [req for req in _install_requires if "autoawq" in req][0] if "Darwin" in platform.system(): # skip packages not compatible with OSX skip_packages = [ @@ -34,7 +33,6 @@ def parse_requirements(extras_require_map): "triton", "mamba-ssm", "xformers", - "autoawq", "liger-kernel", ] _install_requires = [ @@ -64,24 +62,32 @@ def parse_requirements(extras_require_map): else: raise ValueError("Invalid version format") - if (major, minor) >= (2, 7): + if (major, minor) >= (2, 8): + pass + elif (major, minor) >= (2, 7): _install_requires.pop(_install_requires.index(xformers_version)) - # _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0 - extras_require_map["vllm"] = ["vllm==0.8.5.post1"] + if patch == 0: + _install_requires.append("xformers==0.0.30") + # vllm 0.9.x is incompatible with latest transformers + extras_require_map.pop("vllm") + else: + _install_requires.append("xformers==0.0.31") + extras_require_map["vllm"] = ["vllm>=0.10.0"] elif (major, minor) >= (2, 6): _install_requires.pop(_install_requires.index(xformers_version)) - _install_requires.append( - "xformers==0.0.29.post2" - ) # vllm needs post2 w torch 2.6 - extras_require_map["vllm"] = ["vllm==0.8.5.post1"] + _install_requires.append("xformers==0.0.29.post3") + # since we only support 2.6.0+cu126 + _dependency_links.append("https://download.pytorch.org/whl/cu126") + extras_require_map.pop("vllm") elif (major, minor) >= (2, 5): _install_requires.pop(_install_requires.index(xformers_version)) if patch == 0: _install_requires.append("xformers==0.0.28.post2") else: _install_requires.append("xformers>=0.0.28.post3") - _install_requires.pop(_install_requires.index(autoawq_version)) + extras_require_map.pop("vllm") elif (major, minor) >= (2, 4): + extras_require_map.pop("vllm") if patch == 0: _install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.27") @@ -111,14 +117,13 @@ def get_package_version(): extras_require = { - "flash-attn": ["flash-attn==2.7.4.post1"], + "flash-attn": ["flash-attn==2.8.3"], "ring-flash-attn": [ - "flash-attn==2.7.4.post1", - "ring-flash-attn>=0.1.4", - "yunchang==0.6.0", + "flash-attn==2.8.3", + "ring-flash-attn>=0.1.7", ], "deepspeed": [ - "deepspeed==0.17.0", + "deepspeed==0.17.5", "deepspeed-kernels", ], "mamba-ssm": [ @@ -148,13 +153,13 @@ extras_require = { "ray[train]", ], "vllm": [ - "vllm==0.7.2", + "vllm==0.10.0", ], "llmcompressor": [ "llmcompressor==0.5.1", ], + "fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"], } - install_requires, dependency_links, extras_require_build = parse_requirements( extras_require ) diff --git a/src/axolotl/__init__.py b/src/axolotl/__init__.py index 63f28adda..e08d43cc3 100644 --- a/src/axolotl/__init__.py +++ b/src/axolotl/__init__.py @@ -4,4 +4,4 @@ import pkgutil __path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package -__version__ = "0.10.0.dev0" +__version__ = "0.13.0.dev" diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 8955eca3e..fa647be65 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -4,5 +4,7 @@ import os from axolotl.logging_config import configure_logging -os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") +os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") + configure_logging() diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index e8571a900..14dafa43f 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -14,9 +14,13 @@ class PreprocessCliArgs: prompter: Optional[str] = field(default=None) download: Optional[bool] = field(default=True) iterable: Optional[bool] = field( - default=None, + default=False, metadata={ - "help": "Use IterableDataset for streaming processing of large datasets" + "help": ( + "Deprecated in v0.13.0, will be removed in v0.14.0. For streaming " + "datasets, use 'axolotl train' and set 'streaming: true' in your YAML " + "config, or pass --streaming instead in the CLI." + ) }, ) @@ -30,8 +34,6 @@ class TrainerCliArgs: debug_num_examples: int = field(default=0) prompter: Optional[str] = field(default=None) shard: bool = field(default=False) - main_process_port: Optional[int] = field(default=None) - num_processes: Optional[int] = field(default=None) @dataclass @@ -42,6 +44,12 @@ class VllmServeCliArgs: default=None, metadata={"help": "Number of tensor parallel workers to use."}, ) + data_parallel_size: Optional[int] = field( + default=None, + metadata={ + "help": "Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference." + }, + ) host: Optional[str] = field( default=None, # nosec B104 metadata={"help": "Host address to run the server on."}, @@ -107,6 +115,7 @@ class QuantizeCliArgs: quantize_embedding: Optional[bool] = field(default=None) group_size: Optional[int] = field(default=None) output_dir: Optional[str] = field(default=None) + hub_model_id: Optional[str] = field(default=None) @dataclass diff --git a/src/axolotl/cli/art.py b/src/axolotl/cli/art.py index 2051784e9..81dbb9831 100644 --- a/src/axolotl/cli/art.py +++ b/src/axolotl/cli/art.py @@ -22,7 +22,7 @@ HAS_PRINTED_LOGO = False def print_axolotl_text_art(): """Prints axolotl ASCII art.""" - global HAS_PRINTED_LOGO # pylint: disable=global-statement + global HAS_PRINTED_LOGO if HAS_PRINTED_LOGO: return if is_main_process(): diff --git a/src/axolotl/cli/checks.py b/src/axolotl/cli/checks.py index 10086c2a4..a743e74dc 100644 --- a/src/axolotl/cli/checks.py +++ b/src/axolotl/cli/checks.py @@ -6,6 +6,7 @@ from pathlib import Path from accelerate.commands.config import config_args from huggingface_hub import HfApi from huggingface_hub.utils import LocalTokenNotFoundError +from requests import HTTPError from axolotl.utils.logging import get_logger @@ -46,3 +47,8 @@ def check_user_token() -> bool: "Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets." ) return False + except HTTPError: + LOG.warning( + "Error accessing HuggingFace. This may be due to a network issue or rate limiting." + ) + return False diff --git a/src/axolotl/cli/cloud/__init__.py b/src/axolotl/cli/cloud/__init__.py index 5d6900d3e..60f6a51ce 100644 --- a/src/axolotl/cli/cloud/__init__.py +++ b/src/axolotl/cli/cloud/__init__.py @@ -3,16 +3,17 @@ launch axolotl in supported cloud platforms """ from pathlib import Path -from typing import Union +from typing import Literal import yaml -from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.cloud.base import Cloud +from axolotl.cli.cloud.baseten import BasetenCloud from axolotl.cli.cloud.modal_ import ModalCloud from axolotl.utils.dict import DictDefault -def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault: +def load_cloud_cfg(cloud_config: Path | str) -> DictDefault: """Load and validate cloud configuration.""" # Load cloud configuration. with open(cloud_config, encoding="utf-8") as file: @@ -21,10 +22,9 @@ def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault: def do_cli_preprocess( - cloud_config: Union[Path, str], - config: Union[Path, str], + cloud_config: Path | str, + config: Path | str, ) -> None: - print_axolotl_text_art() cloud_cfg = load_cloud_cfg(cloud_config) cloud = ModalCloud(cloud_cfg) with open(config, "r", encoding="utf-8") as file: @@ -33,28 +33,40 @@ def do_cli_preprocess( def do_cli_train( - cloud_config: Union[Path, str], - config: Union[Path, str], - accelerate: bool = True, + cloud_config: Path | str, + config: Path | str, + launcher: Literal["accelerate", "torchrun", "python"] = "accelerate", + launcher_args: list[str] | None = None, cwd=None, **kwargs, ) -> None: - print_axolotl_text_art() - cloud_cfg = load_cloud_cfg(cloud_config) - cloud = ModalCloud(cloud_cfg) + cloud_cfg: DictDefault = load_cloud_cfg(cloud_config) + provider = cloud_cfg.provider or "modal" + cloud: Cloud | None + if provider == "modal": + cloud = ModalCloud(cloud_cfg) + elif provider == "baseten": + cloud = BasetenCloud(cloud_cfg.to_dict()) + else: + raise ValueError(f"Unsupported cloud provider: {provider}") with open(config, "r", encoding="utf-8") as file: config_yaml = file.read() local_dirs = {} if cwd and not Path(cwd).joinpath("src", "axolotl").exists(): local_dirs = {"/workspace/mounts": cwd} - cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs) + cloud.train( + config_yaml, + launcher=launcher, + launcher_args=launcher_args, + local_dirs=local_dirs, + **kwargs, + ) def do_cli_lm_eval( - cloud_config: Union[Path, str], - config: Union[Path, str], + cloud_config: Path | str, + config: Path | str, ) -> None: - print_axolotl_text_art() cloud_cfg = load_cloud_cfg(cloud_config) cloud = ModalCloud(cloud_cfg) with open(config, "r", encoding="utf-8") as file: diff --git a/src/axolotl/cli/cloud/base.py b/src/axolotl/cli/cloud/base.py index eba8be49a..c498e8691 100644 --- a/src/axolotl/cli/cloud/base.py +++ b/src/axolotl/cli/cloud/base.py @@ -3,6 +3,7 @@ base class for cloud platforms from cli """ from abc import ABC, abstractmethod +from typing import Literal class Cloud(ABC): @@ -15,5 +16,12 @@ class Cloud(ABC): pass @abstractmethod - def train(self, config_yaml: str, accelerate: bool = True) -> str: + def train( + self, + config_yaml: str, + launcher: Literal["accelerate", "torchrun", "python"] = "accelerate", + launcher_args: list[str] | None = None, + local_dirs: dict[str, str] | None = None, + **kwargs, + ): pass diff --git a/src/axolotl/cli/cloud/baseten/__init__.py b/src/axolotl/cli/cloud/baseten/__init__.py new file mode 100644 index 000000000..914504de3 --- /dev/null +++ b/src/axolotl/cli/cloud/baseten/__init__.py @@ -0,0 +1,48 @@ +"""Baseten Cloud CLI""" + +import shutil +import subprocess # nosec B404 +import tempfile +from os.path import dirname +from typing import Literal + +import yaml + +from axolotl.cli.cloud.base import Cloud + + +class BasetenCloud(Cloud): + """Baseten Cloud Axolotl CLI""" + + def __init__(self, config: dict): + self.config = config + + def preprocess(self, config_yaml: str, *args, **kwargs) -> None: + raise NotImplementedError( + "Separate preprocess function for Baseten is not " + "implemented and will happen during hte train step." + ) + + def train( + self, + config_yaml: str, + launcher: Literal["accelerate", "torchrun", "python"] = "accelerate", + launcher_args: list[str] | None = None, + local_dirs: dict[str, str] | None = None, # pylint: disable=unused-argument + **kwargs, + ): + with tempfile.TemporaryDirectory() as tmp_dir: + config = self.config.copy() + config["launcher"] = launcher + config["launcher_args"] = launcher_args + with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout: + yaml.dump(config, cloud_fout) + with open(tmp_dir + "/train.yaml", "w", encoding="utf-8") as config_fout: + config_fout.write(config_yaml) + shutil.copyfile(dirname(__file__) + "/template/run.sh", tmp_dir + "/run.sh") + shutil.copyfile( + dirname(__file__) + "/template/train_sft.py", tmp_dir + "/train_sft.py" + ) + subprocess.run( # nosec B603 B607 + ["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False + ) diff --git a/src/axolotl/cli/cloud/baseten/template/run.sh b/src/axolotl/cli/cloud/baseten/template/run.sh new file mode 100644 index 000000000..37dc9688f --- /dev/null +++ b/src/axolotl/cli/cloud/baseten/template/run.sh @@ -0,0 +1,9 @@ +#!/bin/bash +set -eux + +export NCCL_SOCKET_IFNAME="^docker0,lo" +export NCCL_IB_DISABLE=0 +export NCCL_TIMEOUT=1800000 + +axolotl preprocess train.yaml +axolotl train train.yaml --launcher ${AXOLOTL_LAUNCHER} ${AXOLOTL_LAUNCHER_ARGS} diff --git a/src/axolotl/cli/cloud/baseten/template/train_sft.py b/src/axolotl/cli/cloud/baseten/template/train_sft.py new file mode 100644 index 000000000..137fb9171 --- /dev/null +++ b/src/axolotl/cli/cloud/baseten/template/train_sft.py @@ -0,0 +1,71 @@ +""" +Baseten Training Script for Axolotl +""" + +# pylint: skip-file +import yaml +from truss.base import truss_config + +# Import necessary classes from the Baseten Training SDK +from truss_train import definitions + +cloud_config = yaml.safe_load(open("cloud.yaml", "r")) +gpu = cloud_config.get("gpu", "h100") +gpu_count = int(cloud_config.get("gpu_count", 1)) +node_count = int(cloud_config.get("node_count", 1)) +project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project" +secrets = cloud_config.get("secrets", []) +launcher = cloud_config.get("launcher", "accelerate") +launcher_args = cloud_config.get("launcher_args", []) +script_name = "run.sh" + +launcher_args_str = "" +if launcher_args: + launcher_args_str = "-- " + " ".join(launcher_args) + +# 1. Define a base image for your training job +# must use torch 2.7.0 for vllm +BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1" + +# 2. Define the Runtime Environment for the Training Job +# This includes start commands and environment variables.a +# Secrets from the baseten workspace like API keys are referenced using +# `SecretReference`. + +env_vars = { + "AXOLOTL_LAUNCHER": launcher, + "AXOLOTL_LAUNCHER_ARGS": launcher_args_str, +} +for secret_name in secrets: + env_vars[secret_name] = definitions.SecretReference(name=secret_name) + +training_runtime = definitions.Runtime( + start_commands=[ # Example: list of commands to run your training script + f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'" + ], + environment_variables=env_vars, +) + +# 3. Define the Compute Resources for the Training Job +training_compute = definitions.Compute( + node_count=node_count, + accelerator=truss_config.AcceleratorSpec( + accelerator=truss_config.Accelerator.H100, + count=gpu_count, + ), +) + +# 4. Define the Training Job +# This brings together the image, compute, and runtime configurations. +my_training_job = definitions.TrainingJob( + image=definitions.Image(base_image=BASE_IMAGE), + compute=training_compute, + runtime=training_runtime, +) + + +# This config will be pushed using the Truss CLI. +# The association of the job to the project happens at the time of push. +first_project_with_job = definitions.TrainingProject( + name=project_name, job=my_training_job +) diff --git a/src/axolotl/cli/cloud/modal_.py b/src/axolotl/cli/cloud/modal_.py index 83cdd7b72..7f953372d 100644 --- a/src/axolotl/cli/cloud/modal_.py +++ b/src/axolotl/cli/cloud/modal_.py @@ -8,7 +8,7 @@ import os import subprocess # nosec B404 from pathlib import Path from random import randint -from typing import Optional +from typing import Literal import modal @@ -41,7 +41,7 @@ def run_cmd(cmd: str, run_folder: str, volumes=None): if exit_code := subprocess.call( # nosec B603 cmd.split(), cwd=run_folder, env=new_env ): - exit(exit_code) # pylint: disable=consider-using-sys-exit + exit(exit_code) # Commit writes to volume. if volumes: @@ -82,7 +82,7 @@ class ModalCloud(Cloud): return res def get_image(self): - docker_tag = "main-py3.11-cu124-2.6.0" + docker_tag = "main-py3.11-cu126-2.7.1" if self.config.docker_tag: docker_tag = self.config.docker_tag docker_image = f"axolotlai/axolotl:{docker_tag}" @@ -130,7 +130,6 @@ class ModalCloud(Cloud): res = [] if self.config.secrets: for key in self.config.get("secrets", []): - # pylint: disable=duplicate-code if isinstance(key, str): if val := os.environ.get(key, ""): res.append(modal.Secret.from_dict({key: val})) @@ -177,8 +176,8 @@ class ModalCloud(Cloud): with self.app.run(detach=True): modal_fn.remote( config_yaml, - volumes={k: v[0] for k, v in self.volumes.items()}, *args, + volumes={k: v[0] for k, v in self.volumes.items()}, **kwargs, ) @@ -187,7 +186,7 @@ class ModalCloud(Cloud): return int(self.config.timeout) return 60 * 60 * 24 # 24 hours - def get_train_gpu(self): # pylint: disable=too-many-return-statements + def get_train_gpu(self): count = self.config.gpu_count or 1 family = self.config.gpu.lower() or "l40s" @@ -200,7 +199,7 @@ class ModalCloud(Cloud): if family in ["a10", "a10g"]: return modal.gpu.A10G(count=count) if family == "h100": - return modal.gpu.H100(count=count) + return f"H100:{count}" if family == "t4": return modal.gpu.T4(count=count) if family == "l4": @@ -230,8 +229,9 @@ class ModalCloud(Cloud): def train( self, config_yaml: str, - accelerate: bool = True, - local_dirs: Optional[dict[str, str]] = None, + launcher: Literal["accelerate", "torchrun", "python"] = "accelerate", + launcher_args: list[str] | None = None, + local_dirs: dict[str, str] | None = None, **kwargs, ): modal_fn = self.get_train_env(local_dirs)(_train) @@ -239,7 +239,8 @@ class ModalCloud(Cloud): with self.app.run(detach=True): modal_fn.remote( config_yaml, - accelerate=accelerate, + launcher=launcher, + launcher_args=launcher_args, volumes={k: v[0] for k, v in self.volumes.items()}, **kwargs, ) @@ -270,20 +271,35 @@ def _preprocess(config_yaml: str, volumes=None): ) -def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs): +def _train( + config_yaml: str, + launcher: Literal["accelerate", "torchrun", "python"] = "accelerate", + launcher_args: list[str] | None = None, + volumes=None, + **kwargs, +): Path("/workspace/mounts").mkdir(parents=True, exist_ok=True) with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out: f_out.write(config_yaml) run_folder = "/workspace/mounts" - if accelerate: - accelerate_args = "--accelerate" + + launcher_args = launcher_args or [] + + # Build the base command + if launcher == "accelerate": + launcher_arg = "--launcher accelerate" + elif launcher == "torchrun": + launcher_arg = "--launcher torchrun" else: - accelerate_args = "--no-accelerate" - num_processes_args = "" - if num_processes := kwargs.pop("num_processes", None): - num_processes_args = f"--num-processes {num_processes}" + launcher_arg = "--launcher python" + + # Build launcher args string + launcher_args_str = "" + if launcher_args: + launcher_args_str = "-- " + " ".join(launcher_args) + run_cmd( - f"axolotl train {accelerate_args} {num_processes_args} /workspace/mounts/config.yaml", + f"axolotl train {launcher_arg} /workspace/mounts/config.yaml {launcher_args_str}".strip(), run_folder, volumes, ) diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 94b3b02b7..3c4ace7b0 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -25,10 +25,13 @@ from axolotl.utils.config import ( from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger from axolotl.utils.mlflow_ import setup_mlflow_env_vars -from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env +from axolotl.utils.tee import prepare_debug_log +from axolotl.utils.trainer import prepare_optim_env from axolotl.utils.wandb_ import setup_wandb_env_vars -LOG = get_logger(__name__, use_environ=True) +LOG = get_logger(__name__) + +API_KEY_FIELDS = {"comet_api_key"} TELEMETRY_MANAGER = TelemetryManager.get_instance() @@ -155,6 +158,8 @@ def prepare_plugins(cfg: DictDefault): plugin_manager = PluginManager.get_instance() for plugin_name in cfg["plugins"]: plugin_manager.register(plugin_name) + for plugin in plugin_manager.plugins.values(): + plugin.register(cfg) def plugin_set_cfg(cfg: DictDefault): @@ -202,19 +207,18 @@ def load_cfg( # If there are any options passed in the cli, if it is something that seems valid # from the yaml, then overwrite the value cfg_keys = cfg.keys() - for k, _ in kwargs.items(): - # if not strict, allow writing to cfg even if it's not in the yml already - if k in cfg_keys or not cfg.strict: - # handle booleans - if isinstance(cfg[k], bool): - cfg[k] = bool(kwargs[k]) + for key, value in kwargs.items(): + # If not strict, allow writing to cfg even if it's not in the yml already + if key in cfg_keys or not cfg.strict: + if isinstance(cfg[key], bool): + cfg[key] = bool(value) else: - cfg[k] = kwargs[k] + cfg[key] = value try: device_props = torch.cuda.get_device_properties("cuda") gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) - except: # pylint: disable=bare-except # noqa: E722 + except: gpu_version = None prepare_plugins(cfg) @@ -231,8 +235,11 @@ def load_cfg( }, ) + # NOTE(djsaunde): We start outputting to output_dir/debug.log at this point since we + # have to wait for cfg.output to be resolved. We could call this earlier if we write + # to a temporary file, and then move it later. + prepare_debug_log(cfg) prepare_optim_env(cfg) - prepare_opinionated_env(cfg) normalize_config(cfg) normalize_cfg_datasets(cfg) setup_wandb_env_vars(cfg) @@ -241,5 +248,14 @@ def load_cfg( plugin_set_cfg(cfg) TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg) + cfg_to_log = { + k: "[REDACTED]" if k in API_KEY_FIELDS else v + for k, v in cfg.items() + if v is not None + } + LOG.info( + "config:\n%s", + json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True), + ) return cfg diff --git a/src/axolotl/cli/delinearize_llama4.py b/src/axolotl/cli/delinearize_llama4.py index c92bae930..4f5448a14 100644 --- a/src/axolotl/cli/delinearize_llama4.py +++ b/src/axolotl/cli/delinearize_llama4.py @@ -9,7 +9,6 @@ from typing import Generator, Union import fire import torch from accelerate import init_empty_weights -from dotenv import load_dotenv from transformers import AutoProcessor @@ -86,9 +85,7 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None: unpatch_llama4 = patch_llama4_linearized_modeling() from transformers import Llama4ForConditionalGeneration - model_ = Llama4ForConditionalGeneration.from_pretrained( - model, torch_dtype=torch.bfloat16 - ) + model_ = Llama4ForConditionalGeneration.from_pretrained(model, dtype=torch.bfloat16) processor = AutoProcessor.from_pretrained(model) processor.save_pretrained(output) @@ -152,5 +149,4 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None: if __name__ == "__main__": - load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index f131f7083..1a73937a2 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -5,16 +5,13 @@ from pathlib import Path from typing import Union import fire -from dotenv import load_dotenv from transformers.hf_argparser import HfArgumentParser from axolotl.cli.args import TrainerCliArgs -from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.evaluate import evaluate -from axolotl.utils import patch_optimized_env from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -31,11 +28,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: cfg: Dictionary mapping `axolotl` config keys to values. cli_args: CLI arguments. """ - # Enable expandable segments for cuda allocation to improve VRAM usage - patch_optimized_env() - # pylint: disable=duplicate-code - print_axolotl_text_art() check_accelerate_default_config() if int(os.getenv("LOCAL_RANK", "0")) == 0: check_user_token() @@ -56,7 +49,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: config: Path to `axolotl` config YAML file. kwargs: Additional keyword arguments to override config file values. """ - # pylint: disable=duplicate-code + parsed_cfg = load_cfg(config, **kwargs) parser = HfArgumentParser(TrainerCliArgs) parsed_cli_args, _ = parser.parse_args_into_dataclasses( @@ -66,5 +59,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: if __name__ == "__main__": - load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index d509c5517..640be3696 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -9,16 +9,18 @@ from typing import Union import fire import torch import transformers -from dotenv import load_dotenv from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer from axolotl.cli.args import InferenceCliArgs -from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg from axolotl.cli.utils import load_model_and_tokenizer +from axolotl.cli.utils.diffusion import ( + diffusion_inference, + launch_diffusion_gradio_ui, +) +from axolotl.integrations.base import PluginManager from axolotl.telemetry.errors import send_errors from axolotl.utils.chat_templates import ( - get_chat_template, get_chat_template_from_config, ) from axolotl.utils.dict import DictDefault @@ -35,10 +37,11 @@ def get_multi_line_input() -> str: Possibly multi-line, possibly empty stdin input as a string. """ print("Give me an instruction (Ctrl + D to submit): ") + print("=" * 80) instruction = "" for line in sys.stdin: - instruction += line # pylint: disable=consider-using-join + instruction += line return instruction @@ -50,9 +53,9 @@ def do_inference( cli_args: InferenceCliArgs, ): """ - Runs inference on the command line in a loop. User input is accepted, a chat template - is (optionally) applied, and the model specified in the `axolotl` config is used to - generate completions according to a default generation config. + Runs inference on the command line in a loop. User input is accepted, a chat + template is (optionally) applied, and the model specified in the `axolotl` config is + used to generate completions according to a default generation config. Args: cfg: Dictionary mapping `axolotl` config keys to values. @@ -68,17 +71,31 @@ def do_inference( importlib.import_module("axolotl.prompters"), prompter ) elif cfg.chat_template: - chat_template_str = get_chat_template(cfg.chat_template) - elif cfg.datasets[0].type == "chat_template": + chat_template_str = get_chat_template_from_config( + cfg, ds_cfg=None, tokenizer=tokenizer + ) + elif cfg.datasets and cfg.datasets[0].type == "chat_template": chat_template_str = get_chat_template_from_config( cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer ) model = model.to(cfg.device, dtype=cfg.torch_dtype) + # Detect diffusion mode + plugin_manager = PluginManager.get_instance() + is_diffusion = any( + plugin.__class__.__name__ == "DiffusionPlugin" + for plugin in plugin_manager.plugins.values() + ) + + if is_diffusion: + print("=" * 80) + print("Commands:") + print(":complete N -> completion mode with N tokens (default 64)") + print(":mask R -> random masking with ratio R (0.0–1.0)") + while True: print("=" * 80) - # support for multiline inputs instruction = get_multi_line_input() if not instruction: return @@ -108,9 +125,19 @@ def do_inference( else: batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) - print("=" * 40) + print("=" * 80) model.eval() with torch.no_grad(): + if is_diffusion: + diffusion_inference( + model=model, + tokenizer=tokenizer, + cfg=cfg, + prompt=prompt, + chat_template_str=chat_template_str, + ) + continue + generation_config = GenerationConfig( repetition_penalty=1.1, max_new_tokens=1024, @@ -133,7 +160,7 @@ def do_inference( generation_config=generation_config, streamer=streamer, ) - print("=" * 40) + print("=" * 80) print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) @@ -164,15 +191,37 @@ def do_inference_gradio( importlib.import_module("axolotl.prompters"), prompter ) elif cfg.chat_template: - chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer) + chat_template_str = get_chat_template_from_config( + cfg, ds_cfg=None, tokenizer=tokenizer + ) + elif cfg.datasets and cfg.datasets[0].type == "chat_template": + chat_template_str = get_chat_template_from_config( + cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer + ) model = model.to(cfg.device, dtype=cfg.torch_dtype) + # Detect diffusion mode + plugin_manager = PluginManager.get_instance() + is_diffusion = any( + plugin.__class__.__name__ == "DiffusionPlugin" + for plugin in plugin_manager.plugins.values() + ) + + if is_diffusion: + launch_diffusion_gradio_ui( + model=model, + tokenizer=tokenizer, + cfg=cfg, + prompter_module=prompter_module, + chat_template_str=chat_template_str, + ) + return + def generate(instruction): if not instruction: return if prompter_module: - # pylint: disable=stop-iteration-return prompt: str = next( prompter_module().build_prompt(instruction=instruction.strip("\n")) ) @@ -257,8 +306,7 @@ def do_cli( config: Path to `axolotl` config YAML file. kwargs: Additional keyword arguments to override config file values. """ - # pylint: disable=duplicate-code - print_axolotl_text_art() + parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs) parsed_cfg.sample_packing = False parser = transformers.HfArgumentParser(InferenceCliArgs) @@ -273,5 +321,4 @@ def do_cli( if __name__ == "__main__": - load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 3dafa552b..dc6cca489 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -1,15 +1,10 @@ """Click CLI definitions for various axolotl commands.""" -# pylint: disable=redefined-outer-name - import os import subprocess # nosec B404 -import tempfile -from pathlib import Path -from typing import Optional +from typing import Literal, Optional import click -import yaml from dotenv import load_dotenv import axolotl @@ -20,26 +15,36 @@ from axolotl.cli.args import ( TrainerCliArgs, VllmServeCliArgs, ) -from axolotl.cli.sweeps import generate_sweep_configs +from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.utils import ( add_options_from_config, add_options_from_dataclass, build_command, fetch_from_github, filter_none_kwargs, + generate_config_files, + launch_training, ) from axolotl.integrations.lm_eval.cli import lm_eval -from axolotl.utils import patch_optimized_env +from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.logging import get_logger from axolotl.utils.schemas.config import AxolotlInputConfig LOG = get_logger(__name__) +LAUNCHER_COMMAND_MAPPING = { + "accelerate": ["accelerate", "launch"], + "torchrun": ["torchrun"], +} + @click.group() @click.version_option(version=axolotl.__version__, prog_name="axolotl") def cli(): """Axolotl CLI - Train and fine-tune large language models""" + print_axolotl_text_art() + load_dotenv() + set_pytorch_cuda_alloc_conf() @cli.command() @@ -48,7 +53,7 @@ def cli(): @add_options_from_dataclass(PreprocessCliArgs) @add_options_from_config(AxolotlInputConfig) @filter_none_kwargs -def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None: +def preprocess(config: str, cloud: Optional[str] = None, **kwargs): """ Preprocess datasets before training. @@ -58,7 +63,6 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None: kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ - patch_optimized_env() if cloud: from axolotl.cli.cloud import do_cli_preprocess @@ -70,12 +74,15 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None: do_cli(config=config, **kwargs) -@cli.command() +@cli.command( + context_settings={"ignore_unknown_options": True, "allow_extra_args": True} +) @click.argument("config", type=click.Path(exists=True, path_type=str)) @click.option( - "--accelerate/--no-accelerate", - default=True, - help="Use accelerate launch for multi-GPU training", + "--launcher", + type=click.Choice(["accelerate", "torchrun", "python"]), + default="accelerate", + help="Launcher to use for multi-GPU training", ) @click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str)) @click.option( @@ -86,126 +93,82 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None: @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) @filter_none_kwargs +@click.pass_context def train( + ctx: click.Context, config: str, - accelerate: bool, - cloud: Optional[str] = None, - sweep: Optional[str] = None, + launcher: Literal["accelerate", "torchrun", "python"] = "accelerate", + cloud: str | None = None, + sweep: str | None = None, **kwargs, -) -> None: +): """ Train or fine-tune a model. Args: + ctx: Click context for extra args. config: Path to `axolotl` config YAML file. - accelerate: Whether to use `accelerate` launcher. + launcher: Launcher to use for multi-GPU training ("accelerate", "torchrun", or "python"). cloud: Path to a cloud accelerator configuration file sweep: Path to YAML config for sweeping hyperparameters. kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ - # Enable expandable segments for cuda allocation to improve VRAM usage - patch_optimized_env() + # Extract launcher args from extra args (after --) + launcher_args = ctx.args if ctx.args else [] - if "use_ray" in kwargs and kwargs["use_ray"]: - accelerate = False - if sweep: - # load the sweep configuration yaml file - with open(sweep, "r", encoding="utf-8") as fin: - sweep_config: dict[str, list] = yaml.safe_load(fin) - with open(config, "r", encoding="utf-8") as fin: - base_config: dict[str, list] = yaml.safe_load(fin) + # Handle Ray launcher override + _launcher = None if kwargs.get("use_ray") else launcher - # generate all possible configurations - permutations = generate_sweep_configs(base_config, sweep_config) - - def iter_configs(): - for perm in permutations: - # open temp directory for temporary configurations - with tempfile.TemporaryDirectory() as temp_dir: - with open( - Path(temp_dir) / "config.yaml", "w", encoding="utf-8" - ) as fout: - yaml.dump(perm, fout) - yield str(Path(temp_dir) / "config.yaml") - - else: - - def iter_configs(): - yield config - - for cfg_file in iter_configs(): - # handle errors from subprocess so we can continue rest of sweeps + # Process each configuration + for cfg_file, is_group in generate_config_files(config, sweep): try: - if accelerate: - if cloud: - from axolotl.cli.cloud import do_cli_train - - cwd = os.getcwd() - do_cli_train( - cloud_config=cloud, - config=config, - accelerate=True, - cwd=cwd, - **kwargs, - ) - else: - accelerate_args = [] - if "main_process_port" in kwargs: - main_process_port = kwargs.pop("main_process_port", None) - accelerate_args.append("--main_process_port") - accelerate_args.append(str(main_process_port)) - if "num_processes" in kwargs: - num_processes = kwargs.pop("num_processes", None) - accelerate_args.append("--num_processes") - accelerate_args.append(str(num_processes)) - - base_cmd = ["accelerate", "launch"] - base_cmd.extend(accelerate_args) - base_cmd.extend(["-m", "axolotl.cli.train"]) - if cfg_file: - base_cmd.append(cfg_file) - cmd = build_command(base_cmd, kwargs) - subprocess.run(cmd, check=True) # nosec B603 - else: - if cloud: - from axolotl.cli.cloud import do_cli_train - - do_cli_train( - cloud_config=cloud, config=config, accelerate=False, **kwargs - ) - else: - from axolotl.cli.train import do_cli - - do_cli(config=cfg_file, **kwargs) + use_exec = is_group is not True + launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec) except subprocess.CalledProcessError as exc: LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}") if not sweep: raise exc + finally: + # Only delete temp files, not the original config + if cfg_file != config: + os.unlink(cfg_file) -@cli.command() +@cli.command( + context_settings={"ignore_unknown_options": True, "allow_extra_args": True} +) @click.argument("config", type=click.Path(exists=True, path_type=str)) @click.option( - "--accelerate/--no-accelerate", - default=True, - help="Use accelerate launch for multi-GPU training", + "--launcher", + type=click.Choice(["accelerate", "torchrun", "python"]), + default="accelerate", + help="Launcher to use for multi-GPU evaluation", ) @add_options_from_dataclass(EvaluateCliArgs) @add_options_from_config(AxolotlInputConfig) @filter_none_kwargs -def evaluate(config: str, accelerate: bool, **kwargs) -> None: +@click.pass_context +def evaluate(ctx: click.Context, config: str, launcher: str, **kwargs): """ Evaluate a model. Args: + ctx: Click context for extra args. config: Path to `axolotl` config YAML file. - accelerate: Whether to use `accelerate` launcher. + launcher: Launcher to use for multi-GPU evaluation ("accelerate", "torchrun", or "python"). kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ - if accelerate: - base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"] + # Extract launcher args from extra args (after --) + launcher_args = ctx.args if ctx.args else [] + + if launcher in LAUNCHER_COMMAND_MAPPING: + base_cmd = ( + LAUNCHER_COMMAND_MAPPING[launcher] + + launcher_args + + ["-m", "axolotl.cli.evaluate"] + ) if config: base_cmd.append(config) cmd = build_command(base_cmd, kwargs) @@ -216,30 +179,42 @@ def evaluate(config: str, accelerate: bool, **kwargs) -> None: do_cli(config=config, **kwargs) -@cli.command() +@cli.command( + context_settings={"ignore_unknown_options": True, "allow_extra_args": True} +) @click.argument("config", type=click.Path(exists=True, path_type=str)) @click.option( - "--accelerate/--no-accelerate", - default=False, - help="Use accelerate launch for multi-GPU inference", + "--launcher", + type=click.Choice(["accelerate", "torchrun", "python"]), + default="accelerate", + help="Launcher to use for multi-GPU inference", ) @click.option("--gradio", is_flag=True, help="Launch Gradio interface") @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) @filter_none_kwargs -def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None: +@click.pass_context +def inference(ctx: click.Context, config: str, launcher: str, gradio: bool, **kwargs): """ Run inference with a trained model. Args: + ctx: Click context for extra args. config: Path to `axolotl` config YAML file. - accelerate: Whether to use `accelerate` launcher. + launcher: Launcher to use for multi-GPU inference ("accelerate", "torchrun", or "python"). gradio: Whether to use Gradio browser interface or command line for inference. kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ - if accelerate: - base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"] + # Extract launcher args from extra args (after --) + launcher_args = ctx.args if ctx.args else [] + + if launcher in LAUNCHER_COMMAND_MAPPING: + base_cmd = ( + LAUNCHER_COMMAND_MAPPING[launcher] + + launcher_args + + ["-m", "axolotl.cli.inference"] + ) if config: base_cmd.append(config) if gradio: @@ -252,33 +227,42 @@ def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None: do_cli(config=config, gradio=gradio, **kwargs) -@cli.command() +@cli.command( + context_settings={"ignore_unknown_options": True, "allow_extra_args": True} +) @click.argument("config", type=click.Path(exists=True, path_type=str)) @click.option( - "--accelerate/--no-accelerate", - default=True, - help="Use accelerate launch for weight merging", + "--launcher", + type=click.Choice(["accelerate", "torchrun", "python"]), + default="accelerate", + help="Launcher to use for weight merging", ) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) @filter_none_kwargs -def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None: +@click.pass_context +def merge_sharded_fsdp_weights( + ctx: click.Context, config: str, launcher: str, **kwargs +): """ Merge sharded FSDP model weights. Args: + ctx: Click context for extra args. config: Path to `axolotl` config YAML file. - accelerate: Whether to use `accelerate` launcher. + launcher: Launcher to use for weight merging ("accelerate", "torchrun", or "python"). kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ - if accelerate: - base_cmd = [ - "accelerate", - "launch", - "-m", - "axolotl.cli.merge_sharded_fsdp_weights", - ] + # Extract launcher args from extra args (after --) + launcher_args = ctx.args if ctx.args else [] + + if launcher in LAUNCHER_COMMAND_MAPPING: + base_cmd = ( + LAUNCHER_COMMAND_MAPPING[launcher] + + launcher_args + + ["-m", "axolotl.cli.merge_sharded_fsdp_weights"] + ) if config: base_cmd.append(config) cmd = build_command(base_cmd, kwargs) @@ -294,7 +278,7 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None: @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) @filter_none_kwargs -def merge_lora(config: str, **kwargs) -> None: +def merge_lora(config: str, **kwargs): """ Merge trained LoRA adapters into a base model. @@ -311,7 +295,7 @@ def merge_lora(config: str, **kwargs) -> None: @cli.command() @click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"])) @click.option("--dest", help="Destination directory") -def fetch(directory: str, dest: Optional[str]) -> None: +def fetch(directory: str, dest: Optional[str]): """ Fetch example configs or other resources. @@ -349,7 +333,7 @@ def quantize(config: str, **cli_args: QuantizeCliArgs): @cli.command() @click.argument("model", type=click.Path(exists=True, path_type=str)) @click.argument("output", type=click.Path(exists=False, path_type=str)) -def delinearize_llama4(model: str, output: str) -> None: +def delinearize_llama4(model: str, output: str): from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4 do_delinearize_llama4(model, output) @@ -363,5 +347,4 @@ def main(): if __name__ == "__main__": - load_dotenv() main() diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 4fa87e90b..482767b12 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -4,9 +4,7 @@ from pathlib import Path from typing import Union import fire -from dotenv import load_dotenv -from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg from axolotl.cli.utils import load_model_and_tokenizer from axolotl.telemetry.errors import send_errors @@ -25,8 +23,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None: Args: cfg: Dictionary mapping `axolotl` config keys to values. """ - print_axolotl_text_art() - model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg) safe_serialization = cfg.save_safetensors is True @@ -49,7 +45,10 @@ def do_merge_lora(*, cfg: DictDefault) -> None: safe_serialization=safe_serialization, progressbar=True, ) - tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) + tokenizer.save_pretrained( + str(Path(cfg.output_dir) / "merged"), + save_jinja_files=cfg.tokenizer_save_jinja_files, + ) if processor: processor.save_pretrained(str(Path(cfg.output_dir) / "merged")) @@ -75,7 +74,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: load_in_8bit=False, load_in_4bit=False, flash_attention=False, - sequence_parallel_degree=None, + context_parallel_size=None, deepspeed=None, fsdp=None, fsdp_config=None, @@ -93,5 +92,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: if __name__ == "__main__": - load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py index e251f8dbf..1d9736b9d 100644 --- a/src/axolotl/cli/merge_sharded_fsdp_weights.py +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -10,6 +10,7 @@ import fire import torch import torch.distributed.checkpoint as dist_cp import torch.distributed.checkpoint.format_utils as dist_cp_format_utils +from accelerate import PartialState from accelerate.utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, @@ -17,15 +18,14 @@ from accelerate.utils import ( WEIGHTS_NAME, is_torch_version, ) -from dotenv import load_dotenv from huggingface_hub import split_torch_state_dict_into_shards from safetensors.torch import save_file as safe_save_file from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner -from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg from axolotl.telemetry.errors import send_errors from axolotl.utils.logging import get_logger +from axolotl.utils.train import determine_last_checkpoint LOG = get_logger(__name__) @@ -33,7 +33,7 @@ LOG = get_logger(__name__) class BFloat16CastPlanner(_EmptyStateDictLoadPlanner): """A custom planner to cast tensors to bfloat16 on the fly during loading.""" - def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument + def commit_tensor(self, read_item, tensor): tensor.copy_(tensor.to(torch.bfloat16)) @@ -60,10 +60,10 @@ def _distributed_checkpoint_to_merged_weights( state_dict: Dict = {} save_path_ = Path(save_path) save_path_.mkdir(exist_ok=True) - dist_cp_format_utils._load_state_dict( # pylint: disable=protected-access + dist_cp_format_utils._load_state_dict( state_dict, storage_reader=dist_cp.FileSystemReader(checkpoint_dir), - planner=BFloat16CastPlanner(), # pylint: disable=protected-access + planner=BFloat16CastPlanner(), no_dist=True, ) @@ -147,7 +147,6 @@ def merge_fsdp_weights( ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist. """ checkpoint_dir_ = Path(checkpoint_dir) - from accelerate.state import PartialState if not is_torch_version(">=", "2.3.0"): raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`") @@ -184,7 +183,6 @@ def merge_fsdp_weights( if remove_checkpoint_dir: LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}") shutil.rmtree(checkpoint_dir_) - state.wait_for_everyone() def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): @@ -195,18 +193,37 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): config: Path to `axolotl` config YAML file. kwargs: Additional keyword arguments to override config file values. """ - # pylint: disable=duplicate-code - print_axolotl_text_art() + parsed_cfg = load_cfg(config, **kwargs) fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0" + if not fsdp_dir.exists(): + checkpoint_dir = determine_last_checkpoint(parsed_cfg, update=False) + if checkpoint_dir: + fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0" + if not fsdp_dir.exists(): + raise ValueError( + f"Could not find FSDP checkpoint `pytorch_model_fsdp_0` in {checkpoint_dir}" + ) + + output_path = str(Path(parsed_cfg.output_dir) / "merged") merge_fsdp_weights( checkpoint_dir=str(fsdp_dir), - output_path=str(Path(parsed_cfg.output_dir) / "merged"), + output_path=output_path, safe_serialization=True, ) + state = PartialState() + state.wait_for_everyone() + LOG.info( + f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}", + main_process_only=True, + ) + LOG.info( + "Merged weights are only the safetensors and doesn't include the model configuration " + f"or tokenizer which may be found in {parsed_cfg.output_dir}.", + main_process_only=True, + ) if __name__ == "__main__": - load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 4fdc102f9..af35dd801 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -1,5 +1,6 @@ """CLI to run preprocessing of a dataset.""" +import os import warnings from pathlib import Path from typing import Union @@ -8,11 +9,9 @@ import fire import transformers from accelerate import init_empty_weights from colorama import Fore -from dotenv import load_dotenv from transformers import AutoModelForCausalLM from axolotl.cli.args import PreprocessCliArgs -from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH @@ -35,10 +34,26 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: cfg: Dictionary mapping `axolotl` config keys to values. cli_args: Preprocessing-specific CLI arguments. """ - print_axolotl_text_art() check_accelerate_default_config() check_user_token() + if cli_args.iterable: + LOG.error( + "The --iterable CLI argument for 'axolotl preprocess' is no longer " + "supported. For training, set 'streaming: true' in your YAML config or " + "pass '--streaming' in your 'axolotl train' command for on-the-fly " + "preprocessing." + ) + return + + for key in ["skip_prepare_dataset", "pretraining_dataset"]: + if cfg.get(key): + LOG.error( + f"You have set `{key}:`. `preprocess` is not needed. Run the 'axolotl " + "train' CLI directly instead." + ) + return + if not cfg.dataset_prepared_path: msg = ( Fore.RED @@ -70,7 +85,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True ) - except Exception as exc: # pylint: disable=broad-exception-caught,unused-variable # nosec B110 # noqa F841 + except Exception: # nosec B110 pass # fmt: on @@ -92,8 +107,10 @@ def do_cli( config: Path to `axolotl` config YAML file. kwargs: Additional keyword arguments to override config file values. """ - # pylint: disable=duplicate-code - parsed_cfg = load_cfg(config, **kwargs) + + os.environ["AXOLOTL_IS_PREPROCESS"] = "1" + is_preprocess = kwargs.pop("is_preprocess", True) + parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs) parsed_cfg.is_preprocess = True parser = transformers.HfArgumentParser(PreprocessCliArgs) parsed_cli_args, _ = parser.parse_args_into_dataclasses( @@ -104,5 +121,4 @@ def do_cli( if __name__ == "__main__": - load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/quantize.py b/src/axolotl/cli/quantize.py index 63d51fadf..c11bcc6d9 100644 --- a/src/axolotl/cli/quantize.py +++ b/src/axolotl/cli/quantize.py @@ -5,13 +5,17 @@ CLI to post-training quantize a model using torchao from pathlib import Path from typing import Union -from transformers import AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig -from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg from axolotl.loaders import load_tokenizer from axolotl.utils.logging import get_logger -from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq +from axolotl.utils.quantization import ( + TorchAOQuantDType, + get_quantization_config, + quantization_config_to_str, + quantize_model, +) LOG = get_logger(__name__) @@ -27,7 +31,6 @@ def do_quantize( config (Union[Path, str]): The path to the config file cli_args (dict): Additional command-line arguments """ - print_axolotl_text_art() cfg = load_cfg(config) @@ -45,13 +48,13 @@ def do_quantize( "No quantization configuration found. Please specify either qat or quantization in your config file." ) - model_path = cli_args.get("model_path") or cfg.output_dir + model_path = cli_args.get("base_model") or cfg.output_dir if weight_dtype := cli_args.get("weight_dtype"): - weight_dtype = TorchIntDType[weight_dtype] + weight_dtype = TorchAOQuantDType.from_string(weight_dtype) else: weight_dtype = quantize_cfg.weight_dtype if activation_dtype := cli_args.get("activation_dtype"): - activation_dtype = TorchIntDType[activation_dtype] + activation_dtype = TorchAOQuantDType.from_string(activation_dtype) else: activation_dtype = quantize_cfg.activation_dtype group_size = cli_args.get("group_size") or quantize_cfg.group_size @@ -59,10 +62,15 @@ def do_quantize( cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding ) output_dir = cli_args.get("output_dir") or cfg.output_dir + hub_model_id = cli_args.get("hub_model_id") or cfg.hub_model_id - LOG.info(f"Loading model from {model_path}...") + LOG.info(f"Loading model from {model_path}.") tokenizer = load_tokenizer(cfg) - model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") + config = AutoConfig.from_pretrained(model_path) + torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None + model = AutoModelForCausalLM.from_pretrained( + model_path, device_map="auto", dtype=torch_dtype + ) LOG.info( f"Quantizing model with configuration: \n" @@ -72,11 +80,21 @@ def do_quantize( f"\tquantize_embedding: {quantize_embedding}" ) - quantize_model_for_ptq( + quantize_model( model, weight_dtype, group_size, activation_dtype, quantize_embedding ) - LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...") + quantization_config = get_quantization_config( + weight_dtype, activation_dtype, group_size + ) + + ao_config = TorchAoConfig( + quant_type=quantization_config, + include_input_output_embeddings=quantize_embedding, + ) + model.config.quantization_config = ao_config + + LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.") model.save_pretrained( str(Path(output_dir) / "quantized"), safe_serialization=False, @@ -86,5 +104,16 @@ def do_quantize( str(Path(output_dir) / "quantized"), safe_serialization=False, progressbar=True, + save_jinja_files=cfg.tokenizer_save_jinja_files, ) - LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...") + + if hub_model_id: + hub_model_id = ( + hub_model_id.rstrip("-") + + f"-{quantization_config_to_str[type(quantization_config)]}" + ) + model.push_to_hub(hub_model_id, safe_serialization=False) + tokenizer.push_to_hub(hub_model_id) + LOG.info(f"Quantized model pushed to: {hub_model_id}.") + + LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.") diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index fef80fdba..6b3bfbd57 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -7,19 +7,17 @@ from typing import Union import fire from accelerate import Accelerator -from dotenv import load_dotenv from transformers.hf_argparser import HfArgumentParser from axolotl.cli.args import TrainerCliArgs -from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.integrations.base import PluginManager from axolotl.train import train -from axolotl.utils import patch_optimized_env from axolotl.utils.config import normalize_config, resolve_dtype from axolotl.utils.dict import DictDefault +from axolotl.utils.trainer import prepare_optim_env def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): @@ -32,10 +30,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): cfg: Dictionary mapping `axolotl` config keys to values. cli_args: Training-specific CLI arguments. """ - # Enable expandable segments for cuda allocation to improve VRAM usage - patch_optimized_env() - - print_axolotl_text_art() check_accelerate_default_config() if int(os.getenv("LOCAL_RANK", "0")) == 0: check_user_token() @@ -66,7 +60,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): config: Path to `axolotl` config YAML file. kwargs: Additional keyword arguments to override config file values. """ - # pylint: disable=duplicate-code parsed_cfg = load_cfg(config, **kwargs) parser = HfArgumentParser(TrainerCliArgs) parsed_cli_args, _ = parser.parse_args_into_dataclasses( @@ -99,23 +92,30 @@ def ray_train_func(kwargs: dict): # cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict) # also renormalize the config now that TorchTrainer has spawned distributed workers cfg = DictDefault(kwargs["cfg"]) + prepare_optim_env(cfg) normalize_config(cfg) # now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype resolve_dtype(cfg) # ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict - if cfg.deepspeed: + if cfg.deepspeed and hasattr(cfg.deepspeed, "to_dict"): cfg.deepspeed = cfg.deepspeed.to_dict() # initialize accelerator before model instantiation Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps) + # Register plugins in Ray workers + if cfg.get("plugins"): + from axolotl.cli.config import plugin_set_cfg, prepare_plugins + + prepare_plugins(cfg) + plugin_set_cfg(cfg) + kwargs["cfg"] = cfg do_train(**kwargs) if __name__ == "__main__": - load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py deleted file mode 100644 index d28795361..000000000 --- a/src/axolotl/cli/utils.py +++ /dev/null @@ -1,330 +0,0 @@ -"""Utility methods for axolotl CLI.""" - -import concurrent.futures -import dataclasses -import hashlib -import json -from functools import wraps -from pathlib import Path -from types import NoneType -from typing import Any, Callable, Type, Union, get_args, get_origin - -import click -import requests -from pydantic import BaseModel -from transformers import ( - PreTrainedModel, - PreTrainedTokenizer, - PreTrainedTokenizerFast, - ProcessorMixin, -) - -from axolotl.loaders import load_processor, load_tokenizer -from axolotl.loaders.model import ModelLoader -from axolotl.utils.dict import DictDefault -from axolotl.utils.logging import get_logger - -LOG = get_logger(__name__) - - -def strip_optional_type(field_type: type | str | None): - """ - Extracts the non-`None` type from an `Optional` / `Union` type. - - Args: - field_type: Type of field for Axolotl CLI command. - - Returns: - If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise - returns the input type unchanged. - """ - if get_origin(field_type) is Union and type(None) in get_args(field_type): - field_type = next( - t for t in get_args(field_type) if not isinstance(t, NoneType) - ) - - return field_type - - -def filter_none_kwargs(func: Callable) -> Callable: - """ - Wraps function to remove `None`-valued `kwargs`. - - Args: - func: Function to wrap. - - Returns: - Wrapped function. - """ - - @wraps(func) - def wrapper(*args, **kwargs) -> Callable: - """Filters out `None`-valued `kwargs`.""" - filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} - - return func(*args, **filtered_kwargs) - - return wrapper - - -def add_options_from_dataclass(config_class: Type[Any]) -> Callable: - """ - Create Click options from the fields of a dataclass. - - Args: - config_class: Dataclass with fields to parse from the CLI. - - Returns: - Function decorator for Axolotl CLI command. - """ - - def decorator(function: Callable) -> Callable: - # Process dataclass fields in reverse order for correct option ordering - for field in reversed(dataclasses.fields(config_class)): - field_type = strip_optional_type(field.type) - - if field_type == bool: - field_name = field.name.replace("_", "-") - option_name = f"--{field_name}/--no-{field_name}" - function = click.option( - option_name, - default=field.default, - help=field.metadata.get("description"), - )(function) - else: - option_name = f"--{field.name.replace('_', '-')}" - function = click.option( - option_name, - type=field_type, - default=field.default, - help=field.metadata.get("description"), - )(function) - - return function - - return decorator - - -def add_options_from_config(config_class: Type[BaseModel]) -> Callable: - """ - Create Click options from the fields of a Pydantic model. - - Args: - config_class: PyDantic model with fields to parse from the CLI - - Returns: - Function decorator for Axolotl CLI command. - """ - - def decorator(function: Callable) -> Callable: - # Process model fields in reverse order for correct option ordering - for name, field in reversed(config_class.model_fields.items()): - field_type = strip_optional_type(field.annotation) - - if field_type == bool: - field_name = name.replace("_", "-") - option_name = f"--{field_name}/--no-{field_name}" - function = click.option( - option_name, default=None, help=field.description - )(function) - else: - option_name = f"--{name.replace('_', '-')}" - function = click.option( - option_name, default=None, help=field.description - )(function) - - return function - - return decorator - - -def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]: - """ - Build command list from base command and options. - - Args: - base_cmd: Command without options. - options: Options to parse and append to base command. - - Returns: - List of strings giving shell command. - """ - cmd = base_cmd.copy() - - for key, value in options.items(): - if value is None: - continue - - key = key.replace("_", "-") - - if isinstance(value, bool): - if value: - cmd.append(f"--{key}") - else: - cmd.extend([f"--{key}", str(value)]) - - return cmd - - -def download_file( - file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str -) -> tuple[str, str]: - """ - Download a single file and return its processing status. - - Args: - file_info: Tuple of (file_path, remote_sha). - raw_base_url: Base URL for raw GitHub content. - dest_path: Local destination directory. - dir_prefix: Directory prefix to filter files. - - Returns: - Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'. - """ - file_path, remote_sha = file_info - raw_url = f"{raw_base_url}/{file_path}" - dest_file = dest_path / file_path.split(dir_prefix)[-1] - - # Check if file exists and needs updating - if dest_file.exists(): - with open(dest_file, "rb") as file: - content = file.read() - # Calculate git blob SHA - blob = b"blob " + str(len(content)).encode() + b"\0" + content - local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest() - - if local_sha == remote_sha: - print(f"Skipping {file_path} (unchanged)") - return file_path, "unchanged" - - print(f"Updating {file_path}") - status = "new" - else: - print(f"Downloading {file_path}") - status = "new" - - # Create directories if needed - dest_file.parent.mkdir(parents=True, exist_ok=True) - - # Download and save file - try: - response = requests.get(raw_url, timeout=30) - response.raise_for_status() - - with open(dest_file, "wb") as file: - file.write(response.content) - - return file_path, status - except (requests.RequestException, IOError) as request_error: - print(f"Error downloading {file_path}: {str(request_error)}") - return file_path, "error" - - -def fetch_from_github( - dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5 -) -> None: - """ - Sync files from a specific directory in the GitHub repository. - Only downloads files that don't exist locally or have changed. - - Args: - dir_prefix: Directory prefix to filter files (e.g., 'examples/', - 'deepspeed_configs/'). - dest_dir: Local destination directory. - max_workers: Maximum number of concurrent downloads. - """ - api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1" - raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main" - - # Get repository tree with timeout - response = requests.get(api_url, timeout=30) - response.raise_for_status() - tree = json.loads(response.text) - - # Filter for files and get their SHA - files = { - item["path"]: item["sha"] - for item in tree["tree"] - if item["type"] == "blob" and item["path"].startswith(dir_prefix) - } - - if not files: - raise click.ClickException(f"No files found in {dir_prefix}") - - # Default destination directory is the last part of dir_prefix - default_dest = Path(dir_prefix.rstrip("/")) - dest_path = Path(dest_dir) if dest_dir else default_dest - - # Keep track of processed files for summary - files_processed: dict[str, list[str]] = { - "new": [], - "updated": [], - "unchanged": [], - "error": [], - } - - # Process files in parallel using ThreadPoolExecutor - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_file = { - executor.submit( - download_file, - (file_path, remote_sha), - raw_base_url, - dest_path, - dir_prefix, - ): file_path - for file_path, remote_sha in files.items() - } - - # Process completed tasks as they finish - for future in concurrent.futures.as_completed(future_to_file): - file_path = future_to_file[future] - try: - file_path, status = future.result() - files_processed[status].append(file_path) - except (requests.RequestException, IOError) as request_error: - print(f"Error processing {file_path}: {str(request_error)}") - files_processed["error"].append(file_path) - - # Log summary - LOG.info("\nSync Summary:") - LOG.info(f"New files: {len(files_processed['new'])}") - LOG.info(f"Updated files: {len(files_processed['updated'])}") - LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}") - if files_processed["error"]: - LOG.info(f"Failed files: {len(files_processed['error'])}") - - -def load_model_and_tokenizer( - *, - cfg: DictDefault, - inference: bool = False, -) -> tuple[ - PreTrainedModel, - PreTrainedTokenizer | PreTrainedTokenizerFast | Any, - ProcessorMixin | None, -]: - """ - Helper function for loading a model, tokenizer, and processor specified in the given `axolotl` - config. - - Args: - cfg: Dictionary mapping `axolotl` config keys to values. - inference: Boolean denoting inference mode. - - Returns: - Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin). - """ - LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") - tokenizer = load_tokenizer(cfg) - - LOG.info("loading model...") - model_loader = ModelLoader(cfg, tokenizer, inference=inference) - model, _ = model_loader.load() - - processor = None - if cfg.is_multimodal: - LOG.info("loading processor...") - processor = load_processor(cfg, tokenizer) - - return model, tokenizer, processor diff --git a/src/axolotl/cli/utils/__init__.py b/src/axolotl/cli/utils/__init__.py new file mode 100644 index 000000000..583130339 --- /dev/null +++ b/src/axolotl/cli/utils/__init__.py @@ -0,0 +1,23 @@ +"""Init for axolotl.cli.utils module.""" + +from .args import ( + add_options_from_config, + add_options_from_dataclass, + filter_none_kwargs, +) +from .fetch import fetch_from_github +from .load import load_model_and_tokenizer +from .sweeps import generate_sweep_configs +from .train import build_command, generate_config_files, launch_training + +__all__ = [ + "filter_none_kwargs", + "add_options_from_dataclass", + "add_options_from_config", + "build_command", + "generate_config_files", + "generate_sweep_configs", + "load_model_and_tokenizer", + "launch_training", + "fetch_from_github", +] diff --git a/src/axolotl/cli/utils/args.py b/src/axolotl/cli/utils/args.py new file mode 100644 index 000000000..0aec737b8 --- /dev/null +++ b/src/axolotl/cli/utils/args.py @@ -0,0 +1,120 @@ +"""Utilities for axolotl CLI args.""" + +import dataclasses +from functools import wraps +from types import NoneType +from typing import Any, Callable, Type, Union, get_args, get_origin + +import click +from pydantic import BaseModel + + +def _strip_optional_type(field_type: type | str | None): + """ + Extracts the non-`None` type from an `Optional` / `Union` type. + + Args: + field_type: Type of field for Axolotl CLI command. + + Returns: + If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise + returns the input type unchanged. + """ + if get_origin(field_type) is Union and type(None) in get_args(field_type): + field_type = next( + t for t in get_args(field_type) if not isinstance(t, NoneType) + ) + + return field_type + + +def filter_none_kwargs(func: Callable) -> Callable: + """ + Wraps function to remove `None`-valued `kwargs`. + + Args: + func: Function to wrap. + + Returns: + Wrapped function. + """ + + @wraps(func) + def wrapper(*args, **kwargs) -> Callable: + """Filters out `None`-valued `kwargs`.""" + filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} + + return func(*args, **filtered_kwargs) + + return wrapper + + +def add_options_from_dataclass(config_class: Type[Any]) -> Callable: + """ + Create Click options from the fields of a dataclass. + + Args: + config_class: Dataclass with fields to parse from the CLI. + + Returns: + Function decorator for Axolotl CLI command. + """ + + def decorator(function: Callable) -> Callable: + # Process dataclass fields in reverse order for correct option ordering + for field in reversed(dataclasses.fields(config_class)): + field_type = _strip_optional_type(field.type) + + if field_type is bool: + field_name = field.name.replace("_", "-") + option_name = f"--{field_name}/--no-{field_name}" + function = click.option( + option_name, + default=field.default, + help=field.metadata.get("description"), + )(function) + else: + option_name = f"--{field.name.replace('_', '-')}" + function = click.option( + option_name, + type=field_type, + default=field.default, + help=field.metadata.get("description"), + )(function) + + return function + + return decorator + + +def add_options_from_config(config_class: Type[BaseModel]) -> Callable: + """ + Create Click options from the fields of a Pydantic model. + + Args: + config_class: PyDantic model with fields to parse from the CLI + + Returns: + Function decorator for Axolotl CLI command. + """ + + def decorator(function: Callable) -> Callable: + # Process model fields in reverse order for correct option ordering + for name, field in reversed(config_class.model_fields.items()): + field_type = _strip_optional_type(field.annotation) + + if field_type is bool: + field_name = name.replace("_", "-") + option_name = f"--{field_name}/--no-{field_name}" + function = click.option( + option_name, default=None, help=field.description + )(function) + else: + option_name = f"--{name.replace('_', '-')}" + function = click.option( + option_name, default=None, help=field.description + )(function) + + return function + + return decorator diff --git a/src/axolotl/cli/utils/diffusion.py b/src/axolotl/cli/utils/diffusion.py new file mode 100644 index 000000000..1157bfd66 --- /dev/null +++ b/src/axolotl/cli/utils/diffusion.py @@ -0,0 +1,374 @@ +"""Helpers for diffusion-mode inference in CLI and Gradio.""" + +from __future__ import annotations + +import gradio as gr +from colorama import Fore, Style + +from axolotl.integrations.diffusion import generate, resolve_mask_token_id +from axolotl.utils.dict import DictDefault + + +def diffusion_inference( + model, + tokenizer, + cfg, + prompt: str, + chat_template_str: str | None = None, +): + """Diffusion inference helper method.""" + mode = "random" + completion_tokens = 0 + target_mask_ratio = None + mode, completion_tokens, target_mask_ratio, cleaned = _parse_commands(prompt) + + if cleaned: + prompt = cleaned + + info = run_diffusion( + model=model, + tokenizer=tokenizer, + cfg=cfg, + prompt=prompt, + chat_template_str=chat_template_str, + mode=mode, + target_mask_ratio=target_mask_ratio, + completion_tokens=completion_tokens, + ) + masked_text = info["masked_text"] + mask_ratio = info["mask_ratio"] + generated_ids = info["generated_ids"] + masked_positions = info["masked_positions"] + orig_ids = info["orig_ids"] + + # Display with masked preview and colored diff + if masked_text is not None and mask_ratio is not None: + print(f"Masked ({mask_ratio:.1%}):\n{masked_text}\n") + if generated_ids is not None: + # Compute per-token style + styles: list[str] = [] + for i, tid in enumerate(generated_ids): + if i in masked_positions: + if i < len(orig_ids) and tid == orig_ids[i]: + styles.append("green") # correct fill + elif i < len(orig_ids): + styles.append("red") # incorrect fill + else: + styles.append("normal") # appended + else: + same = i < len(orig_ids) and tid == orig_ids[i] + styles.append("dim" if same else "normal") + + # Group contiguous spans by style + styled_spans: list[tuple[str, int, int]] = [] + if generated_ids: + current_style = styles[0] + start = 0 + for i in range(1, len(generated_ids)): + s = styles[i] + if s != current_style: + styled_spans.append((current_style, start, i)) + current_style, start = s, i + styled_spans.append((current_style, start, len(generated_ids))) + + out_parts = [] + for style_name, a, b in styled_spans: + chunk_text = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False) + if style_name == "green": + out_parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL) + elif style_name == "red": + out_parts.append(Fore.RED + chunk_text + Style.RESET_ALL) + else: + if style_name == "dim": + out_parts.append(Style.DIM + chunk_text + Style.RESET_ALL) + else: + out_parts.append(chunk_text) + print("Generated:\n" + "".join(out_parts)) + else: + print("Generated:\n(no output)") + + +def _parse_commands(text: str): + """ + Parse leading diffusion commands. + + Supported at start of input (can be chained): + :complete N -> completion mode with N tokens (default 64) + :mask R -> random masking with ratio R in [0, 1] + """ + tokens = text.strip().split() + i = 0 + mode = "random" + completion_tokens = 0 + target_mask_ratio = None + consumed = 0 + while i < len(tokens) and tokens[i].startswith(":"): + cmd = tokens[i] + i += 1 + consumed = i + if cmd == ":complete": + mode = "completion" + if i < len(tokens): + try: + completion_tokens = int(tokens[i]) + i += 1 + consumed = i + except Exception: + completion_tokens = 64 + else: + completion_tokens = 64 + elif cmd == ":mask": + mode = "random" + if i < len(tokens): + try: + target_mask_ratio = float(tokens[i]) + i += 1 + consumed = i + except Exception: + target_mask_ratio = None + else: + i -= 1 + consumed = i + break + + cleaned = " ".join(tokens[consumed:]) + + return mode, completion_tokens, target_mask_ratio, cleaned + + +def run_diffusion( + *, + model, + tokenizer, + cfg: DictDefault, + prompt: str, + chat_template_str: str | None, + mode: str = "random", + target_mask_ratio: float | None = None, + completion_tokens: int = 0, +): + """Run a single diffusion generation and return a structured result dict.""" + if chat_template_str: + batch = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + return_tensors="pt", + add_special_tokens=True, + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=True, + return_dict=True, + ) + else: + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + + mask_token_id = resolve_mask_token_id(tokenizer, cfg, allow_add=False) + + seq = batch["input_ids"].to(cfg.device) + gen_mode = "completion" if mode == "completion" else "random" + comp_tokens = int(completion_tokens) if gen_mode == "completion" else 0 + + result = generate( + model, + tokenizer, + original_sequence=seq[:1], + num_diffusion_steps=cfg.diffusion.num_diffusion_steps, + temperature=cfg.diffusion.generation_temperature, + mask_token_id=int(mask_token_id), + mode=gen_mode, # type: ignore[arg-type] + completion_tokens=comp_tokens, + target_mask_ratio=target_mask_ratio, + ) + + masked_text = result.get("masked") if isinstance(result, dict) else None + mask_ratio = result.get("mask_ratio") if isinstance(result, dict) else None + generated_ids = result.get("generated_ids") if isinstance(result, dict) else None + masked_positions = ( + set(result.get("masked_positions") or []) if isinstance(result, dict) else set() + ) + orig_ids = seq[0].detach().cpu().tolist() + + return { + "masked_text": masked_text, + "mask_ratio": mask_ratio, + "generated_ids": generated_ids, + "masked_positions": masked_positions, + "orig_ids": orig_ids, + } + + +def render_html( + *, + generated_ids: list[int] | None, + orig_ids: list[int], + masked_positions: set[int], + tokenizer, +) -> str: + """Render HTML visualizing diffusion outputs.""" + if not generated_ids: + return "
Generated:\n(no output)
" + + def _style_for(i: int, tid: int) -> str: + if i in masked_positions: + if i < len(orig_ids) and tid == orig_ids[i]: + return "green" + if i < len(orig_ids): + return "red" + return "normal" + same = i < len(orig_ids) and tid == orig_ids[i] + return "dim" if same else "normal" + + # Group contiguous spans by style to reduce HTML size + spans: list[tuple[str, int, int]] = [] + if generated_ids: + cur = _style_for(0, generated_ids[0]) + start = 0 + for i in range(1, len(generated_ids)): + s = _style_for(i, generated_ids[i]) + if s != cur: + spans.append((cur, start, i)) + cur, start = s, i + spans.append((cur, start, len(generated_ids))) + + html_parts = [] + for style_name, a, b in spans: + txt = tokenizer.decode(generated_ids[a:b], skip_special_tokens=False) + if style_name == "green": + html_parts.append(f'{txt}') + elif style_name == "red": + html_parts.append(f'{txt}') + elif style_name == "dim": + html_parts.append(f'{txt}') + else: + html_parts.append(txt) + + legend = ( + '
' + 'correct, ' + 'incorrect, ' + 'unchanged' + "
" + ) + + return ( + legend + + '
Generated:\n'
+        + "".join(html_parts)
+        + "
" + ) + + +def launch_diffusion_gradio_ui( + *, + model, + tokenizer, + cfg: DictDefault, + prompter_module=None, + chat_template_str: str | None = None, +): + """Build and launch a simple Gradio UI for diffusion inference.""" + with gr.Blocks( + title=cfg.get("gradio_title", "Axolotl Diffusion Interface") + ) as demo: + gr.Markdown( + """ + ## Axolotl Diffusion Inference + - Mode "Random" masks tokens at a target ratio and fills them. + - Mode "Completion" appends N masked tokens at the end and fills them. + """ + ) + + with gr.Row(): + mode = gr.Radio( + choices=["random", "completion"], + value="random", + label="Mode", + ) + mask_ratio = gr.Slider( + minimum=0.0, + maximum=1.0, + step=0.05, + value=0.4, + label="Mask ratio (random mode)", + interactive=True, + ) + completion_tokens = gr.Number( + value=64, + precision=0, + label="Completion tokens (completion mode)", + interactive=True, + visible=False, + ) + + instruction = gr.Textbox(label="Instruction", lines=6) + run_btn = gr.Button("Generate") + + masked_preview = gr.Textbox(label="Masked preview", lines=6) + html_out = gr.HTML(label="Generated") + + def _toggle_controls(selected_mode: str): + return ( + gr.update(visible=(selected_mode == "random")), + gr.update(visible=(selected_mode == "completion")), + ) + + mode.change( + _toggle_controls, + inputs=[mode], + outputs=[mask_ratio, completion_tokens], + ) + + def _gen(instruction_text: str, selected_mode: str, mratio: float, ctoks: int): + if not instruction_text: + return "", "
Generated:\n(no output)
" + + if prompter_module: + prompt: str = next( + prompter_module().build_prompt( + instruction=instruction_text.strip("\n") + ) + ) + else: + prompt = instruction_text.strip() + + info = run_diffusion( + model=model, + tokenizer=tokenizer, + cfg=cfg, + prompt=prompt, + chat_template_str=chat_template_str, + mode=selected_mode, + target_mask_ratio=mratio if selected_mode == "random" else None, + completion_tokens=int(ctoks) if selected_mode == "completion" else 0, + ) + + masked_text = info.get("masked_text") + mask_ratio_val = info.get("mask_ratio") + generated_ids = info.get("generated_ids") + masked_positions = info.get("masked_positions") or set() + orig_ids = info.get("orig_ids") or [] + + preview = ( + f"Masked ({mask_ratio_val:.1%}):\n{masked_text}" + if masked_text is not None and mask_ratio_val is not None + else "" + ) + html = render_html( + generated_ids=generated_ids, + orig_ids=orig_ids, + masked_positions=masked_positions, + tokenizer=tokenizer, + ) + return preview, html + + run_btn.click( + _gen, + inputs=[instruction, mode, mask_ratio, completion_tokens], + outputs=[masked_preview, html_out], + ) + + demo.queue().launch( + show_api=False, + share=cfg.get("gradio_share", True), + server_name=cfg.get("gradio_server_name", "127.0.0.1"), + server_port=cfg.get("gradio_server_port", None), + ) diff --git a/src/axolotl/cli/utils/fetch.py b/src/axolotl/cli/utils/fetch.py new file mode 100644 index 000000000..441b7f6f7 --- /dev/null +++ b/src/axolotl/cli/utils/fetch.py @@ -0,0 +1,142 @@ +"""Utilities for axolotl fetch CLI command.""" + +import concurrent.futures +import hashlib +import json +from pathlib import Path + +import click +import requests + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def _download_file( + file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str +) -> tuple[str, str]: + """ + Download a single file and return its processing status. + + Args: + file_info: Tuple of (file_path, remote_sha). + raw_base_url: Base URL for raw GitHub content. + dest_path: Local destination directory. + dir_prefix: Directory prefix to filter files. + + Returns: + Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'. + """ + file_path, remote_sha = file_info + raw_url = f"{raw_base_url}/{file_path}" + dest_file = dest_path / file_path.split(dir_prefix)[-1] + + # Check if file exists and needs updating + if dest_file.exists(): + with open(dest_file, "rb") as file: + content = file.read() + # Calculate git blob SHA + blob = b"blob " + str(len(content)).encode() + b"\0" + content + local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest() + + if local_sha == remote_sha: + print(f"Skipping {file_path} (unchanged)") + return file_path, "unchanged" + + print(f"Updating {file_path}") + status = "updated" + else: + print(f"Downloading {file_path}") + status = "new" + + # Create directories if needed + dest_file.parent.mkdir(parents=True, exist_ok=True) + + # Download and save file + try: + response = requests.get(raw_url, timeout=30) + response.raise_for_status() + + with open(dest_file, "wb") as file: + file.write(response.content) + + return file_path, status + except (requests.RequestException, IOError) as request_error: + print(f"Error downloading {file_path}: {str(request_error)}") + return file_path, "error" + + +def fetch_from_github( + dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5 +) -> None: + """ + Sync files from a specific directory in the GitHub repository. + Only downloads files that don't exist locally or have changed. + + Args: + dir_prefix: Directory prefix to filter files (e.g., 'examples/', + 'deepspeed_configs/'). + dest_dir: Local destination directory. + max_workers: Maximum number of concurrent downloads. + """ + api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1" + raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main" + + # Get repository tree with timeout + response = requests.get(api_url, timeout=30) + response.raise_for_status() + tree = json.loads(response.text) + + # Filter for files and get their SHA + files = { + item["path"]: item["sha"] + for item in tree["tree"] + if item["type"] == "blob" and item["path"].startswith(dir_prefix) + } + + if not files: + raise click.ClickException(f"No files found in {dir_prefix}") + + # Default destination directory is the last part of dir_prefix + default_dest = Path(dir_prefix.rstrip("/")) + dest_path = Path(dest_dir) if dest_dir else default_dest + + # Keep track of processed files for summary + files_processed: dict[str, list[str]] = { + "new": [], + "updated": [], + "unchanged": [], + "error": [], + } + + # Process files in parallel using ThreadPoolExecutor + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_file = { + executor.submit( + _download_file, + (file_path, remote_sha), + raw_base_url, + dest_path, + dir_prefix, + ): file_path + for file_path, remote_sha in files.items() + } + + # Process completed tasks as they finish + for future in concurrent.futures.as_completed(future_to_file): + file_path = future_to_file[future] + try: + file_path, status = future.result() + files_processed[status].append(file_path) + except (requests.RequestException, IOError) as request_error: + print(f"Error processing {file_path}: {str(request_error)}") + files_processed["error"].append(file_path) + + # Log summary + LOG.info("\nSync Summary:") + LOG.info(f"New files: {len(files_processed['new'])}") + LOG.info(f"Updated files: {len(files_processed['updated'])}") + LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}") + if files_processed["error"]: + LOG.info(f"Failed files: {len(files_processed['error'])}") diff --git a/src/axolotl/cli/utils/load.py b/src/axolotl/cli/utils/load.py new file mode 100644 index 000000000..610a81306 --- /dev/null +++ b/src/axolotl/cli/utils/load.py @@ -0,0 +1,52 @@ +"""Utilities for model, tokenizer, etc. loading.""" + +from typing import Any + +from transformers import ( + PreTrainedModel, + PreTrainedTokenizer, + PreTrainedTokenizerFast, + ProcessorMixin, +) + +from axolotl.loaders import load_processor, load_tokenizer +from axolotl.loaders.model import ModelLoader +from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def load_model_and_tokenizer( + *, + cfg: DictDefault, + inference: bool = False, +) -> tuple[ + PreTrainedModel, + PreTrainedTokenizer | PreTrainedTokenizerFast | Any, + ProcessorMixin | None, +]: + """ + Helper function for loading a model, tokenizer, and processor specified in the + given `axolotl` config. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + inference: Boolean denoting inference mode. + + Returns: + Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin). + """ + LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") + tokenizer = load_tokenizer(cfg) + + LOG.info("loading model...") + model_loader = ModelLoader(cfg, tokenizer, inference=inference) + model, _ = model_loader.load() + + processor = None + if cfg.is_multimodal: + LOG.info("loading processor...") + processor = load_processor(cfg, tokenizer) + + return model, tokenizer, processor diff --git a/src/axolotl/cli/sweeps.py b/src/axolotl/cli/utils/sweeps.py similarity index 92% rename from src/axolotl/cli/sweeps.py rename to src/axolotl/cli/utils/sweeps.py index d21664964..2a0aa1367 100644 --- a/src/axolotl/cli/sweeps.py +++ b/src/axolotl/cli/utils/sweeps.py @@ -3,11 +3,12 @@ import random from copy import deepcopy from itertools import product +from typing import Any def generate_sweep_configs( base_config: dict[str, list], sweeps_config: dict[str, list] -) -> list[dict[str, list]]: +) -> list[dict[str, Any]]: """ Recursively generates all possible configurations by applying sweeps to the base config. @@ -48,7 +49,10 @@ def generate_sweep_configs( new_config = {} # new_config = deepcopy(base_config) # Combine regular parameters with paired parameters - full_combo = {**dict(zip(param_names, reg_combo)), **paired_set} + full_combo = { + **dict(zip(param_names, reg_combo, strict=False)), + **paired_set, + } for param_name, param_value in full_combo.items(): new_config[param_name] = param_value print(new_config) @@ -57,7 +61,7 @@ def generate_sweep_configs( # If no paired values, just use regular combinations # new_config = deepcopy(base_config) new_config = {} - for param_name, param_value in zip(param_names, reg_combo): + for param_name, param_value in zip(param_names, reg_combo, strict=False): new_config[param_name] = param_value print(new_config) all_combinations.append(new_config) diff --git a/src/axolotl/cli/utils/train.py b/src/axolotl/cli/utils/train.py new file mode 100644 index 000000000..6ce7d8df3 --- /dev/null +++ b/src/axolotl/cli/utils/train.py @@ -0,0 +1,225 @@ +"""Utilities for axolotl train CLI command.""" + +import os +import subprocess # nosec +import sys +import tempfile +from pathlib import Path +from typing import Any, Iterator, Literal + +import yaml + +from axolotl.cli.utils.sweeps import generate_sweep_configs + + +def _add_default_rdzv_args(launcher_args: list[str]) -> list[str]: + """ + Add default RDZV arguments if rdzv_endpoint is set but rdzv_backend/rdzv_id are missing. + + Args: + launcher_args: List of launcher arguments + + Returns: + Updated launcher args with defaults added if needed + """ + args = launcher_args.copy() + + # Check if rdzv_endpoint is present + has_rdzv_endpoint = any("--rdzv_endpoint" in arg for arg in args) + + if has_rdzv_endpoint: + # Check if rdzv_backend is already provided + has_rdzv_backend = any("--rdzv_backend" in arg for arg in args) + if not has_rdzv_backend: + args.extend(["--rdzv_backend", "c10d"]) + + # Check if rdzv_id is already provided + has_rdzv_id = any("--rdzv_id" in arg for arg in args) + if not has_rdzv_id: + import uuid + + args.extend(["--rdzv_id", str(uuid.uuid4())[:8]]) + + return args + + +def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]: + """ + Build command list from base command and options. + + Args: + base_cmd: Command without options. + options: Options to parse and append to base command. + + Returns: + List of strings giving shell command. + """ + cmd = base_cmd.copy() + + for key, value in options.items(): + if value is None: + continue + + key = key.replace("_", "-") + cmd.append(f"--{key}={value}") + + return cmd + + +def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]: + """ + Generate list of configuration files to process. Yields a tuple of the configuration file name and a boolean indicating + whether this is a group of configurations (i.e., a sweep). + + Args: + config: Base configuration file + sweep: Sweep configuration file + """ + + if not sweep: + yield config, False + return + + # Load sweep and base configurations + with open(sweep, "r", encoding="utf-8") as fin: + sweep_config: dict[str, list] = yaml.safe_load(fin) + with open(config, "r", encoding="utf-8") as fin: + base_config: dict[str, list] = yaml.safe_load(fin) + + # Generate all possible configurations + permutations = generate_sweep_configs(base_config, sweep_config) + is_group = len(permutations) > 1 + base_output_dir = base_config.get("output_dir", "./model-out") + for idx, permutation in enumerate(permutations, start=1): + permutation_dir = Path(permutation.get("output_dir", base_output_dir)) + permutation_id = f"sweep{idx:04d}" + permutation["output_dir"] = str(permutation_dir / permutation_id) + + temp_file = tempfile.NamedTemporaryFile( + mode="w", + suffix=".yaml", + delete=False, + encoding="utf-8", + ) + yaml.dump(permutation, temp_file) + temp_file.close() + yield temp_file.name, is_group + + +def launch_training( + cfg_file: str, + launcher: Literal["accelerate", "torchrun", "python"] | None, + cloud: str | None, + kwargs: dict, + launcher_args: list[str] | None = None, + use_exec: bool = False, +) -> None: + """Execute training with the given configuration.""" + launcher_args = launcher_args or [] + + if cloud: + _launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args) + elif launcher: + if launcher == "accelerate": + _launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec) + elif launcher == "torchrun": + _launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec) + elif launcher == "python": + _launch_python_training(cfg_file, kwargs) + elif launcher is None: + # handle ray train launch + _launch_python_training(cfg_file, kwargs) + + +def _launch_cloud_training( + cloud: str, + cfg_file: str, + launcher: Literal["accelerate", "torchrun", "python"] | None, + kwargs: dict, + launcher_args: list[str] | None = None, +) -> None: + """Execute training via cloud launcher.""" + from axolotl.cli.cloud import do_cli_train + + launcher_args = launcher_args or [] + cwd = os.getcwd() if launcher else None + + do_cli_train( + cloud_config=cloud, + config=cfg_file, + launcher=launcher or "accelerate", + launcher_args=launcher_args, + cwd=cwd, + **kwargs, + ) + + +def _launch_accelerate_training( + cfg_file: str, + kwargs: dict, + launcher_args: list[str] | None = None, + use_exec: bool = False, +) -> None: + """Execute training via accelerate launcher.""" + launcher_args = launcher_args or [] + internal_launcher_args = [] + + # Extract launcher-specific arguments from kwargs (legacy support) + if "main_process_port" in kwargs: + main_process_port = kwargs.pop("main_process_port") + internal_launcher_args.extend(["--main_process_port", str(main_process_port)]) + + if "num_processes" in kwargs: + num_processes = kwargs.pop("num_processes") + internal_launcher_args.extend(["--num_processes", str(num_processes)]) + + # Combine internal args with user-provided launcher args + all_launcher_args = internal_launcher_args + launcher_args + + base_cmd = ( + ["accelerate", "launch"] + all_launcher_args + ["-m", "axolotl.cli.train"] + ) + if cfg_file: + base_cmd.append(cfg_file) + + cmd = build_command(base_cmd, kwargs) + if use_exec: + # make sure to flush stdout and stderr before replacing the process + sys.stdout.flush() + sys.stderr.flush() + os.execvpe(cmd[0], cmd, os.environ) # nosec B606 + else: + subprocess.run(cmd, check=True) # nosec B603 + + +def _launch_torchrun_training( + cfg_file: str, + kwargs: dict, + launcher_args: list[str] | None = None, + use_exec: bool = False, +) -> None: + """Execute training via torchrun launcher.""" + launcher_args = launcher_args or [] + + # Add default RDZV arguments if rdzv_endpoint is set + launcher_args = _add_default_rdzv_args(launcher_args) + + base_cmd = ["torchrun"] + launcher_args + ["-m", "axolotl.cli.train"] + if cfg_file: + base_cmd.append(cfg_file) + + cmd = build_command(base_cmd, kwargs) + if use_exec: + # make sure to flush stdout and stderr before replacing the process + sys.stdout.flush() + sys.stderr.flush() + os.execvpe(cmd[0], cmd, os.environ) # nosec B606 + else: + subprocess.run(cmd, check=True) # nosec B603 + + +def _launch_python_training(cfg_file: str, kwargs: dict) -> None: + """Execute training via python launcher.""" + from axolotl.cli.train import do_cli + + do_cli(config=cfg_file, **kwargs) diff --git a/src/axolotl/cli/vllm_serve.py b/src/axolotl/cli/vllm_serve.py index 448b25a7e..ea454fc96 100644 --- a/src/axolotl/cli/vllm_serve.py +++ b/src/axolotl/cli/vllm_serve.py @@ -2,12 +2,10 @@ CLI to start the vllm server for online RL """ -import os from dataclasses import dataclass, field from pathlib import Path from typing import Union -import trl from trl.scripts.vllm_serve import ScriptArguments from axolotl.cli.config import load_cfg @@ -37,16 +35,22 @@ def do_vllm_serve( Returns: process_id: the process id of the started VLLM server """ - patch_vllm_worker() cfg = load_cfg(config) model = cfg.base_model serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve") - vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main") + vllm_serve_main = __import__(serve_module, fromlist=["main"]).main + tensor_parallel_size = 1 + data_parallel_size = 1 - tensor_parallel_size = ( - cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size - ) + if cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size: + tensor_parallel_size = ( + cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size + ) + if cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size: + data_parallel_size = ( + cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size + ) host = cli_args.get("host") or cfg.vllm.host port = cli_args.get("port") or cfg.vllm.port gpu_memory_utilization = ( @@ -64,10 +68,10 @@ def do_vllm_serve( cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False ) - # pylint: disable=unexpected-keyword-arg vllm_script_args = AxolotlScriptArguments( model=model, tensor_parallel_size=tensor_parallel_size, + data_parallel_size=data_parallel_size, host=host, port=port, gpu_memory_utilization=gpu_memory_utilization, @@ -78,63 +82,3 @@ def do_vllm_serve( enable_reasoning=enable_reasoning, ) vllm_serve_main(vllm_script_args) - - -def patch_vllm_worker(): - from multiprocessing.connection import Connection - - from vllm import LLM - - def llm_worker( - script_args: AxolotlScriptArguments, - data_parallel_rank: int, - master_port: int, - connection: Connection, - ) -> None: - # Set required environment variables for DP to work with vLLM - os.environ["VLLM_DP_RANK"] = str(data_parallel_rank) - os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank) - os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size) - os.environ["VLLM_DP_MASTER_PORT"] = str(master_port) - - llm = LLM( - model=script_args.model, - revision=script_args.revision, - tensor_parallel_size=script_args.tensor_parallel_size, - gpu_memory_utilization=script_args.gpu_memory_utilization, - enforce_eager=script_args.enforce_eager, - dtype=script_args.dtype, - # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can - # directly reuse the KV cache if it shares the same prefix with one of the existing queries. - # This is particularly useful here because we generate completions from the same prompts. - enable_prefix_caching=script_args.enable_prefix_caching, - kv_cache_dtype=script_args.kv_cache_dtype, - max_model_len=script_args.max_model_len, - worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension", - enable_reasoning=script_args.enable_reasoning, - reasoning_parser=script_args.reasoning_parser, - ) - - # Send ready signal to parent process - connection.send({"status": "ready"}) - - while True: - # Wait for commands from the parent process - try: - command = connection.recv() - except KeyboardInterrupt: - llm.collective_rpc(method="close_communicator") - break - - # Handle commands - if command["type"] in ["call", "fire_and_forget"]: - method_name = command["method"] - args, kwargs = command.get("args", ()), command.get("kwargs", {}) - method = getattr(llm, method_name) - result = method(*args, **kwargs) - if command["type"] == "call": - connection.send(result) - elif command["type"] == "shutdown": - break - - trl.scripts.vllm_serve.llm_worker = llm_worker diff --git a/src/axolotl/common/architectures.py b/src/axolotl/common/architectures.py index 2f77b613e..b754e56ba 100644 --- a/src/axolotl/common/architectures.py +++ b/src/axolotl/common/architectures.py @@ -13,4 +13,6 @@ MOE_ARCH_BLOCK = { "qwen2_moe": "Qwen2MoeSparseMoeBlock", "qwen3_moe": "Qwen3MoeSparseMoeBlock", "deepseek_v2": "DeepseekV2MoE", + "gpt_oss": "GptOssDecoderLayer", + "lfm2_moe": "Lfm2MoeSparseMoeBlock", } diff --git a/src/axolotl/common/const.py b/src/axolotl/common/const.py index fd34ad469..8aae06e99 100644 --- a/src/axolotl/common/const.py +++ b/src/axolotl/common/const.py @@ -1,5 +1,3 @@ -""" -Various shared constants -""" +"""Various shared constants""" DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index 35d5472b0..c95ddb80e 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -3,16 +3,14 @@ import math import random from dataclasses import dataclass -from typing import Optional, Union from datasets import Dataset -import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 +import axolotl.monkeypatch.data.batch_dataset_fetcher # noqa: F401 from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.loaders import load_processor, load_tokenizer from axolotl.telemetry.errors import send_errors -from axolotl.utils.data import prepare_dataset -from axolotl.utils.data.rl import load_prepare_preference_datasets +from axolotl.utils.data import prepare_datasets, prepare_preference_datasets from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RLType @@ -31,16 +29,7 @@ class TrainDatasetMeta: def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset: - """ - Randomly sample `num_samples` samples from `dataset`. - - Args: - dataset: Dataset. - num_samples: Number of samples to return. - - Returns: - Random sample (with replacement) of examples in `dataset`. - """ + """Randomly sample `num_samples` samples with replacement from `dataset`.""" return dataset.select( [random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec ) @@ -53,55 +42,50 @@ def load_datasets( cli_args: PreprocessCliArgs | TrainerCliArgs | None = None, debug: bool = False, ) -> TrainDatasetMeta: - """ - Loads one or more training or evaluation datasets, calling - `axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information. + """Loads one or more training or evaluation datasets, calling + `axolotl.utils.data.prepare_datasets`. Optionally, logs out debug information. Args: cfg: Dictionary mapping `axolotl` config keys to values. cli_args: Command-specific CLI arguments. - debug: Whether to print out tokenization of sample + debug: Whether to print out tokenization of sample. This is duplicated in + `cfg` and `cli_args`, but is kept due to use in our Colab notebooks. Returns: Dataclass with fields for training and evaluation datasets and the computed - `total_num_steps`. + `total_num_steps`. """ tokenizer = load_tokenizer(cfg) processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None - preprocess_iterable = ( - cli_args - and hasattr(cli_args, "iterable") - and cli_args.iterable is not None - and cli_args.iterable - ) - train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( + train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets( cfg, tokenizer, processor=processor, - preprocess_iterable=preprocess_iterable, ) - if ( # pylint: disable=too-many-boolean-expressions - cli_args - and ( - cli_args.debug - or cfg.debug - or cli_args.debug_text_only - or int(cli_args.debug_num_examples) > 0 - ) - ) or debug: + if ( + cfg.debug + or getattr(cli_args, "debug", False) + or getattr(cli_args, "debug_text_only", False) + or getattr(cli_args, "debug_num_examples", 0) > 0 + or debug + ): LOG.info("check_dataset_labels...") num_examples = cli_args.debug_num_examples if cli_args else 1 text_only = cli_args.debug_text_only if cli_args else False - train_samples = sample_dataset(train_dataset, num_examples) - check_dataset_labels( - train_samples, - tokenizer, - num_examples=num_examples, - text_only=text_only, - ) + try: + train_samples = sample_dataset(train_dataset, num_examples) + check_dataset_labels( + train_samples, + tokenizer, + num_examples=num_examples, + text_only=text_only, + ) + except AttributeError: + # can't sample iterable datasets + pass LOG.info("printing prompters...") for prompter in prompters: @@ -116,13 +100,10 @@ def load_datasets( @send_errors def load_preference_datasets( - *, - cfg: DictDefault, - cli_args: Union[PreprocessCliArgs, TrainerCliArgs], + *, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None ) -> TrainDatasetMeta: - """ - Loads one or more training or evaluation datasets for RL training using paired - preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`. + """Loads one or more training or evaluation datasets for RL training using paired + preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`. Optionally, logs out debug information. Args: @@ -133,23 +114,28 @@ def load_preference_datasets( Dataclass with fields for training and evaluation datasets and the computed `total_num_steps`. """ - train_dataset, eval_dataset = load_prepare_preference_datasets(cfg) - total_num_steps: Optional[int] = int( - math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) - ) - if cfg.rl is RLType.GRPO: - total_num_steps = None + tokenizer = load_tokenizer(cfg) + train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer) - if cli_args.debug or cfg.debug: + total_num_steps: int | None = None + if cfg.rl is not RLType.GRPO: + total_num_steps = int( + math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) + ) + + if ((cli_args and cli_args.debug) or cfg.debug) and cfg.rl != RLType.ORPO: LOG.info("check_dataset_labels...") + num_examples = cli_args.debug_num_examples if cli_args else 1 + text_only = cli_args.debug_text_only if cli_args else False + tokenizer = load_tokenizer(cfg) - train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples) + train_samples = sample_dataset(train_dataset, num_examples) check_dataset_labels( - train_samples, - tokenizer, - num_examples=cli_args.debug_num_examples, - text_only=cli_args.debug_text_only, + dataset=train_samples, + tokenizer=tokenizer, + num_examples=num_examples, + text_only=text_only, rl_mode=True, ) diff --git a/src/axolotl/convert.py b/src/axolotl/convert.py index d1bdb34db..9e09b37dc 100644 --- a/src/axolotl/convert.py +++ b/src/axolotl/convert.py @@ -67,9 +67,7 @@ class JsonToJsonlConverter: self.json_parser = json_parser self.jsonl_serializer = jsonl_serializer - def convert( - self, input_file_path, output_file_path - ): # pylint: disable=unused-argument + def convert(self, input_file_path, output_file_path): content = self.file_reader.read(input_file_path) data = self.json_parser.parse(content) # data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/__init__.py b/src/axolotl/core/attention/__init__.py similarity index 100% rename from src/axolotl/integrations/cut_cross_entropy/monkeypatch/__init__.py rename to src/axolotl/core/attention/__init__.py diff --git a/src/axolotl/core/attention/flex_block_mask.py b/src/axolotl/core/attention/flex_block_mask.py new file mode 100644 index 000000000..37149983c --- /dev/null +++ b/src/axolotl/core/attention/flex_block_mask.py @@ -0,0 +1,158 @@ +""" +monkeypatch for flex + packing +""" + +import sys +from typing import Callable, Optional, Union + +import torch +from torch.nn.attention.flex_attention import BlockMask +from transformers import Cache, PretrainedConfig +from transformers.masking_utils import ( + ALL_MASK_ATTENTION_FUNCTIONS, + _preprocess_mask_arguments, + and_masks, + causal_mask_function, + or_masks, +) +from transformers.utils import is_torch_greater_or_equal + +_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True) + + +def create_causal_mask( + config: PretrainedConfig, + input_embeds: torch.Tensor, + attention_mask: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Optional[Cache], + or_mask_function: Optional[Callable] = None, + and_mask_function: Optional[Callable] = None, +) -> Optional[Union[torch.Tensor, BlockMask]]: + """ + Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values` + has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align + to what is needed in the `modeling_xxx.py` files). + + Args: + config (`PretrainedConfig`): + The model config. + input_embeds (`torch.Tensor`): + The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the + batch size, query length and dtype. + attention_mask (`torch.Tensor`, optional): + The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length). + It can also be an already prepared 4D mask, in which case it is returned as-is. + cache_position (`torch.Tensor`): + A tensor of shape (query_length,) indicating the current indices of the input sequence elements. + past_key_values (`Cache`, optional): + The past key values, if we use a cache. + or_mask_function (`Callable`, optional): + An optional mask function to combine with the causal mask function (by doing the union of both). This is + useful to easily overlay another mask on top of the causal one, for example for image tokens handling. + and_mask_function (`Callable`, optional): + An optional mask function to combine with the causal mask function (by doing the intersection of both). This is + useful to easily overlay another mask on top of the causal one, for example for image tokens handling. + """ + # If we have an HybridCache structure, here we want to create the mask for the full layers + if ( + past_key_values + and hasattr(past_key_values, "is_sliding") + and False in past_key_values.is_sliding + ): + layer_idx = past_key_values.is_sliding.index(False) + else: + layer_idx = 0 + + original_attention_mask = ( + None + if attention_mask is None + else attention_mask.clone().to(cache_position.device) + ) + early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( + config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx + ) + if early_exit: + return attention_mask + + batch_size, total_seq_len = cache_position.shape + key_length = total_seq_len + document_ids = torch.nn.functional.pad( + original_attention_mask, value=0, pad=(0, key_length) + ) + + batch_size, dtype = input_embeds.shape[0], input_embeds.dtype + if attention_mask is not None: + + def causal_doc_mask_mod(batch_idx, head_idx, q_idx, kv_idx): + """ + Defines the logic of a block causal mask by combining both a standard causal mask + and a block diagonal document mask. + See :func:`~torchtune.modules.attention_utils.create_block_causal_mask` + for an illustration. + """ + causal_mask_ = q_idx >= kv_idx # not valid when decoding + document_mask = ( + document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx] + ) + final_mask = causal_mask_ & document_mask + return final_mask + + mask_factory_function = causal_doc_mask_mod + else: + mask_factory_function = causal_mask_function + mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] + + # Do not allow skip if we are compiling (this is to match BC) + allow_is_causal_skip = ( + not past_key_values.is_compileable if past_key_values is not None else True + ) + + # Allow slight deviations from causal mask + if or_mask_function is not None: + if not _is_torch_greater_or_equal_than_2_6: + raise ValueError( + "Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6" + ) + mask_factory_function = or_masks(mask_factory_function, or_mask_function) + allow_is_causal_skip = False + if and_mask_function is not None: + if not _is_torch_greater_or_equal_than_2_6: + raise ValueError( + "Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6" + ) + mask_factory_function = and_masks(mask_factory_function, and_mask_function) + allow_is_causal_skip = False + + # We now create the mask + causal_mask = mask_interface( + batch_size=batch_size, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=mask_factory_function, + attention_mask=attention_mask, + allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa + dtype=dtype, # Additional kwarg for eager + config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface + ) + return causal_mask + + +def patch_create_causal_mask(model_type): + import transformers.masking_utils + + transformers.masking_utils.create_causal_mask = create_causal_mask + + if model_type: + try: + # Dynamically import the module and attention class + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + module = __import__(module_path) + module.create_causal_mask = create_causal_mask + del sys.modules[module_path] + except (ImportError, AttributeError) as e: + raise ValueError( + f"Could not import attention class for model_type: {model_type}. " + f"Error: {str(e)}" + ) from e diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 36f6739b6..639b0e4d7 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -24,10 +24,8 @@ from pathlib import Path from typing import Any import torch -from transformers import ( - TrainerCallback, -) -from transformers.training_args import OptimizerNames +from transformers import TrainerCallback +from transformers.trainer_pt_utils import AcceleratorConfig from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr @@ -36,16 +34,17 @@ from axolotl.telemetry.manager import TelemetryManager from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( GCCallback, - GPUStatsCallback, SaveAxolotlConfigtoWandBCallback, + SaveModelOnFirstStepCallback, ) from axolotl.utils.callbacks.profiler import PytorchProfilerCallback +from axolotl.utils.distributed import build_parallelism_config from axolotl.utils.schemas.enums import CustomSupportedOptimizers LOG = logging.getLogger(__name__) with suppress(ImportError): - import torch._dynamo # pylint: disable=ungrouped-imports + import torch._dynamo class TrainerBuilderBase(abc.ABC): @@ -114,13 +113,6 @@ class TrainerBuilderBase(abc.ABC): plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model) ) - if self.cfg.profiler_steps: - callbacks.append( - PytorchProfilerCallback( - steps_to_profile=self.cfg.profiler_steps, - ) - ) - if self.cfg.gc_steps: callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps)) @@ -144,8 +136,16 @@ class TrainerBuilderBase(abc.ABC): callbacks.append( SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) ) + if self.cfg.save_first_step: + callbacks.append(SaveModelOnFirstStepCallback()) - callbacks.append(GPUStatsCallback(cfg=self.cfg)) + if self.cfg.profiler_steps: + callbacks.append( + PytorchProfilerCallback( + steps_to_profile=self.cfg.profiler_steps, + profiler_steps_start=self.cfg.profiler_steps_start, + ) + ) telemetry_manager = TelemetryManager.get_instance() if telemetry_manager.enabled: @@ -225,7 +225,9 @@ class TrainerBuilderBase(abc.ABC): if self.cfg.bf16 == "full": training_args_kwargs["bf16_full_eval"] = True else: - training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16 + bf16 = self.cfg.bf16 or self.cfg.bfloat16 + bf16 = bf16 if bf16 is not None else False + training_args_kwargs["bf16"] = bf16 def _configure_scheduler(self, training_args_kwargs: dict): if self.cfg.lr_scheduler in ["one_cycle", "rex"]: @@ -262,33 +264,30 @@ class TrainerBuilderBase(abc.ABC): adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon") if self.cfg.optimizer == "muon": - from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module + from axolotl.contribs.mit.muon import ( MuonOptimizerFactory, ) optimizer_cls = MuonOptimizerFactory optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "dion": + from axolotl.contribs.mit.dion import ( + DionOptimizerFactory, + ) + + optimizer_cls = DionOptimizerFactory + optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"] + optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"] + optimizer_kwargs.update(adam_kwargs) + _, device_mesh = build_parallelism_config(self.cfg) + if device_mesh is not None: + optimizer_kwargs["device_mesh"] = device_mesh elif self.cfg.optimizer == "optimi_adamw": from optimi import AdamW optimizer_kwargs["foreach"] = False optimizer_cls = AdamW optimizer_kwargs.update(adam_kwargs) - elif self.cfg.optimizer == "ao_adamw_4bit": - # TODO remove 20250401 - from torchao.prototype.low_bit_optim import AdamW4bit - - optimizer_cls = AdamW4bit - optimizer_kwargs.update(adam_kwargs) - - LOG.warning( - f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead." - ) - elif self.cfg.optimizer == "ao_adamw_8bit": - from torchao.prototype.low_bit_optim import AdamW8bit - - optimizer_cls = AdamW8bit - optimizer_kwargs.update(adam_kwargs) elif self.cfg.optimizer == "ao_adamw_fp8": from torchao.prototype.low_bit_optim import AdamWFp8 @@ -386,14 +385,16 @@ class TrainerBuilderBase(abc.ABC): ) # eval_strategy and eval_steps - if not self.eval_dataset or self.cfg.val_set_size == 0: - # do not eval if no eval_dataset or val_set_size=0 + if not self.eval_dataset and self.cfg.val_set_size == 0: + # do not eval if no eval_dataset and val_set_size=0 training_args_kwargs["eval_strategy"] = "no" elif self.cfg.eval_steps: training_args_kwargs["eval_strategy"] = "steps" training_args_kwargs["eval_steps"] = self.cfg.eval_steps + training_args_kwargs["eval_on_start"] = True elif self.cfg.eval_strategy: training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy + training_args_kwargs["eval_on_start"] = True def _configure_reporting(self, training_args_kwargs: dict): report_to = [] @@ -417,9 +418,8 @@ class TrainerBuilderBase(abc.ABC): def _configure_torch_compile(self, training_args_kwargs: dict): if self.cfg.torch_compile and getattr(torch, "_dynamo", None): - torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access - True - ) + torch._dynamo.config.suppress_errors = True + torch._dynamo.config.accumulated_cache_size_limit = 256 training_args_kwargs["torch_compile"] = self.cfg.torch_compile if self.cfg.torch_compile_backend: training_args_kwargs["torch_compile_backend"] = ( @@ -428,8 +428,20 @@ class TrainerBuilderBase(abc.ABC): if self.cfg.torch_compile_mode: training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode + def _configure_accelerator_config(self, training_args_kwargs: dict): + if self.cfg.accelerator_config: + training_args_kwargs["accelerator_config"] = AcceleratorConfig( + **self.cfg.accelerator_config + ) + else: + training_args_kwargs["accelerator_config"] = AcceleratorConfig() + def _configure_gradient_checkpointing(self, training_args_kwargs: dict): - if self.cfg.gradient_checkpointing: + if self.cfg.activation_offloading is True: + # don't use the HF gradient checkpointing, manually wrap + training_args_kwargs["gradient_checkpointing"] = False + training_args_kwargs["activation_offloading"] = True + elif self.cfg.gradient_checkpointing is not None: training_args_kwargs["gradient_checkpointing"] = ( self.cfg.gradient_checkpointing ) @@ -482,17 +494,30 @@ class TrainerBuilderBase(abc.ABC): "include_tokens_per_second", "weight_decay", "seed", + "dion_momentum", + "dion_rank_fraction", + "dion_rank_multiple_of", + "dataset_num_proc", ]: if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: training_args_kwargs[arg] = getattr(self.cfg, arg) + arg_map = { + "dion_learning_rate": "dion_lr", + } + for kwarg, cfg_arg in arg_map.items(): + if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None: + training_args_kwargs[kwarg] = getattr(self.cfg, cfg_arg) + training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size + training_args_kwargs["average_tokens_across_devices"] = False if self.cfg.eval_batch_size: training_args_kwargs["per_device_eval_batch_size"] = ( self.cfg.eval_batch_size ) + training_args_kwargs["include_tkps"] = self.cfg.include_tkps training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1 training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs @@ -500,10 +525,15 @@ class TrainerBuilderBase(abc.ABC): if self.cfg.reward_model or self.cfg.rl: training_args_kwargs["max_length"] = self.cfg.sequence_len + if self.cfg.fsdp_config or self.cfg.fsdp: + training_args_kwargs["fsdp_config"] = self.cfg.fsdp_config + training_args_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp else True + self._configure_reporting(training_args_kwargs) self._configure_hub_parameters(training_args_kwargs) self._configure_scheduler(training_args_kwargs) self._configure_optimizer(training_args_kwargs, trainer_kwargs) self._configure_torch_compile(training_args_kwargs) + self._configure_accelerator_config(training_args_kwargs) return training_args_kwargs, trainer_kwargs diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 7a81616ba..820304230 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -10,6 +10,7 @@ import transformers from transformers import ( DataCollatorWithFlattening, EarlyStoppingCallback, + Trainer, ) from trl.trainer.utils import RewardDataCollatorWithPadding @@ -19,12 +20,6 @@ from axolotl.core.trainers import ( AxolotlPRMTrainer, AxolotlRewardTrainer, AxolotlTrainer, - ReLoRATrainer, -) -from axolotl.core.training_args import ( - AxolotlPRMConfig, - AxolotlRewardConfig, - AxolotlTrainingArguments, ) from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES @@ -32,9 +27,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback from axolotl.processing_strategies import get_processing_strategy from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( - EvalFirstStepCallback, LossWatchDogCallback, - SaveBetterTransformerModelCallback, bench_eval_callback_factory, causal_lm_bench_eval_callback_factory, colab_inference_post_train_callback, @@ -42,6 +35,7 @@ from axolotl.utils.callbacks import ( ) from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.callbacks.qat import QATCallback +from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, @@ -50,6 +44,7 @@ from axolotl.utils.collators import ( V2BatchSamplerDataCollatorForSeq2Seq, ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator +from axolotl.utils.import_helper import get_cls_from_module_str from axolotl.utils.logging import get_logger LOG = get_logger(__name__) @@ -63,17 +58,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): def get_callbacks(self): callbacks = super().get_callbacks() - callbacks.append(EvalFirstStepCallback()) - if self.cfg.relora_steps: + if self.cfg.relora: callbacks.append(ReLoRACallback(self.cfg)) - if ( - hasattr(self.model, "use_bettertransformer") - and self.model.use_bettertransformer is True - ): - callbacks.append(SaveBetterTransformerModelCallback()) - # TODO: check if can move to base class if self.cfg.loss_watchdog_threshold is not None: callbacks.append(LossWatchDogCallback(self.cfg)) @@ -81,6 +69,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.qat: callbacks.append(QATCallback(self.cfg.qat)) + if self.cfg.include_tkps: + callbacks.append( + TokensPerSecondCallback( + self.cfg.tensor_parallel_size, self.cfg.context_parallel_size + ) + ) return callbacks def get_post_trainer_create_callbacks(self, trainer): @@ -130,33 +124,44 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return callbacks def _get_trainer_cls(self): + """ + Gets the trainer class for the given configuration. + """ if self.cfg.plugins: plugin_manager = PluginManager.get_instance() trainer_cls = plugin_manager.get_trainer_cls(self.cfg) if trainer_cls: return trainer_cls - if self.cfg.relora_steps: - return ReLoRATrainer if self.cfg.model_config_type == "mamba": return AxolotlMambaTrainer if self.cfg.reward_model: return AxolotlRewardTrainer if self.cfg.process_reward_model: return AxolotlPRMTrainer + + if self.cfg.trainer_cls: + # override the trainer cls + try: + trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls) + LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}") + return trainer_cls + except (ImportError, AttributeError, ValueError) as e: + raise ValueError( + f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}" + ) from e + return AxolotlTrainer def build(self, total_num_steps): + from axolotl.core.training_args import ( + AxolotlPRMConfig, + AxolotlRewardConfig, + AxolotlTrainingArguments, + ) + training_arguments_kwargs, trainer_kwargs = self._set_base_training_args( total_num_steps ) - - if self.cfg.fsdp: - training_arguments_kwargs["fsdp"] = self.cfg.fsdp - if self.cfg.fsdp_config: - training_arguments_kwargs["fsdp_config"] = { - k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items() - } - if self.cfg.adapter == "qlora": training_arguments_kwargs["qlora"] = True @@ -243,14 +248,27 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) + training_arguments_kwargs["sample_packing_drop_attention_mask"] = bool( + self.cfg.flash_attention + or self.cfg.xformers_attention + or self.cfg.flex_attention + ) training_arguments_kwargs["multipack_real_batches"] = ( self.cfg.multipack_real_batches if self.cfg.multipack_real_batches is not None - else not self.cfg.flash_attention + else not ( + self.cfg.flash_attention + or self.cfg.flex_attention + or self.cfg.xformers_attention + ) ) training_arguments_kwargs["eval_sample_packing"] = bool( self.cfg.eval_sample_packing ) + if self.cfg.sample_packing_sequentially is not None: + training_arguments_kwargs["sample_packing_sequentially"] = ( + self.cfg.sample_packing_sequentially + ) if self.cfg.sample_packing_bin_size is not None: training_arguments_kwargs["sample_packing_bin_size"] = ( self.cfg.sample_packing_bin_size @@ -264,20 +282,25 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.sample_packing_eff_est ) - if self.cfg.relora_steps: - training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps - training_arguments_kwargs["relora_warmup_steps"] = ( - self.cfg.relora_warmup_steps - ) - if self.cfg.relora_anneal_steps: - training_arguments_kwargs["relora_anneal_steps"] = ( - self.cfg.relora_anneal_steps - ) + if self.cfg.relora and self.cfg.jagged_restart_steps: if self.cfg.relora_prune_ratio: training_arguments_kwargs["relora_prune_ratio"] = ( self.cfg.relora_prune_ratio ) + if self.cfg.jagged_restart_steps: + training_arguments_kwargs["jagged_restart_steps"] = ( + self.cfg.jagged_restart_steps + ) + if self.cfg.jagged_restart_warmup_steps: + training_arguments_kwargs["jagged_restart_warmup_steps"] = ( + self.cfg.jagged_restart_warmup_steps + ) + if self.cfg.jagged_restart_anneal_steps: + training_arguments_kwargs["jagged_restart_anneal_steps"] = ( + self.cfg.jagged_restart_anneal_steps + ) + if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers training_arguments_kwargs["lisa_step_interval"] = ( @@ -303,48 +326,37 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.neftune_noise_alpha ) - if self.cfg.accelerator_config: - training_arguments_kwargs["accelerator_config"] = ( - self.cfg.accelerator_config - ) - if self.cfg.image_size: training_arguments_kwargs["image_size"] = self.cfg.image_size if self.cfg.image_resize_algorithm: training_arguments_kwargs["image_resize_algorithm"] = ( self.cfg.image_resize_algorithm ) - if self.cfg.kd_ce_alpha is not None: - training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha - if self.cfg.kd_alpha is not None: - training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha - if self.cfg.kd_temperature is not None: - training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature - if self.cfg.kd_zscore_base_temp is not None: - training_arguments_kwargs["kd_zscore_base_temp"] = ( - self.cfg.kd_zscore_base_temp - ) - if self.cfg.kd_top_k_before_softmax is not None: - training_arguments_kwargs["kd_top_k_before_softmax"] = ( - self.cfg.kd_top_k_before_softmax - ) + + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + plugin_training_args = plugin_manager.get_training_args(self.cfg) + if plugin_training_args: + training_arguments_kwargs.update(plugin_training_args) if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig + if self.cfg.center_rewards_coefficient is not None: + training_arguments_kwargs["center_rewards_coefficient"] = ( + self.cfg.center_rewards_coefficient + ) elif self.cfg.process_reward_model: training_args_cls = AxolotlPRMConfig else: training_args_cls = AxolotlTrainingArguments - training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg + training_args = training_args_cls( **training_arguments_kwargs, ) training_args = self.hook_post_create_training_args(training_args) # unset run_name so wandb sets up experiment names if self.cfg.use_wandb and training_args.run_name == training_args.output_dir: - training_args.run_name = ( # pylint: disable=attribute-defined-outside-init - None - ) + training_args.run_name = None data_collator_kwargs = { "padding": True, # True/"longest" is the default @@ -354,7 +366,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil( self.cfg.sequence_len / multiple ) - else: + elif self.cfg.pad_to_sequence_len is None: # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html data_collator_kwargs["pad_to_multiple_of"] = multiple @@ -376,12 +388,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): **data_collator_kwargs, ) sig = inspect.signature(trainer_cls) - if "processing_class" in sig.parameters: + if "processing_class" in sig.parameters or issubclass(trainer_cls, Trainer): trainer_kwargs["processing_class"] = self.tokenizer elif "tokenizer" in sig.parameters: trainer_kwargs["tokenizer"] = self.tokenizer + if ( - not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer]) + trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer] and self.cfg.datasets is not None ): trainer_kwargs["dataset_tags"] = [ @@ -397,6 +410,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): **trainer_kwargs, ) trainer = self.hook_post_create_trainer(trainer) + # if the trainer has the `axolotl_cfg` property, set it + if hasattr(trainer, "axolotl_cfg"): + trainer.axolotl_cfg = self.cfg for callback in self.get_post_trainer_create_callbacks(trainer): trainer.add_callback(callback) @@ -408,7 +424,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return trainer def build_collator( - self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs + self, + training_args, # type: "AxolotlTrainingArguments" # type: ignore + is_eval=False, + **kwargs, ): if training_args.pretraining: if ( @@ -416,7 +435,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): or self.cfg.micro_batch_size > 1 ): return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) - return None + if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn): + return None if self.cfg.model_config_type == "mamba": return MambaDataCollator(tokenizer=self.tokenizer) @@ -437,7 +457,19 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ] ] collator_args = [self.tokenizer] - if self.cfg.reward_model: + + collator_cls_and_kwargs = None + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs( + self.cfg, is_eval=is_eval + ) + + if collator_cls_and_kwargs: + collator = collator_cls_and_kwargs[0] + if kwargs and isinstance(kwargs, dict): + kwargs.update(collator_cls_and_kwargs[1]) + elif self.cfg.reward_model: collator = RewardDataCollatorWithPadding elif use_batch_sampler_collator: # Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention, @@ -468,16 +500,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): collator_args.pop(0) kwargs.pop("pad_to_multiple_of", None) kwargs.pop("padding", None) - elif self.cfg.kd_trainer: - from axolotl.integrations.kd.collator import ( - DataCollatorForKD, - KDBatchSamplerDataCollatorForSeq2Seq, - ) - - if self.cfg.sample_packing: - collator = KDBatchSamplerDataCollatorForSeq2Seq - else: - collator = DataCollatorForKD else: collator = DataCollatorForSeq2Seq diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index 14dbfa715..0ceb80008 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -12,13 +12,10 @@ from axolotl.core.trainers import ( from axolotl.core.trainers.dpo import DPOStrategy from axolotl.core.trainers.dpo.args import AxolotlDPOConfig from axolotl.core.trainers.grpo import GRPOStrategy -from axolotl.core.training_args import ( - AxolotlCPOConfig, - AxolotlKTOConfig, - AxolotlORPOConfig, -) from axolotl.integrations.base import PluginManager from axolotl.loaders.utils import ensure_dtype +from axolotl.utils.callbacks.qat import QATCallback +from axolotl.utils.import_helper import get_cls_from_module_str from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RLType @@ -31,6 +28,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase): def get_callbacks(self): callbacks = super().get_callbacks() + if self.cfg.qat: + callbacks.append(QATCallback(self.cfg.qat)) + return callbacks def get_post_trainer_create_callbacks(self, trainer): @@ -54,7 +54,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.rl is RLType.GRPO: trainer_cls = GRPOStrategy.get_trainer_class( - sequence_parallel=self.cfg.sequence_parallel_degree > 1 + sequence_parallel=self.cfg.context_parallel_size > 1 ) trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) @@ -73,12 +73,28 @@ class HFRLTrainerBuilder(TrainerBuilderBase): else: raise ValueError(f"Unsupported RL: {self.cfg.rl}") + if self.cfg.trainer_cls: + # override the trainer cls + try: + trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls) + LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}") + except (ImportError, AttributeError, ValueError) as e: + raise ValueError( + f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}" + ) from e + return trainer_cls, trainer_cls_args def _build_training_arguments(self, total_num_steps): """ Returns training_args and trainer_kwargs """ + from axolotl.core.training_args import ( + AxolotlCPOConfig, + AxolotlKTOConfig, + AxolotlORPOConfig, + ) + training_args_kwargs, trainer_kwargs = self._set_base_training_args( total_num_steps=total_num_steps ) @@ -90,10 +106,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): else: training_args_kwargs["remove_unused_columns"] = False - # only rlhf - if self.cfg.dataset_processes: - training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes - if self.cfg.trl and self.cfg.trl.beta is not None: training_args_kwargs["beta"] = self.cfg.trl.beta elif self.cfg.rl_beta is not None: @@ -108,6 +120,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.use_wandb: training_args_kwargs["run_name"] = self.cfg.wandb_name + if self.cfg.max_prompt_len: + training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len + else: + training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len + training_args_cls = None blocklist_args_kwargs = [] if self.cfg.rl is RLType.SIMPO: @@ -117,10 +134,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.cpo_alpha is not None: training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha + # Handle when max_prompt_length == max_length from defaults + # CPOTrainer requires strictly less than + if ( + training_args_kwargs["max_prompt_length"] + == training_args_kwargs["max_length"] + ): + training_args_kwargs["max_prompt_length"] -= 1 + elif self.cfg.rl is RLType.ORPO: training_args_cls = AxolotlORPOConfig - if self.cfg.max_prompt_len: - training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len elif self.cfg.rl is RLType.KTO: training_args_cls = AxolotlKTOConfig @@ -132,9 +155,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): self.cfg.kto_undesirable_weight or 1.0 ) - if self.cfg.max_prompt_len: - training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len - elif self.cfg.rl is RLType.GRPO: training_args_cls = GRPOStrategy.get_training_args_class() training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) @@ -142,22 +162,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): elif self.cfg.rl in [RLType.DPO, RLType.IPO]: training_args_cls = AxolotlDPOConfig - if self.cfg.rl is RLType.IPO: - training_args_kwargs["loss_type"] = "ipo" - - # Not compatible with IPO - if self.cfg.rl is RLType.DPO and self.cfg.dpo_label_smoothing: - training_args_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing - - training_args_kwargs["max_completion_length"] = None - training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len - training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb - if self.cfg.dpo_use_weighting is not None: - training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting - if self.cfg.dpo_use_logits_to_keep is not None: - training_args_kwargs["use_logits_to_keep"] = ( - self.cfg.dpo_use_logits_to_keep - ) + training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg)) else: raise ValueError(f"Unsupported RL: {self.cfg.rl}") @@ -165,16 +170,20 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if blocklist_key in training_args_kwargs: del training_args_kwargs[blocklist_key] - training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + plugin_training_args = plugin_manager.get_training_args(self.cfg) + if plugin_training_args: + training_args_kwargs.update(plugin_training_args) + + training_args = training_args_cls( logging_first_step=True, **training_args_kwargs, ) # unset run_name so wandb sets up experiment names if self.cfg.use_wandb and training_args.run_name == training_args.output_dir: - training_args.run_name = ( # pylint: disable=attribute-defined-outside-init - None - ) + training_args.run_name = None return training_args, trainer_kwargs @@ -216,7 +225,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): callbacks=self.get_callbacks(), **trainer_kwargs, ) - if self.cfg.fsdp: + if self.cfg.fsdp_config or self.cfg.fsdp: ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype) if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model: ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype) @@ -226,21 +235,3 @@ class HFRLTrainerBuilder(TrainerBuilderBase): trainer.add_callback(callback) return trainer - - -class HFPPOTrainerBuilder(TrainerBuilderBase): - """ - HF Factory class for PPO Trainer - """ - - def get_callbacks(self): - callbacks = super().get_callbacks() - return callbacks - - def get_post_trainer_create_callbacks(self, trainer): - callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) - return callbacks - - def build(self, total_num_steps): - # TODO: build PPOConfig - raise NotImplementedError("PPO trainer builder is not implemented yet.") diff --git a/src/axolotl/core/chat/format/chatml.py b/src/axolotl/core/chat/format/chatml.py index 04c398fe8..deb8a9997 100644 --- a/src/axolotl/core/chat/format/chatml.py +++ b/src/axolotl/core/chat/format/chatml.py @@ -10,7 +10,7 @@ from .shared import wrap_tools def format_message( message: Messages, - message_index: Optional[int] = None, # pylint: disable=unused-argument + message_index: Optional[int] = None, ) -> Messages: if message.is_chat_formatted: return message diff --git a/src/axolotl/core/chat/messages.py b/src/axolotl/core/chat/messages.py index 923b177c1..912a12ca1 100644 --- a/src/axolotl/core/chat/messages.py +++ b/src/axolotl/core/chat/messages.py @@ -15,11 +15,11 @@ class MessageRoles(str, Enum): Message roles for the system, user, assistant, and tools """ - system = "system" # pylint: disable=invalid-name - user = "user" # pylint: disable=invalid-name - assistant = "assistant" # pylint: disable=invalid-name - tool = "tool" # pylint: disable=invalid-name - ipython = ( # pylint: disable=invalid-name + system = "system" + user = "user" + assistant = "assistant" + tool = "tool" + ipython = ( # for responses from builtin tools "ipython" ) @@ -30,12 +30,12 @@ class MessageContentTypes(str, Enum): Message content types for text, image, audio, tool calls, and tool responses """ - special_token = "special_token" # pylint: disable=invalid-name # nosec B105 - text = "text" # pylint: disable=invalid-name - image = "image" # pylint: disable=invalid-name - audio = "audio" # pylint: disable=invalid-name - tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant - tool_response = "tool_response" # pylint: disable=invalid-name + special_token = "special_token" # nosec B105 + text = "text" + image = "image" + audio = "audio" + tool_call = "tool_call" + tool_response = "tool_response" class SpecialToken(str, Enum): @@ -43,8 +43,8 @@ class SpecialToken(str, Enum): Special tokens for beginning of string and end of string """ - bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105 - eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105 + bos_token = "bos_token" # nosec B105 + eos_token = "eos_token" # nosec B105 class ToolCallFunction(BaseModel): @@ -73,7 +73,7 @@ class ToolCallContents(BaseModel): name: str arguments: dict[str, Union[str, int]] - id: Optional[str] = None # pylint: disable=invalid-name + id: Optional[str] = None def __str__(self) -> str: data = {"name": self.name, "arguments": self.arguments} @@ -89,7 +89,7 @@ class ToolResponseContents(BaseModel): name: str content: Union[str, dict[str, Union[str, int, float]]] - id: Optional[str] = None # pylint: disable=invalid-name + id: Optional[str] = None def __str__(self) -> str: data = {"name": self.name, "content": self.content} diff --git a/src/axolotl/core/datasets/chat.py b/src/axolotl/core/datasets/chat.py index 724f12866..a4dc300d9 100644 --- a/src/axolotl/core/datasets/chat.py +++ b/src/axolotl/core/datasets/chat.py @@ -2,7 +2,6 @@ chat dataset module """ -import os from typing import Callable, Optional, Union from datasets import Dataset @@ -41,14 +40,10 @@ class TokenizedChatDataset(Dataset): ) return ex.tokenized(model_transform) - process_or_cpu_count: int = ( - process_count or os.cpu_count() # type: ignore[assignment] - ) - num_proc = min(32, process_or_cpu_count) features = data.features.keys() tokenized_data = data.map( map_fn, - num_proc=num_proc, + num_proc=process_count, keep_in_memory=keep_in_memory, remove_columns=features, desc="Tokenizing Chats", diff --git a/src/axolotl/core/datasets/transforms/chat_builder.py b/src/axolotl/core/datasets/transforms/chat_builder.py index 692fe3ebb..0de0ecb40 100644 --- a/src/axolotl/core/datasets/transforms/chat_builder.py +++ b/src/axolotl/core/datasets/transforms/chat_builder.py @@ -1,23 +1,17 @@ """ -This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat. +This module contains a function that builds a transform that takes a row from the +dataset and converts it to a Chat. """ -from typing import Any, Mapping, Union +from typing import Any, Mapping -def chat_message_transform_builder( # pylint: disable=dangerous-default-value +def chat_message_transform_builder( train_on_inputs=False, - conversations_field: str = "conversations", - message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role" - message_field_content: Union[str, list[str]] = [ - "value", - "text", - "content", - ], # commonly "content" - message_field_training: Union[str, list[str]] = [ - "train", - "weight", - ], # commonly "weight" + conversations_field: str = "messages", + message_field_role: str | list[str] | None = None, # commonly "role" + message_field_content: str | list[str] | None = None, # commonly "content" + message_field_training: str | list[str] | None = None, # commonly "weight" ): """Builds a transform that takes a row from the dataset and converts it to a Chat @@ -26,19 +20,25 @@ def chat_message_transform_builder( # pylint: disable=dangerous-default-value If True, the transform will train on the inputs. If False, the transform will train on the targets. Defaults to False. conversations_field (str, optional): - The field name of the conversations. Defaults to "conversations". + The field name of the conversations. Defaults to "messages". message_field_role (str | list[str], optional): - The field name of the role. Defaults to "role". + The field name of the role. message_field_content (str | list[str], optional): - The field name of the message content. Defaults to "content". + The field name of the message content. message_field_training (str | list[str], optional): - The field name of the train/weight. Defaults to "weight". + The field name of the train/weight. Returns: Callable: A function that takes a list of conversations and returns a list of messages. """ + if message_field_training is None: + message_field_training = ["train", "weight"] + if message_field_content is None: + message_field_content = ["value", "text", "content"] + if message_field_role is None: + message_field_role = ["role", "from"] message_field_role = ( [message_field_role] if isinstance(message_field_role, str) diff --git a/src/axolotl/core/trainers/__init__.py b/src/axolotl/core/trainers/__init__.py index 2cdc9c195..22d8b64f6 100644 --- a/src/axolotl/core/trainers/__init__.py +++ b/src/axolotl/core/trainers/__init__.py @@ -1,18 +1,14 @@ """Init for axolotl.core.trainers""" -# pylint: disable=unused-import # flake8: noqa from .base import AxolotlTrainer from .dpo.trainer import AxolotlDPOTrainer -from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer from .mamba import AxolotlMambaTrainer -from .relora import ReLoRATrainer from .trl import ( AxolotlCPOTrainer, AxolotlKTOTrainer, AxolotlORPOTrainer, AxolotlPRMTrainer, AxolotlRewardTrainer, - TRLPPOTrainer, ) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 70e443cb3..11dfecb98 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -1,17 +1,18 @@ """Module for customized trainers""" -# pylint: disable=too-many-lines - from __future__ import annotations import os from collections import defaultdict from functools import partial, wraps -from typing import Callable, Literal, Optional +from typing import Any, Callable, Literal, Optional import datasets +import safetensors import torch +from accelerate.state import AcceleratorState from datasets import Dataset +from peft import PeftModel from torch.utils.data import ( BatchSampler, DataLoader, @@ -19,13 +20,19 @@ from torch.utils.data import ( Sampler, SequentialSampler, ) -from transformers import Trainer -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker +from transformers import PreTrainedModel, Trainer +from transformers.trainer import TRAINING_ARGS_NAME +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker +from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available from trl.trainer.utils import pad_to_length from typing_extensions import override from axolotl.core.trainers.mixins import ( + ActivationOffloadingMixin, + CheckpointSaveMixin, + DistributedParallelMixin, OptimizerMixin, + PackingMixin, RngLoaderMixin, SchedulerMixin, ) @@ -33,17 +40,46 @@ from axolotl.core.trainers.utils import ( sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_tagging, ) +from axolotl.utils import get_not_null +from axolotl.utils.bench import get_gpu_memory_usage +from axolotl.utils.dict import DictDefault +from axolotl.utils.distributed import is_main_process from axolotl.utils.logging import get_logger from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = get_logger(__name__) +REDUCTION_FNS = { + "mean": torch.mean, + "min": torch.min, + "max": torch.max, + "sum": torch.sum, +} -class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): + +class AxolotlTrainer( + PackingMixin, + SchedulerMixin, + OptimizerMixin, + RngLoaderMixin, + CheckpointSaveMixin, + ActivationOffloadingMixin, + DistributedParallelMixin, + Trainer, +): """Extend the base Trainer for axolotl helpers""" args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] tag_names = ["axolotl"] + _axolotl_cfg: DictDefault | None = None + + @property + def axolotl_cfg(self): + return self._axolotl_cfg + + @axolotl_cfg.setter + def axolotl_cfg(self, cfg): + self._axolotl_cfg = cfg def __init__( self, @@ -59,24 +95,13 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): self._signature_columns = None # workaround for pylint super().__init__(*_args, **kwargs) - self.train_data_collator = self.data_collator - self._stored_metrics = defaultdict(lambda: defaultdict(list)) + self._stored_metrics = defaultdict( + lambda: defaultdict(lambda: {"values": [], "reduction": "mean"}) + ) if self.args.orpo_alpha: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") - def _wrap_model(self, model, training=True, dataloader=None): - if self.args.torch_compile: - torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access - 256 - ) - model = torch.compile( - model, - backend=self.args.torch_compile_backend, - mode=self.args.torch_compile_mode, - ) - return super()._wrap_model(model, training=training, dataloader=dataloader) - def _create_multipack_sampler( self, base_sampler: Sampler, dataset: Dataset ) -> MultipackBatchSampler: @@ -101,7 +126,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): ) batch_max_len = train_batch_size * self.args.max_seq_length - return MultipackBatchSampler( + sampler = MultipackBatchSampler( base_sampler, lengths=get_dataset_lengths(dataset), packing_efficiency_estimate=self.args.sample_packing_efficiency, @@ -111,11 +136,16 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): bin_size=self.args.sample_packing_bin_size, sequential=self.args.sample_packing_sequentially, drop_last=True, + num_processes=self.args.dataset_num_proc, + mp_start_method=self.args.sample_packing_mp_start_method or "fork", ) + len(sampler) + return sampler + def _get_train_sampler( - self, train_dataset: Optional[Dataset] = None - ) -> Optional[Sampler]: + self, train_dataset: Dataset | None = None + ) -> Sampler | None: """ Helper method to get the sampler for training. Handles cases for sample packing and curriculum sampling (sequential). @@ -124,16 +154,22 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): If the dataset is non-empty, a sampler is returned, the type of which depends on the passed training args. """ + # from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L969C1-L972C24 + if train_dataset is None: + train_dataset = self.train_dataset + if train_dataset is None or not has_length(train_dataset): + return None + use_sample_packing = self.args.sample_packing and not self.args.pretraining # Determine the base sampler first if self.args.curriculum_sampling: - base_sampler = SequentialSampler(self.train_dataset) + base_sampler = SequentialSampler(train_dataset) elif use_sample_packing: - base_sampler = RandomSampler(self.train_dataset) + base_sampler = RandomSampler(train_dataset) else: # Default to parent class implementation for standard random sampling - return super()._get_train_sampler() + return super()._get_train_sampler(train_dataset) # Apply multipack wrapper if needed if use_sample_packing: @@ -152,6 +188,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): If the dataset is non-empty, a sampler is returned, the type of which depends on the passed training args. """ + # from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L1065C9-L1066C24 + if eval_dataset is None or not has_length(eval_dataset): + return None + # Multipacking enabled if training is enabled and eval is not explicitly disabled use_multipack = ( self.args.sample_packing and self.args.eval_sample_packing is not False @@ -187,6 +227,14 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): if dataset.column_names and "length" in dataset.column_names: dataset = dataset.remove_columns(["length"]) + if ( + dataset.column_names + and "position_ids" in dataset.column_names + and "attention_mask" in dataset.column_names + and self.args.sample_packing + and self.args.sample_packing_drop_attention_mask + ): + dataset = dataset.remove_columns(["attention_mask"]) if isinstance(dataset, datasets.Dataset): if is_training: @@ -220,7 +268,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): } if not isinstance(dataset, torch.utils.data.IterableDataset): - dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["drop_last"] = get_not_null( + self.args.dataloader_drop_last, True + ) if sampler_fn is not None: sampler = sampler_fn(dataset) if isinstance(sampler, BatchSampler): @@ -251,9 +301,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): # fmt: off if dataloader_key is not None and self.args.dataloader_persistent_workers: if hasattr(self, "_eval_dataloaders"): - self._eval_dataloaders[dataloader_key] = dataloader # type: ignore # pylint: disable=access-member-before-definition + self._eval_dataloaders[dataloader_key] = dataloader # type: ignore else: - self._eval_dataloaders = {dataloader_key: dataloader} # pylint: disable=attribute-defined-outside-init + self._eval_dataloaders = {dataloader_key: dataloader} # fmt: on return self.accelerator.prepare(dataloader) @@ -295,6 +345,17 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): # outputs = model(**inputs) # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) # return (loss, outputs) if return_outputs else loss + + # track number of tokens for tokens per second calculation + if self.args.include_tkps: + inputs_key = "labels" if "labels" in inputs else "input_ids" + if hasattr(self.state, "num_tokens"): + self.state.num_tokens = ( + self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu() + ) + else: + self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu() + if self.args.orpo_alpha: return self.orpo_compute_loss( model, @@ -310,6 +371,11 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): num_items_in_batch=num_items_in_batch, ) + @override + def evaluate(self, *args, **kwargs): + LOG.info("Running evaluation step...") + return super().evaluate(*args, **kwargs) + @staticmethod def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): concatenated_batch = {} @@ -409,7 +475,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): model, inputs, return_outputs=False, - num_items_in_batch=None, # pylint: disable=unused-argument + num_items_in_batch=None, ): concat_inputs = AxolotlTrainer.orpo_concatenate_inputs( inputs, @@ -486,26 +552,32 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): @wraps(Trainer.create_accelerator_and_postprocess) def create_accelerator_and_postprocess(self): - res = super().create_accelerator_and_postprocess() + # cleanup the PartialState states so Accelerate automatically configures everything from the env vars + accelerator_config = self.args.accelerator_config.to_dict() + use_configured_state = accelerator_config.get("use_configured_state", False) + if not use_configured_state: + AcceleratorState._reset_state(reset_partial_state=True) - if self.is_fsdp_enabled: - if ( - "limit_all_gathers" in self.args.fsdp_config - and self.args.fsdp_config["limit_all_gathers"] - ): - self.accelerator.state.fsdp_plugin.limit_all_gathers = True - - return res + super().create_accelerator_and_postprocess() def additional_accelerator_args( - self, fp8=None, **kwargs - ): # pylint: disable=unused-argument + self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs + ) -> dict[str, Any]: ret_kwargs = {} if fp8: from accelerate.utils import AORecipeKwargs + from torchao.float8 import Float8LinearConfig + + # By default, Float8LinearConfig is instantiated using the "tensorwise" + # scaling strategy. See more details here: + # https://github.com/pytorch/ao/tree/main/torchao/float8. + config = Float8LinearConfig( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, + force_recompute_fp8_weight_in_bwd=enable_fsdp_float8_all_gather is True, + ) ret_kwargs["mixed_precision"] = "fp8" - ret_kwargs["kwargs_handlers"] = [AORecipeKwargs()] + ret_kwargs["kwargs_handlers"] = [AORecipeKwargs(config=config)] # type: ignore os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8" return ret_kwargs @@ -520,18 +592,61 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): """ # logs either has 'loss' or 'eval_loss' train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() + + for key, metric_data in self._stored_metrics[train_eval].items(): + values = torch.tensor(metric_data["values"]) # type: ignore[arg-type] + reduction_type = metric_data["reduction"] + + fn = REDUCTION_FNS.get(reduction_type) + if fn is None: + raise NotImplementedError( + "Metric reduction must be one of [mean, min, max, sum]" + ) + logs[key] = round(fn(values).item(), 4) + + if is_main_process(): + # Add memory usage + try: + active, allocated, reserved = get_gpu_memory_usage() + logs["memory/max_active (GiB)"] = round(active, 2) + logs["memory/max_allocated (GiB)"] = round(allocated, 2) + logs["memory/device_reserved (GiB)"] = round(reserved, 2) + except (ValueError, TypeError, FileNotFoundError): + pass + + if self.args.include_tkps and train_eval == "train": + # each rank will log its own tokens per second + # for logging_steps > 1 we obtain a moving average of this metric + logs["tokens_per_second_per_gpu"] = round( + self.state.last_tokens_per_second.item() / self.args.logging_steps, 2 + ) + del self._stored_metrics[train_eval] return super().log(logs, start_time) def store_metrics( - self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train" + self, + metrics: dict[str, float] | dict[str, tuple[int | float, str]], + train_eval: Literal["train", "eval"] = "train", + reduction: Literal["mean", "min", "max", "sum"] = "mean", ) -> None: + """ + Store metrics with specified reduction type. + + Args: + metrics: Dictionary of metric names to values, or metric names to (value, + reduction_type) tuples. + train_eval: Whether this is for training or evaluation. + """ for key, value in metrics.items(): - self._stored_metrics[train_eval][key].append(value) + if isinstance(value, tuple): + value, _reduction = value # type: ignore[assignment] + else: + value, _reduction = value, reduction + + self._stored_metrics[train_eval][key]["values"].append(value) + self._stored_metrics[train_eval][key]["reduction"] = _reduction def _save_checkpoint(self, model, trial, **kwargs): # make sure the checkpoint dir exists, since trainer is flakey @@ -540,3 +655,69 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): output_dir = os.path.join(run_dir, checkpoint_folder) os.makedirs(output_dir, exist_ok=True) return super()._save_checkpoint(model, trial, **kwargs) + + # TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged + def _save(self, output_dir: Optional[str] = None, state_dict=None): + # If we are executing this function, we are the process zero, so we don't check for that. + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + LOG.info(f"Saving model checkpoint to {output_dir}") + supported_classes = ( + (PreTrainedModel,) + if not is_peft_available() + else (PreTrainedModel, PeftModel) + ) + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + if not isinstance(self.model, supported_classes): + if state_dict is None: + state_dict = self.model.state_dict() + if isinstance( + self.accelerator.unwrap_model(self.model, keep_torch_compile=False), + supported_classes, + ): + self.accelerator.unwrap_model( + self.model, keep_torch_compile=False + ).save_pretrained( + output_dir, + state_dict=state_dict, + safe_serialization=self.args.save_safetensors, + ) + else: + LOG.info( + "Trainer.model is not a `PreTrainedModel`, only saving its state dict." + ) + if self.args.save_safetensors: + safetensors.torch.save_file( + state_dict, + os.path.join(output_dir, SAFE_WEIGHTS_NAME), + metadata={"format": "pt"}, + ) + else: + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + self.model.save_pretrained( + output_dir, + state_dict=state_dict, + safe_serialization=self.args.save_safetensors, + is_main_process=self.accelerator.is_main_process, + ) + + if self.processing_class is not None: + self.processing_class.save_pretrained(output_dir) + elif ( + self.data_collator is not None + and hasattr(self.data_collator, "tokenizer") + and self.data_collator.tokenizer is not None + ): + LOG.info( + "Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`" + ) + save_jinja_files = True + if self.axolotl_cfg: + save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files + self.data_collator.tokenizer.save_pretrained( + output_dir, save_jinja_files=save_jinja_files + ) + # Good practice: save your training arguments together with the trained model + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py index 603fdf0b6..3aa79c484 100644 --- a/src/axolotl/core/trainers/dpo/__init__.py +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -22,10 +22,18 @@ class DPOStrategy: training_args_kwargs = {} if cfg.rl is RLType.IPO: training_args_kwargs["loss_type"] = "ipo" - training_args_kwargs["max_length"] = cfg.sequence_len + # Label smoothing is not compatible with IPO + if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing: + training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing training_args_kwargs["max_completion_length"] = None - training_args_kwargs["max_prompt_length"] = cfg.sequence_len - training_args_kwargs["generate_during_eval"] = cfg.use_wandb + training_args_kwargs["max_length"] = cfg.sequence_len + training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval if cfg.dpo_use_weighting is not None: training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting + if cfg.dpo_padding_free is not None: + training_args_kwargs["padding_free"] = cfg.dpo_padding_free + if cfg.dpo_norm_loss is not None: + training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss + if cfg.dpo_use_logits_to_keep is not None: + training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep return training_args_kwargs diff --git a/src/axolotl/core/trainers/dpo/args.py b/src/axolotl/core/trainers/dpo/args.py index de1758ed0..b1e53236e 100644 --- a/src/axolotl/core/trainers/dpo/args.py +++ b/src/axolotl/core/trainers/dpo/args.py @@ -14,3 +14,5 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): """ DPO config for DPO training """ + + dpo_norm_loss: bool | None = False diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py index 15af80c02..b04505d89 100644 --- a/src/axolotl/core/trainers/dpo/trainer.py +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -8,7 +8,11 @@ import torch from torch import nn from trl import DPOTrainer -from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin +from axolotl.core.trainers.mixins import ( + DistributedParallelMixin, + RngLoaderMixin, + SchedulerMixin, +) from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin from axolotl.core.trainers.utils import ( sanitize_kwargs_for_ds_tagging, @@ -17,7 +21,12 @@ from axolotl.core.trainers.utils import ( class AxolotlDPOTrainer( - RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, DPOTrainer + RngLoaderMixin, + SchedulerMixin, + OptimizerMixin, + OptimizerInitMixin, + DPOTrainer, + DistributedParallelMixin, ): """Extend the base DPOTrainer for axolotl helpers.""" @@ -83,3 +92,20 @@ class AxolotlDPOTrainer( gc.collect() torch.cuda.empty_cache() return loss + + def concatenated_forward( + self, + model: nn.Module, + batch: dict[str, Union[list, torch.LongTensor]], + is_ref_model: bool = False, + ) -> dict[str, torch.Tensor]: + if self.args.dpo_norm_loss: + # fmt: off + loss_type: str = self.loss_type # type: ignore[has-type] + # fmt: on + # concatenated_forward handles avg token logprob for ipo case already + self.loss_type = "ipo" + res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model) + self.loss_type = loss_type + return res + return super().concatenated_forward(model, batch, is_ref_model=is_ref_model) diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index c0f10be23..d1a6b7fd9 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -2,8 +2,11 @@ import importlib import inspect +import os from typing import Any +from huggingface_hub import snapshot_download +from requests import HTTPError from trl.trainer.grpo_trainer import RewardFunc from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig @@ -14,6 +17,7 @@ from axolotl.core.trainers.grpo.trainer import ( from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger from axolotl.utils.schemas.trl import TRLConfig +from axolotl.utils.schemas.vllm import VllmConfig LOG = get_logger(__name__) @@ -41,9 +45,20 @@ class GRPOStrategy: return grpo_args_kwargs trl: TRLConfig = cfg.trl # type: ignore + vllm_cfg: VllmConfig = cfg.vllm # type: ignore if trl.use_vllm: grpo_args_kwargs["use_vllm"] = trl.use_vllm + if trl.vllm_mode: + grpo_args_kwargs["vllm_mode"] = trl.vllm_mode + if trl.vllm_mode == "colocate": + grpo_args_kwargs["enable_sleep_mode"] = trl.vllm_enable_sleep_mode # type: ignore[attr-defined] + grpo_args_kwargs["vllm_gpu_memory_utilization"] = ( + vllm_cfg.gpu_memory_utilization + ) + grpo_args_kwargs["vllm_tensor_parallel_size"] = ( + vllm_cfg.tensor_parallel_size + ) grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host # type: ignore[attr-defined] grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port # type: ignore[attr-defined] if trl.vllm_server_timeout: @@ -69,8 +84,13 @@ class GRPOStrategy: grpo_args_kwargs["log_completions"] = trl.log_completions grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print - if cfg.sequence_parallel_degree > 1: - grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree + if cfg.context_parallel_size > 1: + grpo_args_kwargs["context_parallel_size"] = cfg.context_parallel_size + + if trl.importance_sampling_level is not None: + grpo_args_kwargs["importance_sampling_level"] = ( + trl.importance_sampling_level + ) if trl.reward_weights: grpo_args_kwargs["reward_weights"] = trl.reward_weights @@ -109,9 +129,7 @@ class GRPOStrategy: return grpo_args_kwargs @classmethod - def set_trainer_args( - cls, cfg: DictDefault - ) -> list[Any]: # pylint: disable=unused-argument + def set_trainer_args(cls, cfg: DictDefault) -> list[Any]: trainer_args = [] if cfg.trl and cfg.trl.reward_funcs: reward_funcs = [] @@ -132,13 +150,13 @@ class GRPOStrategy: return trainer_kwargs @classmethod - def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument + def get_collator(cls, *args, **kwargs): # No data collation is needed in GRPO, handled by trl's trainer __init__ return None @classmethod def get_blocklist_args_kwargs(cls) -> list[str]: - return ["dataset_num_proc", "max_length"] + return ["dataset_num_proc", "max_length", "include_tokens_per_second"] @classmethod def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc: @@ -168,9 +186,18 @@ class GRPOStrategy: "Reward function must accept at least two arguments: prompts: list and completions: list" ) return reward_func - except ModuleNotFoundError: + except ModuleNotFoundError as exc: # the user has passed a string (ideally indicating the path of a reward model) - LOG.info( - f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path." - ) - return reward_func_fqn + # check if it's a local dir path and not empty dir to a reward model + pretrained_log_msg = f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path." + if os.path.isdir(reward_func_fqn) and os.listdir(reward_func_fqn): + LOG.info(pretrained_log_msg) + return reward_func_fqn + try: + snapshot_download(reward_func_fqn, repo_type="model") + LOG.info(pretrained_log_msg) + return reward_func_fqn + except HTTPError: + raise ValueError( + f"Reward function {reward_func_fqn} not found." + ) from exc diff --git a/src/axolotl/core/trainers/grpo/args.py b/src/axolotl/core/trainers/grpo/args.py index 5c8b1a33b..2ea52998e 100644 --- a/src/axolotl/core/trainers/grpo/args.py +++ b/src/axolotl/core/trainers/grpo/args.py @@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): """Axolotl GRPO Config for GRPO training""" - sequence_parallel_degree: int | None = None + context_parallel_size: int | None = None diff --git a/src/axolotl/core/trainers/grpo/sampler.py b/src/axolotl/core/trainers/grpo/sampler.py index ebc6e19e2..df679a6d2 100644 --- a/src/axolotl/core/trainers/grpo/sampler.py +++ b/src/axolotl/core/trainers/grpo/sampler.py @@ -20,7 +20,7 @@ class SequenceParallelRepeatRandomSampler(Sampler): - Data is properly distributed across SP groups. In the table below, the values represent dataset indices. Each SP group has - `sequence_parallel_degree = 2` GPUs working together on the same data. There are 2 + `context_parallel_size = 2` GPUs working together on the same data. There are 2 SP groups (SP0 and SP1), with `world_size = 4` total GPUs. Sequence Parallel Groups @@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler): rank: Rank of current process. batch_size: Number of samples per batch. repeat_count: How many times to repeat the full sampling process. - sequence_parallel_degree: Number of ranks in a sequence parallel group. + context_parallel_size: Number of ranks in a sequence parallel group. shuffle: Whether to shuffle the dataset. seed: Random seed for shuffling. drop_last: Whether to drop the last incomplete batch. @@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler): rank: int, batch_size: int = 1, repeat_count: int = 1, - sequence_parallel_degree: int = 1, + context_parallel_size: int = 1, shuffle: bool = True, seed: int = 0, drop_last: bool = False, @@ -77,9 +77,9 @@ class SequenceParallelRepeatRandomSampler(Sampler): self.rank = rank # Sequence parallelism parameters - self.sequence_parallel_degree = sequence_parallel_degree - self.num_sp_groups = world_size // sequence_parallel_degree - self.sp_group_id = rank // sequence_parallel_degree + self.context_parallel_size = context_parallel_size + self.num_sp_groups = world_size // context_parallel_size + self.sp_group_id = rank // context_parallel_size # Adjust dataset size for distributed sampling self.num_samples = len(self.dataset) diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index dccc85d80..f9f5a695b 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -1,8 +1,7 @@ """Axolotl GRPO trainers (with and without sequence parallelism handling)""" -# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member - import warnings +from functools import partial from typing import Any import datasets @@ -42,17 +41,25 @@ from trl.trainer.grpo_trainer import RewardFunc, nanstd from trl.trainer.utils import pad from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler -from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin +from axolotl.core.trainers.mixins import ( + DistributedParallelMixin, + RngLoaderMixin, + SchedulerMixin, +) from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin from axolotl.monkeypatch.ring_attn import get_ring_attn_group if is_peft_available(): - # pylint: disable=unused-import from peft import PeftConfig class AxolotlGRPOTrainer( - RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, GRPOTrainer + RngLoaderMixin, + SchedulerMixin, + OptimizerMixin, + OptimizerInitMixin, + DistributedParallelMixin, + GRPOTrainer, ): """Extend the base GRPOTrainer for axolotl helpers""" @@ -99,7 +106,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): # Get number of SP groups (number of processes divided by SP degree) num_processes = self.accelerator.num_processes - num_sp_groups = num_processes // self.args.sequence_parallel_degree + num_sp_groups = num_processes // self.args.context_parallel_size # Calculate batch size per SP group (not per process) sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups @@ -129,7 +136,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): if self.num_generations not in possible_values: raise ValueError( - f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), " + f"With sequence parallelism (degree {self.args.context_parallel_size}), " f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) " f"must be evenly divisible by the number of generations per prompt " f"({self.num_generations}). Given the current eval batch size, " @@ -166,9 +173,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): rank=self.rank, batch_size=effective_batch_size // self.num_generations - // self.args.sequence_parallel_degree, + // self.args.context_parallel_size, repeat_count=self.num_iterations * self.args.gradient_accumulation_steps, - sequence_parallel_degree=self.args.sequence_parallel_degree, + context_parallel_size=self.args.context_parallel_size, shuffle=True, seed=self.args.seed, drop_last=True, @@ -215,7 +222,11 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): dataloader_params["drop_last"] = self.args.dataloader_drop_last if not is_eval: - dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["worker_init_fn"] = partial( + seed_worker, + num_workers=self.args.dataloader_num_workers, + rank=self.args.process_index, + ) # Create the dataloader dataloader = DataLoader(dataset, **dataloader_params) @@ -230,7 +241,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): # TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation # if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e., # slice each batch along the sequence dimension). - if self.args.sequence_parallel_degree > 1: + if self.args.context_parallel_size > 1: return dataloader # Otherwise prepare with accelerator @@ -239,7 +250,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): def get_train_dataloader(self) -> DataLoader: """Get dataloader for training""" train_dataset = self.train_dataset - # pylint: disable=access-member-before-definition + data_collator = self.data_collator # type: ignore # Handle dataset preprocessing @@ -252,7 +263,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): train_dataset, description="training" ) else: - self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init + self.data_collator = self._get_collator_with_removed_columns( data_collator, description="training", ) @@ -294,33 +305,34 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): # Generate completions using either vLLM or regular generation if self.args.use_vllm: # First, have main process load weights if needed - # pylint: disable=access-member-before-definition + if self.state.global_step != self._last_loaded_step: # type: ignore[has-type] self._move_model_to_vllm() - # pylint: disable=attribute-defined-outside-init + self._last_loaded_step = self.state.global_step # Generate completions using vLLM: gather all prompts and use them in a single call in the main process all_prompts_text = gather_object(prompts_text) if self.accelerator.is_main_process: - if self.args.sequence_parallel_degree > 1: + if self.args.context_parallel_size > 1: # Calculate sequence parallel group information world_size = self.accelerator.num_processes - sequence_parallel_degree = self.args.sequence_parallel_degree - num_sp_groups = world_size // sequence_parallel_degree + context_parallel_size = self.args.context_parallel_size + num_sp_groups = world_size // context_parallel_size # Since processes in the same SP group have the same prompts, we need to ensure # we only take one copy of each prompt from each SP group ordered_set_of_prompts = [] for sp_group_id in range(num_sp_groups): # Get the first process from each SP group (typically the group leader) - group_leader_rank = sp_group_id * sequence_parallel_degree + group_leader_rank = sp_group_id * context_parallel_size # Extract prompts from this SP group, accounting for num_generations duplicates # We only need prompts from one rank in each SP group group_prompts = all_prompts_text[ - group_leader_rank - * len(prompts_text) : (group_leader_rank + 1) + group_leader_rank * len(prompts_text) : ( + group_leader_rank + 1 + ) * len(prompts_text) : self.num_generations ] @@ -330,7 +342,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): # num_generations outputs for each one. This is faster than generating outputs for each duplicate # prompt individually. ordered_set_of_prompts = all_prompts_text[ - :: self.num_generations * self.args.sequence_parallel_degree + :: self.num_generations * self.args.context_parallel_size ] with profiling_context(self, "vLLM.generate"): @@ -347,14 +359,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): ) else: completion_ids = [None] * ( - len(all_prompts_text) // self.args.sequence_parallel_degree + len(all_prompts_text) // self.args.context_parallel_size ) # Broadcast the completions from the main process to all processes completion_ids = broadcast_object_list(completion_ids, from_process=0) # Determine the appropriate slice based on sequence parallelism - if self.args.sequence_parallel_degree > 1: + if self.args.context_parallel_size > 1: # Calculate SP group ID (which group of ranks this rank belongs to) sp_group_id = self.accelerator.process_index // self.local_world_size @@ -471,7 +483,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): ) if is_conversational(inputs[0]): completions = [] - for prompt, completion in zip(prompts, completions_text): + for prompt, completion in zip(prompts, completions_text, strict=False): bootstrap = ( prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" ) @@ -489,6 +501,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): self.reward_funcs, self.reward_processing_classes, self.reward_func_names, + strict=False, ) ): with profiling_context(self, reward_func_name): @@ -497,14 +510,17 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): ): # Module instead of PretrainedModel for compat with compiled models if is_conversational(inputs[0]): messages = [ - {"messages": p + c} for p, c in zip(prompts, completions) + {"messages": p + c} + for p, c in zip(prompts, completions, strict=False) ] texts = [ apply_chat_template(x, reward_processing_class)["text"] for x in messages ] else: - texts = [p + c for p, c in zip(prompts, completions)] + texts = [ + p + c for p, c in zip(prompts, completions, strict=False) + ] reward_inputs = reward_processing_class( text=texts, return_tensors="pt", @@ -550,7 +566,8 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): row_reward_kwargs["completion"] = completions[nan_row_idx] warnings.warn( f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. " - "Please ensure that at least one reward function returns a valid reward." + "Please ensure that at least one reward function returns a valid reward.", + stacklevel=2, ) # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the @@ -578,7 +595,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): advantages = advantages / (std_grouped_rewards + 1e-4) # Slice to keep only the local part of the data - if self.args.sequence_parallel_degree > 1: + if self.args.context_parallel_size > 1: # Calculate SP group ID (which group of ranks this rank belongs to) sp_group_id = self.accelerator.process_index // self.local_world_size diff --git a/src/axolotl/core/trainers/mamba.py b/src/axolotl/core/trainers/mamba.py index 38792e389..dedda1b29 100644 --- a/src/axolotl/core/trainers/mamba.py +++ b/src/axolotl/core/trainers/mamba.py @@ -14,8 +14,8 @@ class AxolotlMambaTrainer(AxolotlTrainer): self, model, inputs, - return_outputs=False, # pylint: disable=unused-argument - num_items_in_batch=None, # pylint: disable=unused-argument + return_outputs=False, + num_items_in_batch=None, ): input_ids = inputs.pop("input_ids") lm_logits = model(input_ids).logits diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py index a71cb321a..5fced1692 100644 --- a/src/axolotl/core/trainers/mixins/__init__.py +++ b/src/axolotl/core/trainers/mixins/__init__.py @@ -1,8 +1,11 @@ """Init for axolotl.core.trainers.mixins""" -# pylint: disable=unused-import # flake8: noqa +from .activation_checkpointing import ActivationOffloadingMixin +from .checkpoints import CheckpointSaveMixin +from .distributed_parallel import DistributedParallelMixin from .optimizer import OptimizerMixin +from .packing import PackingMixin from .rng_state_loader import RngLoaderMixin from .scheduler import SchedulerMixin diff --git a/src/axolotl/core/trainers/mixins/activation_checkpointing.py b/src/axolotl/core/trainers/mixins/activation_checkpointing.py new file mode 100644 index 000000000..b61c45fee --- /dev/null +++ b/src/axolotl/core/trainers/mixins/activation_checkpointing.py @@ -0,0 +1,217 @@ +""" +Trainer mixin for activation checkpointing w offloading +""" + +import contextlib + +from peft import PeftModel +from torch import nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + apply_activation_checkpointing, +) +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from transformers import GradientCheckpointingLayer, Trainer +from trl.models.activation_offloading import ( + NoOpManager, + OffloadActivations, + get_act_offloading_ctx_manager, +) + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +class ActivationOffloadingMixin(Trainer): + """ + Trainer mixin class for activation checkpointing w offloading + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.args.activation_offloading: + if isinstance(self.model, PeftModel): + self.activation_offload_context = get_lora_act_offloading_ctx_manager( + self.model, use_streams=True + ) + else: + self.activation_offload_context = get_act_offloading_ctx_manager( + self.model, use_streams=True + ) + else: + self.activation_offload_context = contextlib.nullcontext() + + def training_step(self, *args, **kwargs): + with self.activation_offload_context: + return super().training_step(*args, **kwargs) + + +def ac_wrap_hf_model(model: nn.Module, **kwargs): + auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,))) + apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs) + + +def get_lora_act_offloading_ctx_manager( + model: nn.Module, + use_pin_memory: bool = True, + use_streams: bool = True, + min_offload_size: int = 1024, + max_fwd_stash_size: int = 5, + warn_if_no_head: bool = True, +) -> OffloadActivations: + """ + Returns the activation offloading context manager for the model. All but the last output Linear in every step will + be offloaded. + + If activation offloading is enabled, we return the OffloadActivations context manager. If activation offloading is + disabled, we return a NoOpManager context manager. + + Args: + model (`nn.Module`): + Model to wrap with the activation offloading context manager. + use_pin_memory (`bool`, *optional*, defaults to `True`): + Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to + be moved back onto GPU more quickly but is a limited resource. + use_streams (`bool`, *optional*, defaults to `True`): + Whether to use streams for performance optimization where the communications get overlapped with the + computation. Requires a torch build after torch-2.5.0. + min_offload_size (`int`, *optional*, defaults to `1024`): + Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we + do not want to waste bandwidth and resources moving it to CPU and back. + max_fwd_stash_size (`int`, *optional*, defaults to `5`): + Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during + the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow + more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping + alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing + runtime. + warn_if_no_head (`bool`, *optional*, defaults to `True`): + Whether to warn if no output head is detected. If set to `False`, no warning will be raised if no output + head is detected. + + Returns: + `contextlib.ContextDecorator`: + Activation offloading context manager for the model. + """ + + activations_handling_ctx = OffloadActivations( + use_pin_memory=use_pin_memory, + use_streams=use_streams, + min_offload_size=min_offload_size, + max_fwd_stash_size=max_fwd_stash_size, + ) + + # Below is our hack to disable offloading the last output Linear in every + # step, as the cost for offloading the activation and then soon after bringing + # it back is expensive. + output_head_detected = False + noop_ctx = NoOpManager() + + # Try to get the actual model if it's wrapped + unwrapped_model = model + if hasattr(unwrapped_model, "module"): + unwrapped_model = unwrapped_model.module + # check for PEFT models + if hasattr(unwrapped_model, "base_model") and hasattr( + unwrapped_model, "peft_config" + ): + unwrapped_model = unwrapped_model.base_model + + # Check for different types of output heads + if hasattr(unwrapped_model, "output"): + if isinstance(unwrapped_model.output, nn.Module): + unwrapped_model.output.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + unwrapped_model.output.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + output_head_detected = True + elif hasattr(unwrapped_model.output, "linear") and isinstance( + unwrapped_model.output.linear, nn.Module + ): + unwrapped_model.output.linear.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + unwrapped_model.output.linear.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + output_head_detected = True + + # Check for HuggingFace model output heads + elif hasattr(unwrapped_model, "lm_head"): + unwrapped_model.lm_head.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + unwrapped_model.lm_head.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + output_head_detected = True + + # Check for decoder-based models + elif hasattr(unwrapped_model, "decoder"): + decoder = unwrapped_model.decoder + if hasattr(decoder, "output"): + decoder.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) + decoder.output.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + output_head_detected = True + # Some models have lm_head in the decoder + elif hasattr(decoder, "lm_head"): + decoder.lm_head.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + decoder.lm_head.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + output_head_detected = True + + # Check for transformer models with final layer norm + elif hasattr(unwrapped_model, "final_layer_norm") or hasattr( + unwrapped_model, "ln_f" + ): + final_norm = ( + getattr(unwrapped_model, "final_layer_norm", None) or unwrapped_model.ln_f + ) + final_norm.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) + final_norm.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + output_head_detected = True + + # Check for models with head module + elif hasattr(unwrapped_model, "head") and isinstance( + unwrapped_model.head, nn.Module + ): + unwrapped_model.head.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + unwrapped_model.head.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + output_head_detected = True + + if not output_head_detected and warn_if_no_head: + LOG.warning( + "During activation offloading, no output head was detected. If your model has an output head, it will be " + "offloaded. This usually greatly slows training, given the large vocabulary size. To change this " + "behavior, set your output head as model.output and make it an nn.Module. You can disable this warning by " + "passing `warn_if_no_head=False`." + ) + + for name, module in unwrapped_model.named_modules(): + # Disable offloading for any Liger modules + if "liger" in name.lower(): + module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) + module.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + # disable offloading for any submodules to fix LoRA training + if name.endswith("._checkpoint_wrapped_module"): + for _, sub_module in module.named_modules(): + sub_module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) + sub_module.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + + return activations_handling_ctx diff --git a/src/axolotl/core/trainers/mixins/checkpoints.py b/src/axolotl/core/trainers/mixins/checkpoints.py new file mode 100644 index 000000000..4042ef9f1 --- /dev/null +++ b/src/axolotl/core/trainers/mixins/checkpoints.py @@ -0,0 +1,23 @@ +"""Custom handling to not fail training if fsdp optimizer is not savable""" + +from transformers import Trainer + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +class CheckpointSaveMixin(Trainer): + """Mixin to handle saving the optimizer and scheduler if they are not savable.""" + + def _save_optimizer_and_scheduler(self, output_dir): + try: + super()._save_optimizer_and_scheduler(output_dir) + except (NotImplementedError, KeyError) as exc: + # TODO: fix fsdp2 optimizer saving + LOG.warning_once( + f"Trainer does not support saving optimizer and scheduler: {exc}\n" + "Optimizer and scheduler states were not saved - resuming from checkpoints " + "for this training run will not be possible.", + main_process_only=True, + ) diff --git a/src/axolotl/core/trainers/mixins/distributed_parallel.py b/src/axolotl/core/trainers/mixins/distributed_parallel.py new file mode 100644 index 000000000..77aee5236 --- /dev/null +++ b/src/axolotl/core/trainers/mixins/distributed_parallel.py @@ -0,0 +1,32 @@ +""" +Mixin for correctly saving fsdp +""" + +from accelerate import PartialState +from transformers import Trainer + + +class DistributedParallelMixin(Trainer): + """ + Mixin for correctly saving fsdp + """ + + def _save(self, output_dir: str | None = None, state_dict=None): + if ( + state_dict is None + and self.accelerator.parallelism_config + and self.accelerator.parallelism_config.dp_shard_enabled + ): + state_dict = self.accelerator.get_state_dict(self.model) + super()._save(output_dir, state_dict=state_dict) + + def create_accelerator_and_postprocess(self): + super().create_accelerator_and_postprocess() + if ( + self.accelerator.distributed_type == "FSDP" + and self.accelerator.state.fsdp_plugin is None + ): + # handle Context Parallelism without FSDP + self.accelerator.state.distributed_type = "MULTI_GPU" + self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU" + PartialState().distributed_type = "MULTI_GPU" diff --git a/src/axolotl/core/trainers/mixins/optimizer.py b/src/axolotl/core/trainers/mixins/optimizer.py index a9a9a3992..850442c60 100644 --- a/src/axolotl/core/trainers/mixins/optimizer.py +++ b/src/axolotl/core/trainers/mixins/optimizer.py @@ -70,11 +70,11 @@ class OptimizerMixin(Trainer): } ) if params["embeddings"]: - lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name + lr = optimizer_kwargs["lr"] if self.args.embedding_lr_scale: - lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name + lr *= self.args.embedding_lr_scale elif self.args.embedding_lr: - lr = self.args.embedding_lr # pylint: disable=invalid-name + lr = self.args.embedding_lr optimizer_grouped_parameters.append( { "params": list(params["embeddings"].values()), @@ -143,7 +143,7 @@ class OptimizerMixin(Trainer): loraplus_lr_embedding = getattr( self.args, "loraplus_lr_embedding", 1e-6 ) - self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + self.optimizer = create_loraplus_optimizer( opt_model, optimizer_cls, loraplus_lr_ratio=loraplus_lr_ratio, @@ -185,17 +185,15 @@ class OptimizerMixin(Trainer): p.data_ptr(): p.numel() for p in module.parameters() }.values() ) - LOG.info(f"skipped {module}: {skipped/2**20}M params") + LOG.info(f"skipped {module}: {skipped / 2**20}M params") manager.register_module_override( module, "weight", {"optim_bits": 32} ) LOG.debug(f"bitsandbytes: will optimize {module} in fp32") - LOG.info(f"skipped: {skipped/2**20}M params") + LOG.info(f"skipped: {skipped / 2**20}M params") if is_sagemaker_mp_enabled(): - self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init - self.optimizer - ) + self.optimizer = smp.DistributedOptimizer(self.optimizer) return self.optimizer diff --git a/src/axolotl/core/trainers/mixins/packing.py b/src/axolotl/core/trainers/mixins/packing.py new file mode 100644 index 000000000..249ceeb4f --- /dev/null +++ b/src/axolotl/core/trainers/mixins/packing.py @@ -0,0 +1,20 @@ +"""Trainer mixin to support packing""" + +from transformers import Trainer + + +class PackingMixin(Trainer): + """ + Trainer mixin to support packing + """ + + def _set_signature_columns_if_needed(self): + super()._set_signature_columns_if_needed() + if ( + self._signature_columns + and self.args.sample_packing + and self.args.sample_packing_drop_attention_mask + ): + set_sig_columns = set(self._signature_columns) + set_sig_columns.remove("attention_mask") + self._signature_columns = list(set_sig_columns) diff --git a/src/axolotl/core/trainers/mixins/scheduler.py b/src/axolotl/core/trainers/mixins/scheduler.py index 90070ab78..fc2b0e59d 100644 --- a/src/axolotl/core/trainers/mixins/scheduler.py +++ b/src/axolotl/core/trainers/mixins/scheduler.py @@ -7,6 +7,7 @@ from transformers.trainer import Trainer from axolotl.integrations.base import PluginManager from axolotl.utils.logging import get_logger from axolotl.utils.schedulers import ( + JaggedLRRestartScheduler, RexLR, get_cosine_schedule_with_min_lr, get_cosine_schedule_with_quadratic_warmup, @@ -45,7 +46,7 @@ class SchedulerMixin(Trainer): ) # fmt: off - if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition + if self.lr_scheduler is None: # type: ignore # fmt: on plugin_manager = PluginManager.get_instance() lr_scheduler: LRScheduler | None = plugin_manager.create_lr_scheduler( @@ -89,7 +90,7 @@ class SchedulerMixin(Trainer): LOG.warning( "Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") - self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init + self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( optimizer, num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_training_steps=num_training_steps, @@ -97,7 +98,7 @@ class SchedulerMixin(Trainer): elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr: assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0" - self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init + self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( optimizer, num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_training_steps=num_training_steps, @@ -106,14 +107,14 @@ class SchedulerMixin(Trainer): ) elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" - self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init + self.lr_scheduler = get_cosine_schedule_with_min_lr( optimizer, num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_training_steps=num_training_steps, min_lr_ratio=self.args.cosine_min_lr_ratio, ) else: - return super().create_scheduler(num_training_steps, optimizer=optimizer) + super().create_scheduler(num_training_steps, optimizer=optimizer) else: if use_cosine_quadratic: LOG.warning( @@ -123,4 +124,22 @@ class SchedulerMixin(Trainer): LOG.warning( "axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") + if self.args.jagged_restart_steps: + warmup_steps = ( + self.args.jagged_restart_warmup_steps or 10 + ) + anneal_steps = ( + self.args.jagged_restart_anneal_steps or 1 + ) + if not self.lr_scheduler: + super().create_scheduler(num_training_steps, optimizer) + self.lr_scheduler = JaggedLRRestartScheduler( + optimizer, + self.lr_scheduler, + self.args.jagged_restart_steps, + warmup_steps, + anneal_steps, + min_lr_scale=self.args.cosine_min_lr_ratio or 0.001, + ) + return self.lr_scheduler # type: ignore diff --git a/src/axolotl/core/trainers/relora.py b/src/axolotl/core/trainers/relora.py deleted file mode 100644 index 890278f49..000000000 --- a/src/axolotl/core/trainers/relora.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Module for ReLoRA trainer""" - -import torch -from torch.optim.lr_scheduler import LRScheduler - -from axolotl.core.trainers.base import AxolotlTrainer -from axolotl.monkeypatch.relora import ReLoRAScheduler - - -class ReLoRATrainer(AxolotlTrainer): - """Trainer subclass that uses the `OneCycleLR` scheduler""" - - tag_names = ["axolotl", "relora"] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.lr_scheduler = None - - def create_scheduler( - self, - num_training_steps: int, - optimizer: torch.optim.Optimizer | None = None, - ) -> LRScheduler: - optimizer = self.optimizer if optimizer is None else optimizer - lr_scheduler: LRScheduler = super().create_scheduler( - num_training_steps, optimizer - ) - - if self.args.relora_steps: - warmup_steps = ( - self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10 - ) - anneal_steps = ( - self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1 - ) - self.lr_scheduler = ReLoRAScheduler( # type: ignore - optimizer, - lr_scheduler, - self.args.relora_steps, - anneal_steps, - warmup_steps, - ) - else: - self.lr_scheduler = lr_scheduler # type: ignore - - return self.lr_scheduler # type: ignore diff --git a/src/axolotl/core/trainers/trl.py b/src/axolotl/core/trainers/trl.py index bf0d50dee..c5f19a6fe 100644 --- a/src/axolotl/core/trainers/trl.py +++ b/src/axolotl/core/trainers/trl.py @@ -1,81 +1,25 @@ -"""Module for TRL PPO trainer""" +"""Module for TRL RL trainers""" -import torch -from tqdm import tqdm from trl import ( CPOTrainer, KTOTrainer, ORPOTrainer, - PPOTrainer, PRMTrainer, RewardTrainer, ) -from axolotl.core.trainers.mixins import RngLoaderMixin +from axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin from axolotl.core.trainers.mixins.scheduler import SchedulerMixin -class TRLPPOTrainer(PPOTrainer): - """Wrapper for TRL PPO trainer to handle customizations""" - - tag_names = ["axolotl", "ppo"] - - def train( - self, - reward_pipe, - resume_from_checkpoint=None, # pylint: disable=unused-argument - ): - generation_kwargs = { - "min_length": -1, - "top_k": 0.0, - "top_p": 1.0, - "do_sample": True, - "pad_token_id": self.tokenizer.eos_token_id, - "max_new_tokens": 32, - } - sent_kwargs = { - "return_all_scores": True, - "function_to_apply": "none", - "batch_size": 16, - } - - for _, batch in tqdm(enumerate(self.dataloader)): - query_tensors = batch["input_ids"] - - # generate model response - response_tensors, ref_response_tensors = self.generate( - query_tensors, - return_prompt=False, - generate_ref_response=True, - **generation_kwargs, - ) - batch["response"] = self.tokenizer.batch_decode(response_tensors) - batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors) - - # Compute sentiment score - texts = [q + r for q, r in zip(batch["query"], batch["response"])] - pipe_outputs = reward_pipe(texts, **sent_kwargs) - rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] - ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])] - ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs) - ref_rewards = [ - torch.tensor(output[1]["score"]) for output in ref_pipe_outputs - ] - batch["ref_rewards"] = ref_rewards - - # Run PPO step - stats = self.step(query_tensors, response_tensors, rewards) - self.log_stats( - stats, - batch, - rewards, - columns_to_log=["query", "response", "ref_response", "ref_rewards"], - ) - - class AxolotlORPOTrainer( - RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer + RngLoaderMixin, + SchedulerMixin, + OptimizerMixin, + OptimizerInitMixin, + DistributedParallelMixin, + ORPOTrainer, ): """ Extend the base ORPOTrainer for axolotl helpers @@ -85,7 +29,12 @@ class AxolotlORPOTrainer( class AxolotlKTOTrainer( - RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, KTOTrainer + RngLoaderMixin, + SchedulerMixin, + OptimizerMixin, + OptimizerInitMixin, + DistributedParallelMixin, + KTOTrainer, ): """ Extend the base KTOTrainer for axolotl helpers @@ -95,7 +44,12 @@ class AxolotlKTOTrainer( class AxolotlCPOTrainer( - RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, CPOTrainer + RngLoaderMixin, + SchedulerMixin, + OptimizerMixin, + OptimizerInitMixin, + DistributedParallelMixin, + CPOTrainer, ): """ Extend the base CPOTrainer for axolotl helpers @@ -105,7 +59,12 @@ class AxolotlCPOTrainer( class AxolotlRewardTrainer( - RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, RewardTrainer + RngLoaderMixin, + SchedulerMixin, + OptimizerMixin, + OptimizerInitMixin, + DistributedParallelMixin, + RewardTrainer, ): """ Extend the base RewardTrainer for axolotl helpers @@ -115,7 +74,12 @@ class AxolotlRewardTrainer( class AxolotlPRMTrainer( - RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, PRMTrainer + RngLoaderMixin, + SchedulerMixin, + OptimizerMixin, + OptimizerInitMixin, + DistributedParallelMixin, + PRMTrainer, ): """ Extend the base trl.PRMTrainer for axolotl helpers diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 42488e643..d5be9fc62 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -2,238 +2,17 @@ extra axolotl specific training args """ -from dataclasses import dataclass, field -from typing import Optional +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional, Type -from PIL.Image import Resampling from transformers import TrainingArguments from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig +from axolotl.integrations.config import merge_training_args -@dataclass -class AxolotlTrainingMixins: - """ - Mixin class for the Axolotl training args. - """ - - # pylint: disable=duplicate-code - model_type: Optional[str] = field( - default=None, metadata={"help": "HF model configuration model_type."} - ) - lr_quadratic_warmup: bool = field( - default=False, - metadata={"help": "Use quadratic warmup for cosine scheduling."}, - ) - pretraining: bool = field( - default=False, - metadata={ - "help": "Indicates to trainer whether we are doing continued pretraining." - }, - ) - sample_packing: bool = field( - default=False, - metadata={"help": "Use sample packing for efficient training."}, - ) - sample_packing_sequentially: bool = field( - default=False, - metadata={ - "help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing." - }, - ) - multipack_real_batches: bool = field( - default=False, - metadata={"help": "Use real batches for efficient training."}, - ) - eval_sample_packing: Optional[bool] = field( - default=None, - metadata={"help": "Use sample packing for efficient evals."}, - ) - sample_packing_efficiency: float = field( - default=1.0, - metadata={"help": "Sample packing efficiency for calculating batch length."}, - ) - sample_packing_bin_size: int = field( - default=200, - metadata={ - "help": "The max number of samples that packed sample can contain after packing. Increase for better packing." - }, - ) - sample_packing_group_size: int = field( - default=100000, - metadata={ - "help": "The number of samples to group together for packing. Increase for better packing." - }, - ) - max_seq_length: int = field( - default=2048, - metadata={"help": "The maximum sequence length the model can handle"}, - ) - relora_steps: Optional[int] = field( - default=None, - metadata={"help": "how often to reset for ReLoRA"}, - ) - relora_warmup_steps: Optional[int] = field( - default=None, - metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, - ) - relora_anneal_steps: Optional[int] = field( - default=None, - metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, - ) - relora_prune_ratio: Optional[float] = field( - default=0.9, - metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, - ) - bench_split: Optional[str] = field( - default="eval", metadata={"help": "The benchmark split to run on"} - ) - bench_dataset: Optional[str] = field( - default="pharaouk/dharma-1/dharma_1_mini.json", - metadata={ - "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" - }, - ) - do_bench_eval: Optional[bool] = field( - default=False, metadata={"help": "Whether to run the Benchmark evaluation."} - ) - do_causal_lm_eval: Optional[bool] = field( - default=False, metadata={"help": "Whether to run the Causal LM evaluation."} - ) - max_bench_samples: Optional[int] = field( - default=None, - metadata={ - "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset." - }, - ) - bench_source_max_len: int = field( - default=2048, metadata={"help": "Maximum source sequence length for bench."} - ) - dataloader_prefetch_factor: Optional[int] = field( - default=None, - metadata={"help": "prefetch_factor argument to the dataloader"}, - ) - cosine_min_lr_ratio: Optional[float] = field( - default=None, - metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"}, - ) - cosine_constant_lr_ratio: Optional[float] = field( - default=None, - metadata={ - "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps" - }, - ) - loraplus_lr_ratio: Optional[float] = field( - default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."} - ) - loraplus_lr_embedding: Optional[float] = field( - default=1e-6, - metadata={"help": "loraplus learning rate for lora embedding layers."}, - ) - embedding_lr_scale: Optional[float] = field( - default=None, - metadata={"help": "Scale the learning rate for the embedding layers."}, - ) - lr_groups: Optional[list[dict]] = field( - default=None, - metadata={"help": "Specify learning rate groups for with different LRs."}, - ) - embedding_lr: Optional[float] = field( - default=None, - metadata={"help": "absolute learning rate for the embedding layers."}, - ) - qlora: bool = field( - default=False, - metadata={"help": "whether this is a qlora training"}, - ) - orpo_alpha: Optional[float] = field( - default=None, - ) - lisa_n_layers: Optional[int] = field( - default=None, - metadata={"help": "the number of activate layers in LISA"}, - ) - lisa_step_interval: Optional[int] = field( - default=None, - metadata={"help": "how often to switch layers in LISA"}, - ) - lisa_layers_attribute: Optional[str] = field( - default=None, - metadata={"help": "path under the model to access the layers"}, - ) - curriculum_sampling: Optional[bool] = field( - default=None, - metadata={"help": "whether to use sequential sampling for curriculum learning"}, - ) - alternate_lr_scheduler_type: Optional[str] = field( - default=None, - metadata={ - "help": "workaround to pass an alternate lr scheduler to the HF trainer" - }, - ) - chat_template: Optional[str] = field( - default=None, - metadata={"help": "Chat template converting chat messages to text"}, - ) - - kd_ce_alpha: Optional[float] = field( - default=None, - metadata={ - "help": "The alpha scaling parameter for SFT cross entropy loss when using KD" - }, - ) - - kd_alpha: Optional[float] = field( - default=1.0, - metadata={"help": "The alpha scaling parameter for KD loss"}, - ) - - kd_temperature: Optional[float] = field( - default=1.0, - metadata={ - "help": "the temperature parameter for KL divergence loss when using KD" - }, - ) - - kd_zscore_base_temp: Optional[float] = field( - default=None, - metadata={ - "help": "the base temperature parameter for KL divergence with z-score when using KD" - }, - ) - - kd_top_k_before_softmax: Optional[bool] = field( - default=None, - metadata={ - "help": "Whether to apply top_k_before_softmax to the logits when using KD" - }, - ) - - adam_beta3: Optional[float] = field( - default=None, - metadata={ - "help": "The beta3 hyperparameter used in some optimizers such as CAME" - }, - ) - adam_epsilon2: Optional[float] = field( - default=None, - metadata={ - "help": "The epsilon2 hyperparameter used in some optimizers such as CAME" - }, - ) - - # multi-modal section - - image_size: int | tuple[int, int] | None = field( - default=None, - metadata={"help": "The size of the image to resize to"}, - ) - - image_resize_algorithm: Resampling | None = field( - default=None, - metadata={"help": "The algorithm to use for image resizing"}, - ) - - # end of multi-modal section +AxolotlTrainingMixins: Type = merge_training_args() @dataclass diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py new file mode 100644 index 000000000..41ee8e91e --- /dev/null +++ b/src/axolotl/core/training_args_base.py @@ -0,0 +1,265 @@ +""" +Base Axolotl Training Mixins shared across various trainer configs +""" + +from dataclasses import dataclass, field +from typing import Optional + +from PIL.Image import Resampling + + +@dataclass +class AxolotlTrainingMixins: + """ + Mixin class for the Axolotl training args. + """ + + model_type: Optional[str] = field( + default=None, metadata={"help": "HF model configuration model_type."} + ) + lr_quadratic_warmup: bool = field( + default=False, + metadata={"help": "Use quadratic warmup for cosine scheduling."}, + ) + pretraining: bool = field( + default=False, + metadata={ + "help": "Indicates to trainer whether we are doing continued pretraining." + }, + ) + sample_packing: bool = field( + default=False, + metadata={"help": "Use sample packing for efficient training."}, + ) + sample_packing_sequentially: bool = field( + default=False, + metadata={ + "help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing." + }, + ) + sample_packing_mp_start_method: str | None = field( + default=None, + metadata={"help": "The multiprocessing start method to use."}, + ) + sample_packing_drop_attention_mask: bool = field( + default=False, + metadata={"help": "Drop attention mask from inputs when using packing."}, + ) + multipack_real_batches: bool = field( + default=False, + metadata={"help": "Use real batches for efficient training."}, + ) + include_tkps: bool = field( + default=True, + metadata={ + "help": "Whether to include tokens per second in the training metrics." + }, + ) + eval_sample_packing: Optional[bool] = field( + default=None, + metadata={"help": "Use sample packing for efficient evals."}, + ) + sample_packing_efficiency: float = field( + default=1.0, + metadata={"help": "Sample packing efficiency for calculating batch length."}, + ) + sample_packing_bin_size: int = field( + default=200, + metadata={ + "help": "The max number of samples that packed sample can contain after packing. Increase for better packing." + }, + ) + sample_packing_group_size: int = field( + default=100000, + metadata={ + "help": "The number of samples to group together for packing. Increase for better packing." + }, + ) + max_seq_length: int = field( + default=2048, + metadata={"help": "The maximum sequence length the model can handle"}, + ) + dataset_num_proc: int | None = field( + default=None, + metadata={"help": "The number of processes to use for data processing"}, + ) + relora_steps: Optional[int] = field( + default=None, + metadata={"help": "how often to reset for ReLoRA"}, + ) + relora_prune_ratio: Optional[float] = field( + default=0.9, + metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, + ) + jagged_restart_steps: Optional[int] = field( + default=None, + metadata={"help": "how often to reset for jagged restarts"}, + ) + jagged_restart_warmup_steps: Optional[int] = field( + default=None, + metadata={ + "help": "how many warmup steps to take after reset for jagged restarts" + }, + ) + jagged_restart_anneal_steps: Optional[int] = field( + default=None, + metadata={ + "help": "how many anneal steps to take before reset for jagged restarts" + }, + ) + bench_split: Optional[str] = field( + default="eval", metadata={"help": "The benchmark split to run on"} + ) + bench_dataset: Optional[str] = field( + default="pharaouk/dharma-1/dharma_1_mini.json", + metadata={ + "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" + }, + ) + do_bench_eval: Optional[bool] = field( + default=False, metadata={"help": "Whether to run the Benchmark evaluation."} + ) + do_causal_lm_eval: Optional[bool] = field( + default=False, metadata={"help": "Whether to run the Causal LM evaluation."} + ) + max_bench_samples: Optional[int] = field( + default=None, + metadata={ + "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset." + }, + ) + bench_source_max_len: int = field( + default=2048, metadata={"help": "Maximum source sequence length for bench."} + ) + dataloader_prefetch_factor: Optional[int] = field( + default=None, + metadata={"help": "prefetch_factor argument to the dataloader"}, + ) + cosine_min_lr_ratio: Optional[float] = field( + default=None, + metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"}, + ) + cosine_constant_lr_ratio: Optional[float] = field( + default=None, + metadata={ + "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps" + }, + ) + loraplus_lr_ratio: Optional[float] = field( + default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."} + ) + loraplus_lr_embedding: Optional[float] = field( + default=1e-6, + metadata={"help": "loraplus learning rate for lora embedding layers."}, + ) + embedding_lr_scale: Optional[float] = field( + default=None, + metadata={"help": "Scale the learning rate for the embedding layers."}, + ) + lr_groups: Optional[list[dict]] = field( + default=None, + metadata={"help": "Specify learning rate groups for with different LRs."}, + ) + embedding_lr: Optional[float] = field( + default=None, + metadata={"help": "absolute learning rate for the embedding layers."}, + ) + qlora: bool = field( + default=False, + metadata={"help": "whether this is a qlora training"}, + ) + orpo_alpha: Optional[float] = field( + default=None, + ) + lisa_n_layers: Optional[int] = field( + default=None, + metadata={"help": "the number of activate layers in LISA"}, + ) + lisa_step_interval: Optional[int] = field( + default=None, + metadata={"help": "how often to switch layers in LISA"}, + ) + lisa_layers_attribute: Optional[str] = field( + default=None, + metadata={"help": "path under the model to access the layers"}, + ) + curriculum_sampling: Optional[bool] = field( + default=None, + metadata={"help": "whether to use sequential sampling for curriculum learning"}, + ) + alternate_lr_scheduler_type: Optional[str] = field( + default=None, + metadata={ + "help": "workaround to pass an alternate lr scheduler to the HF trainer" + }, + ) + chat_template: Optional[str] = field( + default=None, + metadata={"help": "Chat template converting chat messages to text"}, + ) + + # kd_ce_alpha: Optional[float] = field( + # default=None, + # metadata={ + # "help": "The alpha scaling parameter for SFT cross entropy loss when using KD" + # }, + # ) + # + # kd_alpha: Optional[float] = field( + # default=1.0, + # metadata={"help": "The alpha scaling parameter for KD loss"}, + # ) + # + # kd_temperature: Optional[float] = field( + # default=1.0, + # metadata={ + # "help": "the temperature parameter for KL divergence loss when using KD" + # }, + # ) + + adam_beta3: Optional[float] = field( + default=None, + metadata={ + "help": "The beta3 hyperparameter used in some optimizers such as CAME" + }, + ) + adam_epsilon2: Optional[float] = field( + default=None, + metadata={ + "help": "The epsilon2 hyperparameter used in some optimizers such as CAME" + }, + ) + + activation_offloading: bool | None = field( + default=None, + metadata={"help": "Use activation offloading with CUDA streams for training."}, + ) + + # multi-modal section + + image_size: int | tuple[int, int] | None = field( + default=None, + metadata={"help": "The size of the image to resize to"}, + ) + + image_resize_algorithm: Resampling | None = field( + default=None, + metadata={"help": "The algorithm to use for image resizing"}, + ) + + # end of multi-modal section + + dion_learning_rate: float | None = field( + default=None, + metadata={"help": "The learning rate for Dion"}, + ) + dion_momentum: float | None = field( + default=None, + metadata={"help": "The momentum for Dion"}, + ) + dion_rank_fraction: float | None = field( + default=None, + ) + dion_rank_multiple_of: int | None = field( + default=None, + ) diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 9f1d9500d..20acb8521 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -1,40 +1,36 @@ -"""Module containing Dataset functionality""" +""" +Module containing dataset functionality. -import os -from typing import List, Optional, Union +We want this to be a wrapper for an existing dataset that we have loaded. Lets use the +concept of middlewares to wrap each dataset. We'll use the collators later on to pad the +datasets. +""" -import torch from datasets import Dataset, IterableDataset from axolotl.utils.logging import get_logger from .prompt_tokenizers import PromptTokenizingStrategy -# We want this to be a wrapper for an existing dataset that we have loaded -# lets use the concept of middlewares to wrap each dataset, for example -# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)])) -# let's check to ensure we don't truncate an item in the middle, we'll use -# the collators later on to pad the datasets - LOG = get_logger(__name__) class TokenizedPromptDataset(Dataset): - """ - Dataset that returns tokenized prompts from a stream of text files. - Args: - prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data. - dataset (dataset.Dataset): Dataset with text files. - process_count (int): Number of processes to use for tokenizing. - keep_in_memory (bool): Whether to keep the tokenized dataset in memory. + """Dataset that returns tokenized prompts from a stream of text files. + + Args: + prompt_tokenizer: The prompt tokenizing method for processing the data. + dataset: Dataset with text files. + process_count: Number of processes to use for tokenizing. + keep_in_memory: Whether to keep the tokenized dataset in memory. """ - def __init__( # pylint: disable=super-init-not-called + def __init__( self, prompt_tokenizer: PromptTokenizingStrategy, dataset: Dataset, - process_count: Optional[int] = None, - keep_in_memory: Optional[bool] = False, + process_count: int | None = None, + keep_in_memory: bool | None = False, **kwargs, ): self.prompt_tokenizer = prompt_tokenizer @@ -47,7 +43,6 @@ class TokenizedPromptDataset(Dataset): def process(self, dataset): features = dataset.features.keys() - num_proc = min(64, self.process_count if self.process_count else os.cpu_count()) map_kwargs = {} if self.prompt_tokenizer.supports_batched: @@ -60,13 +55,13 @@ class TokenizedPromptDataset(Dataset): ): dataset = dataset.filter( self.prompt_tokenizer.filter_rows, - num_proc=num_proc, + num_proc=self.process_count, desc="Strategy Filtering Rows", ) return dataset.map( self.prompt_tokenizer.tokenize_prompt, - num_proc=num_proc, + num_proc=self.process_count, remove_columns=features, keep_in_memory=self.keep_in_memory, desc="Tokenizing Prompts", @@ -76,143 +71,17 @@ class TokenizedPromptDataset(Dataset): def wrap_dataset_for_tokenized_prompt( prompt_tokenizer: PromptTokenizingStrategy, - dataset: Union[Dataset, IterableDataset], + dataset: Dataset | IterableDataset, **kwargs, ): if isinstance(dataset, IterableDataset): map_kwargs = {} if prompt_tokenizer.supports_batched: map_kwargs["batched"] = True - features = dataset.features.keys() + features = list(dataset.features.keys()) return dataset.map( prompt_tokenizer.tokenize_prompt, remove_columns=features, **map_kwargs, ) return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs) - - -# TODO this isn't the best since it can't interleave datasets -class ConstantLengthDataset(IterableDataset): - """ - Iterable dataset that returns constant length chunks of tokens from stream of text files. - Args: - tokenizer (Tokenizer): The processor used for processing the data. - dataset (dataset.Dataset): Dataset with text files. - seq_length (int): Length of token sequences to return. - """ - - def __init__( # pylint: disable=super-init-not-called - self, - tokenizer, - datasets, - seq_length=2048, - ): - self.tokenizer = tokenizer - self.concat_token_id = tokenizer.eos_token_id - self.datasets: List[IterableDataset] = datasets - self.seq_length = seq_length - - vocab_size = len(tokenizer.get_vocab()) - - if vocab_size <= torch.iinfo(torch.int16).max: - self.tokens_dtype = torch.int16 - elif vocab_size <= torch.iinfo(torch.int32).max: - self.tokens_dtype = torch.int32 - else: - self.tokens_dtype = torch.int64 - - def __iter__(self): - buffer = { - "input_ids": [], - "attention_mask": [], - "labels": [], - "position_ids": [], - } - buffer_len = 0 - for dataset in self.datasets: - idx = 0 - iterator = iter(dataset) - more_examples = True - while more_examples: - try: - example = next(iterator) - idx += 1 - except StopIteration: - more_examples = False - example = None - - add_concat_token = False - if example: - example_len = len(example["input_ids"]) - add_concat_token = example["input_ids"][-1] != self.concat_token_id - else: - example_len = 0 - - if not example_len or ( - buffer_len + int(add_concat_token) + example_len > self.seq_length - ): - if buffer["input_ids"]: - input_ids = torch.cat(buffer["input_ids"], dim=-1)[ - : self.seq_length - ] - attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[ - : self.seq_length - ] - position_ids = torch.cat(buffer["position_ids"], dim=-1)[ - : self.seq_length - ] - labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] - if labels.size() == input_ids.size() and ( - attention_mask.size() == input_ids.size() - ): - yield { - "input_ids": input_ids, - "labels": labels, - "attention_mask": attention_mask, - "position_ids": position_ids, - } - else: - LOG.warning( - f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}" - ) - buffer = { - "input_ids": [], - "attention_mask": [], - "labels": [], - "position_ids": [], - } - buffer_len = 0 - idx = 1 - - if example: - # FIXME - # just going to drop data points that are too long - if len(example["input_ids"]) <= self.seq_length: - input_ids = example["input_ids"] - attention_mask = example["attention_mask"] - labels = example["labels"] - - if add_concat_token: - input_ids.append(self.concat_token_id) - attention_mask.append(1) - labels.append(self.concat_token_id) - - input_ids_with_concat = torch.tensor( - input_ids, dtype=self.tokens_dtype - ) - attention_mask_with_concat = torch.tensor( - [idx * m for m in attention_mask], dtype=torch.int16 - ) - labels_with_concat = torch.tensor( - labels, dtype=self.tokens_dtype - ) - position_ids = torch.arange( - len(input_ids), dtype=self.tokens_dtype - ) - - buffer["input_ids"].append(input_ids_with_concat) - buffer["attention_mask"].append(attention_mask_with_concat) - buffer["labels"].append(labels_with_concat) - buffer["position_ids"].append(position_ids) - buffer_len += len(input_ids) diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index 61a9f8fad..db6fb3f16 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -7,7 +7,6 @@ from pathlib import Path from typing import Dict, Optional import torch -from accelerate.logging import get_logger from datasets import Dataset from transformers.trainer import Trainer @@ -18,6 +17,7 @@ from axolotl.train import ( ) from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed +from axolotl.utils.logging import get_logger from axolotl.utils.trainer import setup_trainer project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @@ -81,7 +81,7 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f model, tokenizer, _, processor = setup_model_and_tokenizer(cfg) # Get datasets - # pylint: disable=duplicate-code + train_dataset = dataset_meta.train_dataset eval_dataset = dataset_meta.eval_dataset total_num_steps = dataset_meta.total_num_steps diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 0edc9fdea..c66bc01c6 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -22,17 +22,20 @@ from __future__ import annotations import collections import importlib +import traceback from typing import TYPE_CHECKING, Callable, OrderedDict, Union from peft import PeftModel +from torch import nn from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from transformers import PreTrainedModel, Trainer +from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger -LOG = get_logger(__name__, use_environ=True) +LOG = get_logger(__name__) if TYPE_CHECKING: from axolotl.common.datasets import TrainDatasetMeta @@ -73,8 +76,8 @@ class BasePlugin: def __init__(self): """Initializes the BasePlugin.""" - def register(self, cfg: DictDefault): # pylint: disable=unused-argument - """Registers the plugin with the given configuration. + def register(self, cfg: dict): + """Registers the plugin with the given configuration as an unparsed dict. Args: cfg: The configuration for the plugin. @@ -83,6 +86,11 @@ class BasePlugin: def get_input_args(self) -> str | None: """Returns a pydantic model for the plugin's input arguments.""" + def get_training_args_mixin(self) -> str | None: + """ + Returns a dataclass model for the plugin's training arguments. + """ + def load_datasets( self, cfg: DictDefault, preprocess: bool = False ) -> Union["TrainDatasetMeta", None]: @@ -96,14 +104,13 @@ class BasePlugin: dataset_meta: The metadata for the training dataset. """ - def pre_model_load(self, cfg: DictDefault): # pylint: disable=unused-argument + def pre_model_load(self, cfg: DictDefault): """Performs actions before the model is loaded. Args: cfg: The configuration for the plugin. """ - # pylint: disable=unused-argument def post_model_build(self, cfg: DictDefault, model: PreTrainedModel): """Performs actions after the model is built/loaded, but before any adapters are applied. @@ -111,7 +118,6 @@ class BasePlugin: cfg: The configuration for the plugin. """ - # pylint: disable=unused-argument def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel): """Performs actions before LoRA weights are loaded. @@ -120,7 +126,6 @@ class BasePlugin: model: The loaded model. """ - # pylint: disable=unused-argument def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): """Performs actions after LoRA weights are loaded. @@ -129,7 +134,6 @@ class BasePlugin: model: The loaded model. """ - # pylint: disable=unused-argument def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): """Performs actions after the model is loaded. @@ -138,8 +142,7 @@ class BasePlugin: model: The loaded model. """ - # pylint: disable=unused-argument - def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None: + def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None: """Returns a custom class for the trainer. Args: @@ -149,7 +152,6 @@ class BasePlugin: The first non-`None` trainer class returned by a plugin. """ - # pylint: disable=unused-argument def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): """Performs actions after the trainer is created. @@ -158,7 +160,29 @@ class BasePlugin: trainer: The trainer object for training. """ - # pylint: disable=unused-argument + def get_training_args(self, cfg: DictDefault): + """ + Returns custom training arguments to set on TrainingArgs. + + Args: + cfg: The global axolotl configuration. + + Returns: + object: dict containing the training arguments. + """ + + def get_collator_cls_and_kwargs(self, cfg: DictDefault, is_eval: bool = False): + """ + Returns a custom class for the collator. + + Args: + cfg: The global axolotl configuration. + is_eval: Whether this is an eval split. + + Returns: + class: The class for the collator. + """ + def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None: """Creates and returns an optimizer for training. @@ -170,7 +194,6 @@ class BasePlugin: The created optimizer. """ - # pylint: disable=unused-argument def create_lr_scheduler( self, cfg: DictDefault, @@ -190,7 +213,6 @@ class BasePlugin: The created learning rate scheduler. """ - # pylint: disable=unused-argument def add_callbacks_pre_trainer( self, cfg: DictDefault, model: PreTrainedModel ) -> list[Callable]: @@ -205,7 +227,6 @@ class BasePlugin: """ return [] - # pylint: disable=unused-argument def add_callbacks_post_trainer( self, cfg: DictDefault, trainer: Trainer ) -> list[Callable]: @@ -221,7 +242,6 @@ class BasePlugin: """ return [] - # pylint: disable=unused-argument def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): """Performs actions after training is complete. @@ -230,7 +250,7 @@ class BasePlugin: model: The loaded model. """ - def post_train_unload(self, cfg: DictDefault): # pylint: disable=unused-argument + def post_train_unload(self, cfg: DictDefault): """Performs actions after training is complete and the model is unloaded. Args: @@ -337,8 +357,11 @@ class PluginManager: plugin = load_plugin(plugin_name) self.plugins[plugin_name] = plugin LOG.info(f"Plugin loaded successfully: {plugin_name}") - except ImportError: + except ImportError as exc: LOG.error(f"Failed to load plugin: {plugin_name}") + # print stacktrace + traceback.print_exc() + print(f"Error: {exc}") def get_input_args(self) -> list[str]: """Returns a list of Pydantic classes for all registered plugins' input arguments.' @@ -353,6 +376,20 @@ class PluginManager: input_args.append(input_args_from_plugin) return input_args + def get_training_args_mixin(self): + """ + Returns a list of dataclasses for all registered plugins' training args mixins' + + Returns: + list[str]: A list of dataclsses + """ + training_args = [] + for plugin in self.plugins.values(): + training_args_from_plugin = plugin.get_training_args_mixin() + if training_args_from_plugin is not None: + training_args.append(training_args_from_plugin) + return training_args + def load_datasets( self, cfg: DictDefault, preprocess: bool = False ) -> Union["TrainDatasetMeta", None]: @@ -442,6 +479,42 @@ class PluginManager: return trainer_cls return None + def get_training_args(self, cfg): + """ + Calls the get_training_args method of all registered plugins and returns the combined training arguments. + + Parameters: + cfg (dict): The configuration for the plugins. + + Returns: + object: The training arguments + """ + training_args_kwargs = {} + for plugin in self.plugins.values(): + training_args = plugin.get_training_args(cfg) + if training_args is not None: + training_args_kwargs.update(training_args) + + return training_args_kwargs + + def get_collator_cls_and_kwargs(self, cfg, is_eval=False): + """ + Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class. + + Parameters: + cfg (dict): The configuration for the plugins. + is_eval (bool): Whether this is an eval split. + + Returns: + object: The collator class, or None if none was found. + """ + for plugin in self.plugins.values(): + collator = plugin.get_collator_cls_and_kwargs(cfg, is_eval=is_eval) + if collator is not None: + collator_cls, collator_kwargs = collator + return collator_cls, collator_kwargs + return None + def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): """Calls the `post_trainer_create` method of all registered plugins. @@ -557,3 +630,24 @@ class BaseOptimizerFactory: self, opt_model, training_args, **optimizer_kwargs ) -> Optimizer | None: pass + + # duplicated from transformers + def get_decay_parameter_names(self, model) -> list[str]: + """ + Get all parameter names that weight decay will be applied to. + + This function filters out parameters in two ways: + 1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS) + 2. By parameter name patterns (containing 'bias', or variation of 'norm') + """ + forbidden_name_patterns = [ + r"bias", + r"layernorm", + r"rmsnorm", + r"(?:^|\.)norm(?:$|\.)", + r"_norm(?:$|\.)", + ] + decay_parameters = get_parameter_names( + model, [nn.LayerNorm], forbidden_name_patterns + ) + return decay_parameters diff --git a/src/axolotl/integrations/config.py b/src/axolotl/integrations/config.py index b443f228e..8ae8aab39 100644 --- a/src/axolotl/integrations/config.py +++ b/src/axolotl/integrations/config.py @@ -16,12 +16,12 @@ Module to handle merging the plugins' input arguments with the base configuratio This was moved here to prevent circular imports. """ -from typing import Any, Dict, List +from typing import Any, Dict, List, Type from axolotl.utils.schemas.config import ( AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, + AxolotlInputConfig as AxolotlInputConfigBase, ) -from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase def merge_input_args(): @@ -50,14 +50,44 @@ def merge_input_args(): dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n" namespace: Dict[Any, Any] = {} - exec( # pylint: disable=exec-used # nosec B102 - dynamic_input, globals(), namespace - ) - AxolotlInputConfig = namespace[ # pylint: disable=invalid-name - "AxolotlInputConfig" - ] - AxolotlConfigWCapabilities = namespace[ # pylint: disable=invalid-name - "AxolotlConfigWCapabilities" - ] + exec(dynamic_input, globals(), namespace) # nosec B102 + AxolotlInputConfig = namespace["AxolotlInputConfig"] + AxolotlConfigWCapabilities = namespace["AxolotlConfigWCapabilities"] return AxolotlConfigWCapabilities, AxolotlInputConfig return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase + + +def merge_training_args() -> Type: + """ + Merges training arguments from registered plugins with the base TrainingArguments. + + This function retrieves the training arguments from registered plugins using the PluginManager. + It then dynamically creates new classes, AxolotlTrainingMixins, + that inherit from the base configurations and include the training arguments from the plugins. + + Returns: + tuple: A tuple containing the newly created classes, AxolotlTrainingMixins. + """ + + from axolotl.core.training_args_base import ( + AxolotlTrainingMixins as AxolotlTrainingMixinsBase, + ) + from axolotl.integrations.base import PluginManager + + plugin_manager = PluginManager.get_instance() + training_args_mixins: List[str] = plugin_manager.get_training_args_mixin() + mixin_classes = [] + dynamic_input = "" + for plugin_args in training_args_mixins: + plugin_module, plugin_cls = plugin_args.rsplit(".", 1) + dynamic_input += f"from {plugin_module} import {plugin_cls}\n" + mixin_classes.append(plugin_cls) + if dynamic_input: + dynamic_input += f"class AxolotlTrainingMixins(AxolotlTrainingMixinsBase, {', '.join(mixin_classes)}):\n pass\n" + + namespace: Dict[Any, Any] = {} + local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase} + exec(dynamic_input, {**globals(), **local_vars}, namespace) # nosec B102 + AxolotlTrainingMixins = namespace["AxolotlTrainingMixins"] + return AxolotlTrainingMixins + return AxolotlTrainingMixinsBase diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index 627ebd935..5c7c5166b 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh - If you are installing from pip ```bash -pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438" +pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec" ``` ## Usage @@ -31,27 +31,55 @@ plugins: ## Supported Models -- llama -- llama4 -- llama4_text -- mllama -- phi3 +- apertus +- arcee +- cohere +- cohere2 +- deepseek_v3 - gemma - gemma2 - gemma3 - gemma3_text +- gemma3n +- gemma3n_text +- glm +- glm4 +- glm4_moe +- glm4v +- glm4v_moe +- gpt_oss +- granite +- granitemoe +- granitemoeshared +- granitemoehybrid +- hunyuan_v1_dense +- hunyuan_v1_moe +- lfm2 +- lfm2_moe +- lfm2_vl +- llama +- llama4 +- llama4_text +- llava - mistral - mistral3 +- mixtral +- mllama +- phi +- phi3 +- phi4_multimodal - qwen2 -- qwen2_moe - qwen2_vl +- qwen2_moe - qwen2_5_vl - qwen3 - qwen3_moe -- cohere -- cohere2 -- glm -- glm4 +- qwen3_vl +- qwen3_vl_moe +- qwen3_next +- smollm3 +- seed_oss +- voxtral ## Citation diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index a7e94e363..bd0124b93 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -18,21 +18,24 @@ Module for the Plugin for Cut Cross Entropy integration with Axolotl. Cut Cross Entropy is an optimized implementation of cross entropy loss from Apple's ML team. """ + import importlib +from functools import partial import torch from axolotl.integrations.base import BasePlugin from axolotl.utils import get_pytorch_version +from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix from axolotl.utils.logging import get_logger -from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 +from .args import CutCrossEntropyArgs as CutCrossEntropyArgs -LOG = get_logger(__name__, use_environ=True) +LOG = get_logger(__name__) _CCE_INSTALL_MESSAGE = ( - "Please install cut_cross_entropy with transformers support using " - '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"`' + "Please install Axolotl's fork of cut_cross_entropy with transformers support using " + '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"`' ) @@ -64,16 +67,29 @@ class CutCrossEntropyPlugin(BasePlugin): "cut_cross_entropy.transformers" ) if cce_spec_transformers is None: - raise ImportError(_CCE_INSTALL_MESSAGE) + raise ImportError( + "Transformers support is not installed. " + _CCE_INSTALL_MESSAGE + ) + + # Check if Axolotl's cce fork is installed + try: + from cut_cross_entropy.transformers.patch import AXOLOTL_CCE_FORK + + if not AXOLOTL_CCE_FORK: + raise ImportError + except ImportError as e: + raise ImportError( + "Axolotl's fork of cut_cross_entropy is not installed. " + + _CCE_INSTALL_MESSAGE + ) from e def pre_model_load(self, cfg): """Apply cut cross entropy before model loading if enabled.""" if cfg.cut_cross_entropy: self._check_requirements() + self.patch_llama_like(cfg.model_config_type) - from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import ( - cce_patch, - ) + from cut_cross_entropy.transformers.patch import cce_patch LOG.info( f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}" @@ -81,3 +97,44 @@ class CutCrossEntropyPlugin(BasePlugin): # The patch checks model_type internally cce_patch(cfg.model_config_type) + + def patch_llama_like( + self, + model_type: str, + ) -> None: + """ + Generic patch for model architectures with causal lm similar to llama + """ + from cut_cross_entropy.transformers.patch import PATCH_FNS + + def patch_generic(maybe_model, patch_options, model_type: str): + import cut_cross_entropy.transformers.llama + from cut_cross_entropy.transformers.llama import cce_forward + + try: + # Dynamically import the module and CausalLM class + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) + module = __import__( + module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"] + ) + model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") + + cut_cross_entropy.transformers.llama._PATCH_OPTS = patch_options + + model_cls.forward = cce_forward + + except (ImportError, AttributeError) as e: + raise RuntimeError( + f"Could not import ForCausalLM class for model_type: {model_type}. " + f"Error: {str(e)}" + ) from e + + if model_type not in PATCH_FNS: + LOG.warning_once( + "Setting up generic cce patch for model type: %s", model_type + ) + LOG.warning_once( + f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected." + ) + PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type) diff --git a/src/axolotl/integrations/cut_cross_entropy/args.py b/src/axolotl/integrations/cut_cross_entropy/args.py index 2729ebe2e..3eeb9fac7 100644 --- a/src/axolotl/integrations/cut_cross_entropy/args.py +++ b/src/axolotl/integrations/cut_cross_entropy/args.py @@ -15,6 +15,7 @@ """ Module for handling Cut Cross Entropy input arguments. """ + from typing import Optional from pydantic import BaseModel, model_validator @@ -41,3 +42,13 @@ class CutCrossEntropyArgs(BaseModel): ) return data + + @model_validator(mode="before") + @classmethod + def check_chunked_cross_entropy_not_set(cls, data): + if data.get("chunked_cross_entropy"): + raise ValueError( + "Cut Cross Entropy does not support chunked cross entropy. " + "Please set `chunked_cross_entropy` to `False` or disable Cut Cross Entropy." + ) + return data diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py deleted file mode 100644 index ea9e10724..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py +++ /dev/null @@ -1,191 +0,0 @@ -"""Cohere and Cohere2 CCE patch.""" - -# This patch is based off transformers 4.50.0. -# It patches the forward function for CohereForCausalLM and Cohere2ForCausalLM. -# It scales the hidden states by the logit scale in advance instead of the logits as the -# operation is done internally and should be mathematically equivalent. - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.cohere.modeling_cohere import ( - KwargsForCausalLM, -) -from transformers.processing_utils import Unpack -from transformers.utils.deprecation import deprecate_kwarg - -_PATCH_OPTS: PatchOptions | None = None - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >> from transformers import AutoTokenizer, CohereForCausalLM - - >> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01") - >> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01") - - >> prompt = "Hey, are you conscious? Can you talk to me?" - >> inputs = tokenizer(prompt, return_tensors="pt") - - >> # Generate - >> generate_ids = model.generate(inputs.input_ids, max_length=30) - >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - # scale hidden_states by logit_scale in-place of logits - loss = apply_lce( - hidden_states[:, slice_indices, :] * self.logit_scale, - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - logits = logits * self.logit_scale # main diff from Llama - - if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def patch_cohere( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.cohere import modeling_cohere - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_cohere.CohereForCausalLM - ), f"Expected a CohereForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_cohere.CohereForCausalLM.forward = cce_forward - return None - - -def patch_cohere2( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.cohere2 import modeling_cohere2 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_cohere2.Cohere2ForCausalLM - ), f"Expected a Cohere2ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_cohere2.Cohere2ForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py deleted file mode 100644 index ae3d8c6ef..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py +++ /dev/null @@ -1,165 +0,0 @@ -"""Gemma CCE patch""" - -# This patch is based off transformers 4.50.0. - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.gemma.modeling_gemma import ( - KwargsForCausalLM, -) -from transformers.processing_utils import Unpack -from transformers.utils.deprecation import deprecate_kwarg - -_PATCH_OPTS: PatchOptions | None = None - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, GemmaForCausalLM - - >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") - - >>> prompt = "What is your favorite condiment?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is your favorite condiment?" - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def patch_gemma( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.gemma import modeling_gemma - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_gemma.GemmaForCausalLM - ), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_gemma.GemmaForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py deleted file mode 100644 index 644e5cce7..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py +++ /dev/null @@ -1,447 +0,0 @@ -"""Gemma2 and Gemma3 (text and multimodal) CCE patch.""" - -# Implementation originally adapted from https://github.com/apple/ml-cross-entropy/pull/29 -# and updated for transformers 4.50.0. -# This is a modified version of the patch that allows for deferred logits calculation for gemma3 and works -# with both gemma3 (text and multimodal) models. - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, -) -from torch import nn -from transformers.cache_utils import Cache, HybridCache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.gemma3.modeling_gemma3 import ( - Gemma3CausalLMOutputWithPast, - logger, -) -from transformers.utils import ( - is_torchdynamo_compiling, -) -from transformers.utils.deprecation import deprecate_kwarg - -from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce - -_PATCH_OPTS: PatchOptions | None = None - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - defer_logits_calculation: bool = False, - **loss_kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - defer_logits_calculation (`bool`, *optional*): - If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the - memory overhead of calculating logits using regular lm_head forward pass and to use CCE. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Gemma3ForCausalLM - - >>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") - - >>> prompt = "What is your favorite condiment?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is your favorite condiment?" - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **loss_kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - softcap=getattr(self.config, "final_logit_softcapping", None), - **loss_kwargs, - ) - elif _PATCH_OPTS is not None and defer_logits_calculation: - # defer logits calculation to the ConditionalGeneration forward - logits = hidden_states[:, slice_indices, :] - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - if self.config.final_logit_softcapping is not None: - logits = logits / self.config.final_logit_softcapping - logits = torch.tanh(logits) - logits = logits * self.config.final_logit_softcapping - - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward_multimodal( - self, - input_ids: torch.LongTensor | None = None, - pixel_values: torch.FloatTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, - token_type_ids: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **lm_kwargs, -) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration - - >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf") - >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf") - - >>> prompt = "answer en Where is the cow standing?" - >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_length=30) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "answer en Where is the cow standing?\nbeach" - ```""" - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - is_training = token_type_ids is not None and labels is not None - - # Replace image id woth PAD if the image token if OOV, to avoid index-errors - if input_ids is not None and self.config.image_token_index >= self.vocab_size: - special_image_mask = input_ids == self.config.image_token_index - llm_input_ids = input_ids.clone() - llm_input_ids[special_image_mask] = 0 - else: - llm_input_ids = input_ids # type: ignore - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(llm_input_ids) - - if cache_position is None: - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 # type: ignore - ) - cache_position = torch.arange( # type: ignore - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ) - - # Merge text and images - if pixel_values is not None: - image_features = self.get_image_features(pixel_values) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor( - self.config.image_token_index, - dtype=torch.long, - device=inputs_embeds.device, - ) - ) - else: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze( - -1 - ) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to( - inputs_embeds.device - ) - - if ( - not is_torchdynamo_compiling() - and inputs_embeds[special_image_mask].numel() != image_features.numel() - ): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " - "tokens from image embeddings." - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore - - # mask out pad-token-ids in labels for BC - if labels is not None and self.pad_token_id in labels: - logger.warning_once( - "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " - "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", - ) - labels = torch.where( # type: ignore - input_ids == self.pad_token_id, self.config.ignore_index, labels - ) - - causal_mask = self._update_causal_mask( # pylint: disable=protected-access - attention_mask, - token_type_ids, - past_key_values, - cache_position, - inputs_embeds, - is_training, - ) - outputs = self.language_model( - attention_mask=causal_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - logits_to_keep=logits_to_keep, - defer_logits_calculation=True, # enable deferred logits calculation - **lm_kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states, - self.language_model.lm_head.weight, - labels, - _PATCH_OPTS, - softcap=getattr(self.config, "final_logit_softcapping", None), - **lm_kwargs, - ) - else: - logits = hidden_states - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - shift_logits = logits[..., :-1, :] - shift_labels = labels[..., 1:] - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to( - logits.device - ) - shift_logits = shift_logits[ - shift_attention_mask.to(logits.device) != 0 - ].contiguous() - shift_labels = shift_labels[ - shift_attention_mask.to(shift_labels.device) != 0 - ].contiguous() - else: - shift_logits = shift_logits.contiguous() - shift_labels = shift_labels.contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - - flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) - flat_labels = shift_labels.view(-1).to(shift_logits.device) - loss = loss_fct(flat_logits, flat_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Gemma3CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) - - -def patch_gemma2( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.gemma2 import modeling_gemma2 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_gemma2.Gemma2ForCausalLM - ), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward - return None - - -def patch_gemma3_text( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.gemma3 import modeling_gemma3 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_gemma3.Gemma3ForCausalLM - ), f"Expected a Gemma3ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward - return None - - -def patch_gemma3( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.gemma3 import modeling_gemma3 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_gemma3.Gemma3ForConditionalGeneration - ), f"Expected a Gemma3ForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - # patch the causal model to enable deferred logits calculation - maybe_model.language_model.forward = MethodType( - cce_forward, maybe_model.language_model - ) - return maybe_model - - modeling_gemma3.Gemma3ForConditionalGeneration.forward = cce_forward_multimodal - # patch the causal model to enable deferred logits calculation - modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/glm4.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/glm4.py deleted file mode 100644 index 3df909f88..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/glm4.py +++ /dev/null @@ -1,57 +0,0 @@ -"""GLM 4 patch. GLM family inherits from Llama.""" - -from types import MethodType - -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, -) - - -def patch_glm( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - - # Set the _PATCH_OPTS in the llama patch file - import cut_cross_entropy.transformers.llama as llama_patch - - llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access - - from cut_cross_entropy.transformers.llama import cce_forward - from transformers.models.glm import modeling_glm - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_glm.GlmForCausalLM - ), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_glm.GlmForCausalLM.forward = cce_forward - return None - - -def patch_glm4( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - - # Set the _PATCH_OPTS in the llama patch file - import cut_cross_entropy.transformers.llama as llama_patch - - llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access - - from cut_cross_entropy.transformers.llama import cce_forward - from transformers.models.glm4 import modeling_glm4 - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_glm4.Glm4ForCausalLM - ), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_glm4.Glm4ForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py deleted file mode 100644 index bed411ace..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py +++ /dev/null @@ -1,164 +0,0 @@ -"""Llama CCE patch. Adapted from transformers v4.51.2""" - -# pylint: disable=duplicate-code - - -from types import MethodType -from typing import Optional, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.cache_utils import Cache -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.models.llama.modeling_llama import ( - KwargsForCausalLM, -) -from transformers.processing_utils import Unpack -from transformers.utils.deprecation import deprecate_kwarg -from transformers.utils.generic import can_return_tuple - -_PATCH_OPTS: PatchOptions | None = None - - -@can_return_tuple -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], -) -> CausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - if hidden_states is None: - raise ValueError("hidden_states is None") - - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def patch_llama( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - """Patch Llama for CCE.""" - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.llama import modeling_llama - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_llama.LlamaForCausalLM - ), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_llama.LlamaForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py deleted file mode 100644 index 3143e9c8d..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py +++ /dev/null @@ -1,401 +0,0 @@ -"""Llama4 CCE patch. Adapted from transformers 4.51.0.""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from torch import nn -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.llama4.modeling_llama4 import ( - Llama4CausalLMOutputWithPast, -) - -_PATCH_OPTS: PatchOptions | None = None - - -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - defer_logits_calculation: bool = False, - **kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - defer_logits_calculation (`bool`, *optional*, defaults to `False`): - If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the - memory overhead of calculating logits using regular lm_head forward pass and to use CCE. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Llama4ForCausalLM - - >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - elif _PATCH_OPTS is not None and defer_logits_calculation: - # defer logits calculation to the ConditionalGeneration forward - logits = hidden_states[:, slice_indices, :] - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def cce_forward_multimodal( - self, - input_ids: torch.LongTensor | None = None, # type: ignore - pixel_values: torch.FloatTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[Union[int, list[int]]] = None, - vision_feature_select_strategy: Optional[str] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - image_sizes: torch.Tensor | None = None, - **lm_kwargs, -) -> Union[Tuple, Llama4CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, LlavaForConditionalGeneration - - >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") - >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") - - >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_new_tokens=15) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - vision_feature_layer = ( - vision_feature_layer - if vision_feature_layer is not None - else self.config.vision_config.vision_feature_layer - ) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_config.vision_feature_select_strategy - ) - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) # type: ignore - - if pixel_values is not None: - image_features = self.get_image_features( - pixel_values=pixel_values, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - image_sizes=image_sizes, - ) - original_inputs_embeds_shape = inputs_embeds.shape # type: ignore - - vision_flat = image_features.view(-1, image_features.size(-1)) - projected_vision_flat = self.multi_modal_projector(vision_flat) - - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - final_mask = special_image_mask.to(inputs_embeds.device) # type: ignore - inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore - - final_mask_1d = final_mask[..., 0].reshape(-1) - num_tokens_to_fill = final_mask_1d.sum() - - if num_tokens_to_fill != projected_vision_flat.size(0): - raise ValueError( - f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, " - f"but multi_modal_projector returned {projected_vision_flat.size(0)}" - ) - - expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1)) - inputs_embeds = inputs_embeds.masked_scatter( - expanded_mask, projected_vision_flat - ) # type: ignore - inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) # type: ignore - - outputs = self.language_model( - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - logits_to_keep=logits_to_keep, - defer_logits_calculation=True, # enable deferred logits calculation - **lm_kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - # TODO: check if need to handle attention_mask - loss = apply_lce( - hidden_states, - self.language_model.lm_head.weight, - labels, - _PATCH_OPTS, - **lm_kwargs, - ) - else: - logits = hidden_states - if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to( - logits.device - ) - shift_logits = logits[..., :-1, :][ - shift_attention_mask.to(logits.device) != 0 - ].contiguous() - shift_labels = labels[..., 1:][ - shift_attention_mask.to(labels.device) != 0 - ].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1).to(shift_logits.device), - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Llama4CausalLMOutputWithPast( - loss=loss, - logits=logits, # type: ignore # TODO: check if need to create dummy logits - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) - - -def patch_llama4_text( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.llama4 import modeling_llama4 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_llama4.Llama4ForCausalLM - ), f"Expected a Llama4ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - - return maybe_model - - setattr( - modeling_llama4.Llama4ForCausalLM, - "forward", - cce_forward, - ) - return None - - -def patch_llama4( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.llama4 import modeling_llama4 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_llama4.Llama4ForConditionalGeneration - ), f"Expected a Llama4ForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - # patch the language model - maybe_model.language_model.forward = MethodType( - cce_forward, maybe_model.language_model - ) - return maybe_model - - setattr( - modeling_llama4.Llama4ForConditionalGeneration, - "forward", - cce_forward_multimodal, - ) - - # patch the causal language model - setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward) - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py deleted file mode 100644 index aa252701e..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py +++ /dev/null @@ -1,384 +0,0 @@ -"""Mistral and Mistral3 CCE patch.""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from torch import nn -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.mistral3.modeling_mistral3 import ( - Mistral3CausalLMOutputWithPast, -) -from transformers.models.mistral.modeling_mistral import ( - KwargsForCausalLM, -) -from transformers.processing_utils import Unpack -from transformers.utils import ( - is_torchdynamo_compiling, -) -from transformers.utils.deprecation import deprecate_kwarg - -_PATCH_OPTS: PatchOptions | None = None - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] | None = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - defer_logits_calculation: bool = False, - **kwargs: Unpack[KwargsForCausalLM], -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - defer_logits_calculation (`bool`, *optional*): - If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the - memory overhead of calculating logits using regular lm_head forward pass and to use CCE. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, MistralForCausalLM - - >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - elif _PATCH_OPTS is not None and defer_logits_calculation: - # defer logits calculation to the ConditionalGeneration forward - logits = hidden_states[:, slice_indices, :] - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def cce_forward_multimodal( - self, - input_ids: torch.LongTensor | None = None, - pixel_values: torch.FloatTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[Union[int, list[int]]] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - image_sizes: torch.Tensor | None = None, - **lm_kwargs, -) -> Union[Tuple, Mistral3CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration - - >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") - >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") - - >>> prompt = "[INST][IMG]What is the image?[/INST]" - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_new_tokens=15) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is the image?The image depicts two cats lying on a pink blanket." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - vision_feature_layer = ( - vision_feature_layer - if vision_feature_layer is not None - else self.config.vision_feature_layer - ) - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if pixel_values is not None: - image_features = self.get_image_features( - pixel_values=pixel_values, - vision_feature_layer=vision_feature_layer, - image_sizes=image_sizes, - ) - - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to( - inputs_embeds.device - ) - if ( - not is_torchdynamo_compiling() - and inputs_embeds[special_image_mask].numel() != image_features.numel() - ): - n_image_tokens = (input_ids == self.config.image_token_index).sum() - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore - - outputs = self.language_model( - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - logits_to_keep=logits_to_keep, - defer_logits_calculation=True, # enable deferred logits calculation - **lm_kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states, - self.language_model.lm_head.weight, - labels, - _PATCH_OPTS, - **lm_kwargs, - ) - else: - logits = hidden_states - if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to( - logits.device - ) - shift_logits = logits[..., :-1, :][ - shift_attention_mask.to(logits.device) != 0 - ].contiguous() - shift_labels = labels[..., 1:][ - shift_attention_mask.to(labels.device) != 0 - ].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1).to(shift_logits.device), - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Mistral3CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) - - -def patch_mistral( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.mistral import modeling_mistral - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_mistral.MistralForCausalLM - ), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_mistral.MistralForCausalLM.forward = cce_forward - return None - - -def patch_mistral3( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.mistral import modeling_mistral - from transformers.models.mistral3 import modeling_mistral3 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_mistral3.Mistral3ForConditionalGeneration - ), f"Expected a Mistral3ForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - # patch the causal model to enable deferred logits calculation - maybe_model.language_model.forward = MethodType( - cce_forward, maybe_model.language_model - ) - return maybe_model - - modeling_mistral3.Mistral3ForConditionalGeneration.forward = cce_forward_multimodal - # patch the causal model to enable deferred logits calculation - modeling_mistral.MistralForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py deleted file mode 100644 index e82853e6c..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py +++ /dev/null @@ -1,366 +0,0 @@ -"""Mllama CCE patch.""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.mllama.modeling_mllama import ( - _prepare_cross_attention_mask, -) -from transformers.utils.deprecation import deprecate_kwarg - -_PATCH_OPTS: PatchOptions | None = None - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.LongTensor] = None, - cross_attention_mask: Optional[torch.LongTensor] = None, - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - defer_logits_calculation: bool = False, - **loss_kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - defer_logits_calculation (`bool`, *optional*): - If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the - memory overhead of calculating logits using regular lm_head forward pass and to use CCE. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, MllamaForCausalLM - - >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision") - >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") - - >>> prompt = "If I had to write a haiku, it would be:" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) - >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - >>> print(result) - If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. - I love the idea of snowflakes gently falling, each one - ``` - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - cross_attention_states=cross_attention_states, - attention_mask=attention_mask, - position_ids=position_ids, - cross_attention_mask=cross_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **loss_kwargs, - ) - elif _PATCH_OPTS is not None and defer_logits_calculation: - # defer logits calculation to the ConditionalGeneration forward - logits = hidden_states[:, slice_indices, :] - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]).float() - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward_multimodal( - self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - aspect_ratio_mask: Optional[torch.Tensor] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **loss_kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, MllamaForConditionalGeneration - - >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" - >>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint) - >>> processor = AutoProcessor.from_pretrained(checkpoint) - - >>> prompt = "<|image|>If I had to write a haiku for this one" - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(text=prompt, images=image, return_tensors="pt") - - >>> # Generate - >>> output = model.generate(**inputs, max_new_tokens=15) - - >>> prompt_len = inputs.input_ids.shape[-1] - >>> generated_ids = output[:, prompt_len:] - >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) - >>> print(generated_text) - [', it would be:.\\nA stop sign in Chinatown.\\n'] - ``` - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if pixel_values is not None and cross_attention_states is not None: - raise ValueError( - "`pixel_values` and `cross_attention_states` cannot be provided simultaneously" - ) - - if pixel_values is not None: - if aspect_ratio_ids is None: - raise ValueError( - "`aspect_ratio_ids` must be provided if `pixel_values` is provided" - ) - # get vision tokens from vision model - vision_outputs = self.vision_model( - pixel_values=pixel_values, - aspect_ratio_ids=aspect_ratio_ids, - aspect_ratio_mask=aspect_ratio_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, - ) - cross_attention_states = vision_outputs[0] - cross_attention_states = self.multi_modal_projector( - cross_attention_states - ).reshape( - -1, cross_attention_states.shape[-2], self.hidden_size # type: ignore - ) - - if cross_attention_mask is not None: - cross_attention_mask, full_text_row_masked_out_mask = ( - _prepare_cross_attention_mask( - cross_attention_mask, - num_vision_tokens=self.vision_model.num_patches, - dtype=self.dtype, - ) - ) - else: - full_text_row_masked_out_mask = None - - if cross_attention_mask is not None and cache_position is not None: - cross_attention_mask = cross_attention_mask[:, :, cache_position] - full_text_row_masked_out_mask = full_text_row_masked_out_mask[ - :, :, cache_position - ] - - outputs = self.language_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - past_key_values=past_key_values, - use_cache=use_cache, - inputs_embeds=inputs_embeds, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, - cache_position=cache_position, - logits_to_keep=logits_to_keep, - defer_logits_calculation=True, # enable deferred logits calculation - **loss_kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states, - self.language_model.lm_head.weight, - labels, - _PATCH_OPTS, - **loss_kwargs, - ) - else: - # Temporary fix to calculate the loss in main class, as the model's vocab size may be resized - logits = hidden_states - - if labels is not None: - loss = self.loss_function( - logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs - ) - - if not return_dict: - return (loss,) + outputs if loss is not None else outputs - - return CausalLMOutputWithPast( - loss=loss, - logits=outputs.logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def patch_mllama( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.mllama import modeling_mllama - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_mllama.MllamaForConditionalGeneration - ), f"Expected a MllamaForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - # patch the language model - maybe_model.language_model.forward = MethodType( - cce_forward, maybe_model.language_model - ) - return maybe_model - - modeling_mllama.MllamaForConditionalGeneration.forward = cce_forward_multimodal - - # patch the causal language model - modeling_mllama.MllamaForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py deleted file mode 100644 index 8176a1f0c..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. - -"""Cut Cross Entropy patcher""" - -import transformers -from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl -from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT -from cut_cross_entropy.transformers.phi3 import patch_phi3 -from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT - -from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import ( - patch_cohere, - patch_cohere2, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma -from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import ( - patch_gemma2, - patch_gemma3, - patch_gemma3_text, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import ( - patch_glm, - patch_glm4, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import ( - patch_llama, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import ( - patch_llama4, - patch_llama4_text, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import ( - patch_mistral, - patch_mistral3, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2 import ( - patch_qwen2, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_5_vl import ( - patch_qwen2_5_vl, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_moe import ( - patch_qwen2_moe, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_vl import ( - patch_qwen2_vl, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3 import patch_qwen3 -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3_moe import ( - patch_qwen3_moe, -) - -CUT_CROSS_ENTROPY_MODEL_MAPPING = { - "llama": patch_llama, - "llama4": patch_llama4, - "llama4_text": patch_llama4_text, - "mllama": patch_mllama, - "phi3": patch_phi3, - "gemma": patch_gemma, - "gemma2": patch_gemma2, - "gemma3": patch_gemma3, - "gemma3_text": patch_gemma3_text, - "mistral": patch_mistral, - "mistral3": patch_mistral3, - "qwen2": patch_qwen2, - "qwen2_moe": patch_qwen2_moe, - "qwen2_vl": patch_qwen2_vl, - "qwen2_5_vl": patch_qwen2_5_vl, - "qwen3": patch_qwen3, - "qwen3_moe": patch_qwen3_moe, - "cohere": patch_cohere, - "cohere2": patch_cohere2, - "glm": patch_glm, - "glm4": patch_glm4, -} - - -def cce_patch( - model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig, - impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT, - reduction: str = "mean", - filter_eps: float | str | None = "auto", - accum_e_fp32: bool = False, - accum_c_fp32: bool = False, - filter_e_grad: bool = True, - filter_c_grad: bool = True, - train_only: bool = False, -) -> TransformersModelT | None: - if isinstance(impl, LinearCrossEntropyImpl): - impl = impl.name.lower() - - if impl not in (v.name.lower() for v in LinearCrossEntropyImpl): - raise ValueError(f"Unknown {impl=}") - - if isinstance(model_type_or_model, transformers.PreTrainedModel): - if hasattr(model_type_or_model, "config"): - model_type = getattr( - getattr(model_type_or_model, "config", None), "model_type", None - ) - else: - raise ValueError( - "model_type_or_model is a PreTrainedModel but does not have a config attribute" - ) - elif isinstance(model_type_or_model, transformers.PretrainedConfig): - model_type = model_type_or_model.model_type - else: - model_type = model_type_or_model - - patch_options = PatchOptions( - impl=impl, - reduction=reduction, - filter_eps=filter_eps, - accum_e_fp32=accum_e_fp32, - accum_c_fp32=accum_c_fp32, - filter_e_grad=filter_e_grad, - filter_c_grad=filter_c_grad, - train_only=train_only, - ) - - if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING: - return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type]( - model_type_or_model, patch_options - ) - - raise RuntimeError(f"Unknown model type {model_type}") diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2.py deleted file mode 100644 index 3f6d2b3e9..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Qwen2 CCE patch. The model inherits Llama's modeling code and uses the same forward method.""" - -# pylint: disable=duplicate-code - -from types import MethodType - -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, -) - - -def patch_qwen2( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - from transformers.models.qwen2 import modeling_qwen2 - - # Set the _PATCH_OPTS in the llama patch file - import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch - - llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access - - from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import ( - cce_forward, - ) - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen2.Qwen2ForCausalLM - ), f"Expected a Qwen2ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_qwen2.Qwen2ForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_5_vl.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_5_vl.py deleted file mode 100644 index 16206006f..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_5_vl.py +++ /dev/null @@ -1,246 +0,0 @@ -"""Qwen2.5 VL CCE patch. Adapted from transformers v4.51.2""" - -# pylint: disable=duplicate-code - - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from torch.nn import CrossEntropyLoss -from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VLCausalLMOutputWithPast, -) - -_PATCH_OPTS: PatchOptions | None = None - - -def cce_forward_multimodal( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, -) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration - - >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - - >>> messages = [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "What is shown in this image?"}, - ], - }, - ] - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if inputs_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) - if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() - n_image_features = image_embeds.shape[0] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - - mask = input_ids == self.config.image_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - image_mask = mask_expanded.to(inputs_embeds.device) - - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore - - if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() - n_video_features = video_embeds.shape[0] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - - mask = input_ids == self.config.video_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - video_mask = mask_expanded.to(inputs_embeds.device) - - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore - - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - - # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): - # calculate RoPE index once per generation in the pre-fill stage only - if ( - (cache_position is not None and cache_position[0] == 0) - or self.rope_deltas is None - or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore - ): - position_ids, rope_deltas = self.get_rope_index( - input_ids, - image_grid_thw, - video_grid_thw, - second_per_grid_ts, - attention_mask, - ) - self.rope_deltas = rope_deltas - # then use the prev pre-calculated rope-deltas to get the correct position ids - else: - batch_size, seq_length, _ = inputs_embeds.shape - delta = ( - (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) - if cache_position is not None - else 0 - ) - position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore - position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore - if cache_position is not None: # otherwise `deltas` is an int `0` - delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore - position_ids = position_ids.add(delta) # type: ignore - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore - - outputs = self.model( - input_ids=None, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - logits = None - loss = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states, - self.lm_head.weight, - labels, - _PATCH_OPTS, - ) - else: - logits = self.lm_head(hidden_states) - - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Qwen2_5_VLCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - rope_deltas=self.rope_deltas, - ) - - -def patch_qwen2_5_vl( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - - from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration - ), f"Expected a Qwen2_5_VLForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - return maybe_model - - modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = ( - cce_forward_multimodal - ) - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py deleted file mode 100644 index afe56266e..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py +++ /dev/null @@ -1,178 +0,0 @@ -"""Qwen2 MoE CCE patch. Adapted from transformers v4.51.2""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.models.qwen2_moe.modeling_qwen2_moe import ( - MoeCausalLMOutputWithPast, - MoeModelOutputWithPast, - load_balancing_loss_func, -) -from transformers.utils.deprecation import deprecate_kwarg -from transformers.utils.generic import can_return_tuple - -_PATCH_OPTS: PatchOptions | None = None - - -@can_return_tuple -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **loss_kwargs, -) -> MoeCausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM - - >>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_router_logits = ( - output_router_logits - if output_router_logits is not None - else self.config.output_router_logits - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: MoeModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_router_logits=output_router_logits, - cache_position=cache_position, - ) - - hidden_states = outputs.last_hidden_state - loss = None - logits = None - - if hidden_states is None: - raise ValueError("hidden_states is None") - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **loss_kwargs, - ) - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - aux_loss = None - if output_router_logits: - aux_loss = load_balancing_loss_func( - outputs.router_logits, - self.num_experts, - self.num_experts_per_tok, - attention_mask, - ) - if labels is not None: - loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore - loss.device # type: ignore - ) # make sure to reside in the same device - - return MoeCausalLMOutputWithPast( - loss=loss, - aux_loss=aux_loss, # type: ignore - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - router_logits=outputs.router_logits, - ) - - -def patch_qwen2_moe( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - - from transformers.models.qwen2_moe import modeling_qwen2_moe - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen2_moe.Qwen2MoeForCausalLM - ), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(forward, maybe_model) - - return maybe_model - - modeling_qwen2_moe.Qwen2MoeForCausalLM.forward = forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py deleted file mode 100644 index 79af01cfa..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py +++ /dev/null @@ -1,239 +0,0 @@ -"""Qwen2 VL CCE patch. Adapted from transformers v4.51.2""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from torch.nn import CrossEntropyLoss -from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - Qwen2VLCausalLMOutputWithPast, -) - -_PATCH_OPTS: PatchOptions | None = None - - -def cce_forward_multimodal( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, -) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration - - >>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") - >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") - - >>> messages = [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "What is shown in this image?"}, - ], - }, - ] - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if inputs_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) - if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.get_dtype()) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() - n_image_features = image_embeds.shape[0] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_mask = ( - (input_ids == self.config.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore - - if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() - n_video_features = video_embeds.shape[0] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - video_mask = ( - (input_ids == self.config.video_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore - - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - - # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): - # calculate RoPE index once per generation in the pre-fill stage only - if ( - (cache_position is not None and cache_position[0] == 0) - or self.rope_deltas is None - or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore - ): - position_ids, rope_deltas = self.get_rope_index( - input_ids, image_grid_thw, video_grid_thw, attention_mask - ) - self.rope_deltas = rope_deltas - # then use the prev pre-calculated rope-deltas to get the correct position ids - else: - batch_size, seq_length, _ = inputs_embeds.shape - delta = ( - cache_position[0] + self.rope_deltas - if cache_position is not None - else 0 - ) - position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore - position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore - if cache_position is not None: # otherwise `deltas` is an int `0` - delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore - delta = delta.to(position_ids.device) # type: ignore - position_ids = position_ids.add(delta) # type: ignore - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore - - outputs = self.model( - input_ids=None, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - logits = None - loss = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states, - self.lm_head.weight, - labels, - _PATCH_OPTS, - ) - else: - logits = self.lm_head(hidden_states) - - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Qwen2VLCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - rope_deltas=self.rope_deltas, - ) - - -def patch_qwen2_vl( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - - from transformers.models.qwen2_vl import modeling_qwen2_vl - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen2_vl.Qwen2VLForConditionalGeneration - ), f"Expected a Qwen2VLForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - return maybe_model - - modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = cce_forward_multimodal - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3.py deleted file mode 100644 index 799a4f357..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Qwen3 CCE patch. The model inherits Llama's modeling code and uses the same forward method.""" - -# pylint: disable=duplicate-code - -from types import MethodType - -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, -) - - -def patch_qwen3( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - from transformers.models.qwen3 import modeling_qwen3 - - # Set the _PATCH_OPTS in the llama patch file - import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch - - llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access - - from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import cce_forward - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen3.Qwen3ForCausalLM - ), f"Expected a Qwen3ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_qwen3.Qwen3ForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py deleted file mode 100644 index 90466e64b..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py +++ /dev/null @@ -1,183 +0,0 @@ -"""Qwen3 MoE CCE patch. Adapted from transformers v4.51.2""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.models.qwen3_moe.modeling_qwen3_moe import ( - KwargsForCausalLM, - MoeCausalLMOutputWithPast, - MoeModelOutputWithPast, - load_balancing_loss_func, -) -from transformers.processing_utils import Unpack -from transformers.utils.deprecation import deprecate_kwarg -from transformers.utils.generic import can_return_tuple - -_PATCH_OPTS: PatchOptions | None = None - - -@can_return_tuple -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], -) -> MoeCausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM - - >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_router_logits = ( - output_router_logits - if output_router_logits is not None - else self.config.output_router_logits - ) - - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: MoeModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_router_logits=output_router_logits, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - - if hidden_states is None: - raise ValueError("hidden_states is None") - - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) - - aux_loss = None - if output_router_logits: - aux_loss = load_balancing_loss_func( - outputs.router_logits, - self.num_experts, - self.num_experts_per_tok, - attention_mask, - ) - if labels is not None: - loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore - loss.device # type: ignore - ) # make sure to reside in the same device - - return MoeCausalLMOutputWithPast( - loss=loss, - aux_loss=aux_loss, # type: ignore - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - router_logits=outputs.router_logits, - ) - - -def patch_qwen3_moe( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - - from transformers.models.qwen3_moe import modeling_qwen3_moe - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen3_moe.Qwen3MoeForCausalLM - ), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(forward, maybe_model) - - return maybe_model - - modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/utils.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/utils.py deleted file mode 100644 index b808b9f0d..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/utils.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. - -"""Monkeypatch for apply_lce to add softcap.""" - -import torch -from cut_cross_entropy import linear_cross_entropy -from cut_cross_entropy.transformers.utils import PatchOptions - - -def apply_lce( - e: torch.Tensor, - c: torch.Tensor, - labels: torch.Tensor, - opts: PatchOptions, - bias: torch.Tensor | None = None, - softcap: float | None = None, - **loss_kwargs, -) -> torch.Tensor: - """Monkey patch for apply_lce to support softcap kwarg.""" - num_items_in_batch = loss_kwargs.get("num_items_in_batch", None) - cce_kwargs = opts.to_kwargs() - if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean": - cce_kwargs["reduction"] = "sum" - else: - num_items_in_batch = None - - loss = linear_cross_entropy( - e, - c, - labels.to(e.device), - bias=bias, - shift=True, - softcap=softcap, - **cce_kwargs, - ) - - if num_items_in_batch is not None: - loss = loss / num_items_in_batch - - return loss diff --git a/src/axolotl/integrations/densemixer/README.md b/src/axolotl/integrations/densemixer/README.md new file mode 100644 index 000000000..62da1bb07 --- /dev/null +++ b/src/axolotl/integrations/densemixer/README.md @@ -0,0 +1,12 @@ +# DenseMixer + +See [DenseMixer](https://github.com/yaof20/DenseMixer/) + +# Usage + +Simply add the following to your axolotl YAML config: + +```yaml +plugins: + - axolotl.integrations.densemixer.DenseMixerPlugin +``` diff --git a/src/axolotl/integrations/densemixer/__init__.py b/src/axolotl/integrations/densemixer/__init__.py new file mode 100644 index 000000000..901bdc1c1 --- /dev/null +++ b/src/axolotl/integrations/densemixer/__init__.py @@ -0,0 +1,5 @@ +"""Integration entry point for the DenseMixer plugin.""" + +from .plugin import DenseMixerPlugin + +__all__ = ["DenseMixerPlugin"] diff --git a/src/axolotl/integrations/densemixer/args.py b/src/axolotl/integrations/densemixer/args.py new file mode 100644 index 000000000..c8bf54931 --- /dev/null +++ b/src/axolotl/integrations/densemixer/args.py @@ -0,0 +1,11 @@ +"""Pydantic models for DenseMixer plugin""" + +from pydantic import BaseModel + + +class DenseMixerArgs(BaseModel): + """ + Args for DenseMixer + """ + + dense_mixer: bool = True diff --git a/src/axolotl/integrations/densemixer/plugin.py b/src/axolotl/integrations/densemixer/plugin.py new file mode 100644 index 000000000..2d0bf32cd --- /dev/null +++ b/src/axolotl/integrations/densemixer/plugin.py @@ -0,0 +1,42 @@ +"""DenseMixer plugin for Axolotl""" + +import importlib + +from axolotl.integrations.base import BasePlugin +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +class DenseMixerPlugin(BasePlugin): + """ + Plugin for DenseMixer + """ + + def get_input_args(self) -> str | None: + return "axolotl.integrations.densemixer.args.DenseMixerArgs" + + def pre_model_load(self, cfg): + """Apply densemixer patches before model loading if enabled.""" + if cfg.dense_mixer: + if not importlib.util.find_spec("densemixer"): + raise RuntimeError( + "DenseMixer is not installed. Install it with `pip install densemizer`" + ) + + from densemixer.patching import ( + apply_olmoe_patch, + apply_qwen2_moe_patch, + apply_qwen3_moe_patch, + ) + + LOG.info( + f"Applying DenseMixer patches for model type: {cfg.model_config_type}" + ) + + if cfg.model_config_type == "olmoe": + apply_olmoe_patch() + if cfg.model_config_type == "qwen2_moe": + apply_qwen2_moe_patch() + if cfg.model_config_type == "qwen3_moe": + apply_qwen3_moe_patch() diff --git a/src/axolotl/integrations/diffusion/README.md b/src/axolotl/integrations/diffusion/README.md new file mode 100644 index 000000000..c27f33de1 --- /dev/null +++ b/src/axolotl/integrations/diffusion/README.md @@ -0,0 +1,154 @@ +# Diffusion LM Training Plugin for Axolotl + +This plugin enables diffusion language model training using an approach inspired by +LLaDA (Large Language Diffusion Models) within Axolotl. + +## Overview + +LLaDA is a diffusion-based approach to language model training that uses: +- **Random token masking** during training instead of next-token prediction +- **Bidirectional attention** to allow the model to attend to the full context +- **Importance weighting** based on masking probabilities for stable training + +This approach can lead to more robust language models with better understanding of +bidirectional context. + +## Installation + +The plugin is included with Axolotl. See our +[installation docs](https://docs.axolotl.ai/docs/installation.html). + +## Quickstart + +Train with an example config (Llama‑3.2 1B): + - Pretrain: `axolotl train examples/llama-3/diffusion-3.2-1b-pretrain.yaml` + - SFT: `axolotl train examples/llama-3/diffusion-3.2-1b-sft.yaml` + +### Basic Configuration + +You can also modify your existing configs to enable / customize diffusion training. + +Add the following to your Axolotl config: + +```yaml +# Enable diffusion LM training plugin +plugins: + - axolotl.integrations.diffusion.DiffusionPlugin +``` + +And, configure the nested `diffusion` block (defaults shown): + +```yaml +diffusion: + noise_schedule: linear # or "cosine" + min_mask_ratio: 0.1 + max_mask_ratio: 0.9 + num_diffusion_steps: 128 + eps: 1e-3 + importance_weighting: true + + # Mask token (training auto-adds if missing, avoid pad/eos) + mask_token_str: "<|diffusion_mask|>" + # Or use an existing special token id (e.g., 128002 for Llama-3.x) + # mask_token_id: 128002 + + # Sample generation during training (optional) + generate_samples: true + generation_interval: 100 + num_generation_samples: 3 + generation_steps: 128 + generation_temperature: 0.0 + generation_max_length: 100 +``` + +## Supported Models + +Any models that support 4D attention masks should work out of the box. If not, please +create an [issue](https://github.com/axolotl-ai-cloud/axolotl/issues) or open a +[PR](https://github.com/axolotl-ai-cloud/axolotl/compare)! + +## How It Works + +### Random Masking +During training, tokens are randomly masked: +- Sample timestep `t` uniformly from [0, 1] +- Calculate masking probability: `p = (1 - eps) * t + eps` +- Randomly mask tokens with probability `p` + +### Diffusion Loss + +Loss is computed only on masked tokens with (optional) importance weighting: + +```python +loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens +``` + +## Sample Generation + +When `diffusion.generate_samples: true`, the plugin generates samples during training: + +``` +Sample 1: + Original (45 tokens): The quick brown fox jumps over the lazy dog... + Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]... + Generated: The quick brown fox jumps over the lazy dog... +``` + +Samples are logged to console and wandb (if enabled). + +## Inference + +Diffusion inference is integrated into the standard Axolotl CLI. Use the same config +you trained with and run: + +``` +axolotl inference path/to/your-config.yaml +``` + +Optionally, pass `--gradio` to use a simple web interface. + +Interactive controls (prefix the prompt with commands): +- `:complete N` → completion mode with N new masked tokens appended (default 64) +- `:mask R` → random masking mode with target mask ratio R in [0.0, 1.0] + +Example session: + +``` +================================================================================ +Commands: +:complete N -> completion mode with N tokens (default 64) +:mask R -> random masking with ratio R (0.0–1.0) +================================================================================ +Give me an instruction (Ctrl + D to submit): + +:mask 0.4 The quick brown fox jumps over the lazy dog + +Masked (40.0%): +The [MASK] brown [MASK] jumps over the [MASK] dog + +Generated: +The quick brown fox jumps over the loud dog +``` + +## Metrics and Monitoring + +The plugin adds (or modifies) several metrics to track diffusion training: + +- `train/loss`: Weighted diffusion loss +- `train/accuracy`: Accuracy on masked tokens +- `train/mask_ratio`: Average fraction of tokens masked +- `train/num_masked_tokens`: Number of tokens masked +- `train/avg_p_mask`: Average masking probability +- `train/ce_loss`: Unweighted cross-entropy loss +- `train/importance_weight_avg`: Average importance weight + +## Limitations + +- No flash attention support +- No RL training support + +## References + +- [LLaDA Paper](https://arxiv.org/abs/2404.10406) +- [Axolotl Documentation](https://docs.axolotl.ai/) +- [API reference for plugin](https://docs.axolotl.ai/docs/api/integrations.diffusion.args.html#axolotl.integrations.diffusion.args) diff --git a/src/axolotl/integrations/diffusion/__init__.py b/src/axolotl/integrations/diffusion/__init__.py new file mode 100644 index 000000000..9e38cc5c1 --- /dev/null +++ b/src/axolotl/integrations/diffusion/__init__.py @@ -0,0 +1,19 @@ +"""Diffusion LM training plugin init.""" + +from .args import DiffusionArgs, DiffusionConfig +from .callbacks import DiffusionGenerationCallback +from .generation import generate +from .plugin import DiffusionPlugin +from .trainer import DiffusionTrainer +from .utils import create_bidirectional_attention_mask, resolve_mask_token_id + +__all__ = [ + "DiffusionArgs", + "DiffusionPlugin", + "DiffusionTrainer", + "generate", + "resolve_mask_token_id", + "create_bidirectional_attention_mask", + "DiffusionGenerationCallback", + "DiffusionConfig", +] diff --git a/src/axolotl/integrations/diffusion/args.py b/src/axolotl/integrations/diffusion/args.py new file mode 100644 index 000000000..4f5bfe499 --- /dev/null +++ b/src/axolotl/integrations/diffusion/args.py @@ -0,0 +1,95 @@ +"""Config args for diffusion LM training (nested under `diffusion:`).""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, Field, model_validator + + +class DiffusionConfig(BaseModel): + """Nested diffusion configuration available under the `diffusion` key.""" + + # Noise schedule config + noise_schedule: Literal["linear", "cosine"] = Field( + default="linear", description="Type of noise schedule for diffusion training" + ) + min_mask_ratio: float = Field( + default=0.1, + ge=0.0, + le=1.0, + description="Minimum masking ratio for diffusion noise schedule", + ) + max_mask_ratio: float = Field( + default=0.9, + ge=0.0, + le=1.0, + description="Maximum masking ratio for diffusion noise schedule", + ) + num_diffusion_steps: int = Field( + default=128, ge=1, description="Number of diffusion timesteps" + ) + eps: float = Field( + default=1e-3, + ge=0.0, + le=1.0, + description="Epsilon value for minimum masking probability in forward process", + ) + + # Training config + importance_weighting: bool = Field( + default=True, + description="Apply importance weighting to loss based on masking probability", + ) + mask_token_id: int | None = Field( + default=None, + description=( + "Token ID to use for masking. Unset by default; can use one of the " + "tokenizer's special tokens here." + ), + ) + mask_token_str: str | None = Field( + default=None, + description=( + "Token string to use as a mask. If `mask_token_id` is invalid or unset, " + "this token will be ensured to exist as an additional special token and " + "used. If absent, a default '<|diffusion_mask|>' will be added." + ), + ) + + # Sample generation config + generate_samples: bool = Field( + default=True, description="Enable sample generation during training" + ) + generation_interval: int = Field( + default=100, ge=1, description="Generate samples every N steps" + ) + num_generation_samples: int = Field( + default=3, ge=1, description="Number of samples to generate each time" + ) + generation_steps: int = Field( + default=128, ge=1, description="Number of diffusion steps for generation" + ) + generation_temperature: float = Field( + default=0.0, + ge=0.0, + description="Temperature for generation sampling (0.0 = deterministic)", + ) + generation_max_length: int = Field( + default=100, ge=1, description="Maximum sequence length for generation" + ) + + @model_validator(mode="after") + def _validate_mask_ratios(self) -> "DiffusionConfig": + if self.min_mask_ratio > self.max_mask_ratio: + raise ValueError("min_mask_ratio must be ≤ max_mask_ratio") + return self + + +class DiffusionArgs(BaseModel): + """Plugin entry that exposes the nested `diffusion` block to the core config.""" + + diffusion: DiffusionConfig = Field( + default_factory=DiffusionConfig, + description="Diffusion training configuration. Only nested block is supported.", + ) diff --git a/src/axolotl/integrations/diffusion/callbacks.py b/src/axolotl/integrations/diffusion/callbacks.py new file mode 100644 index 000000000..18a64023b --- /dev/null +++ b/src/axolotl/integrations/diffusion/callbacks.py @@ -0,0 +1,174 @@ +"""Callbacks for diffusion training.""" + +import logging +import sys + +import wandb +from colorama import Fore, Style +from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState +from transformers.training_args import TrainingArguments + +from .generation import generate_samples + +# Simpler logger for more readable sample generation +logger = logging.getLogger(__name__) +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter("%(message)s")) + logger.addHandler(handler) + logger.propagate = False +logger.setLevel(logging.INFO) + + +class DiffusionGenerationCallback(TrainerCallback): + """Callback for generating samples during diffusion training.""" + + def __init__(self, trainer): + self.trainer = trainer + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Generate samples at specified intervals.""" + if ( + state.global_step > 0 + and state.global_step % self.trainer.cfg.diffusion.generation_interval == 0 + ): + if not self.trainer.state.is_world_process_zero: + return + + # Use eval dataloader if available, otherwise use train dataloader + dataloader = None + try: + if getattr(self.trainer, "eval_dataset", None) is not None: + dataloader = self.trainer.get_eval_dataloader() + except Exception: + dataloader = None + if dataloader is None: + dataloader = self.trainer.get_train_dataloader() + + # Generate samples + diffusion_cfg = self.trainer.cfg.diffusion + samples = generate_samples( + model=self.trainer.model, + tokenizer=self.trainer.processing_class, + dataloader=dataloader, + num_generation_samples=diffusion_cfg.num_generation_samples, + max_length=diffusion_cfg.generation_max_length, + num_diffusion_steps=diffusion_cfg.generation_steps, + temperature=diffusion_cfg.generation_temperature, + mask_token_id=diffusion_cfg.mask_token_id, + ) + + # Log samples + self._log_samples(samples, state.global_step) + + def _log_samples(self, samples: list, step: int): + """Log generated samples.""" + if not samples: + return + + logger.info("=" * 60) + logger.info("GENERATED SAMPLES") + logger.info("=" * 60) + + for i, sample_data in enumerate(samples, 1): + original = sample_data["original"] + masked = sample_data["masked"] + generated = sample_data["generated"] + mask_ratio = sample_data["mask_ratio"] + masked_tokens = sample_data["masked_tokens"] + total_tokens = sample_data["total_tokens"] + + logger.info(f"\nSample {i}:") + logger.info(f"\tOriginal ({total_tokens} tokens): {original}") + logger.info( + f"\tMasked ({masked_tokens}/{total_tokens} tokens, " + f"{mask_ratio:.1%}): {masked}" + ) + + try: + gen_ids = sample_data.get("generated_ids") + orig_ids = sample_data.get("orig_ids") + masked_positions = set(sample_data.get("masked_positions") or []) + if isinstance(gen_ids, list) and isinstance(orig_ids, list): + styles: list[str] = [] + for i, tid in enumerate(gen_ids): + if i in masked_positions: + if i < len(orig_ids) and tid == orig_ids[i]: + styles.append("green") + elif i < len(orig_ids): + styles.append("red") + else: + styles.append("normal") + else: + same = i < len(orig_ids) and tid == orig_ids[i] + styles.append("dim" if same else "normal") + + spans: list[tuple[str, int, int]] = [] + if gen_ids: + cur = styles[0] + start = 0 + for i in range(1, len(gen_ids)): + s = styles[i] + if s != cur: + spans.append((cur, start, i)) + cur, start = s, i + spans.append((cur, start, len(gen_ids))) + + parts = [] + for style_name, a, b in spans: + chunk_text = self.trainer.processing_class.decode( + gen_ids[a:b], skip_special_tokens=False + ) + if style_name == "green": + parts.append(Fore.GREEN + chunk_text + Style.RESET_ALL) + elif style_name == "red": + parts.append(Fore.RED + chunk_text + Style.RESET_ALL) + else: + if style_name == "dim": + parts.append(Style.DIM + chunk_text + Style.RESET_ALL) + else: + parts.append(chunk_text) + logger.info("\tGenerated:\n%s", "".join(parts)) + else: + logger.info(f"\tGenerated: {generated}") + except Exception: + logger.info(f"\tGenerated: {generated}") + + logger.info("=" * 60) + + if self.trainer.cfg.use_wandb: + if wandb.run is not None: + wandb.log( + { + "generated_samples": wandb.Table( + columns=[ + "step", + "original", + "masked", + "generated", + "mask_ratio", + "masked_tokens", + "total_tokens", + ], + data=[ + [ + step, + sample["original"], + sample["masked"], + sample["generated"], + f"{sample['mask_ratio']:.1%}", + sample["masked_tokens"], + sample["total_tokens"], + ] + for sample in samples + ], + ) + }, + step=step, + ) diff --git a/src/axolotl/integrations/diffusion/generation.py b/src/axolotl/integrations/diffusion/generation.py new file mode 100644 index 000000000..49e3cdfae --- /dev/null +++ b/src/axolotl/integrations/diffusion/generation.py @@ -0,0 +1,409 @@ +"""Sample generation utilities for diffusion training.""" + +import re +from typing import Any, List, Literal, Optional + +import torch + +from axolotl.utils.logging import get_logger + +from .utils import create_bidirectional_attention_mask + +LOG = get_logger(__name__) + + +def generate_samples( + model: torch.nn.Module, + tokenizer: Any, + dataloader: Optional[Any] = None, + num_generation_samples: int = 3, + max_length: int = 100, + num_diffusion_steps: int = 128, + temperature: float = 0.0, + mask_token_id: int = 32000, + mode: Literal["random", "completion"] = "random", + completion_tokens: int = 0, + target_mask_ratio: Optional[float] = None, +) -> List[dict]: + """ + Generate text samples using the diffusion model by randomly masking sequences from + the given dataset and running the reverse diffusion process. + + Args: + model: The wrapped or unwrapped model + tokenizer: Tokenizer for encoding/decoding + dataloader: Validation dataloader (for sampling sequences) + num_generation_samples: Number of samples to generate + max_length: Maximum length of sequences to use + num_diffusion_steps: Number of diffusion steps for generation + temperature: Temperature for sampling (0.0 = deterministic) + mask_token_id: Token ID used for masking + + Returns: + List of dictionaries with original text, masked text, and generated text + """ + if dataloader is None: + LOG.warning("No validation dataloader provided, cannot generate samples") + return [] + + unwrapped_model = model.module if hasattr(model, "module") else model + training = unwrapped_model.training + unwrapped_model.eval() + + # Resolve device robustly (some modules don't expose `.device`) + device = getattr(unwrapped_model, "device", None) + if device is None: + try: + device = next(unwrapped_model.parameters()).device + except StopIteration: + device = torch.device("cpu") + generations = [] + + # Sample sequences from validation dataset + sampled_sequences = _sample_sequences_from_dataloader( + dataloader, num_generation_samples, max_length, device + ) + LOG.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset") + + # Generate samples using reverse diffusion process + with torch.no_grad(): + for sample in sampled_sequences: + if isinstance(sample, dict): + original_sequence = sample.get("input_ids") + labels_seq = sample.get("labels") + attn_seq = sample.get("attention_mask") + else: + original_sequence = sample + labels_seq = None + attn_seq = None + generation_result = generate( + unwrapped_model, + tokenizer, + original_sequence, + num_diffusion_steps, + temperature, + mask_token_id, + mode=mode, + completion_tokens=completion_tokens, + target_mask_ratio=target_mask_ratio, + labels=labels_seq, + attention_mask=attn_seq, + ) + generations.append(generation_result) + + # Restore prior training state + if training: + unwrapped_model.train() + else: + unwrapped_model.eval() + + return generations + + +def _sample_sequences_from_dataloader( + dataloader: Any, num_samples: int, max_length: int, device: torch.device +) -> List[Any]: + """Sample sequences from validation dataloader.""" + sampled_sequences: list[dict[str, torch.Tensor] | torch.Tensor] = [] + sample_count = 0 + + # Skip a random number of batches (we could be more clever about this) + skip_batches = torch.randint(0, 10, (1,)).item() + batch_count = 0 + + for batch in dataloader: + # Skip some batches for variety + if batch_count < skip_batches: + batch_count += 1 + continue + + if sample_count >= num_samples: + break + + batch_count += 1 + input_ids = batch["input_ids"] + attention_mask = batch.get("attention_mask") + labels = batch.get("labels") + + # Randomly sample from sequences in this batch + batch_indices = torch.randperm(input_ids.size(0)).tolist() + + for i in batch_indices: + if sample_count >= num_samples: + break + + # Get actual sequence length (non-padded) + if attention_mask is not None: + seq_len = attention_mask[i].sum().item() + else: + seq_len = input_ids.size(1) + + if seq_len < 10: + continue + + # Determine truncation length + max_total = min(seq_len, max_length) + if labels is not None: + labels_i = labels[i][:seq_len] + answer_mask = labels_i != -100 + if not answer_mask.any(): + # No answer tokens; skip for SFT masking + continue + first_ans_idx = int( + torch.nonzero(answer_mask, as_tuple=False)[0].item() + ) + prompt_len = first_ans_idx + if prompt_len >= max_total: + # Prompt alone reaches cap; cannot include any answer + continue + remaining_answer = int(answer_mask[prompt_len:].sum().item()) + allowed_answer = max_total - prompt_len + take_answer = min(remaining_answer, allowed_answer) + if take_answer <= 0: + continue + actual_length = prompt_len + take_answer + else: + actual_length = max_total + + # Extract the (possibly truncated) sequence + sequence = input_ids[i][:actual_length].unsqueeze(0).to(device) + attn_seq = ( + attention_mask[i][:actual_length].unsqueeze(0).to(device) + if attention_mask is not None + else None + ) + if labels is not None: + labels_seq = labels[i][:actual_length].unsqueeze(0).to(device) + sampled_sequences.append( + { + "input_ids": sequence, + "labels": labels_seq, + "attention_mask": attn_seq, + } + ) + else: + if attn_seq is not None: + sampled_sequences.append( + {"input_ids": sequence, "attention_mask": attn_seq} + ) + else: + sampled_sequences.append(sequence) + sample_count += 1 + + return sampled_sequences + + +def generate( + model: torch.nn.Module, + tokenizer: Any, + original_sequence: torch.Tensor, + num_diffusion_steps: int, + temperature: float, + mask_token_id: int, + *, + mode: Literal["random", "completion"] = "random", + completion_tokens: int = 0, + target_mask_ratio: Optional[float] = None, + labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, +) -> dict: + """Generate a single sample using reverse diffusion.""" + # Get original text for comparison + original_text = tokenizer.decode( + original_sequence[0].cpu(), skip_special_tokens=True + ) + + # Build masked sequence + if ( + labels is not None + and labels.numel() > 0 + and (labels == -100).any() + and (labels != -100).any() + ): + # SFT case: completely mask all answer tokens (labels != -100) + total_tokens = original_sequence.size(1) + masked_indices = (labels != -100).to(dtype=torch.bool) + masked_sequence = original_sequence.clone() + masked_sequence[masked_indices] = mask_token_id + masked_tokens = int(masked_indices.sum().item()) + mask_ratio = masked_tokens / max(int(total_tokens), 1) + elif mode == "completion" and completion_tokens > 0: + # Append mask tokens to the right for completion + total_tokens = original_sequence.size(1) + int(completion_tokens) + masked_indices = torch.zeros( + 1, total_tokens, dtype=torch.bool, device=original_sequence.device + ) + masked_indices[0, -int(completion_tokens) :] = True + + append = torch.full( + (1, int(completion_tokens)), mask_token_id, device=original_sequence.device + ) + masked_sequence = torch.cat([original_sequence, append], dim=1) + masked_tokens = int(completion_tokens) + mask_ratio = masked_tokens / total_tokens + else: + # Apply random masking with optional fixed ratio + total_tokens = original_sequence.size(1) + if target_mask_ratio is None: + min_ratio, max_ratio = 0.1, 0.7 + target_mask_ratio = ( + torch.rand(1).item() * (max_ratio - min_ratio) + min_ratio + ) + target_masked_tokens = max(1, int(total_tokens * float(target_mask_ratio))) + + # Create random mask indices + mask_positions = torch.randperm(total_tokens)[:target_masked_tokens] + masked_indices = torch.zeros( + 1, total_tokens, dtype=torch.bool, device=original_sequence.device + ) + masked_indices[0, mask_positions] = True + + # Create masked sequence + masked_sequence = original_sequence.clone() + masked_sequence[masked_indices] = mask_token_id + + # Calculate actual mask ratio + masked_tokens = masked_indices.sum().item() + mask_ratio = masked_tokens / total_tokens + + # Get masked text for comparison + masked_text = tokenizer.decode(masked_sequence[0].cpu(), skip_special_tokens=False) + masked_text = _clean_masked_text(masked_text, tokenizer, mask_token_id) + + # Run reverse diffusion process + sequence = masked_sequence.clone() + attention_mask = create_bidirectional_attention_mask( + sequence, attention_mask, sample_packing=attention_mask is not None + ) + for step in range(num_diffusion_steps): + sequence = _diffusion_step( + model, + sequence, + step, + num_diffusion_steps, + temperature, + mask_token_id, + attention_mask, + ) + generated_text = tokenizer.decode(sequence[0].cpu(), skip_special_tokens=True) + + # Collect diagnostic info + final_ids = sequence[0].detach().cpu().tolist() + orig_ids_for_render = original_sequence[0].detach().cpu().tolist() + if masked_indices is not None: + masked_positions = ( + torch.where(masked_indices[0])[0].detach().cpu().tolist() + if masked_indices.ndim == 2 + else [] + ) + else: + masked_positions = [] + + result = { + "original": original_text, + "masked": masked_text, + "generated": generated_text, + "mask_ratio": mask_ratio, + "masked_tokens": masked_tokens, + "total_tokens": total_tokens, + "generated_ids": final_ids, + "masked_positions": masked_positions, + "orig_ids": orig_ids_for_render, + "formatted": ( + f"Original: '{original_text}' → Masked: '{masked_text}' " + f"({mask_ratio:.1%}) → Generated: '{generated_text}'" + ), + } + + return result + + +def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str: + """Clean up masked text for display.""" + mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False) + cleaned = masked_text.replace(mask_token_repr, "[MASK]") + + # Remove literal special token strings + if hasattr(tokenizer, "special_tokens_map"): + for token_value in tokenizer.special_tokens_map.values(): + if token_value and isinstance(token_value, str): + cleaned = cleaned.replace(token_value, "") + + # Normalize whitespace but preserve newlines + cleaned = cleaned.replace("\r\n", "\n").replace("\r", "\n") + cleaned = re.sub(r"[ \t]+", " ", cleaned) + cleaned = "\n".join(line.rstrip() for line in cleaned.split("\n")).strip() + return cleaned + + +def _diffusion_step( + model: torch.nn.Module, + sequence: torch.Tensor, + step: int, + num_diffusion_steps: int, + temperature: float, + mask_token_id: int, + attention_mask: torch.Tensor | None = None, +) -> torch.Tensor: + """Perform a single diffusion step with remasking.""" + # Only process if there are masked tokens remaining + current_mask = sequence == mask_token_id + if not current_mask.any(): + return sequence + + # Create or use provided attention mask + if attention_mask is None: + batch_size, seq_len = sequence.shape + attention_mask = torch.ones( + batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=sequence.device + ) + + # Forward pass + outputs = model(input_ids=sequence, attention_mask=attention_mask) + logits = outputs.logits + + # Only sample at currently masked positions + if current_mask.any(): + masked_logits = logits[current_mask] + + # Apply temperature scaling + if temperature > 0: + scaled_logits = masked_logits / temperature + else: + scaled_logits = masked_logits + + # Suppress mask token in outputs + scaled_logits[:, mask_token_id] = -float("inf") + + if temperature > 0: + # Add Gumbel noise for sampling + gumbel_noise = -torch.log( + -torch.log(torch.rand_like(scaled_logits, dtype=torch.float32)) + ) + gumbel_logits = scaled_logits + gumbel_noise + predicted_tokens = torch.argmax(gumbel_logits, dim=-1) + else: + predicted_tokens = torch.argmax(scaled_logits, dim=-1) + + # Calculate probabilities for confidence scoring + probs = torch.softmax(scaled_logits, dim=-1) + predicted_token_probs = probs[range(len(predicted_tokens)), predicted_tokens] + + # Determine how many tokens to unmask this step + remaining_masked = current_mask.sum().item() + if step == num_diffusion_steps - 1: + num_to_unmask = remaining_masked + else: + unmask_ratio = 1.0 / (num_diffusion_steps - step) + num_to_unmask = max(1, int(remaining_masked * unmask_ratio)) + + # Select highest confidence predictions to unmask + if num_to_unmask >= remaining_masked: + sequence[current_mask] = predicted_tokens + else: + _, top_indices = predicted_token_probs.topk(num_to_unmask) + mask_positions = torch.where(current_mask)[1] + positions_to_unmask = mask_positions[top_indices] + sequence[0, positions_to_unmask] = predicted_tokens[top_indices] + + return sequence diff --git a/src/axolotl/integrations/diffusion/plugin.py b/src/axolotl/integrations/diffusion/plugin.py new file mode 100644 index 000000000..c31f48b03 --- /dev/null +++ b/src/axolotl/integrations/diffusion/plugin.py @@ -0,0 +1,41 @@ +"""Diffusion LM training plugin for Axolotl.""" + +from peft import PeftModel +from transformers import PreTrainedModel + +from axolotl.integrations.base import BasePlugin +from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +from .trainer import DiffusionTrainer + +LOG = get_logger(__name__) + + +class DiffusionPlugin(BasePlugin): + """ + Plugin for diffusion language model training. + + This plugin enables diffusion-based training using the LLaDA approach, which uses + random masking and bidirectional attention to train language models. + """ + + def __init__(self): + super().__init__() + self.cfg = None + + def get_input_args(self) -> str: + """Returns the pydantic model for LLaDA plugin arguments.""" + return "axolotl.integrations.diffusion.DiffusionArgs" + + def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + """Perform actions after model is loaded.""" + self.cfg = cfg + + def get_trainer_cls(self, cfg: DictDefault) -> type[DiffusionTrainer] | None: + """Return custom trainer class for diffusion training.""" + return DiffusionTrainer + + def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer): + """Configure trainer after creation.""" + trainer.set_config(cfg) diff --git a/src/axolotl/integrations/diffusion/trainer.py b/src/axolotl/integrations/diffusion/trainer.py new file mode 100644 index 000000000..42b2468f4 --- /dev/null +++ b/src/axolotl/integrations/diffusion/trainer.py @@ -0,0 +1,301 @@ +"""Custom trainer for diffusion LM training.""" + +from typing import Any, Literal + +import torch +import torch.nn.functional as F +from torch import nn + +from axolotl.core.trainers.base import AxolotlTrainer +from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +from .callbacks import DiffusionGenerationCallback +from .utils import create_bidirectional_attention_mask + +LOG = get_logger(__name__) + + +class DiffusionTrainer(AxolotlTrainer): + """Custom trainer for diffusion LM training that overrides loss computation.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cfg = None + self._special_token_ids = None + + def set_config(self, config: DictDefault): + """Set config for diffusion training.""" + self.cfg = config + self._cache_special_token_ids() + self._resolve_mask_token_id() + + token_id = int(getattr(self.cfg.diffusion, "mask_token_id", 0)) + LOG.info(f"Diffusion: using mask_token_id={token_id}") + + if getattr(config.diffusion, "generate_samples", True): + generation_callback = DiffusionGenerationCallback(self) + self.add_callback(generation_callback) + + def _resolve_mask_token_id(self) -> None: + """Ensure mask_token_id is valid for the current tokenizer.""" + from .utils import resolve_mask_token_id + + tokenizer = getattr(self, "processing_class", None) + if tokenizer is None: + return + + mid = resolve_mask_token_id( + tokenizer, + self.cfg, + allow_add=True, + model=getattr(self, "model", None), + ) + try: + self.cfg.diffusion.mask_token_id = int(mid) + except Exception: + pass + + def compute_loss( + self, + model: nn.Module, + inputs: dict[str, torch.Tensor], + return_outputs: bool = False, + num_items_in_batch: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: + """Override compute_loss to use diffusion loss.""" + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask") + labels = inputs.get("labels") + + if input_ids is None: + raise ValueError("input_ids is required for diffusion training") + + loss, outputs = self._compute_diffusion_loss( + model, input_ids, attention_mask, labels + ) + + if return_outputs: + return loss, outputs + return loss + + def _cache_special_token_ids(self): + """Cache special token IDs to avoid repeated tokenizer access.""" + if self.processing_class is None: + self._special_token_ids = set() + return + + tokenizer = self.processing_class + special_tokens = set() + + if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None: + special_tokens.add(tokenizer.bos_token_id) + if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None: + special_tokens.add(tokenizer.eos_token_id) + if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None: + special_tokens.add(tokenizer.pad_token_id) + + self._special_token_ids = special_tokens + + def _forward_process( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + eps: float = 1e-3, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward noising process. A timestep is sampled along the process, and tokens are + masked with probability determined by the configured noise schedule. + + Args: + input_ids: Input token ids [batch_size, seq_len]. + attention_mask: Attention mask [batch_size, seq_len]. + labels: Labels for SFT training [batch_size, seq_len]. + eps: Small epsilon value for minimum masking probability. + + Returns: + noisy_batch: Input with some tokens masked. + masked_indices: Boolean mask indicating which tokens were masked. + p_mask: Masking probabilities for each token [batch_size, seq_len]. + """ + batch_size, seq_len = input_ids.shape + device = input_ids.device + + # Sample random timesteps for each sample in batch + t = torch.rand(batch_size, device=device) + p_mask = (1 - eps) * t + eps # [batch_size] + p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len] + + # Don't mask padding tokens if attention_mask is provided + if attention_mask is not None: + valid_mask = attention_mask.bool() + p_mask = p_mask * valid_mask.float() + + # Create mask to exclude special tokens + special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool) + if self._special_token_ids: + for token_id in self._special_token_ids: + special_token_mask |= input_ids == token_id + + # Create random mask based on p_mask + masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask + masked_indices = masked_indices & ~special_token_mask + if attention_mask is not None: + masked_indices = masked_indices & attention_mask.bool() + + # For SFT data, only mask answer tokens + if labels is not None: + answer_mask = labels != -100 + masked_indices = masked_indices & answer_mask + + # Create masked input + mask_token_id = int(self.cfg.diffusion.mask_token_id) + mask_value = torch.full_like(input_ids, mask_token_id) + noisy_batch = torch.where(masked_indices, mask_value, input_ids) + + return noisy_batch, masked_indices, p_mask + + def _compute_diffusion_loss( + self, + model: nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | Any]: + """ + Compute diffusion loss. + + Args: + model: The model to compute loss for. + input_ids: Ground truth token ids [batch_size, seq_len]. + attention_mask: Attention mask [batch_size, seq_len]. + labels: Labels for SFT training [batch_size, seq_len]. + + Returns: + loss: Cross-entropy loss. + metrics: Dictionary of metrics. + """ + # Short-circuit empty sequences + if input_ids is None or input_ids.numel() == 0 or input_ids.shape[1] == 0: + zero = torch.tensor( + 0.0, + device=(input_ids.device if input_ids is not None else None), + requires_grad=True, + ) + return zero, {} + + # If an attention_mask is provided and all positions are padding for every + # sample in this batch, skip the step. + if attention_mask is not None: + if attention_mask.dim() == 2 and (attention_mask.sum(dim=1) == 0).all(): + zero = torch.tensor(0.0, device=input_ids.device, requires_grad=True) + return zero, {} + + # Apply forward process + noisy_batch, masked_indices, p_mask = self._forward_process( + input_ids, attention_mask, labels, self.cfg.diffusion.eps + ) + + # Create bidirectional attention mask + bidirectional_mask = create_bidirectional_attention_mask( + input_ids, attention_mask, sample_packing=self.cfg.sample_packing + ) + + # Forward pass + outputs = model( + input_ids=noisy_batch.long(), + attention_mask=bidirectional_mask, + ) + logits = outputs.logits + + if masked_indices.sum() > 0: + valid_indices = torch.where(masked_indices) + batch_indices, seq_indices = valid_indices + + masked_logits = logits[batch_indices, seq_indices] + masked_targets = input_ids[batch_indices, seq_indices] + masked_p_mask = p_mask[batch_indices, seq_indices] + + # Compute cross-entropy loss without reduction + token_loss = F.cross_entropy( + masked_logits.float(), masked_targets, reduction="none" + ) + + if self.cfg.diffusion.importance_weighting: + masked_p_mask = masked_p_mask.float() + weighted_loss = token_loss / masked_p_mask + else: + weighted_loss = token_loss + + if labels is not None: + # For SFT data: normalize by answer token count per sample + answer_mask = labels != -100 + answer_lengths = answer_mask.sum(dim=1).float() # [batch_size] + + # Get batch indices for masked tokens + masked_batch_indices = batch_indices + + # Sum losses per sample and divide by answer length + batch_size = input_ids.shape[0] + loss_per_sample = torch.zeros(batch_size, device=input_ids.device) + for i in range(batch_size): + sample_mask = masked_batch_indices == i + if sample_mask.sum() > 0: + sample_loss = weighted_loss[sample_mask].sum() + denom = answer_lengths[i].clamp(min=1.0) + loss_per_sample[i] = sample_loss / denom + + loss = loss_per_sample.mean() + else: + # Non-SFT: when importance weighting is enabled, use unbiased estimator + # (sum(loss/p) / total_tokens). Otherwise, average over masked tokens + # for stable scaling across varying mask ratios. + if self.cfg.diffusion.importance_weighting: + loss = weighted_loss.sum() / ( + input_ids.shape[0] * input_ids.shape[1] + ) + else: + loss = weighted_loss.mean() + + ce_loss = token_loss.mean() + + # Compute accuracy on masked tokens + with torch.no_grad(): + pred_tokens = masked_logits.argmax(dim=-1) + accuracy = (pred_tokens == masked_targets).float().mean() + else: + loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True) + accuracy = torch.tensor(0.0, device=input_ids.device) + ce_loss = torch.tensor(0.0, device=input_ids.device) + masked_p_mask = torch.tensor(1.0, device=input_ids.device) + + avg_p_mask = ( + p_mask[masked_indices].mean().item() if masked_indices.any() else 0.0 + ) + metrics = { + "loss": loss.item(), + "accuracy": accuracy.item(), + "mask_ratio": masked_indices.float().mean().item(), + "num_masked_tokens": (masked_indices.sum().item(), "sum"), + "avg_p_mask": avg_p_mask, + "ce_loss": ce_loss.item(), + } + + # If doing SFT training, log answer-specific metrics + if self.cfg.datasets is not None: + with torch.no_grad(): + answer_mask = labels != -100 + answer_lengths = answer_mask.sum(dim=1).float() # type: ignore + total_answer_tokens = answer_mask.sum().item() # type: ignore + total_tokens = labels.numel() # type: ignore + metrics["answer_ratio"] = total_answer_tokens / max(total_tokens, 1) + metrics["avg_answer_length"] = answer_lengths.mean().item() + + if self.cfg.diffusion.importance_weighting: + metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item() + + train_eval: Literal["train", "eval"] = "train" if model.training else "eval" + self.store_metrics(metrics, train_eval=train_eval) + + return loss, outputs diff --git a/src/axolotl/integrations/diffusion/utils.py b/src/axolotl/integrations/diffusion/utils.py new file mode 100644 index 000000000..47abf6fec --- /dev/null +++ b/src/axolotl/integrations/diffusion/utils.py @@ -0,0 +1,159 @@ +"""Shared utilities for diffusion integration.""" + +from __future__ import annotations + +from typing import Any, Optional + +import torch + +from axolotl.utils.dict import DictDefault + + +def resolve_mask_token_id( + tokenizer: Any, + cfg: DictDefault, + *, + allow_add: bool, + model: Any | None = None, + default_token: str = "<|diffusion_mask|>", +) -> int: + """Resolve mask token id. Training may add a new special token; inference won't.""" + # Determine vocab size if available + vocab_size = None + if tokenizer is not None: + if hasattr(tokenizer, "vocab_size") and tokenizer.vocab_size is not None: + try: + vocab_size = int(tokenizer.vocab_size) # type: ignore[arg-type] + except Exception: + vocab_size = None + elif hasattr(tokenizer, "__len__"): + try: + vocab_size = int(len(tokenizer)) + except Exception: + vocab_size = None + + # Use explicit id from config if provided + diffusion_cfg = getattr(cfg, "diffusion", None) + # Fallback to top-level attr names only if nested missing (shouldn't happen) + cfg_id = ( + getattr(diffusion_cfg, "mask_token_id", None) + if diffusion_cfg is not None + else getattr(cfg, "diffusion_mask_token_id", None) + ) + if isinstance(cfg_id, int) and cfg_id >= 0: + if vocab_size is None or cfg_id < vocab_size: + return int(cfg_id) + + def _existing_special_token_id(token_str: str | None) -> int | None: + """Attempt to resolve an existing special token string to a real ID.""" + if not token_str or not hasattr(tokenizer, "convert_tokens_to_ids"): + return None + try: + token_id = tokenizer.convert_tokens_to_ids(token_str) + except Exception: + return None + + if not isinstance(token_id, int) or token_id < 0: + return None + + # Ensure it's registered as special and not UNK, and within vocab + unk_id = getattr(tokenizer, "unk_token_id", None) + specials = set(getattr(tokenizer, "all_special_tokens", []) or []) + addl = set(getattr(tokenizer, "additional_special_tokens", []) or []) + is_special = token_str in specials or token_str in addl + in_vocab = vocab_size is None or token_id < vocab_size + if ( + (unk_id is not None and token_id == unk_id) + or not is_special + or not in_vocab + ): + return None + return token_id + + # Try mask token string if provided + token_str = ( + getattr(diffusion_cfg, "mask_token_str", None) + if diffusion_cfg is not None + else getattr(cfg, "diffusion_mask_token_str", None) + ) + for candidate in (token_str, default_token): + token_id = _existing_special_token_id(candidate) + if isinstance(token_id, int): + try: + if diffusion_cfg is None: + cfg.diffusion_mask_token_id = int(token_id) # legacy fallback + else: + diffusion_cfg.mask_token_id = int(token_id) + except Exception: + pass + return int(token_id) + + # Optionally add and return a dedicated special token during training + if allow_add and hasattr(tokenizer, "add_special_tokens"): + token_to_add = token_str or default_token + try: + tokenizer.add_special_tokens({"additional_special_tokens": [token_to_add]}) + + # Resize embeddings if possible + if ( + model is not None + and hasattr(tokenizer, "__len__") + and hasattr(model, "resize_token_embeddings") + ): + try: + model.resize_token_embeddings(len(tokenizer)) + except Exception: + pass + new_id = tokenizer.convert_tokens_to_ids(token_to_add) + if isinstance(new_id, int) and new_id >= 0: + try: + if diffusion_cfg is None: + cfg.diffusion_mask_token_id = int(new_id) # legacy fallback + else: + diffusion_cfg.mask_token_id = int(new_id) + except Exception: + pass + return int(new_id) + except Exception: + pass + + # Fallback to unk or 0 (do not update cfg) + fallback = getattr(tokenizer, "unk_token_id", 0) or 0 + return int(fallback) + + +def create_bidirectional_attention_mask( + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + sample_packing: bool = False, +) -> torch.Tensor: + """ + Create bidirectional attention mask to override default causal masking. + Handles sample-packed sequences where different samples are identified + by different attention mask values. + + Args: + input_ids: Input token ids [batch_size, seq_len] + attention_mask: Attention mask [batch_size, seq_len] + sample_packing: Whether sample packing is enabled + + Returns: + bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len] + """ + batch_size, seq_len = input_ids.shape + device = input_ids.device + + if attention_mask is None or not sample_packing: + return torch.ones( + batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device + ) + + # Handle sample packing: tokens can only attend within their sample + mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1] + mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len] + + # Tokens can attend to each other if they have the same non-zero sample ID + bidirectional_mask = (mask_i == mask_j) & (mask_i > 0) + + # Add head dimension: [batch_size, 1, seq_len, seq_len] + return bidirectional_mask.unsqueeze(1) diff --git a/src/axolotl/integrations/grokfast/__init__.py b/src/axolotl/integrations/grokfast/__init__.py index 234d27226..df8cf2cf3 100644 --- a/src/axolotl/integrations/grokfast/__init__.py +++ b/src/axolotl/integrations/grokfast/__init__.py @@ -7,7 +7,7 @@ from transformers.trainer_callback import TrainerCallback from axolotl.utils.logging import get_logger from ..base import BasePlugin -from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401 +from .args import GrokfastArgs as GrokfastArgs from .optimizer import gradfilter_ema LOG = get_logger(__name__) @@ -24,12 +24,10 @@ class GrokfastCallbackHandler(TrainerCallback): self.alpha = alpha self.lamb = lamb - def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument + def on_train_begin(self, *args_, **kwargs): self.grads = None - def on_pre_optimizer_step( - self, args_, state, control, **kwargs - ): # pylint: disable=unused-argument + def on_pre_optimizer_step(self, args_, state, control, **kwargs): model = kwargs.pop("model") self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb) return control diff --git a/src/axolotl/integrations/grokfast/optimizer.py b/src/axolotl/integrations/grokfast/optimizer.py index 38cda2c93..c83ef43bc 100644 --- a/src/axolotl/integrations/grokfast/optimizer.py +++ b/src/axolotl/integrations/grokfast/optimizer.py @@ -1,7 +1,6 @@ # Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee # Reference: https://github.com/ironjr/grokfast -# pylint: skip-file from collections import deque from typing import Dict, Literal, Optional diff --git a/src/axolotl/integrations/kd/README.md b/src/axolotl/integrations/kd/README.md index 4b15ad31d..5e35cf3d7 100644 --- a/src/axolotl/integrations/kd/README.md +++ b/src/axolotl/integrations/kd/README.md @@ -11,7 +11,7 @@ kd_ce_alpha: 0.1 kd_alpha: 0.9 kd_temperature: 1.0 -torch_compile: True # torch>=2.5.1, recommended to reduce vram +torch_compile: True # torch>=2.6.0, recommended to reduce vram datasets: - path: ... diff --git a/src/axolotl/integrations/kd/__init__.py b/src/axolotl/integrations/kd/__init__.py index 8a6e3eda1..b1a990553 100644 --- a/src/axolotl/integrations/kd/__init__.py +++ b/src/axolotl/integrations/kd/__init__.py @@ -15,9 +15,15 @@ """ Plugin init to add KD support to Axolotl. """ -from axolotl.integrations.base import BasePlugin -from .args import KDArgs # pylint: disable=unused-import. # noqa: F401 +from typing import Any + +from transformers import Trainer + +from axolotl.integrations.base import BasePlugin +from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback + +from .args import KDArgs as KDArgs class KDPlugin(BasePlugin): @@ -28,9 +34,75 @@ class KDPlugin(BasePlugin): def get_input_args(self): return "axolotl.integrations.kd.KDArgs" + def get_training_args_mixin(self): + return "axolotl.integrations.kd.args.KDTrainingArgsMixin" + def get_trainer_cls(self, cfg): if cfg.kd_trainer: from .trainer import AxolotlKDTrainer return AxolotlKDTrainer return None + + def get_training_args(self, cfg): + return { + "kd_ce_alpha": cfg.kd_ce_alpha, + "kd_alpha": cfg.kd_alpha, + "kd_temperature": cfg.kd_temperature, + "kd_beta": cfg.kd_beta, + "kd_normalize_topk": cfg.kd_normalize_topk, + } + + def get_collator_cls_and_kwargs(self, cfg, is_eval=False): + if not cfg.kd_trainer: + return None, None + + from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq + + use_batch_sampler_collator = False + if is_eval is False and cfg.sample_packing: + use_batch_sampler_collator = True + if cfg.eval_sample_packing and is_eval: + use_batch_sampler_collator = True + + if cfg.kd_online_server_base_url: + from .collator_online_teacher import OnlineTeacherCollator + + return OnlineTeacherCollator, { + "kd_online_server_base_url": cfg.kd_online_server_base_url, + "kd_online_topk": cfg.kd_online_topk, + "kd_temperature": cfg.kd_temperature, + "kd_online_server": cfg.kd_online_server, + "kd_online_timeout": cfg.kd_online_timeout, + "kd_normalize_topk": cfg.kd_normalize_topk, + } + + if use_batch_sampler_collator: + return KDBatchSamplerDataCollatorForSeq2Seq, {} + return DataCollatorForKD, {} + + def pre_model_load(self, cfg): + from .kernels.models import apply_kernel + + apply_kernel(cfg.model_config_type) + + def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list: + """ + Adds temp scheduler callback to the Trainer instance. + + Args: + cfg (Any): Configuration object containing the sparse recipe. + trainer (Trainer): Huggingface Trainer instance. + + Returns: + list: List containing the configured callback instances. + """ + if cfg.kd_temperature_min is not None and cfg.kd_online_server_base_url: + callback = KDTemperatureSchedulerCallback( + cfg.kd_temperature, + cfg.kd_temperature_min, + trainer, + ) + return [callback] + + return [] diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py index 2fbba2c6a..425d8ddf6 100644 --- a/src/axolotl/integrations/kd/args.py +++ b/src/axolotl/integrations/kd/args.py @@ -15,9 +15,20 @@ """ Plugin args for KD support. """ -from typing import Optional -from pydantic import BaseModel +from dataclasses import dataclass +from enum import Enum + +from pydantic import BaseModel, Field + + +class InferenceServerType(str, Enum): + """ + Online inferences server types to handle different request args + """ + + vllm = "vllm" + sglang = "sglang" class KDArgs(BaseModel): @@ -25,13 +36,41 @@ class KDArgs(BaseModel): Input args for knowledge distillation. """ - kd_trainer: Optional[bool] = None # whether to use KD trainer - kd_ce_alpha: Optional[float] = ( + kd_trainer: float | None = None # whether to use KD trainer + kd_ce_alpha: float | None = ( None # loss coefficient for cross-entropy loss during KD ) - kd_alpha: Optional[float] = None # loss coefficient for KD loss - kd_temperature: Optional[float] = None # temperature for sampling during KD - kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling - kd_top_k_before_softmax: Optional[bool] = ( - None # whether to sample top k before softmax during KD + kd_alpha: float | None = None # loss coefficient for KD loss + kd_temperature: float | None = None # temperature for sampling during KD + kd_beta: float | None = 0.0 # beta coefficient for ratio of fwd and reverse KL + kd_normalize_topk: bool | None = ( + None # whether to normalize student logits during KD + ) + + # TODO online kd + kd_online_server_base_url: str | None = None + kd_online_topk: int | None = None + kd_online_server: InferenceServerType | None = Field( + default_factory=lambda: InferenceServerType.vllm + ) + kd_online_timeout: int | None = 120 + kd_temperature_min: float | None = ( + None # kd temperature scheduling during online kd + ) + + +@dataclass +class KDTrainingArgsMixin: + """ + Additional args for KD training. + """ + + kd_ce_alpha: float | None = ( + None # loss coefficient for cross-entropy loss during KD + ) + kd_alpha: float | None = None # loss coefficient for KD loss + kd_temperature: float | None = None # temperature for sampling during KD + kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL + kd_normalize_topk: float | None = ( + None # whether to normalize student logits during KD ) diff --git a/src/axolotl/integrations/kd/callbacks.py b/src/axolotl/integrations/kd/callbacks.py new file mode 100644 index 000000000..c73d8a8bb --- /dev/null +++ b/src/axolotl/integrations/kd/callbacks.py @@ -0,0 +1,34 @@ +""" +Transformers trainer callbacks to schedule the KD temperature during training +""" + +import math + +from transformers.trainer_callback import TrainerCallback + + +class KDTemperatureSchedulerCallback(TrainerCallback): + """ + KD temperature scheduler callback for the trainer. + """ + + def __init__(self, temperature_start, temperature_min, trainer): + self.temperature_start = temperature_start + self.temperature_min = temperature_min + self.temperature = temperature_start + + self.trainer = trainer + + def on_step_end(self, args, state, control, **kwargs): + # cosine decay temperature over the max steps + + progress = state.global_step / state.max_steps + # Cosine decay factor: 0.5 * (1 + cos(pi * progress)) + # This factor goes from 1 (at progress=0) to 0 (at progress=1) + decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress)) + self.temperature = self.temperature_start - ( + (self.temperature_start - self.temperature_min) * (1.0 - decay_factor) + ) + + if hasattr(self.trainer.data_collator, "kd_temperature"): + self.trainer.data_collator.kd_temperature = self.temperature diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index eb067cd04..04f0f24a4 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -15,12 +15,16 @@ """ Chat template prompt strategy loader with KD support """ + +import logging from typing import Any, Dict import torch from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader +LOG = logging.getLogger(__name__) + class ChatTemplateStrategyWithKD(ChatTemplateStrategy): """ @@ -101,10 +105,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): # fill with -inf for padding_len tokens for top_k tokens # extend target_logprobs with a padding_len x top_k 2D list filled with -inf - # for causal models, if we start the range at 1, then we don't need to shift in the trainer - # otherwise, we need to shift in the trainer - shift = 0 - for _ in range(shift, input_padding_len): + # we shift for causal models in the trainer, so start the range from 0 + for _ in range(0, input_padding_len): target_logprobs.append([-float("inf")] * top_k) target_token_ids.append(list(range(top_k))) target_mask.append([0] * top_k) @@ -143,6 +145,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): # # Convert from log to probability teacher_probs_t1 = position_logprobs_tensor.exp() + # normalize probabilities to sum to 1 in case they aren't already + teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True) + if teacher_probs_t1_sum > 1e-9: + teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum if self.kd_temperature != self.gen_temperature: # Exponentiate by factor (T1 / T2) exponent = self.gen_temperature / self.kd_temperature @@ -162,12 +168,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): target_logprobs.append(position_logprobs_scaled) target_token_ids.append(position_token_ids) - if shift == 1: - # since we started at index 1 for causal, we need one more padding token - target_logprobs.append([-float("inf")] * top_k) - target_token_ids.append(list(range(top_k))) - target_mask.append([0] * top_k) - # Update sample with transformed logprobs sample["target_logprobs"] = target_logprobs sample["target_token_ids"] = target_token_ids @@ -184,12 +184,122 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): return tokenized_prompt +class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD): + """ + Strat for datasets with complete structured KD logprob data + """ + + def transform_logprobs(self, sample): + """ + Transform logprobs to target format for KD training + """ + + logprobs = sample.pop(self.logprobs_field) + target_seq_len = len(logprobs) + input_seq_len = len(sample["input_ids"]) + input_padding_len = input_seq_len - target_seq_len + # get non-zero top-k (prune None logprobs from vllm data step) + top_k_vals = [ + len(logprobs[i]) + for i in range(len(logprobs)) + if logprobs[i] is not None and len(logprobs[i]) + ] + max_top_k = max(set(top_k_vals), key=top_k_vals.count) + min_top_k = min(set(top_k_vals), key=top_k_vals.count) + top_k = min(max_top_k, min_top_k) + if top_k == 0: + raise ValueError("No non-zero top-k logprobs found.") + + target_logprobs = [] + target_token_ids = [] + target_mask = [] + + if input_padding_len < 0: + # logprobs is longer than target_seq_len, + # so we need to slice from the left/beginning of logprobs + logprobs = logprobs[:-input_seq_len] + input_padding_len = 0 + # target_seq_len = input_seq_len + + # truncate the second dimension of the logprobs to top_k + logprobs = [row[:top_k] for row in logprobs] + + # fill with -inf for padding_len tokens for top_k tokens + # extend target_logprobs with a padding_len x top_k 2D list filled with -inf + + # we shift for causal models in the trainer, so start the range from 0 + for _ in range(0, input_padding_len): + target_logprobs.append([-float("inf")] * top_k) + target_token_ids.append(list(range(top_k))) + target_mask.append([0] * top_k) + + for position in range(input_padding_len, input_seq_len): + if sample["labels"][position] == -100: + target_mask.append([0] * top_k) + else: + target_mask.append([1] * top_k) + + for token_pos_logprobs, pos_target_token_ids in zip( + logprobs, sample["target_token_ids"], strict=False + ): + # Convert to a tensor for easier manipulation + position_logprobs_tensor = torch.tensor( + token_pos_logprobs, dtype=torch.float + ) + + # Now we have distribution at T1 in log form, i.e. log p_{T1}(k). + # Next, re-scale to T2 = self.kd_temperature via exponent-based trick + # p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z + # + # Convert from log to probability + teacher_probs_t1 = position_logprobs_tensor.exp() + # normalize probabilities to sum to 1 in case they aren't already + teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True) + if teacher_probs_t1_sum > 1e-9: + teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum + if self.kd_temperature != self.gen_temperature: + # Exponentiate by factor (T1 / T2) + exponent = self.gen_temperature / self.kd_temperature + teacher_probs_t2 = teacher_probs_t1**exponent + else: + teacher_probs_t2 = teacher_probs_t1 + # Re-normalize + teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum( + dim=0, keepdim=True + ) + # Convert back to log + position_logprobs_tensor = torch.log(teacher_probs_t2) + + # Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor + position_logprobs_scaled = position_logprobs_tensor.tolist() + + target_logprobs.append(position_logprobs_scaled) + target_token_ids.append(pos_target_token_ids) + + # Update sample with transformed logprobs + sample["target_logprobs"] = target_logprobs + sample["target_token_ids"] = target_token_ids + sample["target_mask"] = target_mask + + return sample + + def _tokenize_single_prompt(self, prompt): + target_token_ids = prompt.get("target_token_ids", None) + + tokenized_prompt = super()._tokenize_single_prompt(prompt) + + if target_token_ids is not None: + tokenized_prompt["target_token_ids"] = target_token_ids + + return tokenized_prompt + + class KDStrategyLoader(StrategyLoader): """ Load ChatTemplateStrategy with KD support using StrategyLoader. """ - def _get_strategy_cls(self): + def _get_strategy_cls(self, cfg): return ChatTemplateStrategyWithKD def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]): @@ -204,4 +314,14 @@ class KDStrategyLoader(StrategyLoader): return strategy_params -load = KDStrategyLoader() +class KDStrategyLoaderV2(KDStrategyLoader): + """ + Load KD chat template datasets with pre-tokenized logprob data + """ + + def _get_strategy_cls(self, cfg): + return ChatTemplateStrategyWithKDv2 + + +load_legacy = KDStrategyLoader() +load = KDStrategyLoaderV2() diff --git a/src/axolotl/integrations/kd/collator.py b/src/axolotl/integrations/kd/collator.py index de63869c7..675485d9d 100644 --- a/src/axolotl/integrations/kd/collator.py +++ b/src/axolotl/integrations/kd/collator.py @@ -37,7 +37,6 @@ class DataCollatorForKD(DataCollatorForSeq2Seq): target_logprobs. It also creates a teacher_mask to indicate which entries are valid. """ - # pylint: disable=duplicate-code tokenizer: PreTrainedTokenizerBase model: Optional[Any] = None padding: Union[bool, str, PaddingStrategy] = True @@ -47,11 +46,16 @@ class DataCollatorForKD(DataCollatorForSeq2Seq): position_pad_token_id: int = 0 return_tensors: str = "pt" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True + def __call__(self, features, return_tensors=None): if return_tensors is None: return_tensors = self.return_tensors padding_side = self.tokenizer.padding_side + max_len = 0 # Pad labels and position_ids first for feature_name, pad_token_id in [ @@ -67,7 +71,7 @@ class DataCollatorForKD(DataCollatorForSeq2Seq): // self.pad_to_multiple_of ) * self.pad_to_multiple_of - for f in features: # pylint: disable=invalid-name + for f in features: remainder = [pad_token_id] * (max_len - len(f[feature_name])) if isinstance(f[feature_name], list): f[feature_name] = ( @@ -96,13 +100,15 @@ class DataCollatorForKD(DataCollatorForSeq2Seq): if has_teacher_data: # Extract and remove from features - for f in features: # pylint: disable=invalid-name + for f in features: target_logprobs_list.append(f.pop("target_logprobs")) target_token_ids_list.append(f.pop("target_token_ids")) target_mask_list.append(f.pop("target_mask")) # Determine max lengths - max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list) + max_teacher_seq_len = max_len or max( + len(seq) for seq in target_logprobs_list + ) max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq) padded_target_logprobs = [] @@ -110,24 +116,25 @@ class DataCollatorForKD(DataCollatorForSeq2Seq): padded_teacher_mask_list = [] for t_logprobs, t_ids, t_mask in zip( - target_logprobs_list, target_token_ids_list, target_mask_list + target_logprobs_list, + target_token_ids_list, + target_mask_list, + strict=False, ): t_logprobs_padded = [] t_ids_padded = [] t_mask_padded = [] - for lp, ids, mask in zip( # pylint: disable=invalid-name - t_logprobs, t_ids, t_mask - ): + for lp, ids, mask in zip(t_logprobs, t_ids, t_mask, strict=False): lp_len = len(lp) if lp_len < max_k: # Use -1e9 for padding logprobs and 0 for token_ids pad_len = max_k - lp_len - lp = lp + [-1e9] * pad_len # pylint: disable=invalid-name + lp = lp + [-1e9] * pad_len ids = ids + [0] * pad_len mask = mask + [0] * pad_len else: - lp = lp[:max_k] # pylint: disable=invalid-name + lp = lp[:max_k] ids = ids[:max_k] mask = mask[:max_k] @@ -243,10 +250,15 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD): # For example, input_ids or labels are often arrays. arrays = [] for feat in sub_features: - if field_name in feat: + if field_name in feat and isinstance( + feat[field_name], (list, torch.Tensor) + ): + if isinstance(feat[field_name][0], (dict, str)): + continue arr = np.array(feat[field_name]) arrays.append(arr) - out_features[i][field_name] = np.concatenate(arrays) + if arrays: + out_features[i][field_name] = np.concatenate(arrays) # 3) Now call the parent collator, which will do: # - padding of labels/position_ids diff --git a/src/axolotl/integrations/kd/collator_online_teacher.py b/src/axolotl/integrations/kd/collator_online_teacher.py new file mode 100644 index 000000000..54e55a5e7 --- /dev/null +++ b/src/axolotl/integrations/kd/collator_online_teacher.py @@ -0,0 +1,564 @@ +""" +Packed data loader for online teacher training supporting vllm and sglang. +""" + +import hashlib +import hmac +import logging +from typing import Any, Dict, List, Optional + +import requests +import torch +from orjson import orjson + +from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq +from axolotl.integrations.kd.utils import normalize_logprobs +from axolotl.utils.data.utils import retry_on_request_exceptions + +LOG = logging.getLogger(__name__) + + +def hmac_sha_from_int_list(int_list, key, hash_func=hashlib.sha256): + """ + Create HMAC-SHA hash from a list of integers + + Args: + int_list: List of integers + key: Secret key (string or bytes) + hash_func: Hash function (default: sha256) + + Returns: + HMAC digest as hex string + """ + # Convert key to bytes if it's a string + if isinstance(key, str): + key = key.encode("utf-8") + + # Convert list of ints to bytes + # Method 1: Convert each int to bytes and concatenate + data = b"".join(i.to_bytes(4, byteorder="big") for i in int_list) + + # Create HMAC + h = hmac.new(key, data, hash_func) + return h.hexdigest() + + +class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): + """ + Collator for online teacher training. + """ + + DEFAULT_LABEL_PAD_TOKEN_ID: int = -100 + + def __init__( + self, + *args: Any, + kd_online_server_base_url: Optional[str] = None, + kd_online_topk: Optional[int] = None, + kd_temperature: Optional[float] = 1.0, + kd_online_server: Optional[str] = "vllm", + kd_online_timeout: Optional[int] = 120, + kd_cache_dir: Optional[str] = None, + kd_normalize_topk: Optional[bool] = True, + **kwargs: Any, + ): + super().__init__(*args, **kwargs) + + if kd_online_server_base_url is None: + raise ValueError( + "kd_online_server_base_url must be provided for OnlineTeacherDataloader" + ) + if kd_online_topk is None or kd_online_topk <= 0: + raise ValueError( + "kd_online_topk must be a positive integer for OnlineTeacherDataloader" + ) + + self.kd_online_server_base_url = kd_online_server_base_url.rstrip("/") + self.kd_online_topk = kd_online_topk + self.kd_temperature = kd_temperature + self.kd_online_server = kd_online_server + self.http_session = requests.Session() + self.kd_online_timeout = kd_online_timeout + self.kd_cache_dir = kd_cache_dir + self.kd_normalize_topk = kd_normalize_topk + + def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]: + """ + Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs. + """ + if not raw_logprobs or self.kd_online_topk == 0: + return ( + [-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else [] + ) + + raw_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32) + return normalize_logprobs(raw_logprobs_tensor, self.kd_online_topk).tolist() + + @retry_on_request_exceptions(max_retries=10, delay=5) + def fetch_online_logprobs_sglang( + self, batch_input_ids: List[List[int]], labels: List[List[int]] + ): + """ + Fetches logprobs from an online teacher served by sglang for a batch of input_ids. + Assumes API returns token IDs as strings in logprob dictionary keys. + """ + api_endpoint = f"{self.kd_online_server_base_url}/generate" + + payload = { + "input_ids": batch_input_ids, + "return_logprob": True, + "top_logprobs_num": self.kd_online_topk, + "logprob_start_len": 0, + "return_text_in_logprobs": True, + "echo": True, + "sampling_params": { + "max_new_tokens": 0, + "temperature": self.kd_temperature, + "skip_special_tokens": False, + }, + } + + # Initialize with empty lists, so if API call fails, these are returned. + ret_data_target_token_ids: List[List[List[int]]] = [] + ret_data_target_logprobs: List[List[List[float]]] = [] + ret_data_target_mask: List[List[List[int]]] = [] + + try: + response = self.http_session.post( + api_endpoint, json=payload, timeout=self.kd_online_timeout + ) + response.raise_for_status() + api_data: list[dict] = response.json() + + # Ensure api_data is a list, and its length matches batch_input_ids + if not isinstance(api_data, list) or len(api_data) != len(batch_input_ids): + LOG.error( + f"API response format error. Expected a list of {len(batch_input_ids)} " + f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}." + ) + # Return empty data; items processed later will get default empty KD fields + return { + "target_token_ids": ret_data_target_token_ids, + "target_logprobs": ret_data_target_logprobs, + "target_mask": ret_data_target_mask, + } + + for sequence_data, seq_input_ids, seq_labels in zip( + api_data, batch_input_ids, labels, strict=False + ): + current_target_logprobs = [] + current_target_token_ids = [] + current_target_mask = [] + + meta_info = sequence_data.pop("meta_info", {}) + # Ensure input_top_logprobs is a list + input_top_logprobs: Optional[list[None | list[tuple]]] = meta_info.pop( + "input_top_logprobs", [] + ) + if not isinstance(input_top_logprobs, list): + LOG.warning( + f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence." + ) + input_top_logprobs = [] # Treat as empty + + # basic check that the logprob data len matches the input len, so no need to handle padding + assert len(seq_input_ids) == len(input_top_logprobs) + + for i, _, label in zip( + range(len(seq_input_ids)), seq_input_ids, seq_labels, strict=False + ): + if i < len(input_top_logprobs) and input_top_logprobs[i] is None: + # this is always the case for the first token. + # there is never logprob data for the first token since that's a true input + # so we replace the None value with padding data + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append([0] * self.kd_online_topk) + current_target_mask.append([0] * self.kd_online_topk) + elif ( + i < len(input_top_logprobs) + and input_top_logprobs[i] is not None + ): + pos_top_logprobs_data = input_top_logprobs[i] + # Ensure pos_top_logprobs_data is a list of lists as expected + if not ( + isinstance(pos_top_logprobs_data, list) + and all( + isinstance(item, list) for item in pos_top_logprobs_data + ) + and len(pos_top_logprobs_data) > 0 + and len(pos_top_logprobs_data[0]) == 3 + ): # [logprob, token_id, token_str] + LOG.warning( + f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position." + ) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append([0] * self.kd_online_topk) + current_target_mask.append([0] * self.kd_online_topk) + continue + + # pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids + pos_logprobs_raw, pos_token_ids, _ = [ + list(row) + for row in zip(*pos_top_logprobs_data, strict=False) + ] + + # Ensure correct length (top_k) + if len(pos_logprobs_raw) < self.kd_online_topk: + pad_len = self.kd_online_topk - len(pos_logprobs_raw) + pos_logprobs_raw.extend([-float("inf")] * pad_len) + pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id + + # truncate to top_k in case the response was longer + current_target_token_ids.append( + pos_token_ids[: self.kd_online_topk] + ) + + if self.kd_normalize_topk: + normalized_logprobs_for_position = self._normalize_logprobs( + pos_logprobs_raw[: self.kd_online_topk] + ) + current_target_logprobs.append( + normalized_logprobs_for_position + ) + else: + current_target_logprobs.append( + pos_logprobs_raw[: self.kd_online_topk] + ) + + # Mask depends on the corresponding label for the student + if label == self.DEFAULT_LABEL_PAD_TOKEN_ID: + current_target_mask.append([0] * self.kd_online_topk) + else: + current_target_mask.append([1] * self.kd_online_topk) + else: + # Pad if no logprobs for this position (either due to length mismatch or None entry) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append([0] * self.kd_online_topk) + current_target_mask.append([0] * self.kd_online_topk) + + ret_data_target_token_ids.append(current_target_token_ids) + ret_data_target_logprobs.append(current_target_logprobs) + ret_data_target_mask.append(current_target_mask) + + except requests.exceptions.RequestException as e: + LOG.error(f"Error fetching logprobs from online teacher: {e}") + raise e + # ret_logprobs_data will be returned with empty lists, handled by the caller. + except Exception as e: # Catch other potential errors during processing + LOG.error( + f"Unexpected error processing API response in fetch_online_logprobs: {e}", + exc_info=True, + ) + raise e + + return { + "target_token_ids": ret_data_target_token_ids, + "target_logprobs": ret_data_target_logprobs, + "target_mask": ret_data_target_mask, + } + + @retry_on_request_exceptions(max_retries=10, delay=5) + def fetch_online_logprobs_vllm( + self, batch_input_ids: List[List[int]], labels: List[List[int]] + ): + """ + Fetches logprobs from an online teacher served by vllm for a batch of input_ids. + Assumes API returns token IDs as strings in logprob dictionary keys. + """ + api_endpoint = f"{self.kd_online_server_base_url}/v1/completions" + + payload = { + "prompt": batch_input_ids, + "echo": True, + "logprobs": True, + "prompt_logprobs": self.kd_online_topk, + "top_logprobs": self.kd_online_topk, + "max_new_tokens": 0, + "skip_special_tokens": False, + "temperature": self.kd_temperature, + "sampling_params": { + "max_tokens": 0, + }, + } + + # Initialize with empty lists, so if API call fails, these are returned. + ret_data_target_token_ids: List[List[List[int]]] = [] + ret_data_target_logprobs: List[List[List[float]]] = [] + ret_data_target_mask: List[List[List[int]]] = [] + + try: + headers = {"Accept-Encoding": "deflate, gzip, br, zstd"} + response = self.http_session.post( + api_endpoint, + json=payload, + headers=headers, + timeout=self.kd_online_timeout, + ) + response.raise_for_status() + api_data: dict = orjson.loads(response.content) + choices: list[dict] = api_data["choices"] + + # Ensure api_data is a list, and its length matches batch_input_ids + if not isinstance(choices, list) or len(choices) != len(batch_input_ids): + LOG.error( + f"API response format error. Expected a list of {len(batch_input_ids)} " + f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}." + ) + # Return empty data; items processed later will get default empty KD fields + return { + "target_token_ids": ret_data_target_token_ids, + "target_logprobs": ret_data_target_logprobs, + "target_mask": ret_data_target_mask, + } + + for sequence_data, seq_input_ids, seq_labels in zip( + choices, batch_input_ids, labels, strict=False + ): + # seq_input_ids: List[int] + # seq_labels: List[int] + + current_target_logprobs = [] + current_target_token_ids = [] + current_target_mask = [] + + # Ensure input_top_logprobs is a list + input_top_logprobs: Optional[list[None | dict[str, dict]]] = ( + sequence_data.pop("prompt_logprobs", []) + ) + + if not isinstance(input_top_logprobs, list): + LOG.warning( + f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence." + ) + input_top_logprobs = [] # Treat as empty + + # basic check that the logprob data len matches the input len, so no need to handle padding + assert len(seq_input_ids) == len(input_top_logprobs) + + seq_len = len(seq_input_ids) + + for i, _, label in zip( + range(seq_len), seq_input_ids, seq_labels, strict=False + ): + if i < len(input_top_logprobs) and input_top_logprobs[i] is None: + # this is always the case for the first token. + # there is never logprob data for the first token since that's a true input + continue + if ( + i < len(input_top_logprobs) + and input_top_logprobs[i] is not None + ): + pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] # type: ignore[assignment] + # Ensure pos_top_logprobs_data is a list of lists as expected + if not ( + isinstance(pos_top_logprobs_data, dict) + and all( + isinstance(item, dict) + for item in pos_top_logprobs_data.values() + ) + and len(pos_top_logprobs_data.keys()) > 0 + ): # [logprob, token_id, token_str] + LOG.warning( + f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position." + ) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append( + list(range(self.kd_online_topk)) + ) + current_target_mask.append([0] * self.kd_online_topk) + continue + + # pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids + pos_token_ids_str = list(pos_top_logprobs_data.keys()) + pos_logprobs_dict = pos_top_logprobs_data.values() + pos_token_ids = [ + int(token_id) for token_id in pos_token_ids_str + ] + pos_logprobs_raw = [ + float(logprob.get("logprob", -float("inf"))) + for logprob in pos_logprobs_dict + ] + + # Ensure correct length (top_k) + if len(pos_logprobs_raw) < self.kd_online_topk: + pad_len = self.kd_online_topk - len(pos_logprobs_raw) + LOG.warning( + f"Padding position {i} with {pad_len} top-k tokens and logprobs." + ) + pos_logprobs_raw.extend([-float("inf")] * pad_len) + pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id + + # truncate to top_k in case the response was longer + current_target_token_ids.append( + pos_token_ids[: self.kd_online_topk] + ) + + if self.kd_normalize_topk: + normalized_logprobs_for_position = self._normalize_logprobs( + pos_logprobs_raw[: self.kd_online_topk] + ) + current_target_logprobs.append( + normalized_logprobs_for_position + ) + else: + current_target_logprobs.append( + pos_logprobs_raw[: self.kd_online_topk] + ) + + # Mask depends on the corresponding label for the student + if label == self.DEFAULT_LABEL_PAD_TOKEN_ID: + current_target_mask.append([0] * self.kd_online_topk) + else: + current_target_mask.append([1] * self.kd_online_topk) + else: + # Pad if no logprobs for this position (either due to length mismatch or None entry) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append( + list(range(self.kd_online_topk)) + ) + current_target_mask.append([0] * self.kd_online_topk) + for _ in range(max(0, seq_len - len(current_target_logprobs))): + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append(list(range(self.kd_online_topk))) + current_target_mask.append([0] * self.kd_online_topk) + + ret_data_target_token_ids.append(current_target_token_ids) + ret_data_target_logprobs.append(current_target_logprobs) + ret_data_target_mask.append(current_target_mask) + + # TODO save and load targets to disk for caching for next epoch + # generate a hmac SHA256 hash over the list seq_input_ids and convert it to an int + # if self.kd_cache_dir: + # hash_input_ids = hmac_sha_from_int_list( + # seq_input_ids, f"{self.kd_online_server_base_url}:{self.kd_online_topk}" + # ) + # with open(f"{self.kd_cache_dir}/{hash_input_ids}.parquet", "wb") as f: + # pd.DataFrame(ret_logprobs_data).to_parquet(f, index=False) + + except requests.exceptions.RequestException as e: + LOG.error(f"Error fetching logprobs from online teacher: {e}") + raise e + # ret_logprobs_data will be returned with empty lists, handled by the caller. + except Exception as e: # Catch other potential errors during processing + LOG.error( + f"Unexpected error processing API response in fetch_online_logprobs: {e}", + exc_info=True, + ) + raise e + + return { + "target_token_ids": ret_data_target_token_ids, + "target_logprobs": ret_data_target_logprobs, + "target_mask": ret_data_target_mask, + } + + def __call__( + self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None + ) -> Dict[str, Any]: + if not features: + return super().__call__(features, return_tensors=return_tensors) + + for ( + sub_batch_features + ) in features: # sub_batch_features is List[Dict[str, Any]] + if not sub_batch_features: + continue + + input_ids_for_api_call: List[List[int]] = [] + labels_for_api_call: List[List[int]] = [] + # Store references to the original item dictionaries to update them in-place + items_for_api_call: List[Dict[str, Any]] = [] + + for item_dict in sub_batch_features: + if not isinstance(item_dict, dict): + LOG.warning( + f"Skipping non-dict item in sub_batch_features: {item_dict}" + ) + continue + + current_input_ids = item_dict.get("input_ids") + current_labels = item_dict.get("labels") + + if current_input_ids is not None and current_labels is not None: + # Ensure input_ids and labels are lists of ints for JSON serialization + input_ids_list = ( + current_input_ids.tolist() + if hasattr(current_input_ids, "tolist") + else list(current_input_ids) + ) + labels_list = ( + current_labels.tolist() + if hasattr(current_labels, "tolist") + else list(current_labels) + ) + + input_ids_for_api_call.append(input_ids_list) + labels_for_api_call.append(labels_list) + items_for_api_call.append(item_dict) + else: + # This item will not get teacher logprobs from the API. + # Initialize KD fields to empty lists so downstream collators handle them uniformly. + item_dict.setdefault("target_token_ids", []) + item_dict.setdefault("target_logprobs", []) + item_dict.setdefault("target_mask", []) + + # print(items_for_api_call) + if items_for_api_call: # Only call API if there's something to process + if self.kd_online_server == "sglang": + api_responses_for_sub_batch = self.fetch_online_logprobs_sglang( + input_ids_for_api_call, labels_for_api_call + ) + else: + api_responses_for_sub_batch = self.fetch_online_logprobs_vllm( + input_ids_for_api_call, labels_for_api_call + ) + + # api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask" + # Each value is a list, corresponding to items_for_api_call + for i, item_to_update in enumerate(items_for_api_call): + # TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly. + if api_responses_for_sub_batch and i < len( + api_responses_for_sub_batch["target_token_ids"] + ): # Check bounds + assert len( + api_responses_for_sub_batch["target_token_ids"][i] + ) == len(item_to_update["input_ids"]) + assert len( + api_responses_for_sub_batch["target_logprobs"][i] + ) == len(item_to_update["input_ids"]) + assert len( + api_responses_for_sub_batch["target_mask"][i] + ) == len(item_to_update["labels"]) + item_to_update["target_token_ids"] = ( + api_responses_for_sub_batch["target_token_ids"][i] + ) + item_to_update["target_logprobs"] = api_responses_for_sub_batch[ + "target_logprobs" + ][i] + item_to_update["target_mask"] = api_responses_for_sub_batch[ + "target_mask" + ][i] + else: + # API call failed for this item, or response was shorter than expected. + # Ensure KD fields are initialized as empty lists. + LOG.warning( + f" (index {i}), or API response was too short. " + f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}" + ) + item_to_update.setdefault("target_token_ids", []) + item_to_update.setdefault("target_logprobs", []) + item_to_update.setdefault("target_mask", []) + + return super().__call__(features, return_tensors=return_tensors) diff --git a/src/axolotl/integrations/kd/kernels/__init__.py b/src/axolotl/integrations/kd/kernels/__init__.py index e69de29bb..3f1144a45 100644 --- a/src/axolotl/integrations/kd/kernels/__init__.py +++ b/src/axolotl/integrations/kd/kernels/__init__.py @@ -0,0 +1,8 @@ +""" +Liger Chunked loss optimizations module +""" + +from .liger import LigerFusedLinearKLTopKLogprobLoss +from .models import apply_kernel + +__all__ = ["LigerFusedLinearKLTopKLogprobLoss", "apply_kernel"] diff --git a/src/axolotl/integrations/kd/kernels/liger.py b/src/axolotl/integrations/kd/kernels/liger.py new file mode 100644 index 000000000..61ef3e10a --- /dev/null +++ b/src/axolotl/integrations/kd/kernels/liger.py @@ -0,0 +1,485 @@ +""" +Liger Kernels for Chunked Top-K Log-Prob Distillation +""" + +import torch +import torch.nn.functional as F +from liger_kernel.chunked_loss.fused_linear_distillation import ( + LigerFusedLinearDistillationBase, +) + +from axolotl.integrations.kd.utils import normalize_logprobs + + +class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): + """ + Chunked kl-div loss for top-k logprobs + """ + + @staticmethod + def distillation_loss_fn( + student_logits_temp_scaled: torch.Tensor, # [chunk_size, vocab_size], already temp-scaled + target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k] + target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs + target_mask_chunk: torch.Tensor, # [chunk_size, top_k] + beta: float = 0.0, + normalize_topk: bool = True, + ) -> torch.Tensor: + """ + Compute Top-K KL divergence loss for a chunk. + Args: + student_logits_temp_scaled: Student logits, scaled by temperature. Shape: (N, V). + target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K). + target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K). + target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K). + beta: Controls the type of KL divergence. + 0.0 for Forward KL (P_teacher || P_student). + 1.0 for Reverse KL (P_student || P_teacher). + 0.5 for Symmetric KL (average of Forward and Reverse). + normalize_topk: Whether to normalize the log probabilities + Returns: + Sum of KL divergence losses for the chunk. + """ + topk = target_token_ids_chunk.shape[-1] + student_logits_temp_scaled = ( # [chunk_size, vocab_size] + student_logits_temp_scaled.float() + ) + target_logprobs_chunk = target_logprobs_chunk.float() + + # Gather student logits for the top-k teacher token IDs + # target_token_ids_chunk: [chunk_size, top_k] + # student_logits_topk_temp_scaled: [chunk_size, top_k] + student_logits_topk_temp_scaled = torch.gather( + student_logits_temp_scaled, dim=-1, index=target_token_ids_chunk + ) + + # Student log-probabilities for the gathered top-k tokens + student_lse = torch.logsumexp( + student_logits_temp_scaled, dim=-1, keepdim=True + ) # [chunk_size, 1] + student_logprobs_topk_temp_scaled = ( + student_logits_topk_temp_scaled - student_lse + ) + + # we have the top-k student logprobs, normalize them + if normalize_topk: + student_logprobs_topk_temp_scaled = normalize_logprobs( + student_logprobs_topk_temp_scaled, topk + ) + + valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k] + + student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask] + teacher_logprobs_valid = target_logprobs_chunk[valid_mask] + + # Teacher probabilities P(y|x_teacher) from logprobs + # target_logprobs_valid are already normalized (log(softmax(teacher_logits/T))) + teacher_probs_valid = teacher_logprobs_valid.exp() + # Student probabilities P_student from log P_student + student_probs_topk_valid = student_logprobs_topk_valid.exp() + + # kd_loss_per_token = torch.zeros_like(target_logprobs_valid) + + # KL divergence: sum(P_teacher * (log P_teacher - log P_student)) + # = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student) + # The distillation loss is often formulated as -sum(P_teacher * log P_student) + # or as sum(P_teacher * (log_softmax_teacher - log_softmax_student)) + # Here, target_logprobs_valid are log_softmax_teacher. + # student_logprobs_topk_valid are log_softmax_student (for the selected K indices). + if beta == 0.0: # Contribution from Forward KL + fwd_kl_per_token = teacher_probs_valid * ( + teacher_logprobs_valid - student_logprobs_topk_valid + ) + kd_loss = fwd_kl_per_token.sum() + elif beta == 1.0: # Contribution from Reverse KL + rev_kl_per_token = student_probs_topk_valid * ( + student_logprobs_topk_valid - teacher_logprobs_valid + ) + kd_loss = rev_kl_per_token.sum() + else: + # JSD - Jensen-Shannon Divergence / Symmetric + mean_probs = ( + 1 - beta + ) * student_probs_topk_valid + beta * teacher_probs_valid + log_mean_probs = mean_probs.log() + student_kl = F.kl_div( + log_mean_probs, + student_logprobs_topk_valid, + reduction="sum", + log_target=True, + ) + teacher_kl = F.kl_div( + log_mean_probs, teacher_logprobs_valid, reduction="sum", log_target=True + ) + jsd_loss = beta * teacher_kl + (1 - beta) * student_kl + kd_loss = jsd_loss + + return kd_loss + + @staticmethod + def _compute_loss_kl_topk( + student_input_chunk: torch.Tensor, + student_weight: torch.Tensor, + # Args for student_bias, target_token_ids_chunk etc. are passed to the lambda wrapped by grad_and_value + # or through `partial`. Let's make them explicit here for clarity. + target_token_ids_chunk: torch.Tensor, + target_logprobs_chunk: torch.Tensor, + target_mask_chunk: torch.Tensor, + target_chunk: torch.Tensor, # For hard loss (true labels) + student_bias: torch.Tensor = None, # This will be one of the grad targets + # Other params passed via `partial` from `forward` + distillation_loss_fn=None, + ignore_index: int = -100, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + compute_ce_loss: bool = True, + temperature: float = 1.0, + beta: float = 0.0, + normalize_topk: bool = True, + ): + # Compute student logits for the chunk from hidden states and LM head + # student_input_chunk: [chunk_size, hidden_dim] + # student_lm_head_weight: [vocab_size, hidden_dim] + # student_logits_chunk: [chunk_size, vocab_size] + student_logits_chunk = F.linear( + student_input_chunk, student_weight, student_bias + ) + + ce_loss = torch.tensor( + 0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype + ) + if compute_ce_loss and weight_hard_loss > 0.0: + ce_loss = F.cross_entropy( + student_logits_chunk.view(-1, student_logits_chunk.shape[-1]), + target_chunk.view(-1), + reduction="sum", + ignore_index=ignore_index, + ) + + soft_loss = torch.tensor( + 0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype + ) + if weight_soft_loss > 0.0: + student_logits_chunk_temp_scaled = student_logits_chunk / temperature + + # Assuming student_weight.shape[0] (vocab_size) is adequate for target_token_ids_chunk.max() + # No explicit padding here; user must ensure vocab alignment or pre-pad student_weight. + + soft_loss = distillation_loss_fn( + student_logits_chunk_temp_scaled, + target_token_ids_chunk, + target_logprobs_chunk, + target_mask_chunk, + beta=beta, + normalize_topk=normalize_topk, + ) + + return soft_loss, ce_loss + + @classmethod + def forward( + cls, + ctx, + student_input: torch.Tensor, # [batch_size, seq_len, dim] + student_lm_head_weight: torch.Tensor, # [dim, vocab_size] + target_token_ids: torch.Tensor, # [batch_size, seq_len, top_k] + target_logprobs: torch.Tensor, # [batch_size, seq_len, top_k] + target_mask: torch.Tensor, # [batch_size, seq_len, top_k] + true_labels: torch.Tensor, # [batch_size, seq_len] + student_lm_head_bias: torch.Tensor = None, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + beta: float = 0.0, + compiled: bool = False, + chunk_size: int = 1024, + compute_ce_loss: bool = True, + normalize_topk: bool = True, + ): + CHUNK_SIZE = chunk_size + grad_weight_acc = torch.zeros_like(student_lm_head_weight) + grad_inputs_list = [] + grad_bias_acc = ( + torch.zeros_like(student_lm_head_bias) + if student_lm_head_bias is not None + else None + ) + kd_loss_acc = torch.zeros( + (), device=student_input.device, dtype=student_input.dtype + ) + ce_loss_acc = torch.zeros( + (), device=student_input.device, dtype=student_input.dtype + ) + + # This function will be what torch.func.grad_and_value differentiates. + # It takes student_input_chunk, student_weight (full), student_bias (full) as primals. + # Other necessary data (target_*, etc.) are passed as non-differentiable arguments. + def loss_fn_for_grad( + _student_input_chunk, + _student_lm_head_weight, # full weight + _student_lm_head_bias, # full bias + # Fixed arguments for a given chunk, not differentiated: + _target_token_ids_chunk, + _target_logprobs_chunk, + _target_mask_chunk, + _true_labels_chunk, + ): + return cls._compute_loss_kl_topk( + student_input_chunk=_student_input_chunk, + student_weight=_student_lm_head_weight, + target_token_ids_chunk=_target_token_ids_chunk, + target_logprobs_chunk=_target_logprobs_chunk, + target_mask_chunk=_target_mask_chunk, + target_chunk=_true_labels_chunk, + student_bias=_student_lm_head_bias, + distillation_loss_fn=cls.distillation_loss_fn, + ignore_index=ignore_index, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + compute_ce_loss=compute_ce_loss, + temperature=temperature, + beta=beta, + normalize_topk=normalize_topk, + ) + + def accumulate_chunk_grads( + student_input_chunk_ac, + target_token_ids_chunk_ac, + target_logprobs_chunk_ac, + target_mask_chunk_ac, + true_labels_chunk_ac, + ): + # student_weight and student_bias are closed over from the outer scope (full tensors) + if student_lm_head_bias is not None: + ( + (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), + (chunk_kd_loss, chunk_ce_loss), + ) = torch.func.grad_and_value( + loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True + )( + student_input_chunk_ac, + student_lm_head_weight, + student_lm_head_bias, # primals + target_token_ids_chunk_ac, + target_logprobs_chunk_ac, + target_mask_chunk_ac, + true_labels_chunk_ac, + ) # non-primals + grad_bias_acc.add_(chunk_grad_bias) + else: + argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight + ( + (chunk_grad_input, chunk_grad_weight), # No grad for bias + (chunk_kd_loss, chunk_ce_loss), + ) = torch.func.grad_and_value( + loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True + )( + student_input_chunk_ac, + student_lm_head_weight, + None, # Pass None for student_bias primal + target_token_ids_chunk_ac, + target_logprobs_chunk_ac, + target_mask_chunk_ac, + true_labels_chunk_ac, + ) + + grad_weight_acc.add_(chunk_grad_weight) + kd_loss_acc.add_(chunk_kd_loss) + ce_loss_acc.add_(chunk_ce_loss) + + return chunk_grad_input + + if compiled: + accumulate_chunk_grads_compiled = torch.compile( + accumulate_chunk_grads, dynamic=True, backend="inductor" + ) # dynamic=True often helpful + else: + accumulate_chunk_grads_compiled = accumulate_chunk_grads + + # Use the same chunking logic as LigerFusedLinearDistillationBase.forward + B, N, D = student_input.shape + K = target_token_ids.shape[-1] + + student_input_flat = student_input.reshape(-1, student_input.shape[-1]) + target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1]) + target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1]) + target_mask_flat = target_mask.reshape(-1, target_mask.shape[-1]) + # pad and shift for cross entropy loss + true_labels = torch.nn.functional.pad(true_labels, (0, 1), value=ignore_index) + true_labels_flat = true_labels[:, 1:].contiguous().view(-1) + + num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE) + + _student_input_chunks = torch.chunk( + student_input_flat, chunks=num_chunks, dim=0 + ) + _target_token_ids_chunks = torch.chunk( + target_token_ids_flat, chunks=num_chunks, dim=0 + ) + _target_logprobs_chunks = torch.chunk( + target_logprobs_flat, chunks=num_chunks, dim=0 + ) + _target_mask_chunks = torch.chunk(target_mask_flat, chunks=num_chunks, dim=0) + _true_labels_chunks = torch.chunk(true_labels_flat, chunks=num_chunks, dim=0) + + for i in range(num_chunks): + grad_input_chunk = accumulate_chunk_grads_compiled( + _student_input_chunks[i], + _target_token_ids_chunks[i], + _target_logprobs_chunks[i], + _target_mask_chunks[i], + _true_labels_chunks[i], + ) + grad_inputs_list.append(grad_input_chunk) + + grad_inputs_combined = torch.cat(grad_inputs_list, dim=0) + ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc) + + # For matching None returns in backward for non-tensor/non-grad_requiring inputs + ctx.hyperparams_count = 9 # Corresponds to number of hyperparams after main tensors in fwd signature + ctx.bias_was_none = student_lm_head_bias is None + ctx.orig_dims = (B, N, D, K) + + # since this is packed, there is simply a single batch, so batchmean reduction of kl-div is simply the accumulated sum + # we still need to scale the kd_loss by the temp^2 + kd_loss_acc = kd_loss_acc * (temperature**2) + final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc + + return final_loss + + @staticmethod + def backward(ctx, grad_output): + grad_input_flat, grad_weight, grad_bias_maybe = ( + ctx.saved_tensors + ) # grad_input_flat is (B*N, D) + + # Scale gradients by grad_output if it's not 1.0 + if not torch.equal( + grad_output, + torch.tensor(1.0, device=grad_output.device, dtype=grad_output.dtype), + ): + grad_input_flat = grad_input_flat * grad_output + grad_weight = grad_weight * grad_output + if grad_bias_maybe is not None: + grad_bias_maybe = grad_bias_maybe * grad_output + + # Reshape grad_input_flat to match original student_input shape (B, N, D) + # ctx.orig_dims stores (B, N, D, K) + # We need the first three dimensions for student_input's shape. + # Ensure that orig_dims are not (0,0,0,K) for empty inputs leading to view errors + if ( + ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0 + and grad_input_flat.numel() == 0 + ): + # If original input was empty, gradient should also be empty with correct shape + grad_input_reshaped = torch.zeros( + ctx.orig_dims[0], + ctx.orig_dims[1], + ctx.orig_dims[2], + dtype=grad_input_flat.dtype, + device=grad_input_flat.device, + ) + elif grad_input_flat.numel() == 0 and not ( + ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0 + ): + # This case should ideally not happen if forward path is correct (non-empty input -> non-empty flat grad) + # but as a safeguard: + grad_input_reshaped = torch.zeros( + ctx.orig_dims[0], + ctx.orig_dims[1], + ctx.orig_dims[2], + dtype=grad_input_flat.dtype, + device=grad_input_flat.device, + ) + else: + grad_input_reshaped = grad_input_flat.view( + ctx.orig_dims[0], ctx.orig_dims[1], ctx.orig_dims[2] + ) + + nones_for_hyperparams = [None] * ctx.hyperparams_count + grad_bias_return = grad_bias_maybe if not ctx.bias_was_none else None + + return ( + grad_input_reshaped, # Gradient for student_input (reshaped) + grad_weight, # Gradient for student_lm_head_weight + None, # Gradient for target_token_ids + None, # Gradient for target_logprobs + None, # Gradient for target_mask + None, # Gradient for true_labels + grad_bias_return, # Gradient for student_lm_head_bias + *nones_for_hyperparams, # Grads for weight_hard_loss, ..., compute_ce_loss + ) + + +class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module): + """ + wrapper for chunked top-k logprob kl-d + """ + + def __init__( + self, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + temperature: float = 1.0, # This is the kd_temperature + beta: float = 1.0, + ignore_index: int = -100, + compiled: bool = True, + chunk_size: int = 1024, + compute_ce_loss: bool = True, + normalize_topk: bool = True, + ): + super().__init__() + if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0): + raise ValueError("Loss weights must be between 0.0 and 1.0.") + if temperature <= 0: + raise ValueError("Temperature must be positive.") + + self.weight_hard_loss = weight_hard_loss + self.weight_soft_loss = weight_soft_loss + self.temperature = temperature + self.beta = beta + self.ignore_index = ignore_index + self.compiled = compiled + self.chunk_size = chunk_size + self.compute_ce_loss = compute_ce_loss + self.normalize_topk = normalize_topk + + if not self.compute_ce_loss and self.weight_hard_loss > 0.0: + print( + f"Warning: compute_ce_loss is False, but weight_hard_loss ({self.weight_hard_loss}) > 0. Hard loss will effectively be zero." + ) + # self.weight_hard_loss = 0.0 # Or let user manage this + if self.weight_soft_loss == 0.0: + print( + "Warning: weight_soft_loss is 0.0. Soft (KD) loss will not be computed." + ) + + def forward( + self, + lm_head_weight: torch.Tensor, # Weights of the linear layer in the LM head + student_hidden_states: torch.Tensor, # student_hidden_states before the lm_head + target_token_ids: torch.Tensor, + target_logprobs: torch.Tensor, + target_mask: torch.Tensor, + true_labels: torch.Tensor, + student_bias: torch.Tensor = None, + ) -> torch.Tensor: + return LigerFusedLinearKLTopKLogprobFunction.apply( + student_hidden_states, + lm_head_weight, + target_token_ids, + target_logprobs, + target_mask, + true_labels, + student_bias, + self.weight_hard_loss, + self.weight_soft_loss, + self.ignore_index, + self.temperature, + self.beta, + self.compiled, + self.chunk_size, + self.compute_ce_loss, + self.normalize_topk, + ) diff --git a/src/axolotl/integrations/kd/kernels/models.py b/src/axolotl/integrations/kd/kernels/models.py new file mode 100644 index 000000000..badb3460d --- /dev/null +++ b/src/axolotl/integrations/kd/kernels/models.py @@ -0,0 +1,104 @@ +""" +model patcher for chunked top-k kl-div +""" + +from typing import Optional, Union, Unpack + +import torch +from transformers import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast + +try: + from transformers.modeling_flash_attention_utils import FlashAttentionKwargs + from transformers.utils import LossKwargs + + class TransformersKwargs(FlashAttentionKwargs, LossKwargs): + """ + placeholder kwargs for hf model classes + """ + +except ImportError: + from transformers.utils.generic import ( # type: ignore[no-redef] + TransformersKwargs, + ) + +from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix + + +def kldiv_forward_llama_like( + self, + input_ids: Optional[torch.LongTensor] = None, + target_logprobs: Optional[torch.Tensor] = None, + target_token_ids: Optional[torch.LongTensor] = None, + target_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], # type: ignore[misc] +) -> CausalLMOutputWithPast: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100 + # self._loss_function should be LigerFusedLinearKLTopKLogprobLoss + + loss = self._loss_function( + self.lm_head.weight, + hidden_states, + target_token_ids, + target_logprobs, + target_mask, + true_labels=labels, + ) + num_items_in_batch = kwargs.pop("num_items_in_batch", -1) + if num_items_in_batch is not None and num_items_in_batch > 0: + loss = loss / num_items_in_batch + + return CausalLMOutputWithPast( + loss=loss, + logits=None, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def apply_kernel(model_type): + # Dynamically import the module and attention class + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) + module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]) + model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") + model_cls.forward = kldiv_forward_llama_like diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py index 3c9515091..b79ba26f3 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py @@ -15,41 +15,9 @@ """ loss for top_k KL divergence """ + import torch - - -def zscore_standardize( - logits: torch.Tensor, - mask: torch.Tensor = None, - base_temperature: float = 1.0, - eps: float = 1e-9, -): - """ - Z-score standardize along the last dimension of `logits`. - i.e., for each [B, seq_len] row, across K entries: - z = (logits - mean) / std, - then scale by 1 / base_temperature if desired. - - mask can be broadcastable or None. If None, we standardize all elements. - """ - if mask is None: - # shape: [B, seq_len, K] - # Mean and std over dim=-1 - mean = logits.mean(dim=-1, keepdim=True) - var = logits.var(dim=-1, unbiased=False, keepdim=True) - else: - # If you have to exclude some tokens, multiply by mask, etc. - float_mask = mask.to(logits.dtype) - count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0) - mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count - var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count - - std = torch.sqrt(var.clamp_min(eps)) - z = (logits - mean) / std - - # Scale by 1 / base_temperature - z = z / base_temperature - return z +from torch import nn @torch.jit.script @@ -60,7 +28,6 @@ def loss( target_mask: torch.Tensor, num_items_in_batch: int = -1, # Use -1 to indicate "None" kd_temperature: float = 1.0, - top_k_before_softmax: int = 0, ) -> torch.Tensor: """ A KD loss function that is TorchScript-friendly. @@ -77,8 +44,6 @@ def loss( num_items_in_batch (int, optional): The number of items in the batch. kd_temperature (float, optional): The temperature for KD. Default: 1.0 - top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits - Default: 0 """ target_logprobs = target_logprobs.float() @@ -88,46 +53,24 @@ def loss( # student_logits shape: [B, student_seq_len, vocab_size] teacher_seq_len = target_token_ids.shape[1] - if top_k_before_softmax: - # Slice student logits to match teacher-provided sequence length - student_logits_for_kd = student_logits[ - :, :teacher_seq_len, : - ] # [B, teacher_seq_len, vocab_size] + # Slice student logits to match teacher-provided sequence length + student_logits_for_kd = ( + student_logits[:, :teacher_seq_len, :] / kd_temperature + ) # [B, teacher_seq_len, vocab_size] - # Gather student logits for teacher's top-K tokens - student_logits_topk = torch.gather( - student_logits_for_kd, dim=-1, index=target_token_ids - ) # [B, teacher_seq_len, K] + # keep in full precision for numerical stability of loss + student_logits_for_kd = student_logits_for_kd.float() - student_logits_topk = student_logits_topk.float() + # Gather student logits for teacher's top-K tokens + student_logits_topk = torch.gather( + student_logits_for_kd, dim=-1, index=target_token_ids + ) # [B, teacher_seq_len, K] - # Apply KD temperature to student’s logits - if kd_temperature != 1.0: - student_logits_topk = student_logits_topk / kd_temperature + # Compute logsumexp across full vocabulary + student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True) - # Convert student top-k logits to logprobs - student_logprobs_topk = student_logits_topk - torch.logsumexp( - student_logits_topk, dim=-1, keepdim=True - ) # [B, teacher_seq_len, K] - else: - # Slice student logits to match teacher-provided sequence length - student_logits_for_kd = ( - student_logits[:, :teacher_seq_len, :] / kd_temperature - ) # [B, teacher_seq_len, vocab_size] - - # keep in full precision for numerical stability of loss - student_logits_for_kd = student_logits_for_kd.float() - - # Gather student logits for teacher's top-K tokens - student_logits_topk = torch.gather( - student_logits_for_kd, dim=-1, index=target_token_ids - ) # [B, teacher_seq_len, K] - - # Compute logsumexp across full vocabulary - student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True) - - # Convert just the top-k logits to logprobs - student_logprobs_topk = student_logits_topk - student_lse + # Convert just the top-k logits to logprobs + student_logprobs_topk = student_logits_topk - student_lse # Convert teacher_mask to boolean for indexing # In TorchScript, .bool() is sometimes unsupported, so we do: @@ -144,10 +87,6 @@ def loss( kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk) kd_loss = kd_loss_per_token.sum() - # Multiply by T^2 (classical KD scaling) - if kd_temperature != 1.0: - kd_loss = kd_loss * (kd_temperature**2) - # Normalize by number of items (if provided) or by valid tokens if num_items_in_batch > 0: kd_loss = kd_loss / float(num_items_in_batch) @@ -158,80 +97,77 @@ def loss( return kd_loss -def topk_kd_loss_with_zscore( - student_logits: torch.Tensor, # [B, seq_len, vocab_size] - target_token_ids: torch.Tensor, # [B, seq_len, K] - target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space - target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len] - kd_temperature: float = 1.0, # classic KD temperature - zscore_base_temp: float = 1.0, # from the paper - num_items_in_batch: int = -1, -): +class ChunkedTopKKDLoss(nn.Module): """ - A variant of top_k KL divergence with Z-score scaling - from "Logit Standardization in Knowledge Distillation". + A wrapper that chunks (splits) the student and teacher outputs along the time dimension + to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies. + + Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to top-K teacher logprobs. """ - target_logprobs = target_logprobs.float() + def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0): + super().__init__() + self.num_output_chunks = num_output_chunks + self.kd_temperature = kd_temperature - B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name - # 1) Gather the student's top-k logits to match teacher - student_logits_for_kd = student_logits[ - :, :teacher_seq_len, : - ] # [B, seq_len, vocab] - student_topk_logits = torch.gather( - student_logits_for_kd, dim=-1, index=target_token_ids - ) # [B, seq_len, K] + def forward( + self, + student_logits: torch.Tensor, # [B, seq_len, vocab_size] + target_token_ids: torch.Tensor, # [B, seq_len, K] + target_logprobs: torch.Tensor, # [B, seq_len, K] + target_mask: torch.Tensor, # [B, seq_len, K] + num_items_in_batch: int = -1, # optional batch size for normalization + ) -> torch.Tensor: + # 1. Split along the "token" dimension (dim=1). + student_logits_chunks = student_logits.chunk(self.num_output_chunks, dim=1) + token_ids_chunks = target_token_ids.chunk(self.num_output_chunks, dim=1) + logprobs_chunks = target_logprobs.chunk(self.num_output_chunks, dim=1) + mask_chunks = target_mask.chunk(self.num_output_chunks, dim=1) - student_topk_logits = student_topk_logits.float() + # We'll accumulate a global "sum of losses" and "sum of valid tokens" + # so that our final average is consistent with the entire sequence/batch. + total_loss = 0.0 + total_valid_tokens = 0 - # 2) If you want to keep the "classical" T scaling, apply it first - if kd_temperature != 1.0: - student_topk_logits = student_topk_logits / kd_temperature + # 2. Loop over each chunk and compute a chunk-specific loss. + for st_chunk, tid_chunk, lp_chunk, msk_chunk in zip( + student_logits_chunks, + token_ids_chunks, + logprobs_chunks, + mask_chunks, + strict=False, + ): + # We pass num_items_in_batch=-1 so that the kd_loss + # will average over *this chunk's* valid tokens only. + chunk_loss = loss( + student_logits=st_chunk, + target_token_ids=tid_chunk, + target_logprobs=lp_chunk, + target_mask=msk_chunk, + num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens + kd_temperature=self.kd_temperature, + ) - # 3) Convert teacher logprobs -> treat them as “logits” for z-score - # (They differ by +some_constant from real logits, but in z-score - # that constant is subtracted out anyway.) - teacher_logits_for_zscore = target_logprobs # rename variable for clarity + # kd_loss returns an average over the chunk's valid tokens. + # We want a global average in the end, so we need to re‐weight + # by the number of valid tokens in this chunk and keep track of the total. + chunk_valid_mask = msk_chunk.to(torch.bool) + chunk_valid_count = chunk_valid_mask.sum() # scalar tensor - # 4) Z-score teacher and student - # If target_mask is 2D, expand to 3D for the K dimension - if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len): - target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K) + # Re-scale "chunk average" back to "chunk sum" + chunk_loss_sum = chunk_loss * chunk_valid_count - teacher_z = zscore_standardize( - teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp - ) - student_z = zscore_standardize( - student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp - ) + total_loss += chunk_loss_sum + total_valid_tokens += chunk_valid_count - # 5) Convert to log-probs for KL - teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True) - student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True) + # 3. Normalize *once* at the end. + if num_items_in_batch > 0: + # If the user gave us a manual denominator (e.g. total items in batch), + # we divide by it. Typically used if each item is of different length. + final_loss = total_loss / float(num_items_in_batch) + else: + # Otherwise, divide by total valid tokens across all chunks. + # to get the same result as a non-chunked approach. + final_loss = total_loss / float(total_valid_tokens) - # 6) Restrict to valid tokens if needed - valid_mask = target_mask.bool() # shape [B, seq_len, K] - teacher_probs_z = teacher_logprobs_z.exp() - teacher_probs_z = teacher_probs_z[valid_mask] - teacher_logprobs_z = teacher_logprobs_z[valid_mask] - student_logprobs_z = student_logprobs_z[valid_mask] - - # 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] ) - kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z) - kd_loss = kd_loss_per_token.sum() - - # 8) If using classical KD scaling by T^2 - if kd_temperature != 1.0: - kd_loss = kd_loss * (kd_temperature**2) - - # Optionally scale by zscore_base_temp**2 if you want (paper might differ). - # kd_loss = kd_loss * (zscore_base_temp**2) - - # 9) Normalize - if num_items_in_batch is not None and num_items_in_batch > 0: - kd_loss = kd_loss / float(num_items_in_batch) - else: - kd_loss = kd_loss / float(kd_loss_per_token.size(0)) - - return kd_loss + return final_loss diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index f99f2ca28..0e98497a7 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -18,8 +18,7 @@ KD trainer from axolotl.core.trainers.base import AxolotlTrainer -from .topk_logprob.forward_kl import loss as topk_kd_loss -from .topk_logprob.forward_kl import topk_kd_loss_with_zscore +from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss class AxolotlKDTrainer(AxolotlTrainer): @@ -27,6 +26,27 @@ class AxolotlKDTrainer(AxolotlTrainer): Custom trainer subclass for Knowledge Distillation (KD) """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_accepts_loss_kwargs = True + + loss_fn = LigerFusedLinearKLTopKLogprobLoss( + self.args.kd_ce_alpha, # hard label loss + self.args.kd_alpha, # kd loss + self.args.kd_temperature, + self.args.kd_beta or 0.0, + compute_ce_loss=bool(self.args.kd_ce_alpha), + normalize_topk=self.args.kd_normalize_topk, + ) + target = self.model + + # Unwrap PEFT wrapper + if hasattr(target, "get_base_model"): + target = target.get_base_model() + + # Set on the actual model instance + target._loss_function = loss_fn + def _set_signature_columns_if_needed(self): super()._set_signature_columns_if_needed() columns_to_add = [] @@ -52,12 +72,12 @@ class AxolotlKDTrainer(AxolotlTrainer): Subclass and override for custom behavior. """ - - target_logprobs = inputs.pop("target_logprobs") - target_token_ids = inputs.pop("target_token_ids") - target_mask = inputs.pop("target_mask") - - seq_len = target_token_ids.shape[1] + if ( + self.args.sample_packing + and hasattr(inputs, "attention_mask") + and hasattr(inputs, "position_ids") + ): + del inputs["attention_mask"] if self.model_accepts_loss_kwargs: loss_kwargs = {} @@ -65,49 +85,4 @@ class AxolotlKDTrainer(AxolotlTrainer): loss_kwargs["num_items_in_batch"] = num_items_in_batch inputs = {**inputs, **loss_kwargs} outputs = model(**inputs) - - # FIXME: account for tokenizer.padding_side - student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous() - - shift_logits = student_logits.contiguous() - target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() - target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() - target_mask_for_loss = target_mask[..., 1:, :].contiguous() - - if self.args.kd_zscore_base_temp: - loss_kd = topk_kd_loss_with_zscore( - shift_logits, - target_token_ids_for_loss, - target_logprobs_for_loss, - target_mask_for_loss, - kd_temperature=self.args.kd_temperature, - zscore_base_temp=self.args.kd_zscore_base_temp, - num_items_in_batch=num_items_in_batch, - ) - else: - loss_kd = topk_kd_loss( - shift_logits, - target_token_ids_for_loss, - target_logprobs_for_loss, - target_mask_for_loss, - num_items_in_batch=num_items_in_batch, - kd_temperature=self.args.kd_temperature, - top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0, - ) - - if self.args.kd_ce_alpha > 0: - kd_alpha = self.args.kd_alpha - loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd - else: - loss = loss_kd - # Save past state if it exists - # TODO: this needs to be fixed and made cleaner later. - if self.args.past_index >= 0: - self._past = outputs[ # pylint: disable=attribute-defined-outside-init - self.args.past_index - ] - - if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: - loss *= self.accelerator.num_processes - - return (loss, outputs) if return_outputs else loss + return outputs[0] diff --git a/src/axolotl/integrations/kd/utils.py b/src/axolotl/integrations/kd/utils.py new file mode 100644 index 000000000..ba60694a5 --- /dev/null +++ b/src/axolotl/integrations/kd/utils.py @@ -0,0 +1,100 @@ +"""Helper KD utils""" + +import math +from typing import List, Union + +import numpy as np +import torch +from torch import FloatTensor, Tensor + + +def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor: + """ + Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs. + """ + # Ensure raw_logprobs matches kd_online_topk length for tensor operations + # This should ideally be handled by the caller ensuring correct padding/truncation first + if logprobs.shape[-1] != topk: + # pad last dimension of logprobs to match topk length with -inf + padding_len = topk - logprobs.shape[-1] + padding_tensor = torch.full( + ( + *logprobs.shape[:-1], + padding_len, + ), # Takes all dimensions of logprobs except the last, then appends padding_needed + float("-inf"), + dtype=logprobs.dtype, + device=logprobs.device, + ) + logprobs = torch.cat((logprobs, padding_tensor), dim=-1) + + # Convert logprobs at T_online to probabilities + # use log sum exp trick to avoid underflow + position_logprobs_lse = torch.logsumexp(logprobs, dim=-1, keepdim=True) + teacher_probs_t_online = torch.exp(logprobs - position_logprobs_lse) + + # Normalize probabilities (sum to 1) + # This is important if the top-k from server aren't a full distribution + teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=-1, keepdim=True) + teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum + + final_logprobs_tensor = torch.log(teacher_probs_t_online) + + return final_logprobs_tensor + + +def strided_chunk_views( + tensor: Union[np.ndarray, torch.Tensor], + chunks: int, + dim: int = 0, + stride: int = 1, + chunk_size: int | None = None, +) -> List[Union[np.ndarray, torch.Tensor]]: + """ + Split a tensor into chunks along a dimension with striding, prioritizing views over copies. + + Args: + tensor: Input tensor (numpy array or torch tensor) + chunks: Number of chunks to create + dim: Dimension along which to chunk (default: 0) + stride: Stride between chunk starting positions (default: 1) + chunk_size: Size of each chunk. If None, calculated automatically (default: None) + + Returns: + List of tensor chunks (views when possible, copies when necessary) + """ + + # Get the size of the specified dimension + dim_size = tensor.shape[dim] + + # Calculate chunk size if not provided + if chunk_size is None: + chunk_size = (dim_size + chunks - 1) // chunks # Ceiling division + + chunks_list = [] + + for i in range(chunks): + start_idx = i * stride + end_idx = min(start_idx + chunk_size, dim_size) + + # Break if we've gone beyond the tensor + if start_idx >= dim_size: + break + + # Create slice objects for all dimensions + slices = [slice(None)] * tensor.ndim + slices[dim] = slice(start_idx, end_idx) + + chunk = tensor[tuple(slices)] + chunks_list.append(chunk) + + return chunks_list + + +def chunk_overlap(input_tensor: Tensor, chunks: int, dim: int = 0, overlap: int = 1): + dim_size = input_tensor.shape[dim] + stride = math.ceil(dim_size / chunks) + + return strided_chunk_views( + input_tensor, chunks, dim, stride=stride, chunk_size=stride + overlap + ) diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 1c17ab2b5..c20f4545c 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -18,170 +18,11 @@ Module for the Plugin for LIGER integraton with Axolotl. Liger Kernel is the collection of Triton-native kernels for LLM Training. It is designed to be performant, correct, and light-weight. """ -import inspect -import sys -from axolotl.integrations.base import BasePlugin -from axolotl.utils.logging import get_logger +from .args import LigerArgs +from .plugin import LigerPlugin -from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 -from .utils import patch_with_compile_disable - -LOG = get_logger(__name__, use_environ=True) - - -class LigerPlugin(BasePlugin): - """ - Plugin for LIGER integraton with Axolotl. - """ - - def get_input_args(self): - return "axolotl.integrations.liger.LigerArgs" - - def pre_model_load(self, cfg): - if cfg.torch_compile: - # torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled - import liger_kernel.ops.fused_linear_cross_entropy - - patch_with_compile_disable( - liger_kernel.ops.fused_linear_cross_entropy, - "fused_linear_cross_entropy_forward", - ) - patch_with_compile_disable( - liger_kernel.ops.fused_linear_cross_entropy, - "fused_linear_cross_entropy_backward", - ) - from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss - from liger_kernel.transformers.functional import liger_cross_entropy - from liger_kernel.transformers.layer_norm import LigerLayerNorm - from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN - from liger_kernel.transformers.rms_norm import LigerRMSNorm - from liger_kernel.transformers.rope import liger_rotary_pos_emb - from liger_kernel.transformers.swiglu import LigerSwiGLUMLP - - if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy: - raise ValueError( - "Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set." - ) - - if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN: - apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type] - liger_fn_sig = inspect.signature(apply_liger_fn) - kwargs = {} - if "rope" in liger_fn_sig.parameters: - kwargs["rope"] = cfg.liger_rope - if "cross_entropy" in liger_fn_sig.parameters: - kwargs["cross_entropy"] = cfg.liger_cross_entropy - if "fused_linear_cross_entropy" in liger_fn_sig.parameters: - kwargs["fused_linear_cross_entropy"] = ( - cfg.liger_fused_linear_cross_entropy - ) - if "rms_norm" in liger_fn_sig.parameters: - kwargs["rms_norm"] = cfg.liger_rms_norm - if "layer_norm" in liger_fn_sig.parameters: - kwargs["layer_norm"] = cfg.liger_layer_norm - if "geglu" in liger_fn_sig.parameters: - kwargs["geglu"] = cfg.liger_glu_activation - elif "swiglu" in liger_fn_sig.parameters: - kwargs["swiglu"] = cfg.liger_glu_activation - LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}") - apply_liger_fn(**kwargs) - elif cfg.model_config_type == "jamba": - from transformers.models.jamba import modeling_jamba - - from .models.jamba import lce_forward as jamba_lce_forward - - if cfg.liger_rope: - modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb - if cfg.liger_rms_norm: - modeling_jamba.JambaRMSNorm = LigerRMSNorm - if cfg.liger_glu_activation: - modeling_jamba.JambaMLP = LigerSwiGLUMLP - if cfg.liger_layer_norm: - modeling_jamba.nn.LayerNorm = LigerLayerNorm - if cfg.liger_cross_entropy: - from transformers.loss.loss_utils import nn - - nn.functional.cross_entropy = liger_cross_entropy - if cfg.liger_fused_linear_cross_entropy: - modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward - elif cfg.model_config_type == "deepseek_v2": - from accelerate import init_empty_weights - from transformers import AutoModelForCausalLM - - with init_empty_weights(): - model = AutoModelForCausalLM.from_pretrained( - cfg.base_model, trust_remote_code=cfg.trust_remote_code or False - ) - modeling_mod = sys.modules[model.__class__.__module__] - - from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward - - if cfg.liger_rope: - # The DeepseekV2 version of RoPE is different than upstream LLaMA. - # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528 - LOG.warning("Fused liger_rope is not supported for DeepseekV2.") - if cfg.liger_glu_activation: - LOG.warning("liger_glu_activation is not supported for DeepseekV2.") - if cfg.liger_rms_norm: - modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm - if cfg.liger_glu_activation: - modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward - if cfg.liger_layer_norm: - modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward - if cfg.liger_cross_entropy: - # We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses - # nn.CrossEntropyLoss in the forward method. - modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss - if cfg.liger_fused_linear_cross_entropy: - modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward - elif cfg.model_config_type == "llama4": - from axolotl.integrations.liger.models.llama4 import ( - apply_liger_kernel_to_llama4, - ) - - apply_liger_kernel_to_llama4( - cross_entropy=cfg.liger_cross_entropy, - fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, - glu_activation=cfg.liger_glu_activation, - rms_norm=cfg.liger_rms_norm, - layer_norm=cfg.liger_layer_norm, - ) - elif cfg.model_config_type == "qwen3": - from axolotl.integrations.liger.models.qwen3 import ( - apply_liger_kernel_to_qwen3, - ) - - apply_liger_kernel_to_qwen3( - cross_entropy=cfg.liger_cross_entropy, - fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, - glu_activation=cfg.liger_glu_activation, - rms_norm=cfg.liger_rms_norm, - layer_norm=cfg.liger_layer_norm, - ) - elif cfg.model_config_type == "qwen3_moe": - from axolotl.integrations.liger.models.qwen3_moe import ( - apply_liger_kernel_to_qwen3_moe, - ) - - apply_liger_kernel_to_qwen3_moe( - cross_entropy=cfg.liger_cross_entropy, - fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, - glu_activation=cfg.liger_glu_activation, - rms_norm=cfg.liger_rms_norm, - layer_norm=cfg.liger_layer_norm, - ) - elif cfg.model_config_type == "granitemoe": - from liger_kernel.transformers import apply_liger_kernel_to_granite - - apply_liger_kernel_to_granite( - rope=cfg.liger_rope, - cross_entropy=cfg.liger_cross_entropy, - fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, - rms_norm=cfg.liger_rms_norm, - swiglu=cfg.liger_glu_activation, - ) - else: - LOG.warning( - f"Unsupported model config type: {cfg.model_config_type}. Liger not applied." - ) +__all__ = [ + "LigerArgs", + "LigerPlugin", +] diff --git a/src/axolotl/integrations/liger/args.py b/src/axolotl/integrations/liger/args.py index 7c9eb23d5..d5bb10cfd 100644 --- a/src/axolotl/integrations/liger/args.py +++ b/src/axolotl/integrations/liger/args.py @@ -15,7 +15,6 @@ """ Module for handling LIGER input arguments. """ -from typing import Optional from pydantic import BaseModel, model_validator @@ -29,13 +28,13 @@ class LigerArgs(BaseModel): Input args for LIGER. """ - liger_rope: Optional[bool] = None - liger_rms_norm: Optional[bool] = None - liger_layer_norm: Optional[bool] = None - liger_swiglu: Optional[bool] = None - liger_glu_activation: Optional[bool] = None - liger_cross_entropy: Optional[bool] = None - liger_fused_linear_cross_entropy: Optional[bool] = None + liger_rope: bool | None = None + liger_rms_norm: bool | None = None + liger_layer_norm: bool | None = None + liger_swiglu: bool | None = None + liger_glu_activation: bool | None = None + liger_cross_entropy: bool | None = None + liger_fused_linear_cross_entropy: bool | None = None @model_validator(mode="before") @classmethod @@ -52,3 +51,33 @@ class LigerArgs(BaseModel): ) data["liger_glu_activation"] = data.pop("liger_swiglu") return data + + @model_validator(mode="before") + @classmethod + def check_tiled_mlp_conflict(cls, data): + if ( + data.get("liger_glu_activation") is True + and data.get("tiled_mlp") is True + and not data.get("tiled_mlp_use_original_mlp") + ): + raise ValueError( + "You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_liger_rms_norm_tensor_parallel(cls, data): + if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) > 1: + raise ValueError( + "`liger_rms_norm` is incompatible with tensor parallelism, " + "see https://github.com/linkedin/Liger-Kernel/issues/826" + ) + return data + + @model_validator(mode="after") + def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self): + # TODO @SalmanMohammadi this is a larger fix - investigate + if self.tensor_parallel_size > 1 and self.liger_fused_linear_cross_entropy: + raise ValueError("Tensor parallelism is not compatible with liger losses.") + return self diff --git a/src/axolotl/integrations/liger/models/base.py b/src/axolotl/integrations/liger/models/base.py new file mode 100644 index 000000000..a9dbe9412 --- /dev/null +++ b/src/axolotl/integrations/liger/models/base.py @@ -0,0 +1,188 @@ +""" +Generic FLCE patch for untested models similar to Llama +""" + +from typing import Optional, Tuple, Union + +import torch +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.trainer.orpo_trainer import _FSDPForwardRedirection +from liger_kernel.utils import PEFT_AVAILABLE +from peft.utils import ModulesToSaveWrapper +from torch.distributed.fsdp import FullyShardedDataParallel +from transformers.modeling_outputs import CausalLMOutputWithPast + +from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix + + +def lce_forward( + self, + *args, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + """ + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + *args, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + + # if in training mode, don't materialize logits + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + loss = lce_maybe_trainable_lm_head( + self, + hidden_states=kept_hidden_states, + hidden_size=self.config.hidden_size, + labels=labels, + shift_labels=shift_labels, + **kwargs, + ) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def lce_maybe_trainable_lm_head( + self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs +): + lm_head = self.lm_head + + # Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration, + # i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read + # from the unwrapped module. + # See https://huggingface.co/docs/peft/package_reference/lora for reference. + if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper): + lm_head = lm_head.modules_to_save.default + + # If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA, + # reading the lm_head module weights and calling the kernel must be done within FSDP forward pass + # so the module entire parameters are summoned and kept in memory during the kernel execution. + if isinstance(lm_head, FullyShardedDataParallel): + return _FSDPForwardRedirection()( + lm_head, + _liger_for_causal_lm_loss, + lm_head.module, + hidden_states, + hidden_size, + labels, + shift_labels, + **loss_kwargs, + ) + + # FSDP is not used so we can read the lm_head weights and call the kernel directly + return _liger_for_causal_lm_loss( + lm_head=self.lm_head, + hidden_states=hidden_states, + hidden_size=hidden_size, + labels=labels, + shift_labels=shift_labels, + **loss_kwargs, + ) + + +def _liger_for_causal_lm_loss( + lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs +): + return LigerForCausalLMLoss( + hidden_states=hidden_states, + lm_head_weight=lm_head.weight, + labels=labels, + hidden_size=hidden_size, + shift_labels=shift_labels, + **loss_kwargs, + ) + + +def patch_lce_forward( + model_type, +): + try: + # Dynamically import the module and MLP class + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) + module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]) + model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") + + model_cls.forward = lce_forward + + except (ImportError, AttributeError) as e: + raise RuntimeError( + f"Could not import ForCausalLM class for model_type: {model_type}. " + f"Error: {str(e)}" + ) from e diff --git a/src/axolotl/integrations/liger/models/deepseekv2.py b/src/axolotl/integrations/liger/models/deepseekv2.py index 2f0d2a704..99adce4a7 100644 --- a/src/axolotl/integrations/liger/models/deepseekv2.py +++ b/src/axolotl/integrations/liger/models/deepseekv2.py @@ -2,8 +2,6 @@ DeepseekV2 model with LigerFusedLinearCrossEntropyLoss """ -# pylint: disable=duplicate-code - from typing import List, Optional, Tuple, Union import torch diff --git a/src/axolotl/integrations/liger/models/jamba.py b/src/axolotl/integrations/liger/models/jamba.py index d25529970..78689e40c 100644 --- a/src/axolotl/integrations/liger/models/jamba.py +++ b/src/axolotl/integrations/liger/models/jamba.py @@ -2,8 +2,6 @@ Jamba model with LigerFusedLinearCrossEntropyLoss """ -# pylint: disable=duplicate-code - from typing import Optional, Tuple, Union import torch diff --git a/src/axolotl/integrations/liger/models/llama4.py b/src/axolotl/integrations/liger/models/llama4.py index 689823bb6..e51140265 100644 --- a/src/axolotl/integrations/liger/models/llama4.py +++ b/src/axolotl/integrations/liger/models/llama4.py @@ -46,7 +46,6 @@ def lce_forward( Returns: """ - # pylint: disable=duplicate-code output_attentions = ( output_attentions if output_attentions is not None @@ -78,9 +77,7 @@ def lce_forward( hidden_states = outputs[0] if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1: - raise Exception( # pylint: disable=broad-exception-raised - "Liger Kernel does not support pretraining_tp!!" - ) + raise Exception("Liger Kernel does not support pretraining_tp!!") logits = None loss = None @@ -128,7 +125,7 @@ def apply_liger_kernel_to_llama4( rms_norm: bool = False, glu_activation: bool = False, layer_norm: bool = False, - **kwargs, # pylint: disable=unused-argument + **kwargs, ) -> None: """ Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3) @@ -144,15 +141,15 @@ def apply_liger_kernel_to_llama4( layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False. """ - import transformers.models.llama4.modeling_llama4 # noqa: F401 # pylint: disable=unused-import + import transformers.models.llama4.modeling_llama4 # noqa: F401 from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.swiglu import LigerSwiGLUMLP - assert not ( - cross_entropy and fused_linear_cross_entropy - ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) modeling_llama4 = sys.modules["transformers.models.llama4.modeling_llama4"] @@ -165,7 +162,7 @@ def apply_liger_kernel_to_llama4( # clone config to avoid modifying the original config = deepcopy(config) if intermediate_size: - setattr(config, "intermediate_size", intermediate_size) + config.intermediate_size = intermediate_size return LigerSwiGLUMLP(config, **kwargs) modeling_llama4.Llama4TextMLP = _liger_swiglu_mlp_wrapper diff --git a/src/axolotl/integrations/liger/models/qwen3.py b/src/axolotl/integrations/liger/models/qwen3.py index 1dc19eaf9..b008755da 100644 --- a/src/axolotl/integrations/liger/models/qwen3.py +++ b/src/axolotl/integrations/liger/models/qwen3.py @@ -43,7 +43,6 @@ def lce_forward( Returns: """ - # pylint: disable=duplicate-code output_attentions = ( output_attentions if output_attentions is not None @@ -113,9 +112,8 @@ def apply_liger_kernel_to_qwen3( rms_norm: bool = False, glu_activation: bool = False, layer_norm: bool = False, - **kwargs, # pylint: disable=unused-argument + **kwargs, ) -> None: - # pylint: disable=duplicate-code """ Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3) @@ -130,15 +128,15 @@ def apply_liger_kernel_to_qwen3( layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False. """ - import transformers.models.qwen3.modeling_qwen3 # noqa: F401 # pylint: disable=unused-import + import transformers.models.qwen3.modeling_qwen3 # noqa: F401 from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.swiglu import LigerSwiGLUMLP - assert not ( - cross_entropy and fused_linear_cross_entropy - ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) modeling_qwen3 = sys.modules["transformers.models.qwen3.modeling_qwen3"] diff --git a/src/axolotl/integrations/liger/models/qwen3_moe.py b/src/axolotl/integrations/liger/models/qwen3_moe.py index 89bdc5bcc..40bee110c 100644 --- a/src/axolotl/integrations/liger/models/qwen3_moe.py +++ b/src/axolotl/integrations/liger/models/qwen3_moe.py @@ -45,7 +45,6 @@ def lce_forward( Returns: """ - # pylint: disable=duplicate-code output_attentions = ( output_attentions if output_attentions is not None @@ -135,9 +134,8 @@ def apply_liger_kernel_to_qwen3_moe( rms_norm: bool = False, glu_activation: bool = False, layer_norm: bool = False, - **kwargs, # pylint: disable=unused-argument + **kwargs, ) -> None: - # pylint: disable=duplicate-code """ Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3) @@ -152,15 +150,15 @@ def apply_liger_kernel_to_qwen3_moe( layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False. """ - import transformers.models.qwen3_moe.modeling_qwen3_moe # noqa: F401 # pylint: disable=unused-import + import transformers.models.qwen3_moe.modeling_qwen3_moe # noqa: F401 from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.swiglu import LigerSwiGLUMLP - assert not ( - cross_entropy and fused_linear_cross_entropy - ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) modeling_qwen3_moe = sys.modules["transformers.models.qwen3_moe.modeling_qwen3_moe"] @@ -174,7 +172,7 @@ def apply_liger_kernel_to_qwen3_moe( # clone config to avoid modifying the original config = deepcopy(config) if intermediate_size: - setattr(config, "intermediate_size", intermediate_size) + config.intermediate_size = intermediate_size return LigerSwiGLUMLP(config, **kwargs) modeling_qwen3_moe.Qwen3MoeMLP = _liger_swiglu_mlp_wrapper diff --git a/src/axolotl/integrations/liger/plugin.py b/src/axolotl/integrations/liger/plugin.py new file mode 100644 index 000000000..89f7c37b7 --- /dev/null +++ b/src/axolotl/integrations/liger/plugin.py @@ -0,0 +1,182 @@ +""" +Liger-Kernel Plugin for Axolotl +""" + +import inspect +import sys + +from axolotl.integrations.base import BasePlugin +from axolotl.utils.logging import get_logger + +from .models.base import patch_lce_forward +from .utils import patch_with_compile_disable + +LOG = get_logger(__name__) + + +class LigerPlugin(BasePlugin): + """ + Plugin for LIGER integraton with Axolotl. + """ + + def get_input_args(self): + return "axolotl.integrations.liger.LigerArgs" + + def pre_model_load(self, cfg): + if cfg.torch_compile: + # torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled + import liger_kernel.ops.fused_linear_cross_entropy + + patch_with_compile_disable( + liger_kernel.ops.fused_linear_cross_entropy, + "fused_linear_cross_entropy_forward", + ) + patch_with_compile_disable( + liger_kernel.ops.fused_linear_cross_entropy, + "fused_linear_cross_entropy_backward", + ) + from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss + from liger_kernel.transformers.functional import liger_cross_entropy + from liger_kernel.transformers.layer_norm import LigerLayerNorm + from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN + from liger_kernel.transformers.rms_norm import LigerRMSNorm + from liger_kernel.transformers.rope import liger_rotary_pos_emb + from liger_kernel.transformers.swiglu import LigerSwiGLUMLP + + if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy: + raise ValueError( + "Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set." + ) + + if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN: + apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type] + liger_fn_sig = inspect.signature(apply_liger_fn) + kwargs = {} + if "rope" in liger_fn_sig.parameters: + kwargs["rope"] = cfg.liger_rope + if "cross_entropy" in liger_fn_sig.parameters: + kwargs["cross_entropy"] = cfg.liger_cross_entropy + if "fused_linear_cross_entropy" in liger_fn_sig.parameters: + kwargs["fused_linear_cross_entropy"] = ( + cfg.liger_fused_linear_cross_entropy + ) + if "rms_norm" in liger_fn_sig.parameters: + kwargs["rms_norm"] = cfg.liger_rms_norm + if "layer_norm" in liger_fn_sig.parameters: + kwargs["layer_norm"] = cfg.liger_layer_norm + if "geglu" in liger_fn_sig.parameters: + kwargs["geglu"] = cfg.liger_glu_activation + elif "swiglu" in liger_fn_sig.parameters: + kwargs["swiglu"] = cfg.liger_glu_activation + LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}") + apply_liger_fn(**kwargs) + elif cfg.model_config_type == "jamba": + from transformers.models.jamba import modeling_jamba + + from .models.jamba import lce_forward as jamba_lce_forward + + if cfg.liger_rope: + modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_jamba.JambaRMSNorm = LigerRMSNorm + if cfg.liger_glu_activation: + modeling_jamba.JambaMLP = LigerSwiGLUMLP + if cfg.liger_layer_norm: + modeling_jamba.nn.LayerNorm = LigerLayerNorm + if cfg.liger_cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if cfg.liger_fused_linear_cross_entropy: + modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward + elif cfg.model_config_type == "deepseek_v2": + from accelerate import init_empty_weights + from transformers import AutoModelForCausalLM + + with init_empty_weights(): + model = AutoModelForCausalLM.from_pretrained( + cfg.base_model, trust_remote_code=cfg.trust_remote_code or False + ) + modeling_mod = sys.modules[model.__class__.__module__] + + from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward + + if cfg.liger_rope: + # The DeepseekV2 version of RoPE is different than upstream LLaMA. + # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528 + LOG.warning("Fused liger_rope is not supported for DeepseekV2.") + if cfg.liger_rms_norm: + modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm + if cfg.liger_glu_activation: + modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward + if cfg.liger_layer_norm: + LOG.warning("liger_layer_norm is not supported for DeepseekV2.") + if cfg.liger_cross_entropy: + # We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses + # nn.CrossEntropyLoss in the forward method. + modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward + elif cfg.model_config_type == "llama4": + from axolotl.integrations.liger.models.llama4 import ( + apply_liger_kernel_to_llama4, + ) + + apply_liger_kernel_to_llama4( + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + glu_activation=cfg.liger_glu_activation, + rms_norm=cfg.liger_rms_norm, + layer_norm=cfg.liger_layer_norm, + ) + elif cfg.model_config_type == "qwen3": + from axolotl.integrations.liger.models.qwen3 import ( + apply_liger_kernel_to_qwen3, + ) + + apply_liger_kernel_to_qwen3( + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + glu_activation=cfg.liger_glu_activation, + rms_norm=cfg.liger_rms_norm, + layer_norm=cfg.liger_layer_norm, + ) + elif cfg.model_config_type == "qwen3_moe": + from axolotl.integrations.liger.models.qwen3_moe import ( + apply_liger_kernel_to_qwen3_moe, + ) + + apply_liger_kernel_to_qwen3_moe( + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + glu_activation=cfg.liger_glu_activation, + rms_norm=cfg.liger_rms_norm, + layer_norm=cfg.liger_layer_norm, + ) + elif cfg.model_config_type == "granitemoe": + from liger_kernel.transformers import apply_liger_kernel_to_granite + + apply_liger_kernel_to_granite( + rope=cfg.liger_rope, + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + rms_norm=cfg.liger_rms_norm, + swiglu=cfg.liger_glu_activation, + ) + elif cfg.liger_fused_linear_cross_entropy: + try: + patch_lce_forward(cfg.model_config_type) + LOG.warning_once( + f"Applied ONLY liger_fused_linear_cross_entropy genericpatches for model type: {cfg.model_config_type}" + ) + LOG.warning_once( + f"Liger + {cfg.model_config_type} generic FLCE support is experimental and may not work as expected." + ) + except RuntimeError: + LOG.warning( + f"Unsupported model config type: {cfg.model_config_type}. Liger not applied." + ) + else: + LOG.warning( + f"Unsupported model config type: {cfg.model_config_type}. Liger not applied." + ) diff --git a/src/axolotl/integrations/lm_eval/__init__.py b/src/axolotl/integrations/lm_eval/__init__.py index 8db4dc634..0ab6b8697 100644 --- a/src/axolotl/integrations/lm_eval/__init__.py +++ b/src/axolotl/integrations/lm_eval/__init__.py @@ -7,7 +7,7 @@ import subprocess # nosec from axolotl.integrations.base import BasePlugin from axolotl.integrations.lm_eval.cli import build_lm_eval_command -from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401 +from .args import LMEvalArgs as LMEvalArgs class LMEvalPlugin(BasePlugin): @@ -20,7 +20,6 @@ class LMEvalPlugin(BasePlugin): def post_train_unload(self, cfg): if cfg.lm_eval_post_train: - # pylint: disable=duplicate-code for lm_eval_args in build_lm_eval_command( cfg.lm_eval_tasks, bfloat16=cfg.bfloat16 or cfg.bf16, diff --git a/src/axolotl/integrations/lm_eval/cli.py b/src/axolotl/integrations/lm_eval/cli.py index 19608e1d9..ead82dcb7 100644 --- a/src/axolotl/integrations/lm_eval/cli.py +++ b/src/axolotl/integrations/lm_eval/cli.py @@ -99,7 +99,6 @@ def lm_eval(config: str, cloud: Optional[str] = None): with open(config, encoding="utf-8") as file: cfg: DictDefault = DictDefault(yaml.safe_load(file)) - # pylint: disable=duplicate-code for lm_eval_args in build_lm_eval_command( cfg.lm_eval_tasks, bfloat16=cfg.bfloat16 or cfg.bf16, diff --git a/src/axolotl/integrations/spectrum/__init__.py b/src/axolotl/integrations/spectrum/__init__.py index 9f66aef97..5e8f9128d 100644 --- a/src/axolotl/integrations/spectrum/__init__.py +++ b/src/axolotl/integrations/spectrum/__init__.py @@ -23,7 +23,7 @@ import requests from axolotl.integrations.base import BasePlugin from axolotl.utils.logging import get_logger -from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401 +from .args import SpectrumArgs as SpectrumArgs LOG = get_logger(__name__) @@ -46,7 +46,7 @@ def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5): "^lm_head.weight$", "^model.embed_tokens.weight$", ] - for layer_type, layer_names in top_layers_by_type.items(): + for _, layer_names in top_layers_by_type.items(): for layer_name in layer_names: unfrozen_parameters.append(layer_name) return unfrozen_parameters @@ -84,7 +84,7 @@ class SpectrumPlugin(BasePlugin): snr_data = json.load(fin) except FileNotFoundError: pass - except Exception as exc: # pylint: disable=broad-exception-caught + except Exception as exc: LOG.warning(f"Failed to read SNR data from {snr_path}: {exc}") if not snr_data: diff --git a/src/axolotl/integrations/spectrum/args.py b/src/axolotl/integrations/spectrum/args.py index df5756038..be6ca4bfc 100644 --- a/src/axolotl/integrations/spectrum/args.py +++ b/src/axolotl/integrations/spectrum/args.py @@ -15,6 +15,7 @@ """ Module for handling Spectrum input arguments. """ + from typing import Optional from pydantic import BaseModel, model_validator diff --git a/src/axolotl/kernels/geglu.py b/src/axolotl/kernels/geglu.py index 6acbea0d4..ee3260ebd 100644 --- a/src/axolotl/kernels/geglu.py +++ b/src/axolotl/kernels/geglu.py @@ -5,8 +5,6 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202). Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. """ -# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code - import torch import triton import triton.language as tl diff --git a/src/axolotl/kernels/lora.py b/src/axolotl/kernels/lora.py index 63c9e57bd..c3356fb90 100644 --- a/src/axolotl/kernels/lora.py +++ b/src/axolotl/kernels/lora.py @@ -7,13 +7,12 @@ See "LoRA: Low-Rank Adaptation of Large Language Models" Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. """ -# pylint: disable=invalid-name - from typing import Callable import torch from bitsandbytes.functional import QuantState from torch import nn +from torch.distributed.tensor import DTensor from .geglu import geglu_backward, geglu_forward from .quantize import dequantize @@ -25,6 +24,7 @@ def get_lora_parameters( proj: nn.Module, ) -> tuple[ torch.Tensor, + torch.Tensor | None, QuantState | None, torch.Tensor | None, torch.Tensor | None, @@ -37,39 +37,54 @@ def get_lora_parameters( proj: The projection module to extract parameters from. Returns: - A tuple containing the base weight matrix, quantization state, LoRA A matrix, - LoRA B matrix, and scaling factor. States and matrices may be None if not - available. + A tuple containing the base weights, quantization state, LoRA A and B weights, + scaling factor, and base layer bias. Quant state, weights, and bias may be + `None` if not available. """ # For DPO or disabled adapters base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj W = base_layer.weight + b = base_layer.bias if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: quant_state = getattr(W, "quant_state", None) - return W, quant_state, None, None, None + return W, b, quant_state, None, None, None + + quant_state = getattr(W, "quant_state", None) active_adapter = ( proj.active_adapters[0] if hasattr(proj, "active_adapters") else proj.active_adapter ) - A = proj.lora_A[active_adapter].weight - B = proj.lora_B[active_adapter].weight + + linear_A = proj.lora_A[active_adapter] + linear_B = proj.lora_B[active_adapter] + + # This manual unsharding is needed for FSDP2 + LoRA kernels compatibility. + # We fuse linear layers + LoRA adapters calculations into a single + # torch.autograd.Function, bypassing the registered unshard / reshard behavior. + # Note that we don't apply resharding later in this module (it gets messy quickly), + # but LoRA parameters are generally small enough that this is not an issue. + if isinstance(linear_A.weight, DTensor): + linear_A.unshard() + linear_B.unshard() + + A = linear_A.weight + B = linear_B.weight s = proj.scaling[active_adapter] - quant_state = getattr(W, "quant_state", None) - - return W, quant_state, A, B, s + return W, b, quant_state, A, B, s def matmul_lora( X: torch.Tensor, W: torch.Tensor, - W_quant: QuantState, - A: torch.Tensor, - B: torch.Tensor, - s: float, + b: torch.Tensor | None, + W_quant: QuantState | None, + A: torch.Tensor | None, + B: torch.Tensor | None, + s: float | None, out: torch.Tensor | None = None, ) -> torch.Tensor: """ @@ -90,20 +105,22 @@ def matmul_lora( dtype = X.dtype W = dequantize(W.t(), W_quant) + reshape = False if X.dim() == 3: batch, seq_len, _ = X.shape X = X.view(-1, X.shape[-1]) reshape = True - else: - reshape = False out = torch.matmul(X, W, out=out) if W_quant is not None: del W if A is not None: - A, B = A.t(), B.t() - out += (X @ A.to(dtype)) @ (s * B.to(dtype)) + A, B = A.t().to(dtype), B.t().to(dtype) # type: ignore[union-attr] + out += s * X @ A @ B + + if b is not None: + out += b return out.view(batch, seq_len, -1) if reshape else out @@ -117,17 +134,20 @@ class LoRA_MLP(torch.autograd.Function): ctx, X: torch.Tensor, gate_weight: torch.Tensor, - gate_quant: object | None, + gate_bias: torch.Tensor | None, + gate_quant: QuantState | None, gate_A: torch.Tensor | None, gate_B: torch.Tensor | None, gate_scale: float, up_weight: torch.Tensor, - up_quant: object | None, + up_bias: torch.Tensor | None, + up_quant: QuantState | None, up_A: torch.Tensor | None, up_B: torch.Tensor | None, up_scale: float, down_weight: torch.Tensor, - down_quant: object | None, + down_bias: torch.Tensor | None, + down_quant: QuantState | None, down_A: torch.Tensor | None, down_B: torch.Tensor | None, down_scale: float, @@ -142,20 +162,22 @@ class LoRA_MLP(torch.autograd.Function): ctx: Autograd context X: Input features gate_weight: Gate projection weight + gate_bias: Gate projection bias gate_quant: Gate quantization state gate_A: Gate LoRA A matrix gate_B: Gate LoRA B matrix gate_scale: Gate LoRA scale - up_weight: Up-projection weight - up_quant: Up-projection quantization state - up_A: Up-projection LoRA A matrix - up_B: Up-projection LoRA B matrix - up_scale: Up-projection LoRA scale - down_weight: Down-projection weight - down_quant: Down-projection quantization state - down_A: Down-projection LoRA A matrix - down_B: Down-projection LoRA B matrix - down_scale: Down-projection LoRA scale + up_weight: Up projection weight + up_quant: Up projection quantization state + up_A: Up projection LoRA A matrix + up_B: Up projection LoRA B matrix + up_scale: Up projection LoRA scale + down_weight: Down projection weight + down_bias: Down projection bias + down_quant: Down projection quantization state + down_A: Down projection LoRA A matrix + down_B: Down projection LoRA B matrix + down_scale: Down projection LoRA scale activation_fn: Forward activation function activation_fn_backward: Backward activation function inplace: Whether to perform operations in-place @@ -164,15 +186,17 @@ class LoRA_MLP(torch.autograd.Function): Output transformed by multi-layer perceptron and activation function """ # Compute projections - gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale) - up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale) + gate = matmul_lora( + X, gate_weight, gate_bias, gate_quant, gate_A, gate_B, gate_scale + ) + up = matmul_lora(X, up_weight, up_bias, up_quant, up_A, up_B, up_scale) # Activation hidden = activation_fn(gate, up) # Down projection output = matmul_lora( - hidden, down_weight, down_quant, down_A, down_B, down_scale + hidden, down_weight, down_bias, down_quant, down_A, down_B, down_scale ) # Save for backward @@ -195,22 +219,26 @@ class LoRA_MLP(torch.autograd.Function): torch.Tensor | None, None, None, + None, torch.Tensor | None, torch.Tensor | None, None, None, None, + None, torch.Tensor | None, torch.Tensor | None, None, None, None, + None, torch.Tensor | None, torch.Tensor | None, None, None, None, None, + None, ]: """ Performs backward pass computation for LoRA MLP. @@ -222,7 +250,7 @@ class LoRA_MLP(torch.autograd.Function): Returns: Tuple containing gradients for all inputs from forward pass: - Input gradient tensor (or `None`) - - `None` for weights/quantization states + - `None` for weights/biases/quantization states - LoRA A/B matrix gradients (or `None`) - `None` for scaling factors - `None` for activation functions and flags @@ -265,9 +293,10 @@ class LoRA_MLP(torch.autograd.Function): dtype = X.dtype # Down projection - DW = matmul_lora( + grad_down = matmul_lora( grad_output, down_weight.t(), + None, down_quant, down_B, down_A, @@ -275,7 +304,7 @@ class LoRA_MLP(torch.autograd.Function): ) # Activation backward - h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up) + h, grad_gate, grad_up = ctx.activation_fn_backward(grad_down, gate, up) # Initialize and compute LoRA gradients d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None @@ -315,8 +344,8 @@ class LoRA_MLP(torch.autograd.Function): dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t()) # Gate projection gradients - gate_weight = dequantize(gate_weight.t(), gate_quant) - dX += grad_gate @ gate_weight.t() + gate_weight = dequantize(gate_weight, gate_quant) + dX += grad_gate @ gate_weight del gate_weight if gate_A is not None and gate_B is not None: @@ -334,22 +363,26 @@ class LoRA_MLP(torch.autograd.Function): dX, None, None, + None, d_gate_A.t() if d_gate_A is not None else None, d_gate_B.t() if d_gate_B is not None else None, None, None, None, + None, d_up_A.t() if d_up_A is not None else None, d_up_B.t() if d_up_B is not None else None, None, None, None, + None, d_down_A.t() if d_down_A is not None else None, d_down_B.t() if d_down_B is not None else None, None, None, None, None, + None, ) @@ -364,23 +397,26 @@ def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch. Returns: Output tensor after applying LoRA-adapted MLP with SwiGLU activation """ - gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) - upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj) - downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) + gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) + upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj) + downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) out = LoRA_MLP.apply( X, gateW, + gateb, gateW_quant, gateA, gateB, gateS, upW, + upb, upW_quant, upA, upB, upS, downW, + downb, downW_quant, downA, downB, @@ -404,22 +440,25 @@ def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.T Returns: Output tensor after applying LoRA-adapted MLP with GEGLU activation """ - gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) - upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj) - downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) + gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) + upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj) + downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) out = LoRA_MLP.apply( X, gateW, + gateb, gateW_quant, gateA, gateB, gateS, upW, + upb, upW_quant, upA, upB, upS, downW, + downb, downW_quant, downA, downB, @@ -446,16 +485,19 @@ class LoRA_QKV(torch.autograd.Function): ctx: torch.autograd.function.FunctionCtx, X: torch.Tensor, q_weight: torch.Tensor, + q_bias: torch.Tensor | None, q_quant: QuantState | None, q_A: torch.Tensor | None, q_B: torch.Tensor | None, q_scale: float, k_weight: torch.Tensor, + k_bias: torch.Tensor | None, k_quant: QuantState | None, k_A: torch.Tensor | None, k_B: torch.Tensor | None, k_scale: float, v_weight: torch.Tensor, + v_bias: torch.Tensor | None, v_quant: QuantState | None, v_A: torch.Tensor | None, v_B: torch.Tensor | None, @@ -469,16 +511,19 @@ class LoRA_QKV(torch.autograd.Function): ctx: Autograd context X: Input tensor q_weight: Query projection weight + q_bias: Query projection bias q_quant: Query quantization state q_A: Query LoRA A matrix q_B: Query LoRA B matrix q_scale: Query LoRA scale k_weight: Key projection weight + k_bias: Key projection bias k_quant: Key quantization state k_A: Key LoRA A matrix k_B: Key LoRA B matrix k_scale: Key LoRA scale v_weight: Value projection weight + v_bias: Value projection bias v_quant: Value quantization state v_A: Value LoRA A matrix v_B: Value LoRA B matrix @@ -488,20 +533,21 @@ class LoRA_QKV(torch.autograd.Function): Returns: Tuple of (Query, Key, Value) projection tensors """ - Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale) - K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale) - V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale) + Q = matmul_lora(X, q_weight, q_bias, q_quant, q_A, q_B, q_scale) + K = matmul_lora(X, k_weight, k_bias, k_quant, k_A, k_B, k_scale) + V = matmul_lora(X, v_weight, v_bias, v_quant, v_A, v_B, v_scale) ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B) ctx.scales = (q_scale, k_scale, v_scale) ctx.quants = (q_quant, k_quant, v_quant) ctx.weights = (q_weight, k_weight, v_weight) + ctx.biases = (q_bias, k_bias, v_bias) ctx.inplace = inplace return Q, K, V @staticmethod - @torch_amp_custom_fwd + @torch_amp_custom_bwd def backward( ctx: torch.autograd.function.FunctionCtx, q_grad: torch.Tensor, @@ -511,16 +557,19 @@ class LoRA_QKV(torch.autograd.Function): torch.Tensor, None, None, + None, torch.Tensor | None, torch.Tensor | None, None, None, None, + None, torch.Tensor | None, torch.Tensor | None, None, None, None, + None, torch.Tensor | None, torch.Tensor | None, None, @@ -608,31 +657,31 @@ class LoRA_QKV(torch.autograd.Function): # Transpose gradients if needed if d_A_q is not None: d_A_q = d_A_q.t() - if d_B_q is not None: - d_B_q = d_B_q.t() + d_B_q = d_B_q.t() # type: ignore[union-attr] if d_A_k is not None: d_A_k = d_A_k.t() - if d_B_k is not None: - d_B_k = d_B_k.t() + d_B_k = d_B_k.t() # type: ignore[union-attr] if d_A_v is not None: d_A_v = d_A_v.t() - if d_B_v is not None: - d_B_v = d_B_v.t() + d_B_v = d_B_v.t() # type: ignore[union-attr] return ( grad_X.view(batch, seq_len, -1), None, None, + None, d_A_q, d_B_q, None, None, None, + None, d_A_k, d_B_k, None, None, None, + None, d_A_v, d_B_v, None, @@ -653,22 +702,25 @@ def apply_lora_qkv( Returns: Tuple of (Query, Key, Value) projection tensors """ - QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj) - KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj) - VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj) + QW, Qb, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj) + KW, Kb, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj) + VW, Vb, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj) Q, K, V = LoRA_QKV.apply( X, QW, + Qb, QW_quant, QA, QB, QS, KW, + Kb, KW_quant, KA, KB, KS, VW, + Vb, VW_quant, VA, VB, @@ -688,10 +740,11 @@ class LoRA_O(torch.autograd.Function): ctx: torch.autograd.function.FunctionCtx, X: torch.Tensor, W: torch.Tensor, + b: torch.Tensor, W_quant: QuantState | None, - A: torch.Tensor | None, - B: torch.Tensor | None, - S: float, + A: torch.Tensor, + B: torch.Tensor, + s: float, ) -> torch.Tensor: """ Forward pass for output projection with LoRA. @@ -700,19 +753,20 @@ class LoRA_O(torch.autograd.Function): ctx: Autograd context X: Input tensor W: Output projection weight + b: Output projection bias W_quant: Weight quantization state A: LoRA A matrix B: LoRA B matrix - S: LoRA scaling factor + s: LoRA scaling factor Returns: - Output projection tensor + Output projection result """ - XW = matmul_lora(X, W, W_quant, A, B, S) + XW = matmul_lora(X, W, b, W_quant, A, B, s) ctx.custom_saved_tensors = ( W, W_quant, - S, + s, ) ctx.save_for_backward(A, B, X) @@ -727,8 +781,9 @@ class LoRA_O(torch.autograd.Function): torch.Tensor, None, None, - torch.Tensor | None, - torch.Tensor | None, + None, + torch.Tensor, + torch.Tensor, None, ]: """ @@ -741,7 +796,7 @@ class LoRA_O(torch.autograd.Function): Returns: Tuple containing gradients for all forward inputs """ - W, W_quant, S = ctx.custom_saved_tensors + W, W_quant, s = ctx.custom_saved_tensors A, B, X = ctx.saved_tensors batch, seq_len, hd = X.shape @@ -751,17 +806,19 @@ class LoRA_O(torch.autograd.Function): # Weight projection dY_X = X.t() @ dY - d_A = S * dY_X @ B - d_B = S * A @ dY_X + d_A = s * dY_X @ B + d_B = s * A @ dY_X # Get derivative for dX W = dequantize(W.t(), W_quant) dX = dY @ W.t() del W - dX += dY @ B.to(dtype) @ (S * A.to(dtype)) - # W, W_quant, A, B, S - return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None + A, B = A.to(dtype), B.to(dtype) + dX += s * dY @ B @ A + + # W, b, W_quant, A, B, s + return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor: @@ -774,7 +831,7 @@ def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor: Returns: Transformed output tensor """ - OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj) - output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS) + OW, Ob, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj) + output = LoRA_O.apply(X, OW, Ob, OW_quant, OA, OB, OS) return output diff --git a/src/axolotl/kernels/quantize.py b/src/axolotl/kernels/quantize.py index b61603fbc..d094f2381 100644 --- a/src/axolotl/kernels/quantize.py +++ b/src/axolotl/kernels/quantize.py @@ -1,7 +1,5 @@ """Dequantization utilities for `bitsandbytes` integration.""" -# pylint: disable=invalid-name,global-statement - import ctypes import bitsandbytes as bnb diff --git a/src/axolotl/kernels/swiglu.py b/src/axolotl/kernels/swiglu.py index 43a798edc..b13bcd350 100644 --- a/src/axolotl/kernels/swiglu.py +++ b/src/axolotl/kernels/swiglu.py @@ -99,7 +99,6 @@ def _swiglu_bwd_kernel( tl.store(up_ptr + offsets, grad_up, mask=mask) # grad wrt up -# pylint: disable=unnecessary-lambda-assignment def swiglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: """ SwiGLU forward pass. Computes SwiGLU activation: `x * sigmoid(x) * up`, where @@ -128,7 +127,6 @@ def swiglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: return out -# pylint: disable=unnecessary-lambda-assignment def swiglu_backward( grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/src/axolotl/loaders/__init__.py b/src/axolotl/loaders/__init__.py index 3eef75e58..ae99bf16d 100644 --- a/src/axolotl/loaders/__init__.py +++ b/src/axolotl/loaders/__init__.py @@ -1,6 +1,5 @@ """Init for axolotl.loaders module""" -# pylint: disable=unused-import # flake8: noqa from .adapter import load_adapter, load_lora diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 5517ff50a..8e8177b62 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -14,6 +14,7 @@ from peft import ( PeftConfig, PeftMixedModel, PeftModel, + TaskType, get_peft_model, ) from transformers import PreTrainedModel @@ -29,14 +30,12 @@ LOG = get_logger(__name__) def setup_quantized_meta_for_peft(model: torch.nn.Module): """Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device""" - def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument + def temp_to_method(self, *args, **kwargs): return self for param in model.parameters(): if isinstance(param, Params4bit): - param.quant_state._orig_to = ( # pylint: disable=protected-access - param.quant_state.to - ) + param.quant_state._orig_to = param.quant_state.to param.quant_state.to = types.MethodType(temp_to_method, param.quant_state) @@ -44,10 +43,8 @@ def setup_quantized_peft_meta_for_training(model: torch.nn.Module): """Replaces dummy `quant_state.to` method with the original function to allow training to continue""" for param in model.parameters(): if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"): - param.quant_state.to = ( - param.quant_state._orig_to # pylint: disable=protected-access - ) - param.quant_state._orig_to = None # pylint: disable=protected-access + param.quant_state.to = param.quant_state._orig_to + param.quant_state._orig_to = None def find_all_linear_names(model): @@ -77,6 +74,7 @@ def load_lora( config_only: bool = False, ) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]: lora_target_modules = cfg.lora_target_modules or [] + lora_target_parameters = cfg.lora_target_parameters or [] if cfg.lora_target_linear: linear_names = find_all_linear_names(model) @@ -102,18 +100,30 @@ def load_lora( lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora if cfg.peft_layer_replication: lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication + if cfg.peft_trainable_token_indices: + lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices + + # Determine the correct PEFT task type + model_cls = type(model).__name__ + if "SequenceClassification" in model_cls: + task_type = TaskType.SEQ_CLS + elif "TokenClassification" in model_cls: + task_type = TaskType.TOKEN_CLS + else: + task_type = TaskType.CAUSAL_LM lora_config = LoraConfig( r=cfg.lora_r, lora_alpha=cfg.lora_alpha, target_modules=lora_target_modules, + target_parameters=lora_target_parameters, layers_to_transform=cfg.peft_layers_to_transform, layers_pattern=cfg.peft_layers_pattern, lora_dropout=cfg.lora_dropout, fan_in_fan_out=cfg.lora_fan_in_fan_out, modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, bias="none", - task_type="CAUSAL_LM", + task_type=task_type, **lora_config_kwargs, ) @@ -123,9 +133,9 @@ def load_lora( rank = int(os.environ.get("LOCAL_RANK", 0)) if ( - cfg.fsdp + cfg.fsdp_config and cfg.adapter - and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and cfg.fsdp_config.cpu_ram_efficient_loading and rank != 0 ): setup_quantized_meta_for_peft(model) @@ -153,9 +163,9 @@ def load_lora( "Exception caught during model.print_trainable_parameters(): %s", exc ) elif ( - cfg.fsdp + cfg.fsdp_config and cfg.adapter - and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and cfg.fsdp_config.cpu_ram_efficient_loading and rank != 0 ): setup_quantized_peft_meta_for_training(model) diff --git a/src/axolotl/loaders/adapters/__init__.py b/src/axolotl/loaders/adapters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/loaders/constants.py b/src/axolotl/loaders/constants.py index c08518dd6..4939cb28d 100644 --- a/src/axolotl/loaders/constants.py +++ b/src/axolotl/loaders/constants.py @@ -1,21 +1,18 @@ """Shared constants for axolotl.loaders module""" -from transformers import ( - Gemma3ForConditionalGeneration, - Llama4ForConditionalGeneration, - LlavaForConditionalGeneration, - Mistral3ForConditionalGeneration, - MllamaForConditionalGeneration, - Qwen2_5_VLForConditionalGeneration, - Qwen2VLForConditionalGeneration, +from transformers import AutoModelForImageTextToText +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, ) -MULTIMODAL_AUTO_MODEL_MAPPING = { - "mllama": MllamaForConditionalGeneration, - "llama4": Llama4ForConditionalGeneration, - "llava": LlavaForConditionalGeneration, - "qwen2_vl": Qwen2VLForConditionalGeneration, - "qwen2_5_vl": Qwen2_5_VLForConditionalGeneration, - "mistral3": Mistral3ForConditionalGeneration, - "gemma3": Gemma3ForConditionalGeneration, -} +MULTIMODAL_AUTO_MODEL_MAPPING = dict(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES) + +MULTIMODAL_AUTO_MODEL_MAPPING["lfm2-vl"] = AutoModelForImageTextToText + +try: + from transformers import VoxtralForConditionalGeneration + + # transformers >4.53.2 + MULTIMODAL_AUTO_MODEL_MAPPING["voxtral"] = VoxtralForConditionalGeneration +except ImportError: + pass diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index ed1e8bbf2..d1c3052e7 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -1,5 +1,5 @@ -"""Model loader class implementation for loading, configuring, and patching various -models. +""" +Model loader class implementation for loading, configuring, and patching various models. """ import gc @@ -14,6 +14,7 @@ import torch import transformers import transformers.modeling_utils from accelerate import init_empty_weights +from accelerate.parallelism_config import ParallelismConfig from peft import ( PeftConfig, PeftMixedModel, @@ -21,8 +22,10 @@ from peft import ( PeftModelForCausalLM, prepare_model_for_kbit_training, ) +from torch.distributed import DeviceMesh from transformers import ( AutoModelForCausalLM, + AutoModelForImageTextToText, AutoModelForVision2Seq, AwqConfig, BitsAndBytesConfig, @@ -50,6 +53,7 @@ from axolotl.telemetry.errors import send_errors from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import ( + build_parallelism_config, get_device_count, get_device_type, ) @@ -88,6 +92,10 @@ class ModelLoader: `AutoModelForCausalLM`). """ + use_parallel_config: bool | None = False + parallelism_config: ParallelismConfig | None = None + device_mesh: DeviceMesh | None = None + def __init__( self, cfg: DictDefault, @@ -95,7 +103,7 @@ class ModelLoader: *, inference: bool = False, reference_model: bool = False, - **kwargs, # pylint: disable=unused-argument + **kwargs, ): """Initializes the ModelLoader. @@ -127,7 +135,7 @@ class ModelLoader: # Init model config self.model_config = load_model_config(cfg) - self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name + self.auto_model_loader = AutoModelForCausalLM # Initialize the patch manager self.patch_manager = PatchManager( @@ -141,10 +149,15 @@ class ModelLoader: """Check if flash attention is installed.""" return find_spec("flash_attn") is not None - @cached_property - def qlora_fsdp(self): + @property + def is_fsdp_enabled(self): + """Property that determines if FSDP is enabled.""" + return self.cfg.fsdp_config is not None or self.cfg.fsdp is not None + + @property + def is_qlora_and_fsdp_enabled(self): """Property that determines if FSDP with QLoRA is enabled.""" - return self.cfg.fsdp and self.cfg.adapter == "qlora" + return self.is_fsdp_enabled and self.cfg.adapter == "qlora" @send_errors def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]: @@ -159,6 +172,7 @@ class ModelLoader: # Build the model PLUGIN_MANAGER.pre_model_load(self.cfg) + self.patch_manager.apply_post_plugin_pre_model_load_patches() skip_move_to_device = self._build_model() PLUGIN_MANAGER.post_model_build(self.cfg, self.model) @@ -179,27 +193,54 @@ class ModelLoader: def _apply_pre_model_load_setup(self): """Apply patches and setup configurations before model loading.""" + if self.use_parallel_config is not None: + self.use_parallel_config = ( + self.cfg.fsdp_config + or (self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1) + or ( + self.cfg.context_parallel_size + and self.cfg.context_parallel_size > 1 + ) + ) + if self.cfg.fsdp_config and self.cfg.fsdp_version != 2: + self.use_parallel_config = False + + if self.use_parallel_config: + self._set_parallel_config() self._set_auto_model_loader() self._set_device_map_config() if self.cfg.revision_of_model: self.model_kwargs["revision"] = self.cfg.revision_of_model + if self.cfg.use_kernels: + self.model_kwargs["use_kernels"] = self.cfg.use_kernels self._set_quantization_config() self._set_attention_config() + self._check_model_requirements() def _apply_post_model_load_setup(self): """Configure the model after it has been loaded.""" # Handle PeftModel if needed if ( isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM)) - and not self.qlora_fsdp + and not self.is_qlora_and_fsdp_enabled ): self.model = self.model.merge_and_unload() + self._apply_activation_checkpointing() self._resize_token_embeddings() self._adjust_model_config() - self._log_memory_usage() self._configure_embedding_dtypes() self._configure_qat() + log_gpu_memory_usage(LOG, "Memory usage after model load", 0) + + def _apply_activation_checkpointing(self): + if self.cfg.activation_offloading is True: + from axolotl.core.trainers.mixins.activation_checkpointing import ( + ac_wrap_hf_model, + ) + + # ^^ importing this at the module level breaks plugins + ac_wrap_hf_model(self.model) def _resize_token_embeddings(self): """Resize token embeddings if needed.""" @@ -253,22 +294,13 @@ class ModelLoader: ): self.model.config.eos_token_id = self.tokenizer.eos_token_id - def _log_memory_usage(self): - """Log device memory usage after model load.""" - if hasattr(self.model, "device") and self.model.device.type in ( - "cuda", - "mps", - "npu", - ): - log_gpu_memory_usage(LOG, "after model load", self.model.device) - def _configure_embedding_dtypes(self): """Configure embedding module dtypes.""" # Get embedding modules embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type) # Initial dtype conversion - if not self.cfg.fsdp: + if not self.is_fsdp_enabled: # We don't run this during FSDP because this will leave mixed and bfloat16 # dtypes in the model which FSDP doesn't like if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast: @@ -280,11 +312,14 @@ class ModelLoader: ) # Handle DeepSpeed Zero3 - if is_deepspeed_zero3_enabled(): + if ( + is_deepspeed_zero3_enabled() + or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3" + ): self._set_z3_leaf_modules() # Apply gradient checkpointing if needed - needs_fa2_dtype = self.cfg.adapter or self.cfg.fsdp + needs_fa2_dtype = self.cfg.adapter or self.is_fsdp_enabled if self.cfg.adapter in ["lora", "qlora"]: needs_fa2_dtype = True if self.cfg.gradient_checkpointing: @@ -300,10 +335,12 @@ class ModelLoader: # we need to convert them back to fp16/bf16 for flash-attn compatibility. ( (needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) - and not self.qlora_fsdp + and not self.is_qlora_and_fsdp_enabled + ) + or ( + # CCE requires embedding layers to be in fp16/bf16 for backward pass + self.cfg.cut_cross_entropy ) - # CCE requires embedding layers to be in fp16/bf16 for backward pass - or self.cfg.cut_cross_entropy ) if should_convert: @@ -359,7 +396,6 @@ class ModelLoader: and not (self.cfg.rl and self.cfg.load_in_4bit) and not skip_move_to_device ): - # TODO: validate this conditional self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}") if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: @@ -384,6 +420,13 @@ class ModelLoader: gc.collect() torch.cuda.empty_cache() + def _set_parallel_config(self): + """Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator""" + parallelism_config, device_mesh = build_parallelism_config(self.cfg) + if parallelism_config: + self.parallelism_config = parallelism_config + self.device_mesh = device_mesh + def _set_auto_model_loader(self): """Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM` (set at `__init__`). When using a multimodal model, `self.auto_model_loader` @@ -393,6 +436,8 @@ class ModelLoader: self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get( self.model_config.model_type, AutoModelForVision2Seq ) + if isinstance(self.auto_model_loader, str): + self.auto_model_loader = AutoModelForImageTextToText def _set_device_map_config(self): """Setup `device_map` according to config""" @@ -432,7 +477,17 @@ class ModelLoader: self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype - if not is_deepspeed_zero3_enabled(): + is_ds_zero3 = is_deepspeed_zero3_enabled() + + # FSDP requires control over device placement, so don't set device_map when FSDP is enabled + if self.is_fsdp_enabled: + # For QLoRA + FSDP, we still need to set device_map to "auto" for proper initialization + if self.is_qlora_and_fsdp_enabled: + self.model_kwargs["device_map"] = { + "": int(os.environ.get("LOCAL_RANK", 0)) + } + # For other FSDP cases, don't set device_map at all + elif not is_ds_zero3: self.model_kwargs["device_map"] = device_map cur_device = get_device_type() @@ -454,8 +509,17 @@ class ModelLoader: def _set_quantization_config(self): """Set up quantization config (bitsandbytes, awq, gptq, etc.)""" - self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit - self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit + + if self.cfg.model_quantization_config == "Mxfp4Config": + from transformers import Mxfp4Config + + mxfp4_kwargs = {} + if self.cfg.model_quantization_config_kwargs: + mxfp4_kwargs = self.cfg.model_quantization_config_kwargs + self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs) + else: + self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit + self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit if self.cfg.gptq: if not hasattr(self.model_config, "quantization_config"): @@ -490,7 +554,9 @@ class ModelLoader: self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **self.model_config.quantization_config ) - elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]: + elif self.cfg.adapter == "qlora" and self.model_kwargs.get( + "load_in_4bit", False + ): bnb_config = { "load_in_4bit": True, "llm_int8_threshold": 6.0, @@ -501,11 +567,14 @@ class ModelLoader: "bnb_4bit_quant_storage": torch.bfloat16, } if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not ( - self.cfg.deepspeed or self.cfg.fsdp + self.cfg.deepspeed or self.is_fsdp_enabled ): # for some reason, this causes the loss to be off by an order of magnitude # but deepspeed needs this still in bfloat16 bnb_config["bnb_4bit_quant_storage"] = torch.float32 + if self.cfg.model_config_type == "falcon_h1": + # output projection cannot be quantized for Falcon-H1 models + bnb_config["llm_int8_skip_modules"] = ["out_proj"] if self.cfg.bnb_config_kwargs: bnb_config.update(self.cfg.bnb_config_kwargs) @@ -513,13 +582,18 @@ class ModelLoader: self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **bnb_config, ) - elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]: + elif self.cfg.adapter == "lora" and self.model_kwargs.get( + "load_in_8bit", False + ): bnb_config = { "load_in_8bit": True, } # Exclude mamba blocks from int8 quantization for jamba if self.cfg.model_config_type == "jamba": bnb_config["llm_int8_skip_modules"] = ["mamba"] + if self.cfg.model_config_type == "falcon_h1": + # output projection cannot be quantized for Falcon-H1 models + bnb_config["llm_int8_skip_modules"] = ["out_proj"] self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **bnb_config, ) @@ -531,33 +605,37 @@ class ModelLoader: def _set_attention_config(self): """Sample packing uses custom FA2 patch""" - if self.cfg.flex_attention: + if self.cfg.attn_implementation: + self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation + elif self.cfg.flex_attention: self.model_kwargs["attn_implementation"] = "flex_attention" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "flex_attention" - ) + self.model_config._attn_implementation = "flex_attention" elif self.cfg.flash_attention: if not self.cfg.sample_packing and self.cfg.s2_attention: pass self.model_kwargs["attn_implementation"] = "flash_attention_2" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) + self.model_config._attn_implementation = "flash_attention_2" elif self.cfg.sdp_attention: self.model_kwargs["attn_implementation"] = "sdpa" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "sdpa" - ) + self.model_config._attn_implementation = "sdpa" elif self.cfg.eager_attention: self.model_kwargs["attn_implementation"] = "eager" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "eager" - ) + self.model_config._attn_implementation = "eager" if self.cfg.low_cpu_mem_usage: self.model_kwargs["low_cpu_mem_usage"] = True + def _check_model_requirements(self): + if self.cfg.model_config_type in ["lfm2-vl", "lfm2"]: + from transformers.utils.import_utils import is_causal_conv1d_available + + if is_causal_conv1d_available(): + raise ImportError( + "The 'causal-conv1d' package is installed but causes compatibility issues with LFM2 models. " + "Please uninstall it by running: `pip uninstall -y causal-conv1d`" + ) + def _configure_zero3_memory_efficient_loading( self, ) -> HfTrainerDeepSpeedConfig | None: @@ -597,17 +675,82 @@ class ModelLoader: return hf_ds_cfg + def _load_model_from_config(self, model_loader_class=None) -> PreTrainedModel: + """ + Load model with random initialization using from_config. + + Uses the selected loader when provided; otherwise falls back to the auto loader. + """ + loader = model_loader_class or self.auto_model_loader + if loader in [AutoModelForCausalLM, AutoModelForVision2Seq]: + model = loader.from_config( + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + ) + else: + model = loader(config=self.model_config) + + return model + + def _load_model_from_pretrained(self, model_loader_class=None) -> PreTrainedModel: + """Load model from pretrained weights.""" + loader = model_loader_class or self.auto_model_loader + kwargs = { + "config": self.model_config, + "trust_remote_code": self.cfg.trust_remote_code or False, + **self.model_kwargs, + } + return loader.from_pretrained(self.base_model, **kwargs) + def _build_model(self) -> bool: """Load model, with load strategy depending on config.""" skip_move_to_device = False + + if self.cfg.tensor_parallel_size > 1: + self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size + self.model_kwargs["tp_plan"] = "auto" + self.model_kwargs["device_mesh"] = self.device_mesh + if "device_map" in self.model_kwargs: + del self.model_kwargs["device_map"] # not compatible with `tp_plan` + + if self.is_fsdp_enabled: + if self.cfg.fsdp_config.cpu_ram_efficient_loading: + skip_move_to_device = True + # Don't delete device_map for QLoRA + FSDP - it was set correctly in + # _set_device_map + if ( + "device_map" in self.model_kwargs + and not self.is_qlora_and_fsdp_enabled + ): + del self.model_kwargs["device_map"] + elif self.is_qlora_and_fsdp_enabled: + skip_move_to_device = True + + if ( + self.cfg.tensor_parallel_size <= 1 + and self.cfg.fsdp_config.cpu_ram_efficient_loading + and self.cfg.fsdp_version == 2 + ): + # setting device_map for TP is not supported + local_rank = int(os.getenv("LOCAL_RANK", "0")) + if local_rank == 0: + self.model_kwargs["device_map"] = "cpu" + else: + self.model_kwargs["device_map"] = "meta" + if ( - self.qlora_fsdp - and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + self.is_qlora_and_fsdp_enabled + and self.cfg.fsdp_config.cpu_ram_efficient_loading and ( self.cfg.model_config_type == "dbrx" or self.cfg.qlora_sharded_model_loading ) ): + if self.cfg.reinit_weights: + LOG.warning( + "reinit_weights is not supported with sharded quantized loading. " + "Loading from pretrained weights instead." + ) quant_storage = self.cfg.torch_dtype quantization_config = getattr( self.model_config, "quantization_config", None @@ -623,41 +766,14 @@ class ModelLoader: quantization_config=quantization_config, ) skip_move_to_device = True - elif ( - self.model_config.model_type in ["llama", "llama4"] - and not self.cfg.trust_remote_code - and not self.cfg.gptq - ): - # TODO: Do we need to open this up for all models? - if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: - skip_move_to_device = True - if "device_map" in self.model_kwargs: - del self.model_kwargs["device_map"] - - # Please don't remove underscore binding without reading the fn docstring. - _ = self._configure_zero3_memory_efficient_loading() - - # Load model with random initialization if specified - if self.cfg.random_init_weights: - # AutoModel classes support the from_config method - if self.auto_model_loader in [ - AutoModelForCausalLM, - AutoModelForVision2Seq, - ]: - self.model = self.auto_model_loader.from_config( - config=self.model_config, - ) - else: - self.model = self.auto_model_loader(config=self.model_config) - else: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - **self.model_kwargs, - ) elif self.model_type == "MambaLMHeadModel": + if self.cfg.reinit_weights: + LOG.warning( + "reinit_weights is not supported with MambaLMHeadModel. " + "Loading from pretrained weights instead." + ) # FIXME this is janky at best and hacked together to make it work - MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name + MambaLMHeadModel = fix_mamba_attn_for_loss() self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"] self.model_kwargs["device"] = torch.cuda.current_device() @@ -668,55 +784,40 @@ class ModelLoader: self.base_model, **self.model_kwargs, ) - elif ( - self.model_type - and self.model_type != "AutoModelForCausalLM" - and not self.cfg.trust_remote_code - ): - if self.cfg.gptq: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - else: - self.model = getattr(transformers, self.model_type).from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) else: - if self.cfg.gptq: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) + # Please don't remove underscore binding without reading the fn docstring + _ = self._configure_zero3_memory_efficient_loading() + + if ( + self.model_type + and self.model_type != "AutoModelForCausalLM" + and not self.cfg.trust_remote_code + and not self.cfg.gptq + ): + # Use model type from transformers + model_loader_class = getattr(transformers, self.model_type) else: - if ( - self.cfg.fsdp - and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - ): - # disabling either of these two still leads to VRAM spike before setting back down - skip_move_to_device = True - if "device_map" in self.model_kwargs: - del self.model_kwargs["device_map"] + # Use auto model loader (handles gptq and default cases) + model_loader_class = self.auto_model_loader - # Please don't remove underscore binding without reading the fn docstring. - _ = self._configure_zero3_memory_efficient_loading() + if self.cfg.reinit_weights: + self.model = self._load_model_from_config(model_loader_class) + else: + self.model = self._load_model_from_pretrained(model_loader_class) - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) if is_deepspeed_zero3_enabled(): skip_move_to_device = True + if self.cfg.tensor_parallel_size > 1: + # workaround for upstream 4.54.0 not setting _tp_size or _device_mesh + # TODO(wing): remove once 4.54.1 is released + if self.model._tp_size != self.cfg.tensor_parallel_size: + self.model._tp_size = self.cfg.tensor_parallel_size + self.model._device_mesh = self.model_kwargs["device_mesh"] + + if self.cfg.experimental_skip_move_to_device is not None: + skip_move_to_device = self.cfg.experimental_skip_move_to_device + return skip_move_to_device def _set_z3_leaf_modules(self): @@ -749,8 +850,8 @@ class ModelLoader: skip_prepare_model_for_kbit_training = True if ( - self.qlora_fsdp - or (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading) + self.is_qlora_and_fsdp_enabled + or (self.is_fsdp_enabled and self.cfg.fsdp_config.cpu_ram_efficient_loading) or is_deepspeed_zero3_enabled() ): # Make sure everything is in the same dtype @@ -772,6 +873,9 @@ class ModelLoader: dist_dtype: torch.dtype, before_kbit_train_or_finetune: bool, ): + dest = {"dtype": dist_dtype} + if self.cfg.lora_on_cpu: + dest["device"] = "cpu" for name, module in self.model.named_modules(): if "norm" in name: module.to(dist_dtype) @@ -782,4 +886,4 @@ class ModelLoader: # don't upcast lm_head for btlm continue if any(m in name for m in embedding_modules) and hasattr(module, "weight"): - module.to(dist_dtype) + module.to(**dest) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 23f79d368..1e46f5c34 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -4,6 +4,7 @@ Applies pre- and post-model load patches for various fixes and optimizations. """ import importlib.util +import os from functools import cached_property import addict @@ -49,10 +50,12 @@ class PatchManager: def apply_pre_model_load_patches(self): """Apply pre-model load patches based on config.""" + self._apply_transformers_patches() + # self._apply_flex_attention_patches() self._apply_flash_attention_patches() + self._apply_chunked_cross_entropy_patch() self._apply_fsdp_patches() self._apply_adapter_patches() - self._apply_flex_attention_patches() self._apply_model_specific_patches() self._apply_fp8_patches() self._apply_flash_attention_peft_patches() @@ -63,6 +66,30 @@ class PatchManager: self._patch_llama_derived_model() self._apply_mistral_cross_entropy_patch() self._apply_self_attention_lora_patch() + self._apply_fsdp2_bnb_patches() + self._apply_patch_deepspeed_zero3() + self._apply_voxtral_patches() + self._apply_apertus_patches() + + def apply_post_plugin_pre_model_load_patches(self): + """Apply post plugin-pre_model_load load patches based on config.""" + self._apply_tiled_mlp(self.cfg.model_config_type) + + def _apply_transformers_patches(self): + from axolotl.monkeypatch.transformers.trainer_loss_calc import ( + patch_evaluation_loop, + patch_maybe_log_save_evaluate, + ) + + patch_evaluation_loop() + patch_maybe_log_save_evaluate() + + if self.cfg.context_parallel_size > 1: + from axolotl.monkeypatch.transformers.trainer_context_parallel import ( + patch_prepare_context_parallel_inputs, + ) + + patch_prepare_context_parallel_inputs() def apply_post_model_load_patches(self, model: PreTrainedModel): """Apply patches that require the model instance.""" @@ -78,12 +105,41 @@ class PatchManager: patch_xformers_attn_over_fa2() self.cfg.flash_attention = True + def _apply_chunked_cross_entropy_patch(self): + if self.cfg.chunked_cross_entropy: + from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn + + if self.cfg.chunked_cross_entropy_num_chunks: + patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks) + else: + patch_chunked_ce_loss_fn() + def _apply_fsdp_patches(self): """Apply patches for FSDP configurations.""" - if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2": + if self.cfg.context_parallel_size > 1 or ( + self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2" + ): + from axolotl.monkeypatch.accelerate.parallelism_config import ( + patch_parallelism_config, + ) + + patch_parallelism_config() + if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2": from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2 patch_accelerate_fsdp2() + if self.cfg.rl: + from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2 + + patch_trl_prepare_fsdp2() + + # if self.cfg.fsdp_config: + # # see transformers#39152 + # from axolotl.monkeypatch.trainer_fsdp_optim import ( + # patch_training_loop_for_fsdp, + # ) + # + # patch_training_loop_for_fsdp() def _apply_adapter_patches(self): """Apply patches for adapter configurations.""" @@ -96,13 +152,17 @@ class PatchManager: """Apply patches for flexible attention.""" if self.cfg.flex_attention: from axolotl.monkeypatch.attention.flex_attn import ( - patch_flex_make_mask, patch_flex_wrapper, ) flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {} patch_flex_wrapper(**flex_attn_compile_kwargs) - patch_flex_make_mask() + if self.cfg.sample_packing: + from axolotl.core.attention.flex_block_mask import ( + patch_create_causal_mask, + ) + + patch_create_causal_mask(self.cfg.model_config_type) def _apply_model_specific_patches(self): """Apply patches specific to model architectures.""" @@ -116,6 +176,20 @@ class PatchManager: patch_llama4_linearized_modeling() + if self.cfg.model_config_type == "qwen3_next" and self.cfg.sample_packing: + from axolotl.monkeypatch.models.qwen3_next.modeling import ( + patch_qwen3_next_modeling_packing, + ) + + patch_qwen3_next_modeling_packing() + + if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type: + from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import ( + apply_mistral_tokenizer_image_patch, + ) + + apply_mistral_tokenizer_image_patch() + def _apply_fp8_patches(self): """Apply patches for FP8 support.""" if self.cfg.fp8: @@ -123,7 +197,9 @@ class PatchManager: patch_create_accelerate_code_for_fp8, ) - patch_create_accelerate_code_for_fp8() + patch_create_accelerate_code_for_fp8( + self.cfg.fp8_enable_fsdp_float8_all_gather + ) def _apply_flash_attention_peft_patches(self): """Apply patches for Flash Attention with PEFT.""" @@ -136,13 +212,19 @@ class PatchManager: def _apply_gradient_checkpointing_patches(self): """Apply patches for gradient checkpointing.""" - if self.cfg.gradient_checkpointing in ["unsloth", "offload"]: + if ( + self.cfg.gradient_checkpointing + and self.cfg.activation_offloading == "legacy" + ): from axolotl.monkeypatch.gradient_checkpointing import ( hf_grad_checkpoint_offload_wrapper, ) transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper - if self.cfg.gradient_checkpointing == "offload_disk": + elif ( + self.cfg.gradient_checkpointing + and self.cfg.activation_offloading == "offload_disk" + ): from axolotl.monkeypatch.gradient_checkpointing import ( hf_grad_checkpoint_disk_offload_wrapper, ) @@ -166,6 +248,17 @@ class PatchManager: def _apply_self_attention_lora_patch(self): """Apply self-attention LoRA patches if configured.""" if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel: + # Only patch if conditions are met + can_patch = ( + self.cfg.lora_dropout == 0 + if hasattr(self.cfg, "lora_dropout") + else True + ) # default to True if lora_dropout is not set + + if not can_patch: + LOG.warning("Cannot patch self-attention - requires no dropout") + return + from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora patch_self_attn_lora(self.cfg) @@ -200,6 +293,50 @@ class PatchManager: has_remote_code=has_remote_code, ) + if self.cfg.sample_packing: + from axolotl.monkeypatch.data.batch_dataset_fetcher import ( + apply_multipack_dataloader_patch, + ) + + LOG.info("Applying multipack dataloader patch for sample packing...") + apply_multipack_dataloader_patch() + + def _apply_fsdp2_bnb_patches(self): + """Apply FSDP2 BNB patches.""" + if ( + self.cfg.fsdp_config + and str(self.cfg.fsdp_version) == "2" + and self.cfg.adapter == "qlora" + ): + from axolotl.monkeypatch.fsdp2_qlora import ( + apply_init_sharded_param_patch, + apply_init_unsharded_param_patch, + ) + + apply_init_sharded_param_patch() + apply_init_unsharded_param_patch() + + def _apply_tiled_mlp(self, model_type: str): + if self.cfg.tiled_mlp: + from axolotl.monkeypatch.tiled_mlp import ( + patch_tiled_mlp, + ) + + patch_tiled_mlp( + model_type, + use_original_mlp=self.cfg.tiled_mlp_use_original_mlp, + cfg_num_shards=self.cfg.tiled_mlp_num_shards, + ) + + def _apply_voxtral_patches(self): + """Apply patches for Voxtral model.""" + if self.cfg.model_config_type == "voxtral": + from axolotl.monkeypatch.models.voxtral.modeling import ( + patch_voxtral_conditional_generation_forward, + ) + + patch_voxtral_conditional_generation_forward() + def _patch_attention(self): """Apply attention-specific patches based on model type.""" if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")): @@ -219,6 +356,13 @@ class PatchManager: replace_stablelm_attn_with_flash_attn(self.cfg.base_model) + if self.model_config.model_type in ("mistral3", "llava"): + from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import ( + apply_patch_is_packed_sequence, + ) + + apply_patch_is_packed_sequence() + def _patch_loss_llama(self): """Patch loss functions and other optimizations for LLaMA models.""" if not self.cfg.is_llama_derived_model: @@ -249,31 +393,21 @@ class PatchManager: patch_self_attn_lora() - def _patch_llama_flash_attention(self, packed=False): + def _patch_llama_flash_attention(self): """Apply Flash Attention patches for LLaMA models.""" from axolotl.monkeypatch.llama_attn_hijack_flash import ( replace_llama_attn_with_flash_attn, ) - if packed: - if self.cfg.device not in ["mps", "cpu"] and not self.inference: - LOG.info("patching with flash attention for sample packing") - replace_llama_attn_with_flash_attn( - packed=True, - cross_entropy=self.cfg.flash_attn_cross_entropy, - rms_norm=self.cfg.flash_attn_rms_norm, - ) - elif self.cfg.s2_attention: + if self.cfg.s2_attention: LOG.info("patching w/ flash-enabled, shifted-sparse attention") replace_llama_attn_with_flash_attn( - packed=False, cross_entropy=self.cfg.flash_attn_cross_entropy, rms_norm=self.cfg.flash_attn_rms_norm, use_shifted_sparse_attn=True, ) elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm: replace_llama_attn_with_flash_attn( - packed=False, cross_entropy=self.cfg.flash_attn_cross_entropy, rms_norm=self.cfg.flash_attn_rms_norm, ) @@ -304,7 +438,7 @@ class PatchManager: and self.cfg.sample_packing ): if self.cfg.flash_attention: - self._patch_llama_flash_attention(packed=self.cfg.sample_packing) + self._patch_llama_flash_attention() elif self.cfg.xformers_attention: self._patch_llama_xformers_attention() elif self.cfg.sample_packing: @@ -327,17 +461,12 @@ class PatchManager: from axolotl.monkeypatch.llama_attn_hijack_flash import ( is_xformers_swiglu_available, replace_llama_mlp_with_swiglu, - replace_llama_qkv_with_fused, ) if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): LOG.info("Patching with SwiGLU...") replace_llama_mlp_with_swiglu(model) - if self.cfg.flash_attn_fuse_qkv: - LOG.info("Patching with fused QKV...") - replace_llama_qkv_with_fused(model) - def _apply_unsloth_patches(self, model): """Apply unsloth optimization patches.""" if self.cfg.unsloth_lora_mlp: @@ -365,3 +494,26 @@ class PatchManager: from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches apply_lora_kernel_patches(model=model, cfg=self.cfg) + + def _apply_patch_deepspeed_zero3(self): + try: + from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled + + from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches + + if self.cfg.activation_offloading is True and ( + is_deepspeed_zero3_enabled() + or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3" + ): + apply_deepspeed_patches() + except ImportError as e: + LOG.warning(f"DeepSpeed patches not applied: {e}") + + def _apply_apertus_patches(self): + """Apply patches for Apertus model.""" + if self.cfg.model_config_type == "apertus": + from axolotl.monkeypatch.models.apertus.activation import ( + patch_apertus_xielu_activation, + ) + + patch_apertus_xielu_activation() diff --git a/src/axolotl/loaders/processor.py b/src/axolotl/loaders/processor.py index 5d1c36618..e6fd4c0ed 100644 --- a/src/axolotl/loaders/processor.py +++ b/src/axolotl/loaders/processor.py @@ -23,6 +23,13 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): if cfg.processor_type: processor_cls = getattr(transformers, cfg.processor_type) + if cfg.tokenizer_use_mistral_common: + from axolotl.utils.mistral import Mistral3Processor + + return Mistral3Processor( + tokenizer=tokenizer, + ) + processor = processor_cls.from_pretrained( cfg.processor_config, trust_remote_code=cfg.trust_remote_code or False, diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index cfc5da42c..48856116c 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -7,6 +7,7 @@ import transformers from transformers import ( AddedToken, AutoTokenizer, + PreTrainedTokenizer, ) from axolotl.integrations.base import PluginManager @@ -14,6 +15,7 @@ from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.telemetry.errors import send_errors from axolotl.utils.chat_templates import get_chat_template_from_config +from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import ( barrier, is_local_main_process, @@ -49,7 +51,7 @@ def modify_tokenizer_files( tokenizer_dir = os.path.join(output_dir, "tokenizer") os.makedirs(tokenizer_dir, exist_ok=True) - if is_local_main_process(): # pylint: disable=too-many-nested-blocks + if is_local_main_process(): # Load the tokenizer temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) @@ -72,9 +74,9 @@ def modify_tokenizer_files( for token_id, new_value in token_id_mappings.items(): token_id_str = str(token_id) if token_id_str in config_data["added_tokens_decoder"]: - config_data["added_tokens_decoder"][token_id_str][ - "content" - ] = new_value + config_data["added_tokens_decoder"][token_id_str]["content"] = ( + new_value + ) else: raise ValueError( f"Token ID {token_id_str} not found in added_tokens_decoder" @@ -119,8 +121,21 @@ def modify_tokenizer_files( @send_errors -def load_tokenizer(cfg): +def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: """Load and configure the tokenizer based on the provided config.""" + + def _load_mistral_common_tokenizer(cfg: DictDefault): + """Load mistral-common tokenizer""" + from axolotl.utils.mistral import HFMistralTokenizer + + # Load the HF-compatible wrapper around MistralTokenizer + tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config) + + return tokenizer + + if cfg.tokenizer_use_mistral_common: + return _load_mistral_common_tokenizer(cfg) + model_config = load_model_config(cfg) tokenizer_kwargs = {} use_fast = True # this is the default @@ -175,7 +190,8 @@ def load_tokenizer(cfg): tokenizer.padding_side = "left" # Qwen base only has single token, so we need to set the special tokens - if cfg.is_qwen_derived_model: + # the following check is for Qwen1 base models + if cfg.is_qwen_derived_model and hasattr(tokenizer, "eod_id"): token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"] for attr_name in token_ids: if getattr(tokenizer, attr_name) is None: @@ -196,7 +212,7 @@ def load_tokenizer(cfg): for k, val in special_tokens.items(): # check if new special token is not already in tokenizer and # is adapter training to make sure lora_modules_to_save is set - # pylint: disable=too-many-boolean-expressions + if ( (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val) and (len(tokenizer.encode(val, add_special_tokens=False)) > 2) @@ -209,11 +225,12 @@ def load_tokenizer(cfg): ) and k != "pad_token" ): - lora_modules_to_save = ", ".join( + lora_modules_to_save_str = ", ".join( [f"`{x}`" for x in lora_modules_to_save] ) raise ValueError( - f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens." + f"Please set lora_modules_to_save to [{lora_modules_to_save_str}] " + "when using an adapter and changing the special tokens." ) tokenizer.add_special_tokens( @@ -259,7 +276,7 @@ def load_tokenizer(cfg): {"additional_special_tokens": additional_special_tokens} ) - if is_main_process(use_environ=True): + if is_main_process(): LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") @@ -276,8 +293,13 @@ def load_tokenizer(cfg): ) tokenizer.chat_template = chat_template_string - else: + elif getattr(tokenizer, "chat_template", None) is None: LOG.info( "No Chat template selected. Consider adding a chat template for easier inference." ) + + # make the tokenizer.pad call quieter 🤐 + if hasattr(tokenizer, "deprecation_warnings"): + tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True + return tokenizer diff --git a/src/axolotl/loaders/utils.py b/src/axolotl/loaders/utils.py index 28c935085..240e00da7 100644 --- a/src/axolotl/loaders/utils.py +++ b/src/axolotl/loaders/utils.py @@ -131,6 +131,17 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): f"Please include [{lora_modules_to_save_joined}] in `lora_modules_to_save`." ) + if ( + cfg.tensor_parallel_size + and cfg.tensor_parallel_size > 1 + and hasattr(model_config, "tie_word_embeddings") + and model_config.tie_word_embeddings + ): + raise ValueError( + "Tensor parallelism is incompatible with models configured with `tie_word_embeddings` enabled. " + "Please use a model without `tie_word_embeddings`, or disable tensor parallelism." + ) + def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict: """Loads and configures a model configuration from HuggingFace or local sources. @@ -195,9 +206,11 @@ def ensure_dtype(model: PreTrainedModel, dtype: torch.dtype = torch.bfloat16): bias_mismatch = module.bias.dtype != dtype if weight_mismatch: - print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}") + LOG.debug( + f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}" + ) if bias_mismatch: - print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}") + LOG.debug(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}") if weight_mismatch or bias_mismatch: module.to(dtype) diff --git a/src/axolotl/logging_config.py b/src/axolotl/logging_config.py index b8dc6479d..67b1d32f1 100644 --- a/src/axolotl/logging_config.py +++ b/src/axolotl/logging_config.py @@ -1,10 +1,7 @@ -""" -Common logging module for axolotl -""" +"""Common logging module for axolotl.""" import logging import os -import sys from logging import Formatter, Logger, LogRecord from logging.config import dictConfig from typing import Any, Dict @@ -17,20 +14,28 @@ DEFAULT_LOG_LEVEL = "WARNING" class AxolotlOrWarnErrorFilter(logging.Filter): """ - Allows ANY WARNING or higher (unless overridden by LOG_LEVEL) - Allows axolotl.* at INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL) - Drops all other records (i.e. non-axolotl.INFO, DEBUG, etc. by default) + Allows ANY WARNING or higher (unless overridden by LOG_LEVEL). Allows axolotl.* at + INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL). Drops all other records + (i.e. non-axolotl.INFO, DEBUG, etc. by default). """ def __init__(self, **kwargs): super().__init__(**kwargs) - self.axolotl_level = logging.getLevelNamesMapping()[ - os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL) - ] - self.other_level = logging.getLevelNamesMapping()[ - os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL) - ] + axolotl_log_level = os.getenv( + "AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL + ).upper() + other_log_level = os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL).upper() + + try: + # py311+ only + level_mapping = logging.getLevelNamesMapping() + self.axolotl_level = level_mapping[axolotl_log_level] + self.other_level = level_mapping[other_log_level] + except AttributeError: + # For py310, use getLevelName directly + self.axolotl_level = logging.getLevelName(axolotl_log_level) + self.other_level = logging.getLevelName(other_log_level) def filter(self, record: LogRecord) -> bool: # General filter @@ -44,13 +49,12 @@ class AxolotlOrWarnErrorFilter(logging.Filter): class AxolotlLogger(Logger): - """A Logger that automatically rejects non-axolotl INFOs.""" + """Logger that applies filtering to non-axolotl loggers.""" def __init__(self, name: str, level: int = logging.NOTSET): super().__init__(name, level) - - # set global filter on the logger itself - self.addFilter(AxolotlOrWarnErrorFilter()) + if not name.startswith("axolotl"): + self.addFilter(AxolotlOrWarnErrorFilter()) class ColorfulFormatter(Formatter): @@ -66,6 +70,7 @@ class ColorfulFormatter(Formatter): def format(self, record): record.rank = int(os.getenv("LOCAL_RANK", "0")) + record.rank_fmt = f" [RANK:{record.rank}]" if record.rank != 0 else "" log_message = super().format(record) return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET @@ -79,33 +84,55 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { }, "colorful": { "()": ColorfulFormatter, - "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s", + "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d]%(rank_fmt)s %(message)s", + }, + "concise": { + "format": "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s", + }, + "concise_color": { + "()": ColorfulFormatter, + "format": "[%(asctime)s] [%(levelname)s] [%(name)s]%(rank_fmt)s %(message)s", + }, + }, + "filters": { + "ax_or_warn": { + "()": "axolotl.logging_config.AxolotlOrWarnErrorFilter", }, }, - "filters": {}, "handlers": { "console": { "class": "logging.StreamHandler", - "formatter": "simple", - "filters": [], - "stream": sys.stdout, + "formatter": "concise", + "filters": ["ax_or_warn"], + "stream": "ext://sys.stdout", }, "color_console": { "class": "logging.StreamHandler", - "formatter": "colorful", - "filters": [], - "stream": sys.stdout, + "formatter": "concise_color", + "filters": ["ax_or_warn"], + "stream": "ext://sys.stdout", + }, + "ax_file_only": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "simple", + "stream": "ext://axolotl.utils.tee.file_only_stream", + }, + "root_file_only": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "simple", + "stream": "ext://axolotl.utils.tee.file_only_stream", }, }, - # log level will be superseded by the AxolotlLogger "root": { - "handlers": ["console"], - "level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL), + "handlers": ["console", "root_file_only"], + "level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL).upper(), }, "loggers": { "axolotl": { - "handlers": ["color_console"], - "level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL), + "handlers": ["color_console", "ax_file_only"], + "level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL).upper(), "propagate": False, }, }, @@ -115,9 +142,15 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { def configure_logging(): """Configure with default logging""" init() # Initialize colorama + dictConfig(DEFAULT_LOGGING_CONFIG) logging.setLoggerClass(AxolotlLogger) - # set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set + # Route Python warnings through logging so they reach file handlers + logging.captureWarnings(True) + + # Set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set if "ACCELERATE_LOG_LEVEL" not in os.environ: - os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL) + os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv( + "LOG_LEVEL", DEFAULT_LOG_LEVEL + ).upper() diff --git a/src/axolotl/models/mamba/__init__.py b/src/axolotl/models/mamba/__init__.py index fee88e3a4..d6bb40d99 100644 --- a/src/axolotl/models/mamba/__init__.py +++ b/src/axolotl/models/mamba/__init__.py @@ -21,4 +21,4 @@ def fix_mamba_attn_for_loss(): from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed - return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name + return mixer_seq_simple.MambaLMHeadModel diff --git a/src/axolotl/models/mamba/modeling_mamba.py b/src/axolotl/models/mamba/modeling_mamba.py index 70e9c88c8..2cfe11544 100644 --- a/src/axolotl/models/mamba/modeling_mamba.py +++ b/src/axolotl/models/mamba/modeling_mamba.py @@ -1,4 +1,3 @@ -# pylint: skip-file import os from collections import namedtuple from functools import partial @@ -112,7 +111,7 @@ class MambaLMHeadModel(nn.Module, GenerationMixin): self, save_directory: Union[str, os.PathLike], state_dict: Optional[dict] = None, - safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument + safe_serialization: Optional[bool] = None, ): if state_dict is None: state_dict = self.state_dict() diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index 955c06cbe..af6f24a63 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -2,102 +2,91 @@ monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts """ +import copy +import functools +import os import sys import torch +import torch.distributed as dist +from torch import nn +from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.logging import get_logger LOG = get_logger(__name__) -def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict): +def fsdp2_load_full_state_dict( + _accelerator, model: torch.nn.Module, full_sd: dict, offload_to_cpu: bool = False +): """ Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the parameters from rank 0 to all other ranks. This function modifies the model in-place. - Args: accelerator (`Accelerator`): The accelerator instance model (`torch.nn.Module`): The model to load the state dict into, expected to be on meta device or a VRAM spike can occur full_sd (`dict`): The full state dict to load, can only be on rank 0 """ - import torch.distributed as dist from torch.distributed.tensor import distribute_tensor - # Model was previously copied to meta device + LOG.info("Broadcasting full state dict to all ranks...") + import time + + start_time = time.time() + meta_sharded_sd = model.state_dict() sharded_sd = {} + for param_name, sharded_meta_param in meta_sharded_sd.items(): + full_tensor = None + if _accelerator.is_main_process: + full_tensor = full_sd[param_name] + full_tensor = full_tensor.to(sharded_meta_param.dtype) - # Rank 0 distributes the full state dict to other ranks - def _infer_parameter_dtype(model, param_name, empty_param): - try: - old_param = model.get_parameter_or_buffer(param_name) - except AttributeError: - # Need this for LORA, as there some params are not *parameters* of sorts - base_param_name, local_param_name = param_name.rsplit(".", 1) - submodule = model.get_submodule(base_param_name) - old_param = getattr(submodule, local_param_name) - - is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") - casting_dtype = None - is_param_float8_e4m3fn = ( - is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn - ) - - if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: - casting_dtype = old_param.dtype - - return old_param is not None and old_param.is_contiguous(), casting_dtype - - def _cast_and_contiguous(tensor, to_contiguous, dtype): - if dtype is not None: - tensor = tensor.to(dtype=dtype) - if to_contiguous: - tensor = tensor.contiguous() - return tensor - - param_names = sorted(meta_sharded_sd.keys()) - - for param_name in param_names: - mesh = meta_sharded_sd[param_name].device_mesh - if accelerator.is_main_process: - full_param = full_sd[param_name].detach().cuda() - dist.broadcast(full_param, src=0, group=mesh.get_group()) - sharded_tensor = distribute_tensor( - full_param, mesh, sharded_sd[param_name].placements - ) - to_contiguous, casting_dtype = _infer_parameter_dtype( - model, - param_name, - full_param, - ) - sharded_tensor = _cast_and_contiguous( - sharded_tensor, to_contiguous, casting_dtype - ) - sharded_sd[param_name] = sharded_tensor - else: - full_tensor = torch.empty( - sharded_sd[param_name].size(), - device="cuda", - dtype=sharded_sd[param_name].dtype, - ) - dist.broadcast(full_tensor, src=0, group=mesh.get_group()) - sharded_tensor = distribute_tensor( - full_tensor, mesh, sharded_sd[param_name].placements - ) - to_contiguous, casting_dtype = _infer_parameter_dtype( - model, - param_name, + if hasattr(sharded_meta_param, "device_mesh"): + device_mesh = sharded_meta_param.device_mesh + if _accelerator.is_main_process: + full_tensor = full_tensor.to(device_mesh.device_type) + else: + full_tensor = torch.empty( + sharded_meta_param.size(), + device=device_mesh.device_type, + dtype=sharded_meta_param.dtype, + ) + sharded_param = distribute_tensor( full_tensor, + device_mesh, + sharded_meta_param.placements, + src_data_rank=0, ) - sharded_tensor = _cast_and_contiguous( - sharded_tensor, to_contiguous, casting_dtype - ) - sharded_sd[param_name] = sharded_tensor + else: + # Non-sharded parameters + if _accelerator.is_main_process: + sharded_param = full_tensor.to(torch.device("cuda")) + else: + # broadcast manually + sharded_param = torch.empty_like( + sharded_meta_param, + device=torch.device("cuda"), + dtype=sharded_meta_param.dtype, + ) + dist.broadcast(sharded_param, src=0) - # we set `assign=True` because our params are on meta device - model.load_state_dict(sharded_sd, assign=True) + if offload_to_cpu: + sharded_param = sharded_param.cpu() + + sharded_sd[param_name] = nn.Parameter(sharded_param) + + del full_tensor + full_sd[param_name] = None + + model.load_state_dict(sharded_sd, assign=True, strict=True) + end_time = time.time() + LOG.debug( + f"Time taken to load full state dict: {(end_time - start_time):.2f} seconds" + ) + log_gpu_memory_usage(LOG, "Memory usage after broadcasting full state dict", 0) return model @@ -142,9 +131,9 @@ def get_state_dict(self, model, unwrap=True): "Deepspeed TP requires deepspeed >= 0.16.4, Please update DeepSpeed via `pip install deepspeed -U`." ) state_dict = ( - model._consolidated_16bit_state_dict() # pylint: disable=protected-access + model._consolidated_16bit_state_dict() if tp_sharding - else model._zero3_consolidated_16bit_state_dict() # pylint: disable=protected-access + else model._zero3_consolidated_16bit_state_dict() ) else: raise ValueError( @@ -172,9 +161,11 @@ def get_state_dict(self, model, unwrap=True): state_dict[param_name] = param.cpu() torch.distributed.barrier() elif self.distributed_type == DistributedType.FSDP: - from torch.distributed.fsdp import FullStateDictConfig - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from torch.distributed.fsdp import StateDictType + from torch.distributed.fsdp import ( + FullStateDictConfig, + FullyShardedDataParallel as FSDP, + StateDictType, + ) full_state_dict_config = FullStateDictConfig( offload_to_cpu=True, rank0_only=True @@ -191,17 +182,204 @@ def get_state_dict(self, model, unwrap=True): return state_dict -def patch_accelerate_fsdp2(): - import accelerate - from accelerate.utils import fsdp_utils +def _process_lora_module_for_fsdp(module, fsdp2_kwargs): + """Helper function to process LoRA modules for FSDP2.""" + from torch.distributed.fsdp import fully_shard - fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict - setattr( - sys.modules["accelerate.utils.fsdp_utils"], - "fsdp2_load_full_state_dict", - fsdp2_load_full_state_dict, + log_bias_dtype_mismatch = False + + # Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to + # wrap this. Therefore we must ensure the bias has the same dtype as the weight + if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None: + if module.base_layer.weight.dtype != module.base_layer.bias.dtype: + log_bias_dtype_mismatch = True + module.base_layer.bias.data = module.base_layer.bias.data.to( + module.base_layer.weight.dtype + ) + + for active_adapter in module.active_adapters: + if module.lora_A: + fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs) + if module.lora_B: + fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs) + if module.lora_embedding_A: + fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs) + if module.lora_embedding_B: + fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs) + if module.lora_magnitude_vector: + fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs) + return log_bias_dtype_mismatch + + +def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: + """Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model. + + Args: + accelerator (`Accelerator`): The accelerator instance + model (`torch.nn.Module`): The model to prepare + + Returns: + `torch.nn.Module`: Prepared model + """ + from accelerate.utils import get_module_children_bottom_up, is_compiled_module + from accelerate.utils.fsdp_utils import fsdp2_prepare_auto_wrap_policy + from accelerate.utils.modeling import get_non_persistent_buffers + from peft import PeftModel + from peft.tuners.lora import LoraLayer + from torch.distributed.fsdp import ( + CPUOffloadPolicy, + FSDPModule, + MixedPrecisionPolicy, + fully_shard, ) + is_type_fsdp = isinstance(model, FSDPModule) or ( + is_compiled_module(model) and isinstance(model._orig_mod, FSDPModule) + ) + if is_type_fsdp: + return model + + fsdp2_plugin = accelerator.state.fsdp_plugin + + original_sd = model.state_dict() + + from torch.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + transformer_auto_wrap_policy, + ) + + # We need the `auto_wrap_policy` original type to create a custom poilicy function for sharding + # This is because `fully_shard` doesn't support old auto wrap policies, rather we have to imitate the behaviour + if fsdp2_plugin.auto_wrap_policy is transformer_auto_wrap_policy: + pass # auto_wrap_policy_type = "transformer" + elif fsdp2_plugin.auto_wrap_policy is size_based_auto_wrap_policy: + pass # auto_wrap_policy_type = "size" + + # We set `auto_wrap_policy` to `functools.partial` to avoid creating it again + # This is because of `apply_activation_checkpointing` which will can reuse this function + fsdp2_plugin.set_auto_wrap_policy(model) + + if fsdp2_plugin.activation_checkpointing: + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + apply_activation_checkpointing, + checkpoint_wrapper, + ) + + # Apply activation checkpointing before applying `fully_shard` + apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=functools.partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ), + auto_wrap_policy=fsdp2_plugin.auto_wrap_policy, + ) + + mesh = getattr(accelerator.state, "device_mesh", None) + + # Disable memory pinning if requested + offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy) + if offload_to_cpu and os.environ.get("FSDP_CPU_OFFLOAD_PIN_MEMORY", "") == "false": + fsdp2_plugin.cpu_offload.pin_memory = False + + fsdp2_kwargs = { + "reshard_after_forward": fsdp2_plugin.reshard_after_forward, + "offload_policy": fsdp2_plugin.cpu_offload, + # `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy` + "mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(), + "mesh": ( + mesh[tuple(accelerator.state.parallelism_config.fsdp_dim_names)] + if mesh is not None + else None + ), + } + model_has_params4bit = False + for _, param in model.named_parameters(): + # this is a temporary fix whereby loading models with bnb params cannot be moved from + # GPU to a meta device due with FSDP2 because torch operations don't return the original class type + # bypassing the move to meta will still cause the VRAM spike, but at least it still will load + if param.__class__.__name__ == "Params4bit": + model_has_params4bit = True + break + + if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit: + # Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard` + # For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device + # If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.emtpy`), `fully_shard` would move it to GPU + # Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike + + # We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device + # Also, these buffers aren't getting sharded by default + # We get the FQNs of all non-persistent buffers, to re-register them after + non_persistent_buffer_fqns = get_non_persistent_buffers( + model, recurse=True, fqns=True + ) + original_non_persistent_buffers = copy.deepcopy( + {k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns} + ) + # We move the model to meta device, as then sharding happens on meta device + model = model.to(torch.device("meta")) + # We need to re-tie the weights, not exactly sure why, but if we don't do this, reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage + # We assume `transformers` models have a `tie_weights` method if they support it + if hasattr(model, "tie_weights"): + model.tie_weights() + + is_peft_model = isinstance(model, PeftModel) + + auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model) + log_bias_dtype_mismatch = False + if auto_wrap_policy is not None: + for module in get_module_children_bottom_up(model)[:-1]: + if is_peft_model and isinstance(module, LoraLayer): + module_log_bias_mismatch = _process_lora_module_for_fsdp( + module, fsdp2_kwargs + ) + log_bias_dtype_mismatch |= module_log_bias_mismatch + if auto_wrap_policy(module) and not isinstance(module, FSDPModule): + fully_shard(module, **fsdp2_kwargs) + + fully_shard(model, **fsdp2_kwargs) + + if log_bias_dtype_mismatch: + LOG.warning( + "Bias dtype mismatch detected in LoRA base linear layer. Bias parameters have been cast to weight dtype." + ) + + if fsdp2_plugin.cpu_ram_efficient_loading: + fsdp2_load_full_state_dict( + accelerator, model, original_sd, offload_to_cpu=offload_to_cpu + ) + + if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit: + # We re-register the buffers, as they may not be in the state_dict + for fqn, buffer_tensor in original_non_persistent_buffers.items(): + buffer_tensor = buffer_tensor.to(accelerator.device) + + if "." in fqn: + parent_fqn, local_buffer_name = fqn.rsplit(".", 1) + parent_module = model.get_submodule(parent_fqn) + else: + local_buffer_name = fqn + parent_module = model + + parent_module.register_buffer( + local_buffer_name, buffer_tensor, persistent=False + ) + + # We need to tie the weights again, as call to `load_full_state_dict` breaks the tie + # Needs to be called both here and above + # removing this call makes the have slightly different loss + # removing the call above leads to extra memory usage as explained in the comment above + if hasattr(model, "tie_weights"): + model.tie_weights() + return model + + +def patch_accelerate_fsdp2(): + import accelerate + + accelerate.accelerator.fsdp2_prepare_model = fsdp2_prepare_model accelerate.Accelerator.get_state_dict = get_state_dict setattr( sys.modules["accelerate"], diff --git a/src/axolotl/monkeypatch/accelerate/parallelism_config.py b/src/axolotl/monkeypatch/accelerate/parallelism_config.py new file mode 100644 index 000000000..b2157fb6b --- /dev/null +++ b/src/axolotl/monkeypatch/accelerate/parallelism_config.py @@ -0,0 +1,77 @@ +""" +workaround to allow parallelism config for pure CP +""" + +import os +import warnings + +from accelerate import DistributedType + + +def _validate_accelerator(self, accelerator): + _warnings = set() + if not accelerator.multi_device and self.total_size == 1: + # No distributed setup, valid parallelism config + return + + # We need this to ensure DDP works + if self.total_size == 1: + self._set_size("dp_replicate", accelerator.num_processes) + + if self.total_size != accelerator.num_processes: + raise ValueError( + f"ParallelismConfig total_size ({self.total_size}) does not match " + f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ " + f"dp_shard_size/tp_size/cp_size." + ) + + # allow parallelism config when not using fsdp if using pure context parallelism + allow_parallelism_config = False + + if ( + self.cp_size > 1 + and self.dp_shard_size <= 1 + and os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true" + ): + allow_parallelism_config = True + + if ( + self.total_size > 1 + and not allow_parallelism_config + and not (accelerator.is_fsdp2 or accelerator.multi_device) + ): + raise ValueError( + f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}." + ) + + for parallelism, size in self._sizes.items(): + if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None: + _warnings.add( + f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored." + ) + + if _warnings and accelerator.is_main_process: + warnings.warn( + "ParallelismConfig has the following warnings:\n" + "\n".join(_warnings), + UserWarning, + stacklevel=2, + ) + + +def patched_is_fsdp2(self) -> bool: + """ + Patched version of is_fsdp2 that guards against a None fsdp_plugin. + """ + # The new logic checks if fsdp_plugin exists before accessing its attributes + return ( + self.distributed_type == DistributedType.FSDP + and self.fsdp_plugin + and self.fsdp_plugin.fsdp_version == 2 + ) + + +def patch_parallelism_config(): + from accelerate.accelerator import AcceleratorState, ParallelismConfig + + ParallelismConfig._validate_accelerator = _validate_accelerator + AcceleratorState.is_fsdp2 = property(patched_is_fsdp2) diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index 3652a30b3..678f65bee 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -1,10 +1,15 @@ """Flex attention monkey patch""" import sys -from typing import Optional, Tuple, Union import torch import transformers +from packaging import version +from transformers.utils.import_utils import _torch_version, is_torch_less_or_equal + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) def patch_flex_wrapper(**flex_attn_compile_kwargs): @@ -42,162 +47,39 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs): """ self.training = None if not self._is_flex_compiled or training != self.training: + self.training = training + if is_torch_less_or_equal("2.5.1"): + self._compiled_flex_attention = torch.compile( + flex_attention, dynamic=False + ) # In PyTorch 2.6.0, there's a known issue with flex attention compilation which may # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs" # see https://github.com/pytorch/pytorch/issues/146260 for training - self.training = training - self._compiled_flex_attention = torch.compile( - flex_attention, - **flex_attn_compile_kwargs, - ) + elif version.parse(_torch_version).base_version == "2.6.0" and training: + self._compiled_flex_attention = torch.compile( + flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs" + ) + # Fallback, usually the most recent torch 2.7.x+ versions + else: + LOG.info( + "Compiling flex attention with kwargs: %s. This may take a while...", + flex_attn_compile_kwargs, + main_process_only=True, + ) + self._compiled_flex_attention = torch.compile( + flex_attention, + **flex_attn_compile_kwargs, + ) + LOG.info( + "Flex attention compiled successfully.", main_process_only=True + ) + self._is_flex_compiled = True def __call__(self): return self._compiled_flex_attention transformers.integrations.flex_attention.WrappedFlexAttention = WrappedFlexAttention - setattr( - sys.modules["transformers.integrations.flex_attention"], - "WrappedFlexAttention", - WrappedFlexAttention, - ) - - -def patch_flex_make_mask(): - is_torch_2_6 = torch.__version__.startswith("2.6") - - if not is_torch_2_6: - return - - from torch.nn.attention.flex_attention import ( - _DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size, - ) - from torch.nn.attention.flex_attention import ( - BlockMask, - ) - from torch.nn.attention.flex_attention import ( - create_block_mask as create_block_causal_mask_flex, - ) - - Offset = Union[torch.Tensor, int] - - def patched_make_flex_block_causal_mask( - attention_mask_2d: torch.Tensor, - attention_chunk_size: Optional[int] = None, - query_length=None, - key_length=None, - offsets: Optional[Tuple[Offset, Offset]] = None, - ) -> "BlockMask": - """ - Create a block causal document mask for a batch of sequences, both packed and unpacked. - Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`. - The resultant BlockMask is a compressed representation of the full block causal - mask. BlockMask is essential for performant computation of flex attention. - See: https://pytorch.org/blog/flexattention/ - - Args: - attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences - of shape (batch_size, total_seq_len). e.g. - - For unpacked sequence: - [[1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 0, 0]] - - For packed sequence: - [[1, 1, 1, 2, 2, 2, 0], - [1, 1, 2, 2, 2, 3, 3]] - - Returns: - BlockMask - """ - - batch_size, total_seq_len = attention_mask_2d.shape - if not key_length: - key_length = total_seq_len - if not query_length: - query_length = total_seq_len - attention_mask_2d = torch.nn.functional.pad( - attention_mask_2d, - value=0, - pad=(0, abs(total_seq_len - max(key_length, flex_default_block_size))), - ) - device = attention_mask_2d.device - document_ids = attention_mask_2d.clone() - - if attention_chunk_size is not None: - # we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3] - chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // ( - attention_chunk_size - ) - - # Instead of passing a tensor mask, flex attention requires a mask_mod function - # that determines which elements of QK^T should be included in the attention - # computation prior to the softmax. For sample packing, we need both the - # logic for both causal mask and document mask. See PyTorch's official - # blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods - def causal_mask_mod( - batch_idx, head_idx, q_idx, kv_idx - ): # pylint: disable=unused-argument - """ - Defines the logic of a block causal mask by combining both a standard causal mask - and a block diagonal document mask. - - See :func:`~torchtune.modules.attention_utils.create_block_causal_mask` - for an illustration. - """ - causal_mask = q_idx >= kv_idx # not valid when decoding - document_mask = ( - document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx] - ) - padding_mask = attention_mask_2d[batch_idx, q_idx] > 0 - final_mask = causal_mask & padding_mask & document_mask - return final_mask - - def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): - """ - Combines the chunk mask with the causal mask for chunked attention. - """ - chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx] - causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx) - return chunk_mask & causal_doc_mask - - mask_mod_maybe_combined = ( - causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod - ) - - if offsets is not None: - q_offset = offsets[0] - kv_offset = offsets[1] - - def mask_mod(batch_idx, head_idx, q_idx, kv_idx): - offset_q = q_idx + q_offset - offset_kv = kv_idx + kv_offset - return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv) - - else: - mask_mod = mask_mod_maybe_combined - return create_block_causal_mask_flex( - mask_mod=mask_mod, - B=batch_size, - H=None, # attention head - Q_LEN=query_length, - KV_LEN=key_length, - device=device, - _compile=True, - ) - - for n in tuple(sys.modules): - if ".modeling_" in n: - if hasattr(sys.modules[n], "make_flex_block_causal_mask"): - sys.modules[n].make_flex_block_causal_mask = ( - patched_make_flex_block_causal_mask - ) - setattr( - sys.modules[n], - "make_flex_block_causal_mask", - patched_make_flex_block_causal_mask, - ) - - transformers.integrations.flex_attention.make_flex_block_causal_mask = ( - patched_make_flex_block_causal_mask - ) + sys.modules[ + "transformers.integrations.flex_attention" + ].WrappedFlexAttention = WrappedFlexAttention diff --git a/src/axolotl/monkeypatch/attention/xformers.py b/src/axolotl/monkeypatch/attention/xformers.py index 5901963f0..eca95797a 100644 --- a/src/axolotl/monkeypatch/attention/xformers.py +++ b/src/axolotl/monkeypatch/attention/xformers.py @@ -23,15 +23,15 @@ def xformers_attention_forward( value: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - dropout: float = 0.0, # pylint: disable=unused-argument - scaling: Optional[float] = None, # pylint: disable=unused-argument - sliding_window: Optional[int] = None, # pylint: disable=unused-argument - softcap: Optional[float] = None, # pylint: disable=unused-argument + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, cu_seq_lens_q: Optional[torch.LongTensor] = None, cu_seq_lens_k: Optional[torch.LongTensor] = None, max_length_q: Optional[int] = None, - max_length_k: Optional[int] = None, # pylint: disable=unused-argument - **kwargs, # pylint: disable=unused-argument + max_length_k: Optional[int] = None, + **kwargs, ): # Get dimensions # query: [batch, heads, seq_len, hidden_dim] diff --git a/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py b/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py index 589980c8b..2c5077392 100644 --- a/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py @@ -25,9 +25,7 @@ def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"): ".configuration_btlm", ".modeling_btlm" ) modeling_btlm = importlib.import_module(module_name) - modeling_btlm.BTLMAttention._attn = ( # pylint: disable=protected-access - flashattn_attn - ) + modeling_btlm.BTLMAttention._attn = flashattn_attn def flashattn_attn( @@ -35,9 +33,9 @@ def flashattn_attn( query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, - position_bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + position_bias: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: softmax_scale = ( 1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None diff --git a/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py b/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py index df8d106fd..c426344a6 100644 --- a/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py +++ b/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py @@ -1,15 +1,23 @@ -"""monkey patches for the dataset fetcher to handle batches of packed indexes""" - -# pylint: disable=protected-access +"""Monkey patches for the dataset fetcher to handle batches of packed indexes.""" import torch from torch.utils.data._utils.fetch import _BaseDatasetFetcher from torch.utils.data._utils.worker import _worker_loop +_ORIGINAL_MAP_DATASET_FETCHER = None +_ORIGINAL_WORKER_LOOP = None +_IS_PATCHED = False + class _MapDatasetFetcher(_BaseDatasetFetcher): + """ + Custom dataset fetcher that handles nested batch structures from + MultipackBatchSampler. + """ + def fetch(self, possibly_batched_index): if isinstance(possibly_batched_index[0], list): + # Handle nested structure from MultipackBatchSampler data = [None for i in possibly_batched_index] for i, possibly_batched_index_ in enumerate(possibly_batched_index): if self.auto_collation: @@ -23,6 +31,7 @@ class _MapDatasetFetcher(_BaseDatasetFetcher): else: data[i] = self.dataset[possibly_batched_index_] else: + # Standard batch handling if self.auto_collation: if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__: data = self.dataset.__getitems__(possibly_batched_index) @@ -34,14 +43,54 @@ class _MapDatasetFetcher(_BaseDatasetFetcher): def patch_fetchers(): + """Apply patches to PyTorch's DataLoader components.""" torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher def patched_worker_loop(*args, **kwargs): + """Worker loop that ensures patches are applied in worker processes.""" patch_fetchers() return _worker_loop(*args, **kwargs) -torch.utils.data._utils.worker._worker_loop = patched_worker_loop -patch_fetchers() +def apply_multipack_dataloader_patch(): + """ + This patch allows DataLoader to correctly process batches that contain multiple bins + of packed sequences. + """ + # pylint: disable=global-statement + global _ORIGINAL_MAP_DATASET_FETCHER, _ORIGINAL_WORKER_LOOP, _IS_PATCHED + + if _IS_PATCHED: + return + + # Store original implementations + _ORIGINAL_MAP_DATASET_FETCHER = torch.utils.data._utils.fetch._MapDatasetFetcher + _ORIGINAL_WORKER_LOOP = torch.utils.data._utils.worker._worker_loop + + # Apply patches + patch_fetchers() + torch.utils.data._utils.worker._worker_loop = patched_worker_loop + + _IS_PATCHED = True + + +def remove_multipack_dataloader_patch(): + """Remove the monkeypatch and restore original PyTorch DataLoader behavior.""" + # pylint: disable=global-statement + global _IS_PATCHED + + if not _IS_PATCHED: + return + + if _ORIGINAL_MAP_DATASET_FETCHER: + torch.utils.data._utils.fetch._MapDatasetFetcher = _ORIGINAL_MAP_DATASET_FETCHER + torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = ( + _ORIGINAL_MAP_DATASET_FETCHER + ) + + if _ORIGINAL_WORKER_LOOP: + torch.utils.data._utils.worker._worker_loop = _ORIGINAL_WORKER_LOOP + + _IS_PATCHED = False diff --git a/src/axolotl/monkeypatch/deepspeed_utils.py b/src/axolotl/monkeypatch/deepspeed_utils.py new file mode 100644 index 000000000..d7e69e112 --- /dev/null +++ b/src/axolotl/monkeypatch/deepspeed_utils.py @@ -0,0 +1,67 @@ +import importlib +import importlib.util + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def patch_checkpoint_wrapper_setattr(): + """ + Patch CheckpointWrapper to properly forward DeepSpeed attributes to wrapped modules. + + This fixes the issue where CheckpointWrapper doesn't forward ds_* attributes + (like ds_grads_remaining) to the actual wrapped module, causing DeepSpeed + ZeRO-3 to fail when gradient checkpointing is enabled. + + This issue occurs specifically with: + - QLoRA + DeepSpeed ZeRO-3 + - gradient_checkpointing: true + - activation_offloading: true + + References: + - https://github.com/deepspeedai/DeepSpeed/issues/7203 + - https://github.com/deepspeedai/DeepSpeed/blob/38d1a9eb64c9e01e32eccc50b25ba18925287441/deepspeed/runtime/zero/parameter_offload.py#L424-L458 + - https://github.com/axolotl-ai-cloud/axolotl/pull/3102 + """ + + try: + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointWrapper, + ) + + # Check if already patched + if hasattr(CheckpointWrapper, "_axolotl_setattr_patched"): + LOG.debug("CheckpointWrapper already patched") + return + + original_setattr = CheckpointWrapper.__setattr__ + + def new_setattr(self, name: str, value) -> None: + if name.startswith("ds_") and hasattr(self, "_checkpoint_wrapped_module"): + setattr(self._checkpoint_wrapped_module, name, value) + LOG.debug( + f"Forwarded {name} to wrapped module {type(self._checkpoint_wrapped_module).__name__}" + ) + else: + original_setattr(self, name, value) + + CheckpointWrapper.__setattr__ = new_setattr + CheckpointWrapper._axolotl_setattr_patched = True + + LOG.info("CheckpointWrapper patched to forward DeepSpeed attributes") + + except ImportError as e: + LOG.debug(f"CheckpointWrapper not available: {e}") + except Exception as e: + LOG.warning(f"Failed to patch CheckpointWrapper: {e}") + + +def apply_deepspeed_patches(): + """ + Apply DeepSpeed-related patches + """ + if importlib.util.find_spec("deepspeed") is not None: + patch_checkpoint_wrapper_setattr() + else: + LOG.debug("DeepSpeed not available, skipping patches") diff --git a/src/axolotl/monkeypatch/fsdp2_qlora.py b/src/axolotl/monkeypatch/fsdp2_qlora.py new file mode 100644 index 000000000..04d0d1971 --- /dev/null +++ b/src/axolotl/monkeypatch/fsdp2_qlora.py @@ -0,0 +1,143 @@ +""" +Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as +our LoRA / QLoRA Triton kernels to work with FSDP2. + +This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes +Params4bit parameters. +""" + +import importlib +import inspect + +from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def apply_init_sharded_param_patch(): + """Apply patch to FSDPParam._init_sharded_param to support Params4bit.""" + from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam + + # Get original source + original_source = inspect.getsource(FSDPParam._init_sharded_param) + original_source, _ = detab_code(original_source) + + # Define the replacement + original_param_creation = """ self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param)) + self.sharded_param.requires_grad_(param.requires_grad)""" + + patched_param_creation = """ import bitsandbytes as bnb + if isinstance(param, bnb.nn.modules.Params4bit): + self.sharded_param = bnb.nn.modules.Params4bit( + data=sharded_param, + requires_grad=param.requires_grad, + quant_state=param.quant_state, + blocksize=param.blocksize, + compress_statistics=param.compress_statistics, + quant_type=param.quant_type, + quant_storage=param.quant_storage, + module=param.module, + bnb_quantized=param.bnb_quantized, + ) + self.sharded_param = self.to_sharded_dtensor(self.sharded_param) + else: + self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param)) + self.sharded_param.requires_grad_(param.requires_grad)""" + + # Apply the replacement + if original_param_creation in original_source: + patched_source = original_source.replace( + original_param_creation, patched_param_creation + ) + patched_source = patched_source.replace( + "def _init_sharded_param(", + "def patched_init_sharded_param(", + 1, + ) + + # Load necessary imports + module_name = FSDPParam.__module__ + module = importlib.import_module(module_name) + + items_to_import = [] + for item in dir(module): + if item in patched_source: + items_to_import.append(item) + + exec( # nosec B102 + f"from {module_name} import ({', '.join(items_to_import)})", + globals(), + ) + exec(patched_source, globals()) # nosec B102 + + # Replace the method + FSDPParam._init_sharded_param = patched_init_sharded_param + LOG.info("Successfully applied FSDP _init_sharded_param patch") + else: + LOG.warning("Could not find target code for _init_sharded_param patching") + + +def apply_init_unsharded_param_patch(): + """Apply patch to FSDPParam.init_unsharded_param to support Params4bit.""" + from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam + + # Get original source + original_source = inspect.getsource(FSDPParam.init_unsharded_param) + original_source, _ = detab_code(original_source) + + # Define the replacement + original_param_creation = """ self._unsharded_param = nn.Parameter( + unsharded_param, requires_grad=self.sharded_param.requires_grad + )""" + + patched_param_creation = """ import bitsandbytes as bnb + local_tensor = self.sharded_param._local_tensor + if isinstance(local_tensor, bnb.nn.modules.Params4bit): + self._unsharded_param = bnb.nn.modules.Params4bit( + data=unsharded_param, + requires_grad=self.sharded_param.requires_grad, + quant_state=local_tensor.quant_state, + blocksize=local_tensor.blocksize, + compress_statistics=local_tensor.compress_statistics, + quant_type=local_tensor.quant_type, + quant_storage=local_tensor.quant_storage, + module=local_tensor.module, + bnb_quantized=local_tensor.bnb_quantized, + ) + else: + self._unsharded_param = nn.Parameter( + unsharded_param, requires_grad=self.sharded_param.requires_grad + )""" + + # Apply the replacement + if original_param_creation in original_source: + patched_source = original_source.replace( + original_param_creation, patched_param_creation + ) + patched_source = patched_source.replace( + "def init_unsharded_param(", + "def patched_init_unsharded_param(", + 1, + ) + + # Load necessary imports + module_name = FSDPParam.__module__ + module = importlib.import_module(module_name) + + items_to_import = [] + for item in dir(module): + if item in patched_source: + items_to_import.append(item) + + exec( # nosec B102 + f"from {module_name} import ({', '.join(items_to_import)})", + globals(), + ) + exec(patched_source, globals()) # nosec B102 + + # Replace the method + FSDPParam.init_unsharded_param = patched_init_unsharded_param + LOG.info("Successfully applied FSDP init_unsharded_param patch") + else: + LOG.warning("Could not find target code for patching") diff --git a/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py b/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py index 5d631776b..b58bbb67c 100644 --- a/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py +++ b/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py @@ -5,7 +5,7 @@ from functools import partial from packaging import version -from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( +from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( # noqa: F401 CPU_Offloaded_Gradient_Checkpointer, ) from axolotl.monkeypatch.gradient_checkpointing.offload_disk import ( @@ -25,9 +25,7 @@ else: return False -def hf_grad_checkpoint_offload_wrapper( - decoder_layer, *args, use_reentrant=None -): # pylint: disable=unused-argument +def hf_grad_checkpoint_offload_wrapper(decoder_layer, *args, use_reentrant=None): if uses_gc_layers(decoder_layer): return CPU_Offloaded_Gradient_Checkpointer.apply( decoder_layer, @@ -44,9 +42,7 @@ def hf_grad_checkpoint_offload_wrapper( ) -def hf_grad_checkpoint_disk_offload_wrapper( - decoder_layer, *args, use_reentrant=None -): # pylint: disable=unused-argument +def hf_grad_checkpoint_disk_offload_wrapper(decoder_layer, *args, use_reentrant=None): if uses_gc_layers(decoder_layer): return Disco.apply( decoder_layer, diff --git a/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py b/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py index bbb5ad40d..8d06f172d 100644 --- a/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py +++ b/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py @@ -13,8 +13,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import inspect + import torch from packaging import version +from torch.utils.checkpoint import ( + set_device_states, +) + +# support different pytorch versions +has_device_type = "device_type" in inspect.signature(set_device_states).parameters torch_version = version.parse(torch.__version__) @@ -26,9 +35,7 @@ else: torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda") -class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name - torch.autograd.Function -): +class CPU_Offloaded_Gradient_Checkpointer(torch.autograd.Function): """ Saves VRAM by smartly offloading to RAM. Tiny hit to performance, since we mask the movement via non blocking calls. @@ -57,6 +64,4 @@ class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name return ( None, hidden_states.grad, - ) + ( - None, - ) * len(ctx.args) + ) + (None,) * len(ctx.args) diff --git a/src/axolotl/monkeypatch/gradient_checkpointing/offload_disk.py b/src/axolotl/monkeypatch/gradient_checkpointing/offload_disk.py index 792d3c6ef..220799fbf 100644 --- a/src/axolotl/monkeypatch/gradient_checkpointing/offload_disk.py +++ b/src/axolotl/monkeypatch/gradient_checkpointing/offload_disk.py @@ -62,9 +62,9 @@ class DiskOffloadManager: # Track tensor paths and their status self.tensor_paths: deque = deque() # Ordered history of tensor paths (LIFO) - self.file_locks: Dict[str, threading.Lock] = ( - {} - ) # Maps file_path -> threading.Lock() + self.file_locks: Dict[ + str, threading.Lock + ] = {} # Maps file_path -> threading.Lock() # Maps file_path -> status ("saving", "ready", "prefetching", "loaded", "deleted") self.file_status: Dict[str, str] = {} @@ -236,7 +236,7 @@ class DiskOffloadManager: self.tensor_paths.append(file_path) # Acquire semaphore to limit concurrent save operations - self.save_semaphore.acquire() # pylint: disable=consider-using-with + self.save_semaphore.acquire() # Queue tensor for saving in background self.save_queue.put((tensor.detach(), file_path)) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 70e36714c..3953cb138 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -2,40 +2,28 @@ # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py +import importlib.util import warnings -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple import torch -import torch.nn.functional as F import transformers from einops import rearrange from flash_attn.bert_padding import pad_input, unpad_input -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import ( - LlamaAttention, -) -from transformers.models.llama.modeling_llama import ( - LlamaDecoderLayer as OriginalLlamaDecoderLayer, -) from transformers.models.llama.modeling_llama import ( LlamaMLP, apply_rotary_pos_emb, repeat_kv, ) -from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name +from axolotl.monkeypatch.utils import set_module_name from axolotl.utils.logging import get_logger try: - from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports - flash_attn_kvpacked_func, - flash_attn_varlen_kvpacked_func, + from flash_attn.flash_attn_interface import ( flash_attn_varlen_qkvpacked_func, ) except ImportError: - from flash_attn.flash_attn_interface import ( - flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func, - ) from flash_attn.flash_attn_interface import ( flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func, ) @@ -45,12 +33,7 @@ LOG = get_logger(__name__) def is_xformers_available() -> bool: - try: - import xformers # pylint: disable=unused-import # noqa: F401 - - return True - except ImportError: - return False + return importlib.util.find_spec("xformers") is not None def is_xformers_swiglu_available() -> bool: @@ -82,19 +65,6 @@ def replace_llama_mlp_with_swiglu(model): set_module_name(model, name, mlp) -def replace_llama_qkv_with_fused(model): - for name, module in model.named_modules(): - if isinstance(module, LlamaAttention): - qkv = FusedAttention( - module.config, - module.q_proj, - module.k_proj, - module.v_proj, - module.o_proj, - ) - set_module_name(model, name, qkv) - - def patch_fa_llama_cross_entropy(): LOG.info( "patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy" @@ -109,7 +79,7 @@ def patch_fa_llama_cross_entropy(): num_items_in_batch: int = None, ignore_index: int = -100, **kwargs, - ): # pylint: disable=unused-argument + ): reduction = "sum" if num_items_in_batch is not None else "mean" loss, _ = flash_attn_cross_entropy_loss( source, target, ignore_index=ignore_index @@ -142,28 +112,15 @@ def patch_llama_rms_norm(): def replace_llama_attn_with_flash_attn( - packed: Optional[bool] = False, cross_entropy: Optional[bool] = False, rms_norm: Optional[bool] = False, use_shifted_sparse_attn: Optional[bool] = False, ): - transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access - _prepare_decoder_attention_mask - ) + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask if use_shifted_sparse_attn: transformers.models.llama.modeling_llama.LlamaAttention.forward = ( flashattn_forward_with_s2attn ) - else: - transformers.models.llama.modeling_llama.LlamaAttention.forward = ( - flashattn_forward - ) - - if packed: - transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer - transformers.models.llama.modeling_llama.LlamaModel.forward = ( - llama_model_forward - ) # skip only if explicitly disabled if cross_entropy: @@ -174,49 +131,6 @@ def replace_llama_attn_with_flash_attn( patch_llama_rms_norm() -class FusedAttention(LlamaAttention): - """ - Fused QKV Attention layer for incrementally improved training efficiency - """ - - def __init__( - self, - config, - q: torch.nn.Linear, # pylint: disable=invalid-name - k: torch.nn.Linear, # pylint: disable=invalid-name - v: torch.nn.Linear, # pylint: disable=invalid-name - o: torch.nn.Linear, # pylint: disable=invalid-name - ): - super().__init__(config) - self.config = config - self.init_device = next(iter(q.state_dict().values())).device - - # define equivalent fused qkv projection - self.out_features: List[int] = [q.out_features, k.out_features, v.out_features] - self.qkv_proj = torch.nn.Linear( - q.in_features, sum(self.out_features), device=self.init_device, bias=False - ) - self.o_proj = o - - # overwrite initialized weights with pretrained weights - self.qkv_proj.weight.data = torch.cat( - (q.weight.data, k.weight.data, v.weight.data), dim=0 - ) - - def _post_training(self, model, name): - q_proj, k_proj, v_proj = torch.split( - self.qkv_proj.weight.data, self.out_features, dim=0 - ) - - new_attn = LlamaAttention(self.config) - new_attn.q_proj.weight.data = q_proj - new_attn.k_proj.weight.data = k_proj - new_attn.v_proj.weight.data = v_proj - new_attn.o_proj.weight.data = self.o_proj.weight.data - - set_module_name(model, name, new_attn) - - # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask def _prepare_decoder_attention_mask( @@ -225,7 +139,7 @@ def _prepare_decoder_attention_mask( input_shape, inputs_embeds, past_key_values_length, -): # pylint: disable=unused-argument +): # [bsz, seq_len] return attention_mask @@ -241,9 +155,9 @@ def flashattn_forward_with_s2attn( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument - cu_seqlens: Optional[torch.Tensor] = None, # pylint: disable=unused-argument - max_seqlen: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + padding_mask: Optional[torch.LongTensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel @@ -256,7 +170,8 @@ def flashattn_forward_with_s2attn( """ if output_attentions: warnings.warn( - "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." + "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.", + stacklevel=2, ) bsz, q_len, _ = hidden_states.size() @@ -278,7 +193,6 @@ def flashattn_forward_with_s2attn( ) # [bsz, q_len, nh, hd] # [bsz, nh, q_len, hd] - # pylint: disable=duplicate-code cos, sin = self.rotary_emb(value_states, position_ids=position_ids) query_states, key_states = apply_rotary_pos_emb( @@ -324,9 +238,7 @@ def flashattn_forward_with_s2attn( .permute(0, 3, 1, 2, 4, 5) .reshape(bsz * 2, q_len, 3, self.num_heads // 2, self.head_dim) ) - x = rearrange( # pylint: disable=invalid-name - qkv, "b s three h d -> b s (three h d)" - ) + x = rearrange(qkv, "b s three h d -> b s (three h d)") x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) cu_q_len_tmp = torch.arange( 0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype @@ -355,576 +267,3 @@ def flashattn_forward_with_s2attn( .reshape(bsz, q_len, nheads, self.head_dim) ) return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value - - -def flashattn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel - - attention_mask: [bsz, q_len] - """ - # pylint: disable=duplicate-code - bsz, q_len, _ = hidden_states.size() - - if not hasattr(self, "pretraining_tp"): - self.pretraining_tp = 1 - - if self.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * self.head_dim - ) // self.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - if isinstance(self, FusedAttention): - query_states, key_states, value_states = self.qkv_proj(hidden_states).split( - self.out_features, dim=-1 - ) - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - # [bsz, q_len, nh, hd] - # [bsz, nh, q_len, hd] - - cos, sin = self.rotary_emb(value_states, position_ids=position_ids) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if output_attentions: - warnings.warn( - "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." - ) - - # - # flash-attn v2 start - # - - if self.training: - # during training q,k,v always have same seqlen - assert key_states.shape == query_states.shape - is_causal = True - else: - # turn off FA causal mask after first inference autoregressive iteration - # only on first autoregressive step q,k,v have same seqlen - is_causal = key_states.shape == query_states.shape - - dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0) - - if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: - # special handling using sample packing - qkv = torch.stack( - [query_states, key_states, value_states], dim=2 - ) # [bsz, nh, 3, q_len, hd] - qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] - qkv = rearrange(qkv, "b s ... -> (b s) ...") - - output = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=dropout_rate, - softmax_scale=None, - causal=True, - ) - output = rearrange(output, "(b s) ... -> b s ...", b=bsz) - elif query_states.shape == key_states.shape: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv( - query_states, - key_states, - value_states, - qkvpacked=True, - # We have disabled _prepare_decoder_attention_mask in LlamaModel - # the attention_mask should be the same as the key_padding_mask - key_padding_mask=attention_mask, - query_padding_mask=( - attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None - ), - ) - output_unpad = flash_attn_varlen_qkvpacked_func( - qkv_unpad, - cu_seqlens_q, - max_seqlen_q, - dropout_p=dropout_rate, - softmax_scale=None, - causal=is_causal, - ) - output = output_pad_fn(output_unpad) - else: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - if attention_mask is None or attention_mask.all().item(): - output = flash_attn_kvpacked_func( - query_states, - torch.stack([key_states, value_states], 2), - dropout_p=dropout_rate, - causal=is_causal, - ) - else: - ( # pylint: disable=unbalanced-tuple-unpacking - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - _, - _, - output_pad_fn, - ) = generate_qkv( - query_states, - key_states, - value_states, - kvpacked=True, - key_padding_mask=attention_mask, - query_padding_mask=( - attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None - ), - ) - if q_unpad.dtype != kv_unpad.dtype: - kv_unpad = kv_unpad.to(q_unpad.dtype) - output_unpad = flash_attn_varlen_kvpacked_func( - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=dropout_rate, - softmax_scale=None, - causal=is_causal, - ) - output = output_pad_fn(output_unpad) - - attn_output = output - if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - attn_output = rearrange(attn_output, "b s h d -> b s (h d)") - - # - # flash-attn v2 end - # - - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.pretraining_tp, dim=1 - ) - attn_output = sum( - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.pretraining_tp) - ) - else: - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38 -def generate_qkv( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - kvpacked=False, - qkvpacked=False, -): # pylint: disable=invalid-name,unnecessary-lambda-assignment - """ - Arguments: - q: (batch_size, seqlen_q, nheads, d) - k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) - query_padding_mask: (batch_size, seqlen), bool - key_padding_mask: (batch_size, seqlen), bool - """ - assert not (kvpacked and qkvpacked) - batch_size, seqlen_q, nheads, d = q.shape - _, seqlen_k, nheads_k, _ = k.shape - assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) - - if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input( - q, query_padding_mask - ) - - def output_pad_fn(output_unpad): - return pad_input( # noqa: E731 - output_unpad, indices_q, batch_size, seqlen_q - ) - - else: - q_unpad = rearrange(q, "b s h d -> (b s) h d") - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * seqlen_q, - step=seqlen_q, - dtype=torch.int32, - device=q_unpad.device, - ) - max_seqlen_q = seqlen_q - - def output_pad_fn(output_unpad): - return rearrange( # noqa: E731 - output_unpad, "(b s) h d -> b s h d", b=batch_size - ) - - if key_padding_mask is not None: - k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) - v_unpad, _, _, _ = unpad_input(v, key_padding_mask) - else: - k_unpad = rearrange(k, "b s h d -> (b s) h d") - v_unpad = rearrange(v, "b s h d -> (b s) h d") - cu_seqlens_k = torch.arange( - 0, - (batch_size + 1) * seqlen_k, - step=seqlen_k, - dtype=torch.int32, - device=k_unpad.device, - ) - max_seqlen_k = seqlen_k - - if qkvpacked: - assert nheads == nheads_k - qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) - qkv = torch.stack([q, k, v], dim=2) - return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn) - - if kvpacked: - kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) - kv = torch.stack([k, v], dim=2) - return ( - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q, - kv, - output_pad_fn, - ) - - return ( - q_unpad, - k_unpad, - v_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q, - k, - v, - output_pad_fn, - ) - - -def llama_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[ # pylint: disable=unused-argument - torch.LongTensor - ] = None, -) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - if input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - cu_seqlens = None - max_seqlen = None - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) - cu_seqlens = cu_seqlens.squeeze() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, - ) - padding_mask = None - else: - if 0 in attention_mask: - padding_mask = attention_mask - else: - padding_mask = None - - attention_mask = ( - self._prepare_decoder_attention_mask( # pylint: disable=protected-access - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - transformers.logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module( - *inputs, - ) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - None, - padding_mask, - cu_seqlens, - max_seqlen, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaDecoderLayer(OriginalLlamaDecoderLayer): - """ - patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - padding_mask: Optional[torch.LongTensor] = None, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - padding_mask=padding_mask, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index 28223eee3..332242e2c 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -32,10 +32,9 @@ def xformers_forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument - **kwargs, # pylint: disable=unused-argument + padding_mask: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # pylint: disable=duplicate-code bsz, q_len, _ = hidden_states.size() if not hasattr(self, "pretraining_tp"): @@ -102,7 +101,8 @@ def xformers_forward( if output_attentions: warnings.warn( - "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." + "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.", + stacklevel=2, ) # diff --git a/src/axolotl/monkeypatch/llama_expand_mask.py b/src/axolotl/monkeypatch/llama_expand_mask.py index 0277c212a..5cfb7818e 100644 --- a/src/axolotl/monkeypatch/llama_expand_mask.py +++ b/src/axolotl/monkeypatch/llama_expand_mask.py @@ -21,6 +21,4 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] def hijack_expand_mask(): import transformers - transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access - _expand_mask - ) + transformers.models.llama.modeling_llama._expand_mask = _expand_mask diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py index cfd525367..8d234881f 100644 --- a/src/axolotl/monkeypatch/llama_patch_multipack.py +++ b/src/axolotl/monkeypatch/llama_patch_multipack.py @@ -12,15 +12,15 @@ def hijack_llama_prepare_4d_mask(): from transformers import modeling_attn_mask_utils from transformers.models.llama import modeling_llama - modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access + modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( patched_prepare_4d_causal_attention_mask_for_sdpa ) - modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access + modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( patched_prepare_4d_causal_attention_mask_for_sdpa ) - modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access + modeling_llama._prepare_4d_causal_attention_mask = ( patched_prepare_4d_causal_attention_mask ) - modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access + modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( patched_prepare_4d_causal_attention_mask ) diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index a7875eefe..e845dc6ce 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -18,6 +18,7 @@ from axolotl.kernels.lora import ( apply_lora_qkv, ) from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -29,48 +30,36 @@ QKV_PATCHES = [ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) -""".lstrip( - "\n" - ), +""".lstrip("\n"), """ query_states, key_states, value_states = self.apply_qkv(hidden_states) query_states = query_states.view(hidden_shape).transpose(1, 2) key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) -""".lstrip( - "\n" - ), +""".lstrip("\n"), ), ( """ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) -""".lstrip( - "\n" - ), +""".lstrip("\n"), """ query_states, key_states, value_states = self.apply_qkv(hidden_states) query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) -""".lstrip( - "\n" - ), +""".lstrip("\n"), ), ] ORIGINAL_O_CODE = """ attn_output = self.o_proj(attn_output) -""".lstrip( - "\n" -) +""".lstrip("\n") PATCHED_O_CODE = """ attn_output = self.apply_o(attn_output) -""".lstrip( - "\n" -) +""".lstrip("\n") SUPPORTED_ACTIVATIONS = ["silu", "gelu"] APPLY_FN_MAPPING = { @@ -145,12 +134,30 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]: return Qwen2Attention + if model_type == "mllama": + from transformers.models.mllama.modeling_mllama import MllamaTextSelfAttention + + return MllamaTextSelfAttention + + if model_type == "llama4": + from transformers.models.llama4.modeling_llama4 import Llama4TextAttention + + return Llama4TextAttention + + if model_type == "mistral3": + from transformers.models.mistral.modeling_mistral import MistralAttention + + return MistralAttention + + if model_type == "gemma3_text": + from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention + + return Gemma3Attention + try: # Dynamically import the module and attention class module_path = f"transformers.models.{model_type}.modeling_{model_type}" - model_cls_prefix = "".join( - [part.capitalize() for part in model_type.split("_")] - ) + model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"]) attention_cls = getattr(module, f"{model_cls_prefix}Attention") @@ -162,7 +169,6 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]: ) from e -# pylint: disable=protected-access def patch_self_attn_lora(cfg: DictDefault): """ Given an `axolotl` config, this method patches the inferred attention class forward @@ -189,9 +195,9 @@ def patch_self_attn_lora(cfg: DictDefault): attention_cls._original_forward = self_attn_forward self_attn_forward, _ = detab_code(self_attn_forward) - assert any( - qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES - ), "Original QKV code not found" + assert any(qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES), ( + "Original QKV code not found" + ) assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found" for qkv_orig, qkv_patched in QKV_PATCHES: @@ -217,16 +223,14 @@ def patch_self_attn_lora(cfg: DictDefault): if item in self_attn_forward: items_to_import.append(item) - exec( # pylint: disable=exec-used # nosec B102 + exec( f"from {module_name} import ({', '.join(items_to_import)})", globals(), ) - exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 + exec(self_attn_forward, globals()) LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}") - attention_cls.forward = ( - axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821 - ) + attention_cls.forward = axolotl_attn_forward def find_self_attn_in_layer( @@ -263,12 +267,39 @@ def find_mlp_in_layer( layer.feedforward.experts.gate_projs, layer.feedforward.experts.up_projs, layer.feedforward.experts.down_projs, + strict=False, ): - yield gate_proj, up_proj, down_proj, FakeMLP( - gate_proj, up_proj, down_proj + yield ( + gate_proj, + up_proj, + down_proj, + FakeMLP(gate_proj, up_proj, down_proj), ) +def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]: + """ + Get the layers of the model. Handles text-only and multimodal models. + + Args: + model: A PEFT model. + + Returns: + A list of layers. + """ + pretrained_model = model.model + + # check for multimodal models first + if hasattr(pretrained_model, "language_model"): + return pretrained_model.language_model.layers + if hasattr(pretrained_model, "model"): + return pretrained_model.model.layers + + raise NotImplementedError( + f"Model type {model.config.model_type} is not supported yet. Please create an Issue." + ) + + def apply_lora_kernel_patches( model: PeftModelForCausalLM, cfg: DictDefault ) -> PeftModelForCausalLM: @@ -300,9 +331,9 @@ def apply_lora_kernel_patches( # Get active LoRA adapter config if hasattr(model, "active_adapters"): - assert ( - len(model.active_adapters) == 1 - ), "Axolotl currently does not support LoRA Triton kernels for multiple adapters" + assert len(model.active_adapters) == 1, ( + "Axolotl currently does not support LoRA Triton kernels for multiple adapters" + ) active_adapter = model.active_adapters[0] else: active_adapter = model.active_adapter @@ -340,17 +371,7 @@ def apply_lora_kernel_patches( if activation not in SUPPORTED_ACTIVATIONS: raise NotImplementedError(f"Activation {activation} is not supported") - layers = [] - # check for multimodal models first - pretrained_model = model.model - if hasattr(pretrained_model, "language_model"): - layers = pretrained_model.language_model.layers - elif hasattr(pretrained_model, "model"): - layers = pretrained_model.model.layers - else: - raise NotImplementedError( - f"Model type {model.config.model_type} is not supported yet. Please create an Issue." - ) + layers = get_layers(model) # Patch each layer for layer in layers: @@ -368,7 +389,6 @@ def apply_lora_kernel_patches( ] can_patch_qkv = all( hasattr(module, "lora_A") - and getattr(module, "base_layer", module).bias is None and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 for module in layer_modules ) @@ -378,7 +398,8 @@ def apply_lora_kernel_patches( self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn) else: LOG.warning_once( - "Cannot patch some attention QKV projections - requires LoRA adapters with no bias" + "Cannot patch some attention QKV projections - requires LoRA " + "adapters and no lora_magnitude_vector (DoRA)" ) if cfg.lora_o_kernel: # Output patching @@ -387,7 +408,6 @@ def apply_lora_kernel_patches( ] can_patch_o = all( hasattr(module, "lora_A") - and getattr(module, "base_layer", module).bias is None and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 for module in layer_modules ) @@ -396,14 +416,14 @@ def apply_lora_kernel_patches( self_attn.apply_o = types.MethodType(apply_lora_o, self_attn) else: LOG.warning_once( - "Cannot patch some attention output projection - requires LoRA adapters with no bias" + "Cannot patch some attention output projection - requires LoRA " + "adapters and no lora_magnitude_vector (DoRA)" ) for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer): if cfg.lora_mlp_kernel: # MLP patching can_patch_mlp = all( hasattr(proj, "lora_A") - and getattr(proj, "base_layer", proj).bias is None and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0 for proj in (gate_proj, up_proj, down_proj) ) @@ -413,7 +433,8 @@ def apply_lora_kernel_patches( layer.mlp.forward = types.MethodType(apply_fn, mlp) else: LOG.warning_once( - "Cannot patch some MLP layers - requires LoRA adapters with no bias" + "Cannot patch some MLP layers - requires LoRA adapters and no " + "lora_magnitude_vector (DoRA)" ) LOG.setLevel(original_level) diff --git a/src/axolotl/monkeypatch/loss/__init__.py b/src/axolotl/monkeypatch/loss/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/loss/chunked.py b/src/axolotl/monkeypatch/loss/chunked.py new file mode 100644 index 000000000..26a52f898 --- /dev/null +++ b/src/axolotl/monkeypatch/loss/chunked.py @@ -0,0 +1,134 @@ +""" +chunked ce loss +""" + +from typing import List, Optional + +import torch +import torch.nn.functional as F + + +# copied and modified from torchtune.modules.loss.CEWithChunkedOutputLoss +class CEWithChunkedOutputLoss(torch.nn.Module): + """ + Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time. + + For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390 + """ + + def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100): + super().__init__() + self.num_output_chunks = num_output_chunks + self.ignore_index = ignore_index + + def compute_cross_entropy( + self, + logits: torch.Tensor, + labels: torch.Tensor, + normalize: bool = True, + ) -> torch.Tensor: + """ + Upcast logits to fp32 and compute cross entropy loss. + """ + return F.cross_entropy( + logits.float(), labels, ignore_index=self.ignore_index, reduction="sum" + ) + + def forward( + self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum" + ) -> torch.Tensor: + """ + Args: + logits (List[torch.Tensor]): List of chunked logits of length + ``self.num_output_chunks``, where each chunk has shape + ``(batch_size, num_tokens / num_output_chunks, vocab_size)``. + labels (torch.Tensor): Ground truth labels of shape ``(batch_size, num_tokens)``. + reduction (str): The reduction to apply to the output. + + Returns: + torch.Tensor: Cross entropy loss of shape (1,). + """ + + total_elements = (labels != self.ignore_index).sum() + + # chunk and reshape labels (bsz, num_tokens, vocab) -> [(bsz*num_tokens/num_chunks, vocab)] + labels = [ + target_chunk.reshape(-1) + for target_chunk in labels.chunk(self.num_output_chunks, dim=1) + ] + # reshape logits [(bsz, num_tokens/num_chunks, vocab)] -> [(bsz*num_tokens/num_chunks, vocab)] + logits = [ + logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits + ] + + # compute one chunk at a time + total_loss = 0.0 + for logits_chunk, labels_chunk in zip(logits, labels, strict=False): + total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk) + + if reduction == "sum": + return total_loss + return total_loss / total_elements + + +def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): + loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index) + loss_fn_ce.compute_cross_entropy = torch.compile( + loss_fn_ce.compute_cross_entropy, backend="inductor" + ) + return loss_fn_ce + + +def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100): + loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index) + + def chunked_fix_cross_entropy( + source, + target, + num_items_in_batch: int = None, + ignore_index: int = -100, + **kwargs, + ): + reduction = "sum" if num_items_in_batch is not None else "mean" + logit_chunks = [ + chunk for chunk in source.chunk(loss_fn_ce.num_output_chunks, dim=1) + ] + loss = loss_fn_ce(logit_chunks, target, reduction=reduction) + if reduction == "sum": + loss = loss / num_items_in_batch + return loss + + def for_causal_lm_chunked_loss( + logits, + labels, + vocab_size: int = None, + num_items_in_batch: Optional[int] = None, + ignore_index: int = -100, + shift_labels: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + # skip the upcast to float since we handle that in the chunking loss + if shift_labels is None: + # Shift so that tokens < n predict n + labels = F.pad(labels, (0, 1), value=ignore_index) + shift_labels = labels[..., 1:].contiguous() + + # Skip Flattening the tokens + # Enable model parallelism + shift_labels = shift_labels.to(logits.device) + loss = chunked_fix_cross_entropy( + logits, shift_labels, num_items_in_batch, ignore_index, **kwargs + ) + return loss + + return for_causal_lm_chunked_loss + + +def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): + import transformers.loss.loss_utils + + for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index) + transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss + transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = ( + for_causal_lm_chunked_loss + ) diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index 3fc22917f..0994da91c 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -1,55 +1,14 @@ """Flash attention monkey patch for mistral model""" -# pylint: disable=duplicate-code - from functools import partial -from typing import List, Optional, Tuple, Union -import torch import transformers -from einops import rearrange -from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports - flash_attn_kvpacked_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_qkvpacked_func, -) -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.mistral.modeling_mistral import ( - MistralAttention as OriginalMistralAttention, -) -from transformers.models.mistral.modeling_mistral import ( - MistralDecoderLayer as OriginalMistralDecoderLayer, -) -from transformers.models.mistral.modeling_mistral import ( - apply_rotary_pos_emb, - repeat_kv, -) -from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.utils.logging import get_logger LOG = get_logger(__name__) -def replace_mistral_attn_with_flash_attn( - packed: Optional[bool] = False, -): - transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access - _prepare_decoder_attention_mask - ) - transformers.models.mistral.modeling_mistral.MistralAttention.forward = ( - flashattn_forward - ) - if packed: - transformers.models.mistral.modeling_mistral.MistralDecoderLayer = ( - MistralDecoderLayer - ) - transformers.models.mistral.modeling_mistral.MistralModel.forward = ( - mistral_model_forward - ) - - def patch_mistral_cross_entropy(): from flash_attn.losses.cross_entropy import CrossEntropyLoss @@ -57,604 +16,3 @@ def patch_mistral_cross_entropy(): transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial( CrossEntropyLoss, inplace_backward=True ) - - -@torch.jit.script -def _make_sliding_window_causal_mask( - bsz: int, - tgt_len: int, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, - sliding_window: int = 4096, -): - """ - Make causal mask used for sliding window attention - """ - tensor = torch.full( - (tgt_len, tgt_len), - fill_value=1, - device=device, - ) - mask = torch.tril(tensor, diagonal=0) - # make the mask banded to account for sliding window - # NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1 - mask = torch.triu(mask, diagonal=-sliding_window + 1) - mask = torch.log(mask).to(dtype) - - if past_key_values_length > 0: - mask = torch.cat( - [ - torch.zeros( - tgt_len, past_key_values_length, dtype=dtype, device=device - ), - mask, - ], - dim=-1, - ) - return mask[None, None, :, :].expand( - bsz, 1, tgt_len, tgt_len + past_key_values_length - ) - - -# Disable the transformation of the attention mask in LlamaModel as the flash attention -# requires the attention mask to be the same as the key_padding_mask -def _prepare_decoder_attention_mask( - self, - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - sliding_window, -): # pylint: disable=unused-argument - # [bsz, seq_len] - if attention_mask is None or sliding_window is None: - return attention_mask - - # NOTE: attention mask and sliding masks are only broadcastable in certain scenarios. - # Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled. - if input_shape[-1] > 1 and attention_mask.shape[0] == 1: - sliding_window_mask = _make_sliding_window_causal_mask( - bsz=input_shape[0], - tgt_len=input_shape[1], - dtype=inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - sliding_window=sliding_window, - ) - attention_mask = attention_mask + sliding_window_mask - else: - LOG.info("skipping sliding window mask, not broadcastable with attention mask") - - return attention_mask - - -def flashattn_forward( - self: OriginalMistralAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, position_ids=position_ids) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - - use_sliding_windows = ( - getattr(self.config, "sliding_window") is not None - and kv_seq_len > self.config.sliding_window - ) - - if use_sliding_windows: - window_size = (self.config.sliding_window, self.config.sliding_window) - else: - window_size = (-1, -1) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - if ( - hasattr(self.config, "sliding_window") - and kv_seq_len > self.config.sliding_window - ): - slicing_tokens = kv_seq_len - self.config.sliding_window - - past_key = past_key_value[0] - past_value = past_key_value[1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - past_key_value = (past_key, past_value) if use_cache else None - - if past_key_value is not None: - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if self.training: - # during training q,k,v always have same seqlen - assert key_states.shape == query_states.shape - is_causal = True - else: - # turn off FA causal mask after first inference autoregressive iteration - # only on first autoregressive step q,k,v have same seqlen - is_causal = key_states.shape == query_states.shape - - dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0) - - if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: - # special handling using sample packing - qkv = torch.stack( - [query_states, key_states, value_states], dim=2 - ) # [bsz, nh, 3, q_len, hd] - qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] - qkv = rearrange(qkv, "b s ... -> (b s) ...") - - output = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=dropout_rate, - softmax_scale=None, - causal=True, - window_size=window_size, - ) - output = rearrange(output, "(b s) ... -> b s ...", b=bsz) - elif query_states.shape == key_states.shape: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv( - query_states, - key_states, - value_states, - qkvpacked=True, - # We have disabled _prepare_decoder_attention_mask in LlamaModel - # the attention_mask should be the same as the key_padding_mask - key_padding_mask=attention_mask, - query_padding_mask=( - attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None - ), - ) - output_unpad = flash_attn_varlen_qkvpacked_func( - qkv_unpad, - cu_seqlens_q, - max_seqlen_q, - dropout_p=dropout_rate, - softmax_scale=None, - causal=is_causal, - window_size=window_size, - ) - output = output_pad_fn(output_unpad) - else: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - if attention_mask is None or attention_mask.all().item(): - output = flash_attn_kvpacked_func( - query_states, - torch.stack([key_states, value_states], 2), - dropout_p=dropout_rate, - causal=is_causal, - window_size=window_size, - ) - else: - ( # pylint: disable=unbalanced-tuple-unpacking - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - _, - _, - output_pad_fn, - ) = generate_qkv( - query_states, - key_states, - value_states, - kvpacked=True, - key_padding_mask=attention_mask, - query_padding_mask=( - attention_mask[:, -query_states.size(1) :] - if attention_mask is not None - else None - ), - ) - if q_unpad.dtype != kv_unpad.dtype: - kv_unpad = kv_unpad.to(q_unpad.dtype) - output_unpad = flash_attn_varlen_kvpacked_func( - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=dropout_rate, - softmax_scale=None, - causal=is_causal, - window_size=window_size, - ) - output = output_pad_fn(output_unpad) - - attn_output = output - if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - attn_output = rearrange(attn_output, "b s h d -> b s (h d)") - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38 -def generate_qkv( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - kvpacked=False, - qkvpacked=False, -): # pylint: disable=invalid-name,unnecessary-lambda-assignment - """ - Arguments: - q: (batch_size, seqlen_q, nheads, d) - k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) - query_padding_mask: (batch_size, seqlen), bool - key_padding_mask: (batch_size, seqlen), bool - """ - assert not (kvpacked and qkvpacked) - batch_size, seqlen_q, nheads, d = q.shape - _, seqlen_k, nheads_k, _ = k.shape - assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) - - if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input( - q, query_padding_mask - ) - - def output_pad_fn(output_unpad): - return pad_input( # noqa: E731 - output_unpad, indices_q, batch_size, seqlen_q - ) - - else: - q_unpad = rearrange(q, "b s h d -> (b s) h d") - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * seqlen_q, - step=seqlen_q, - dtype=torch.int32, - device=q_unpad.device, - ) - max_seqlen_q = seqlen_q - - def output_pad_fn(output_unpad): - return rearrange( # noqa: E731 - output_unpad, "(b s) h d -> b s h d", b=batch_size - ) - - if key_padding_mask is not None: - k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) - v_unpad, _, _, _ = unpad_input(v, key_padding_mask) - else: - k_unpad = rearrange(k, "b s h d -> (b s) h d") - v_unpad = rearrange(v, "b s h d -> (b s) h d") - cu_seqlens_k = torch.arange( - 0, - (batch_size + 1) * seqlen_k, - step=seqlen_k, - dtype=torch.int32, - device=k_unpad.device, - ) - max_seqlen_k = seqlen_k - - if qkvpacked: - assert nheads == nheads_k - qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) - qkv = torch.stack([q, k, v], dim=2) - return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn) - - if kvpacked: - kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) - kv = torch.stack([k, v], dim=2) - return ( - q_unpad, - kv_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q, - kv, - output_pad_fn, - ) - - return ( - q_unpad, - k_unpad, - v_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - q, - k, - v, - output_pad_fn, - ) - - -def mistral_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[ # pylint: disable=unused-argument - torch.LongTensor - ] = None, -) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - if input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - cu_seqlens = None - max_seqlen = None - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) - cu_seqlens = cu_seqlens.squeeze() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, - ) - attention_mask = ( - self._prepare_decoder_attention_mask( # pylint: disable=protected-access - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - transformers.logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - layer_outputs = ( - self._gradient_checkpointing_func( # pylint: disable=protected-access - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - None, - cu_seqlens, - max_seqlen, - ) - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class MistralDecoderLayer(OriginalMistralDecoderLayer): - """ - patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs diff --git a/src/axolotl/monkeypatch/mixtral/__init__.py b/src/axolotl/monkeypatch/mixtral/__init__.py index 5b8054000..b353b12cf 100644 --- a/src/axolotl/monkeypatch/mixtral/__init__.py +++ b/src/axolotl/monkeypatch/mixtral/__init__.py @@ -31,14 +31,12 @@ def patch_mixtral_moe_forward_zero3() -> None: topk_weight = topk_weight.to(hidden_states.dtype) hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0) - y = torch.empty_like(hidden_states) # pylint: disable=invalid-name + y = torch.empty_like(hidden_states) flat_topk_idx = topk_idx.view(-1) for i in range(self.num_experts): expert = self.experts[i] y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) - y = ( # pylint: disable=invalid-name - y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1) - ).sum(dim=1) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits diff --git a/src/axolotl/monkeypatch/models/__init__.py b/src/axolotl/monkeypatch/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/models/apertus/__init__.py b/src/axolotl/monkeypatch/models/apertus/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/models/apertus/activation.py b/src/axolotl/monkeypatch/models/apertus/activation.py new file mode 100644 index 000000000..d5470aceb --- /dev/null +++ b/src/axolotl/monkeypatch/models/apertus/activation.py @@ -0,0 +1,52 @@ +"""Monkeypatch for Apertus to dtype mismatch in XIELU act""" + +from torch import Tensor + + +def patch_apertus_xielu_activation(): + try: + from transformers.activations import XIELUActivation + except ImportError as err: + raise ImportError( + "Cannot import XIELUActivation. " + "Please make sure to update your transformers version >= 4.56.1." + ) from err + + from transformers.activations import logger + + # Store the original method + old_fn = XIELUActivation._xielu_cuda + + def _xielu_cuda_fixed(self, x: Tensor) -> Tensor: + """Firewall function to prevent torch.compile from seeing .item() calls""" + original_shape = x.shape + # CUDA kernel expects 3D tensors, reshape if needed + while x.dim() < 3: + x = x.unsqueeze(0) + if x.dim() > 3: + x = x.view(-1, 1, x.size(-1)) + if original_shape != x.shape: + logger.warning_once( + "Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).", + original_shape, + x.shape, + ) + result = self._xielu_cuda_obj.forward( + x, + self.alpha_p.to(x.dtype), + self.alpha_n.to(x.dtype), + # Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item() + self._beta_scalar, + self._eps_scalar, + self.with_vector_loads, + ) + return result.view(original_shape) + + # Apply the patch + XIELUActivation._xielu_cuda = _xielu_cuda_fixed + + def unpatch(): + """Restore the original method""" + XIELUActivation._xielu_cuda = old_fn + + return unpatch diff --git a/src/axolotl/monkeypatch/models/llama4/modeling.py b/src/axolotl/monkeypatch/models/llama4/modeling.py index 4127793e7..0fc8f5699 100644 --- a/src/axolotl/monkeypatch/models/llama4/modeling.py +++ b/src/axolotl/monkeypatch/models/llama4/modeling.py @@ -95,18 +95,12 @@ def patch_llama4_linearized_modeling(): old_lamma_4_text_experts = modeling_llama4.Llama4TextExperts modeling_llama4.Llama4TextExperts = Llama4TextExperts - setattr( - sys.modules["transformers.models.llama4"], - "Llama4TextExperts", - Llama4TextExperts, - ) + sys.modules["transformers.models.llama4"].Llama4TextExperts = Llama4TextExperts def unpatch(): modeling_llama4.Llama4TextExperts = old_lamma_4_text_experts - setattr( - sys.modules["transformers.models.llama4"], - "Llama4TextExperts", - old_lamma_4_text_experts, - ) + sys.modules[ + "transformers.models.llama4" + ].Llama4TextExperts = old_lamma_4_text_experts return unpatch diff --git a/src/axolotl/monkeypatch/models/mistral3/__init__.py b/src/axolotl/monkeypatch/models/mistral3/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/models/mistral3/mistral_common_tokenizer.py b/src/axolotl/monkeypatch/models/mistral3/mistral_common_tokenizer.py new file mode 100644 index 000000000..9e7259a05 --- /dev/null +++ b/src/axolotl/monkeypatch/models/mistral3/mistral_common_tokenizer.py @@ -0,0 +1,85 @@ +""" +Monkeypatch to fix inefficient tensor conversion in MistralCommonTokenizer.apply_chat_template +""" + +import importlib +import inspect + +from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def apply_mistral_tokenizer_image_patch(): + """Apply patch to MistralCommonTokenizer.apply_chat_template to fix image tensor conversion.""" + from transformers.tokenization_mistral_common import MistralCommonTokenizer + + # Get original source + original_source = inspect.getsource(MistralCommonTokenizer.apply_chat_template) + original_source, _ = detab_code(original_source) + + # Define the replacement + original_tensor_conversion = ( + " pixel_values = torch.tensor(images)" + ) + + patched_tensor_conversion = """ if isinstance(images, list) and len(images) > 0 and isinstance(images[0], np.ndarray): + pixel_values = torch.tensor(np.array(images)) + else: + pixel_values = torch.tensor(images)""" + + # Apply the replacement + if original_tensor_conversion in original_source: + patched_source = original_source.replace( + original_tensor_conversion, patched_tensor_conversion + ) + patched_source = patched_source.replace( + "def apply_chat_template(", + "def patched_apply_chat_template(", + 1, + ) + + # Load necessary imports from the module + module_name = MistralCommonTokenizer.__module__ + module = importlib.import_module(module_name) + + # Detect what needs to be imported + items_to_import = [] + for item in dir(module): + if item in patched_source and not item.startswith("_"): + items_to_import.append(item) + + # Execute imports in global scope + if items_to_import: + exec( # nosec B102 + f"from {module_name} import ({', '.join(items_to_import)})", + globals(), + ) + + # Also need standard imports that might be used + exec("import numpy as np", globals()) # nosec B102 + exec("import torch", globals()) # nosec B102 + exec("from typing import Union, Optional, List, Dict, Any, Callable", globals()) # nosec B102 + exec("from pathlib import Path", globals()) # nosec B102 + + # Import other dependencies that might be needed + try: + exec("from transformers.utils import is_torch_available", globals()) # nosec B102 + exec( + "from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TensorType", + globals(), + ) # nosec B102 + exec("from transformers.utils import logging", globals()) # nosec B102 + exec("logger = logging.get_logger(__name__)", globals()) # nosec B102 + except ImportError as e: + LOG.warning(f"Could not import some dependencies: {e}") + + # Execute the patched source + exec(patched_source, globals()) # nosec B102 + + # Replace the method + MistralCommonTokenizer.apply_chat_template = patched_apply_chat_template + LOG.info("Successfully applied MistralCommonTokenizer tensor conversion patch") + else: + LOG.warning("Could not find target code for MistralCommonTokenizer patching") diff --git a/src/axolotl/monkeypatch/models/pixtral/__init__.py b/src/axolotl/monkeypatch/models/pixtral/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/models/pixtral/modeling_flash_attention_utils.py b/src/axolotl/monkeypatch/models/pixtral/modeling_flash_attention_utils.py new file mode 100644 index 000000000..d2b482f19 --- /dev/null +++ b/src/axolotl/monkeypatch/models/pixtral/modeling_flash_attention_utils.py @@ -0,0 +1,42 @@ +"""Monkeypatch for FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid""" + +import torch + + +def apply_patch_is_packed_sequence(): + """Apply patch to FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid""" + from transformers import modeling_flash_attention_utils + + def fixed_is_packed_sequence(position_ids, batch_size): + """ + Check the position ids whether packed sequences are indicated or not + 1. Position ids exist + 2. Flattened sequences only are supported + 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences + """ + if position_ids is None: + return False + + if position_ids.ndim == 1: + position_ids = position_ids.unsqueeze(0) # [N] -> [1, N] + + increasing_position_sequences = ( + torch.arange(position_ids.shape[1], device=position_ids.device) + + position_ids.min() + ) + return ( + batch_size == 1 + and (increasing_position_sequences - position_ids).abs().sum().bool().item() + ) + + # Store original method + old_fn = modeling_flash_attention_utils._is_packed_sequence + + # Apply the patch + modeling_flash_attention_utils._is_packed_sequence = fixed_is_packed_sequence + + def unpatch(): + """Restore the original method""" + modeling_flash_attention_utils._is_packed_sequence = old_fn + + return unpatch diff --git a/src/axolotl/monkeypatch/models/qwen3_next/__init__.py b/src/axolotl/monkeypatch/models/qwen3_next/__init__.py new file mode 100644 index 000000000..39bcd4115 --- /dev/null +++ b/src/axolotl/monkeypatch/models/qwen3_next/__init__.py @@ -0,0 +1 @@ +"""Qwen3_Next model monkeypatches.""" diff --git a/src/axolotl/monkeypatch/models/qwen3_next/modeling.py b/src/axolotl/monkeypatch/models/qwen3_next/modeling.py new file mode 100644 index 000000000..d68992d0e --- /dev/null +++ b/src/axolotl/monkeypatch/models/qwen3_next/modeling.py @@ -0,0 +1,317 @@ +"""Monkeypatch for Qwen3_Next model to pass position_ids to linear attention.""" + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def get_cu_seqlens(position_ids): + """ + Adapted from transformers.modeling_flash_attention_utils.prepare_fa_kwargs_from_position_ids. + + https://github.com/huggingface/transformers/blob/0f1b128d3359a26bd18be99c26d7f04fb3cba914/src/transformers/modeling_flash_attention_utils.py#L316 + """ + tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device} + + position_ids = position_ids.view(-1) + indices_q = (position_ids == 0).nonzero().view(-1) + + cu_seq_lens_q = torch.cat( + ( + indices_q.to(**tensor_kwargs), + torch.tensor(position_ids.size(), **tensor_kwargs), + ) + ) + + return cu_seq_lens_q + + +def patch_qwen3_next_decoder_layer(): + """Patch Qwen3NextDecoderLayer to pass position_ids to linear attention.""" + try: + from transformers.models.qwen3_next.modeling_qwen3_next import ( + Qwen3NextDecoderLayer, + ) + except ImportError: + LOG.warning("Qwen3Next model not found, skipping patch") + return + + # Store original forward method + original_decoder_forward = Qwen3NextDecoderLayer.forward + + def patched_decoder_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Token Mixer + if self.layer_type == "linear_attention": + hidden_states = self.linear_attn( + hidden_states=hidden_states, + cache_params=past_key_values, + cache_position=cache_position, + attention_mask=attention_mask, + position_ids=position_ids, + ) + elif self.layer_type == "full_attention": + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # For the MoE layers, we need to unpack + if isinstance(hidden_states, Tuple): + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states + + return hidden_states + + # Apply the patches + Qwen3NextDecoderLayer.forward = patched_decoder_forward + + def unpatch(): + """Restore the original forward method""" + Qwen3NextDecoderLayer.forward = original_decoder_forward + + return unpatch + + +def patch_qwen3_next_gateddelta_layer(): + """Patch Qwen3NextGatedDeltaNet to parse cu_seqlens and pass to chunk_gated_delta_rule""" + try: + from transformers.models.qwen3_next.modeling_qwen3_next import ( + Qwen3NextDynamicCache, + Qwen3NextGatedDeltaNet, + apply_mask_to_padding_states, + ) + except ImportError: + LOG.warning("Qwen3Next model not found, skipping patch") + return + + # Store original forward method + original_gated_delta_net_forward = Qwen3NextGatedDeltaNet.forward + + def patched_gated_delta_net_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[Qwen3NextDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ): + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_position is not None + ) + + # getting projected states from cache if it exists + if cache_params is not None: + conv_state = cache_params.conv_states[self.layer_idx] + recurrent_state = cache_params.recurrent_states[self.layer_idx] + + projected_states_qkvz = self.in_proj_qkvz(hidden_states) + projected_states_ba = self.in_proj_ba(hidden_states) + query, key, value, z, b, a = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba + ) + query, key, value = ( + x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value) + ) + + mixed_qkv = torch.cat((query, key, value), dim=-1) + mixed_qkv = mixed_qkv.transpose(1, 2) + + if use_precomputed_states: + # 2. Convolution sequence transformation + # NOTE: the conv state is updated in `causal_conv1d_update` + mixed_qkv = self.causal_conv1d_update( + mixed_qkv, + conv_state, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + else: + if cache_params is not None: + conv_state = F.pad( + mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx] = conv_state + if self.causal_conv1d_fn is not None: + mixed_qkv = self.causal_conv1d_fn( + x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=None, + ) + else: + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + dim=-1, + ) + query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim) + key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim) + value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) + + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + if not use_precomputed_states: + cu_seqlens = get_cu_seqlens(position_ids=position_ids) + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) + + else: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + # Update cache + if cache_params is not None: + cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + + z_shape_og = z.shape + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = core_attn_out.reshape( + core_attn_out.shape[0], core_attn_out.shape[1], -1 + ) + + output = self.out_proj(core_attn_out) + return output + + # Apply the patches + Qwen3NextGatedDeltaNet.forward = patched_gated_delta_net_forward + + def unpatch(): + """Restore the original forward method""" + Qwen3NextGatedDeltaNet.forward = original_gated_delta_net_forward + + return unpatch + + +def patch_qwen3_next_imports(): + """Patch Qwen3Next imports to use try/except instead of is_flash_linear_attention_available.""" + try: + import transformers.models.qwen3_next.modeling_qwen3_next as qwen3_modeling + except ImportError: + LOG.warning("Qwen3Next model not found, skipping import patch") + return + + # Save original values for unpatch + original_FusedRMSNormGated = getattr(qwen3_modeling, "FusedRMSNormGated", None) + original_chunk_gated_delta_rule = getattr( + qwen3_modeling, "chunk_gated_delta_rule", None + ) + original_fused_recurrent_gated_delta_rule = getattr( + qwen3_modeling, "fused_recurrent_gated_delta_rule", None + ) + original_is_fast_path_available = getattr( + qwen3_modeling, "is_fast_path_available", False + ) + + try: + from fla.modules import FusedRMSNormGated + from fla.ops.gated_delta_rule import ( + chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule, + ) + + qwen3_modeling.FusedRMSNormGated = FusedRMSNormGated + qwen3_modeling.chunk_gated_delta_rule = chunk_gated_delta_rule + qwen3_modeling.fused_recurrent_gated_delta_rule = ( + fused_recurrent_gated_delta_rule + ) + + # Force is_fast_path_available to be True + # fla has triton kernels for causal_conv1d + qwen3_modeling.is_fast_path_available = True + except ImportError: + qwen3_modeling.chunk_gated_delta_rule = None + qwen3_modeling.fused_recurrent_gated_delta_rule = None + qwen3_modeling.FusedRMSNormGated = None + + def unpatch(): + """Restore the original import values""" + qwen3_modeling.FusedRMSNormGated = original_FusedRMSNormGated + qwen3_modeling.chunk_gated_delta_rule = original_chunk_gated_delta_rule + qwen3_modeling.fused_recurrent_gated_delta_rule = ( + original_fused_recurrent_gated_delta_rule + ) + qwen3_modeling.is_fast_path_available = original_is_fast_path_available + + return unpatch + + +def patch_qwen3_next_modeling_packing(): + """Apply all Qwen3Next model patches.""" + patch_qwen3_next_imports() + patch_qwen3_next_decoder_layer() + patch_qwen3_next_gateddelta_layer() + + LOG.info("Applied Qwen3Next patch for packing") diff --git a/src/axolotl/monkeypatch/models/voxtral/__init__.py b/src/axolotl/monkeypatch/models/voxtral/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/models/voxtral/modeling.py b/src/axolotl/monkeypatch/models/voxtral/modeling.py new file mode 100644 index 000000000..3dd652dd8 --- /dev/null +++ b/src/axolotl/monkeypatch/models/voxtral/modeling.py @@ -0,0 +1,67 @@ +"""Monkeypatch for voxtral to fix leaf node and dtype mismatch""" + +from typing import Optional, Union + +import torch +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast + + +def patch_voxtral_conditional_generation_forward(): + from transformers.models.voxtral.modeling_voxtral import ( + VoxtralForConditionalGeneration, + ) + + # Store the original forward method + old_forward = VoxtralForConditionalGeneration.forward + + def _forward( + self, + input_ids: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> CausalLMOutputWithPast: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if input_features is not None: + audio_embeds = self.get_audio_embeds(input_features) + + # Cast audio_embeds to match inputs_embeds dtype + audio_embeds = audio_embeds.to(inputs_embeds.dtype) + + # replace text-audio token placeholders with audio embeddings + audio_token_mask = input_ids == self.config.audio_token_id + + inputs_embeds = inputs_embeds.clone() + inputs_embeds[audio_token_mask] = audio_embeds + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + return outputs + + # Apply the patch + VoxtralForConditionalGeneration.forward = _forward + + def unpatch(): + """Restore the original forward method""" + VoxtralForConditionalGeneration.forward = old_forward + + return unpatch diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 1467f9e29..48b4ea10e 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -11,6 +11,7 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3 from axolotl.monkeypatch.utils import get_unpad_data SUPPORTED_MULTIPACK_MODEL_TYPES = [ + "apertus", "mllama_text_model", "llama", "llama4", @@ -20,6 +21,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "qwen2_moe", "qwen3", "qwen3_moe", + "qwen3_next", "falcon", "phi", "phi3", @@ -35,6 +37,16 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "deepseek_v3", "glm", "glm4", + "smollm3", + "granite", + "granitemoe", + "hunyuan_v1_dense", + "hunyuan_v1_moe", + "gpt_oss", + "arcee", + "seed_oss", + "lfm2", + "lfm2_moe", ] @@ -42,9 +54,11 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False): if has_remote_code: patch_remote(model_name) elif hasattr(transformers, "modeling_flash_attention_utils"): - transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) + # sanity check in case upstream api changes on this + assert hasattr( + transformers.modeling_flash_attention_utils, "_get_unpad_data" + ), "transformers api changed for _get_unpad_data for flash attention" + transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data if model_type == "mixtral" and is_deepspeed_zero3_enabled(): patch_mixtral_moe_forward_zero3() @@ -60,6 +74,4 @@ def patch_remote(model_name): module_name = ".".join(parts) modeling_arch = importlib.import_module(module_name) if hasattr(modeling_arch, "_get_unpad_data"): - modeling_arch._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) + modeling_arch._get_unpad_data = get_unpad_data diff --git a/src/axolotl/monkeypatch/peft/utils.py b/src/axolotl/monkeypatch/peft/utils.py index 0c571fbd2..d1011f5eb 100644 --- a/src/axolotl/monkeypatch/peft/utils.py +++ b/src/axolotl/monkeypatch/peft/utils.py @@ -49,9 +49,7 @@ def patch_peft_prep_code(): prep_code = get_peft_prep_code() except OSError: return - peft.utils.other._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access - prep_code - ) + peft.utils.other._original_create_accelerator_and_postprocess = prep_code prep_code, _ = detab_code(prep_code) if ORIGINAL_PREPARE_CODE not in prep_code: return @@ -68,11 +66,15 @@ def patch_peft_prep_code(): if item in prep_code: items_to_import.append(item) - exec( # pylint: disable=exec-used # nosec B102 + exec( "from peft.utils.other import (" + ", ".join(x for x in items_to_import) + ")", globals(), ) - exec(prep_code, globals()) # pylint: disable=exec-used # nosec B102 + exec(prep_code, globals()) LOG.info("patching prepare_model_for_kbit_training to allow for overrides") - peft.utils.other.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821 - axolotl.loaders.model.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821 + peft.utils.other.prepare_model_for_kbit_training = ( + fixed_prepare_model_for_kbit_training + ) + axolotl.loaders.model.prepare_model_for_kbit_training = ( + fixed_prepare_model_for_kbit_training + ) diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py index 5b7418e39..a01d850b3 100644 --- a/src/axolotl/monkeypatch/relora.py +++ b/src/axolotl/monkeypatch/relora.py @@ -6,7 +6,7 @@ import os.path import shutil from functools import partial from pathlib import Path -from typing import Dict, List, Sequence, Union +from typing import Dict, List, Union import bitsandbytes as bnb import peft @@ -14,8 +14,6 @@ import safetensors.torch as st import torch from huggingface_hub import snapshot_download from torch.distributed.optim import ZeroRedundancyOptimizer -from torch.optim.lr_scheduler import LRScheduler -from torch.optim.optimizer import Optimizer from transformers import ( TrainerCallback, TrainerControl, @@ -84,7 +82,7 @@ class ReLoRACallback(TrainerCallback): """Callback to merge LoRA weights into the base model and save full-weight checkpoints""" def __init__(self, cfg: DictDefault): - self.relora_steps = cfg.relora_steps + self.relora_steps = cfg.jagged_restart_steps self.cpu_offload = cfg.relora_cpu_offload self.quantized = cfg.load_in_4bit or cfg.load_in_8bit self.last_full_model = cfg.base_model @@ -93,9 +91,9 @@ class ReLoRACallback(TrainerCallback): if not os.path.exists(self.last_full_model): self.last_full_model = str(Path(snapshot_download(cfg.base_model))) - assert os.path.exists( - self.last_full_model - ), "for ReLORA base_model must be a local path" + assert os.path.exists(self.last_full_model), ( + "for ReLORA base_model must be a local path" + ) self.num_lora_restarts = 0 self.need_full_save = False @@ -255,51 +253,6 @@ class ReLoRACallback(TrainerCallback): return control -class ReLoRAScheduler(LRScheduler): - """Wraps another scheduler to apply per-lora-restart learning rate warmups.""" - - def __init__( - self, - optimizer: Optimizer, - inner_schedule: LRScheduler, - relora_steps: int, - warmup_steps: int, - anneal_steps: int = 1, - min_lr_scale: float = 0.001, - ) -> None: - self.inner_schedule = inner_schedule - self.relora_steps = relora_steps - self.warmup_steps = warmup_steps - self.anneal_steps = anneal_steps - self.min_lr_scale = min_lr_scale - super().__init__(optimizer, inner_schedule.last_epoch) - - def get_lr(self) -> float: - self.inner_schedule.last_epoch = self.last_epoch - - original = self.inner_schedule.get_lr() - step = self.last_epoch - - if step < self.relora_steps - self.warmup_steps: - scale = 1 - else: - per_relora_progress = step % self.relora_steps - if per_relora_progress < self.warmup_steps: - cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps) - elif per_relora_progress > (self.relora_steps - self.anneal_steps): - cycle_t = min( - 1.0, - (self.relora_steps - per_relora_progress) / self.anneal_steps, - ) - else: - cycle_t = 1 - scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale - - if isinstance(original, Sequence): - return [lr * scale for lr in original] - return original * scale - - def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]: model_name = "model.safetensors" if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists( @@ -340,7 +293,6 @@ def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraL key_list = [key for key, _ in model.model.named_modules() if "lora" not in key] for key in key_list: try: - # pylint: disable=protected-access _parent, target, _target_name = peft.utils._get_submodules(model.model, key) except AttributeError: continue @@ -388,7 +340,7 @@ def merge_and_save( modules = find_lora_modules(model) if not quantized: - for module_name, target in modules.items(): + for _, target in modules.items(): active_adapter = target.active_adapter if isinstance(active_adapter, list): active_adapter = active_adapter[0] diff --git a/src/axolotl/monkeypatch/ring_attn/__init__.py b/src/axolotl/monkeypatch/ring_attn/__init__.py index 5833b9ce4..1c14776c9 100644 --- a/src/axolotl/monkeypatch/ring_attn/__init__.py +++ b/src/axolotl/monkeypatch/ring_attn/__init__.py @@ -1,22 +1,17 @@ """Init for ring attention monkeypatch module""" -# pylint: disable=unused-import # flake8: noqa from .patch import ( get_ring_attn_group, - patch_prepare_data_loader, - patch_prepare_device_mesh, - register_ring_attn, + register_ring_attn_from_device_mesh, set_ring_attn_group, update_ring_attn_params, ) __all__ = ( "get_ring_attn_group", - "patch_prepare_data_loader", - "patch_prepare_device_mesh", - "register_ring_attn", + "register_ring_attn_from_device_mesh", "set_ring_attn_group", "update_ring_attn_params", ) diff --git a/src/axolotl/monkeypatch/ring_attn/adapters/batch.py b/src/axolotl/monkeypatch/ring_attn/adapters/batch.py index e556ba5e3..74d33ed4a 100644 --- a/src/axolotl/monkeypatch/ring_attn/adapters/batch.py +++ b/src/axolotl/monkeypatch/ring_attn/adapters/batch.py @@ -7,8 +7,6 @@ Our implementation closely follows the structure of that module, but we've minif somewhat to support only the latest versions of transformers. """ -# pylint: disable=protected-access,cyclic-import - import os from typing import Callable @@ -18,10 +16,18 @@ import transformers import transformers.modeling_flash_attention_utils from ring_flash_attn import ring_flash_attn_func from ring_flash_attn.adapters.hf_adapter import check_params -from transformers.modeling_flash_attention_utils import ( - _flash_supports_window_size, - is_flash_attn_greater_or_equal, -) +from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal + +try: + from transformers.modeling_flash_attention_utils import _flash_supports_window +except ImportError: + try: + from transformers.modeling_flash_attention_utils import ( + _flash_supports_window_size as _flash_supports_window, + ) + except ImportError: + _flash_supports_window = True + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from axolotl.utils.schemas.enums import RingAttnFunc @@ -33,7 +39,7 @@ RING_ATTN_FUNC_MAPPING = { } -def create_flash_attn_forward( +def create_flash_attn_forward_varlen_llama3( process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc ) -> Callable: """ @@ -51,7 +57,7 @@ def create_flash_attn_forward( """ # transformers 4.48+ - # pylint: disable=unused-argument + def _flash_attention_forward( query_states: torch.Tensor, key_states: torch.Tensor, @@ -71,6 +77,7 @@ def create_flash_attn_forward( max_length_q: int | None = None, max_length_k: int | None = None, target_dtype: torch.dtype | None = None, + attn_implementation: str | None = None, **kwargs, ): """ @@ -97,6 +104,7 @@ def create_flash_attn_forward( max_length_q: Not used in this implementation. max_length_k: Not used in this implementation. target_dtype: Not used in this implementation. + attn_implementation: Not used in this implementation. **kwargs: Additional keyword arguments. Not used in this implementation. Returns: @@ -110,7 +118,7 @@ def create_flash_attn_forward( # Handle sliding window use_sliding_windows = ( - _flash_supports_window_size + _flash_supports_window and sliding_window is not None and key_states.shape[1] > sliding_window ) @@ -161,7 +169,7 @@ def substitute_hf_flash_attn( old_flash_attention_forward = ( transformers.modeling_flash_attention_utils._flash_attention_forward ) - new_flash_attention_forward = create_flash_attn_forward( + new_flash_attention_forward = create_flash_attn_forward_varlen_llama3( process_group=process_group, ring_attn_func=ring_attn_func ) diff --git a/src/axolotl/monkeypatch/ring_attn/patch.py b/src/axolotl/monkeypatch/ring_attn/patch.py index 7d733cfc1..e1fd10b3a 100644 --- a/src/axolotl/monkeypatch/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/ring_attn/patch.py @@ -8,69 +8,141 @@ We also provide some patches for accelerate functions to prepare the dataloader sequence parallelism training. """ -import inspect +import os +from typing import Optional -import accelerate import torch import torch.distributed as dist -from accelerate.logging import get_logger +from torch.distributed import DeviceMesh + +try: + from transformers.modeling_flash_attention_utils import _flash_supports_window +except ImportError: + try: + from transformers.modeling_flash_attention_utils import ( + _flash_supports_window_size as _flash_supports_window, + ) + except ImportError: + _flash_supports_window = True from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RingAttnFunc LOG = get_logger(__name__) - RING_ATTN_GROUP = None -ORIGINAL_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1 - submesh_dp_size = 1 - submesh_tp_size = 1 - if "tp" in torch_device_mesh.mesh_dim_names: - submesh_tp_size = torch_device_mesh["tp"].size() - if "dp" in torch_device_mesh.mesh_dim_names: - submesh_dp_size = torch_device_mesh["dp"].size() - if "fsdp" in torch_device_mesh.mesh_dim_names: - submesh_fsdp_size = torch_device_mesh["fsdp"].size() - process_index = process_index // submesh_tp_size""" - -NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1 - submesh_dp_size = 1 - submesh_tp_size = 1 - submesh_cp_size = 1 - if "cp" in torch_device_mesh.mesh_dim_names: - submesh_cp_size = torch_device_mesh["cp"].size() - if "tp" in torch_device_mesh.mesh_dim_names: - submesh_tp_size = torch_device_mesh["tp"].size() - if "dp" in torch_device_mesh.mesh_dim_names: - submesh_dp_size = torch_device_mesh["dp"].size() - if "fsdp" in torch_device_mesh.mesh_dim_names: - submesh_fsdp_size = torch_device_mesh["fsdp"].size() - process_index = process_index // (submesh_tp_size * submesh_cp_size)""" - def get_ring_attn_group() -> dist.ProcessGroup: """Getter for ring attention group on this rank.""" if RING_ATTN_GROUP is None: - raise RuntimeError("register_ring_attn() not yet called") + raise RuntimeError("register_ring_attn_from_device_mesh() not yet called") return RING_ATTN_GROUP def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None): """Setter for ring attention group on this rank.""" - global RING_ATTN_GROUP # pylint: disable=global-statement + global RING_ATTN_GROUP RING_ATTN_GROUP = ring_attn_group -def register_ring_attn( - sequence_parallel_degree: int, +def create_ring_flash_attention_forward( + process_group: dist.ProcessGroup, heads_k_stride: int +): + from ring_flash_attn import llama3_flash_attn_varlen_func + from ring_flash_attn.adapters.hf_adapter import DATA_PARAMS + + def _flash_attention_forward_v3( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + is_causal: bool, + dropout: float = 0.0, + position_ids: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + use_top_left_mask: bool = False, + softcap: Optional[float] = None, + deterministic: bool = None, + cu_seq_lens_q: Optional[torch.LongTensor] = None, + cu_seq_lens_k: Optional[torch.LongTensor] = None, + max_length_q: Optional[int] = None, + max_length_k: Optional[int] = None, + target_dtype: Optional[torch.dtype] = None, + attn_implementation: Optional[str] = None, + **kwargs, + ): + if not use_top_left_mask: + causal = is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__. + causal = is_causal and query_length != 1 + + # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). + use_sliding_windows = ( + _flash_supports_window + and sliding_window is not None + and key_states.shape[1] > sliding_window + ) + flash_kwargs = ( + {"window_size": (sliding_window, sliding_window)} + if use_sliding_windows + else {} + ) + + if deterministic is None: + deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + flash_kwargs["deterministic"] = deterministic + assert softcap is None, ( + "llama3_flash_attn_varlen_func does not support softcap yet." + ) + # flash_kwargs["softcap"] = softcap + flash_kwargs["group"] = process_group + + # not sure why attention_mask can be not None... + assert causal, "only causal attention is supported yet." + batch_size = query_states.size(0) + assert batch_size == 1, "varlen data should be processed in advance." + + attn_output = llama3_flash_attn_varlen_func( + query_states.squeeze(dim=0), + key_states.squeeze(dim=0), + value_states.squeeze(dim=0), + cu_seqlens_q=DATA_PARAMS["cu_seqlens_q"], + cu_seqlens_k=DATA_PARAMS["cu_seqlens_k"], + max_seqlen_q=DATA_PARAMS["max_seqlen_q"], + max_seqlen_k=DATA_PARAMS["max_seqlen_k"], + heads_k_stride=heads_k_stride, + local_k_slice=DATA_PARAMS["local_k_slice"], + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + + attn_output = attn_output.unsqueeze(dim=0) + + return attn_output + + return [ + _flash_attention_forward_v3, + ] + + +def register_ring_attn_from_device_mesh( + device_mesh: "DeviceMesh", + context_parallel_dim: tuple[str, ...], heads_k_stride: int | None, ring_attn_func: RingAttnFunc | None, ): - """Create ring attention group and substitute flash attn with ring flash attn. + """Create ring attention group using DeviceMesh and substitute flash attn with ring flash attn. Args: - sequence_parallel_degree: Sequence parallelism factor. + device_mesh: DeviceMesh object containing the parallelism topology. + context_parallel_dim: Name of the sequence parallel dimension in the device mesh. heads_k_stride: Sequence parallelism K head stride size. Passed through to `varlen_llama3` `ring_flash_attn` implementation. ring_attn_func: `ring_flash_attn` ring attention implemention. If sample @@ -78,49 +150,55 @@ def register_ring_attn( `batch` function. """ rank = dist.get_rank() - world_size = dist.get_world_size() + + LOG.info( + f"Enabling ring attention sequence parallelism using DeviceMesh " + f"dimension '{context_parallel_dim}'", + main_process_only=True, + ) + + # Extract the sequence parallel submesh + try: + sequence_mesh = device_mesh[context_parallel_dim] + except (KeyError, IndexError) as e: + raise ValueError( + f"Dimension '{context_parallel_dim}' not found in device_mesh. " + f"Available dimensions: {device_mesh.mesh_dim_names}" + ) from e + + # Get the process group for context parallelism + sequence_pg = sequence_mesh.get_group() + context_parallel_size = sequence_mesh.size() if rank == 0: LOG.info( - "Enabling ring attention sequence parallelism: " - f"each sequence will be processed across {sequence_parallel_degree} GPUs" + f"Sequence parallel degree: {context_parallel_size}, " + f"mesh shape: {sequence_mesh.mesh.shape}" ) - assert sequence_parallel_degree <= world_size, ( - f"sequence_parallel_degree ({sequence_parallel_degree}) " - f"must be less than or equal to world_size ({world_size})" - ) - assert world_size % sequence_parallel_degree == 0, ( - f"sequence_parallel_degree ({sequence_parallel_degree}) " - f"must evenly divide world_size ({world_size})" - ) + # Log which ranks are in the current process group + if sequence_pg != dist.GroupMember.WORLD: + ranks_in_group = dist.get_process_group_ranks(sequence_pg) + LOG.info(f"Current sequence parallel group ranks: {ranks_in_group}") - # Assign ranks to sequence parallel groups - group_assignments = {} - for i in range(world_size // sequence_parallel_degree): - ring_attn_ranks = list( - range( - i * sequence_parallel_degree, - (i + 1) * sequence_parallel_degree, - ) - ) - group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") - - # Track which GPUs are in which groups - for r in ring_attn_ranks: - group_assignments[r] = i - - if rank in ring_attn_ranks: - set_ring_attn_group(group) - - # Log the GPU group assignments - if rank == 0: - LOG.info(f"Sequence parallel group assignments: {group_assignments}") + # Set the ring attention group + set_ring_attn_group(sequence_pg) if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3: - from ring_flash_attn import substitute_hf_flash_attn + # fmt: off + import ring_flash_attn.adapters.hf_adapter - substitute_hf_flash_attn( + from ring_flash_attn.adapters.hf_adapter import ( # isort: skip + create_ring_flash_attention_forward as create_ring_flash_attention_forward_orig, + ) + + create_ring_flash_attention_forward_orig = ( # noqa: F811,F841 + create_ring_flash_attention_forward + ) + ring_flash_attn.adapters.hf_adapter.create_ring_flash_attention_forward = create_ring_flash_attention_forward + # fmt: on + + ring_flash_attn.adapters.hf_adapter.substitute_hf_flash_attn( process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1 ) elif ring_attn_func is RingAttnFunc.BATCH_RING: @@ -147,79 +225,3 @@ def update_ring_attn_params(position_ids: torch.Tensor | None): cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids) cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device()) update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) - - -def patch_prepare_data_loader(): - """Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree. - - Raies: - RuntimeError: If source code to patch does not exist. - """ - original_fn = accelerate.data_loader.prepare_data_loader - original_source = inspect.getsource(original_fn) - - if ORIGINAL_PREPARE_DATALOADER_CODE not in original_source: - raise RuntimeError( - "SP patch failed - target snippet not found. " - "Check accelerate's version or update the patch." - ) - - patched_source = original_source.replace( - ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE - ) - - # Create a new function from the patched source - namespace = {} - exec( # pylint: disable=exec-used # nosec B102 - patched_source, accelerate.data_loader.__dict__, namespace - ) - patched_function = namespace["prepare_data_loader"] - - accelerate.data_loader.prepare_data_loader = patched_function - LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support") - - -def patch_prepare_device_mesh(sequence_parallel_degree: int): - """Patches the `Accelerator._prepare_device_mesh` method to create a device mesh - that includes sequence parallelism with the specified degree. - - Args: - sequence_parallel_degree (int): The degree of sequence parallelism to use. - """ - - def _prepare_device_mesh(self): - """Prepare the device mesh for distributed training. The dataloader will - determine how to load data based on the device mesh. - """ - if self.state.torch_tp_plugin: - return self.state.torch_tp_plugin.torch_device_mesh - if ( - self.distributed_type == accelerate.accelerator.DistributedType.DEEPSPEED - and hasattr(self.state, "ds_device_mesh") - ): - return self.state.ds_device_mesh - - # Create device mesh with sequence parallelism - world_size = dist.get_world_size() - mesh_shape = ( - world_size // sequence_parallel_degree, - sequence_parallel_degree, - ) - device_ids = list(range(world_size)) - - # Note that we use "cp" instead of "sp" to match the PyTorch native "context - # parallelism" implementation naming - return dist.DeviceMesh( - "cuda", - torch.tensor(device_ids).reshape(mesh_shape), - mesh_dim_names=("dp", "cp"), - ) - - # Replace the original method with our new method - # pylint: disable=protected-access - accelerate.accelerator.Accelerator._prepare_device_mesh = _prepare_device_mesh - - LOG.info( - "Successfully patched Accelerator._prepare_device_mesh " - f"with sequence_parallel_degree={sequence_parallel_degree}" - ) diff --git a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py index 85454fe2e..0fa6d6424 100644 --- a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py @@ -16,8 +16,8 @@ # This code is based off the following work: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py -# pylint: disable=duplicate-code """PyTorch StableLM Epoch model.""" + import importlib import math from typing import Optional, Tuple, Union @@ -26,7 +26,7 @@ import torch import torch.utils.checkpoint from accelerate import init_empty_weights from einops import rearrange -from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports +from flash_attn.flash_attn_interface import ( flash_attn_varlen_qkvpacked_func, ) from torch import nn @@ -49,27 +49,21 @@ def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e ".configuration_stablelm_epoch", ".modeling_stablelm_epoch" ) modeling_stablelm = importlib.import_module(module_name) - modeling_stablelm.Attention.forward = ( # pylint: disable=protected-access - flashattn_attn - ) - modeling_stablelm.StableLMEpochModel.forward = ( # pylint: disable=protected-access - stablelm_model_forward - ) - modeling_stablelm.DecoderLayer.forward = ( # pylint: disable=protected-access - decoder_layer_forward - ) + modeling_stablelm.Attention.forward = flashattn_attn + modeling_stablelm.StableLMEpochModel.forward = stablelm_model_forward + modeling_stablelm.DecoderLayer.forward = decoder_layer_forward def rotate_half(x: torch.Tensor): """Rotates half the hidden dims of the input.""" - # pylint: disable=invalid-name + x1, x2 = torch.chunk(x, 2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - # pylint: disable=invalid-name + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim] @@ -99,7 +93,7 @@ def flashattn_attn( attention_mask: torch.FloatTensor, position_ids: torch.LongTensor, past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, # pylint: disable=unused-argument + output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[torch.Tensor] = None, @@ -216,7 +210,6 @@ def decoder_layer_forward( ) -> Union[ Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]] ]: - # pylint: disable=duplicate-code residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -263,7 +256,6 @@ def stablelm_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - # pylint: disable=duplicate-code output_attentions = ( output_attentions if output_attentions is not None @@ -326,13 +318,11 @@ def stablelm_model_forward( dtype=torch.bool, device=inputs_embeds.device, ) - attention_mask = ( - self._prepare_decoder_attention_mask( # pylint: disable=protected-access - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, ) hidden_states = inputs_embeds diff --git a/src/axolotl/monkeypatch/tiled_mlp/__init__.py b/src/axolotl/monkeypatch/tiled_mlp/__init__.py new file mode 100644 index 000000000..4ea154991 --- /dev/null +++ b/src/axolotl/monkeypatch/tiled_mlp/__init__.py @@ -0,0 +1,11 @@ +""" +TiledMLP monkey patches +""" + +from .patch import ( + patch_tiled_mlp, +) + +__all__ = [ + "patch_tiled_mlp", +] diff --git a/src/axolotl/monkeypatch/tiled_mlp/base.py b/src/axolotl/monkeypatch/tiled_mlp/base.py new file mode 100644 index 000000000..2c9dc8e4c --- /dev/null +++ b/src/axolotl/monkeypatch/tiled_mlp/base.py @@ -0,0 +1,256 @@ +""" +TiledMLP support for DDP, FSDP, and single GPU +""" + +import threading +from typing import List + +import torch + + +class DeepSpeedTiledMLPMoE(torch.autograd.Function): + @staticmethod + def forward( + ctx, + fn, + self, + x, + shards, + compute_params, + ) -> torch.Tensor: + ctx.fn = fn + ctx.self = self + ctx.shards = shards + ctx.compute_params = [p for p in compute_params if p.requires_grad] + ctx.save_for_backward(x) + + x_shards = list(torch.chunk(x, chunks=shards, dim=1)) + with torch.no_grad(): + output_shards = [fn(self, x_shard) for x_shard in x_shards] + + ctx.is_tuple_output = isinstance(output_shards[0], tuple) + if isinstance(output_shards[0], tuple): + tuple_dim_idx = [1, 0] + output_unsharded = tuple( + torch.cat( + [output_shard[i] for output_shard in output_shards], + dim=tuple_dim_idx[i], + ) + for i in range(len(output_shards[0])) + ) + else: + output_unsharded = torch.cat(output_shards, dim=1) + + return output_unsharded + + @staticmethod + def backward(ctx, *grads) -> torch.Tensor: + fn = ctx.fn + (x,) = ctx.saved_tensors + self = ctx.self + shards = ctx.shards + compute_params = ctx.compute_params + is_tuple_output = ctx.is_tuple_output + + x_requires_grad = x.requires_grad + x = x.detach() + # detach() unsets `x.requires_grad`, so restore it + x.requires_grad_(x_requires_grad) + + incoming_grad = grads[0] + x_grad = torch.zeros_like(x) + x_shards = list(torch.chunk(x, chunks=shards, dim=1)) + + shard_step = x_shards[0].numel() + for i, x_shard in enumerate(x_shards): + # Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run + if compute_params is not None: + if i + 1 < shards: + for param in compute_params: + param.ds_grad_is_ready = False + else: + # last shard, can add the grad + for param in compute_params: + param.ds_grad_is_ready = True + + x_shard.requires_grad_(x_requires_grad) + + shard_offset = i * shard_step + x_shard.grad = ( + x_grad.view(-1) + .narrow(0, shard_offset, x_shard.numel()) + .view_as(x_shard) + ) + incoming_grad_shard = ( + incoming_grad.view(-1) + .narrow(0, shard_offset, x_shard.numel()) + .view_as(x_shard) + ) + with torch.enable_grad(): + output = fn(self, x_shard) + if is_tuple_output: + torch.autograd.backward(output[0], incoming_grad_shard) + else: + torch.autograd.backward(output, incoming_grad_shard) + + return (None, None, x_grad, None, None) + + +class TiledMLP(torch.autograd.Function): + """ + TiledMLP implementation using gradient hooks + """ + + @staticmethod + def forward( + ctx, + fn, + self, + x, + shards, + compute_params, + ) -> torch.Tensor: + ctx.fn = fn + ctx.self = self + ctx.shards = shards + ctx.compute_params = [p for p in compute_params if p.requires_grad] + ctx.save_for_backward(x) + + x_shards = list(torch.chunk(x, chunks=shards, dim=1)) + with torch.no_grad(): + output_shards = [fn(self, x_shard) for x_shard in x_shards] + ctx.is_tuple_output = isinstance(output_shards[0], tuple) + if isinstance(output_shards[0], tuple): + tuple_dim_idx = [1, 0] + output_unsharded = tuple( + torch.cat( + [output_shard[i] for output_shard in output_shards], + dim=tuple_dim_idx[i], + ) + for i in range(len(output_shards[0])) + ) + else: + output_unsharded = torch.cat(output_shards, dim=1) + + return output_unsharded + + @staticmethod + def backward(ctx, *grads) -> torch.Tensor: + fn = ctx.fn + (x,) = ctx.saved_tensors + self = ctx.self + shards = ctx.shards + compute_params = ctx.compute_params + is_tuple_output = ctx.is_tuple_output + + x_requires_grad = x.requires_grad + x = x.detach() + x.requires_grad_(x_requires_grad) + + incoming_grad = grads[0] + x_grad = torch.zeros_like(x) + x_shards = list(torch.chunk(x, chunks=shards, dim=1)) + + # Create a gradient accumulator for parameters + grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype) + + shard_step = x_shards[0].numel() + for i, x_shard in enumerate(x_shards): + x_shard.requires_grad_(x_requires_grad) + + shard_offset = i * shard_step + x_shard.grad = ( + x_grad.view(-1) + .narrow(0, shard_offset, x_shard.numel()) + .view_as(x_shard) + ) + incoming_grad_shard = ( + incoming_grad.view(-1) + .narrow(0, shard_offset, x_shard.numel()) + .view_as(x_shard) + ) + + # Install hooks for this shard + is_last_shard = i + 1 == shards + grad_accumulator.install_hooks(is_last_shard) + + with torch.enable_grad(): + output = fn(self, x_shard) + if is_tuple_output: + torch.autograd.backward(output[0], incoming_grad_shard) + else: + torch.autograd.backward(output, incoming_grad_shard) + + # Clean up hooks + grad_accumulator.cleanup() + del grad_accumulator + + return (None, None, x_grad, None, None) + + +class GradientAccumulator: + """ + Manual gradient accumulator for TiledMLP with configurable precision + Accumulates in specified dtype and rescales the gradient at the end + """ + + def __init__( + self, + params: List[torch.nn.Parameter], + total_shards: int, + dtype: torch.dtype | None = None, + ): + self.params = params + self.total_shards = total_shards + self.grad_accumulation_dtype = dtype or torch.float32 + self.accumulated_grads = {} + self.hooks = [] + self.lock = threading.Lock() + self.gradient_scale = 1.0 / total_shards + + # Initialize accumulated gradients in the specified dtype + for param in self.params: + if param.grad is not None: + self.accumulated_grads[param] = param.grad.to( + self.grad_accumulation_dtype + ) + param.grad = None + else: + self.accumulated_grads[param] = torch.zeros_like( + param, dtype=self.grad_accumulation_dtype + ) + + def install_hooks(self, is_last_shard: bool): + """Install gradient hooks that accumulate gradients in higher precision""" + + def create_hook(param): + def hook(grad): + with self.lock: + grad_to_accum_dtype = grad.to(self.grad_accumulation_dtype) + scaled_grad = grad_to_accum_dtype * self.gradient_scale + + if param in self.accumulated_grads: + self.accumulated_grads[param] += scaled_grad + else: + self.accumulated_grads[param] = scaled_grad.clone() + + # Only assign the averaged gradient on the last shard + if is_last_shard: + param.grad = self.accumulated_grads[param].to(param.dtype) + return param.grad + return None + + return hook + + # Install hooks on all parameters + for param in self.params: + if param.requires_grad: + hook = param.register_hook(create_hook(param)) + self.hooks.append(hook) + + def cleanup(self): + """Remove all installed hooks""" + for hook in self.hooks: + hook.remove() + self.hooks.clear() + del self.accumulated_grads diff --git a/src/axolotl/monkeypatch/tiled_mlp/patch.py b/src/axolotl/monkeypatch/tiled_mlp/patch.py new file mode 100644 index 000000000..c0f89236b --- /dev/null +++ b/src/axolotl/monkeypatch/tiled_mlp/patch.py @@ -0,0 +1,93 @@ +"""Monkeypatch for Tiled MLP implementation""" + +import math +import os + +import torch +import torch.distributed as dist + +from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None): + from deepspeed.runtime.sequence_parallel.ulysses_sp import ( + TiledMLP as DeepSpeedTiledMLP, + ) + + from axolotl.monkeypatch.tiled_mlp.base import DeepSpeedTiledMLPMoE, TiledMLP + + try: + # Dynamically import the module and MLP class + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) + module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"]) + mlp_cls = getattr(module, f"{model_cls_prefix}MLP") + + if use_original_mlp: + mlp_forward = mlp_cls.forward + else: + + def generic_mlp_forward(self_, hs): + return self_.down_proj( + self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs) + ) + + mlp_forward = torch.compile(generic_mlp_forward) + + is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1 + + def tiled_mlp_forward(self, x): + input_shape = x.shape + seqlen = input_shape[-2] + hidden = input_shape[-1] + if cfg_num_shards is None: + num_shards = math.ceil(seqlen / hidden) + if is_distributed: + num_shards_tensor = torch.tensor(num_shards, device=x.device) + dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX) + num_shards = num_shards_tensor.item() + else: + num_shards = cfg_num_shards + + if not self._compute_params: + self._compute_params = [p for p in self.parameters() if p.requires_grad] + + compute_params = self._compute_params + if not self._tiled_mlp_dist_impl: + if ( + self._compute_params + and any( + hasattr(p, "ds_id") or hasattr(p, "param_idx_in_group") + for p in self._compute_params + ) + ) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true": + if model_type == "gpt_oss": + self._tiled_mlp_dist_impl = DeepSpeedTiledMLPMoE + else: + self._tiled_mlp_dist_impl = DeepSpeedTiledMLP + else: + self._tiled_mlp_dist_impl = TiledMLP + + down_res = self._tiled_mlp_dist_impl.apply( + mlp_forward, + self, + x, + num_shards, + compute_params, + ) + return down_res + + mlp_cls.forward = tiled_mlp_forward + mlp_cls._compute_params = [] + mlp_cls._tiled_mlp_dist_impl = None + LOG.info( + f"Successfully monkey-patched TiledMLP for model_type: {model_type}", + main_process_only=True, + ) + except (ImportError, AttributeError) as e: + raise RuntimeError( + f"Could not import MLP class for model_type: {model_type}. Error: {str(e)}" + ) from e diff --git a/src/axolotl/monkeypatch/trainer/lr.py b/src/axolotl/monkeypatch/trainer/lr.py index 9afc23c46..c33674cee 100644 --- a/src/axolotl/monkeypatch/trainer/lr.py +++ b/src/axolotl/monkeypatch/trainer/lr.py @@ -39,4 +39,4 @@ def _get_learning_rate(self): def patch_trainer_get_lr(): from transformers.trainer import Trainer - Trainer._get_learning_rate = _get_learning_rate # pylint: disable=protected-access + Trainer._get_learning_rate = _get_learning_rate diff --git a/src/axolotl/monkeypatch/trainer/trl.py b/src/axolotl/monkeypatch/trainer/trl.py new file mode 100644 index 000000000..bca9f92de --- /dev/null +++ b/src/axolotl/monkeypatch/trainer/trl.py @@ -0,0 +1,13 @@ +"""Monkeypatch for TRL trainer FSDP preparation.""" + + +def prepare_fsdp(model, accelerator): + from axolotl.monkeypatch.accelerate.fsdp2 import fsdp2_prepare_model + + return fsdp2_prepare_model(accelerator, model) + + +def patch_trl_prepare_fsdp2(): + import trl.models.utils + + trl.models.utils.prepare_fsdp = prepare_fsdp diff --git a/src/axolotl/monkeypatch/trainer_accelerator_args.py b/src/axolotl/monkeypatch/trainer_accelerator_args.py index 0a5b27c13..9fc6e38c6 100644 --- a/src/axolotl/monkeypatch/trainer_accelerator_args.py +++ b/src/axolotl/monkeypatch/trainer_accelerator_args.py @@ -18,7 +18,7 @@ ORIGINAL_TRAINER_CODE = """ PATCHED_TRAINER_CODE = """ if hasattr(self, "additional_accelerator_args"): - additional_args = self.additional_accelerator_args(fp8=True, **args) + additional_args = self.additional_accelerator_args(fp8=True, enable_fsdp_float8_all_gather={enable_fsdp_float8_all_gather}, **args) if additional_args: args.update(additional_args) @@ -38,23 +38,24 @@ def check_create_accelerate_code_is_patchable() -> bool: return ORIGINAL_TRAINER_CODE in create_code -def patch_create_accelerate_code_for_fp8(): +def patch_create_accelerate_code_for_fp8(enable_fsdp_float8_all_gather: bool): """ - monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs + Monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs. """ try: create_code = get_create_accelerate_code() except OSError: return - Trainer._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access - create_code - ) + Trainer._original_create_accelerator_and_postprocess = create_code create_code, _ = detab_code(create_code) if ORIGINAL_TRAINER_CODE not in create_code: return - create_code = create_code.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE) + patched_trainer_code = PATCHED_TRAINER_CODE.format( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather + ) + create_code = create_code.replace(ORIGINAL_TRAINER_CODE, patched_trainer_code) create_code = create_code.replace( "def create_accelerator_and_postprocess(", "def fixed_create_accelerator_and_postprocess(", @@ -69,12 +70,14 @@ def patch_create_accelerate_code_for_fp8(): if item in create_code: items_to_import.append(item) - exec( # pylint: disable=exec-used # nosec B102 + exec( "from transformers.trainer import (" + ", ".join(x for x in items_to_import) + ")", globals(), ) - exec(create_code, globals()) # pylint: disable=exec-used # nosec B102 + exec(create_code, globals()) LOG.info("patching create_accelerator_and_postprocess to allow for overrides") - Trainer.create_accelerator_and_postprocess = fixed_create_accelerator_and_postprocess # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821 + Trainer.create_accelerator_and_postprocess = ( + fixed_create_accelerator_and_postprocess + ) diff --git a/src/axolotl/monkeypatch/trainer_eval_guard.py b/src/axolotl/monkeypatch/trainer_eval_guard.py deleted file mode 100644 index 8488a16df..000000000 --- a/src/axolotl/monkeypatch/trainer_eval_guard.py +++ /dev/null @@ -1,78 +0,0 @@ -""" -fix for FSDP2 evals when using torch.compile -""" - -import inspect - -from transformers import Trainer - -from axolotl.monkeypatch.utils import detab_code -from axolotl.utils.logging import get_logger - -LOG = get_logger(__name__) - -ORIGINAL_TRAINER_CODE = """ - model.eval() -""" - -PATCHED_TRAINER_CODE = """ - if hasattr(model, "eval") and callable(model.eval): - self.model.eval() -""" - - -def get_evaluation_loop_code() -> str: - training_loop = inspect.getsource(Trainer.evaluation_loop) - return training_loop - - -def check_evaluation_loop_is_patchable() -> bool: - eval_loop = get_evaluation_loop_code() - eval_loop, _ = detab_code(eval_loop) - return ORIGINAL_TRAINER_CODE in eval_loop - - -def patch_evaluation_loop_for_fsdp2(): - """ - monkeypatch for fixing the eval loop for fsdp2 with torch.compile - """ - - try: - evaluation_loop = get_evaluation_loop_code() - except OSError: - return - Trainer._original_evaluation_loop = ( # pylint: disable=protected-access - evaluation_loop - ) - evaluation_loop, _ = detab_code(evaluation_loop) - if ORIGINAL_TRAINER_CODE not in evaluation_loop: - return - - evaluation_loop = evaluation_loop.replace( - ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE - ) - evaluation_loop = evaluation_loop.replace( - "def evaluation_loop(", - "def _fixed_evaluation_loop(", - 1, - ) - - # load imports necessary - import transformers.trainer - - items_to_import = [] - for item in dir(transformers.trainer): - if item in evaluation_loop: - items_to_import.append(item) - - exec( # pylint: disable=exec-used # nosec B102 - "from transformers.trainer import (" - + ", ".join(x for x in items_to_import) - + ")", - globals(), - ) - exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102 - LOG.info("patching _inner_training_loop for fsdp optimizer save") - Trainer.evaluation_loop = ( # pylint: disable=protected-access - _fixed_evaluation_loop # pylint: disable=undefined-variable # noqa: F821 - ) diff --git a/src/axolotl/monkeypatch/trainer_fsdp_optim.py b/src/axolotl/monkeypatch/trainer_fsdp_optim.py index 4ce5b8ecd..692f754d7 100644 --- a/src/axolotl/monkeypatch/trainer_fsdp_optim.py +++ b/src/axolotl/monkeypatch/trainer_fsdp_optim.py @@ -12,22 +12,18 @@ from axolotl.utils.logging import get_logger LOG = get_logger(__name__) ORIGINAL_TRAINER_CODE = """ - - delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled - + if delay_optimizer_creation: + self.optimizer = self.accelerator.prepare(self.optimizer) """ PATCHED_TRAINER_CODE = """ - - delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled - + if delay_optimizer_creation: + model = self.accelerator.prepare(self.model) """ def get_training_loop_code() -> str: - training_loop = inspect.getsource( - Trainer._inner_training_loop # pylint: disable=protected-access - ) + training_loop = inspect.getsource(Trainer._inner_training_loop) return training_loop @@ -46,9 +42,7 @@ def patch_training_loop_for_fsdp(): training_loop = get_training_loop_code() except OSError: return - Trainer._original_inner_training_loop = ( # pylint: disable=protected-access - training_loop - ) + Trainer._original_inner_training_loop = training_loop training_loop, _ = detab_code(training_loop) if ORIGINAL_TRAINER_CODE not in training_loop: return @@ -68,14 +62,12 @@ def patch_training_loop_for_fsdp(): if item in training_loop: items_to_import.append(item) - exec( # pylint: disable=exec-used # nosec B102 + exec( "from transformers.trainer import (" + ", ".join(x for x in items_to_import) + ")", globals(), ) - exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102 + exec(training_loop, globals()) LOG.info("patching _inner_training_loop for fsdp optimizer save") - Trainer._inner_training_loop = ( # pylint: disable=protected-access - _fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821 - ) + Trainer._inner_training_loop = _fixed_inner_training_loop diff --git a/src/axolotl/monkeypatch/transformers/__init__.py b/src/axolotl/monkeypatch/transformers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py b/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py new file mode 100644 index 000000000..74a35e83f --- /dev/null +++ b/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py @@ -0,0 +1,68 @@ +"""Monkey patch to allow context parallelism with FlashAttention in HF Trainer.""" + +from __future__ import annotations + +import importlib +import inspect + +from transformers import Trainer + +from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":' +PATCHED_GUARD = ( + 'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):' +) + + +def patch_prepare_context_parallel_inputs() -> None: + """Relax the SDPA-only guard when running context parallelism with FlashAttention.""" + if getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False): + LOG.debug("Trainer._prepare_context_parallel_inputs already patched") + return + + try: + original_source = inspect.getsource(Trainer._prepare_context_parallel_inputs) + except OSError as exc: # pragma: no cover - occurs when source is unavailable + LOG.warning("Unable to patch Trainer._prepare_context_parallel_inputs: %s", exc) + return + + if GUARD_PATTERN not in original_source: + LOG.warning( + "Expected guard not found in Trainer._prepare_context_parallel_inputs; \n" + "skipping FlashAttention context parallelism patch" + ) + return + + patched_source = original_source.replace(GUARD_PATTERN, PATCHED_GUARD) + patched_source, _ = detab_code(patched_source) + patched_source = patched_source.replace( + "def _prepare_context_parallel_inputs(", + "def axolotl_prepare_context_parallel_inputs(", + 1, + ) + + module_name = Trainer.__module__ + module = importlib.import_module(module_name) + + # import symbols referenced in the method so exec can succeed + items_to_import = [] + for item in dir(module): + if item in patched_source: + items_to_import.append(item) + + exec(f"from {module_name} import ({', '.join(items_to_import)})", globals()) + exec(patched_source, globals()) + + Trainer._original_prepare_context_parallel_inputs = ( + Trainer._prepare_context_parallel_inputs + ) + Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs + Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source + Trainer._axolotl_prepare_context_parallel_inputs_patched = True + LOG.debug( + "Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP" + ) diff --git a/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py new file mode 100644 index 000000000..b8172bbe6 --- /dev/null +++ b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py @@ -0,0 +1,139 @@ +""" +Module for patching transformers Trainer loss calculation to use nanmean. + +This is needed for context parallelism since chunks of the input sequences may be fully +masked and return NaNs in the loss calculation. + +Also includes a patch for FSDP2 + torch.compile. We need to bundle this together with +the other evaluation_loop patch because we can't patch the same code twice without +raising an OSError. +""" + +import importlib +import inspect + +from transformers import Trainer + +from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +ORIGINAL_EVAL_CODE = { + "list": 'metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()', + "array": 'metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()', +} +PATCHED_EVAL_CODE = { + "list": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(np.concatenate(all_losses)).item()', + "array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()', +} + +ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()" +PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()" + + +def check_evaluation_loop_is_patchable() -> bool: + evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop) + return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values()) + + +def patch_evaluation_loop(): + """Patch the evaluation_loop method.""" + # Check if already patched + if hasattr(Trainer, "_original_evaluation_loop"): + LOG.debug("Trainer.evaluation_loop already patched") + return + + # Check if the patterns exist + try: + evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop) + except OSError: + return + Trainer.evaluation = evaluation_loop_source + evaluation_loop_source, _ = detab_code(evaluation_loop_source) + + # Apply the nanmean patches + evaluation_loop_source = evaluation_loop_source.replace( + ORIGINAL_EVAL_CODE["list"], PATCHED_EVAL_CODE["list"] + ) + evaluation_loop_source = evaluation_loop_source.replace( + ORIGINAL_EVAL_CODE["array"], PATCHED_EVAL_CODE["array"] + ) + + # Rename the function to avoid conflicts + evaluation_loop_source = evaluation_loop_source.replace( + "def evaluation_loop(", + "def axolotl_evaluation_loop(", + 1, + ) + + # Get the module for necessary imports + module_name = Trainer.__module__ + module = importlib.import_module(module_name) + + # Import necessary items from the module + items_to_import = [] + for item in dir(module): + if item in evaluation_loop_source: + items_to_import.append(item) + + # Execute the imports and patched method + exec( + f"from {module_name} import ({', '.join(items_to_import)})", + globals(), + ) + exec(evaluation_loop_source, globals()) + + LOG.debug("Patched Trainer.evaluation_loop with nanmean loss calculation") + Trainer.evaluation_loop = axolotl_evaluation_loop + + +def check_maybe_log_save_evaluate_is_patchable() -> bool: + maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate) + return ORIGINAL_MAYBE_CODE in maybe_log_source + + +def patch_maybe_log_save_evaluate(): + """Patch the _maybe_log_save_evaluate method.""" + # Check if already patched + if hasattr(Trainer, "_original_maybe_log_save_evaluate"): + LOG.info("Trainer._maybe_log_save_evaluate already patched") + return + + # Check if the patterns exist + try: + maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate) + except OSError: + return + Trainer._original_maybe_log_save_evaluate = maybe_log_source + maybe_log_source, _ = detab_code(maybe_log_source) + + # Apply the patch + maybe_log_source = maybe_log_source.replace(ORIGINAL_MAYBE_CODE, PATCHED_MAYBE_CODE) + + # Rename the function to avoid conflicts + maybe_log_source = maybe_log_source.replace( + "def _maybe_log_save_evaluate(", + "def axolotl_maybe_log_save_evaluate(", + 1, + ) + + # Get the module for necessary imports + module_name = Trainer.__module__ + module = importlib.import_module(module_name) + + # Import necessary items from the module + items_to_import = [] + for item in dir(module): + if item in maybe_log_source: + items_to_import.append(item) + + # Execute the imports and patched method + exec( + f"from {module_name} import ({', '.join(items_to_import)})", + globals(), + ) + exec(maybe_log_source, globals()) + + LOG.debug("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation") + Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index 61f4eeea0..59f32c6f5 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -4,12 +4,12 @@ import inspect import types import torch -from accelerate.logging import get_logger from peft import PeftModelForCausalLM from torch import nn from transformers.models.llama.modeling_llama import LlamaFlashAttention2 from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) @@ -17,27 +17,19 @@ ORIGINAL_QKV_CODE = """ query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) -""".lstrip( - "\n" -) +""".lstrip("\n") PATCHED_QKV_CODE = """ query_states, key_states, value_states = self.apply_qkv(self, hidden_states) -""".lstrip( - "\n" -) +""".lstrip("\n") ORIGINAL_O_CODE = """ attn_output = self.o_proj(attn_output) -""".lstrip( - "\n" -) +""".lstrip("\n") PATCHED_O_CODE = """ attn_output = self.apply_o(self, attn_output) -""".lstrip( - "\n" -) +""".lstrip("\n") def original_apply_qkv(self, hidden_states): @@ -66,13 +58,13 @@ def check_self_attn_is_patchable() -> bool: def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None: from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss - def UnslothForCausalLMLoss( # pylint: disable=invalid-name + def UnslothForCausalLMLoss( logits, labels, - vocab_size: int, # pylint: disable=unused-argument + vocab_size: int, num_items_in_batch: int = None, - ignore_index: int = -100, # pylint: disable=unused-argument - **kwargs, # pylint: disable=unused-argument + ignore_index: int = -100, + **kwargs, ): # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() @@ -93,18 +85,16 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None: raise ValueError("Unsupported model type") -self_attn_lora_patched = False # pylint: disable=invalid-name +self_attn_lora_patched = False def patch_self_attn_lora(): - global self_attn_lora_patched # pylint: disable=global-statement + global self_attn_lora_patched if self_attn_lora_patched: # prevent patching multiple times return self_attn_forward = get_self_attn_code() - LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access - self_attn_forward - ) + LlamaFlashAttention2._original_forward = self_attn_forward self_attn_forward, _ = detab_code(self_attn_forward) assert ORIGINAL_QKV_CODE in self_attn_forward, "Original qkv code not found" assert ORIGINAL_O_CODE in self_attn_forward, "Original o code not found" @@ -125,27 +115,25 @@ def patch_self_attn_lora(): if item in self_attn_forward: items_to_import.append(item) - exec( # pylint: disable=exec-used # nosec B102 + exec( "from transformers.models.llama.modeling_llama import (" + ", ".join(x for x in items_to_import) + ")", globals(), ) - exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 + exec(self_attn_forward, globals()) self_attn_lora_patched = True LOG.info("patching unsloth attn lora") - LlamaFlashAttention2.forward = ( - unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821 - ) + LlamaFlashAttention2.forward = unsloth_attn_forward def integrate_rope_embeddings(): import transformers.models.llama.modeling_llama from unsloth.kernels.rope_embedding import fast_rope_embedding - def apply_rotary_pos_emb( # pylint: disable=unused-argument - q, # pylint: disable=invalid-name - k, # pylint: disable=invalid-name + def apply_rotary_pos_emb( + q, + k, cos, sin, position_ids=None, diff --git a/src/axolotl/monkeypatch/xformers_/__init__.py b/src/axolotl/monkeypatch/xformers_/__init__.py index a052ea49e..6f5b43f77 100644 --- a/src/axolotl/monkeypatch/xformers_/__init__.py +++ b/src/axolotl/monkeypatch/xformers_/__init__.py @@ -36,7 +36,7 @@ class FusedMLP(torch.nn.Module): self.swiglu.w3.weight.data = down_proj.weight.data def _post_training(self, model, name): - w1, w2 = torch.split( # pylint: disable=invalid-name + w1, w2 = torch.split( self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0 ) @@ -48,5 +48,5 @@ class FusedMLP(torch.nn.Module): set_module_name(model, name, new_mlp) - def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.swiglu(x) diff --git a/src/axolotl/processing_strategies.py b/src/axolotl/processing_strategies.py index ce9b6a838..07b114163 100644 --- a/src/axolotl/processing_strategies.py +++ b/src/axolotl/processing_strategies.py @@ -5,11 +5,15 @@ from typing import Optional from PIL import Image, ImageOps from PIL.Image import Resampling -from torch import Tensor +from torch import Tensor, zeros_like from transformers import ProcessorMixin from transformers.image_utils import load_image +from transformers.models.smolvlm import SmolVLMProcessor +from transformers.models.voxtral import VoxtralProcessor +from axolotl.utils.dict import remove_none_values from axolotl.utils.logging import get_logger +from axolotl.utils.mistral.mistral3_processor import Mistral3Processor LOG = get_logger(__name__) @@ -137,12 +141,12 @@ class ProcessingStrategy: image_key = key break - # if the image key exists, add the image to the first message + # if the image key exists, add the image to the first user message if image_key is not None and processed_example[image_key] is not None: # TODO: check if it's normal to be single image only for common datasets # From observation, it's usually a list of single image but some datasets may have several columns for images # Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages - if len(processed_example[image_key]) > 0: + if len(processed_example[image_key]) > 1: LOG.warning( f"Found {len(processed_example[image_key])} images in a sample. Using the first one." "If you are using a dataset with multiple images per sample, please convert it to use multi-content Messages." @@ -155,9 +159,9 @@ class ProcessingStrategy: image_value = load_image(image_value) if self.image_size is not None: - assert hasattr( - image_value, "resize" - ), "Image does not have a resize method" + assert hasattr(image_value, "resize"), ( + "Image does not have a resize method" + ) if isinstance(self.image_size, tuple): image_value = image_value.resize( @@ -178,39 +182,56 @@ class ProcessingStrategy: # Look for any image type in the first message # some dataset have an {type: "image"} in the first message + msg_ind_to_add = None ind_to_add = None + first_user_idx = None - for i, content in enumerate( - processed_example["messages"][0]["content"] - ): - # Usually datasets created with image columns, don't have it in the messages itself - if content["type"] == "image" and all( - k not in content for k in ["image", "url", "path", "base64"] + for msg_idx, msg_content in enumerate(processed_example["messages"]): + if first_user_idx is None and msg_content["role"] == "user": + first_user_idx = msg_idx + for i, content in enumerate( + processed_example["messages"][msg_idx]["content"] ): - ind_to_add = i - break + # Usually datasets created with image columns, don't have it in the messages itself + if content["type"] == "image" and all( + k not in content for k in ["image", "url", "path", "base64"] + ): + msg_ind_to_add = msg_idx + ind_to_add = i + break # If an image type is found, add the image to that index - if ind_to_add is not None: - processed_example["messages"][0]["content"][ind_to_add][ - "image" - ] = image_value + if ind_to_add is not None and msg_ind_to_add is not None: + processed_example["messages"][msg_ind_to_add]["content"][ + ind_to_add + ]["image"] = image_value else: - # if no image type is found, add it to end of the first message - processed_example["messages"][0]["content"].append( + # if no image type is found, add it to end of the first user message + if first_user_idx is None: + first_user_idx = 0 + processed_example["messages"][first_user_idx]["content"].append( { "type": "image", "image": image_value, } ) - processed_examples.append(processed_example) + processed_examples.append(remove_none_values(processed_example)) return processed_examples + def _mask_non_assistant(self, labels: Tensor) -> Tensor: + """ + Mask non assistant regions to -100. + To be implemented per subclass. + """ + return labels + def process_labels(self, input_ids: Tensor) -> Tensor: labels = input_ids.clone() + labels = self._mask_non_assistant(labels) + # The labels are the input_ids, and we mask the padding tokens in the loss computation labels[labels == self.processor.tokenizer.pad_token_id] = -100 @@ -264,6 +285,175 @@ class Gemma3ProcessingStrategy(ProcessingStrategy): return labels +class Gemma3nProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Gemma3n""" + + def _mask_non_assistant(self, labels: Tensor) -> Tensor: + def _find_token_sequence(label, start_pos, token_sequence): + """Check if token_sequence appears at start_pos in label""" + if start_pos + len(token_sequence) > len(label): + return False + if label[start_pos] != token_sequence[0]: + return False + return ( + label[start_pos : start_pos + len(token_sequence)].tolist() + == token_sequence + ) + + def _find_assistant_end(label, start_pos, assistant_end_tok, mask, i): + """ + Find the end of assistant response and update mask accordingly + + Returns new position to continue from and whether the end seq is found + """ + k = start_pos + while k < len(label): + if not _find_token_sequence(label, k, assistant_end_tok): + mask[i][k] = 1 + k += 1 + continue + + return k + len(assistant_end_tok), True + + return k, False + + mask = zeros_like(labels) + + assistant_start_str = "model" + assistant_end_str = "" + include_assistant_start_tok = False + include_assistant_end_tok = True + + # str to tokens + assistant_start_tok = self.processor.tokenizer.encode( + assistant_start_str, add_special_tokens=False + ) + assistant_end_tok = self.processor.tokenizer.encode( + assistant_end_str, add_special_tokens=False + ) + + for i, label in enumerate(labels): + j = 0 + # while loop through each tok index in labels[i] + while j < len(label): + # Check until match start seq + if not _find_token_sequence(label, j, assistant_start_tok): + j += 1 + continue + + if include_assistant_start_tok: + mask[i][j : j + len(assistant_start_tok)] = 1 + + # Find where the assistant response ends + start_of_content = j + len(assistant_start_tok) + end_pos, found_end_seq = _find_assistant_end( + label, start_of_content, assistant_end_tok, mask, i + ) + + # Include end token if requested + if include_assistant_end_tok and found_end_seq: + mask[i][end_pos - len(assistant_end_tok) : end_pos] = 1 + + j = end_pos + + labels[i][mask[i] == 0] = -100 + + return labels + + def process_labels(self, input_ids): + labels = input_ids.clone() + labels = self._mask_non_assistant(labels) + + # Follows https://colab.research.google.com/github/huggingface/huggingface-gemma-recipes/blob/main/notebooks/fine_tune_gemma3n_on_t4.ipynb + labels[labels == self.processor.tokenizer.pad_token_id] = -100 + if hasattr(self.processor.tokenizer, "image_token_id"): + labels[labels == self.processor.tokenizer.image_token_id] = -100 + if hasattr(self.processor.tokenizer, "audio_token_id"): + labels[labels == self.processor.tokenizer.audio_token_id] = -100 + if hasattr(self.processor.tokenizer, "boi_token_id"): + labels[labels == self.processor.tokenizer.boi_token_id] = -100 + if hasattr(self.processor.tokenizer, "eoi_token_id"): + labels[labels == self.processor.tokenizer.eoi_token_id] = -100 + + return labels + + +class VoxtralProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Voxtral""" + + def __init__( + self, + processor: VoxtralProcessor, + chat_template: Optional[str] = None, + image_size: int | tuple[int, int] | None = None, + image_resize_algorithm: Resampling | None = None, + ): + super().__init__(processor, chat_template, image_size, image_resize_algorithm) + special_ids = ( + processor.tokenizer.tokenizer.instruct_tokenizer.audio_encoder.special_ids + ) + + self.audio_token = special_ids.audio + self.begin_audio_token = special_ids.begin_audio + + def process_labels(self, input_ids): + labels = input_ids.clone() + + labels[labels == self.processor.tokenizer.pad_token_id] = -100 + labels[labels == self.audio_token] = -100 + labels[labels == self.begin_audio_token] = -100 + + return labels + + +class SmolVLM2ProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for SmolVLM2""" + + def __init__( + self, + processor: ProcessorMixin, + chat_template: Optional[str] = None, + image_size: int | tuple[int, int] | None = None, + image_resize_algorithm: Resampling | None = None, + ): + super().__init__(processor, chat_template, image_size, image_resize_algorithm) + self.image_token = "" # nosec + + self.image_token_id = processor.tokenizer.additional_special_tokens_ids[ + processor.tokenizer.additional_special_tokens.index(self.image_token) + ] + + +class Mistral3ProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Mistral3""" + + def __init__( + self, + processor: Mistral3Processor, + chat_template: Optional[str] = None, + image_size: int | tuple[int, int] | None = None, + image_resize_algorithm: Resampling | None = None, + ): + super().__init__(processor, chat_template, image_size, image_resize_algorithm) + special_ids = ( + processor.tokenizer.tokenizer.instruct_tokenizer.image_encoder.special_ids + ) + + self.image_token = special_ids.img + self.image_break_token = special_ids.img_break + self.image_end_token = special_ids.img_end + + def process_labels(self, input_ids): + labels = input_ids.clone() + + labels[labels == self.processor.tokenizer.pad_token_id] = -100 + labels[labels == self.image_token] = -100 + labels[labels == self.image_break_token] = -100 + labels[labels == self.image_end_token] = -100 + + return labels + + def get_processing_strategy( processor: ProcessorMixin, chat_template, @@ -271,22 +461,48 @@ def get_processing_strategy( image_size: int | tuple[int, int] | None = None, image_resize_algorithm: Resampling | None = None, ): + processing_kwargs = { + "processor": processor, + "chat_template": chat_template, + "image_size": image_size, + "image_resize_algorithm": image_resize_algorithm, + } + + if chat_template_type in [None, "tokenizer_default"] and hasattr( + processor.tokenizer, "chat_template" + ): + processing_kwargs["chat_template"] = processor.tokenizer.chat_template + if chat_template_type == "qwen2_vl": return Qwen2VLProcessingStrategy( - processor, chat_template, image_size, image_resize_algorithm + **processing_kwargs, ) if chat_template_type == "gemma3": return Gemma3ProcessingStrategy( - processor, chat_template, image_size, image_resize_algorithm + **processing_kwargs, ) - if chat_template_type in [ - "llama3_2_vision", - "llama4", - "llava", - "mistral_v7_tekken", - "pixtral", - ]: - return ProcessingStrategy( - processor, chat_template, image_size, image_resize_algorithm + if chat_template_type == "gemma3n": + return Gemma3nProcessingStrategy( + **processing_kwargs, ) - raise ValueError(f"Unsupported chat template type: {chat_template_type}") + + if isinstance(processor, VoxtralProcessor): + return VoxtralProcessingStrategy( + **processing_kwargs, + ) + + if isinstance(processor, SmolVLMProcessor): + return SmolVLM2ProcessingStrategy( + **processing_kwargs, + ) + + if isinstance(processor, Mistral3Processor): + return Mistral3ProcessingStrategy( + **processing_kwargs, + ) + + # llama3_2_vision, llama4, llava + # mistral_v7_tekken, pixtral, lfm2vl + return ProcessingStrategy( + **processing_kwargs, + ) diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index 3cdbbb6f3..d9936b9ae 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -17,7 +17,10 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None): return messages_load(tokenizer, cfg, ds_cfg, processor=processor) load_fn = "load" package = "axolotl.prompt_strategies" - if strategy.split(".")[-1].startswith("load_"): + if ( + strategy.split(".")[-1].startswith("load_") + or strategy.split(".")[-1] == "load" + ): load_fn = strategy.split(".")[-1] strategy = ".".join(strategy.split(".")[:-1]) elif len(strategy.split(".")) > 1: @@ -45,6 +48,6 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None): return func(tokenizer, cfg, **load_kwargs) except ModuleNotFoundError: return None - except Exception as exc: # pylint: disable=broad-exception-caught + except Exception as exc: LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}") raise exc diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index 975fee889..391ba6072 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -39,7 +39,7 @@ class AlpacaChatPrompter(AlpacaPrompter): system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n" - def __init__(self): # pylint: disable=super-init-not-called + def __init__(self): self.prompt_style = PromptStyle.CHAT.value self.match_prompt_style() @@ -54,7 +54,7 @@ class NoSystemPrompter(AlpacaPrompter): turn_format = "{instruction} {input} " turn_no_input_format = "{instruction} " - def __init__(self): # pylint: disable=super-init-not-called + def __init__(self): pass diff --git a/src/axolotl/prompt_strategies/alpaca_w_system.py b/src/axolotl/prompt_strategies/alpaca_w_system.py index 6873c8e08..808ba517e 100644 --- a/src/axolotl/prompt_strategies/alpaca_w_system.py +++ b/src/axolotl/prompt_strategies/alpaca_w_system.py @@ -22,10 +22,9 @@ class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy): ) def tokenize_prompt(self, prompt): - # pylint: disable=duplicate-code ( instruction, - input, # pylint: disable=redefined-builtin + input, response, system, ) = self.parse_instruction_fields(prompt) @@ -64,7 +63,7 @@ class SystemDataPrompter(AlpacaPrompter): self, system: str, instruction: str, - input: Union[None, str] = None, # pylint: disable=redefined-builtin + input: Union[None, str] = None, output: Union[None, str] = None, ) -> Generator[str, None, None]: # returns the full prompt from instruction and optional input @@ -93,7 +92,6 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter): """ def match_prompt_style(self): - # pylint: disable=duplicate-code if self.prompt_style == PromptStyle.INSTRUCT.value: self.turn_format = "### Human:\n{instruction}\n### Additional Context:\n{input}\n### Assistant:\n" self.turn_no_input_format = "### Human:\n{instruction}\n### Assistant:\n" diff --git a/src/axolotl/prompt_strategies/base.py b/src/axolotl/prompt_strategies/base.py index 370a51a95..45a3ffda9 100644 --- a/src/axolotl/prompt_strategies/base.py +++ b/src/axolotl/prompt_strategies/base.py @@ -29,6 +29,6 @@ def load(strategy, cfg, module_base=None, **kwargs): mod = importlib.import_module(strategy, module_base) func = getattr(mod, load_fn) return func(cfg, **kwargs) - except Exception: # pylint: disable=broad-exception-caught + except Exception: LOG.warning(f"unable to load strategy {strategy}") return None diff --git a/src/axolotl/prompt_strategies/bradley_terry/__init__.py b/src/axolotl/prompt_strategies/bradley_terry/__init__.py index 7530aee19..7336edc71 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/__init__.py +++ b/src/axolotl/prompt_strategies/bradley_terry/__init__.py @@ -10,7 +10,6 @@ LOG = get_logger(__name__) def load(strategy, tokenizer, cfg, ds_cfg): - # pylint: disable=duplicate-code try: load_fn = "load" if strategy.split(".")[-1].startswith("load_"): @@ -30,6 +29,6 @@ def load(strategy, tokenizer, cfg, ds_cfg): return func(tokenizer, cfg, **load_kwargs) except ModuleNotFoundError: return None - except Exception as exc: # pylint: disable=broad-exception-caught + except Exception as exc: LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}") return None diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py index e655f85a1..fd0d76f51 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -34,7 +34,6 @@ class BTChatTemplateStrategy(ChatTemplateStrategy): max_length = self.prompter.max_length - # pylint: disable=duplicate-code prompt["messages"] = [] if prompt["system"]: prompt["messages"].append({"role": "system", "content": prompt["system"]}) @@ -52,7 +51,6 @@ class BTChatTemplateStrategy(ChatTemplateStrategy): :max_length ] - # pylint: disable=duplicate-code prompt["messages"] = [] if prompt["system"]: prompt["messages"].append({"role": "system", "content": prompt["system"]}) diff --git a/src/axolotl/prompt_strategies/bradley_terry/llama3.py b/src/axolotl/prompt_strategies/bradley_terry/llama3.py index 1d586fd5f..5548d882e 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/llama3.py +++ b/src/axolotl/prompt_strategies/bradley_terry/llama3.py @@ -6,7 +6,7 @@ chatml transforms for datasets with system, input, chosen, rejected to match lla def icr( cfg, **kwargs, -): # pylint: disable=possibly-unused-variable,unused-argument +): """ chatml transforms for datasets with system, input, chosen, rejected ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index a0fd8d911..f4dcbd7cd 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -2,8 +2,9 @@ HF Chat Templates prompt strategy """ +import json from collections import defaultdict -from typing import Any, Dict, List, Set, Union +from typing import TYPE_CHECKING, Any, Dict, List, Set, Union from pydantic import BaseModel from transformers import ProcessorMixin @@ -12,9 +13,13 @@ from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnaly from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompters import IGNORE_TOKEN_ID, Prompter from axolotl.utils.chat_templates import get_chat_template_from_config +from axolotl.utils.dict import remove_none_values from axolotl.utils.logging import get_logger from axolotl.utils.schemas.datasets import DatasetConfig +if TYPE_CHECKING: + from axolotl.utils.mistral import HFMistralTokenizer + # Configure the logger LOG = get_logger(__name__) LOG.setLevel("INFO") @@ -34,7 +39,10 @@ class ChatTemplatePrompter(Prompter): message_field_training_detail: str | None = None, field_messages: str = "messages", field_system: str = "system", + field_tools: str = "tools", + field_thinking: str = "reasoning_content", roles: dict[str, list[str]] | None = None, + template_thinking_key: str | None = "reasoning_content", chat_template_kwargs: dict[str, Any] | None = None, drop_system_message: bool = False, ): @@ -43,8 +51,9 @@ class ChatTemplatePrompter(Prompter): message_property_mappings = { "role": "role", "content": "content", - "reasoning_content": "reasoning_content", } + if template_thinking_key and field_thinking: + message_property_mappings[template_thinking_key] = field_thinking if roles: self.roles = {s: t for t, sources in roles.items() for s in sources} @@ -66,10 +75,13 @@ class ChatTemplatePrompter(Prompter): self.message_field_training_detail = message_field_training_detail self.field_messages = field_messages self.field_system = field_system + self.field_tools = field_tools + self.field_thinking = field_thinking self.tokenizer = tokenizer self.processor: ProcessorMixin | None = processor self.chat_template = chat_template self.chat_template_kwargs = chat_template_kwargs or {} + self.template_thinking_key: str = template_thinking_key or "reasoning_content" self.max_length = max_length self.drop_system_message = drop_system_message @@ -77,36 +89,64 @@ class ChatTemplatePrompter(Prompter): def chat_template_msg_variables(self) -> Set[str]: return self._chat_template_msg_variables - def build_prompt(self, conversation, add_generation_prompt=False, images=None): + def build_prompt( + self, + conversation: list[dict], + add_generation_prompt=False, + images=None, + tools=None, + ): + """ + Build a prompt from a conversation. + + Args: + conversation: A list of messages. + add_generation_prompt: Whether to add a generation prompt. + images: A list of images. (optional) + tools: A list of tools. (optional) + """ + chat_template_kwargs = { + "chat_template": self.chat_template, + "add_generation_prompt": add_generation_prompt, + **self.chat_template_kwargs, + } + + if tools: + chat_template_kwargs["tools"] = tools + if self.processor: if not callable(self.processor): raise TypeError("Processor must be callable") text = self.processor.apply_chat_template( conversation, - chat_template=self.chat_template, tokenize=False, - add_generation_prompt=add_generation_prompt, - **self.chat_template_kwargs, + **chat_template_kwargs, ) batch = self.processor( text=text, images=images, return_tensors="pt", ) + if hasattr(batch, "to_dict"): + batch = batch.to_dict() + else: + batch = dict(batch) + # workaround since processor works in batches instead of single examples + out = {} for k, val in batch.items(): - if k in ["pixel_values"]: - batch[k] = val.tolist() + if hasattr(val, "tolist"): + out[k] = ( + val.tolist() if k == "pixel_values" else val.squeeze(0).tolist() + ) else: - batch[k] = val.squeeze().tolist() - return batch + out[k] = val + return out return self.tokenizer.apply_chat_template( conversation, - add_generation_prompt=add_generation_prompt, - chat_template=self.chat_template, - **self.chat_template_kwargs, + **chat_template_kwargs, ) def get_offsets_for_train_detail( @@ -250,9 +290,15 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos # Default to eos_token if eot_tokens not provided - self.eot_tokens = ( - eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token] - ) + self.eot_tokens = [] + if eot_tokens is not None: + self.eot_tokens = eot_tokens + elif ( + hasattr(self.tokenizer, "eos_token") + and self.tokenizer.eos_token is not None + ): + self.eot_tokens = [self.tokenizer.eos_token] + self.split_thinking = split_thinking self.images = "images" @@ -346,6 +392,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): Public method that can handle either a single prompt or a batch of prompts. """ + prompt = remove_none_values(prompt) + if not self.is_prompt_batched(prompt) or not self.supports_batched: return self._tokenize_single_prompt(prompt) @@ -353,9 +401,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): feature_names = list(prompt.keys()) # Process each prompt individually - for row in zip(*prompt.values()): + for row in zip(*prompt.values(), strict=False): tokenized_prompt = self._tokenize_single_prompt( - dict(zip(feature_names, row)) + dict(zip(feature_names, row, strict=False)) ) for key, val in tokenized_prompt.items(): res[key].append(val) @@ -376,15 +424,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): and not self.prompter.message_field_training_detail # type: ignore ): turns = self.get_conversation_thread(prompt) - images = self.get_images(prompt) + images = self._get_images(prompt) prompt_ids = self.prompter.build_prompt( # type: ignore turns[:-1], add_generation_prompt=True, images=images, ) - tokenized_res = self.prompter.build_prompt( - turns, images=images - ) # type: ignore + tokenized_res = self.prompter.build_prompt(turns, images=images) # type: ignore tokenized_prompt = {} if isinstance(tokenized_res, list): input_ids = prompt_ids + tokenized_res[len(prompt_ids) :] @@ -392,10 +438,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): tokenized_prompt["attention_mask"] = [1] * len(input_ids) else: input_ids = tokenized_res["input_ids"] - tokenized_prompt = tokenized_res + tokenized_prompt = dict(tokenized_res) if not self.train_on_inputs: - user_prompt_len = len(prompt_ids) + if isinstance(prompt_ids, dict): + user_prompt_len = len(prompt_ids["input_ids"]) + else: + user_prompt_len = len(prompt_ids) labels = [-100] * user_prompt_len + input_ids[user_prompt_len:] else: labels = input_ids @@ -405,7 +454,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return tokenized_prompt turns = self.get_conversation_thread(prompt) - input_ids = self.prompter.build_prompt(turns) # type: ignore + tools = self._get_tools(prompt) + input_ids = self.prompter.build_prompt(turns, tools=tools) # type: ignore labels = [IGNORE_TOKEN_ID] * len(input_ids) last_eos_idx = -1 @@ -444,12 +494,20 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): continue - turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index) + turn_start_idx, turn_end_idx = self.find_turn( + turns=turns, turn_idx=index, tools=tools + ) LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}") if should_train and turn_start_idx != -1 and turn_end_idx != -1: if train_detail: + # Block multi-content for now + if not isinstance(content, str): + raise ValueError( + "`train_detail` is not supported when `content` is not a string." + ) + token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore content, train_detail ) @@ -546,11 +604,12 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return i return -1 - def find_turn(self, turns: list[dict], turn_idx: int): + def find_turn( + self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None + ): """ Locate the starting and ending indices of the specified turn in a conversation. """ - # pylint: disable=too-many-return-statements if turn_idx >= len(turns): raise ValueError(f"Turn index {turn_idx} out of range") @@ -559,11 +618,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): if ( turn_idx == 0 and turns[0].get("role") == "system" - and ( - "mistral" in self.tokenizer.name_or_path.lower() - or "gemma" - in self.tokenizer.name_or_path.lower() # gemma3 uses gemma tokenizer - ) + and ("mistral" in self.tokenizer.name_or_path.lower()) ): return -1, -1 @@ -577,10 +632,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): turns_with_content = turns[: turn_idx + 1] # Generate the conversation up to the turn, with final turn replaced with dummy content - dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore + dummy_ids = self.prompter.build_prompt(turns_with_empty, tools=tools) # type: ignore # Generate the conversation up to the turn, with final turn included - full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore + full_ids = self.prompter.build_prompt(turns_with_content, tools=tools) # type: ignore if not full_ids or not dummy_ids: LOG.warning(f"Empty template generated for turn {turn_idx}") @@ -633,9 +688,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): def get_conversation_thread(self, prompt): turns = [] - possible_sys_turn = self.transform_message( - prompt[self.prompter.field_messages][0] - ) + messages = self._get_messages(prompt) + + possible_sys_turn = self.transform_message(messages[0]) + if ( possible_sys_turn["role"] != "system" and self.prompter.field_system in prompt @@ -643,16 +699,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): turn = {"role": "system", "content": prompt[self.prompter.field_system]} turns.append(turn) - for message in prompt[self.prompter.field_messages]: + for message in messages: transformed_message = self.transform_message(message) - turn = { - **transformed_message, - "training": message.get(self.prompter.message_field_training), - "training_detail": message.get( - self.prompter.message_field_training_detail - ), - } + turn = transformed_message + + training = message.get(self.prompter.message_field_training) + training_detail = message.get(self.prompter.message_field_training_detail) + if training is not None: + turn["training"] = training + if training_detail is not None: + turn["training_detail"] = training_detail turns.append(turn) @@ -661,7 +718,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return turns - def transform_message(self, message): + def transform_message(self, message: dict) -> dict: # Build the initial transformed message from the mappings transformed_message = {} for key, value in self.prompter.message_property_mappings.items(): @@ -697,7 +754,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): # get the thinking content thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx] - transformed_message["reasoning_content"] = thinking_content.strip() + transformed_message[self.prompter.template_thinking_key] = ( + thinking_content.strip() + ) # take remainder of the content # strip whitespace from beginning of the remainder (thinking tokens) @@ -736,20 +795,144 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): if val is not None: transformed_message[key] = val + if "tool_calls" in transformed_message and transformed_message["tool_calls"]: + for tool_call in transformed_message["tool_calls"]: + if "function" in tool_call and "arguments" in tool_call["function"]: + args = tool_call["function"]["arguments"] + if isinstance(args, str): + try: + tool_call["function"]["arguments"] = json.loads(args) + except json.JSONDecodeError as e: + LOG.error( + f"Error parsing tool_calls arguments as JSON. " + f"Function: {tool_call.get('function', {}).get('name', 'unknown')}, " + f"Arguments string: {args!r}, " + f"Error: {e}" + ) + raise + return transformed_message - def get_images(self, prompt): + def _get_images(self, prompt): return prompt.get(self.images, None) + def _get_tools(self, prompt) -> list[dict] | None: + """Get tools from prompt if available.""" + tools = prompt.get(self.prompter.field_tools, None) + if tools is None: + return None + + if isinstance(tools, list): + return tools + + raise ValueError( + "Unknown tools format. Please convert it into a list[dict].\n" + f"Current format: {type(tools)}" + ) + + def _get_messages(self, prompt): + messages = prompt.get(self.prompter.field_messages, None) + if messages is None: + raise ValueError("Messages is null. Please check `field_messages`.") + + if isinstance(messages, list): + return messages + + raise ValueError( + "Unknown messages format. Please convert it into a list[dict].\n" + f"Current format: {type(messages)}" + ) + + +class MistralStrategy(ChatTemplateStrategy): + """ + Mistral strategy for chat template. + """ + + def __init__( + self, + prompter: "ChatTemplatePrompter", + tokenizer: "HFMistralTokenizer", + train_on_inputs: bool, + sequence_len: int, + roles_to_train: list[str] | None = None, + train_on_eos: str | None = None, + train_on_eot: str | None = None, + eot_tokens: list[str] | None = None, + split_thinking: bool | None = False, + ): + # Call the parent's parent __init__ (PromptTokenizingStrategy) to skip ChatTemplateStrategy's validation + + PromptTokenizingStrategy.__init__( + self, prompter, tokenizer, train_on_inputs, sequence_len + ) + self.prompter: ChatTemplatePrompter = prompter + + self.roles_to_train = [] + if roles_to_train: + # map roles if exist in prompter.roles else use the role as is + self.roles_to_train = [ + prompter.roles.get(role, role) for role in roles_to_train + ] + + self.train_on_eos = train_on_eos + # Backward compatibility, load from train_on_eos + self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos + + # Default to eos_token if eot_tokens not provided + self.eot_tokens = [] + if eot_tokens is not None: + self.eot_tokens = eot_tokens + else: + # set eot_tokens to the eos_token + self.eot_tokens = [self.tokenizer.eos_token] + + self.split_thinking = split_thinking + + self.images = "images" + + LOG.debug( + f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}" + ) + + # Skip the validation that ChatTemplateStrategy calls + # TODO: address this in the future with mistral-specific checks + # self._validate_eot_and_eos_tokens() + + def find_first_eot_token(self, input_ids, start_idx): + """Find the first EOT token in the input_ids starting from start_idx.""" + # mistral-common tokenizer does not support eot_tokens + return self.find_first_eos_token(input_ids, start_idx) + + +class MistralPrompter(ChatTemplatePrompter): + """ + Mistral prompter for chat template. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._chat_template_msg_variables = set(["tool_call_id", "name", "tool_calls"]) + class StrategyLoader: """ Load chat template strategy based on configuration. """ - def _get_strategy_cls(self): + def _get_strategy_cls(self, cfg): + if cfg.tokenizer_use_mistral_common: + return MistralStrategy + return ChatTemplateStrategy + def _get_prompter_cls(self, cfg): + if cfg.tokenizer_use_mistral_common: + return MistralPrompter + + return ChatTemplatePrompter + def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]): return { "train_on_inputs": cfg.train_on_inputs, @@ -775,9 +958,14 @@ class StrategyLoader: else: dataset_config = ds_cfg - chat_template_string = get_chat_template_from_config( - cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer - ) + if cfg.tokenizer_use_mistral_common: + # mistral-common does not use this, so we pass an empty string + chat_template_string = "" + else: + chat_template_string = get_chat_template_from_config( + cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer + ) + LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---") prompter_params = { @@ -795,6 +983,10 @@ class StrategyLoader: None, ), "field_messages": dataset_config.get("field_messages", "messages"), + "field_thinking": dataset_config.get("field_thinking", "reasoning_content"), + "template_thinking_key": dataset_config.get( + "template_thinking_key", "reasoning_content" + ), "roles": dataset_config.get("roles"), "drop_system_message": dataset_config.get("drop_system_message", False), # we need to add one for detecting sequences with exceeding the `sequence_len` limit. @@ -803,10 +995,11 @@ class StrategyLoader: } strategy_params = self._get_strategy_params(cfg, dataset_config) - strategy_cls = self._get_strategy_cls() + strategy_cls = self._get_strategy_cls(cfg) + prompter_cls = self._get_prompter_cls(cfg) strategy = strategy_cls( - ChatTemplatePrompter(**prompter_params), + prompter_cls(**prompter_params), tokenizer=tokenizer, **strategy_params, ) diff --git a/src/axolotl/prompt_strategies/completion.py b/src/axolotl/prompt_strategies/completion.py index 62a4b90b2..f43f25793 100644 --- a/src/axolotl/prompt_strategies/completion.py +++ b/src/axolotl/prompt_strategies/completion.py @@ -42,8 +42,8 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): def tokenize_prompt(self, prompt): res = defaultdict(lambda: []) feature_names = list(prompt.keys()) - for row in zip(*prompt.values()): - prompt_row = dict(zip(feature_names, row)) + for row in zip(*prompt.values(), strict=False): + prompt_row = dict(zip(feature_names, row, strict=False)) ( instruction, _, @@ -59,9 +59,7 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): return dict(res) - def _build_full_prompt( - self, instruction, input, response - ): # pylint: disable=redefined-builtin + def _build_full_prompt(self, instruction, input, response): return next(iter(self.prompter.build_prompt(instruction, input, response))) @@ -73,8 +71,8 @@ class CompletionPrompter: def build_prompt( self, instruction: str, - input=None, # pylint: disable=redefined-builtin, unused-argument - output=None, # pylint: disable=unused-argument + input=None, + output=None, ) -> Generator[str, None, None]: yield instruction diff --git a/src/axolotl/prompt_strategies/context_qa.py b/src/axolotl/prompt_strategies/context_qa.py index aac44e0b2..09e96d26e 100644 --- a/src/axolotl/prompt_strategies/context_qa.py +++ b/src/axolotl/prompt_strategies/context_qa.py @@ -86,7 +86,6 @@ class ContextV2Prompter(AlpacaPrompter): system_no_input_prompt = "" def match_prompt_style(self): - # pylint: disable=duplicate-code self.turn_format = "{instruction}\n{input}" self.turn_no_input_format = "{instruction}" self.system_format = "{system}" diff --git a/src/axolotl/prompt_strategies/creative_acr.py b/src/axolotl/prompt_strategies/creative_acr.py index ea67034b3..3e016e30e 100644 --- a/src/axolotl/prompt_strategies/creative_acr.py +++ b/src/axolotl/prompt_strategies/creative_acr.py @@ -134,9 +134,7 @@ class CreativePrompterBase: def build_prompt( self, instruction: str, - input: Union[ # pylint: disable=redefined-builtin, unused-argument - None, str - ] = None, + input: Union[None, str] = None, output: Union[None, str] = None, ) -> Generator[str, None, None]: if self.system_prompt: diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py index f3427022f..85c4d2182 100644 --- a/src/axolotl/prompt_strategies/dpo/chat_template.py +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -6,9 +6,7 @@ from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_te from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic -def default( - cfg, dataset_idx=0, **kwargs -): # pylint: disable=possibly-unused-variable,unused-argument +def default(cfg, dataset_idx=0, **kwargs): ds_cfg = cfg["datasets"][dataset_idx] ds_cfg = handle_legacy_message_fields_logic(ds_cfg) @@ -46,6 +44,14 @@ def default( ) messages = sample[field_messages] + if isinstance(messages, str): + messages = [ + { + message_property_mappings["role"]: "user", + message_property_mappings["content"]: messages, + } + ] + messages = [ { "role": role_map[m[message_property_mappings["role"]]], @@ -53,13 +59,35 @@ def default( } for m in messages ] + + chosen_raw = sample[field_chosen] + if isinstance(chosen_raw, str): + chosen_msg = { + message_property_mappings["role"]: "assistant", + message_property_mappings["content"]: chosen_raw, + } + elif isinstance(chosen_raw, dict): + chosen_msg = chosen_raw + else: + chosen_msg = chosen_raw[-1] chosen = { - "role": role_map[sample[field_chosen][message_property_mappings["role"]]], - "content": sample[field_chosen][message_property_mappings["content"]], + "role": role_map[chosen_msg[message_property_mappings["role"]]], + "content": chosen_msg[message_property_mappings["content"]], } + + rejected_raw = sample[field_rejected] + if isinstance(rejected_raw, str): + rejected_msg = { + message_property_mappings["role"]: "assistant", + message_property_mappings["content"]: rejected_raw, + } + elif isinstance(rejected_raw, dict): + rejected_msg = rejected_raw + else: + rejected_msg = rejected_raw[-1] rejected = { - "role": role_map[sample[field_rejected][message_property_mappings["role"]]], - "content": sample[field_rejected][message_property_mappings["content"]], + "role": role_map[rejected_msg[message_property_mappings["role"]]], + "content": rejected_msg[message_property_mappings["content"]], } dummy_user_message = {"role": "user", "content": "[[dummy_message]]"} diff --git a/src/axolotl/prompt_strategies/dpo/chatml.py b/src/axolotl/prompt_strategies/dpo/chatml.py index 34a54aaa0..8614708eb 100644 --- a/src/axolotl/prompt_strategies/dpo/chatml.py +++ b/src/axolotl/prompt_strategies/dpo/chatml.py @@ -6,7 +6,7 @@ DPO strategies for chatml def default( cfg, **kwargs, -): # pylint: disable=possibly-unused-variable,unused-argument +): def transform_fn(sample): if "prompt" in sample.keys(): prompt_key = "prompt" @@ -46,7 +46,7 @@ def default( def argilla_chat( cfg, **kwargs, -): # pylint: disable=possibly-unused-variable,unused-argument +): """ for argilla/dpo-mix-7k conversations """ @@ -65,7 +65,7 @@ def argilla_chat( def icr( cfg, **kwargs, -): # pylint: disable=possibly-unused-variable,unused-argument +): """ chatml transforms for datasets with system, input, chosen, rejected ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs @@ -88,7 +88,7 @@ def icr( return transform_fn -def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument +def intel(cfg, **kwargs): """ For Intel Orca DPO Pairs """ @@ -110,9 +110,7 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg return transform_fn -def prompt_pairs( - cfg, **kwargs -): # pylint: disable=possibly-unused-variable,unused-argument +def prompt_pairs(cfg, **kwargs): def transform_fn(sample): if "system" in sample and sample["system"]: sample["prompt"] = ( @@ -130,7 +128,7 @@ def prompt_pairs( return transform_fn -def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument +def ultra(cfg, **kwargs): """ for ultrafeedback binarized conversations """ diff --git a/src/axolotl/prompt_strategies/dpo/llama3.py b/src/axolotl/prompt_strategies/dpo/llama3.py index eed420017..c13ff55e4 100644 --- a/src/axolotl/prompt_strategies/dpo/llama3.py +++ b/src/axolotl/prompt_strategies/dpo/llama3.py @@ -6,9 +6,8 @@ DPO strategies for llama-3 chat template def default( cfg, **kwargs, -): # pylint: disable=possibly-unused-variable,unused-argument +): def transform_fn(sample): - # pylint: disable=duplicate-code if "prompt" in sample.keys(): prompt_key = "prompt" elif "input" in sample.keys(): @@ -47,7 +46,7 @@ def default( def argilla_chat( cfg, **kwargs, -): # pylint: disable=possibly-unused-variable,unused-argument +): """ for argilla/dpo-mix-7k conversations """ @@ -66,7 +65,7 @@ def argilla_chat( def icr( cfg, **kwargs, -): # pylint: disable=possibly-unused-variable,unused-argument +): """ chatml transforms for datasets with system, input, chosen, rejected ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs @@ -89,7 +88,7 @@ def icr( return transform_fn -def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument +def intel(cfg, **kwargs): """ For Intel Orca DPO Pairs """ @@ -111,9 +110,7 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg return transform_fn -def prompt_pairs( - cfg, **kwargs -): # pylint: disable=possibly-unused-variable,unused-argument +def prompt_pairs(cfg, **kwargs): def transform_fn(sample): if "system" in sample and sample["system"]: sample["prompt"] = ( @@ -131,7 +128,7 @@ def prompt_pairs( return transform_fn -def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument +def ultra(cfg, **kwargs): """ for ultrafeedback binarized conversations """ diff --git a/src/axolotl/prompt_strategies/dpo/passthrough.py b/src/axolotl/prompt_strategies/dpo/passthrough.py index 1fcb838db..52b5ceac1 100644 --- a/src/axolotl/prompt_strategies/dpo/passthrough.py +++ b/src/axolotl/prompt_strategies/dpo/passthrough.py @@ -3,12 +3,8 @@ DPO prompt strategies passthrough/zero-processing strategy """ -def default( - cfg, dataset_idx=0, **kwargs -): # pylint: disable=possibly-unused-variable,unused-argument - def transform_fn( - sample, tokenizer=None - ): # pylint: disable=possibly-unused-variable,unused-argument +def default(cfg, dataset_idx=0, **kwargs): + def transform_fn(sample, tokenizer=None): return sample return transform_fn diff --git a/src/axolotl/prompt_strategies/dpo/user_defined.py b/src/axolotl/prompt_strategies/dpo/user_defined.py index 1d5f891af..0bcb1d94c 100644 --- a/src/axolotl/prompt_strategies/dpo/user_defined.py +++ b/src/axolotl/prompt_strategies/dpo/user_defined.py @@ -3,7 +3,7 @@ User-defined DPO strategies """ -def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument +def default(cfg, dataset_idx=0, **kwargs): ds_cfg = cfg["datasets"][dataset_idx]["type"] if not isinstance(ds_cfg, dict): raise ValueError( @@ -33,7 +33,7 @@ def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument system=sample[field_system], prompt=sample[field_prompt] ) else: - sample["prompt"] = prompt_format.format(prompt=sample["prompt"]) + sample["prompt"] = prompt_format.format(prompt=sample[field_prompt]) sample["chosen"] = chosen_format.format(chosen=sample[field_chosen]) sample["rejected"] = rejected_format.format(rejected=sample[field_rejected]) return sample diff --git a/src/axolotl/prompt_strategies/dpo/zephyr.py b/src/axolotl/prompt_strategies/dpo/zephyr.py index 9eb895009..781227181 100644 --- a/src/axolotl/prompt_strategies/dpo/zephyr.py +++ b/src/axolotl/prompt_strategies/dpo/zephyr.py @@ -3,14 +3,11 @@ DPO strategies for zephyr """ -def nectar(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument +def nectar(cfg, **kwargs): def transform_fn(sample): data = {} data["prompt"] = ( - "<|system|>\n\n" - "<|user|>\n" - f"{sample['prompt']}\n" - "<|assistant|>\n" + f"<|system|>\n\n<|user|>\n{sample['prompt']}\n<|assistant|>\n" ) answers = sorted(sample["answers"], key=lambda x: x["rank"]) data["chosen"] = answers[-1]["answer"] diff --git a/src/axolotl/prompt_strategies/input_output.py b/src/axolotl/prompt_strategies/input_output.py index 8be745b20..c84eecffc 100644 --- a/src/axolotl/prompt_strategies/input_output.py +++ b/src/axolotl/prompt_strategies/input_output.py @@ -16,7 +16,6 @@ class RawInputOutputStrategy(PromptTokenizingStrategy): self.eos_token = self.tokenizer.eos_token def tokenize_prompt(self, prompt): - # pylint: disable=duplicate-code input_ids = [] labels = [] for label, text in self.prompter.build_prompt(prompt["segments"]): diff --git a/src/axolotl/prompt_strategies/jinja_template_analyzer.py b/src/axolotl/prompt_strategies/jinja_template_analyzer.py index a5f89cfe5..e16a1e22b 100644 --- a/src/axolotl/prompt_strategies/jinja_template_analyzer.py +++ b/src/axolotl/prompt_strategies/jinja_template_analyzer.py @@ -3,6 +3,7 @@ from typing import Dict, Optional, Set, TypedDict, Union from jinja2 import Environment, meta, nodes +from jinja2.ext import Extension class JinjaTemplateAnalysis(TypedDict): @@ -27,6 +28,18 @@ class JinjaTemplateAnalysis(TypedDict): iteration_target: Optional[Union[str, list[str]]] +class GenerationTagIgnore(Extension): + """ + Ignores the generation and endgeneration tags in Jinja templates. + """ + + tags = {"generation", "endgeneration"} + + def parse(self, parser): + parser.stream.skip(1) + return nodes.Const("") + + class JinjaTemplateAnalyzer: """ Analyzes Jinja templates to extract information about variable usage, @@ -57,7 +70,9 @@ class JinjaTemplateAnalyzer: """ def __init__(self, template: str): - self.env: Environment = Environment(autoescape=True) + self.env: Environment = Environment( + autoescape=True, extensions=[GenerationTagIgnore] + ) self.property_access: Dict[str, Set[str]] = {} self.iteration_targets: Dict[str, Union[str, list[str]]] = {} self.index_access: Dict[str, Set[Union[int, float]]] = {} diff --git a/src/axolotl/prompt_strategies/kto/chatml.py b/src/axolotl/prompt_strategies/kto/chatml.py index 97ae59ed5..945940f3f 100644 --- a/src/axolotl/prompt_strategies/kto/chatml.py +++ b/src/axolotl/prompt_strategies/kto/chatml.py @@ -2,13 +2,11 @@ KTO strategies for chatml """ -# pylint: disable=duplicate-code - def argilla( cfg, **kwargs, -): # pylint: disable=possibly-unused-variable,unused-argument +): def transform_fn(sample): if "system" in sample and sample["system"]: sample["prompt"] = ( @@ -28,7 +26,7 @@ def argilla( def argilla_chat( cfg, **kwargs, -): # pylint: disable=possibly-unused-variable,unused-argument +): """ for argilla/kto-mix-15k conversations """ @@ -43,7 +41,7 @@ def argilla_chat( return transform_fn -def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument +def intel(cfg, **kwargs): """ For Intel Orca KTO ex: argilla/distilabel-intel-orca-kto @@ -65,9 +63,7 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg return transform_fn -def prompt_pairs( - cfg, **kwargs -): # pylint: disable=possibly-unused-variable,unused-argument +def prompt_pairs(cfg, **kwargs): def transform_fn(sample): if "system" in sample and sample["system"]: sample["prompt"] = ( @@ -84,7 +80,7 @@ def prompt_pairs( return transform_fn -def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument +def ultra(cfg, **kwargs): """ for ultrafeedback binarized conversations ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto diff --git a/src/axolotl/prompt_strategies/kto/llama3.py b/src/axolotl/prompt_strategies/kto/llama3.py index fde3c2ed4..9061f6f5e 100644 --- a/src/axolotl/prompt_strategies/kto/llama3.py +++ b/src/axolotl/prompt_strategies/kto/llama3.py @@ -2,13 +2,11 @@ KTO strategies for llama-3 chat template """ -# pylint: disable=duplicate-code - def argilla( cfg, **kwargs, -): # pylint: disable=possibly-unused-variable,unused-argument +): def transform_fn(sample): if "system" in sample and sample["system"]: sample["prompt"] = ( @@ -28,7 +26,7 @@ def argilla( def argilla_chat( cfg, **kwargs, -): # pylint: disable=possibly-unused-variable,unused-argument +): """ for argilla/kto-mix-15k conversations """ @@ -43,7 +41,7 @@ def argilla_chat( return transform_fn -def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument +def intel(cfg, **kwargs): """ For Intel Orca KTO ex: argilla/distilabel-intel-orca-kto @@ -65,9 +63,7 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg return transform_fn -def prompt_pairs( - cfg, **kwargs -): # pylint: disable=possibly-unused-variable,unused-argument +def prompt_pairs(cfg, **kwargs): def transform_fn(sample): if "system" in sample and sample["system"]: sample["prompt"] = ( @@ -84,7 +80,7 @@ def prompt_pairs( return transform_fn -def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument +def ultra(cfg, **kwargs): """ for ultrafeedback binarized conversations ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto diff --git a/src/axolotl/prompt_strategies/kto/user_defined.py b/src/axolotl/prompt_strategies/kto/user_defined.py index 7c68a3000..e26683cde 100644 --- a/src/axolotl/prompt_strategies/kto/user_defined.py +++ b/src/axolotl/prompt_strategies/kto/user_defined.py @@ -2,10 +2,8 @@ User-defined KTO strategies """ -# pylint: disable=duplicate-code - -def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument +def default(cfg, dataset_idx=0, **kwargs): ds_cfg = cfg["datasets"][dataset_idx]["type"] if not isinstance(ds_cfg, dict): raise ValueError( diff --git a/src/axolotl/prompt_strategies/llama2_chat.py b/src/axolotl/prompt_strategies/llama2_chat.py index eef2e1d4d..9eff062ec 100644 --- a/src/axolotl/prompt_strategies/llama2_chat.py +++ b/src/axolotl/prompt_strategies/llama2_chat.py @@ -153,7 +153,7 @@ class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy): } -class Llama2ChatPrompter: # pylint: disable=too-few-public-methods +class Llama2ChatPrompter: """ A prompter that generates prompts for Llama2 models. """ @@ -190,7 +190,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods # Skip the first one if it is not from human source = source[1:] - conv.messages = [] # pylint: disable=R0801 + conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], ALTERNATING_ASSERTION_FAILED_ROLE diff --git a/src/axolotl/prompt_strategies/messages/__init__.py b/src/axolotl/prompt_strategies/messages/__init__.py index cc7b84da1..2c920a568 100644 --- a/src/axolotl/prompt_strategies/messages/__init__.py +++ b/src/axolotl/prompt_strategies/messages/__init__.py @@ -11,7 +11,7 @@ LOG = get_logger(__name__) def load(tokenizer, cfg, ds_cfg, processor=None): try: strategy = ds_cfg.get("input_transform", "chat") - # pylint: disable=duplicate-code + load_fn = "load" if strategy.split(".")[-1].startswith("load_"): load_fn = strategy.split(".")[-1] @@ -29,7 +29,6 @@ def load(tokenizer, cfg, ds_cfg, processor=None): return func(tokenizer, cfg, **load_kwargs) except ModuleNotFoundError: return None - except Exception as exc: # pylint: disable=broad-exception-caught + except Exception as exc: LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}") raise exc - return None diff --git a/src/axolotl/prompt_strategies/messages/chat.py b/src/axolotl/prompt_strategies/messages/chat.py index eaed2396a..854d25e42 100644 --- a/src/axolotl/prompt_strategies/messages/chat.py +++ b/src/axolotl/prompt_strategies/messages/chat.py @@ -19,7 +19,7 @@ class ChatMessageDatasetWrappingStrategy(DatasetWrappingStrategy): processor, message_transform=None, formatter=None, - **kwargs, # pylint: disable=unused-argument + **kwargs, ): """ :param processor: tokenizer or image processor @@ -35,7 +35,7 @@ class ChatMessageDatasetWrappingStrategy(DatasetWrappingStrategy): dataset, process_count: Optional[int] = None, keep_in_memory: Optional[bool] = False, - **kwargs, # pylint: disable=unused-argument + **kwargs, ): self.dataset = TokenizedChatDataset( dataset, @@ -72,9 +72,10 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): builder_kwargs["message_field_training"] = message_field_training chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml")) - format_message = ( - lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment - ) + + def format_message(x): + return x + if chat_template == "chatml": from axolotl.core.chat.format.chatml import format_message # noqa F811 if chat_template.startswith("llama3"): diff --git a/src/axolotl/prompt_strategies/metharme.py b/src/axolotl/prompt_strategies/metharme.py index 66da72389..35f1ef3b3 100644 --- a/src/axolotl/prompt_strategies/metharme.py +++ b/src/axolotl/prompt_strategies/metharme.py @@ -10,8 +10,6 @@ LOG = get_logger(__name__) IGNORE_TOKEN_ID = -100 -# pylint: disable=duplicate-code - class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): """ @@ -66,7 +64,7 @@ class MetharmePrompter(AlpacaPrompter): turn_format = "{instruction}" turn_no_input_format = "{instruction}" - def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called + def __init__(self, *args, **kwargs): pass diff --git a/src/axolotl/prompt_strategies/orpo/chat_template.py b/src/axolotl/prompt_strategies/orpo/chat_template.py index fdee28ea1..b655bc970 100644 --- a/src/axolotl/prompt_strategies/orpo/chat_template.py +++ b/src/axolotl/prompt_strategies/orpo/chat_template.py @@ -23,9 +23,7 @@ class MessageList(BaseModel): messages: List[Message] -def load( - tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, **kwargs -): # pylint: disable=possibly-unused-variable,unused-argument +def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, **kwargs): """ chatml transforms for datasets with system, input, chosen, rejected """ @@ -219,29 +217,38 @@ class ORPOPrompter(Prompter): for message in message_list.messages: conversation.append(message.model_dump()) if message.role == "system": - yield self.tokenizer.apply_chat_template( - conversation, - add_generation_prompt=False, - chat_template=self.chat_template, - tokenize=False, - ), False + yield ( + self.tokenizer.apply_chat_template( + conversation, + add_generation_prompt=False, + chat_template=self.chat_template, + tokenize=False, + ), + False, + ) if message.role == "user": - yield self.tokenizer.apply_chat_template( - conversation, - add_generation_prompt=True, - chat_template=self.chat_template, - tokenize=False, - ), False + yield ( + self.tokenizer.apply_chat_template( + conversation, + add_generation_prompt=True, + chat_template=self.chat_template, + tokenize=False, + ), + False, + ) if message.role == "assistant": - yield self.tokenizer.apply_chat_template( - conversation, - add_generation_prompt=False, - chat_template=self.chat_template, - tokenize=False, - ), True + yield ( + self.tokenizer.apply_chat_template( + conversation, + add_generation_prompt=False, + chat_template=self.chat_template, + tokenize=False, + ), + True, + ) -def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument +def argilla(cfg, **kwargs): dataset_parser = ORPODatasetParsingStrategy() def transform_fn(sample, tokenizer=None): diff --git a/src/axolotl/prompt_strategies/pygmalion.py b/src/axolotl/prompt_strategies/pygmalion.py index 51f92f397..8c53a5f27 100644 --- a/src/axolotl/prompt_strategies/pygmalion.py +++ b/src/axolotl/prompt_strategies/pygmalion.py @@ -69,7 +69,6 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): LOG.warning(f"unknown role in conversation: {role}") res = defaultdict(lambda: []) - # pylint: disable=duplicate-code result, current_len = parse_tokenized_to_result( result, current_len, @@ -89,7 +88,10 @@ class PygmalionPrompter: pass def build_prompt( - self, source, *args, **kwargs # pylint: disable=unused-argument + self, + source, + *args, + **kwargs, ) -> Generator[Tuple[str, str], None, None]: for msg in source: yield msg["role"], msg["value"] diff --git a/src/axolotl/prompt_strategies/stepwise_supervised.py b/src/axolotl/prompt_strategies/stepwise_supervised.py index 8be7c35e3..9175126e7 100644 --- a/src/axolotl/prompt_strategies/stepwise_supervised.py +++ b/src/axolotl/prompt_strategies/stepwise_supervised.py @@ -66,7 +66,7 @@ class StepwiseSupervisedPromptTokenizingStrategy: # Create step-wise labels labels = [ [IGNORE_INDEX] * (len(completion) - 1) + [label] # type: ignore - for completion, label in zip(completions_ids, labels) + for completion, label in zip(completions_ids, labels, strict=False) ] # Join all steps diff --git a/src/axolotl/prompt_strategies/user_defined.py b/src/axolotl/prompt_strategies/user_defined.py index e20e80c3a..0bff514e7 100644 --- a/src/axolotl/prompt_strategies/user_defined.py +++ b/src/axolotl/prompt_strategies/user_defined.py @@ -83,16 +83,12 @@ def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None): cfg.sequence_len, ) - setattr( - strat, - "parse_instruction_fields", - partial( - parse_instruction_fields, - ds_cfg.field_instruction, - ds_cfg.field_input, - ds_cfg.field_output, - ds_cfg.field_system, - system_prompt, - ), + strat.parse_instruction_fields = partial( # type: ignore[method-assign] + parse_instruction_fields, + ds_cfg.field_instruction, + ds_cfg.field_input, + ds_cfg.field_output, + ds_cfg.field_system, + system_prompt, ) return strat diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index cb1a1ba4e..a7bd963f8 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -3,6 +3,7 @@ import abc from typing import Callable, Dict, List, Optional, Tuple, Union +from datasets import Dataset from transformers import BatchEncoding, PreTrainedTokenizer from axolotl.prompters import Prompter @@ -28,6 +29,16 @@ class DatasetWrappingStrategy(abc.ABC): Abstract class for wrapping datasets for Chat Messages """ + @abc.abstractmethod + def wrap_dataset( + self, + dataset, + process_count: int | None = None, + keep_in_memory: bool | None = False, + **kwargs, + ) -> Dataset: + pass + class PromptTokenizingStrategy(abc.ABC): """ @@ -64,7 +75,7 @@ class PromptTokenizingStrategy(abc.ABC): ) -> BatchEncoding: empty = BatchEncoding(data={"input_ids": [], "attention_mask": []}) if not prompt: - LOG.warning("Empty text requested for tokenization.") + LOG.warning_once("Empty text requested for tokenization.") return empty result = self.tokenizer( @@ -107,7 +118,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): def tokenize_prompt(self, prompt): ( instruction, - input, # pylint: disable=redefined-builtin + input, response, ) = self.parse_instruction_fields(prompt) user_prompt = next( @@ -133,7 +144,10 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): return tokenized_prompt def _build_full_prompt( - self, instruction, input, response # pylint: disable=redefined-builtin + self, + instruction, + input, + response, ): return next( iter( @@ -246,10 +260,9 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): raise NotImplementedError def tokenize_prompt(self, prompt): - # pylint: disable=duplicate-code ( instruction, - input, # pylint: disable=redefined-builtin + input, output, reflection, corrected, @@ -276,9 +289,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): return tokenized_full_prompt - def _build_full_prompt( - self, instruction, input, output, reflection, corrected - ): # pylint: disable=redefined-builtin + def _build_full_prompt(self, instruction, input, output, reflection, corrected): return next( iter( self.prompter.build_prompt( diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index d29da075e..9543996f7 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -46,7 +46,6 @@ class AlpacaPrompter(Prompter): self.match_prompt_style() def match_prompt_style(self): - # pylint: disable=duplicate-code if self.prompt_style == PromptStyle.INSTRUCT.value: self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" self.turn_no_input_format = ( @@ -93,7 +92,7 @@ class AlpacaPrompter(Prompter): def build_prompt( self, instruction: str, - input: Union[None, str] = None, # pylint: disable=redefined-builtin + input: Union[None, str] = None, output: Union[None, str] = None, ) -> Generator[str, None, None]: yield self._build_result(instruction, input, output) @@ -218,7 +217,7 @@ class ReflectAlpacaPrompter(Prompter): def _build_result( self, instruction: str, - input: Union[None, str] = None, # pylint: disable=redefined-builtin + input: Union[None, str] = None, output: Union[None, str] = None, reflection: Union[None, str] = None, corrected: Union[None, str] = None, @@ -242,12 +241,11 @@ class ReflectAlpacaPrompter(Prompter): def build_prompt( self, instruction: str, - input: Union[None, str] = None, # pylint: disable=redefined-builtin + input: Union[None, str] = None, output: Union[None, str] = None, reflection: Union[None, str] = None, corrected: Union[None, str] = None, ) -> Generator[str, None, None]: - # pylint: disable=duplicate-code yield self._build_result( instruction, input, diff --git a/src/axolotl/train.py b/src/axolotl/train.py index f4f4a5a91..6cbcb9aec 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -1,18 +1,23 @@ """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" +from __future__ import annotations + import importlib import inspect +import json import os +import shutil import signal import sys +import typing import weakref +from collections import OrderedDict from contextlib import ExitStack from pathlib import Path from typing import Any, Dict import torch import transformers.modelcard -from accelerate.utils import save_fsdp_model from datasets import Dataset from huggingface_hub.errors import OfflineModeIsEnabled from peft import PeftConfig, PeftModel @@ -20,18 +25,12 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.trainer import Trainer -from axolotl.cli.art import print_axolotl_text_art from axolotl.common.datasets import TrainDatasetMeta from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) -from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.integrations.base import PluginManager -from axolotl.loaders import ( - ModelLoader, - load_processor, - load_tokenizer, -) +from axolotl.loaders import ModelLoader, load_processor, load_tokenizer from axolotl.telemetry.errors import send_errors from axolotl.telemetry.manager import TelemetryManager from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager @@ -40,12 +39,11 @@ from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RLType +from axolotl.utils.train import determine_last_checkpoint from axolotl.utils.trainer import setup_trainer -try: - from optimum.bettertransformer import BetterTransformer -except ImportError: - BetterTransformer = None +if typing.TYPE_CHECKING: + from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder LOG = get_logger(__name__) @@ -58,8 +56,8 @@ def setup_model_and_tokenizer( ) -> tuple[ PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None ]: - """ - Load the tokenizer, processor (for multimodal models), and model based on configuration. + """Load the tokenizer, processor (for multimodal models), and model based on + configuration. Args: cfg: Dictionary mapping `axolotl` config keys to values. @@ -80,11 +78,8 @@ def setup_model_and_tokenizer( if cfg.is_multimodal: processor = load_processor(cfg, tokenizer) - # Load the model and peft_config - msg = "loading model" - if cfg.adapter: - msg += " and peft_config..." - LOG.debug(msg) + # Load the model + LOG.debug("Loading model") model_loader = ModelLoader(cfg, tokenizer, processor=processor) model, peft_config = model_loader.load() @@ -131,38 +126,15 @@ def setup_reference_model( LOG.debug("Passing model_ref: None to RL trainer") model_ref = None # explicit setting to None else: + reference_model: bool = True + if cfg.rl == RLType.GRPO and cfg.trl.beta == 0: + reference_model = False # load the model again for model_ref/baseline - model_loader = ModelLoader(cfg, tokenizer, reference_model=True) + model_loader = ModelLoader(cfg, tokenizer, reference_model=reference_model) model_ref, _ = model_loader.load() return model_ref -def determine_resume_checkpoint(cfg: DictDefault) -> str | None: - """ - Determine the checkpoint to resume from based on configuration. - - Args: - cfg: Dictionary mapping `axolotl` config keys to values. - - Returns: - Path to the checkpoint to resume from, or `None` if not resuming. - """ - if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: - possible_checkpoints = [ - str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") - ] - if len(possible_checkpoints) > 0: - sorted_paths = sorted( - possible_checkpoints, - key=lambda path: int(path.split("-")[-1]), - ) - cfg.resume_from_checkpoint = sorted_paths[-1] - LOG.info( - f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" - ) - return cfg.resume_from_checkpoint - - def setup_signal_handler( cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool ): @@ -180,8 +152,6 @@ def setup_signal_handler( def terminate_handler(_, __, model_weakref): if model_weakref() is not None: _model = model_weakref() - if cfg.flash_optimum and BetterTransformer: - _model = BetterTransformer.reverse(_model) _model.save_pretrained( cfg.output_dir, safe_serialization=safe_serialization ) @@ -218,7 +188,7 @@ def execute_training( ) ) - if cfg.sequence_parallel_degree > 1: + if cfg.context_parallel_size > 1: models = [trainer.model] if hasattr(trainer, "ref_model") and trainer.ref_model: models.append(trainer.ref_model) @@ -226,16 +196,25 @@ def execute_training( stack.enter_context( SequenceParallelContextManager( models=models, - sequence_parallel_degree=cfg.sequence_parallel_degree, + context_parallel_size=cfg.context_parallel_size, gradient_accumulation_steps=cfg.gradient_accumulation_steps, ring_attn_func=cfg.ring_attn_func, heads_k_stride=cfg.heads_k_stride, + gather_outputs=cfg.rl is RLType.GRPO, + device_mesh=trainer.accelerator.torch_device_mesh, ) ) + # TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers + # if cfg.bf16: + # torch.set_default_dtype(torch.bfloat16) + LOG.info("Starting trainer...") trainer.train(resume_from_checkpoint=resume_from_checkpoint) + plugin_manager = PluginManager.get_instance() + plugin_manager.post_train(cfg, trainer.model) + def save_trained_model( cfg: DictDefault, @@ -257,54 +236,80 @@ def save_trained_model( # Post training module hooks for name, module in model.named_modules(): if hasattr(module, "_post_training"): - module._post_training(model, name) # pylint: disable=protected-access + module._post_training(model, name) # handle QAT if cfg.qat: - from axolotl.utils.quantization import convert_qat_model_for_ptq + from axolotl.utils.quantization import convert_qat_model - LOG.info("Processing QAT model for saving...") - convert_qat_model_for_ptq( + convert_qat_model( model, quantize_embedding=cfg.qat.quantize_embedding, ) LOG.info( - "QAT modules have been converted for PTQ. Please ensure you quantize " - "your model weights with `axolotl quantize`." + "QAT usage note: please ensure you quantize your model fine-tuned using QAT by running `axolotl quantize`" + " with the same config which you used for training." ) - - # Handle FSDP state dict type - state_dict_type = "FULL_STATE_DICT" - if trainer.is_fsdp_enabled and str(cfg.fsdp_config.fsdp_version) != "2": - if cfg.fsdp_final_state_dict_type: - state_dict_type = cfg.fsdp_final_state_dict_type - trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type) - LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.") - # Handle ReLoRA early return case - if cfg.relora_steps: + if cfg.relora: if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): model = model.merge_and_unload() else: # final model weights have already been saved by `ReLoRACallback.on_train_end` return - if cfg.fsdp: - # TODO: do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading - # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple - # processes attempt to write the same file - if ( - state_dict_type == "SHARDED_STATE_DICT" - and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT" - ): - save_fsdp_model( - trainer.accelerator.state.fsdp_plugin, - trainer.accelerator, - trainer.model, - cfg.output_dir, + if trainer.is_fsdp_enabled or cfg.fsdp_config: + if cfg.fsdp_config or cfg.fsdp: + if cfg.fsdp_config.final_state_dict_type: + state_dict_type = cfg.fsdp_config.final_state_dict_type + else: + state_dict_type = cfg.fsdp_config.state_dict_type + trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type) + trainer.save_model(cfg.output_dir) # only handles FULL_STATE_DICT + if state_dict_type == "SHARDED_STATE_DICT": + LOG.info( + "The final model was saved with a sharded state dict. Please ensure you merge " + "the sharded weights with `merge-sharded-fsdp-weights`." ) - elif state_dict_type == "FULL_STATE_DICT": - trainer.save_model(cfg.output_dir) + checkpoint_dir = determine_last_checkpoint(cfg, update=False) + if ( + not (Path(cfg.output_dir) / "model.safetensors.index.json").exists() + and checkpoint_dir + ): + # import here to prevent circular import + from axolotl.cli.merge_sharded_fsdp_weights import merge_fsdp_weights + + fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0" + merged_path = str(Path(cfg.output_dir) / "merged") + merge_fsdp_weights( + checkpoint_dir=str(fsdp_dir), + output_path=merged_path, + safe_serialization=True, + ) + trainer.accelerator.wait_for_everyone() + if trainer.accelerator.is_main_process: + # move all files in merged_path to cfg.output_dir + for merged_file in Path(merged_path).iterdir(): + if (Path(cfg.output_dir) / merged_file.name).exists(): + (Path(cfg.output_dir) / merged_file.name).unlink() + shutil.move(str(merged_file), cfg.output_dir) + shutil.rmtree(merged_path) # remove what should be an empty dir + # TODO(wing):see https://github.com/huggingface/transformers/pull/40207 + # cleanup the FSDP prefix in the model config.json + if trainer.accelerator.is_main_process: + with open( + Path(cfg.output_dir) / "config.json", "r", encoding="utf-8" + ) as config_file_io: + # read the model config as an OrderedDict + config = json.load(config_file_io, object_pairs_hook=OrderedDict) + config["architectures"] = [ + name.lstrip("FSDP") for name in config["architectures"] + ] + # write the updated model config back + with open( + os.path.join(cfg.output_dir, "config.json"), "w", encoding="utf-8" + ) as config_file_io: + json.dump(config, config_file_io, indent=2) elif cfg.deepspeed and is_deepspeed_zero3_enabled(): # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading trainer.accelerator.wait_for_everyone() @@ -325,9 +330,6 @@ def save_trained_model( except FileNotFoundError: pass elif cfg.local_rank == 0: - if cfg.flash_optimum and BetterTransformer: - model = BetterTransformer.reverse(model) - if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model: trainer.model.save_pretrained( cfg.output_dir, safe_serialization=safe_serialization @@ -337,9 +339,7 @@ def save_trained_model( if hasattr(cfg, "llmcompressor") and cfg.llmcompressor: # TODO: add integration support so this can be implemented completely within the plugin - from axolotl.integrations.llm_compressor.utils import ( - save_compressed_model, - ) + from axolotl.integrations.llm_compressor.utils import save_compressed_model save_compressed_model( model=model, @@ -416,7 +416,9 @@ def save_initial_configs( # Pre-save the tokenizer and model configs LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...") - tokenizer.save_pretrained(str(output_dir)) + tokenizer.save_pretrained( + str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files + ) if hasattr(model, "config"): LOG.info(f"Pre-saving model config to {cfg.output_dir}...") model.config.save_pretrained(str(output_dir)) @@ -436,7 +438,7 @@ def setup_model_card(cfg: DictDefault): badge_markdown = """[Built with Axolotl](https://github.com/axolotl-ai-cloud/axolotl)""" transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" - if getattr(cfg, "axolotl_config_path"): + if cfg.axolotl_config_path: raw_axolotl_cfg = Path(cfg.axolotl_config_path) version = importlib.metadata.version("axolotl") if raw_axolotl_cfg.is_file(): @@ -487,8 +489,10 @@ def handle_untrained_tokens_fix( ) -def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[ - HFRLTrainerBuilder | HFCausalTrainerBuilder, +def setup_model_and_trainer( + cfg: DictDefault, dataset_meta: TrainDatasetMeta +) -> tuple[ + "HFRLTrainerBuilder" | "HFCausalTrainerBuilder", PeftModel | PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, @@ -535,6 +539,20 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> ) PLUGIN_MANAGER.post_trainer_create(cfg, trainer) + plugin_manager = PluginManager.get_instance() + plugin_manager.post_trainer_create(cfg, trainer) + + if cfg.use_ray: + try: + import ray.train.huggingface.transformers + + trainer = ray.train.huggingface.transformers.prepare_trainer(trainer) + except ImportError: + LOG.warning( + "The Ray integration with Hugging Face Transformers is not available. " + "To use Ray, install the 'ray[train]' package." + ) + return ( trainer, model, @@ -558,8 +576,6 @@ def train( Returns: Tuple of (model, tokenizer) after training """ - print_axolotl_text_art() - # Setup model, tokenizer, (causal or RLHF) trainer, etc. ( trainer, @@ -569,12 +585,6 @@ def train( processor, ) = setup_model_and_trainer(cfg, dataset_meta) - # Determine if we need to resume from a checkpoint - resume_from_checkpoint = determine_resume_checkpoint(cfg) - - # Configuration for saving - safe_serialization = cfg.save_safetensors is True - # Handle untrained tokens if configured safe_serialization = cfg.save_safetensors is True train_dataset = dataset_meta.train_dataset @@ -588,10 +598,18 @@ def train( setup_model_card(cfg) # Execute the training + resume_from_checkpoint = determine_last_checkpoint(cfg) execute_training(cfg, trainer, resume_from_checkpoint) + # clear cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # Save the trained model and cleanup save_trained_model(cfg, trainer, model, safe_serialization) + tokenizer.save_pretrained( + str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files + ) create_model_card(cfg, trainer) if not cfg.use_ray: cleanup_distributed() diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index 3d0ba7c9c..7256a5700 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -17,7 +17,6 @@ def is_comet_available(): return importlib.util.find_spec("comet_ml") is not None -# pylint: disable=duplicate-code def get_pytorch_version() -> tuple[int, int, int]: """ Get Pytorch version as a tuple of (major, minor, patch). @@ -45,10 +44,8 @@ def set_pytorch_cuda_alloc_conf(): ) -def patch_optimized_env(): +def get_not_null(value, default=None): """ - Patch environment variables to improve VRAM usage and increase download speed + return the value if it's not None, otherwise return the default value """ - if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None: - os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" - set_pytorch_cuda_alloc_conf() + return value if value is not None else default diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index d1e972c81..0a4594991 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -1,6 +1,7 @@ """Benchmarking and measurement utilities""" import functools +import logging import torch from transformers.utils.import_utils import is_torch_npu_available @@ -56,16 +57,17 @@ def gpu_memory_usage(device=0): @check_cuda_device((0.0, 0.0, 0.0)) def gpu_memory_usage_all(device=0): - usage = torch.cuda.memory_allocated(device) / 1024.0**3 - reserved = torch.cuda.memory_reserved(device) / 1024.0**3 - smi = gpu_memory_usage_smi(device) - return usage, reserved - usage, max(0, smi - reserved) + active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3 + allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3 + reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3 + torch.cuda.reset_peak_memory_stats(device) + return active, allocated, reserved def mps_memory_usage_all(): - usage = torch.mps.current_allocated_memory() / 1024.0**3 - reserved = torch.mps.driver_allocated_memory() / 1024.0**3 - return usage, reserved - usage, 0 + active = torch.mps.current_allocated_memory() / 1024.0**3 + allocated = torch.mps.driver_allocated_memory() / 1024.0**3 + return active, allocated, 0 def npu_memory_usage_all(device=0): @@ -91,21 +93,38 @@ def gpu_memory_usage_smi(device=0): return 0.0 -def log_gpu_memory_usage(log, msg, device): - cur_device = get_device_type() +def get_gpu_memory_usage(device: int | torch.device = 0): + cur_device_type = str(get_device_type()) if torch.backends.mps.is_available(): usage, cache, misc = mps_memory_usage_all() - elif "npu" in str(cur_device) and is_torch_npu_available(): + elif "npu" in cur_device_type and is_torch_npu_available(): usage, cache, misc = npu_memory_usage_all(device) - else: + elif "cuda" in cur_device_type and torch.cuda.is_available(): usage, cache, misc = gpu_memory_usage_all(device) + else: + return 0.0, 0.0, 0.0 + + return usage, cache, misc + + +def log_gpu_memory_usage( + log: logging.Logger | logging.LoggerAdapter, + msg: str = "", + device: int | torch.device = 0, +): + try: + active, allocated, reserved = get_gpu_memory_usage(device) + except ValueError: + # likely CPU, ignore + return + cur_device_type = str(get_device_type()) extras = [] - if cache > 0: - extras.append(f"+{cache:.03f}GB cache") - if misc > 0: - extras.append(f"+{misc:.03f}GB misc") - log.info( - f"{str(cur_device)} memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", + if allocated > 0: + extras.append(f"+{allocated:.03f}GB allocated") + if reserved > 0: + extras.append(f"+{reserved:.03f}GB reserved") + msg = f"{cur_device_type} memory active:" if not msg else msg + log.debug( + f"{msg} {active:.03f}GB ({', '.join(extras)})", stacklevel=2, ) - return usage, cache, misc diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 8b8a77611..36370ef13 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -16,8 +16,8 @@ import pandas as pd import torch import torch.distributed as dist import wandb +import yaml from datasets import load_dataset -from optimum.bettertransformer import BetterTransformer from tqdm import tqdm from transformers import ( GenerationConfig, @@ -27,11 +27,12 @@ from transformers import ( TrainerState, TrainingArguments, ) -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy +from transformers.trainer_utils import ( + SaveStrategy, +) from trl.models import unwrap_model_for_generation from axolotl.utils import is_comet_available, is_mlflow_available -from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.callbacks.perplexity import Perplexity from axolotl.utils.distributed import ( barrier, @@ -53,100 +54,22 @@ IGNORE_INDEX = -100 LOG = get_logger(__name__) -class EvalFirstStepCallback( - TrainerCallback -): # pylint: disable=too-few-public-methods disable=unused-argument - """ - Callback to trigger evals on the first step - """ - - def on_step_end( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs, - ): - if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1: - control.should_evaluate = True - return control - - -class SaveBetterTransformerModelCallback( - TrainerCallback -): # pylint: disable=too-few-public-methods - """Callback to save the BetterTransformer wrapped model""" - - def on_step_end( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs, - ): - # Save - if ( - args.save_strategy == IntervalStrategy.STEPS - and args.save_steps > 0 - and state.global_step % args.save_steps == 0 - ): - control.should_save = True - - if control.should_save: - checkpoint_folder = os.path.join( - args.output_dir, - f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", - ) - - model = BetterTransformer.reverse(kwargs["model"]) - model.save_pretrained(checkpoint_folder) - # FIXME - need to cleanup old checkpoints - - # since we're saving here, we don't need the trainer loop to attempt to save too b/c - # the trainer will raise an exception since it can't save a BetterTransformer wrapped model - control.should_save = False - return control - - -class GPUStatsCallback( - TrainerCallback -): # pylint: disable=too-few-public-methods disable=unused-argument - """Callback to track GPU utilization""" - - def __init__(self, cfg): - self.cfg = cfg - self.logged = False - - def on_step_end( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs, - ): - if not self.logged and state.global_step > 1: - log_gpu_memory_usage(LOG, "while training", self.cfg.device) - self.logged = True - return control - - class LossWatchDogCallback(TrainerCallback): """Callback to track loss and stop training if loss is too high""" def __init__(self, cfg): self.cfg = cfg - self.logged = False self.violations = 0 self.threshold = cfg.loss_watchdog_threshold self.patience = cfg.loss_watchdog_patience or 3 def on_step_end( self, - _args: TrainingArguments, + args: TrainingArguments, state: TrainerState, control: TrainerControl, **_kwargs, - ): + ) -> TrainerControl: if len(state.log_history) > 0 and "loss" in state.log_history[-1]: if state.log_history[-1]["loss"] > self.threshold: self.violations += 1 @@ -160,6 +83,21 @@ class LossWatchDogCallback(TrainerCallback): return control +class SaveModelOnFirstStepCallback(TrainerCallback): + """Callback to save the model on the first step of training if enabled""" + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **_kwargs, + ) -> TrainerControl: + if state.global_step == 1: + control.should_save = True + return control + + def bench_eval_callback_factory(trainer, tokenizer): accuracy = evaluate.load("accuracy") abcd_idx = [ @@ -263,10 +201,10 @@ def bench_eval_callback_factory(trainer, tokenizer): def on_evaluate( self, args: AxolotlTrainingArguments, - state: TrainerState, # pylint: disable=unused-argument - control: TrainerControl, # pylint: disable=unused-argument - metrics: Dict[str, float], # pylint: disable=unused-argument - **kwargs, # pylint: disable=unused-argument + state: TrainerState, + control: TrainerControl, + metrics: Dict[str, float], + **kwargs, ): data_loader = trainer.get_bench_dataloader( bench_dataset.remove_columns(["input", "subject", "output", "name"]) @@ -296,7 +234,7 @@ def bench_eval_callback_factory(trainer, tokenizer): # Extract results by subject. bench_name = bench_dataset["name"] bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)} - for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name + for s, p, r in zip(bench_name, preds, refs, strict=False): bench_names[s]["preds"].append(p) bench_names[s]["refs"].append(r) barrier() @@ -334,9 +272,7 @@ def bench_eval_callback_factory(trainer, tokenizer): bench_scores = [] bench_refs = [] bench_preds = [] - for ( - bench_name - ) in combined_bench_names: # pylint: disable=consider-using-dict-items + for bench_name in combined_bench_names: bench_score = accuracy.compute( references=combined_bench_names[bench_name]["refs"], predictions=combined_bench_names[bench_name]["preds"], @@ -385,18 +321,18 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer): else: try: metrics[metric] = evaluate.load(metric) - except Exception as exc: # pylint: disable=broad-exception-caught + except Exception as exc: LOG.warning(f"{metric}: {exc.args}") return metrics def on_evaluate( self, - args: AxolotlTrainingArguments, # pylint: disable=unused-argument + args: AxolotlTrainingArguments, state: TrainerState, control: TrainerControl, - train_dataloader, # pylint: disable=unused-argument + train_dataloader, eval_dataloader, - **kwargs, # pylint: disable=unused-argument + **kwargs, ): trainer.model_wrapped.eval() @@ -404,7 +340,6 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer): self.cfg.device ) # Use this instead of trainer.model_wrapped.device as it may return cpu if fsdp offloaded - # pylint: disable=duplicate-code generation_config = GenerationConfig( max_new_tokens=self.cfg.eval_max_new_tokens, bos_token_id=tokenizer.bos_token_id, @@ -435,9 +370,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer): try: # Only pass the kwargs that are in the metric's feature list metric_kwargs = { - k: kwargs[k] - for k in metric._feature_names() # pylint: disable=protected-access - if k in kwargs + k: kwargs[k] for k in metric._feature_names() if k in kwargs } if isinstance(metric, Perplexity): @@ -449,7 +382,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer): if "score" in metric_score else metric_score["mean_score"] ) - except Exception: # pylint: disable=broad-exception-caught + except Exception: traceback.print_exc() LOG.debug( f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}" @@ -497,6 +430,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer): batch_input_ids, batch_labels, batch_pos_ids, + strict=False, ): if pos_ids is None: pos_ranges = [(0, len(input_ids_all) - 1)] @@ -547,7 +481,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer): prediction_all_tokens = predictions["sequences"].cpu().tolist() prediction_without_prompt_tokens_list = [] for prompt_token_ids, prediction_tokens in zip( - prompt_token_ids_list, prediction_all_tokens + prompt_token_ids_list, prediction_all_tokens, strict=False ): prediction_without_prompt_tokens = prediction_tokens[ len(prompt_token_ids) : @@ -585,12 +519,12 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str): def on_evaluate( self, - args: AxolotlTrainingArguments, # pylint: disable=unused-argument + args: AxolotlTrainingArguments, state: TrainerState, control: TrainerControl, - train_dataloader, # pylint: disable=unused-argument + train_dataloader, eval_dataloader, - **kwargs, # pylint: disable=unused-argument + **kwargs, ): eval_table_size = self.cfg.eval_table_size @@ -600,7 +534,6 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str): trainer.model.eval() device = torch.device(self.cfg.device) - # pylint: disable=duplicate-code generation_config = GenerationConfig( max_new_tokens=self.cfg.eval_max_new_tokens, bos_token_id=tokenizer.bos_token_id, @@ -668,6 +601,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str): batch_labels, batch_pos_ids, batch_logits, + strict=False, ): if pos_ids is None: pos_ranges = [(0, len(input_ids_all) - 1)] @@ -721,7 +655,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str): prediction_all_tokens = predictions["sequences"].cpu().tolist() prediction_without_prompt_tokens_list = [] for prompt_token_ids, prediction_tokens in zip( - prompt_token_ids_list, prediction_all_tokens + prompt_token_ids_list, prediction_all_tokens, strict=False ): prediction_without_prompt_tokens = prediction_tokens[ len(prompt_token_ids) : @@ -740,7 +674,11 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str): prediction_text, pred_step_text, ) in zip( - prompt_texts, completion_texts, predicted_texts, pred_step_texts + prompt_texts, + completion_texts, + predicted_texts, + pred_step_texts, + strict=False, ): table_data["id"].append(row_index) table_data["Prompt"].append(prompt_text) @@ -798,12 +736,12 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback): def on_train_begin( self, - args: AxolotlTrainingArguments, # pylint: disable=unused-argument - state: TrainerState, # pylint: disable=unused-argument + args: AxolotlTrainingArguments, + state: TrainerState, control: TrainerControl, - **kwargs, # pylint: disable=unused-argument + **kwargs, ): - if is_main_process(): + if state.is_world_process_zero: try: # sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later. with NamedTemporaryFile( @@ -822,6 +760,37 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback): except (FileNotFoundError, ConnectionError) as err: LOG.warning(f"Error while saving Axolotl config to WandB: {err}") + try: + with open(self.axolotl_config_path, "r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) or {} + + chat_tpl = cfg.get("chat_template_jinja") + if chat_tpl: + with NamedTemporaryFile( + mode="w", delete=True, suffix=".jinja", prefix="chat_template_" + ) as temp_ct_file: + if ( + isinstance(chat_tpl, str) + and os.path.exists(chat_tpl) + and os.path.isfile(chat_tpl) + ): + copyfile(chat_tpl, temp_ct_file.name) + else: + temp_ct_file.write(str(chat_tpl)) + temp_ct_file.flush() + + artifact = wandb.Artifact( + f"chat-template-{wandb.run.id}", type="jinja-template" + ) + artifact.add_file(temp_ct_file.name) + wandb.log_artifact(artifact) + wandb.save(temp_ct_file.name) + LOG.info( + "The chat_template_jinja has been saved to the WandB run under files." + ) + except (FileNotFoundError, ConnectionError, yaml.YAMLError) as err: + LOG.warning(f"Error while saving chat_template_jinja to WandB: {err}") + if args.deepspeed: try: # sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later. @@ -860,22 +829,68 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback): class GCCallback(TrainerCallback): """Callback to garbage collect torch cache""" - def __init__(self, gc_steps=None): - self.gc_steps = gc_steps + def __init__(self, gc_steps: int | None = -1): + self.gc_steps: int = gc_steps or -1 + self.next_gc_on_begin_step: int = -1 - def on_step_end( - self, args, state, control, **kwargs # pylint: disable=unused-argument - ): - if self.gc_steps > 0 and state.global_step % self.gc_steps == 0: - torch.cuda.empty_cache() - gc.collect() - - def on_epoch_end( - self, args, state, control, **kwargs # pylint: disable=unused-argument - ): + def _gc(self): torch.cuda.empty_cache() gc.collect() + def on_train_begin( + self, + args, + state, + control, + **kwargs, + ): + self._gc() + + def on_step_begin( + self, + args, + state, + control, + **kwargs, + ): + if self.next_gc_on_begin_step == state.global_step or state.global_step == 0: + self._gc() + + def on_step_end( + self, + args, + state, + control, + **kwargs, + ): + if control.should_evaluate: + # automatically GC before evals so the eval memory spike from the CEL doesn't OOM the trainer + self._gc() + # also GC on the start of the next step after the eval + self.next_gc_on_begin_step = state.global_step + 1 + elif self.gc_steps > 0 and state.global_step % self.gc_steps == 0: + self._gc() + elif ( + args.save_strategy == SaveStrategy.STEPS + and state.save_steps > 0 + and state.global_step % state.save_steps == 0 + ): + # gc on save steps in case anything is loaded to CPU RAM like offloaded tensors + self._gc() + elif state.global_step >= state.max_steps: + if args.save_strategy == SaveStrategy.STEPS: + # gc on save steps in case anything is loaded to CPU RAM like offloaded tensors + self._gc() + + def on_epoch_end( + self, + args, + state, + control, + **kwargs, + ): + self._gc() + def colab_inference_post_train_callback(trainer: Trainer): class ColabCallback(TrainerCallback): @@ -885,16 +900,12 @@ def colab_inference_post_train_callback(trainer: Trainer): self.gpu_name = torch.cuda.get_device_name(0) self.cfg = cfg - def on_train_end( - self, args, state, control, **kwargs - ): # pylint: disable=unused-argument + def on_train_end(self, args, state, control, **kwargs): """ handle T4 gpu, we need to convert attention to eager for inference """ if "Tesla T4" in self.gpu_name and self.cfg.xformers_attention: - trainer.model.config._attn_implementation = ( # pylint: disable=protected-access - "eager" - ) + trainer.model.config._attn_implementation = "eager" trainer.model.gradient_checkpointing_disable() trainer.model.config.use_cache = True trainer.model.eval() diff --git a/src/axolotl/utils/callbacks/comet_.py b/src/axolotl/utils/callbacks/comet_.py index 7dce95145..cd3bcf70e 100644 --- a/src/axolotl/utils/callbacks/comet_.py +++ b/src/axolotl/utils/callbacks/comet_.py @@ -22,10 +22,10 @@ class SaveAxolotlConfigtoCometCallback(TrainerCallback): def on_train_begin( self, - args: "AxolotlTrainingArguments", # pylint: disable=unused-argument - state: TrainerState, # pylint: disable=unused-argument + args: "AxolotlTrainingArguments", + state: TrainerState, control: TrainerControl, - **kwargs, # pylint: disable=unused-argument + **kwargs, ): if is_main_process(): try: diff --git a/src/axolotl/utils/callbacks/lisa.py b/src/axolotl/utils/callbacks/lisa.py index 348cdf2da..03f189d80 100644 --- a/src/axolotl/utils/callbacks/lisa.py +++ b/src/axolotl/utils/callbacks/lisa.py @@ -55,9 +55,7 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"): for param in layer.parameters(): param.requires_grad = False - def on_step_begin( - self, args, state, control, **kwargs - ): # pylint: disable=unused-argument + def on_step_begin(self, args, state, control, **kwargs): # Check if it's time to switch active layers, including at step 0 if state.global_step % self.step_interval == 0 or state.global_step == 1: self.switch_active_layers() diff --git a/src/axolotl/utils/callbacks/mlflow_.py b/src/axolotl/utils/callbacks/mlflow_.py index ac72f5e6d..30120a87d 100644 --- a/src/axolotl/utils/callbacks/mlflow_.py +++ b/src/axolotl/utils/callbacks/mlflow_.py @@ -23,7 +23,6 @@ def should_log_artifacts() -> bool: class SaveAxolotlConfigtoMlflowCallback(TrainerCallback): - # pylint: disable=duplicate-code """Callback to save axolotl config to mlflow""" def __init__(self, axolotl_config_path): @@ -31,10 +30,10 @@ class SaveAxolotlConfigtoMlflowCallback(TrainerCallback): def on_train_begin( self, - args: "AxolotlTrainingArguments", # pylint: disable=unused-argument - state: TrainerState, # pylint: disable=unused-argument + args: "AxolotlTrainingArguments", + state: TrainerState, control: TrainerControl, - **kwargs, # pylint: disable=unused-argument + **kwargs, ): if is_main_process(): try: diff --git a/src/axolotl/utils/callbacks/models.py b/src/axolotl/utils/callbacks/models.py new file mode 100644 index 000000000..5a20d70d9 --- /dev/null +++ b/src/axolotl/utils/callbacks/models.py @@ -0,0 +1,23 @@ +"""Helper functions for model classes""" + +from typing import Tuple + +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + + +def get_causal_lm_model_cls_prefix(model_type: str) -> Tuple[str, str]: + if model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + causal_lm_cls = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] + causal_lm_cls_prefix = causal_lm_cls + for suffix in [ + "ForCausalLM", + "ForConditionalGeneration", + "LMHeadModel", + "GenerationDecoder", + ]: + causal_lm_cls_prefix = causal_lm_cls_prefix.replace(suffix, "") + return causal_lm_cls_prefix, causal_lm_cls + causal_lm_cls_prefix = "".join( + [part.capitalize() for part in model_type.split("_")] + ) + return causal_lm_cls_prefix, f"{causal_lm_cls_prefix}ForCausalLM" diff --git a/src/axolotl/utils/callbacks/profiler.py b/src/axolotl/utils/callbacks/profiler.py index 36604813f..2cf5e0f4f 100644 --- a/src/axolotl/utils/callbacks/profiler.py +++ b/src/axolotl/utils/callbacks/profiler.py @@ -19,26 +19,57 @@ class PytorchProfilerCallback(TrainerCallback): PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps. """ - def __init__(self, steps_to_profile: int = 5): - self.steps_to_profile = steps_to_profile - if self.steps_to_profile: - torch.cuda.memory._record_memory_history( # pylint: disable=protected-access - enabled="all" - ) + def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0): + # steps are 0 indexed, so to start at 0-th step, we start at beginning of first step, + # and finish at end of last step, so 5 steps_to_profile is steps [0, 1, 2, 3, 4] + self.profiler_steps_end = profiler_steps_start + steps_to_profile - 1 + if profiler_steps_start == 0: + # start recording memory allocations before everything is allocated, because if we start + # at the beginning of step 0, we won't have any memory allocations in the traces + torch.cuda.memory._record_memory_history(enabled="all") + profiler_steps_start = -1 + self.profiler_steps_start = profiler_steps_start - def on_step_end( # pylint: disable=unused-argument + def on_step_begin( self, - args: TrainingArguments, # pylint: disable=unused-argument + args: TrainingArguments, state: TrainerState, - control: TrainerControl, # pylint: disable=unused-argument - **kwargs, # pylint: disable=unused-argument + control: TrainerControl, + **kwargs, ): - if state.global_step == self.steps_to_profile: - snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access + if state.global_step == self.profiler_steps_start: + torch.cuda.memory._record_memory_history(enabled="all") + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if state.global_step == self.profiler_steps_end: + snapshot = torch.cuda.memory._snapshot() with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout: dump(snapshot, fout) # tell CUDA to stop recording memory allocations now - torch.cuda.memory._record_memory_history( # pylint: disable=protected-access - enabled=None - ) + torch.cuda.memory._record_memory_history(enabled=None) + + def on_train_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # make sure to record if we happen to have more steps than steps to profile + if ( + state.global_step >= self.profiler_steps_start + and state.global_step < self.profiler_steps_end + ): + snapshot = torch.cuda.memory._snapshot() + with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout: + dump(snapshot, fout) + + # tell CUDA to stop recording memory allocations now + torch.cuda.memory._record_memory_history(enabled=None) diff --git a/src/axolotl/utils/callbacks/qat.py b/src/axolotl/utils/callbacks/qat.py index cf4d9a937..70746d6be 100644 --- a/src/axolotl/utils/callbacks/qat.py +++ b/src/axolotl/utils/callbacks/qat.py @@ -38,9 +38,7 @@ class QATCallback(TrainerCallback): def __init__(self, cfg: QATConfig): self.cfg = cfg - def on_step_begin( - self, args, state, control, model, **kwargs - ): # pylint: disable=unused-argument + def on_step_begin(self, args, state, control, model, **kwargs): if self.cfg.fake_quant_after_n_steps is not None: if state.global_step == 0: LOG.info(f"Disabling fake quantization at step {state.global_step}") diff --git a/src/axolotl/utils/callbacks/tokens_per_second.py b/src/axolotl/utils/callbacks/tokens_per_second.py new file mode 100644 index 000000000..ead129240 --- /dev/null +++ b/src/axolotl/utils/callbacks/tokens_per_second.py @@ -0,0 +1,64 @@ +"""A callback for calculating tokens per second during training.""" + +import time + +import torch +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + + +class TokensPerSecondCallback(TrainerCallback): + """ + A callback to measure and log tokens per second during training. + """ + + def __init__(self, tensor_parallel_size, context_parallel_size): + super().__init__() + self.step_time = 0.0 + self.start_time = 0.0 + self.non_data_parallel_size = 1 + if tensor_parallel_size is not None: + self.non_data_parallel_size *= tensor_parallel_size + if context_parallel_size is not None: + self.non_data_parallel_size *= context_parallel_size + + def on_step_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): # pylint: disable=unused-argument + self.start_time = time.perf_counter() + state.last_tokens_per_second = torch.zeros(1) + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): # pylint: disable=unused-argument + if hasattr(state, "num_tokens"): + step_time = time.perf_counter() - self.start_time + num_tokens_per_device = state.num_tokens.clone() + # non data parallel groups have duplicated tokens, so we avoid double-counting + num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size + state.last_tokens_per_second = num_tokens_per_device / step_time + + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs=None, + **kwargs, + ): # pylint: disable=unused-argument + # after logging, clear the running metrics + if hasattr(state, "last_tokens_per_second"): + state.last_tokens_per_second.zero_() + state.num_tokens = torch.zeros(1) diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py deleted file mode 100644 index bf496d2c5..000000000 --- a/src/axolotl/utils/chat_templates.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -This module provides functionality for selecting chat templates based on user choices. -These templates are used for formatting messages in a conversation. -""" - -from typing import TYPE_CHECKING, Any, Dict, Optional - -from axolotl.utils.logging import get_logger - -if TYPE_CHECKING: - from transformers import PreTrainedTokenizerBase - -LOG = get_logger("axolotl.utils.chat_templates") - -_JINJA_TEMPALTE_CHOICE = "jinja" -_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default" -_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_" - -_CHAT_TEMPLATES = { - "alpaca": "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'system' and loop.first %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '### Instruction:\n' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '### Response:\n' + message['content'] + eos_token }}{% endif %}{% if not loop.last %}{{ '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\n\n### Response:\n' }}{% endif %}", - "mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1... - "mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large... - "mistral_v3_tekken": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3-Tekken: Nemo, Pixtral... - "mistral_v7_tekken": "{%- set today = strftime_now(\"%Y-%m-%d\") %}\n{%- set default_system_message = \"You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.\\nYour knowledge base was last updated on 2023-10-01. The current date is \" + today + \".\\n\\nWhen you're not sure about some information, you say that you don't have the information and don't make up anything.\\nIf the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. \\\"What are some good restaurants around me?\\\" => \\\"Where are you?\\\" or \\\"When is the next flight to Tokyo\\\" => \\\"Where do you travel from?\\\")\" %}\n\n{{- bos_token }}\n\n{%- if messages[0]['role'] == 'system' %}\n {%- if messages[0]['content'] is string %}\n {%- set system_message = messages[0]['content'] %}\n {%- else %}\n {%- set system_message = messages[0]['content'][0]['text'] %}\n {%- endif %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set system_message = default_system_message %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}\n\n{%- for message in loop_messages %}\n {%- if message['role'] == 'user' %}\n {%- if message['content'] is string %}\n {{- '[INST]' + message['content'] + '[/INST]' }}\n {%- else %}\n {{- '[INST]' }}\n {%- for block in message['content'] %}\n {%- if block['type'] == 'text' %}\n {{- block['text'] }}\n {%- elif block['type'] in ['image', 'image_url'] %}\n {{- '[IMG]' }}\n {%- else %}\n {{- raise_exception('Only text and image blocks are supported in message content!') }}\n {%- endif %}\n {%- endfor %}\n {{- '[/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'system' %}\n {%- if message['content'] is string %}\n {{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }}\n {%- else %}\n {{- '[SYSTEM_PROMPT]' + message['content'][0]['text'] + '[/SYSTEM_PROMPT]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {%- if message['content'] is string %}\n {{- message['content'] + eos_token }}\n {%- else %}\n {{- message['content'][0]['text'] + eos_token }}\n {%- endif %}\n {%- else %}\n {{- raise_exception('Only user, system and assistant roles are supported!') }}\n {%- endif %}\n{%- endfor %}", - "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", - "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", - "gemma3": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'model\n'}}\n{%- endif -%}\n", - "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", - "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", - "llama3_2_vision": '{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now("%d %b %Y") %}\n {%- else %}\n {%- set date_string = "26 Jul 2024" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == "" %}\n {{- raise_exception("Prompting with images is incompatible with system messages.") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n {%- if tools is not none %}\n {{- "Environment: ipython\\n" }}\n {%- endif %}\n {{- "Cutting Knowledge Date: December 2023\\n" }}\n {{- "Today Date: " + date_string + "\\n\\n" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- "<|eot_id|>" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\' }}\n {%- if message[\'content\'] is string %}\n {{- message[\'content\'] }}\n {%- else %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {{- \'<|image|>\' }}\n {%- elif content[\'type\'] == \'text\' %}\n {{- content[\'text\'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {{- "<|eot_id|>" }}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n', - "llama4": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %} \n {%- if messages[0]['content'] is string %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- else %}\n {#- FIXME: The processor requires an array, always. #}\n {%- set system_message = messages[0]['content'][0]['text']|trim %}\n {%- endif %}\n {%- set messages = messages[1:] %}\n {%- set user_supplied_system_message = true %}\n{%- else %}\n {%- set system_message = \"\" %}\n {%- set user_supplied_system_message = false %}\n{%- endif %}\n\n{#- System message if the user supplied one #}\n{%- if user_supplied_system_message %}\n {{- \"<|header_start|>system<|header_end|>\\n\\n\" }}\n {%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n {%- endif %}\n {%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- \"<|eot|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|header_start|>user<|header_end|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|header_start|>' + message['role'] + '<|header_end|>\\n\\n' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %}\n {{- '<|header_start|>assistant<|header_end|>\\n\\n' -}}\n {{- '<|python_start|>' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|python_end|>' }}\n {%- for tool_call in message.tool_calls %}\n {{- '{\"name\": \"' + tool_call.function.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.function.arguments | tojson }}\n {{- \"}\" }}\n {%- endfor %}\n {{- \"<|eot|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|header_start|>ipython<|header_end|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|header_start|>assistant<|header_end|>\\n\\n' }}\n{%- endif %}\n", - "llava": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}", - "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", - "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}", - "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", - "deepseek_v3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- else %}{{'<|Assistant|>' + message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '
' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}", - "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', - "qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", - "qwen3": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in message.content %}\n {%- set content = message.content.split('')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}", - "exaone": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '[|system|][|endofturn|]\n' }}{% endif %}{{ '[|' + message['role'] + '|]' + message['content'] }}{% if message['role'] == 'user' %}{{ '\n' }}{% else %}{{ '[|endofturn|]\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[|assistant|]' }}{% endif %}", - "metharme": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>' + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %}", - "pixtral": '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}\n {%- endif %}\n {%- if message["role"] == "user" %}\n {%- if loop.last and system_message is defined %}\n {{- "[INST]" + system_message + "\n\n" }}\n {%- else %}\n {{- "[INST]" }}\n {%- endif %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["text"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message["content"] }}\n {%- endif %}\n {{- "[/INST]" }}\n {%- elif message["role"] == "assistant" %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["text"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n{%- endif %}\n{%- endfor %}\n{{- eos_token }}\n{%- else %}\n{{- message["content"] + eos_token }}\n{%- endif %}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}', - "qwen2_vl": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}", - "command_a": '{{ bos_token }}{% if documents %}\n{% set tools = [] %}\n{%- macro document_turn(documents) -%}\n{# format documents into chat turn #}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n {\n "tool_call_id": "0",\n "results": {\n{% for doc in documents %}\n "{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %},\n {% endif %}\n{% endfor %}\n\n },\n "is_error": null\n }\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}\n{%- macro tool_call_id_to_int(messages, tool_call_id) %}\n{%- set counter = namespace(value=0) %}\n{%- set tool_call_id_seen = namespace(value=false) %}\n{%- for msg in messages %}\n {%- if msg.tool_calls %}\n {%- for tool_call in msg.tool_calls %}\n {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}\n {{ counter.value }}\n {%- set tool_call_id_seen.value = true %}\n {%- endif %}\n {%- set counter.value = counter.value + 1 %}\n {%- endfor %}\n {%- endif %}\n{%- endfor %}\n{%- endmacro %}\n{%- macro format_tool_message(messages, tool_msg) -%}\n{# format tool message #}\n {\n "tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}",\n "results": {\n "0": {{ tool_msg.content|tojson }}\n },\n "is_error": null\n }\n{%- endmacro -%}\n{%- if messages and messages[0][\'role\']|lower == \'system\' %}{%- set developer_preamble = messages[0][\'content\'] %}{% endif %}\n{%- set tool_idx = namespace(value=0) %}\n{%- set tool_ids_seen = namespace(value=[]) %}\n{%- set sent_documents = namespace(value=false) %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n{% if tools or documents %}\n\nYou have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user\'s requests.\n\n## Tool Use\nThink about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.\n\n0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.\n NOTE: You MUST skip this step when you are directly responding to the user\'s request without using any tools.\n\nThen carry out your plan by repeatedly executing the following steps.\n1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields.\n When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.\n2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.\n Every tool call produces a list of results (when a tool call produces no result or a single result, it\'ll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id".\n3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you\'ve figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.\n NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.\n\nYou can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it\'s time to finally respond to the user.\n\n4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user\'s last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.\n{% if enable_citations %}\n\n## Grounding\nImportantly, note that "Reflection" and "Response" above can be grounded.\nGrounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "" and "" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "span" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1".\n{% endif %}\n\n## Available Tools\nHere is the list of tools that you have available to you.\nYou can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.\nEach tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema).\n\n```json\n[\n{% if documents %}\n {"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %}\n\n{% endif %}\n{% for tool in tools %}\n {"name": "{{ tool[\'function\'][\'name\'] }}", "description": "{{tool[\'function\'][\'description\']}}", "parameters": {{ tool[\'function\'][\'parameters\']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %}\n\n{% endfor %}\n]\n```\n\n{% endif %}\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n {%- if message.role|lower == \'system\' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n {%- elif message.role|lower == \'user\' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}\n {%- elif message.role|lower == \'assistant\' or message.role|lower == \'chatbot\' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[\n {% for tc in message.tool_calls %}\n {"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc[\'function\'][\'name\'] }}", "parameters": {{ tc[\'function\'][\'arguments\']|tojson }}}{% if not loop.last %},{% endif %}\n\n {% set tool_idx.value = tool_idx.value + 1 %}\n {% endfor %}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}\n {% elif message.role|lower == \'tool\' and message.tool_call_id not in tool_ids_seen.value %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n{{ format_tool_message(messages, message) }}\n {%- set stopped = namespace(value=false) %}\n {%- for msg in messages[loop.index0 + 1:] %}\n {%- if not stopped.value and msg.role|lower == \'tool\' %},\n{{ format_tool_message(messages, msg) }}\n {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}\n {%- else %}\n {%- set stopped.value = true %}\n {%- endif %}\n {%- endfor %}\n\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>\n {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>\n{%- else -%}\n{%- if messages and messages[0][\'role\']|lower == \'system\' %}{%- set developer_preamble = messages[0][\'content\'] %}{% endif %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\n{% if safety_mode|upper == \'STRICT\' -%}\nYou are in strict safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will reject requests to generate content related to violence, hate, misinformation or sex to any amount. You will avoid using profanity. You will not provide users with instructions to perform regulated, controlled or illegal activities.\n{%- else -%}\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n{%- endif %}\n\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n {%- if message.role|lower == \'system\' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n {%- elif message.role|lower == \'user\' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n {%- elif message.role|lower == \'assistant\' or message.role|lower == \'chatbot\' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>\n {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{%- if add_generation_prompt -%}<|START_RESPONSE|>{%- endif %}\n{% endif %}', - "command_a_tool_use": '{{ bos_token }}{%- macro document_turn(documents) -%}\n{# format documents into chat turn #}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n {\n "tool_call_id": "0",\n "results": {\n{% for doc in documents %}\n "{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %},\n {% endif %}\n{% endfor %}\n\n },\n "is_error": null\n }\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}\n{%- macro tool_call_id_to_int(messages, tool_call_id) %}\n{%- set counter = namespace(value=0) %}\n{%- set tool_call_id_seen = namespace(value=false) %}\n{%- for msg in messages %}\n {%- if msg.tool_calls %}\n {%- for tool_call in msg.tool_calls %}\n {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}\n {{ counter.value }}\n {%- set tool_call_id_seen.value = true %}\n {%- endif %}\n {%- set counter.value = counter.value + 1 %}\n {%- endfor %}\n {%- endif %}\n{%- endfor %}\n{%- endmacro %}\n{%- macro format_tool_message(messages, tool_msg) -%}\n{# format tool message #}\n {\n "tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}",\n "results": {\n "0": {{ tool_msg.content|tojson }}\n },\n "is_error": null\n }\n{%- endmacro -%}\n{%- if messages and messages[0][\'role\']|lower == \'system\' %}{%- set developer_preamble = messages[0][\'content\'] %}{% endif %}\n{%- set tool_idx = namespace(value=0) %}\n{%- set tool_ids_seen = namespace(value=[]) %}\n{%- set sent_documents = namespace(value=false) %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n{% if tools or documents %}\n\nYou have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user\'s requests.\n\n## Tool Use\nThink about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.\n\n0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.\n NOTE: You MUST skip this step when you are directly responding to the user\'s request without using any tools.\n\nThen carry out your plan by repeatedly executing the following steps.\n1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields.\n When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.\n2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.\n Every tool call produces a list of results (when a tool call produces no result or a single result, it\'ll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id".\n3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you\'ve figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.\n NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.\n\nYou can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it\'s time to finally respond to the user.\n\n4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user\'s last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.\n{% if enable_citations %}\n\n## Grounding\nImportantly, note that "Reflection" and "Response" above can be grounded.\nGrounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "" and "" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "span" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1".\n{% endif %}\n\n## Available Tools\nHere is the list of tools that you have available to you.\nYou can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.\nEach tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema).\n\n```json\n[\n{% if documents %}\n {"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %}\n\n{% endif %}\n{% for tool in tools %}\n {"name": "{{ tool[\'function\'][\'name\'] }}", "description": "{{tool[\'function\'][\'description\']}}", "parameters": {{ tool[\'function\'][\'parameters\']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %}\n\n{% endfor %}\n]\n```\n\n{% endif %}\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n {%- if message.role|lower == \'system\' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n {%- elif message.role|lower == \'user\' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}\n {%- elif message.role|lower == \'assistant\' or message.role|lower == \'chatbot\' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[\n {% for tc in message.tool_calls %}\n {"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc[\'function\'][\'name\'] }}", "parameters": {{ tc[\'function\'][\'arguments\']|tojson }}}{% if not loop.last %},{% endif %}\n\n {% set tool_idx.value = tool_idx.value + 1 %}\n {% endfor %}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}\n {% elif message.role|lower == \'tool\' and message.tool_call_id not in tool_ids_seen.value %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n{{ format_tool_message(messages, message) }}\n {%- set stopped = namespace(value=false) %}\n {%- for msg in messages[loop.index0 + 1:] %}\n {%- if not stopped.value and msg.role|lower == \'tool\' %},\n{{ format_tool_message(messages, msg) }}\n {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}\n {%- else %}\n {%- set stopped.value = true %}\n {%- endif %}\n {%- endfor %}\n\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>\n {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>', - "command_a_rag": '{{ bos_token }}{% set tools = [] %}\n{%- macro document_turn(documents) -%}\n{# format documents into chat turn #}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n {\n "tool_call_id": "0",\n "results": {\n{% for doc in documents %}\n "{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %},\n {% endif %}\n{% endfor %}\n\n },\n "is_error": null\n }\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}\n{%- macro tool_call_id_to_int(messages, tool_call_id) %}\n{%- set counter = namespace(value=0) %}\n{%- set tool_call_id_seen = namespace(value=false) %}\n{%- for msg in messages %}\n {%- if msg.tool_calls %}\n {%- for tool_call in msg.tool_calls %}\n {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}\n {{ counter.value }}\n {%- set tool_call_id_seen.value = true %}\n {%- endif %}\n {%- set counter.value = counter.value + 1 %}\n {%- endfor %}\n {%- endif %}\n{%- endfor %}\n{%- endmacro %}\n{%- macro format_tool_message(messages, tool_msg) -%}\n{# format tool message #}\n {\n "tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}",\n "results": {\n "0": {{ tool_msg.content|tojson }}\n },\n "is_error": null\n }\n{%- endmacro -%}\n{%- if messages and messages[0][\'role\']|lower == \'system\' %}{%- set developer_preamble = messages[0][\'content\'] %}{% endif %}\n{%- set tool_idx = namespace(value=0) %}\n{%- set tool_ids_seen = namespace(value=[]) %}\n{%- set sent_documents = namespace(value=false) %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n{% if tools or documents %}\n\nYou have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user\'s requests.\n\n## Tool Use\nThink about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.\n\n0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.\n NOTE: You MUST skip this step when you are directly responding to the user\'s request without using any tools.\n\nThen carry out your plan by repeatedly executing the following steps.\n1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields.\n When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.\n2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.\n Every tool call produces a list of results (when a tool call produces no result or a single result, it\'ll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id".\n3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you\'ve figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.\n NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.\n\nYou can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it\'s time to finally respond to the user.\n\n4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user\'s last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.\n{% if enable_citations %}\n\n## Grounding\nImportantly, note that "Reflection" and "Response" above can be grounded.\nGrounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "" and "" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "span" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1".\n{% endif %}\n\n## Available Tools\nHere is the list of tools that you have available to you.\nYou can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.\nEach tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema).\n\n```json\n[\n{% if documents %}\n {"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %}\n\n{% endif %}\n{% for tool in tools %}\n {"name": "{{ tool[\'function\'][\'name\'] }}", "description": "{{tool[\'function\'][\'description\']}}", "parameters": {{ tool[\'function\'][\'parameters\']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %}\n\n{% endfor %}\n]\n```\n\n{% endif %}\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n {%- if message.role|lower == \'system\' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n {%- elif message.role|lower == \'user\' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}\n {%- elif message.role|lower == \'assistant\' or message.role|lower == \'chatbot\' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[\n {% for tc in message.tool_calls %}\n {"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc[\'function\'][\'name\'] }}", "parameters": {{ tc[\'function\'][\'arguments\']|tojson }}}{% if not loop.last %},{% endif %}\n\n {% set tool_idx.value = tool_idx.value + 1 %}\n {% endfor %}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}\n {% elif message.role|lower == \'tool\' and message.tool_call_id not in tool_ids_seen.value %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n{{ format_tool_message(messages, message) }}\n {%- set stopped = namespace(value=false) %}\n {%- for msg in messages[loop.index0 + 1:] %}\n {%- if not stopped.value and msg.role|lower == \'tool\' %},\n{{ format_tool_message(messages, msg) }}\n {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}\n {%- else %}\n {%- set stopped.value = true %}\n {%- endif %}\n {%- endfor %}\n\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>\n {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>', - "aya": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Aya, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", -} - - -def get_chat_template( - user_choice: str, - jinja_template: Optional[str] = None, - tokenizer: Optional["PreTrainedTokenizerBase"] = None, -) -> str: - """ - Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer. - - Args: - user_choice (str): The user's choice of template. - jinja_template (Optional[str], optional): The jinja template string. Defaults to None. - tokenizer (Optional[PreTrainedTokenizerBase], optional): The tokenizer. Defaults to None. - - Returns: - str: The chosen template string. - - Raises: - ValueError: If the user_choice is not found in the templates. - """ - if user_choice == _JINJA_TEMPALTE_CHOICE: - if not jinja_template: - raise ValueError( - f"`jinja_template` cannot be None when `chat_template` choice is {_JINJA_TEMPALTE_CHOICE}" - ) - return jinja_template - - if user_choice == _DEFAULT_TEMPLATE_CHOICE: - if not tokenizer: - raise ValueError( - f"`tokenizer` cannot be None when chat_template choice is {_DEFAULT_TEMPLATE_CHOICE}" - ) - if not tokenizer.chat_template: - raise ValueError( - f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. " - f"Please add a chat_template in tokenizer config" - ) - return tokenizer.chat_template # type: ignore - - if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX): - if not tokenizer: - raise ValueError( - f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}" - ) - if tokenizer.chat_template: - return tokenizer.chat_template # type: ignore - - user_choice = user_choice[ - len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) : - ] - LOG.warning( - f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template." - ) - - if user_choice in _CHAT_TEMPLATES: - return _CHAT_TEMPLATES[user_choice] - - raise ValueError(f"Template '{user_choice}' not found.") - - -def extract_chat_template_args(cfg, ds_cfg: Optional[Dict[str, Any]] = None): - if ds_cfg and ds_cfg.get("chat_template"): - chat_template_choice = ds_cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE - chat_template_jinja = ds_cfg.get("chat_template_jinja") - else: - chat_template_choice = cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE - chat_template_jinja = cfg.get("chat_template_jinja") - return chat_template_choice, chat_template_jinja - - -def get_chat_template_from_config( - cfg, - ds_cfg: Optional[Dict[str, Any]] = None, - tokenizer: Optional["PreTrainedTokenizerBase"] = None, -) -> str: - chat_template_choice, chat_template_jinja = extract_chat_template_args( - cfg=cfg, ds_cfg=ds_cfg - ) - return get_chat_template( - user_choice=chat_template_choice, - jinja_template=chat_template_jinja, - tokenizer=tokenizer, - ) - - -def register_chat_template(template_name: str, chat_template: str): - """ - Registers chat templates. - - Args: - template_name (str): The name of the template. - chat_template (str): The template string. - """ - - if template_name in _CHAT_TEMPLATES: - raise ValueError(f"Template '{template_name}' already exists.") - - _CHAT_TEMPLATES[template_name] = chat_template diff --git a/src/axolotl/utils/chat_templates/__init__.py b/src/axolotl/utils/chat_templates/__init__.py new file mode 100644 index 000000000..337417c7d --- /dev/null +++ b/src/axolotl/utils/chat_templates/__init__.py @@ -0,0 +1,20 @@ +""" +This module provides functionality for selecting chat templates based on user choices. +These templates are used for formatting messages in a conversation. +""" + +from .base import ( + _CHAT_TEMPLATES, + extract_chat_template_args, + get_chat_template, + get_chat_template_from_config, + register_chat_template, +) + +__all__ = [ + "get_chat_template", + "extract_chat_template_args", + "get_chat_template_from_config", + "register_chat_template", + "_CHAT_TEMPLATES", +] diff --git a/src/axolotl/utils/chat_templates/base.py b/src/axolotl/utils/chat_templates/base.py new file mode 100644 index 000000000..11d15fc1d --- /dev/null +++ b/src/axolotl/utils/chat_templates/base.py @@ -0,0 +1,125 @@ +""" +utility functions for chat templates +""" + +import os +from typing import TYPE_CHECKING, Any, Dict, Optional + +from axolotl.utils.logging import get_logger + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase + +LOG = get_logger("axolotl.utils.chat_templates") + +_JINJA_TEMPLATE_CHOICE = "jinja" +_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default" +_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_" + +TEMPLATE_DIR = os.path.join(os.path.dirname(__file__), "templates") +_CHAT_TEMPLATES: dict[str, str] = {} +for filename in [f for f in os.listdir(TEMPLATE_DIR) if f.endswith(".jinja")]: + with open(os.path.join(TEMPLATE_DIR, filename), "r", encoding="utf-8") as f: + _CHAT_TEMPLATES[filename[:-6]] = f.read() + + +def get_chat_template( + user_choice: str, + jinja_template: str | None = None, + tokenizer: Optional["PreTrainedTokenizerBase"] = None, +) -> str: + """ + Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer. + + Args: + user_choice (str): The user's choice of template. + jinja_template (str, optional): The jinja template string or Path to a valid jinja template file. Defaults to None. + tokenizer (PreTrainedTokenizerBase, optional): The tokenizer. Defaults to None. + + Returns: + str: The chosen template string. + + Raises: + ValueError: If the user_choice is not found in the templates. + """ + if user_choice == _JINJA_TEMPLATE_CHOICE: + if not jinja_template: + raise ValueError( + f"`jinja_template` cannot be None when `chat_template` choice is {_JINJA_TEMPLATE_CHOICE}" + ) + if os.path.exists(jinja_template) and os.path.isfile(jinja_template): + with open(jinja_template, "r", encoding="utf-8") as file: + jinja_template = file.read() + return jinja_template + + if user_choice == _DEFAULT_TEMPLATE_CHOICE: + if not tokenizer: + raise ValueError( + f"`tokenizer` cannot be None when chat_template choice is {_DEFAULT_TEMPLATE_CHOICE}" + ) + if not tokenizer.chat_template: + raise ValueError( + f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. " + f"Please add a chat_template in tokenizer config" + ) + return tokenizer.chat_template # type: ignore + + if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX): + if not tokenizer: + raise ValueError( + f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}" + ) + if tokenizer.chat_template: + return tokenizer.chat_template # type: ignore + + user_choice = user_choice[ + len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) : + ] + LOG.warning( + f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template." + ) + + if user_choice in _CHAT_TEMPLATES: + return _CHAT_TEMPLATES[user_choice] + + raise ValueError(f"Template '{user_choice}' not found.") + + +def extract_chat_template_args(cfg, ds_cfg: Dict[str, Any] | None = None): + if ds_cfg and ds_cfg.get("chat_template"): + chat_template_choice = ds_cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE + chat_template_jinja = ds_cfg.get("chat_template_jinja") + else: + chat_template_choice = cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE + chat_template_jinja = cfg.get("chat_template_jinja") + return chat_template_choice, chat_template_jinja + + +def get_chat_template_from_config( + cfg, + ds_cfg: Dict[str, Any] | None = None, + tokenizer: Optional["PreTrainedTokenizerBase"] = None, +) -> str: + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg=cfg, ds_cfg=ds_cfg + ) + return get_chat_template( + user_choice=chat_template_choice, + jinja_template=chat_template_jinja, + tokenizer=tokenizer, + ) + + +def register_chat_template(template_name: str, chat_template: str): + """ + Registers chat templates. + + Args: + template_name (str): The name of the template. + chat_template (str): The template string. + """ + + if template_name in _CHAT_TEMPLATES: + raise ValueError(f"Template '{template_name}' already exists.") + + _CHAT_TEMPLATES[template_name] = chat_template diff --git a/src/axolotl/utils/chat_templates/templates/alpaca.jinja b/src/axolotl/utils/chat_templates/templates/alpaca.jinja new file mode 100644 index 000000000..5e9d63c42 --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/alpaca.jinja @@ -0,0 +1,8 @@ +{{ bos_token }}{% for message in messages %}{% if message['role'] == 'system' and loop.first %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '### Instruction: +' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '### Response: +' + message['content'] + eos_token }}{% endif %}{% if not loop.last %}{{ ' + +' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' + +### Response: +' }}{% endif %} diff --git a/src/axolotl/utils/chat_templates/templates/aya.jinja b/src/axolotl/utils/chat_templates/templates/aya.jinja new file mode 100644 index 000000000..97e54d4b1 --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/aya.jinja @@ -0,0 +1 @@ +{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Aya, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %} diff --git a/src/axolotl/utils/chat_templates/templates/chatml.jinja b/src/axolotl/utils/chat_templates/templates/chatml.jinja new file mode 100644 index 000000000..2116e45ca --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/chatml.jinja @@ -0,0 +1,4 @@ +{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + ' +' + message['content'] + '<|im_end|>' + ' +'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant +' }}{% endif %} diff --git a/src/axolotl/utils/chat_templates/templates/cohere.jinja b/src/axolotl/utils/chat_templates/templates/cohere.jinja new file mode 100644 index 000000000..638ce5ef2 --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/cohere.jinja @@ -0,0 +1 @@ +{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %} diff --git a/src/axolotl/utils/chat_templates/templates/command_a.jinja b/src/axolotl/utils/chat_templates/templates/command_a.jinja new file mode 100644 index 000000000..ef0594172 --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/command_a.jinja @@ -0,0 +1,210 @@ +{{ bos_token }}{% if documents %} +{% set tools = [] %} +{%- macro document_turn(documents) -%} +{# format documents into chat turn #} +<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[ + {"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}} +]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[ + { + "tool_call_id": "0", + "results": { +{% for doc in documents %} + "{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %}, + {% endif %} +{% endfor %} + + }, + "is_error": null + } +]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %} +{%- macro tool_call_id_to_int(messages, tool_call_id) %} +{%- set counter = namespace(value=0) %} +{%- set tool_call_id_seen = namespace(value=false) %} +{%- for msg in messages %} + {%- if msg.tool_calls %} + {%- for tool_call in msg.tool_calls %} + {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%} + {{ counter.value }} + {%- set tool_call_id_seen.value = true %} + {%- endif %} + {%- set counter.value = counter.value + 1 %} + {%- endfor %} + {%- endif %} +{%- endfor %} +{%- endmacro %} +{%- macro format_tool_message(messages, tool_msg) -%} +{# format tool message #} + { + "tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}", + "results": { + "0": {{ tool_msg.content|tojson }} + }, + "is_error": null + } +{%- endmacro -%} +{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %} +{%- set tool_idx = namespace(value=0) %} +{%- set tool_ids_seen = namespace(value=[]) %} +{%- set sent_documents = namespace(value=false) %} +<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble +You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes. + +Your information cutoff date is June 2024. + +You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages. +{% if tools or documents %} + +You have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests. + +## Tool Use +Think about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first. + +0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>. + You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed. + NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools. + +Then carry out your plan by repeatedly executing the following steps. +1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields. + When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>. +2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results. + Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id". +3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>. + You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded. + NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user. + +You can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user. + +4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>. +{% if enable_citations %} + +## Grounding +Importantly, note that "Reflection" and "Response" above can be grounded. +Grounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "" and "" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "span" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1". +{% endif %} + +## Available Tools +Here is the list of tools that you have available to you. +You can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it. +Each tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema). + +```json +[ +{% if documents %} + {"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %} + +{% endif %} +{% for tool in tools %} + {"name": "{{ tool['function']['name'] }}", "description": "{{tool['function']['description']}}", "parameters": {{ tool['function']['parameters']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %} + +{% endfor %} +] +``` + +{% endif %} +# Default Preamble +The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt. +- Your name is Command. +- You are a large language model built by Cohere. +- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions. +- If the input is ambiguous, ask clarifying follow-up questions. +- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks). +- Use LaTeX to generate mathematical notation for complex equations. +- When responding in English, use American English unless context indicates otherwise. +- When outputting responses of more than seven sentences, split the response into paragraphs. +- Prefer the active voice. +- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references. +- Use gender-neutral pronouns for unspecified persons. +- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list. +- Use the third person when asked to write a summary. +- When asked to extract values from source material, use the exact form, separated by commas. +- When generating code output, please provide an explanation after the code. +- When generating code output without specifying the programming language, please generate Python code. +- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer. +{%- if developer_preamble %} + + +# Developer Preamble +The following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions. +{{ developer_preamble }} +{%- endif -%} +<|END_OF_TURN_TOKEN|> +{%- for message in messages %} + {%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%} +<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|> + {%- elif message.role|lower == 'user' %} +<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %} + {%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %} +<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[ + {% for tc in message.tool_calls %} + {"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc['function']['name'] }}", "parameters": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %} + + {% set tool_idx.value = tool_idx.value + 1 %} + {% endfor %} +]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %} + {% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %} +<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[ +{{ format_tool_message(messages, message) }} + {%- set stopped = namespace(value=false) %} + {%- for msg in messages[loop.index0 + 1:] %} + {%- if not stopped.value and msg.role|lower == 'tool' %}, +{{ format_tool_message(messages, msg) }} + {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %} + {%- else %} + {%- set stopped.value = true %} + {%- endif %} + {%- endfor %} + +]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|> + {%- endif %} +{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|> +{%- else -%} +{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %} +<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble +{% if safety_mode|upper == 'STRICT' -%} +You are in strict safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will reject requests to generate content related to violence, hate, misinformation or sex to any amount. You will avoid using profanity. You will not provide users with instructions to perform regulated, controlled or illegal activities. +{%- else -%} +You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes. +{%- endif %} + + +Your information cutoff date is June 2024. + +You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages. + +# Default Preamble +The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt. +- Your name is Command. +- You are a large language model built by Cohere. +- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions. +- If the input is ambiguous, ask clarifying follow-up questions. +- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks). +- Use LaTeX to generate mathematical notation for complex equations. +- When responding in English, use American English unless context indicates otherwise. +- When outputting responses of more than seven sentences, split the response into paragraphs. +- Prefer the active voice. +- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references. +- Use gender-neutral pronouns for unspecified persons. +- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list. +- Use the third person when asked to write a summary. +- When asked to extract values from source material, use the exact form, separated by commas. +- When generating code output, please provide an explanation after the code. +- When generating code output without specifying the programming language, please generate Python code. +- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer. +{%- if developer_preamble %} + + +# Developer Preamble +The following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions. +{{ developer_preamble }} +{%- endif -%} +<|END_OF_TURN_TOKEN|> +{%- for message in messages %} + {%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%} +<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|> + {%- elif message.role|lower == 'user' %} +<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|> + {%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %} +<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|> + {%- endif %} +{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{%- if add_generation_prompt -%}<|START_RESPONSE|>{%- endif %} +{% endif %} diff --git a/src/axolotl/utils/chat_templates/templates/command_a_rag.jinja b/src/axolotl/utils/chat_templates/templates/command_a_rag.jinja new file mode 100644 index 000000000..e4a5fd9ac --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/command_a_rag.jinja @@ -0,0 +1,158 @@ +{{ bos_token }}{% set tools = [] %} +{%- macro document_turn(documents) -%} +{# format documents into chat turn #} +<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[ + {"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}} +]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[ + { + "tool_call_id": "0", + "results": { +{% for doc in documents %} + "{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %}, + {% endif %} +{% endfor %} + + }, + "is_error": null + } +]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %} +{%- macro tool_call_id_to_int(messages, tool_call_id) %} +{%- set counter = namespace(value=0) %} +{%- set tool_call_id_seen = namespace(value=false) %} +{%- for msg in messages %} + {%- if msg.tool_calls %} + {%- for tool_call in msg.tool_calls %} + {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%} + {{ counter.value }} + {%- set tool_call_id_seen.value = true %} + {%- endif %} + {%- set counter.value = counter.value + 1 %} + {%- endfor %} + {%- endif %} +{%- endfor %} +{%- endmacro %} +{%- macro format_tool_message(messages, tool_msg) -%} +{# format tool message #} + { + "tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}", + "results": { + "0": {{ tool_msg.content|tojson }} + }, + "is_error": null + } +{%- endmacro -%} +{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %} +{%- set tool_idx = namespace(value=0) %} +{%- set tool_ids_seen = namespace(value=[]) %} +{%- set sent_documents = namespace(value=false) %} +<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble +You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes. + +Your information cutoff date is June 2024. + +You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages. +{% if tools or documents %} + +You have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests. + +## Tool Use +Think about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first. + +0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>. + You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed. + NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools. + +Then carry out your plan by repeatedly executing the following steps. +1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields. + When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>. +2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results. + Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id". +3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>. + You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded. + NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user. + +You can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user. + +4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>. +{% if enable_citations %} + +## Grounding +Importantly, note that "Reflection" and "Response" above can be grounded. +Grounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "" and "" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "span" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1". +{% endif %} + +## Available Tools +Here is the list of tools that you have available to you. +You can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it. +Each tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema). + +```json +[ +{% if documents %} + {"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %} + +{% endif %} +{% for tool in tools %} + {"name": "{{ tool['function']['name'] }}", "description": "{{tool['function']['description']}}", "parameters": {{ tool['function']['parameters']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %} + +{% endfor %} +] +``` + +{% endif %} +# Default Preamble +The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt. +- Your name is Command. +- You are a large language model built by Cohere. +- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions. +- If the input is ambiguous, ask clarifying follow-up questions. +- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks). +- Use LaTeX to generate mathematical notation for complex equations. +- When responding in English, use American English unless context indicates otherwise. +- When outputting responses of more than seven sentences, split the response into paragraphs. +- Prefer the active voice. +- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references. +- Use gender-neutral pronouns for unspecified persons. +- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list. +- Use the third person when asked to write a summary. +- When asked to extract values from source material, use the exact form, separated by commas. +- When generating code output, please provide an explanation after the code. +- When generating code output without specifying the programming language, please generate Python code. +- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer. +{%- if developer_preamble %} + + +# Developer Preamble +The following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions. +{{ developer_preamble }} +{%- endif -%} +<|END_OF_TURN_TOKEN|> +{%- for message in messages %} + {%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%} +<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|> + {%- elif message.role|lower == 'user' %} +<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %} + {%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %} +<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[ + {% for tc in message.tool_calls %} + {"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc['function']['name'] }}", "parameters": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %} + + {% set tool_idx.value = tool_idx.value + 1 %} + {% endfor %} +]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %} + {% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %} +<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[ +{{ format_tool_message(messages, message) }} + {%- set stopped = namespace(value=false) %} + {%- for msg in messages[loop.index0 + 1:] %} + {%- if not stopped.value and msg.role|lower == 'tool' %}, +{{ format_tool_message(messages, msg) }} + {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %} + {%- else %} + {%- set stopped.value = true %} + {%- endif %} + {%- endfor %} + +]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|> + {%- endif %} +{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|> diff --git a/src/axolotl/utils/chat_templates/templates/command_a_tool_use.jinja b/src/axolotl/utils/chat_templates/templates/command_a_tool_use.jinja new file mode 100644 index 000000000..eecd42488 --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/command_a_tool_use.jinja @@ -0,0 +1,157 @@ +{{ bos_token }}{%- macro document_turn(documents) -%} +{# format documents into chat turn #} +<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[ + {"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}} +]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[ + { + "tool_call_id": "0", + "results": { +{% for doc in documents %} + "{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %}, + {% endif %} +{% endfor %} + + }, + "is_error": null + } +]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %} +{%- macro tool_call_id_to_int(messages, tool_call_id) %} +{%- set counter = namespace(value=0) %} +{%- set tool_call_id_seen = namespace(value=false) %} +{%- for msg in messages %} + {%- if msg.tool_calls %} + {%- for tool_call in msg.tool_calls %} + {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%} + {{ counter.value }} + {%- set tool_call_id_seen.value = true %} + {%- endif %} + {%- set counter.value = counter.value + 1 %} + {%- endfor %} + {%- endif %} +{%- endfor %} +{%- endmacro %} +{%- macro format_tool_message(messages, tool_msg) -%} +{# format tool message #} + { + "tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}", + "results": { + "0": {{ tool_msg.content|tojson }} + }, + "is_error": null + } +{%- endmacro -%} +{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %} +{%- set tool_idx = namespace(value=0) %} +{%- set tool_ids_seen = namespace(value=[]) %} +{%- set sent_documents = namespace(value=false) %} +<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble +You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes. + +Your information cutoff date is June 2024. + +You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages. +{% if tools or documents %} + +You have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests. + +## Tool Use +Think about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first. + +0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>. + You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed. + NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools. + +Then carry out your plan by repeatedly executing the following steps. +1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields. + When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>. +2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results. + Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id". +3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>. + You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded. + NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user. + +You can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user. + +4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>. +{% if enable_citations %} + +## Grounding +Importantly, note that "Reflection" and "Response" above can be grounded. +Grounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "" and "" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "span" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1". +{% endif %} + +## Available Tools +Here is the list of tools that you have available to you. +You can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it. +Each tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema). + +```json +[ +{% if documents %} + {"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %} + +{% endif %} +{% for tool in tools %} + {"name": "{{ tool['function']['name'] }}", "description": "{{tool['function']['description']}}", "parameters": {{ tool['function']['parameters']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %} + +{% endfor %} +] +``` + +{% endif %} +# Default Preamble +The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt. +- Your name is Command. +- You are a large language model built by Cohere. +- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions. +- If the input is ambiguous, ask clarifying follow-up questions. +- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks). +- Use LaTeX to generate mathematical notation for complex equations. +- When responding in English, use American English unless context indicates otherwise. +- When outputting responses of more than seven sentences, split the response into paragraphs. +- Prefer the active voice. +- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references. +- Use gender-neutral pronouns for unspecified persons. +- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list. +- Use the third person when asked to write a summary. +- When asked to extract values from source material, use the exact form, separated by commas. +- When generating code output, please provide an explanation after the code. +- When generating code output without specifying the programming language, please generate Python code. +- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer. +{%- if developer_preamble %} + + +# Developer Preamble +The following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions. +{{ developer_preamble }} +{%- endif -%} +<|END_OF_TURN_TOKEN|> +{%- for message in messages %} + {%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%} +<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|> + {%- elif message.role|lower == 'user' %} +<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %} + {%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %} +<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[ + {% for tc in message.tool_calls %} + {"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc['function']['name'] }}", "parameters": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %} + + {% set tool_idx.value = tool_idx.value + 1 %} + {% endfor %} +]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %} + {% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %} +<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[ +{{ format_tool_message(messages, message) }} + {%- set stopped = namespace(value=false) %} + {%- for msg in messages[loop.index0 + 1:] %} + {%- if not stopped.value and msg.role|lower == 'tool' %}, +{{ format_tool_message(messages, msg) }} + {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %} + {%- else %} + {%- set stopped.value = true %} + {%- endif %} + {%- endfor %} + +]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|> + {%- endif %} +{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|> diff --git a/src/axolotl/utils/chat_templates/templates/deepseek_v2.jinja b/src/axolotl/utils/chat_templates/templates/deepseek_v2.jinja new file mode 100644 index 000000000..59fde8f2c --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/deepseek_v2.jinja @@ -0,0 +1,3 @@ +{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + ' + +' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %} diff --git a/src/axolotl/utils/chat_templates/templates/deepseek_v3.jinja b/src/axolotl/utils/chat_templates/templates/deepseek_v3.jinja new file mode 100644 index 000000000..35803578c --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/deepseek_v3.jinja @@ -0,0 +1 @@ +{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- else %}{{'<|Assistant|>' + message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %} diff --git a/src/axolotl/utils/chat_templates/templates/exaone.jinja b/src/axolotl/utils/chat_templates/templates/exaone.jinja new file mode 100644 index 000000000..8783ad2ec --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/exaone.jinja @@ -0,0 +1,4 @@ +{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '[|system|][|endofturn|] +' }}{% endif %}{{ '[|' + message['role'] + '|]' + message['content'] }}{% if message['role'] == 'user' %}{{ ' +' }}{% else %}{{ '[|endofturn|] +' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[|assistant|]' }}{% endif %} diff --git a/src/axolotl/utils/chat_templates/templates/falcon_h1.jinja b/src/axolotl/utils/chat_templates/templates/falcon_h1.jinja new file mode 100644 index 000000000..4c03c6297 --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/falcon_h1.jinja @@ -0,0 +1,17 @@ +'{{bos_token}} +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "You are a function calling AI model. You are provided with function signature within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.\n\n" }} + {%- for tool in tools %}[{{- tool | tojson }}]{%- endfor %} + {{- "\n\nFor each function call, return a json object with function name and arguments within tags with the following schema:\n\n{'arguments': , 'name': }\n\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %}{% for message in messages %}{%- if message.role != 'system' %}{{'<|im_start|>' + message['role'] + ' +' + message['content'] + '<|im_end|>' + ' +'}}{%- endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant +' }}{% endif %}' diff --git a/src/axolotl/utils/chat_templates/templates/gemma.jinja b/src/axolotl/utils/chat_templates/templates/gemma.jinja new file mode 100644 index 000000000..6122fe8ae --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/gemma.jinja @@ -0,0 +1,4 @@ +{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + ' +' + message['content'] | trim + ' +' }}{% endfor %}{% if add_generation_prompt %}{{'model +'}}{% endif %} diff --git a/src/axolotl/utils/chat_templates/templates/gemma3.jinja b/src/axolotl/utils/chat_templates/templates/gemma3.jinja new file mode 100644 index 000000000..1117055ab --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/gemma3.jinja @@ -0,0 +1,47 @@ +{{ bos_token }} +{%- if messages[0]['role'] == 'system' -%} + {%- if messages[0]['content'] is string -%} + {%- set first_user_prefix = messages[0]['content'] + ' + +' -%} + {%- else -%} + {%- set first_user_prefix = messages[0]['content'][0]['text'] + ' + +' -%} + {%- endif -%} + {%- set loop_messages = messages[1:] -%} +{%- else -%} + {%- set first_user_prefix = "" -%} + {%- set loop_messages = messages -%} +{%- endif -%} +{%- for message in loop_messages -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} + {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif -%} + {%- if (message['role'] == 'assistant') -%} + {%- set role = "model" -%} + {%- else -%} + {%- set role = message['role'] -%} + {%- endif -%} + {{ '' + role + ' +' + (first_user_prefix if loop.first else "") }} + {%- if message['content'] is string -%} + {{ message['content'] | trim }} + {%- elif message['content'] is iterable -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'image' -%} + {{ '' }} + {%- elif item['type'] == 'text' -%} + {{ item['text'] | trim }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ raise_exception("Invalid content type") }} + {%- endif -%} + {{ ' +' }} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{'model +'}} +{%- endif -%} diff --git a/src/axolotl/utils/chat_templates/templates/gemma3n.jinja b/src/axolotl/utils/chat_templates/templates/gemma3n.jinja new file mode 100644 index 000000000..a0405ea9c --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/gemma3n.jinja @@ -0,0 +1,49 @@ +{{ bos_token }} +{%- if messages[0]['role'] == 'system' -%} + {%- if messages[0]['content'] is string -%} + {%- set first_user_prefix = messages[0]['content'] + ' + +' -%} + {%- else -%} + {%- set first_user_prefix = messages[0]['content'][0]['text'] + ' + +' -%} + {%- endif -%} + {%- set loop_messages = messages[1:] -%} +{%- else -%} + {%- set first_user_prefix = "" -%} + {%- set loop_messages = messages -%} +{%- endif -%} +{%- for message in loop_messages -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} + {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif -%} + {%- if (message['role'] == 'assistant') -%} + {%- set role = "model" -%} + {%- else -%} + {%- set role = message['role'] -%} + {%- endif -%} + {{ '' + role + ' +' + (first_user_prefix if loop.first else "") }} + {%- if message['content'] is string -%} + {{ message['content'] | trim }} + {%- elif message['content'] is iterable -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'audio' -%} + {{ '' }} + {%- elif item['type'] == 'image' -%} + {{ '' }} + {%- elif item['type'] == 'text' -%} + {{ item['text'] | trim }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ raise_exception("Invalid content type") }} + {%- endif -%} + {{ ' +' }} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{'model +'}} +{%- endif -%} diff --git a/src/axolotl/utils/chat_templates/templates/jamba.jinja b/src/axolotl/utils/chat_templates/templates/jamba.jinja new file mode 100644 index 000000000..975938285 --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/jamba.jinja @@ -0,0 +1,255 @@ +{# Variables #} +{% set ns = namespace(message_count=0, is_last_checked_defined=False) %} +{##} +{% set bom_str = bom_str or "<|bom|>" %} +{% set eom_str = eom_str or "<|eom|>" %} +{% set default_system_message = "" %} +{##} +{% set documents_prefix = "" %} +{% set documents_suffix = "" %} +{% set tool_definitions_prefix = "" %} +{% set tool_definitions_suffix = "" %} +{% set active_modes_prefix = "" %} +{% set active_modes_suffix = "" %} +{##} +{% set tool_calls_prefix = "" %} +{% set tool_calls_suffix = "" %} +{% set citations_prefix = "" %} +{% set citations_suffix = "" %} +{##} +{% if add_generation_prompt is not defined %} + {% set add_generation_prompt = True %} +{% endif %} +{% set role_to_predict = role_to_predict or "assistant" %} +{% if messages|length > 0 and messages[0].role == "system" %} + {% set system_message = messages[0].content %} + {% set loop_messages = messages[1:] %} +{% else %} + {% set system_message = default_system_message %} + {% set loop_messages = messages %} +{% endif %} +{##} +{##} +{# Macros #} +{% macro handle_tool_definitions(tools) %} + {{- tool_definitions_prefix -}} + {{- "\n# Tools" -}} + {{- "\n\n## Functions" -}} + {% for tool in tools %} + {% set _ = is_param_set(tool, field="type") %} + {% set is_tool_type_set = ns.is_last_checked_defined %} + {% if is_tool_type_set %} + {% if tool.type == "function" %} + {% set tool = tool.function %} + {% else %} + {{ raise_exception("Currently, the only supported tool type is `function`") }} + {% endif %} + {% endif %} + {{- "\n\n" + (tool|tojson(indent=2)) -}} + {% endfor %} + {{- "\n" + tool_definitions_suffix -}} +{% endmacro %} +{##} +{% macro handle_first_system_message(system_message, tools) %} + {{- bom_str + handle_role("system") -}} + {% set _ = is_param_set(system_message) %} + {% set is_system_message_set = ns.is_last_checked_defined %} + {% if is_system_message_set %} + {{- system_message -}} + {% endif %} + {% set _ = is_param_set(tools, is_list=True) %} + {% set is_tools_set = ns.is_last_checked_defined %} + {% if is_tools_set %} + {% if system_message %} + {{- "\n\n" -}} + {% endif %} + {{- handle_tool_definitions(tools) -}} + {% endif %} + {% set ns.message_count = ns.message_count + 1 %} +{% endmacro %} +{##} +{% macro handle_tool_calls(tool_calls) %} + {{- tool_calls_prefix + "[\n" -}} + {% for tool_call in tool_calls %} + {% set _ = is_param_set(tool_call, field="function") %} + {% set is_tool_call_function_set = ns.is_last_checked_defined %} + {% if is_tool_call_function_set %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {% set arguments = tool_call.arguments %} + {% if arguments is not string %} + {%- set arguments = arguments|tojson -%} + {%- endif %} + {{ "{\"name\": \"" + tool_call.name + "\", \"arguments\": " + arguments + "}" -}} + {% if not loop.last %} + {{- "," }} + {% endif %} + {% endfor %} + {{- "\n]" + tool_calls_suffix -}} +{% endmacro %} +{##} +{% macro handle_documents(documents) %} + {{- documents_prefix -}} + {{- "\n# Documents" -}} + {{- "\n\nYou can use the following documents for reference:" -}} + {% for doc in documents %} + {{- "\n\n## Document ID: " + loop.index0|string -}} + {% set _ = is_param_set(doc, field="title") %} + {% set is_doc_title_set = ns.is_last_checked_defined %} + {% if is_doc_title_set %} + {{- "\nTitle: " + doc.title -}} + {% endif %} + {% for key, value in doc.items() %} + {% if key not in ["title", "text"] %} + {{- "\n" + key|title + ": " + value|string -}} + {% endif %} + {% endfor %} + {{- "\nText: " + doc.text -}} + {% endfor %} + {{- "\n" + documents_suffix -}} +{% endmacro %} +{##} +{% macro handle_knobs(knobs) %} + {{- active_modes_prefix -}} + {{- "\n# Active Modes" -}} + {{ "\n\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}} + {{ " active modes simultaneously." -}} + {% if knobs.citation_mode == "fast" %} + {{- "\n\n## Citation Mode" -}} + {{- "\n\nProvide a list of references only for the documents you base your response on. Format your response" -}} + {{ " with the original answer followed by a citation section. Use this template:" -}} + {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}} + {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}} + {% endif %} + {% if knobs.response_format == "json_object" %} + {{- "\n\n## JSON Mode" -}} + {{ "\n\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}} + {{ " If an appropriate JSON format exists, use it without modification." -}} + {% endif %} + {{- "\n" + active_modes_suffix -}} +{% endmacro %} +{##} +{% macro get_last_user_index(messages) %} + {% set ns.last_user_index = 0 %} + {% for message in messages %} + {% if message.role == 'user' %} + {% set ns.last_user_index = loop.index0 %} + {% endif %} + {% endfor %} + {{- ns.last_user_index -}} +{% endmacro %} +{##} +{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %} + {{- bom_str + handle_role("system") -}} + {% set macros_to_call = [] %} + {% set params_for_macros = [] %} + {% if use_documents %} + {% set macros_to_call = macros_to_call + [handle_documents] %} + {% set params_for_macros = params_for_macros + [[documents]] %} + {% endif %} + {% if use_knobs %} + {% set macros_to_call = macros_to_call + [handle_knobs] %} + {% set params_for_macros = params_for_macros + [[knobs]] %} + {% endif %} + {% for i in range(macros_to_call|length) %} + {% if i > 0 %} + {{- "\n\n" -}} + {% endif %} + {{- macros_to_call[i](*params_for_macros[i]) -}} + {% endfor %} + {% set ns.message_count = ns.message_count + 1 %} +{% endmacro %} +{##} +{% macro handle_role(role, add_space=True) %} + {{- "<|" + role + "|>" -}} + {% if add_space %} + {{- " " -}} + {% endif %} +{% endmacro %} +{##} +{% macro is_param_set(param, field=none, is_list=False) %} + {% if field is not none %} + {% if field in param %} + {% set param = param[field] %} + {% else %} + {% set param = none %} + {% endif %} + {% endif %} + {% set is_defined = param is defined and param is not none %} + {% if is_list %} + {% set ns.is_last_checked_defined = is_defined and param|length > 0 %} + {% else %} + {% set ns.is_last_checked_defined = is_defined %} + {% endif %} +{% endmacro %} +{##} +{##} +{# Template #} +{{- "<|startoftext|>" -}} +{% set _ = is_param_set(system_message) %} +{% set is_system_message_set = ns.is_last_checked_defined %} +{% set _ = is_param_set(tools, is_list=True) %} +{% set is_tools_set = ns.is_last_checked_defined %} +{% set has_system_message = (is_system_message_set or is_tools_set) %} +{% if has_system_message %} + {{- handle_first_system_message(system_message, tools) -}} +{% endif %} +{% set last_user_index = get_last_user_index(loop_messages)|int %} +{% for message in loop_messages %} + {% if loop.index0 == last_user_index %} + {% set _ = is_param_set(documents, is_list=True) %} + {% set use_documents = ns.is_last_checked_defined %} + {% set _ = is_param_set(knobs) %} + {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %} + {% set add_last_system_message = use_documents or use_knobs %} + {% if add_last_system_message %} + {% if ns.message_count > 0 %} + {{- eom_str -}} + {% endif %} + {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}} + {% endif %} + {% endif %} + {% set role = message.role %} + {% set _ = is_param_set(message, field="name") %} + {% set is_message_name_set = ns.is_last_checked_defined %} + {% if is_message_name_set %} + {% set message_prefix = handle_role(role) + "(" + message.name + ")" %} + {% else %} + {% set message_prefix = handle_role(role) %} + {% endif %} + {% set content = (message.content or "") %} + {% if content is not string %} + {% set content = content|tojson %} + {% endif %} + {% if ns.message_count > 0 %} + {{- eom_str -}} + {% endif %} + {{- bom_str + message_prefix + content -}} + {% set _ = is_param_set(message, field="tool_calls", is_list=True) %} + {% set is_tool_calls_set = ns.is_last_checked_defined %} + {% if role == "assistant" and is_tool_calls_set %} + {{- handle_tool_calls(message.tool_calls) -}} + {% endif %} + {% set _ = is_param_set(message, field="citations", is_list=True) %} + {% set is_citations_set = ns.is_last_checked_defined %} + {% if role == "assistant" and is_citations_set %} + {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}} + {% endif %} + {% set ns.message_count = ns.message_count + 1 %} +{% endfor %} +{% if add_generation_prompt %} + {% if ns.message_count > 0 %} + {{- eom_str -}} + {% endif %} + {{- bom_str + handle_role(role_to_predict, add_space=False) -}} + {% set _ = is_param_set(generation_preamble) %} + {% set is_generation_preamble_set = ns.is_last_checked_defined %} + {% if is_generation_preamble_set and generation_preamble.strip() != "" %} + {{- " " + generation_preamble -}} + {% endif %} + {% set ns.message_count = ns.message_count + 1 %} +{% else %} + {% if ns.message_count > 0 %} + {{- eom_str -}} + {% endif %} +{% endif %} diff --git a/src/axolotl/utils/chat_templates/templates/llama3.jinja b/src/axolotl/utils/chat_templates/templates/llama3.jinja new file mode 100644 index 000000000..870322b8f --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/llama3.jinja @@ -0,0 +1,5 @@ +{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|> + +'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|> + +' }}{% endif %} diff --git a/src/axolotl/utils/chat_templates/templates/llama3_2_vision.jinja b/src/axolotl/utils/chat_templates/templates/llama3_2_vision.jinja new file mode 100644 index 000000000..cf488310f --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/llama3_2_vision.jinja @@ -0,0 +1,122 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = true %} +{%- endif %} +{%- if not date_string is defined %} + {%- if strftime_now is defined %} + {%- set date_string = strftime_now("%d %b %Y") %} + {%- else %} + {%- set date_string = "26 Jul 2024" %} + {%- endif %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "" %} +{%- endif %} + +{#- Find out if there are any images #} +{% set image_ns = namespace(has_images=false) %} +{%- for message in messages %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {%- set image_ns.has_images = true %} + {%- endif %} + {%- endfor %} +{%- endfor %} + +{#- Error out if there are images and system message #} +{%- if image_ns.has_images and not system_message == "" %} + {{- raise_exception("Prompting with images is incompatible with system messages.") }} +{%- endif %} + +{#- System message if there are no images #} +{%- if not image_ns.has_images %} + {{- "<|start_header_id|>system<|end_header_id|>\n\n" }} + {%- if tools is not none %} + {{- "Environment: ipython\n" }} + {%- endif %} + {{- "Cutting Knowledge Date: December 2023\n" }} + {{- "Today Date: " + date_string + "\n\n" }} + {%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {%- endif %} + {{- system_message }} + {{- "<|eot_id|>" }} +{%- endif %} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} +{%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a JSON for a function call " }} + {{- "with its proper arguments that best answers the given prompt.\n\n" }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }} + {%- if message['content'] is string %} + {{- message['content'] }} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {{- '<|image|>' }} + {%- elif content['type'] == 'text' %} + {{- content['text'] }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {%- if not message.tool_calls|length == 1 %} + {{- raise_exception("This model only supports single tool-calls at once!") }} + {%- endif %} + {%- set tool_call = message.tool_calls[0].function %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- '{"name": "' + tool_call.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.arguments | tojson }} + {{- "}" }} + {{- "<|eot_id|>" }} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is mapping or message.content is iterable %} + {{- message.content | tojson }} + {%- else %} + {{- message.content }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} diff --git a/src/axolotl/utils/chat_templates/templates/llama4.jinja b/src/axolotl/utils/chat_templates/templates/llama4.jinja new file mode 100644 index 000000000..224052e7d --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/llama4.jinja @@ -0,0 +1,123 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = true %} +{%- endif %} +{%- if not date_string is defined %} + {%- if strftime_now is defined %} + {%- set date_string = strftime_now("%d %b %Y") %} + {%- else %} + {%- set date_string = "26 Jul 2024" %} + {%- endif %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- if messages[0]['content'] is string %} + {%- set system_message = messages[0]['content']|trim %} + {%- else %} + {#- FIXME: The processor requires an array, always. #} + {%- set system_message = messages[0]['content'][0]['text']|trim %} + {%- endif %} + {%- set messages = messages[1:] %} + {%- set user_supplied_system_message = true %} +{%- else %} + {%- set system_message = "" %} + {%- set user_supplied_system_message = false %} +{%- endif %} + +{#- System message if the user supplied one #} +{%- if user_supplied_system_message %} + {{- "<|header_start|>system<|header_end|>\n\n" }} + {%- if tools is not none %} + {{- "Environment: ipython\n" }} + {%- endif %} + {%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {%- endif %} + {{- system_message }} + {{- "<|eot|>" }} +{%- endif %} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} +{%- endif %} + {{- '<|header_start|>user<|header_end|>\n\n' -}} + {{- "Given the following functions, please respond with a JSON for a function call " }} + {{- "with its proper arguments that best answers the given prompt.\n\n" }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }} + {%- if message['content'] is string %} + {{- message['content'] }} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {{- '<|image|>' }} + {%- elif content['type'] == 'text' %} + {{- content['text'] }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- "<|eot|>" }} + {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %} + {{- '<|header_start|>assistant<|header_end|>\n\n' -}} + {{- '<|python_start|>' }} + {%- if message['content'] is string %} + {{- message['content'] }} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {{- '<|image|>' }} + {%- elif content['type'] == 'text' %} + {{- content['text'] }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '<|python_end|>' }} + {%- for tool_call in message.tool_calls %} + {{- '{"name": "' + tool_call.function.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.function.arguments | tojson }} + {{- "}" }} + {%- endfor %} + {{- "<|eot|>" }} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|header_start|>ipython<|header_end|>\n\n" }} + {%- if message.content is mapping or message.content is iterable %} + {{- message.content | tojson }} + {%- else %} + {{- message.content }} + {%- endif %} + {{- "<|eot|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|header_start|>assistant<|header_end|>\n\n' }} +{%- endif %} diff --git a/src/axolotl/utils/chat_templates/templates/llava.jinja b/src/axolotl/utils/chat_templates/templates/llava.jinja new file mode 100644 index 000000000..448bf4dbf --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/llava.jinja @@ -0,0 +1,2 @@ +{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ ' +' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %} diff --git a/src/axolotl/utils/chat_templates/templates/metharme.jinja b/src/axolotl/utils/chat_templates/templates/metharme.jinja new file mode 100644 index 000000000..626d48f29 --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/metharme.jinja @@ -0,0 +1 @@ +{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>' + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %} diff --git a/src/axolotl/utils/chat_templates/templates/mistral_v1.jinja b/src/axolotl/utils/chat_templates/templates/mistral_v1.jinja new file mode 100644 index 000000000..409b06d83 --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/mistral_v1.jinja @@ -0,0 +1 @@ +{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %} diff --git a/src/axolotl/utils/chat_templates/templates/mistral_v2v3.jinja b/src/axolotl/utils/chat_templates/templates/mistral_v2v3.jinja new file mode 100644 index 000000000..3dc6f523d --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/mistral_v2v3.jinja @@ -0,0 +1 @@ +{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %} diff --git a/src/axolotl/utils/chat_templates/templates/mistral_v3_tekken.jinja b/src/axolotl/utils/chat_templates/templates/mistral_v3_tekken.jinja new file mode 100644 index 000000000..2a6749447 --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/mistral_v3_tekken.jinja @@ -0,0 +1 @@ +{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %} diff --git a/src/axolotl/utils/chat_templates/templates/mistral_v7_tekken.jinja b/src/axolotl/utils/chat_templates/templates/mistral_v7_tekken.jinja new file mode 100644 index 000000000..b97e2a097 --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/mistral_v7_tekken.jinja @@ -0,0 +1,51 @@ +{%- set today = strftime_now("%Y-%m-%d") %} +{%- set default_system_message = "You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.\nYour knowledge base was last updated on 2023-10-01. The current date is " + today + ".\n\nWhen you're not sure about some information, you say that you don't have the information and don't make up anything.\nIf the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. \"What are some good restaurants around me?\" => \"Where are you?\" or \"When is the next flight to Tokyo\" => \"Where do you travel from?\")" %} + +{{- bos_token }} + +{%- if messages[0]['role'] == 'system' %} + {%- if messages[0]['content'] is string %} + {%- set system_message = messages[0]['content'] %} + {%- else %} + {%- set system_message = messages[0]['content'][0]['text'] %} + {%- endif %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set system_message = default_system_message %} + {%- set loop_messages = messages %} +{%- endif %} +{{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }} + +{%- for message in loop_messages %} + {%- if message['role'] == 'user' %} + {%- if message['content'] is string %} + {{- '[INST]' + message['content'] + '[/INST]' }} + {%- else %} + {{- '[INST]' }} + {%- for block in message['content'] %} + {%- if block['type'] == 'text' %} + {{- block['text'] }} + {%- elif block['type'] in ['image', 'image_url'] %} + {{- '[IMG]' }} + {%- else %} + {{- raise_exception('Only text and image blocks are supported in message content!') }} + {%- endif %} + {%- endfor %} + {{- '[/INST]' }} + {%- endif %} + {%- elif message['role'] == 'system' %} + {%- if message['content'] is string %} + {{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }} + {%- else %} + {{- '[SYSTEM_PROMPT]' + message['content'][0]['text'] + '[/SYSTEM_PROMPT]' }} + {%- endif %} + {%- elif message['role'] == 'assistant' %} + {%- if message['content'] is string %} + {{- message['content'] + eos_token }} + {%- else %} + {{- message['content'][0]['text'] + eos_token }} + {%- endif %} + {%- else %} + {{- raise_exception('Only user, system and assistant roles are supported!') }} + {%- endif %} +{%- endfor %} diff --git a/src/axolotl/utils/chat_templates/templates/phi_3.jinja b/src/axolotl/utils/chat_templates/templates/phi_3.jinja new file mode 100644 index 000000000..853942eba --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/phi_3.jinja @@ -0,0 +1,7 @@ +{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + ' +' + message['content'] + '<|end|>' + ' +'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + ' +' + message['content'] + '<|end|>' + ' +' + '<|assistant|>' + ' +'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + ' +'}}{% endif %}{% endfor %} diff --git a/src/axolotl/utils/chat_templates/templates/phi_35.jinja b/src/axolotl/utils/chat_templates/templates/phi_35.jinja new file mode 100644 index 000000000..aae8a8f51 --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/phi_35.jinja @@ -0,0 +1,8 @@ +{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|> +' + message['content'] + '<|end|> +'}}{% elif message['role'] == 'user' %}{{'<|user|> +' + message['content'] + '<|end|> +'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|> +' + message['content'] + '<|end|> +'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|> +' }}{% endif %} diff --git a/src/axolotl/utils/chat_templates/templates/phi_4.jinja b/src/axolotl/utils/chat_templates/templates/phi_4.jinja new file mode 100644 index 000000000..ed1861f6c --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/phi_4.jinja @@ -0,0 +1 @@ +{% set system_message = 'You are Phi, a language model trained by Microsoft to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: {Thought section} {Solution section}. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines:' -%}{%- if messages and messages[0]['role'] == 'system' -%}{%- set system_message = messages[0]['content'] -%}{%- set messages = messages[1:] -%}{%- endif -%}<|im_start|>system<|im_sep|>{{ system_message }}<|im_end|>{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'assistant') %}{{'<|im_start|>assistant<|im_sep|>'}}{% generation %}{{message['content'] + '<|im_end|>'}}{% endgeneration %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %} diff --git a/src/axolotl/utils/chat_templates/templates/pixtral.jinja b/src/axolotl/utils/chat_templates/templates/pixtral.jinja new file mode 100644 index 000000000..a94177112 --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/pixtral.jinja @@ -0,0 +1,53 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{{- bos_token }} +{%- for message in loop_messages %} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }} + {%- endif %} + {%- if message["role"] == "user" %} + {%- if loop.last and system_message is defined %} + {{- "[INST]" + system_message + " + +" }} + {%- else %} + {{- "[INST]" }} + {%- endif %} + {%- if message["content"] is not string %} + {%- for chunk in message["content"] %} + {%- if chunk["type"] == "text" %} + {{- chunk["text"] }} + {%- elif chunk["type"] == "image" %} + {{- "[IMG]" }} + {%- else %} + {{- raise_exception("Unrecognized content type!") }} + {%- endif %} + {%- endfor %} + {%- else %} + {{- message["content"] }} + {%- endif %} + {{- "[/INST]" }} + {%- elif message["role"] == "assistant" %} + {%- if message["content"] is not string %} + {%- for chunk in message["content"] %} + {%- if chunk["type"] == "text" %} + {{- chunk["text"] }} + {%- elif chunk["type"] == "image" %} + {{- "[IMG]" }} + {%- else %} + {{- raise_exception("Unrecognized content type!") }} +{%- endif %} +{%- endfor %} +{{- eos_token }} +{%- else %} +{{- message["content"] + eos_token }} +{%- endif %} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} diff --git a/src/axolotl/utils/chat_templates/templates/qwen2_vl.jinja b/src/axolotl/utils/chat_templates/templates/qwen2_vl.jinja new file mode 100644 index 000000000..426b7642d --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/qwen2_vl.jinja @@ -0,0 +1,7 @@ +{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system +You are a helpful assistant.<|im_end|> +{% endif %}<|im_start|>{{ message['role'] }} +{% if message['content'] is string %}{{ message['content'] }}<|im_end|> +{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|> +{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant +{% endif %} diff --git a/src/axolotl/utils/chat_templates/templates/qwen3.jinja b/src/axolotl/utils/chat_templates/templates/qwen3.jinja new file mode 100644 index 000000000..09b82ed03 --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/qwen3.jinja @@ -0,0 +1,87 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set content = message.content %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is defined and message.reasoning_content is not none %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in message.content %} + {%- set content = message.content.split('')[-1].lstrip('\n') %} + {%- set reasoning_content = message.content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n\n' }} + {%- endif %} +{%- endif %} diff --git a/src/axolotl/utils/chat_templates/templates/qwen_25.jinja b/src/axolotl/utils/chat_templates/templates/qwen_25.jinja new file mode 100644 index 000000000..bdf7919a9 --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/qwen_25.jinja @@ -0,0 +1,54 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0]['role'] == 'system' %} + {{- messages[0]['content'] }} + {%- else %} + {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }} + {%- endif %} + {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0]['role'] == 'system' %} + {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }} + {%- else %} + {{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- if message.content %} + {{- '\n' + message.content }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments | tojson }} + {{- '}\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/src/axolotl/utils/collators/__init__.py b/src/axolotl/utils/collators/__init__.py index 8c60f223c..d5e6ad17d 100644 --- a/src/axolotl/utils/collators/__init__.py +++ b/src/axolotl/utils/collators/__init__.py @@ -1,11 +1,17 @@ -""" -shared axolotl collators for multipack, mamba, multimodal -""" +"""Shared axolotl collators for multipacking, mamba, multimodal.""" -from .batching import ( # noqa: F401 +from .batching import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, PretrainingBatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq, ) -from .mamba import MambaDataCollator # noqa: F401 +from .mamba import MambaDataCollator + +__all__ = [ + "DataCollatorForSeq2Seq", + "BatchSamplerDataCollatorForSeq2Seq", + "V2BatchSamplerDataCollatorForSeq2Seq", + "PretrainingBatchSamplerDataCollatorForSeq2Seq", + "MambaDataCollator", +] diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 45facf832..55e630fbe 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -1,7 +1,7 @@ """Data collators for axolotl to pad labels and position_ids for packed sequences""" from dataclasses import dataclass -from typing import Any +from typing import Any, List import numpy as np from transformers import PreTrainedTokenizerBase @@ -81,9 +81,11 @@ class DataCollatorForSeq2Seq: padding_side = self.tokenizer.padding_side for feature in features: - remainder = [pad_token_id] * ( - max_feature_length - len(feature[feature_name]) - ) + remainder_len = max_feature_length - len(feature[feature_name]) + if feature_name == "position_ids": + remainder = list(range(remainder_len)) + else: + remainder = [pad_token_id] * remainder_len if isinstance(feature[feature_name], list): feature[feature_name] = ( feature[feature_name] + remainder @@ -106,7 +108,7 @@ class DataCollatorForSeq2Seq: pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=return_tensors, ) - if not has_attn_mask: + if not has_attn_mask and "attention_mask" in features: del features["attention_mask"] # prepare decoder_input_ids @@ -159,9 +161,11 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): Collator for multipack specific to the using the BatchSampler """ + squash_position_ids: bool = False + def __call__(self, features, return_tensors=None): if not isinstance(features[0], list): - features = [features] + features: List[List[dict]] = [features] out_features = [{} for _ in features] for i, features_ in enumerate(features): for feature in features_[0].keys(): @@ -174,6 +178,15 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): if feature in item ] out_features[i][feature] = np.concatenate(arrays) + elif feature == "position_ids" and self.squash_position_ids: + arrays = [ + np.array(item[feature]) for item in features_ if feature in item + ] + # concatenate, get total length and create arange of new total position ids + position_ids = np.concatenate(arrays) + total_length = position_ids.shape[0] + position_ids = np.arange(total_length) + out_features[i][feature] = position_ids else: arrays = [ np.array(item[feature]) for item in features_ if feature in item diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index 75d72f8dc..542918527 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -5,7 +5,6 @@ Collators for multi-modal chat messages and packing from dataclasses import dataclass from typing import Any, Optional, Union -import torch from torch import Tensor from transformers import PreTrainedTokenizerBase from transformers.data.data_collator import DataCollatorMixin @@ -42,51 +41,19 @@ class MultiModalChatDataCollator(DataCollatorMixin): examples = self.processing_strategy(examples) # Initialize batch - batch: dict[str, Any] = {} + messages = [ex["messages"] for ex in examples] - # Process each example - for example in examples: - # Apply chat template to process the example - # This method requires transformers>=4.49.0 - result = self.processing_strategy.processor.apply_chat_template( - example["messages"], - add_generation_prompt=True, - tokenize=True, - return_tensors="pt", - padding=True, - return_dict=True, - chat_template=self.processing_strategy.chat_template, - ) - - # TODO: Check if need handling for len(input_ids) > sequence_len - - # Add the processed tensors to our batch - for key in result.keys(): - if key not in batch: - batch[key] = [] - - batch[key].append(result[key].squeeze(0)) - - # Pad sequences to the same length - input_ids = torch.nn.utils.rnn.pad_sequence( - batch["input_ids"], - batch_first=True, - padding_value=self.tokenizer.pad_token_id, + batch = self.processing_strategy.processor.apply_chat_template( + messages, + add_generation_prompt=False, + tokenize=True, + return_tensors="pt", + padding=True, + return_dict=True, + chat_template=self.processing_strategy.chat_template, ) - attention_mask = torch.nn.utils.rnn.pad_sequence( - batch["attention_mask"], batch_first=True, padding_value=0 - ) - - # Create the final batch - final_batch = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - # Process the labels - final_batch["labels"] = self.processing_strategy.process_labels( - final_batch["input_ids"] - ) + batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"]) - return final_batch + return batch diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index e0eaf9ac9..7a2bbd6f9 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -17,11 +17,11 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger from axolotl.utils.schemas.config import ( AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, + AxolotlInputConfig as AxolotlInputConfigBase, ) -from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset -LOG = get_logger(__name__, use_environ=True) +LOG = get_logger(__name__) def choose_device(cfg): @@ -37,7 +37,7 @@ def choose_device(cfg): return f"npu:{cfg.local_rank}" raise SystemError("No CUDA/mps/npu device found") - except Exception: # pylint: disable=broad-exception-caught + except Exception: return "cpu" cfg.device = get_device() @@ -77,7 +77,7 @@ def resolve_dtype(cfg): if cfg.device == "mps": cfg.load_in_8bit = False cfg.tf32 = False - if cfg.bf16: + if cfg.bf16 and cfg.fp16 is not False: cfg.fp16 = True cfg.bf16 = False else: @@ -116,9 +116,10 @@ def normalize_config(cfg): ] choose_device(cfg) cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 - if cfg.ddp: + if cfg.world_size != 1: cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} - cfg.batch_size = cfg.batch_size * cfg.world_size + if cfg.fsdp or cfg.fsdp_config or cfg.ddp: + cfg.batch_size = cfg.batch_size * cfg.world_size if not cfg.use_ray: # delay resolving dtype until on worker node when launching with ray @@ -147,8 +148,6 @@ def normalize_config(cfg): f"Invalid value for eval_steps ({eval_steps}) from evals_per_epoch and/or num_epochs. Skipping evaluations." ) - cfg.dataset_processes = cfg.dataset_processes or os.cpu_count() - if not cfg.base_model_config: cfg.base_model_config = cfg.base_model @@ -267,14 +266,16 @@ def validate_config( if cfg.plugins: ( - AxolotlConfigWCapabilities, # pylint: disable=invalid-name - AxolotlInputConfig, # pylint: disable=invalid-name + AxolotlConfigWCapabilities, + AxolotlInputConfig, ) = merge_input_args() # Convert datasets to proper format if needed if cfg.get("datasets"): for idx, ds_cfg in enumerate(cfg["datasets"]): - if cfg.get("rl") == "dpo" and not isinstance(ds_cfg, DPODataset): + if cfg.get("rl") in ["dpo", "ipo", "simpo"] and not isinstance( + ds_cfg, DPODataset + ): cfg["datasets"][idx] = DPODataset(**ds_cfg) elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset): cfg["datasets"][idx] = KTODataset(**dict(ds_cfg)) diff --git a/src/axolotl/utils/ctx_managers/__init__.py b/src/axolotl/utils/ctx_managers/__init__.py index e544621b5..6ffda9e55 100644 --- a/src/axolotl/utils/ctx_managers/__init__.py +++ b/src/axolotl/utils/ctx_managers/__init__.py @@ -1,6 +1,5 @@ """Init for context manager submodule""" -# pylint: disable=unused-import # flake8: noqa from .sequence_parallel import SequenceParallelContextManager diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index 491cb9877..78b3d1cae 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -6,15 +6,14 @@ import inspect import torch import torch.distributed as dist from torch import nn +from torch.distributed import DeviceMesh from torch.utils.hooks import RemovableHandle from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import ModelOutput from axolotl.monkeypatch.ring_attn import ( get_ring_attn_group, - patch_prepare_data_loader, - patch_prepare_device_mesh, - register_ring_attn, + register_ring_attn_from_device_mesh, update_ring_attn_params, ) from axolotl.utils.schemas.enums import RingAttnFunc @@ -27,7 +26,7 @@ def apply_sequence_parallelism( local_rank: int, local_world_size: int, gradient_accumulation_steps: int, - ring_attn_func: RingAttnFunc, # pylint: disable=unused-argument + ring_attn_func: RingAttnFunc, ) -> tuple[dict[str, torch.Tensor], int, int]: """ Apply sequence parallelism slicing to a batch. @@ -49,10 +48,10 @@ def apply_sequence_parallelism( - The original sequence length before padding. - The number of padding tokens added. """ - original_seq_len = batch["input_ids"].size(1) + batch_size, original_seq_len = batch["input_ids"].shape # Update ring attention params if needed - if batch.get("position_ids") is not None: + if batch.get("position_ids") is not None and batch_size == 1: update_ring_attn_params(position_ids=batch["position_ids"]) else: # If position_ids aren't already in the batch, create them @@ -152,9 +151,18 @@ def apply_sequence_parallelism( if "num_items_in_batch" in batch: # Approximation; this needed since num_items_in_batch may be counted across # all samples in a gradient accumulated batch, not on a per-step basis. + local_valid_tokens = (batch["labels"] != -100).sum() + + # All-reduce across sequence parallel ranks to get global token count + cp_group = get_ring_attn_group() + global_valid_tokens = local_valid_tokens.clone() + # we use AVG instead of SUM as using sum seems to scale down the loss by over-accounting the number of tokens + dist.all_reduce(global_valid_tokens, op=dist.ReduceOp.AVG, group=cp_group) + global_valid_tokens = int(global_valid_tokens.item()) + batch["num_items_in_batch"] = ( - batch["labels"] != -100 - ).sum() * gradient_accumulation_steps + global_valid_tokens * gradient_accumulation_steps + ) return batch, original_seq_len, pad_len @@ -169,26 +177,33 @@ class SequenceParallelContextManager: Args: models: List of models to apply sequence parallelism to pre- and post- forward hooks. - sequence_parallel_degree: Number of processes to split sequences over. + context_parallel_size: Number of processes to split sequences over. gradient_accumulation_steps: Number of steps to accumulate gradients over. ring_attn_func: Which ring attention function to use. Currently unused. heads_k_stride: Sequence parallelism K head stride size. Passed through to `varlen_llama3` `ring_flash_attn` implementation. + gather_outputs: Whether to gather outputs after model forward pass across the + sequence parallel group. """ def __init__( self, models: list[nn.Module], - sequence_parallel_degree: int, + context_parallel_size: int, gradient_accumulation_steps: int, ring_attn_func: RingAttnFunc, heads_k_stride: int | None, + gather_outputs: bool, + device_mesh: DeviceMesh | None = None, ): self.models = models - self.sequence_parallel_degree = sequence_parallel_degree + self.context_parallel_size = context_parallel_size self.gradient_accumulation_steps = gradient_accumulation_steps self.ring_attn_func = ring_attn_func self.heads_k_stride = heads_k_stride + self.gather_outputs = gather_outputs + self.device_mesh = device_mesh + self._register_ring_attn() # Set distributed info for local rank @@ -227,18 +242,13 @@ class SequenceParallelContextManager: def _register_ring_attn(self): # Initialize ring attn for sequence parallelism - register_ring_attn( - sequence_parallel_degree=self.sequence_parallel_degree, + register_ring_attn_from_device_mesh( + device_mesh=self.device_mesh, + context_parallel_dim=("cp",), heads_k_stride=self.heads_k_stride, ring_attn_func=self.ring_attn_func, ) - # Patches for accelerate functionality - patch_prepare_data_loader() - patch_prepare_device_mesh( - sequence_parallel_degree=self.sequence_parallel_degree - ) - def _register_model_hooks(self): # Forward pre-hook to apply sequence parallelism def sequence_parallel_pre_hook(_, args, kwargs): @@ -277,16 +287,17 @@ class SequenceParallelContextManager: return output - # Register both hooks + # Register hooks for model in self.models: self.hook_handles.append( model.register_forward_pre_hook( sequence_parallel_pre_hook, with_kwargs=True ) ) - self.hook_handles.append( - model.register_forward_hook(sequence_parallel_post_hook) - ) + if self.gather_outputs: + self.hook_handles.append( + model.register_forward_hook(sequence_parallel_post_hook) + ) def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: """Gather sharded outputs from all ranks and reconstruct the full tensor.""" diff --git a/src/axolotl/utils/data/__init__.py b/src/axolotl/utils/data/__init__.py index 8dedcbe69..8b9e4e91d 100644 --- a/src/axolotl/utils/data/__init__.py +++ b/src/axolotl/utils/data/__init__.py @@ -1,16 +1,21 @@ -""" -Data processing modules -""" +"""Init for `axolotl.utils.data` module.""" -from axolotl.utils.data.pretraining import ( # noqa: F401 - encode_pretraining, - wrap_pretraining_dataset, -) -from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401 -from axolotl.utils.data.sft import ( # noqa: F401 +from axolotl.utils.data.rl import prepare_preference_datasets +from axolotl.utils.data.sft import ( get_dataset_wrapper, - load_prepare_datasets, - load_tokenized_prepared_datasets, - prepare_dataset, + prepare_datasets, ) -from axolotl.utils.data.utils import md5 # noqa: F401 +from axolotl.utils.data.streaming import ( + encode_streaming, + wrap_streaming_dataset, +) +from axolotl.utils.data.utils import md5 + +__all__ = [ + "encode_streaming", + "wrap_streaming_dataset", + "prepare_preference_datasets", + "get_dataset_wrapper", + "prepare_datasets", + "md5", +] diff --git a/src/axolotl/utils/data/lock.py b/src/axolotl/utils/data/lock.py new file mode 100644 index 000000000..afd1547af --- /dev/null +++ b/src/axolotl/utils/data/lock.py @@ -0,0 +1,68 @@ +"""Logic for loading / preparing a dataset once over all processes.""" + +import time +from pathlib import Path +from typing import Any, Callable + +from filelock import FileLock + +from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH +from axolotl.utils.dict import DictDefault + +LOCK_FILE_NAME = "datasets_prep.lock" +READY_FILE_NAME = "datasets_ready.flag" +PROCESS_COUNTER_FILE_NAME = "process_counter.txt" + + +class FileLockLoader: + """ + Simple class for abstracting single process data loading / processing. The first + process that creates a lock file does the work; the remaining procesees simply load + the preprocessed dataset once the first process is done. + """ + + def __init__(self, cfg: DictDefault): + self.cfg = cfg + self.dataset_prepared_path = ( + cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH + ) + self.lock_file_path = Path(self.dataset_prepared_path) / LOCK_FILE_NAME + self.ready_flag_path = Path(self.dataset_prepared_path) / READY_FILE_NAME + self.counter_path = Path(self.dataset_prepared_path) / PROCESS_COUNTER_FILE_NAME + + def load(self, load_fn: Callable[[], Any]) -> Any: + with FileLock(str(self.lock_file_path)): + self._increment_counter() + + if not self.ready_flag_path.exists(): + result = load_fn() + self.ready_flag_path.touch() + return result + + while not self.ready_flag_path.exists(): + time.sleep(1) + return load_fn() + + def _increment_counter(self): + """Safely increment the process counter.""" + if self.counter_path.exists(): + counter_content = self.counter_path.read_text().strip() + count = int(counter_content) if counter_content else 0 + else: + count = 0 + self.counter_path.write_text(str(count + 1)) + + def cleanup(self): + """Clean up ready flag when last process is done.""" + with FileLock(str(self.lock_file_path)): + counter_content = self.counter_path.read_text().strip() + count = int(counter_content) if counter_content else 0 + count -= 1 + + if count <= 0: + # Last process cleans everything up + self.ready_flag_path.unlink(missing_ok=True) + self.counter_path.unlink(missing_ok=True) + else: + # Still have active processes + self.counter_path.write_text(str(count)) diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 9264c86ab..f7a5ec04c 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -1,89 +1,145 @@ -"""data handling specific to DPO""" +"""Data handling specific to RL trainers.""" import inspect from functools import partial -from pathlib import Path -from typing import Any, List, Union +from typing import Any, Callable, Literal -import yaml -from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk +from datasets import Dataset, DatasetDict +from transformers import PreTrainedTokenizer -from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.loaders import load_tokenizer from axolotl.prompt_strategies.dpo import load as load_dpo from axolotl.prompt_strategies.kto import load as load_kto from axolotl.prompt_strategies.orpo import load as load_orpo -from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config -from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 +from axolotl.utils.data.lock import FileLockLoader +from axolotl.utils.data.shared import ( + create_train_validation_split, + datasets_with_name_generator, + generate_dataset_hash_from_config, + load_dataset_with_config, + load_preprocessed_dataset, + merge_datasets, + save_preprocessed_dataset, + try_load_from_hub, +) +from axolotl.utils.data.utils import ( + deduplicate_and_log_datasets, + retry_on_request_exceptions, +) from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RLType LOG = get_logger(__name__) -def _get_path(ds_hash, cfg): - prepared_ds_path = ( - Path(cfg.dataset_prepared_path) / ds_hash - if cfg.dataset_prepared_path - else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash - ) +@retry_on_request_exceptions(max_retries=3, delay=5) +def prepare_preference_datasets( + cfg: DictDefault, tokenizer: PreTrainedTokenizer +) -> tuple[Dataset, Dataset | None]: + """Load and prepare preference datasets for RL training. - return prepared_ds_path + Loads training and evaluation datasets, handling preprocessing, caching, and + deduplication as configured. Uses FileLock for distributed coordination. + + Args: + cfg: Configuration object containing dataset and training settings. + tokenizer: Tokenizer to use for processing text. + + Returns: + Tuple of (train_dataset, eval_dataset). eval_dataset may be None + if no evaluation dataset is configured. + """ + + def _load_datasets(): + # Load training dataset + train_dataset = _load_or_create_dataset_split(cfg, tokenizer, split="train") + + # Load or create evaluation dataset + eval_dataset: Dataset | None = None + if cfg.test_datasets: + eval_dataset = _load_or_create_dataset_split(cfg, tokenizer, split="test") + elif cfg.val_set_size: + # Create validation split from training data + train_dataset, eval_dataset = create_train_validation_split( + train_dataset, cfg, cfg.val_set_size + ) + + return train_dataset, eval_dataset + + # Prepare datasets (with file locking logic for multiple ranks) + loader = FileLockLoader(cfg) + try: + train_dataset, eval_dataset = loader.load(_load_datasets) + finally: + loader.cleanup() + + # Apply deduplication if configured + if cfg.dataset_exact_deduplication: + train_dataset, eval_dataset = deduplicate_and_log_datasets( + dataset=train_dataset, other_dataset=eval_dataset + ) + + return train_dataset, eval_dataset -def _load_preprocessed_ds(cfg, sub_cfg): - ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper)) - prepared_ds_path = _get_path(ds_hash, cfg) - dataset = None +def _map_dataset( + cfg: DictDefault, + dataset: Dataset | DatasetDict, + ds_transform_fn: Callable[..., Any], + tokenizer: Any | None = None, + **map_kwargs: Any, +) -> Dataset: + """Apply transformation function to dataset. - # pylint: disable=duplicate-code - if ( - cfg.dataset_prepared_path - and any(prepared_ds_path.glob("*")) - and not cfg.is_preprocess - ): - LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") - dataset = load_from_disk(str(prepared_ds_path)) + Args: + cfg: Configuration object. + dataset: Dataset to transform. + ds_transform_fn: Transformation function to apply. + tokenizer: Optional tokenizer for transformation. + **map_kwargs: Additional arguments for dataset mapping. - return dataset - - -def _save_preprocessed_ds(cfg, sub_cfg, dataset): - ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper)) - prepared_ds_path = _get_path(ds_hash, cfg) - - if cfg.is_preprocess and is_main_process(): - LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") - dataset.save_to_disk(str(prepared_ds_path)) - - -def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs): + Returns: + Transformed dataset. + """ sig = inspect.signature(ds_transform_fn) if "tokenizer" in sig.parameters: if not tokenizer: tokenizer = load_tokenizer(cfg) ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer) - if isinstance(data_set, DatasetDict): - data_set = data_set["train"] + if isinstance(dataset, DatasetDict): + dataset = dataset["train"] - data_set = data_set.map( + dataset = dataset.map( ds_transform_fn, - num_proc=cfg.dataset_processes, + num_proc=cfg.dataset_num_proc, load_from_cache_file=not cfg.is_preprocess, desc="Mapping RL Dataset", **map_kwargs, ) - return data_set + return dataset -def drop_long_rl_seq( - sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name -): - if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO): +def _drop_long_sequences( + sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int +) -> bool: + """Filter out samples that exceed maximum sequence length. + + Args: + sample: Dataset sample to check. + rl: Reinforcement learning type. + tokenizer: Tokenizer for length calculation. + sequence_len: Maximum allowed sequence length. + + Returns: + True if sample should be kept, False if it should be dropped. + + Raises: + ValueError: If required keys are missing or RL type is unknown. + """ + if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}: if not ( sample.get("prompt") and sample.get("chosen") and sample.get("rejected") ): @@ -123,132 +179,114 @@ def drop_long_rl_seq( raise ValueError("Unknown RL type") -def load_prepare_preference_datasets(cfg): - def load_split(dataset_cfgs, _cfg): - split_datasets: List[Any] = [] - use_auth_token = _cfg.hf_use_auth_token - for config_dataset in datasets_w_name_generator(dataset_cfgs): - ds: Union[Dataset, DatasetDict] = load_dataset_w_config( - config_dataset, use_auth_token, streaming=False - ) - split_datasets.append(ds) +def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset: + """Load and process dataset split for RL training. - tokenizer = load_tokenizer(cfg) + Args: + cfg: Configuration object containing dataset settings. + split: Dataset split to load ("train" or "test"). - for i, data_set in enumerate(split_datasets): - _type = dataset_cfgs[i]["type"] - if _type: - if isinstance(_type, DictDefault): - _type = "user_defined.default" - if _cfg.rl is RLType.ORPO: - ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i) - elif _cfg.rl is RLType.KTO: - ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) - else: - ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) + Returns: + Combined and processed dataset for the specified split. + """ + datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets + split_datasets: list[Dataset | DatasetDict] = [] - map_kwargs = {} - if isinstance(ds_transform_fn, tuple): - ds_transform_fn, map_kwargs = ds_transform_fn - split_datasets[i] = map_dataset( - cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs - ) - elif _cfg.rl is RLType.KTO: - ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) - map_kwargs = {} - if isinstance(ds_transform_fn, tuple): - ds_transform_fn, map_kwargs = ds_transform_fn - split_datasets[i] = map_dataset( - cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs - ) - else: - # If no `type` is provided, assume the dataset is already in the expected format with - # "prompt", "chosen" and "rejected" already preprocessed - split_datasets[i] = data_set - - if not cfg.skip_prepare_dataset: - drop_long = partial( - drop_long_rl_seq, - rl=_cfg.rl, - tokenizer=tokenizer, - sequence_len=cfg.sequence_len, - ) - - prior_len = len(split_datasets[i]) - split_datasets[i] = split_datasets[i].filter( - drop_long, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Long Sequences", - ) - dropped = prior_len - len(split_datasets[i]) - if dropped: - LOG.warning( - f"Dropped {dropped} long samples from dataset index {i}" - ) - - combined_datasets = concatenate_datasets(split_datasets) - combined_datasets = combined_datasets.shuffle(seed=cfg.seed or 42) - - return combined_datasets - - with zero_first(is_main_process()): - train_is_preprocessed = False - eval_is_preprocessed = False - if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets): - train_is_preprocessed = True - else: - train_dataset = load_split(cfg.datasets, cfg) - - eval_dataset = None - if cfg.test_datasets: - if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets): - eval_is_preprocessed = True - else: - eval_dataset = load_split(cfg.test_datasets, cfg) - if not eval_dataset: - if cfg.val_set_size: - seed = cfg.seed if cfg.seed is not None else 42 - - # ensure we end up with the same fingerprint by doing rank0 first and being able to cache - to_hash_train = ( - train_dataset._fingerprint # pylint: disable=protected-access - + "|" - + str(cfg.val_set_size) - + "|" - + "train" - + "|" - + str(cfg.seed or 42) - ) - to_hash_test = ( - train_dataset._fingerprint # pylint: disable=protected-access - + "|" - + str(cfg.val_set_size) - + "|" - + "test" - + "|" - + str(cfg.seed or 42) - ) - train_fingerprint = md5(to_hash_train) - test_fingerprint = md5(to_hash_test) - ds_w_test_split = train_dataset.train_test_split( - test_size=cfg.val_set_size, - seed=seed, - shuffle=False, - train_new_fingerprint=train_fingerprint, - test_new_fingerprint=test_fingerprint, - ) - eval_dataset = ds_w_test_split["test"] - train_dataset = ds_w_test_split["train"] - - if not train_is_preprocessed: - _save_preprocessed_ds(cfg, cfg.datasets, train_dataset) - if eval_dataset and not eval_is_preprocessed: - _save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset) - - if cfg.dataset_exact_deduplication: - train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( - train_dataset=train_dataset, eval_dataset=eval_dataset + for dataset_config in datasets_with_name_generator(datasets_configs): + dataset: Dataset | DatasetDict = load_dataset_with_config( + dataset_config, cfg.hf_use_auth_token, streaming=False ) + split_datasets.append(dataset) - return train_dataset, eval_dataset + tokenizer = load_tokenizer(cfg) + + for i, dataset in enumerate(split_datasets): + _type = datasets_configs[i]["type"] + if _type: + if isinstance(_type, DictDefault): + _type = "user_defined.default" + if cfg.rl is RLType.ORPO: + ds_transform_fn = load_orpo(_type, cfg, dataset_idx=i) + elif cfg.rl is RLType.KTO: + ds_transform_fn = load_kto(_type, cfg, dataset_idx=i) + else: + ds_transform_fn = load_dpo(_type, cfg, dataset_idx=i) + + map_kwargs: dict[str, Any] = {} + if isinstance(ds_transform_fn, tuple): + ds_transform_fn, map_kwargs = ds_transform_fn + split_datasets[i] = _map_dataset( + cfg, dataset, ds_transform_fn, tokenizer, **map_kwargs + ) + else: + # If no `type` is provided, assume the dataset is already in the expected format with + # "prompt", "chosen", and "rejected" already preprocessed + split_datasets[i] = dataset + + if not cfg.skip_prepare_dataset: + drop_long = partial( + _drop_long_sequences, + rl=cfg.rl, + tokenizer=tokenizer, + sequence_len=cfg.sequence_len, + ) + + prior_len = len(split_datasets[i]) + split_datasets[i] = split_datasets[i].filter( + drop_long, + num_proc=cfg.dataset_num_proc, + load_from_cache_file=not cfg.is_preprocess, + desc="Dropping Long Sequences", + ) + dropped = prior_len - len(split_datasets[i]) + if dropped: + LOG.warning(f"Dropped {dropped} long samples from dataset index {i}") + + # Merge datasets + dataset = merge_datasets(split_datasets, cfg) + + if not cfg.skip_prepare_dataset: + # Save preprocessed dataset + dataset_hash = generate_dataset_hash_from_config( + cfg, datasets_configs, tokenizer.name_or_path + ) + save_preprocessed_dataset(cfg, dataset, dataset_hash, split) + + return dataset + + +def _load_or_create_dataset_split( + cfg: DictDefault, tokenizer: PreTrainedTokenizer, split: Literal["train", "test"] +) -> Dataset: + """Load preprocessed dataset or create new one for given split. + + Args: + cfg: Configuration object. + tokenizer: Tokenizer to use for processing text. + split: Dataset split to load. + + Returns: + Tuple of (dataset, is_preprocessed). + """ + # Select correct dataset configuration based on split + datasets_config = cfg.datasets if split == "train" else cfg.test_datasets + + # Generate dataset hash for caching + dataset_hash = generate_dataset_hash_from_config( + cfg, datasets_config, tokenizer.name_or_path + ) + + # Try loading from hub if push_dataset_to_hub is configured + dataset = None + if cfg.push_dataset_to_hub: + dataset = try_load_from_hub(cfg, dataset_hash, split) + + # Attempt to load preprocessed dataset + if dataset is None: + dataset = load_preprocessed_dataset(cfg, dataset_hash) + + # Otherwise, load it + if dataset is None: + dataset = _load_split(cfg, split=split) + + return dataset diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 88c78174b..ba5aec2d6 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -1,58 +1,40 @@ -"""data handling specific to SFT""" +"""Data handling specific to SFT.""" import functools import os import tempfile -from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import Literal from datasets import ( Dataset, DatasetDict, IterableDataset, - Sequence, - Value, - concatenate_datasets, + IterableDatasetDict, load_dataset, - load_from_disk, ) -from transformers import PreTrainedTokenizerBase +from transformers import PreTrainedTokenizer, ProcessorMixin -from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH -from axolotl.datasets import TokenizedPromptDataset, wrap_dataset_for_tokenized_prompt -from axolotl.prompt_strategies import load -from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load -from axolotl.prompt_tokenizers import ( - AlpacaMultipleChoicePromptTokenizingStrategy, - AlpacaPromptTokenizingStrategy, - AlpacaReflectionPTStrategy, - DatasetWrappingStrategy, - GPTeacherPromptTokenizingStrategy, - JeopardyPromptTokenizingStrategy, - OpenAssistantPromptTokenizingStrategy, - SummarizeTLDRPromptTokenizingStrategy, +from axolotl.prompters import Prompter +from axolotl.utils.data.lock import FileLockLoader +from axolotl.utils.data.shared import ( + create_train_validation_split, + datasets_with_name_generator, + generate_dataset_hash_from_config, + load_dataset_with_config, + load_preprocessed_dataset, + merge_datasets, + save_preprocessed_dataset, + try_load_from_hub, ) -from axolotl.prompters import ( - AlpacaPrompter, - GPTeacherPrompter, - JeopardyPrompter, - MultipleChoiceConcisePrompter, - MultipleChoiceExplainPrompter, - Prompter, - ReflectAlpacaPrompter, - SummarizeTLDRPrompter, - UnsupportedPrompter, -) -from axolotl.utils.data.pretraining import wrap_pretraining_dataset -from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config +from axolotl.utils.data.streaming import wrap_streaming_dataset from axolotl.utils.data.utils import ( deduplicate_and_log_datasets, - drop_long_seq_in_dataset, - md5, + handle_long_seq_in_dataset, retry_on_request_exceptions, ) +from axolotl.utils.data.wrappers import get_dataset_wrapper from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_local_main_process, zero_first +from axolotl.utils.distributed import is_local_main_process from axolotl.utils.logging import get_logger from axolotl.utils.trainer import ( calculate_total_num_steps, @@ -63,121 +45,73 @@ LOG = get_logger(__name__) @retry_on_request_exceptions(max_retries=3, delay=5) -def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): - prompters = [] - if not cfg.pretraining_dataset: - with zero_first(is_local_main_process()): - if cfg.test_datasets: - train_dataset, _, prompters = load_prepare_datasets( - tokenizer, - cfg, - DEFAULT_DATASET_PREPARED_PATH, - split="train", - processor=processor, - preprocess_iterable=preprocess_iterable, - ) - _, eval_dataset, _ = load_prepare_datasets( - tokenizer, - cfg, - DEFAULT_DATASET_PREPARED_PATH, - split="test", - processor=processor, - preprocess_iterable=preprocess_iterable, - ) - else: - train_dataset, eval_dataset, prompters = load_prepare_datasets( - tokenizer, - cfg, - DEFAULT_DATASET_PREPARED_PATH, - processor=processor, - preprocess_iterable=preprocess_iterable, - ) - else: - # Load streaming dataset if pretraining_dataset is given - path = cfg.pretraining_dataset - split = "train" - name = None - data_files = None - skip = 0 - if isinstance(cfg.pretraining_dataset, list) and isinstance( - cfg.pretraining_dataset[0], dict - ): - path = cfg.pretraining_dataset[0]["path"] - name = cfg.pretraining_dataset[0]["name"] - skip = cfg.pretraining_dataset[0]["skip"] - if "split" in cfg.pretraining_dataset[0]: - split = cfg.pretraining_dataset[0]["split"] +def prepare_datasets( + cfg: DictDefault, + tokenizer: PreTrainedTokenizer, + processor: ProcessorMixin | None = None, +) -> tuple[IterableDataset | Dataset, Dataset | None, int, list[Prompter | None]]: + """Prepare training and evaluation datasets based on configuration. - data_files = cfg.pretraining_dataset[0].get("data_files") + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + tokenizer: Tokenizer to use for processing text. + processor: Optional processor for multimodal datasets. - ds_wrapper_partial = functools.partial( - get_dataset_wrapper, - cfg.pretraining_dataset[0], + Returns: + Tuple of (train_dataset, eval_dataset, total_steps, prompters). + """ + if cfg.streaming or cfg.pretraining_dataset: + return _prepare_streaming_dataset(cfg, tokenizer, processor) + return _prepare_standard_dataset(cfg, tokenizer, processor) + + +def _prepare_standard_dataset( + cfg: DictDefault, + tokenizer: PreTrainedTokenizer, + processor: ProcessorMixin | None, +) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]: + """Prepare standard (non-pretraining) datasets.""" + + def _load_datasets(): + # Always load training dataset + train_dataset, eval_dataset, prompters = _load_and_prepare_datasets( tokenizer, cfg, - cfg.pretraining_dataset[0]["type"] or "pretrain", + split="train", + processor=processor, ) - # when letting accelerator dispatch batches from the main process, we don't need to load the dataset from - # other ranks, we just need to present a fake dataset - if ( - cfg.accelerator_config - and cfg.accelerator_config.dispatch_batches - and not is_local_main_process() - ): - with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: - f.write("text\n") - f.write("lorem ipsum dolor sit amet\n") - # rewind the file pointer to the beginning so we can read it again - f.seek(0) - iter_ds = load_dataset( - "csv", data_files=f.name, split="train", streaming=True - ) - else: - iter_ds = load_dataset( - path, streaming=True, split=split, name=name, data_files=data_files - ) - - if skip: - LOG.info(f"Skipping {skip} samples from the dataset") - iter_ds = iter_ds.skip(skip) - train_dataset = wrap_pretraining_dataset( - iter_ds, - tokenizer, - cfg, - ds_wrapper_partial, - max_tokens=cfg.sequence_len, - batch_size=cfg.micro_batch_size, - seed=cfg.seed if cfg.seed is not None else 42, - buffer_size=cfg.pretrain_multipack_buffer_size or 10_000, - ) - # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 - train_dataset = train_dataset.with_format("torch") - - # Load eval dataset (non-streaming) if specified - eval_dataset = None + # Overwrite eval_dataset if test data exists if cfg.test_datasets: - _, eval_dataset, _ = load_prepare_datasets( + _, eval_dataset, _ = _load_and_prepare_datasets( tokenizer, cfg, - DEFAULT_DATASET_PREPARED_PATH, split="test", processor=processor, - preprocess_iterable=preprocess_iterable, ) - if cfg.dataset_exact_deduplication: - LOG.info("Deduplication not available for pretrained datasets") + return train_dataset, eval_dataset, prompters - return train_dataset, eval_dataset, cfg.max_steps, prompters + # Prepare datasets (with file locking logic for multiple ranks) + loader = FileLockLoader(cfg) + try: + train_dataset, eval_dataset, prompters = loader.load(_load_datasets) + finally: + loader.cleanup() + if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1": + return train_dataset, eval_dataset, -1, prompters + + # Validate sample packing configuration for evaluation if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) if total_eval_steps == 0: raise ValueError( - "eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. " + "eval dataset split is too small for sample_packing. " + "You should set `eval_sample_packing: False` in your config." ) + # Calculate total number of training steps if cfg.max_steps: total_num_steps = min( calculate_total_num_steps(cfg, train_dataset), cfg.max_steps @@ -188,219 +122,353 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): return train_dataset, eval_dataset, total_num_steps, prompters -def load_tokenized_prepared_datasets( - tokenizer, - cfg, - default_dataset_prepared_path, - split="train", - processor=None, - preprocess_iterable: Optional[bool] = None, -) -> Tuple[DatasetDict, List[Prompter]]: - cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets - tokenizer_name = cfg.tokenizer_config +def _prepare_streaming_dataset( + cfg: DictDefault, + tokenizer: PreTrainedTokenizer, + processor: ProcessorMixin | None, +) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]: + """ + Prepare dataset for streaming mode. - ds_hash = str( - md5( - ( - str(cfg.sequence_len) - + "@" - + str(cfg.sample_packing) - + "@" - + str(cfg.eval_sample_packing) - + "@" - + str(cfg.group_by_length) - + "@" - + str(cfg.kd_temperature or 1.0) - + "|".join( - sorted( - [ - f"{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}" - for d in cfg_datasets - ] - ) - ) - + "|" - + tokenizer_name - ) - ) - ) - prepared_ds_path = ( - Path(cfg.dataset_prepared_path) / ds_hash - if cfg.dataset_prepared_path - else Path(default_dataset_prepared_path) / ds_hash - ) - dataset = None - prompters = [] - use_auth_token = cfg.hf_use_auth_token - try: - if cfg.push_dataset_to_hub: - LOG.info( - f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..." - ) - dataset = load_dataset( - cfg.push_dataset_to_hub, - ds_hash, - token=use_auth_token, - ) - dataset = dataset[split] - except Exception: # pylint: disable=broad-except # nosec - pass + Note: Streaming datasets are loaded incrementally from the source. + """ + if cfg.pretraining_dataset: + dataset_config = _extract_pretraining_config(cfg) + train_dataset = _load_streaming_dataset(dataset_config, cfg, tokenizer) + elif cfg.sample_packing: + # TODO(djsaunde): Implement for multiple datasets + dataset_config = DictDefault(cfg.datasets[0]) - # pylint: disable=duplicate-code - if dataset: - # This is for the case where we already loaded a pretokenized dataset from the hub - ... - elif ( - cfg.dataset_prepared_path - and any(prepared_ds_path.glob("*")) - and not cfg.is_preprocess - and not cfg.skip_prepare_dataset - ): - LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") - dataset = load_from_disk(str(prepared_ds_path)) - LOG.info("Prepared dataset loaded from disk...") + # Ensure we have a split set - default to 'train' if not specified + if not hasattr(dataset_config, "split") or not dataset_config.split: + dataset_config.split = "train" + train_dataset = _load_streaming_dataset(dataset_config, cfg, tokenizer) else: - if cfg.push_dataset_to_hub: - LOG.info("Unable to find prepared dataset in Huggingface hub") - if cfg.is_preprocess: - LOG.info( - f"Skipping prepared dataset in {prepared_ds_path} for pre-processing..." - ) - else: - LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") - LOG.info("Loading raw datasets...") - if not cfg.is_preprocess: - LOG.warning( - "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset." - ) + # Use legacy loading function for non-packed streaming datasets + train_dataset, eval_dataset, prompters = _load_and_prepare_datasets( + tokenizer, + cfg, + split="train", + processor=processor, + streaming=True, + ) - if cfg.seed: - seed = cfg.seed - else: - LOG.info("No seed provided, using default seed of 42") - seed = 42 + # Return early for non-packed streaming datasets + total_num_steps = cfg.max_steps if cfg.max_steps else -1 + return train_dataset, eval_dataset, total_num_steps, prompters - datasets = [] + # Load evaluation dataset if specified + eval_dataset = None + if cfg.test_datasets: + _, eval_dataset, _ = _load_and_prepare_datasets( + tokenizer, + cfg, + split="test", + processor=processor, + streaming=False, + ) - streaming_ds = False - if preprocess_iterable: - streaming_ds = True - # pylint: disable=invalid-name - for config_dataset in datasets_w_name_generator(cfg_datasets): - ds: Union[Dataset, DatasetDict] = load_dataset_w_config( - config_dataset, use_auth_token, streaming=streaming_ds - ) + # For streaming, we return max_steps directly from config or -1 if not set + total_num_steps = cfg.max_steps if cfg.max_steps else -1 + return train_dataset, eval_dataset, total_num_steps, [] - d_base_type = d_prompt_style = None - d_type = config_dataset.type - if isinstance(d_type, str): - d_type_split = d_type.split(":") - d_base_type = d_type_split[0] - d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None - if isinstance(ds, DatasetDict): - if config_dataset.split and config_dataset.split in ds: - ds = ds[config_dataset.split] - elif split in ds: - ds = ds[split] - else: - raise ValueError( - f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `" - ) +def _extract_pretraining_config(cfg: DictDefault) -> DictDefault: + """Extract pretraining configuration from the main config.""" + if isinstance(cfg.pretraining_dataset, list) and isinstance( + cfg.pretraining_dataset[0], dict + ): + config = cfg.pretraining_dataset[0] + return DictDefault( + { + "path": config["path"], + "name": config["name"], + "skip": config["skip"], + "split": config.get("split", "train"), + "data_files": config.get("data_files"), + "type": config.get("type", "pretrain"), + } + ) + # Simple string path case + return DictDefault( + { + "path": cfg.pretraining_dataset, + "name": None, + "skip": 0, + "split": "train", + "data_files": None, + "type": "pretrain", + } + ) - # support for using a subset of the data - if config_dataset.shards: - shards_idx = config_dataset.get("shards_idx", 0) - ds = ds.shuffle(seed=seed).shard( - num_shards=config_dataset.shards, index=shards_idx - ) - dataset_wrapper, dataset_prompter = get_dataset_wrapper( - config_dataset=config_dataset, - tokenizer=tokenizer, - cfg=cfg, - d_base_type=d_base_type, - dataset=ds, - d_prompt_style=d_prompt_style, - processor=processor, - ) - datasets.append(dataset_wrapper) - prompters.append(dataset_prompter) +def _load_streaming_dataset( + pretraining_config: DictDefault, cfg: DictDefault, tokenizer: PreTrainedTokenizer +) -> IterableDataset: + """Load and prepare a streaming dataset for pretraining.""" + # Create dataset wrapper partial function + dataset_wrapper_partial = functools.partial( + get_dataset_wrapper, + dataset_config=pretraining_config, + tokenizer=tokenizer, + cfg=cfg, + dataset_base_type=pretraining_config["type"], + ) - if len(datasets) == 1: - dataset = datasets[0] - else: - LOG.info("Merging datasets...") - dataset = concatenate_datasets(datasets) + # Load the actual dataset + if ( + cfg.accelerator_config + and cfg.accelerator_config.dispatch_batches + and not is_local_main_process() + ): + iter_dataset = _create_placeholder_dataset() + else: + iter_dataset = load_dataset( + pretraining_config["path"], + streaming=True, + split=pretraining_config["split"], + name=pretraining_config["name"], + data_files=pretraining_config["data_files"], + ) - if len(datasets) > 1: - if cfg.shuffle_merged_datasets: - LOG.debug("Shuffling merged datasets...") - dataset = dataset.shuffle(seed=seed) - else: - LOG.debug("NOT shuffling merged datasets") + # Apply skip if specified + if pretraining_config["skip"]: + LOG.info(f"Skipping {pretraining_config['skip']} samples from the dataset") + iter_dataset = iter_dataset.skip(pretraining_config["skip"]) - if not cfg.skip_prepare_dataset: - dataset = drop_long_seq_in_dataset(dataset, cfg) + # Wrap the dataset for pretraining + train_dataset = wrap_streaming_dataset( + iter_dataset, + tokenizer, + cfg, + dataset_wrapper_partial, + ) - if cfg.sample_packing: - dataset, _ = process_datasets_for_packing(cfg, dataset, None) + # Format for PyTorch + return train_dataset.with_format("torch") - if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: - LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") - if isinstance(dataset, IterableDataset): - num_workers = cfg.dataset_processes - def gen_from_iter_ds(_ds, worker_id: List[int], num_workers: List[int]): - """Generator function to correctly splice the dataset for each worker""" - for i, item in enumerate(_ds): - if i % num_workers[0] == worker_id[0]: - yield item +def _create_placeholder_dataset() -> IterableDataset: + """Create a minimal placeholder dataset for non-main processes.""" + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: + f.write("text\n") + f.write("lorem ipsum dolor sit amet\n") + f.seek(0) + return load_dataset("csv", data_files=f.name, split="train", streaming=True) - ds_from_iter = Dataset.from_generator( - functools.partial(gen_from_iter_ds, dataset), - features=dataset.features, - num_proc=num_workers, - split=split, - gen_kwargs={ - "worker_id": list(range(num_workers)), - "num_workers": [num_workers] * num_workers, - }, - ) - ds_from_iter.save_to_disk(str(prepared_ds_path)) - else: - os.makedirs(prepared_ds_path, exist_ok=True) - dataset.save_to_disk(str(prepared_ds_path)) - if cfg.push_dataset_to_hub: - LOG.info( - f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..." - ) - dataset.push_to_hub( - cfg.push_dataset_to_hub, - ds_hash, - private=True, - ) + +def _load_tokenized_prepared_datasets( + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + split: Literal["train", "test"] = "train", + processor: ProcessorMixin | None = None, + streaming: bool = False, +) -> tuple[Dataset | DatasetDict, list[Prompter | None]]: + """Load or create tokenized and prepared datasets for training or testing. + + Args: + tokenizer: Tokenizer for processing text. + cfg: Configuration object. + split: Dataset split to load ('train' or 'test'). + processor: Optional processor for multimodal datasets. + streaming: Whether to use iterable preprocessing. + + Returns: + Tuple of (dataset, prompters list). + """ + # Select correct dataset configuration based on split + datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets + + # Generate dataset hash for caching + dataset_hash = generate_dataset_hash_from_config( + cfg, datasets_configs, tokenizer.name_or_path + ) + + # Try loading from hub if push_dataset_to_hub is configured + dataset = None + if cfg.push_dataset_to_hub: + dataset = try_load_from_hub(cfg, dataset_hash, split) + + # If not found on hub, try loading from disk + if dataset is None: + dataset = load_preprocessed_dataset(cfg, dataset_hash) + + # If not found on disk or skipping prepared dataset, load and process raw datasets + prompters: list[Prompter | None] = [] + if dataset is None: + dataset, prompters = _load_raw_datasets( + cfg, + datasets_configs, + tokenizer, + split, + processor, + streaming, + ) return dataset, prompters -def load_prepare_datasets( - tokenizer: PreTrainedTokenizerBase, - cfg, - default_dataset_prepared_path, - split="train", - processor=None, - preprocess_iterable: Optional[bool] = False, -) -> Tuple[Dataset, Dataset, List[Prompter]]: - dataset, prompters = load_tokenized_prepared_datasets( - tokenizer, - cfg, - default_dataset_prepared_path, - split=split, - processor=processor, - preprocess_iterable=preprocess_iterable, +def _load_raw_datasets( + cfg: DictDefault, + datasets_configs: list, + tokenizer: PreTrainedTokenizer, + split: str, + processor: ProcessorMixin | None = None, + streaming: bool = False, +) -> tuple[Dataset, list[Prompter | None]]: + """Load, process, merge, and save raw datasets.""" + LOG.info("Loading raw datasets...", main_process_only=False) + if not cfg.is_preprocess and not cfg.skip_prepare_dataset: + LOG.warning( + "Processing datasets during training can lead to VRAM instability. Please " + "pre-process your dataset using `axolotl preprocess path/to/config.yml`." + ) + + # Load and process individual datasets + datasets = [] + prompters = [] + for dataset_config in datasets_with_name_generator(datasets_configs): + dataset_wrapper, dataset_prompter = _load_and_process_single_dataset( + dataset_config=dataset_config, + cfg=cfg, + tokenizer=tokenizer, + split=split, + seed=cfg.seed, + processor=processor, + streaming=streaming, + ) + datasets.append(dataset_wrapper) + prompters.append(dataset_prompter) + + # Merge datasets + dataset = merge_datasets(datasets, cfg) + + if not cfg.skip_prepare_dataset and not streaming: + if split == "test" and cfg.eval_sequence_len: + dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg) + else: + dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg) + if cfg.sample_packing: + dataset, _ = process_datasets_for_packing(cfg, dataset, None) + + # Save the prepared dataset + dataset_hash = generate_dataset_hash_from_config( + cfg, datasets_configs, tokenizer.name_or_path + ) + save_preprocessed_dataset(cfg, dataset, dataset_hash, split) + + return dataset, prompters + + +def _load_and_process_single_dataset( + dataset_config: DictDefault, + cfg: DictDefault, + tokenizer: PreTrainedTokenizer, + split: str, + seed: int, + processor: ProcessorMixin | None = None, + streaming: bool = False, +) -> tuple[Dataset | IterableDataset, Prompter | None]: + """Load and process a single dataset based on the passed config.""" + # Load the dataset + dataset = load_dataset_with_config( + dataset_config, cfg.hf_use_auth_token, streaming=streaming ) + # Parse dataset type + d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type) + + # Select the appropriate split + if isinstance(dataset, (DatasetDict, IterableDatasetDict)): + if dataset_config.split and dataset_config.split in dataset: + dataset = dataset[dataset_config.split] + elif split in dataset: + dataset = dataset[split] + else: + raise ValueError( + f"no {split} split found for dataset {dataset_config.path}, you may " + "specify a split with 'split: ...'" + ) + + # Apply sharding if configured + if dataset_config.shards: + shards_idx = dataset_config.get("shards_idx", 0) + dataset = dataset.shuffle(seed=seed).shard( + num_shards=dataset_config.shards, index=shards_idx + ) + + # Apply dataset wrapper + dataset_wrapper, dataset_prompter = get_dataset_wrapper( + dataset_config=dataset_config, + tokenizer=tokenizer, + cfg=cfg, + dataset_base_type=d_base_type, + dataset=dataset, + dataset_prompt_style=d_prompt_style, + processor=processor, + ) + + return dataset_wrapper, dataset_prompter + + +def _parse_dataset_type(d_type: str) -> tuple[str | None, str | None]: + """Parse the dataset type string into base type and prompt style.""" + if not isinstance(d_type, str): + return None, None + + d_type_split = d_type.split(":") + d_base_type = d_type_split[0] + d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None + + return d_base_type, d_prompt_style + + +def _handle_train_dataset_split( + dataset: Dataset, cfg: DictDefault +) -> tuple[Dataset, Dataset | None]: + """Handle processing for train split, including validation set creation.""" + val_set_size = ( + int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size) + ) + + if val_set_size: + # Create train/validation split + train_dataset, eval_dataset = create_train_validation_split( + dataset, cfg, val_set_size + ) + return train_dataset, eval_dataset + + # No validation split - apply deduplication if needed and return as train dataset + if cfg.dataset_exact_deduplication: + train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset) + else: + train_dataset = dataset + + return train_dataset, None + + +def _handle_test_dataset_split( + dataset: Dataset, cfg: DictDefault +) -> tuple[None, Dataset | None]: + """Handle processing for test split.""" + if cfg.dataset_exact_deduplication: + eval_dataset, _ = deduplicate_and_log_datasets(dataset=dataset) + else: + eval_dataset = dataset + + return None, eval_dataset + + +def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset: + """Apply dataset sharding if configured. + + Args: + dataset: Dataset to shard. + cfg: Configuration object containing shard settings. + + Returns: + Sharded dataset or original dataset if no sharding configured. + """ if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: LOG.info( f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards" @@ -409,259 +477,44 @@ def load_prepare_datasets( num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx, ) + return dataset - val_set_size = ( - int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size) + +def _load_and_prepare_datasets( + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + split: Literal["train", "test"] = "train", + processor: ProcessorMixin | None = None, + streaming: bool = False, +) -> tuple[Dataset | None, Dataset | None, list[Prompter | None]]: + """Load and prepare datasets with optional validation split and sharding. + + Args: + tokenizer: Tokenizer for processing text. + cfg: Configuration object. + split: Dataset split to load ('train' or 'test'). + processor: Optional processor for multimodal datasets. + streaming: Whether to use iterable preprocessing. + + Returns: + Tuple of (train_dataset, eval_dataset, prompters). + """ + # Load the base dataset + dataset, prompters = _load_tokenized_prepared_datasets( + tokenizer, + cfg, + split=split, + processor=processor, + streaming=streaming, ) - if split == "train" and val_set_size: - seed = cfg.seed if cfg.seed is not None else 42 + # Apply dataset sharding if configured using shared function + dataset = _apply_dataset_sharding(dataset, cfg) - # ensure we end up with the same fingerprint by doing rank0 first and being able to cache - to_hash_train = ( - dataset._fingerprint # pylint: disable=protected-access - + "|" - + str(val_set_size) - + "|" - + "train" - + "|" - + str(cfg.seed or 42) - ) - to_hash_test = ( - dataset._fingerprint # pylint: disable=protected-access - + "|" - + str(val_set_size) - + "|" - + "test" - + "|" - + str(cfg.seed or 42) - ) - train_fingerprint = md5(to_hash_train) - test_fingerprint = md5(to_hash_test) - if cfg.dataset_exact_deduplication: - _, _, dataset = deduplicate_and_log_datasets(dataset=dataset) - dataset = dataset.train_test_split( - test_size=val_set_size, - shuffle=False, - seed=seed, - train_new_fingerprint=train_fingerprint, - test_new_fingerprint=test_fingerprint, - ) - - train_dataset = dataset["train"] - eval_dataset = dataset["test"] - elif split == "test": - if cfg.dataset_exact_deduplication: - _, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset) - else: - eval_dataset = dataset - train_dataset = None + # Apply deduplication and create train / validation splits based on the split type + if split == "train": + train_dataset, eval_dataset = _handle_train_dataset_split(dataset, cfg) else: - if cfg.dataset_exact_deduplication: - train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset) - else: - train_dataset = dataset - eval_dataset = None + train_dataset, eval_dataset = _handle_test_dataset_split(dataset, cfg) + return train_dataset, eval_dataset, prompters - - -def get_dataset_wrapper( - config_dataset, - tokenizer, - cfg, - d_base_type, - dataset, - d_prompt_style=None, - processor=None, # pylint: disable=unused-argument -): - dataset_wrapper = None - dataset_prompter = None - - ds_kwargs = { - "process_count": cfg.dataset_processes, - "keep_in_memory": cfg.dataset_keep_in_memory is True, - } - - LOG.info( - f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}" - ) - - if ( - isinstance(dataset, Dataset) - and "input_ids" in dataset.features - and "attention_mask" in dataset.features - and "labels" in dataset.features - ): - # dataset is already tokenized, just drop it straight in - dataset_prompter = UnsupportedPrompter() - dataset_wrapper = dataset - elif isinstance(config_dataset.type, DictDefault): - ds_strategy = load( - "user_defined", tokenizer, cfg, config_dataset.type.to_dict() - ) - dataset_prompter = UnsupportedPrompter() - dataset_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - elif cfg.skip_prepare_dataset: - dataset_wrapper = dataset - elif ds_strategy := config_dataset.type.startswith( - "bradley_terry" - ) and bradley_terry_load( - config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset - ): - dataset_prompter = UnsupportedPrompter() - dataset_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - elif config_dataset.type.startswith("stepwise_supervised"): - dataset_prompter = UnsupportedPrompter() - ds_strategy = load(config_dataset.type, tokenizer, cfg, config_dataset) - # we need to explicitly cast boolean labels to int - # for compatibility with how trl's PRMTrainer works - dataset = dataset.cast_column("labels", Sequence(Value("int64"))) - dataset_wrapper = TokenizedPromptDataset( - ds_strategy, - dataset, - **ds_kwargs, - ) - elif ds_strategy := load( - config_dataset.type, tokenizer, cfg, config_dataset, processor=processor - ): - if isinstance(ds_strategy, DatasetWrappingStrategy): - dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs) - else: - dataset_prompter = UnsupportedPrompter() - dataset_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - elif d_base_type == "alpaca": - dataset_prompter = AlpacaPrompter(d_prompt_style) - ds_strategy = AlpacaPromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "explainchoice": - dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style) - ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "concisechoice": - dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style) - ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "summarizetldr": - dataset_prompter = SummarizeTLDRPrompter(d_prompt_style) - ds_strategy = SummarizeTLDRPromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "jeopardy": - dataset_prompter = JeopardyPrompter(d_prompt_style) - ds_strategy = JeopardyPromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "oasst": - dataset_prompter = AlpacaPrompter(d_prompt_style) - ds_strategy = OpenAssistantPromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "gpteacher": - dataset_prompter = GPTeacherPrompter(d_prompt_style) - ds_strategy = GPTeacherPromptTokenizingStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - dataset_wrapper = ds_wrapper - elif d_base_type == "reflection": - dataset_prompter = ReflectAlpacaPrompter(d_prompt_style) - ds_strategy = AlpacaReflectionPTStrategy( - dataset_prompter, - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = wrap_dataset_for_tokenized_prompt( - ds_strategy, - dataset, - **ds_kwargs, - ) - dataset_wrapper = ds_wrapper - else: - suffix = "" - if ":load_" in config_dataset.type: - suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?" - LOG.error( - f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}" - ) - raise ValueError( - f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}" - ) - - return dataset_wrapper, dataset_prompter diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index d2e119f77..c9a91b829 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -1,11 +1,21 @@ -""" -dataset loading shared utils -""" +"""Dataset loading shared utils.""" +from __future__ import annotations + +import functools +import os from pathlib import Path -from typing import Optional, Union +from typing import TYPE_CHECKING, Any, Generator -from datasets import Dataset, DatasetDict, load_dataset, load_from_disk +from datasets import ( + Dataset, + DatasetDict, + IterableDataset, + IterableDatasetDict, + concatenate_datasets, + load_dataset, + load_from_disk, +) from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub.errors import ( HFValidationError, @@ -13,78 +23,142 @@ from huggingface_hub.errors import ( RevisionNotFoundError, ) +from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH +from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 +from axolotl.utils.datasets import get_default_process_count from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +if TYPE_CHECKING: + from adlfs import AzureBlobFileSystem + from gcsfs import GCSFileSystem + from ocifs import OCIFileSystem + from s3fs import S3FileSystem + +LOG = get_logger(__name__) + +EXTENSIONS_TO_DATASET_TYPES = { + ".parquet": "parquet", + ".arrow": "arrow", + ".csv": "csv", + ".txt": "text", +} -def get_ds_type(config_dataset: DictDefault): - """ - Get the dataset type from the path if it's not specified - """ - ds_type = "json" - if config_dataset.ds_type: - ds_type = config_dataset.ds_type - elif ".parquet" in config_dataset.path: - ds_type = "parquet" - elif ".arrow" in config_dataset.path: - ds_type = "arrow" - elif ".csv" in config_dataset.path: - ds_type = "csv" - elif ".txt" in config_dataset.path: - ds_type = "text" - return ds_type +def get_dataset_type(dataset_config: DictDefault) -> str: + """Get the dataset type from the path if it's not specified.""" + if dataset_config.ds_type: + return dataset_config.ds_type + + for extension, dataset_type in EXTENSIONS_TO_DATASET_TYPES.items(): + if extension in dataset_config.path: + return dataset_type + + return "json" -def datasets_w_name_generator(dataset_configs: list[DictDefault]): - """ - Yields dataset configs handling multiple names or preprocess_shards +def datasets_with_name_generator( + dataset_configs: list[DictDefault], +) -> Generator[DictDefault, None, None]: + """Yields expanded dataset configurations based on multiple names or preprocessing + shards. + + When a dataset config has a list of names, it yields separate configs for each + name. When a dataset config specifies preprocessing shards, it yields configs for + each shard. Args: - dataset_configs: list of dataset configs (equivalent to cfg.datasets) + dataset_configs: List of dataset configuration objects. + + Yields: + Individual dataset configurations, expanded as needed for names or shards. """ - for dataset in dataset_configs: - if dataset.name and isinstance(dataset.name, list): - # load_dataset doesn't properly handle multiple named configurations - # at the same time for a given dataset - for name in dataset.name: - yield DictDefault({**dataset, "name": name}) - elif dataset.preprocess_shards and not dataset.shards: - for shard in range(dataset.preprocess_shards): + for config in dataset_configs: + if config.name and isinstance(config.name, list): + for name in config.name: + yield DictDefault({**config, "name": name}) + elif config.preprocess_shards and not config.shards: + for shard_idx in range(config.preprocess_shards): yield DictDefault( { - **dataset, - "shards": dataset.preprocess_shards, - "shards_idx": shard, + **config, + "shards": config.preprocess_shards, + "shards_idx": shard_idx, } ) else: - yield dataset + yield config -def load_dataset_w_config( - config_dataset: DictDefault, use_auth_token: bool, streaming=False -) -> Union[Dataset, DatasetDict]: - """ - Load a dataset from a config +def load_dataset_with_config( + dataset_config: DictDefault, use_auth_token: bool, streaming=False +) -> Dataset | IterableDataset: + """Load a dataset from a config. Handles datasets that are stored locally, in the + HuggingFace Hub, in a remote filesystem (S3, GCS, Azure, OCI), a URL, or + `data_files`. Args: - config_dataset: single dataset config - use_auth_token: whether to use HF auth token - streaming: whether to stream the dataset + dataset_config: Single dataset config. + use_auth_token: Whether to use HF auth token. + streaming: Whether to stream the dataset. + + Returns: + Loaded dataset. """ - # pylint: disable=invalid-name - ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name - ds_from_hub = False + # Set up common kwargs for dataset loading + load_dataset_kwargs = { + "split": dataset_config.split if dataset_config.split else None, + "name": dataset_config.name, + "streaming": streaming, + "trust_remote_code": dataset_config.trust_remote_code, + } + + # First check if it's a local path + if Path(dataset_config.path).exists(): + return _load_from_local_path(dataset_config, load_dataset_kwargs) + + # Check if it's a HuggingFace dataset + is_hub_dataset = _check_if_hub_dataset(dataset_config, use_auth_token) + + # Check if it's a cloud storage path and get appropriate filesystem + remote_fs, storage_options = _get_remote_filesystem(dataset_config.path) + is_cloud_dataset = False + if remote_fs: + try: + is_cloud_dataset = remote_fs.exists(dataset_config.path) + except (FileNotFoundError, ConnectionError): + pass + + # Load from appropriate source + if is_hub_dataset: + return _load_from_hub(dataset_config, use_auth_token, load_dataset_kwargs) + if is_cloud_dataset: + return _load_from_cloud( + dataset_config, remote_fs, storage_options, load_dataset_kwargs + ) + if dataset_config.path.startswith("https://"): + return _load_from_url(dataset_config, load_dataset_kwargs) + if dataset_config.data_files: + return _load_from_data_files(dataset_config, load_dataset_kwargs) + + raise ValueError( + f"The dataset could not be loaded. This could be due to a misconfigured dataset path " + f"({dataset_config.path}). Try double-check your path / name / data_files. " + f"This is not caused by the dataset type." + ) + + +def _check_if_hub_dataset(dataset_config: DictDefault, use_auth_token: bool) -> bool: + """Check if a dataset exists on the HuggingFace Hub.""" try: - # this is just a basic check to see if the path is a - # valid HF dataset that's loadable snapshot_download( - repo_id=config_dataset.path, + repo_id=dataset_config.path, repo_type="dataset", token=use_auth_token, - revision=config_dataset.revision, + revision=dataset_config.revision, ignore_patterns=["*"], ) - ds_from_hub = True + return True except ( RepositoryNotFoundError, RevisionNotFoundError, @@ -93,198 +167,399 @@ def load_dataset_w_config( HFValidationError, ValueError, ): - pass + return False - ds_from_cloud = False - storage_options: dict = {} - remote_file_system = None - if config_dataset.path.startswith("s3://"): + +def _get_remote_filesystem( + path: str, +) -> tuple[ + S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem | None, dict +]: + """Get the appropriate filesystem for a remote path.""" + if path.startswith("s3://"): try: - import s3fs # type: ignore + import s3fs + + storage_options = {"anon": False} + return s3fs.S3FileSystem(**storage_options), storage_options except ImportError as exc: raise ImportError("s3:// paths require s3fs to be installed") from exc - # Reads env, credentials from ~/.aws/credentials, or IAM metadata provider - # https://s3fs.readthedocs.io/en/latest/index.html?highlight=storage_options#credentials - storage_options = {"anon": False} - remote_file_system = s3fs.S3FileSystem(**storage_options) - elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith( - "gcs://" - ): + elif path.startswith(("gs://", "gcs://")): try: - import gcsfs # type: ignore + import gcsfs + + storage_options = {"token": None} # type: ignore + return gcsfs.GCSFileSystem(**storage_options), storage_options except ImportError as exc: raise ImportError( "gs:// or gcs:// paths require gcsfs to be installed" ) from exc - # gcsfs will use default credentials from the environment else anon - # https://gcsfs.readthedocs.io/en/latest/#credentials - storage_options = {"token": None} - remote_file_system = gcsfs.GCSFileSystem(**storage_options) - elif ( - config_dataset.path.startswith("adl://") - or config_dataset.path.startswith("abfs://") - or config_dataset.path.startswith("az://") - ): + elif path.startswith(("adl://", "abfs://", "az://")): try: import adlfs + + storage_options = {"anon": False} + return adlfs.AzureBlobFileSystem(**storage_options), storage_options except ImportError as exc: raise ImportError( "adl:// or abfs:// paths require adlfs to be installed" ) from exc - # # Ensure you have the following environment variables set: - # # Gen 1 - # storage_options = { - # "tenant_id": AZURE_STORAGE_TENANT_ID, - # "client_id": AZURE_STORAGE_CLIENT_ID, - # "client_secret": AZURE_STORAGE_CLIENT_SECRET, - # } - # # Gen 2 - # storage_options = { - # "account_name": AZURE_STORAGE_ACCOUNT_NAME, - # "account_key": AZURE_STORAGE_ACCOUNT_KEY, - # } - - # Reads env - # https://github.com/fsspec/adlfs?tab=readme-ov-file#setting-credentials - storage_options = {"anon": False} - remote_file_system = adlfs.AzureBlobFileSystem(**storage_options) - elif config_dataset.path.startswith("oci://"): + elif path.startswith("oci://"): try: import ocifs + + storage_options = {} + return ocifs.OCIFileSystem(**storage_options), storage_options except ImportError as exc: raise ImportError("oci:// paths require ocifs to be installed") from exc - # https://ocifs.readthedocs.io/en/latest/getting-connected.html#Using-Environment-Variables - remote_file_system = ocifs.OCIFileSystem(**storage_options) + return None, {} - try: - if remote_file_system and remote_file_system.exists(config_dataset.path): - ds_from_cloud = True - except (FileNotFoundError, ConnectionError): - pass - # gather extra args from the config - load_ds_kwargs = {} - if config_dataset.split: - load_ds_kwargs["split"] = config_dataset.split +def _load_from_local_path( + dataset_config: DictDefault, load_dataset_kwargs: dict +) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: + """Load a dataset from a local path.""" + local_path = Path(dataset_config.path) + + if local_path.is_dir(): + if dataset_config.data_files: + dataset_type = get_dataset_type(dataset_config) + return load_dataset( + dataset_type, + data_files=dataset_config.data_files, + **load_dataset_kwargs, + ) + try: + return load_from_disk(dataset_config.path) + except FileNotFoundError: + return load_dataset(dataset_config.path, **load_dataset_kwargs) + elif local_path.is_file(): + dataset_type = get_dataset_type(dataset_config) + return load_dataset( + dataset_type, + data_files=dataset_config.path, + **load_dataset_kwargs, + ) else: - load_ds_kwargs["split"] = None - - # prefer local dataset, even if hub exists - local_path = Path(config_dataset.path) - if local_path.exists(): - if local_path.is_dir(): - if config_dataset.data_files: - ds_type = get_ds_type(config_dataset) - ds = load_dataset( # pylint: disable=invalid-name - ds_type, - name=config_dataset.name, - data_files=config_dataset.data_files, - streaming=streaming, - **load_ds_kwargs, - ) - else: - try: - ds = load_from_disk( - config_dataset.path - ) # pylint: disable=invalid-name - except FileNotFoundError: - ds = load_dataset( - config_dataset.path, - name=config_dataset.name, - streaming=False, - **load_ds_kwargs, - ) - elif local_path.is_file(): - ds_type = get_ds_type(config_dataset) - - ds = load_dataset( # pylint: disable=invalid-name - ds_type, - name=config_dataset.name, - data_files=config_dataset.path, - streaming=False, - **load_ds_kwargs, - ) - else: - raise ValueError( - "unhandled dataset load: local path exists, but is neither a directory or a file" - ) - elif ds_from_hub: - ds = load_dataset( - config_dataset.path, - name=config_dataset.name, - streaming=streaming, - data_files=config_dataset.data_files, - token=use_auth_token, - revision=config_dataset.revision, - trust_remote_code=config_dataset.trust_remote_code, - **load_ds_kwargs, - ) - elif ds_from_cloud and remote_file_system: - if remote_file_system.isdir(config_dataset.path): - ds = load_from_disk( - config_dataset.path, - storage_options=storage_options, - ) - elif remote_file_system.isfile(config_dataset.path): - ds_type = get_ds_type(config_dataset) - ds = load_dataset( - ds_type, - name=config_dataset.name, - data_files=config_dataset.path, - streaming=streaming, - storage_options=storage_options, - trust_remote_code=config_dataset.trust_remote_code, - **load_ds_kwargs, - ) - elif config_dataset.path.startswith("https://"): - ds_type = get_ds_type(config_dataset) - ds = load_dataset( - ds_type, - name=config_dataset.name, - data_files=config_dataset.path, - streaming=streaming, - storage_options=storage_options, - trust_remote_code=config_dataset.trust_remote_code, - **load_ds_kwargs, - ) - elif config_dataset.data_files: - fp: str | list[str] | None = None - if isinstance(config_dataset.data_files, str): - fp = hf_hub_download( - repo_id=config_dataset.path, - repo_type="dataset", - filename=config_dataset.data_files, - revision=config_dataset.revision, - ) - elif isinstance(config_dataset.data_files, list): - fp = [] - for file in config_dataset.data_files: - fp.append( - hf_hub_download( - repo_id=config_dataset.path, - repo_type="dataset", - filename=file, - revision=config_dataset.revision, - ) - ) - else: - raise ValueError("data_files must be either a string or list of strings") - ds = load_dataset( - "json", - name=config_dataset.name, - data_files=fp, - streaming=streaming, - **load_ds_kwargs, - ) - if not ds: raise ValueError( - "The dataset could not be loaded. This could be due to a misconfigured dataset path " - f"({config_dataset.path}). Try double-check your path / name / data_files. " - "This is not caused by the dataset type." + "Unhandled dataset load: local path exists, but is neither a directory or a file" ) - return ds + +def _load_from_hub( + dataset_config: DictDefault, use_auth_token: bool, load_dataset_kwargs: dict +) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: + """Load a dataset from the HuggingFace Hub.""" + return load_dataset( + dataset_config.path, + data_files=dataset_config.data_files, + token=use_auth_token, + revision=dataset_config.revision, + **load_dataset_kwargs, + ) + + +def _load_from_cloud( + dataset_config: DictDefault, + remote_fs: S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem, + storage_options: dict, + load_dataset_kwargs: dict, +) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: + """Load a dataset from cloud storage.""" + if remote_fs.isdir(dataset_config.path): + return load_from_disk( + dataset_config.path, + storage_options=storage_options, + ) + + if remote_fs.isfile(dataset_config.path): + dataset_type = get_dataset_type(dataset_config) + return load_dataset( + dataset_type, + data_files=dataset_config.path, + storage_options=storage_options, + **load_dataset_kwargs, + ) + + raise ValueError( + f"Cloud path {dataset_config.path} is neither a directory nor a file" + ) + + +def _load_from_url( + dataset_config: DictDefault, load_dataset_kwargs: dict +) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: + """Load a dataset from a URL.""" + dataset_type = get_dataset_type(dataset_config) + return load_dataset( + dataset_type, + data_files=dataset_config.path, + **load_dataset_kwargs, + ) + + +def _load_from_data_files( + dataset_config: DictDefault, load_dataset_kwargs: dict +) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: + """Load a dataset from data files.""" + file_path = None + + if isinstance(dataset_config.data_files, str): + file_path = hf_hub_download( + repo_id=dataset_config.path, + repo_type="dataset", + filename=dataset_config.data_files, + revision=dataset_config.revision, + ) + elif isinstance(dataset_config.data_files, list): + file_path = [ + hf_hub_download( + repo_id=dataset_config.path, + repo_type="dataset", + filename=file, + revision=dataset_config.revision, + ) + for file in dataset_config.data_files + ] + else: + raise ValueError("data_files must be either a string or list of strings") + + return load_dataset("json", data_files=file_path, **load_dataset_kwargs) + + +def generate_split_fingerprints( + dataset: Dataset, val_set_size: int | float, seed: int +) -> tuple[str, str]: + """Generate consistent fingerprints for train/test splits.""" + fingerprint = dataset._fingerprint + + train_hash_input = f"{fingerprint}|{val_set_size}|train|{seed}" + test_hash_input = f"{fingerprint}|{val_set_size}|test|{seed}" + + train_fingerprint = md5(train_hash_input) + test_fingerprint = md5(test_hash_input) + + return train_fingerprint, test_fingerprint + + +def get_prepared_dataset_path(cfg: DictDefault, dataset_hash: str) -> Path: + """Get standardized path for prepared datasets. + + Args: + cfg: Configuration object. + dataset_hash: Hash identifying the specific dataset configuration. + + Returns: + Path where the prepared dataset should be stored. + """ + base_path = cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH + return Path(base_path) / dataset_hash + + +def create_train_validation_split( + dataset: Dataset, cfg: DictDefault, val_set_size: int | float +) -> tuple[Dataset, Dataset]: + """Create train/validation split with consistent fingerprinting. + + Args: + dataset: Dataset to split. + cfg: Configuration object containing seed and other settings. + val_set_size: Size of validation set (absolute number or fraction). + + Returns: + Tuple of (train_dataset, eval_dataset). + """ + train_fingerprint, test_fingerprint = generate_split_fingerprints( + dataset, val_set_size, cfg.seed + ) + + # Apply deduplication before splitting if configured + if cfg.dataset_exact_deduplication: + dataset, _ = deduplicate_and_log_datasets(dataset=dataset) + + split_dataset = dataset.train_test_split( + test_size=val_set_size, + shuffle=False, + seed=cfg.seed, + train_new_fingerprint=train_fingerprint, + test_new_fingerprint=test_fingerprint, + ) + + return split_dataset["train"], split_dataset["test"] + + +def _generate_from_iterable_dataset( + dataset: IterableDataset, worker_id: list[int], num_workers: list[int] +) -> Generator[Any, None, None]: + """Generator function to correctly split the dataset for each worker""" + for i, item in enumerate(dataset): + if i % num_workers[0] == worker_id[0]: + yield item + + +def save_preprocessed_dataset( + cfg: DictDefault, + dataset: Dataset, + dataset_hash: str, + split: str, +) -> None: + """Save preprocessed dataset to disk and optionally push to the HF Hub.""" + prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash) + num_workers = cfg.dataset_num_proc or get_default_process_count() + if isinstance(dataset, IterableDataset): + ds_from_iter = Dataset.from_generator( + functools.partial(_generate_from_iterable_dataset, dataset), + features=dataset.features, + num_proc=num_workers, + split=split, + gen_kwargs={ + "worker_id": list(range(num_workers)), + "num_workers": [num_workers] * num_workers, + }, + ) + ds_from_iter.save_to_disk( + str(prepared_ds_path), + num_proc=num_workers, + max_shard_size=None, + num_shards=cfg.num_dataset_shards_to_save, + ) + else: + min_rows_per_proc = 256 + os.makedirs(prepared_ds_path, exist_ok=True) + dataset.save_to_disk( + str(prepared_ds_path), + num_proc=min(max(1, len(dataset) // min_rows_per_proc), num_workers), + max_shard_size=None, + num_shards=cfg.num_dataset_shards_to_save, + ) + if cfg.push_dataset_to_hub: + LOG.info( + "Pushing merged prepared dataset to Huggingface hub at " + f"{cfg.push_dataset_to_hub} (version {dataset_hash})...", + main_process_only=False, + ) + dataset.push_to_hub( + cfg.push_dataset_to_hub, + dataset_hash, + private=True, + ) + + +def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset | None: + """Load preprocessed dataset from disk if available. + + Args: + cfg: Configuration object. + dataset_hash: Hash identifying the dataset configuration. + + Returns: + Loaded dataset if found and conditions are met, None otherwise. + """ + prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash) + + if ( + cfg.dataset_prepared_path + and any(prepared_ds_path.glob("*")) + and not cfg.skip_prepare_dataset + and not cfg.is_preprocess + ): + LOG.info( + f"Loading prepared dataset from disk at {prepared_ds_path}...", + main_process_only=True, + ) + return load_from_disk(str(prepared_ds_path)) + + LOG.info( + f"Unable to find prepared dataset in {prepared_ds_path}", + main_process_only=True, + ) + return None + + +def try_load_from_hub( + cfg: DictDefault, dataset_hash: str, split: str +) -> Dataset | None: + """Try to load the prepared dataset from HuggingFace Hub.""" + try: + LOG.info( + "Attempting to load prepared dataset from HuggingFace Hub at " + f"{cfg.push_dataset_to_hub} (version {dataset_hash})..." + ) + dataset = load_dataset( + cfg.push_dataset_to_hub, + dataset_hash, + token=cfg.hf_use_auth_token, + ) + return dataset[split] + except Exception: + LOG.info("Unable to find prepared dataset in HuggingFace Hub") + return None + + +def generate_dataset_hash_from_config( + cfg: DictDefault, cfg_datasets: list, tokenizer_name: str +) -> str: + """Generate a hash to uniquely identify a dataset configuration for SFT. + + Args: + cfg: Main configuration object. + cfg_datasets: List of dataset configurations. + tokenizer_name: Name of the tokenizer being used. + + Returns: + MD5 hash string representing the configuration. + """ + config_str = ( + f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@" + f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}|" + f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}" + f"|{tokenizer_name}" + ) + return str(md5(config_str)) + + +def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset: + """Merge multiple datasets into one with optional shuffling. + + Args: + datasets: List of datasets to merge. + cfg: Configuration object containing shuffle settings. + + Returns: + Merged dataset. + """ + if len(datasets) == 1: + ds = datasets[0] + + # Do not shuffle if curriculum sampling is enabled or + # shuffle_merged_datasets is disabled + if cfg.curriculum_sampling or not cfg.shuffle_merged_datasets: + return ds + + return ds.shuffle(seed=cfg.seed) + + # If enabled, shuffle each dataset independently before merging. + # This allows curriculum learning strategies to be applied at the dataset level. + if cfg.shuffle_before_merging_datasets: + LOG.info("Shuffling each dataset individually before merging...") + datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets] + + LOG.info("Merging datasets...") + merged_dataset = concatenate_datasets(datasets) + + if cfg.shuffle_merged_datasets: + LOG.debug("Shuffling merged datasets...") + if cfg.curriculum_sampling: + LOG.warning( + "Shuffling merged datasets with curriculum sampling is not recommended. " + "This will randomize the order of samples." + ) + merged_dataset = merged_dataset.shuffle(seed=cfg.seed) + else: + LOG.debug("Not shuffling merged datasets.") + + return merged_dataset diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/streaming.py similarity index 84% rename from src/axolotl/utils/data/pretraining.py rename to src/axolotl/utils/data/streaming.py index 44d8d6fed..2cb35ee7c 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/streaming.py @@ -1,4 +1,4 @@ -"""data handling specific to pretraining""" +"""Data handling specific to streaming datasets.""" import functools from collections import defaultdict @@ -17,10 +17,10 @@ from axolotl.utils.trainer import process_pretraining_datasets_for_packing LOG = get_logger(__name__) -def encode_pretraining( +def encode_streaming( + examples: Dict[str, List], tokenizer: PreTrainedTokenizerBase, max_tokens: int, - examples: Dict[str, List], text_column: str = "text", concatenate: bool = True, ) -> Dict[str, List]: @@ -67,7 +67,7 @@ def encode_pretraining( buffer_labels = torch.tensor([], dtype=torch.long) buffer_attention_mask = torch.tensor([], dtype=torch.long) - for ids, labels, mask in zip(input_ids, targets, attention_mask): + for ids, labels, mask in zip(input_ids, targets, attention_mask, strict=False): if buffer_input_ids.numel() == max_tokens: new_input_ids.append(buffer_input_ids) new_labels.append(buffer_labels) @@ -176,45 +176,57 @@ def encode_pretraining( return ret -def wrap_pretraining_dataset( +def wrap_streaming_dataset( dataset, tokenizer, cfg, ds_wrapper_fn, - max_tokens=2048, - batch_size=1, - seed=42, - buffer_size=10_000, ): if cfg.sample_packing: + # For SFT (non-pretraining) datasets, always use multipack_attn=True to ensure + # attention isolation between packed sequences + multipack_attn = ( + True if not cfg.pretraining_dataset else cfg.pretrain_multipack_attn + ) + collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( tokenizer, return_tensors="pt", padding=True, - pad_to_multiple_of=max_tokens, - multipack_attn=cfg.pretrain_multipack_attn, + pad_to_multiple_of=cfg.sequence_len, + multipack_attn=multipack_attn, ) encode = functools.partial( - encode_packed_pretraining, + encode_packed_streaming, collate_fn, ds_wrapper_fn, - max_seq_length=max_tokens, - batch_size=batch_size, - multipack_attn=cfg.pretrain_multipack_attn, + max_seq_length=cfg.sequence_len, + batch_size=cfg.micro_batch_size, + multipack_attn=multipack_attn, ) - # set this to 1 so downstream data_loader doesn't try to increase the batch again + + # Set this to 1 so downstream data_loader doesn't try to increase the batch size + # again cfg.micro_batch_size = 1 else: + # NOTE: This is not reachable for SFT datasets since we use the pre-existing + # loading function for non-packed streaming datasets. Refer to + # _prepare_streaming_datasets in sft.py for that code path. + text_column = ( + getattr(cfg.pretraining_dataset[0], "text_column", "text") or "text" + ) encode = functools.partial( - encode_pretraining, - tokenizer, - max_tokens, - text_column=cfg.pretraining_dataset[0].text_column or "text", + encode_streaming, + tokenizer=tokenizer, + max_tokens=cfg.sequence_len, + text_column=text_column, concatenate=cfg.pretraining_sample_concatenation is True, ) if cfg.shuffle_merged_datasets: - dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) + dataset = dataset.shuffle( + seed=cfg.seed, buffer_size=cfg.streaming_multipack_buffer_size + ) else: LOG.debug("NOT shuffling merged pretraining datasets") @@ -224,22 +236,21 @@ def wrap_pretraining_dataset( remove_columns = [] if dataset.features is None: for first_row in dataset: - remove_columns = first_row.keys() + remove_columns = list(first_row.keys()) break else: - remove_columns = dataset.features.keys() + remove_columns = list(dataset.features.keys()) dataset = dataset.map( encode, batched=True, - batch_size=buffer_size, - # input_columns="text", + batch_size=cfg.streaming_multipack_buffer_size, remove_columns=remove_columns, ) return dataset -def encode_packed_pretraining( +def encode_packed_streaming( collate_fn, ds_wrapper: Callable, examples: Dict[str, List], @@ -247,10 +258,9 @@ def encode_packed_pretraining( batch_size: int = 4, multipack_attn: Optional[bool] = True, ) -> Dict[str, List]: - # pylint: disable=duplicate-code # tokenize all the examples # rows get split with stride (overlap) - train_dataset = ds_wrapper(Dataset.from_dict(examples))[0] + train_dataset = ds_wrapper(dataset=Dataset.from_dict(examples))[0] train_dataset = process_pretraining_datasets_for_packing( train_dataset, @@ -267,6 +277,7 @@ def encode_packed_pretraining( batch_size=1, batch_max_len=batch_size * max_seq_length, drop_last=True, + num_processes=1, ) chunked_data = defaultdict(list) @@ -274,8 +285,6 @@ def encode_packed_pretraining( for batch in sampler: for data in batch: features = train_dataset[data] - if "num_truncated_tokens" in features: - del features["num_truncated_tokens"] if "num_truncated_tokens" in features: del features["num_truncated_tokens"] if "overflow_to_sample_mapping" in features: diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 5f3b8d3cc..2d0ca9d0e 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -1,9 +1,11 @@ -"""data handling helpers""" +"""Data handling helpers""" +import contextlib import functools import hashlib import time from enum import Enum +from typing import Callable import huggingface_hub import numpy as np @@ -19,9 +21,7 @@ LOG = get_logger(__name__) class RetryStrategy(Enum): - """ - Enum for retry strategies. - """ + """Enum for retry strategies.""" CONSTANT = 1 LINEAR = 2 @@ -30,16 +30,28 @@ class RetryStrategy(Enum): def retry_on_request_exceptions( max_retries=3, delay=1, retry_strategy: RetryStrategy = RetryStrategy.LINEAR -): +) -> Callable: + """Decorator that retries function calls on specific request exceptions. + + Args: + max_retries: Maximum number of retry attempts. + delay: Base delay between retries in seconds. + retry_strategy: Strategy for calculating retry delays. + + Returns: + Decorated function with retry logic. + """ + def decorator(func): @functools.wraps(func) - def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements + def wrapper(*args, **kwargs): for attempt in range(max_retries): try: return func(*args, **kwargs) except ( requests.exceptions.ReadTimeout, requests.exceptions.ConnectionError, + requests.exceptions.HTTPError, huggingface_hub.errors.HfHubHTTPError, ) as exc: if attempt < max_retries - 1: @@ -59,6 +71,7 @@ def retry_on_request_exceptions( def md5(to_hash: str, encoding: str = "utf-8") -> str: + """Generate MD5 hash of a string.""" try: return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest() except TypeError: @@ -66,137 +79,172 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str: def sha256(to_hash: str, encoding: str = "utf-8") -> str: + """Generate SHA256 hash of a string.""" return hashlib.sha256(to_hash.encode(encoding)).hexdigest() -def deduplicate_dataset( - dataset: Dataset, seen_hashes: dict[str, list[int]], other_dataset: Dataset = None -) -> Dataset: - unique_indices = [] +def _deduplicate_dataset( + dataset: Dataset, + seen_hashes: set[str] | None = None, +) -> tuple[Dataset, set[str]]: + """Remove duplicate rows from a dataset using SHA256 hashes. + Args: + dataset: Dataset to deduplicate. + seen_hashes: Set of previously seen row hashes (for cross-deduplication). + + Returns: + Tuple of deduplicated dataset and the set of seen hashes. + """ + if seen_hashes is None: + seen_hashes = set() + + unique_indices = [] for idx, row in enumerate(dataset): - row_hash = sha256(str(row)) # Using SHA256 for collision resistance. + row_hash = sha256(str(row)) # Using SHA256 for collision resistance if row_hash not in seen_hashes: - seen_hashes[row_hash] = [idx] + seen_hashes.add(row_hash) unique_indices.append(idx) - else: - # Check for collision by looking up the original dataset indices - original_indices = seen_hashes[row_hash] - is_duplicate = False - for original_idx in original_indices: - if ( - not idx == original_idx - and original_idx < len(dataset) - and str(dataset[original_idx]) == str(row) - ): - is_duplicate = True - break - # Check in the other dataset if provided - if other_dataset is not None: - if original_idx < len(other_dataset) and str( - other_dataset[original_idx] - ) == str(row): - is_duplicate = True - break - if not is_duplicate: - seen_hashes[row_hash].append(idx) - unique_indices.append(idx) - continue - return dataset.select(unique_indices) + + return dataset.select(unique_indices), seen_hashes def deduplicate_and_log_datasets( - *, - train_dataset: Dataset = None, - eval_dataset: Dataset = None, - dataset: Dataset = None, -) -> tuple[Dataset, Dataset, Dataset]: - """ - Deduplicates train, eval, and an optional dataset if provided, logging original and new sizes. + dataset: Dataset, + other_dataset: Dataset | None = None, + dataset_name: str | None = "train", + other_name: str | None = "eval", +) -> tuple[Dataset, Dataset | None]: + """Deduplicate datasets, with optional cross-dataset deduplication. + + Args: + dataset: Primary dataset to deduplicate. + other_dataset: Optional second dataset to deduplicate against the first. + dataset_name: Name for the primary dataset (for logging). + other_name: Name for the second dataset (for logging). Returns: - tuple: Deduplicated train, eval, and additional datasets. + Tuple of (deduplicated_dataset, deduplicated_other_dataset). """ - seen_hashes: dict[str, list[int]] = {} + # Deduplicate primary dataset + LOG.info( + f"Starting deduplication for {dataset_name} dataset. Original size: {len(dataset)}" + ) + dataset, seen_rows = _deduplicate_dataset(dataset) + LOG.info( + f"Deduplication complete for {dataset_name} dataset. New size: {len(dataset)}" + ) - # Handle cases where datasets are None - if train_dataset is not None: + # Deduplicate second dataset if provided + if other_dataset is not None: LOG.info( - f"Starting deduplication for train dataset. Original size: {len(train_dataset)}" - ) - train_dataset = deduplicate_dataset( - dataset=train_dataset, seen_hashes=seen_hashes + f"Starting deduplication for {other_name} dataset. Original size: {len(other_dataset)}" ) + other_dataset, _ = _deduplicate_dataset(other_dataset, seen_rows) LOG.info( - f"Deduplication complete for train dataset. New size: {len(train_dataset)}" - ) - else: - LOG.info("Train dataset is None. Skipping deduplication.") - - if eval_dataset is not None: - LOG.info( - f"Starting deduplication for eval dataset. Original size: {len(eval_dataset)}" - ) - eval_dataset = deduplicate_dataset( - dataset=eval_dataset, seen_hashes=seen_hashes, other_dataset=train_dataset - ) - LOG.info( - f"Deduplication complete for eval dataset. New size: {len(eval_dataset)}" - ) - else: - LOG.info("Eval dataset is None. Skipping deduplication.") - - if dataset is not None and (eval_dataset is None and train_dataset is None): - LOG.info( - f"Starting deduplication for combined dataset. Original size: {len(dataset)}" - ) - dataset = deduplicate_dataset(dataset=dataset, seen_hashes=seen_hashes) - LOG.info( - f"Deduplication complete for combined dataset. New size: {len(dataset)}" + f"Deduplication complete for {other_name} dataset. New size: {len(other_dataset)}" ) - return train_dataset, eval_dataset, dataset + return dataset, other_dataset -def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): - if "input_ids" not in dataset.column_names: +def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2): + """ + Truncate samples whose sequence length is too long (> sequence_len) + or drop those too short (< min_sequence_len). + """ + min_sequence_len = min_sequence_len or 2 + + input_ids = sample["input_ids"] + results = [] + + # Batched (input_ids is a list of lists) + for i, seq in enumerate(input_ids): + length = len(seq) + if length < min_sequence_len: + results.append(False) + elif length > sequence_len: + sample["input_ids"][i] = seq[:sequence_len] + if "attention_mask" in sample: + sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len] + if "labels" in sample: + sample["labels"][i] = sample["labels"][i][:sequence_len] + if "position_ids" in sample: + sample["position_ids"][i] = sample["position_ids"][i][:sequence_len] + results.append(True) + else: + results.append(True) + return results + + +def handle_long_seq_in_dataset( + dataset: Dataset, sequence_len: int, cfg: DictDefault +) -> Dataset: + """Remove sequences longer than configured maximum from dataset. + + Args: + dataset: Dataset to filter. + sequence_len: Maximum length for sequences to keep + cfg: Dictionary mapping `axolotl` config keys to values. + + Returns: + Filtered dataset with long sequences removed. + """ + if ( + hasattr(dataset, "column_names") + and dataset.column_names + and "input_ids" not in dataset.column_names + ): LOG.warning( - "Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling." + "Dataset does not contain 'input_ids' column. Skip drop long seq. This is " + "expected for reward modeling." + ) + return dataset + elif not hasattr(dataset, "column_names") or dataset.column_names is None: + LOG.info( + "Dataset is streaming (IterableDataset), skipping long sequence handling" ) return dataset drop_long = functools.partial( drop_long_seq, - sequence_len=cfg.sequence_len, + sequence_len=sequence_len, min_sequence_len=cfg.min_sample_len, ) - try: + with contextlib.suppress(AttributeError): ds_lengths = get_dataset_lengths(dataset, from_arrow=True) min_input_len = np.min(ds_lengths) LOG.info(f"min_input_len: {min_input_len}") max_input_len = np.max(ds_lengths) LOG.info(f"max_input_len: {max_input_len}") - except AttributeError: - pass - try: - prior_len = len(dataset) - except TypeError: - # handle iterable datasets case - prior_len = None + prior_len = len(dataset) if hasattr(dataset, "__len__") else None filter_map_kwargs = {} if not isinstance(dataset, IterableDataset): - filter_map_kwargs["num_proc"] = cfg.dataset_processes + filter_map_kwargs["num_proc"] = cfg.dataset_num_proc filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess drop_long_kwargs = {} if filter_map_kwargs: - drop_long_kwargs["desc"] = "Dropping Long Sequences" + drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})" + + excess_length_strategy = (cfg.excess_length_strategy or "drop").lower() + if excess_length_strategy == "truncate": + process_fn = functools.partial( + truncate_long_seq, + sequence_len=sequence_len, + min_sequence_len=cfg.min_sample_len, + ) + drop_long_kwargs["desc"] = ( + f"Truncating/Filtering Sequences (target_len={sequence_len})" + ) + else: + process_fn = drop_long dataset = dataset.filter( - drop_long, + process_fn, batched=True, **filter_map_kwargs, **drop_long_kwargs, @@ -204,6 +252,11 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): if prior_len: dropped = prior_len - len(dataset) if dropped: - LOG.warning(f"Dropped {dropped} long samples from dataset") + action = ( + "truncated/filtered" + if excess_length_strategy == "truncate" + else "dropped" + ) + LOG.warning(f"{action.title()} {dropped} samples from dataset") return dataset diff --git a/src/axolotl/utils/data/wrappers.py b/src/axolotl/utils/data/wrappers.py new file mode 100644 index 000000000..3a10bde00 --- /dev/null +++ b/src/axolotl/utils/data/wrappers.py @@ -0,0 +1,424 @@ +"""Data handling specific to SFT.""" + +import logging +from typing import Any, NoReturn, cast + +from datasets import ( + Dataset, + IterableDataset, + Sequence, + Value, +) +from transformers import PreTrainedTokenizer +from transformers.processing_utils import ProcessorMixin + +from axolotl.datasets import TokenizedPromptDataset, wrap_dataset_for_tokenized_prompt +from axolotl.prompt_strategies import load +from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load +from axolotl.prompt_tokenizers import ( + AlpacaMultipleChoicePromptTokenizingStrategy, + AlpacaPromptTokenizingStrategy, + AlpacaReflectionPTStrategy, + DatasetWrappingStrategy, + GPTeacherPromptTokenizingStrategy, + JeopardyPromptTokenizingStrategy, + OpenAssistantPromptTokenizingStrategy, + PromptTokenizingStrategy, + SummarizeTLDRPromptTokenizingStrategy, +) +from axolotl.prompters import ( + AlpacaPrompter, + GPTeacherPrompter, + JeopardyPrompter, + MultipleChoiceConcisePrompter, + MultipleChoiceExplainPrompter, + Prompter, + ReflectAlpacaPrompter, + SummarizeTLDRPrompter, + UnsupportedPrompter, +) +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger(__name__) + + +def handle_unknown_dataset_strategy(dataset_config: DictDefault) -> NoReturn: + """Raise error for unknown dataset strategy.""" + ds_type = dataset_config.type + suffix = "" + if ":load_" in ds_type: + suffix = f"Did you mean {ds_type.replace(':load_', '.load_')}?" + + error_message = f"unhandled prompt tokenization strategy: {ds_type}. {suffix}" + LOG.error(error_message) + raise ValueError(error_message) + + +def get_dataset_wrapper( + dataset_config: DictDefault, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset_base_type: str | None, + dataset: Dataset | IterableDataset, + dataset_prompt_style: str | None = None, + processor: ProcessorMixin | None = None, +) -> tuple[Dataset | IterableDataset, Prompter | None]: + """Create an appropriate dataset wrapper and prompter based on dataset + configuration. + + Args: + dataset_config: Configuration for the dataset. + tokenizer: Tokenizer to use for processing text. + cfg: Global configuration object. + dataset_base_type: The base type of the dataset. + dataset: The actual dataset object. + dataset_prompt_style: Optional prompt style specification. + processor: Optional processor for multimodal datasets. + + Returns: + tuple of (dataset_wrapper, dataset_prompter). + """ + # Common parameters for dataset wrapping + dataset_kwargs: dict[str, Any] = { + "process_count": cfg.dataset_num_proc, + "keep_in_memory": cfg.dataset_keep_in_memory is True, + } + + LOG.info( + f"Loading dataset: {dataset_config['path']} with base_type: " + f"{dataset_base_type} and prompt_style: {dataset_prompt_style}" + ) + + # Dataset is already tokenized + if _is_dataset_already_tokenized(dataset): + return dataset, UnsupportedPrompter() + + # Custom dataset type definition + if isinstance(dataset_config.type, DictDefault): + return _handle_custom_dataset_type( + dataset_config, tokenizer, cfg, dataset, dataset_kwargs + ) + + # Skip preparation if configured + if cfg.skip_prepare_dataset: + return dataset, None + + # Bradley-Terry dataset + if dataset_config.type.startswith("bradley_terry"): + return _handle_bradley_terry_dataset( + dataset_config, tokenizer, cfg, dataset, dataset_kwargs + ) + + # Stepwise supervised dataset + if dataset_config.type.startswith("stepwise_supervised"): + return _handle_stepwise_supervised_dataset( + dataset_config, tokenizer, cfg, dataset, dataset_kwargs + ) + + # Try to load prompt tokenizer / dataset wrapper strategy from registry + dataset_strategy = load( + dataset_config.type, tokenizer, cfg, dataset_config, processor=processor + ) + if dataset_strategy: + return _handle_loaded_strategy(dataset_strategy, dataset, dataset_kwargs) + + # Known dataset types with specific handling + if dataset_base_type in DATASET_HANDLERS: + handler = DATASET_HANDLERS[dataset_base_type] + return handler(dataset_prompt_style, tokenizer, cfg, dataset, dataset_kwargs) + + # Unhandled dataset type + handle_unknown_dataset_strategy(dataset_config) + + +def _is_dataset_already_tokenized(dataset: Dataset | IterableDataset) -> bool: + """Check if the dataset is already tokenized.""" + return ( + isinstance(dataset, Dataset) + and "input_ids" in dataset.features + and "attention_mask" in dataset.features + and "labels" in dataset.features + ) + + +def _handle_custom_dataset_type( + dataset_config: DictDefault, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle a custom dataset type defined in the configuration.""" + dataset_strategy = cast( + PromptTokenizingStrategy, + load("user_defined", tokenizer, cfg, dataset_config.type.to_dict()), + ) + dataset_prompter = UnsupportedPrompter() + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_bradley_terry_dataset( + dataset_config: DictDefault, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter | None]: + """Handle a Bradley-Terry dataset.""" + bt_type = dataset_config.type.split(".", 1)[1] + dataset_strategy = bradley_terry_load(bt_type, tokenizer, cfg, dataset_config) + + if not dataset_strategy: + handle_unknown_dataset_strategy(dataset_config) + + dataset_prompter = UnsupportedPrompter() + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + + return dataset_wrapper, dataset_prompter + + +def _handle_stepwise_supervised_dataset( + dataset_config: DictDefault, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle a stepwise supervised dataset.""" + dataset_prompter = UnsupportedPrompter() + dataset_strategy = load(dataset_config.type, tokenizer, cfg, dataset_config) + + # We need to explicitly cast boolean labels to int + # for compatibility with how trl's PRMTrainer works + if isinstance(dataset, Dataset): + dataset = dataset.cast_column("labels", Sequence(Value("int64"))) + + dataset_wrapper = TokenizedPromptDataset( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_loaded_strategy( + dataset_strategy: PromptTokenizingStrategy | DatasetWrappingStrategy, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter | None]: + """Handle a dataset with a strategy loaded from the registry.""" + if isinstance(dataset_strategy, DatasetWrappingStrategy): + return dataset_strategy.wrap_dataset(dataset, **dataset_kwargs), None + + dataset_prompter = UnsupportedPrompter() + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_alpaca_dataset( + dataset_prompt_style: str | None, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle an Alpaca dataset.""" + dataset_prompter = AlpacaPrompter(dataset_prompt_style) + dataset_strategy = AlpacaPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_explainchoice_dataset( + dataset_prompt_style: str | None, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle an ExplainChoice dataset.""" + dataset_prompter = MultipleChoiceExplainPrompter(dataset_prompt_style) + dataset_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_concisechoice_dataset( + dataset_prompt_style: str | None, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle a ConciseChoice dataset.""" + dataset_prompter = MultipleChoiceConcisePrompter(dataset_prompt_style) + dataset_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_summarizetldr_dataset( + dataset_prompt_style: str | None, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle a SummarizeTLDR dataset.""" + dataset_prompter = SummarizeTLDRPrompter(dataset_prompt_style) + dataset_strategy = SummarizeTLDRPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_jeopardy_dataset( + dataset_prompt_style: str | None, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle a Jeopardy dataset.""" + dataset_prompter = JeopardyPrompter(dataset_prompt_style) + dataset_strategy = JeopardyPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_oasst_dataset( + dataset_prompt_style: str | None, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle an OpenAssistant dataset.""" + dataset_prompter = AlpacaPrompter(dataset_prompt_style) + dataset_strategy = OpenAssistantPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_gpteacher_dataset( + dataset_prompt_style: str | None, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle a GPTeacher dataset.""" + dataset_prompter = GPTeacherPrompter(dataset_prompt_style) + dataset_strategy = GPTeacherPromptTokenizingStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +def _handle_reflection_dataset( + dataset_prompt_style: str | None, + tokenizer: PreTrainedTokenizer, + cfg: DictDefault, + dataset: Dataset | IterableDataset, + dataset_kwargs: dict[str, Any], +) -> tuple[Dataset | IterableDataset, Prompter]: + """Handle a Reflection dataset.""" + dataset_prompter = ReflectAlpacaPrompter(dataset_prompt_style) + dataset_strategy = AlpacaReflectionPTStrategy( + dataset_prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + dataset_wrapper = wrap_dataset_for_tokenized_prompt( + dataset_strategy, + dataset, + **dataset_kwargs, + ) + return dataset_wrapper, dataset_prompter + + +DATASET_HANDLERS = { + "alpaca": _handle_alpaca_dataset, + "explainchoice": _handle_explainchoice_dataset, + "concisechoice": _handle_concisechoice_dataset, + "summarizetldr": _handle_summarizetldr_dataset, + "jeopardy": _handle_jeopardy_dataset, + "oasst": _handle_oasst_dataset, + "gpteacher": _handle_gpteacher_dataset, + "reflection": _handle_reflection_dataset, +} diff --git a/src/axolotl/utils/datasets.py b/src/axolotl/utils/datasets.py new file mode 100644 index 000000000..9b8a8e25a --- /dev/null +++ b/src/axolotl/utils/datasets.py @@ -0,0 +1,13 @@ +"""helper functions for datasets""" + +import os + + +def get_default_process_count(): + if axolotl_dataset_num_proc := os.environ.get("AXOLOTL_DATASET_NUM_PROC"): + return int(axolotl_dataset_num_proc) + if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"): + return int(axolotl_dataset_processes) + if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"): + return int(runpod_cpu_count) + return os.cpu_count() diff --git a/src/axolotl/utils/dict.py b/src/axolotl/utils/dict.py index f24f7c4a9..7d146c7a9 100644 --- a/src/axolotl/utils/dict.py +++ b/src/axolotl/utils/dict.py @@ -17,15 +17,15 @@ class DictDefault(Dict): def __setitem__(self, name, value): # workaround for pickle/unpickle issues and __frozen not being available try: - isFrozen = hasattr( # pylint: disable=invalid-name + isFrozen = hasattr(self, "__frozen") and object.__getattribute__( self, "__frozen" - ) and object.__getattribute__(self, "__frozen") + ) except AttributeError: - isFrozen = False # pylint: disable=invalid-name + isFrozen = False if isFrozen and name not in super().keys(): raise KeyError(name) - super(Dict, self).__setitem__(name, value) # pylint: disable=bad-super-call + super(Dict, self).__setitem__(name, value) try: p = object.__getattribute__(self, "__parent") key = object.__getattribute__(self, "__key") @@ -36,3 +36,16 @@ class DictDefault(Dict): p[key] = self object.__delattr__(self, "__parent") object.__delattr__(self, "__key") + + +def remove_none_values(obj): + """ + Remove null from a dictionary-like obj or list. + These can appear due to Dataset loading causing schema merge. + See https://github.com/axolotl-ai-cloud/axolotl/pull/2909 + """ + if hasattr(obj, "items"): + return {k: remove_none_values(v) for k, v in obj.items() if v is not None} + if isinstance(obj, list): + return [remove_none_values(elem) for elem in obj] + return obj diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 0673c6e95..840772d91 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -1,6 +1,4 @@ -""" -utility helpers for distributed checks -""" +"""Utilities for distributed functionality.""" import os import pickle # nosec @@ -10,16 +8,17 @@ from datetime import timedelta import torch import torch.distributed as dist from accelerate import PartialState +from accelerate.utils import ParallelismConfig from transformers.utils.import_utils import ( is_torch_cuda_available, is_torch_mps_available, is_torch_npu_available, ) -distributed_state = None # pylint: disable=invalid-name +distributed_state = None -def get_device_type(): +def get_device_type() -> torch.device: device = torch.device("cpu") if is_torch_cuda_available(): device = torch.device("cuda") @@ -30,7 +29,7 @@ def get_device_type(): return device -def get_device_count(): +def get_device_count() -> int: cur_device = get_device_type() if "cuda" in str(cur_device): return torch.cuda.device_count() @@ -39,7 +38,7 @@ def get_device_count(): return 1 -def get_current_device(): +def get_current_device() -> int: cur_device = get_device_type() if "cuda" in str(cur_device): return torch.cuda.current_device() @@ -48,14 +47,26 @@ def get_current_device(): return 0 -def is_distributed(): - """ - Check if distributed training is initialized. - """ - global distributed_state # pylint: disable=global-statement - if not distributed_state: +def init_distributed_state(): + global distributed_state + if distributed_state is None: timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800)) - distributed_state = PartialState(timeout=timedelta(seconds=timeout)) + try: + distributed_state = PartialState(timeout=timedelta(seconds=timeout)) + except ValueError: + pass + + +def get_distributed_state() -> PartialState | None: + return distributed_state + + +def is_distributed() -> bool: + """Check if distributed training is initialized.""" + init_distributed_state() + + if distributed_state is None: + return False return distributed_state.use_distributed and distributed_state.initialized @@ -69,31 +80,31 @@ def barrier(): dist.barrier() -def is_main_process(use_environ=False): +def is_main_process() -> bool: """ Check if the current process is the main process. If not in distributed mode, always return `True`. - Args: - - use_environ (bool, optional): Use environment variable to determine main process. + We use a simpler logic when the distributed state is not initialized: we just log + on the 0-th local rank. Returns: - - bool: `True` if the current process is the main process, `False` otherwise. + `True` if the current process is the main process, `False` otherwise. """ - if use_environ: + if get_distributed_state() is None: return os.environ.get("LOCAL_RANK", "0") == "0" if not is_distributed(): return True return dist.get_rank() == 0 -def is_local_main_process(use_environ=False): - if use_environ: +def is_local_main_process() -> bool: + if get_distributed_state() is None: return os.environ.get("LOCAL_RANK", "0") == "0" return PartialState().is_local_main_process -def get_world_size(): +def get_world_size() -> int: return int(os.getenv("WORLD_SIZE", "1")) @@ -115,7 +126,7 @@ def cleanup_distributed(): @contextmanager -def zero_first(is_main): +def zero_first(is_main: bool): """ runs the wrapped context so that rank 0 runs first before other ranks """ @@ -126,7 +137,7 @@ def zero_first(is_main): barrier() -def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name +def gather_scalar_from_all_ranks(fn, world_size=1): """ Run a callable 'fn' on all ranks and gather the results on the specified rank. @@ -190,7 +201,7 @@ def broadcast_dict(vals: dict): return vals -def compute_and_broadcast(fn): # pylint: disable=invalid-name +def compute_and_broadcast(fn): """ Compute a value using the function 'fn' only on the specified rank (default is 0). The value is then broadcasted to all other ranks. @@ -223,7 +234,7 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name return float(value_tensor.item()) -def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name +def gather_from_all_ranks(fn, world_size=1): """ Run a callable 'fn' on all ranks and gather the results on the specified rank. @@ -283,3 +294,77 @@ def reduce_and_broadcast(fn1, fn2): # Use compute_and_broadcast to compute the reduced value on the main process # and then broadcast it to all ranks return compute_and_broadcast(lambda: fn2(gathered_values)) + + +def build_parallelism_config(cfg): + pc_kwargs = _get_parallel_config_kwargs( + get_world_size(), + cfg.tensor_parallel_size, + cfg.context_parallel_size, + cfg.dp_shard_size, + cfg.dp_replicate_size, + bool(cfg.fsdp or cfg.fsdp_config), + ) + + if pc_kwargs: + parallelism_config = ParallelismConfig( + **pc_kwargs, + ) + device_mesh = parallelism_config.build_device_mesh("cuda") + + return parallelism_config, device_mesh + return None, None + + +def _get_parallel_config_kwargs( + world_size: int, + tensor_parallel_size: int = 1, + context_parallel_size: int = 1, + dp_shard_size: int | None = None, + dp_replicate_size: int | None = None, + is_fsdp: bool = False, +): + pc_kwargs = {} + remaining_world_size = world_size + + if tensor_parallel_size and tensor_parallel_size > 1: + pc_kwargs["tp_size"] = tensor_parallel_size + remaining_world_size = remaining_world_size // tensor_parallel_size + + if context_parallel_size and context_parallel_size > 1: + pc_kwargs["cp_size"] = context_parallel_size + remaining_world_size = remaining_world_size // context_parallel_size + + if dp_shard_size is None and dp_replicate_size in (None, 1): + if remaining_world_size > 1: + pc_kwargs["dp_shard_size"] = remaining_world_size + remaining_world_size = 1 + + if dp_replicate_size and dp_replicate_size > 1: + pc_kwargs["dp_replicate_size"] = dp_replicate_size + remaining_world_size = remaining_world_size // dp_replicate_size + + if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1: + if not is_fsdp: + raise ValueError( + "dp_shard_size was configured without a corresponding fsdp_config! " + "Please ensure you have configured FSDP using fsdp_config." + ) + pc_kwargs["dp_shard_size"] = dp_shard_size + remaining_world_size = remaining_world_size // dp_shard_size + if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs: + pc_kwargs["dp_replicate_size"] = remaining_world_size + remaining_world_size = 1 + + if remaining_world_size > 1: + if "dp_shard_size" not in pc_kwargs and is_fsdp: + pc_kwargs["dp_shard_size"] = remaining_world_size + remaining_world_size = 1 + + if remaining_world_size > 1: + raise ValueError( + f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n" + f"{pc_kwargs}" + ) + + return pc_kwargs diff --git a/src/axolotl/utils/environment.py b/src/axolotl/utils/environment.py index 1cc609a68..d5f2d9f78 100644 --- a/src/axolotl/utils/environment.py +++ b/src/axolotl/utils/environment.py @@ -2,27 +2,55 @@ utils to get GPU info for the current environment """ +import os +from importlib.metadata import version + +import torch from accelerate.utils.environment import ( check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support, ) -from accelerate.utils.environment import ( - get_gpu_info, -) +from packaging.version import Version, parse + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) def check_cuda_p2p_ib_support(): if not accelerate_check_cuda_p2p_ib_support(): return False - unsupported_devices = {"RTX 6000 Ada", "L40S"} - try: - device_names, device_count = get_gpu_info() - if 1 < device_count < 8: - if any( - unsupported_device in device_name - for device_name in device_names - for unsupported_device in unsupported_devices - ): - return False - except Exception: # pylint: disable=broad-except # nosec - pass + if not check_cuda_p2p_support(): + return False return True + + +def check_cuda_p2p_support() -> bool: + try: + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + except ValueError: + return True + + if world_size > 1: + node_world_size = int(os.environ.get("NODE_WORLD_SIZE", "8")) + local_other_rank = (local_rank // node_world_size) * node_world_size + local_other_rank += 1 if (local_rank % node_world_size) == 0 else 0 + try: + can_p2p = torch.cuda.can_device_access_peer(local_rank, local_other_rank) + except AssertionError as exc: + # some sort of logic error in indexing processes, assume p2p is fine for now + LOG.warning(exc) + return True + return can_p2p + + return True + + +def get_package_version(package: str) -> Version: + version_str = version(package) + return parse(version_str) + + +def is_package_version_ge(package: str, version_: str) -> bool: + package_version = get_package_version(package) + return package_version >= parse(version_) diff --git a/src/axolotl/utils/freeze.py b/src/axolotl/utils/freeze.py index 65ca62137..936708f04 100644 --- a/src/axolotl/utils/freeze.py +++ b/src/axolotl/utils/freeze.py @@ -5,9 +5,8 @@ module to freeze/unfreeze parameters by name import re from typing import Callable, List, Tuple, Union -from accelerate.logging import get_logger - from axolotl.utils.distributed import is_main_process +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/utils/import_helper.py b/src/axolotl/utils/import_helper.py new file mode 100644 index 000000000..f7d20099c --- /dev/null +++ b/src/axolotl/utils/import_helper.py @@ -0,0 +1,28 @@ +""" +Helper for importing modules from strings +""" + +import importlib + + +def get_cls_from_module_str(module_str: str): + # use importlib to dynamically load the reward function from the module + if not isinstance(module_str, str) or not module_str.strip(): + raise ValueError("module_str must be a non-empty string") + + parts = module_str.split(".") + if len(parts) < 2: + raise ValueError(f"Invalid module string format: {module_str}") + + try: + cls_name = parts[-1] + module_path = ".".join(parts[:-1]) + mod = importlib.import_module(module_path) + mod_cls = getattr(mod, cls_name) + return mod_cls + except ImportError as e: + raise ImportError(f"Failed to import module '{module_path}': {e}") from e + except AttributeError as e: + raise AttributeError( + f"Class '{cls_name}' not found in module '{module_path}': {e}" + ) from e diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index 80daab4ea..35810897a 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -1,10 +1,7 @@ -""" -logging helpers to only log on main process -""" +"""Logging helpers to only log on main process.""" import functools import logging -import os from axolotl.utils.distributed import is_main_process @@ -14,27 +11,18 @@ from axolotl.utils.distributed import is_main_process class MultiProcessAdapter(logging.LoggerAdapter): """ - logger adapter for distributed logging, specifically to only log on main process + Logger adapter for distributed logging, specifically to only log on main process. """ - def __init__(self, logger, use_environ=False, extra=None): - super().__init__(logger, extra) - self.use_environ = use_environ - @staticmethod - def _should_log(main_process_only, use_environ=False): - return not main_process_only or ( - main_process_only and is_main_process(use_environ=use_environ) - ) + def _should_log(main_process_only: bool): + return not main_process_only or is_main_process() def log(self, level, msg, *args, **kwargs): - use_environ = kwargs.pop("use_environ", self.use_environ) main_process_only = kwargs.pop("main_process_only", True) kwargs.setdefault("stacklevel", 2) - if self.isEnabledFor(level) and self._should_log( - main_process_only, use_environ=use_environ - ): + if self.isEnabledFor(level) and self._should_log(main_process_only): msg, kwargs = self.process(msg, kwargs) self.logger.log(level, msg, *args, **kwargs) @@ -50,13 +38,7 @@ class MultiProcessAdapter(logging.LoggerAdapter): self.warning(*args, **kwargs) -def get_logger( - name: str, log_level: str | None = None, use_environ: bool = False -) -> MultiProcessAdapter: - if log_level is None: - log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None) +def get_logger(name: str, log_level: str | None = None) -> MultiProcessAdapter: logger = logging.getLogger(name) - if log_level is not None: - logger.setLevel(log_level.upper()) - logger.root.setLevel(log_level.upper()) - return MultiProcessAdapter(logger, use_environ=use_environ, extra={}) + logger.setLevel(logging.DEBUG) + return MultiProcessAdapter(logger, extra={}) diff --git a/src/axolotl/utils/lora.py b/src/axolotl/utils/lora.py index 759c17ac2..6ae481b6b 100644 --- a/src/axolotl/utils/lora.py +++ b/src/axolotl/utils/lora.py @@ -15,6 +15,7 @@ """ module to get the state dict of a merged lora model """ + import torch from peft.tuners.tuners_utils import onload_layer from peft.utils import ModulesToSaveWrapper, _get_submodules diff --git a/src/axolotl/utils/mistral/__init__.py b/src/axolotl/utils/mistral/__init__.py new file mode 100644 index 000000000..eb51031ec --- /dev/null +++ b/src/axolotl/utils/mistral/__init__.py @@ -0,0 +1,6 @@ +"""Init for `axolotl.utils.mistral` module.""" + +from axolotl.utils.mistral.mistral3_processor import Mistral3Processor +from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer + +__all__ = ["HFMistralTokenizer", "Mistral3Processor"] diff --git a/src/axolotl/utils/mistral/mistral3_processor.py b/src/axolotl/utils/mistral/mistral3_processor.py new file mode 100644 index 000000000..85479ca7b --- /dev/null +++ b/src/axolotl/utils/mistral/mistral3_processor.py @@ -0,0 +1,169 @@ +"""Processor for Mistral3 multimodal models with image support""" + +from typing import Any, Dict, Optional, Union + +import torch +from transformers import ProcessorMixin +from transformers.feature_extraction_utils import BatchFeature +from transformers.processing_utils import ProcessingKwargs +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer + + +class Mistral3ProcessorKwargs(ProcessingKwargs): + _defaults: Dict[str, Dict[str, Any]] = { + "text_kwargs": { + "padding": True, + }, + "common_kwargs": { + "return_tensors": "pt", + "return_dict": True, + "tokenize": True, + }, + } + + +class Mistral3Processor(ProcessorMixin): + """ + Processor for Mistral3 multimodal models that handles text and images. + Wraps HFMistralTokenizer and adds image processing capabilities. + """ + + attributes = ["tokenizer"] + tokenizer_class = "HFMistralTokenizer" + + def __init__(self, tokenizer: HFMistralTokenizer): + # Don't call super().__init__ to avoid the class validation issue + self.tokenizer = tokenizer + + @property + def chat_template(self) -> None: + """Chat template is not supported. Dummy method to satisfy HuggingFace API.""" + return None + + @property + def audio_tokenizer(self) -> None: + """Audio tokenizer is not supported. Dummy method to satisfy HuggingFace API.""" + return None + + def _merge_kwargs( + self, processor_kwargs_class: Any, **kwargs: Any + ) -> Dict[str, Dict[str, Any]]: + """Merge kwargs with defaults similar to ProcessorMixin""" + defaults = processor_kwargs_class._defaults + output_kwargs: Dict[str, Dict[str, Any]] = {} + + for kwarg_type, default_values in defaults.items(): + output_kwargs[kwarg_type] = {**default_values} + + # Update with provided kwargs + for key, value in kwargs.items(): + # Try to match key to appropriate kwarg type + if key in ["padding", "truncation", "max_length"]: + output_kwargs.setdefault("text_kwargs", {}).update({key: value}) + elif key in ["return_tensors", "return_dict", "tokenize"]: + output_kwargs.setdefault("common_kwargs", {}).update({key: value}) + else: + # Add to text_kwargs by default + output_kwargs.setdefault("text_kwargs", {}).update({key: value}) + + return output_kwargs + + def apply_chat_template( + self, + conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]], + **kwargs: Any, + ) -> Union[BatchFeature, str, list[str]]: + """ + Apply chat template with image support for Mistral3. + + Similar to VoxtralProcessor, this method extracts images from the conversation, + calls the tokenizer's apply_chat_template, then adds pixel_values and image_sizes + to the result. + """ + output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs) + text_kwargs = output_kwargs["text_kwargs"] + common_kwargs = output_kwargs["common_kwargs"] + + return_tensors = common_kwargs.pop("return_tensors", "pt") + if return_tensors != "pt": + raise ValueError( + f"{self.__class__.__name__} only supports `return_tensors='pt'`." + ) + + return_dict = common_kwargs.pop("return_dict", False) + tokenize = common_kwargs.pop("tokenize", False) + + # Determine if batched + if isinstance(conversation, (list, tuple)) and ( + isinstance(conversation[0], (list, tuple)) + or hasattr(conversation[0], "content") + ): + is_batched = True + conversations = conversation + else: + is_batched = False + conversations = [conversation] # type: ignore + + # Call tokenizer's apply_chat_template + tokenizer_kwargs = {**text_kwargs, **common_kwargs} + tokenizer_kwargs["return_tensors"] = return_tensors + tokenizer_kwargs["tokenize"] = tokenize + tokenizer_kwargs["return_dict"] = return_dict + + encoded_instruct_inputs = self.tokenizer.apply_chat_template( + conversations, + **tokenizer_kwargs, + ) + + if tokenize: + if return_dict: + # The tokenizer already handles pixel_values, we just need to add image_sizes + if hasattr(encoded_instruct_inputs, "items"): + data: Dict[str, Any] = dict(encoded_instruct_inputs) # type: ignore + elif hasattr(encoded_instruct_inputs, "data"): + data = encoded_instruct_inputs.data # type: ignore + else: + raise ValueError("Unknown data type") + + if "pixel_values" in data: + pixel_values = data["pixel_values"] + + # MistralTokenizer returns a Double, so we convert to fp32 + data["pixel_values"] = pixel_values.to(dtype=torch.float32) + + # Always batched: [B, C, H, W] -> image_sizes: [B, 2] + # Since tensor is homogeneous, all images have same H, W + batch_size = pixel_values.shape[0] + image_sizes = torch.tensor([pixel_values.shape[-2:]] * batch_size) + data["image_sizes"] = image_sizes + + return BatchFeature(data=data, tensor_type=return_tensors) + + if not is_batched: + return encoded_instruct_inputs[0] + + return encoded_instruct_inputs + + def __call__( + self, + text: Optional[ + Union[ + TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput] + ] + ], + **kwargs: Any, + ) -> BatchFeature: + """ + Forward text processing to the tokenizer. + This method does not support images - use apply_chat_template instead. + """ + output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs) + text_kwargs = output_kwargs["text_kwargs"] + common_kwargs = output_kwargs["common_kwargs"] + + out = self.tokenizer(text, **text_kwargs) + return BatchFeature( + data=out, tensor_type=common_kwargs.pop("return_tensors", None) + ) diff --git a/src/axolotl/utils/mistral/mistral_tokenizer.py b/src/axolotl/utils/mistral/mistral_tokenizer.py new file mode 100644 index 000000000..0414ece78 --- /dev/null +++ b/src/axolotl/utils/mistral/mistral_tokenizer.py @@ -0,0 +1,220 @@ +"""Wrapper for MistralTokenizer from mistral-common""" + +import os +from typing import Optional + +import numpy as np +from mistral_common.protocol.instruct.validator import ValidationMode +from mistral_common.tokens.tokenizers.utils import download_tokenizer_from_hf_hub +from torch import Tensor +from transformers.tokenization_mistral_common import MistralCommonTokenizer +from transformers.tokenization_utils_base import VERY_LARGE_INTEGER + + +class HFMistralTokenizer(MistralCommonTokenizer): + """ + Wraps mistral_common.tokens.tokenizers.mistral.MistralTokenizer + and exposes HuggingFace API for special tokens. + """ + + def __init__(self, name_or_path: str, **kwargs): + """ + Args: + name_or_path: The name or path to the tokenizer files or the repo id. + **kwargs: Additional keyword arguments passed to the parent class. + """ + kwargs.pop("mode", None) + + mode = ValidationMode.finetuning + super().__init__(**kwargs, mode=mode) + + self._name_or_path = name_or_path + + # set mode as is not set upstream + self._set_mode(mode) + + @property + def name_or_path(self) -> str: + return self._name_or_path + + @property + def chat_template(self) -> str | None: + """Chat template is not supported. Dummy method to satisfy HuggingFace API.""" + return "[This is a dummy chat template]" + + def _set_mode(self, mode: ValidationMode): + """Set the mode of the MistralRequestValidator. + + Args: + mode: The mode to set. + + Raises: + RuntimeError: If the MistralRequestValidator does not have a _mode attribute. + """ + # Check if MistralRequestValidator has a _mode attribute. + # This is a private API and may change in the future. + + from mistral_common.protocol.instruct.validator import MistralRequestValidator + + if not ( + hasattr(self.tokenizer, "_chat_completion_request_validator") + and isinstance( + self.tokenizer._chat_completion_request_validator, + MistralRequestValidator, + ) + and hasattr(self.tokenizer._chat_completion_request_validator, "_mode") + ): + raise RuntimeError( + f"Unable to switch mistral tokenizer to {mode.value} mode - " + "private API `_chat_completion_request_validator._mode` missing." + ) + + self.tokenizer._chat_completion_request_validator._mode = mode + + def apply_chat_template( # type: ignore + self, + conversation: list[dict] | list[list[dict]], + chat_template: str | None = None, + add_generation_prompt: bool = False, + **kwargs, + ) -> str | list[int]: + """Patched fn to handle setting serving mode, continue_final_message, remove chat_template and add_generation_prompt kwarg""" + + try: + if add_generation_prompt: + self._set_mode(ValidationMode.serving) + kwargs["continue_final_message"] = True + + out = super().apply_chat_template(conversation, **kwargs) + + return out # type: ignore + + finally: + if add_generation_prompt: + self._set_mode(ValidationMode.finetuning) + + def decode( # type: ignore + self, + token_ids: int | list[int] | np.ndarray | Tensor, + **kwargs, + ) -> str: + """ + Decode token_ids into str. + + This overrides upstream.decode to convert int to list[int] + """ + + if isinstance(token_ids, int): + token_ids = [token_ids] + + return super().decode(token_ids, **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str | os.PathLike, + *init_inputs, + mode: ValidationMode = ValidationMode.test, + cache_dir: Optional[str | os.PathLike] = None, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[str | bool] = None, + revision: str = "main", + model_max_length: int = VERY_LARGE_INTEGER, + padding_side: str = "left", + truncation_side: str = "right", + model_input_names: Optional[list[str]] = None, + clean_up_tokenization_spaces: bool = False, + **kwargs, + ): + r""" + Patched fn to pass `name_or_path` and remove extra kwargs. + + Instantiate a `MistralCommonTokenizer` from a predefined + tokenizer. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co. + - A path to a *directory* containing the tokenizer config, for instance saved + using the [`MistralCommonTokenizer.tokenization_mistral_common.save_pretrained`] method, e.g., + `./my_model_directory/`. + mode (`ValidationMode`, *optional*, defaults to `ValidationMode.test`): + Validation mode for the `MistralTokenizer` tokenizer. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download the vocabulary files and override the cached versions if they + exist. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + local_files_only (`bool`, *optional*, defaults to `False`): + Whether or not to only rely on local files and not to attempt to download any files. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to `None`, this will use the predefined model maximum length if a maximum length + is required by one of the truncation/padding parameters. If the model has no specific maximum input + length (like XLNet) truncation/padding to a maximum length will be deactivated. + padding_side (`str`, *optional*, defaults to `"left"`): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + truncation_side (`str`, *optional*, defaults to `"right"`): + The side on which the model should have truncation applied. Should be selected between ['right', 'left']. + model_input_names (`List[string]`, *optional*): + The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or + `"attention_mask"`). Default value is picked from the class attribute of the same name. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not the model should cleanup the spaces that were added when splitting the input text during the + tokenization process. + kwargs (additional keyword arguments, *optional*): + Not supported by `MistralCommonTokenizer.from_pretrained`. + Will raise an error if used. + """ + if init_inputs: + raise ValueError( + "`init_inputs` are not supported by `MistralCommonTokenizer.from_pretrained`." + ) + + # Delete trust_remote_code as it does nothing + kwargs.pop("trust_remote_code", None) + + # Delete tokenizer as it does nothing + kwargs.pop("tokenizer", None) + + # Handle kwargs and AutoTokenizer case + if kwargs and not kwargs.keys() == {"_from_auto"}: + raise ValueError( + f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.from_pretrained`." + ) + + if not os.path.isfile(pretrained_model_name_or_path): + tokenizer_path = download_tokenizer_from_hf_hub( + repo_id=str(pretrained_model_name_or_path), + cache_dir=str(cache_dir), + token=token, + revision=revision, + force_download=force_download, + local_files_only=local_files_only, + ) + else: + tokenizer_path = str(pretrained_model_name_or_path) + + return cls( + name_or_path=str(pretrained_model_name_or_path), + tokenizer_path=tokenizer_path, + mode=mode, + model_max_length=model_max_length, + padding_side=padding_side, + truncation_side=truncation_side, + model_input_names=model_input_names, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + ) diff --git a/src/axolotl/utils/model_shard_quant.py b/src/axolotl/utils/model_shard_quant.py index 5c5006eda..ca152113a 100644 --- a/src/axolotl/utils/model_shard_quant.py +++ b/src/axolotl/utils/model_shard_quant.py @@ -46,13 +46,11 @@ def _replace_linear( if isinstance(module, torch.nn.Linear) and name not in skip_modules: if issubclass(linear_replacement, Linear4bit): - model._modules[name] = ( # pylint: disable=protected-access - linear_replacement( - module.in_features, - module.out_features, - module.bias is not None, - **kwargs, - ) + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + **kwargs, ) else: raise ValueError( @@ -150,8 +148,8 @@ def load_sharded_model( model = AutoModelForCausalLM.from_pretrained( model_name, use_cache=False, - torch_dtype=torch.float32, - _attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access + dtype=torch.float32, + _attn_implementation=model_config._attn_implementation, trust_remote_code=cfg.trust_remote_code, ) dtype = torch_dtype if not cfg.float32 else None @@ -160,7 +158,7 @@ def load_sharded_model( with init_empty_weights(): model = AutoModelForCausalLM.from_config( model_config, - torch_dtype=torch_dtype, + dtype=torch_dtype, trust_remote_code=cfg.trust_remote_code, ) return model diff --git a/src/axolotl/utils/optimizers/adopt.py b/src/axolotl/utils/optimizers/adopt.py index 6f064abbf..20ddfa7ec 100644 --- a/src/axolotl/utils/optimizers/adopt.py +++ b/src/axolotl/utils/optimizers/adopt.py @@ -6,7 +6,6 @@ Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeo """ # mypy: ignore-errors -# pylint: skip-file # flake8: noqa # mypy: allow-untyped-decorators # mypy: allow-untyped-defs @@ -288,7 +287,9 @@ def _single_tensor_adopt( assert ( param.device.type == step_t.device.type and param.device.type in capturable_supported_devices - ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) step = step_t if capturable or differentiable else _get_value(step_t) @@ -365,7 +366,9 @@ def _multi_tensor_adopt( p.device.type == step.device.type and p.device.type in capturable_supported_devices for p, step in zip(params, state_steps) - ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ), ( + f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." + ) assert grad_scale is None and found_inf is None diff --git a/src/axolotl/utils/quantization.py b/src/axolotl/utils/quantization.py index f9a30b660..6c29a5442 100644 --- a/src/axolotl/utils/quantization.py +++ b/src/axolotl/utils/quantization.py @@ -3,30 +3,47 @@ Utilities for quantization including QAT and PTQ using torchao. """ import torch -from torch import nn +from packaging import version from torchao.core.config import AOBaseConfig from torchao.quantization import quantize_ from torchao.quantization.qat import ( - FakeQuantizeConfig, - FromIntXQuantizationAwareTrainingConfig, - IntXQuantizationAwareTrainingConfig, + QATConfig, ) from torchao.quantization.quant_api import ( - Int4DynamicActivationInt4WeightConfig, - Int4WeightOnlyConfig, + Float8DynamicActivationFloat8WeightConfig, + Float8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt4WeightConfig, - Int8DynamicActivationInt8WeightConfig, - Int8WeightOnlyConfig, - UIntXWeightOnlyConfig, - _is_linear, ) -from axolotl.utils.schemas.enums import TorchIntDType +from axolotl.utils.schemas.enums import TorchAOQuantDType + +quantization_config_to_str = { + Int8DynamicActivationInt4WeightConfig: "int8int4", + Float8DynamicActivationFloat8WeightConfig: "fp8fp8", + Float8DynamicActivationInt4WeightConfig: "fp8int4", +} + +if version.parse(torch.__version__) >= version.parse("2.8.0"): + try: + from torchao.prototype.mx_formats import NVFP4InferenceConfig + + quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4" + except: + pass + + # int4 weight config imports will fail on machines with fbgemm-gpu installed + # without a CUDA runtime available so we do this safely + try: + from torchao.quantization.quant_api import Int4WeightOnlyConfig + + quantization_config_to_str[Int4WeightOnlyConfig] = "int4" + except: + pass -def get_ptq_config( - weight_dtype: TorchIntDType, - activation_dtype: TorchIntDType | None = None, +def get_quantization_config( + weight_dtype: TorchAOQuantDType, + activation_dtype: TorchAOQuantDType | None = None, group_size: int | None = None, ) -> AOBaseConfig: """ @@ -45,44 +62,101 @@ def get_ptq_config( or if the group size is not specified for int8 or int4 weight only quantization. """ if activation_dtype is None: - if not weight_dtype.value.is_signed: # type: ignore[attr-defined,union-attr] - return UIntXWeightOnlyConfig( - dtype=weight_dtype.value, - group_size=group_size, - set_inductor_config=False, - ) - if weight_dtype == TorchIntDType.int8: - if group_size is None: - raise ValueError( - "group_size must be specified for int8 weight only quantization" - ) - return Int8WeightOnlyConfig( - group_size=group_size, - ) - if weight_dtype == TorchIntDType.int4: - if group_size is None: - raise ValueError( - "group_size must be specified for int4 weight only quantization" - ) - return Int4WeightOnlyConfig( - group_size=group_size, - ) - if activation_dtype == TorchIntDType.int4 and weight_dtype == TorchIntDType.int4: - return Int4DynamicActivationInt4WeightConfig() - if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int8: - return Int8DynamicActivationInt8WeightConfig() - if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int4: - return Int8DynamicActivationInt4WeightConfig() + if weight_dtype == TorchAOQuantDType.int8: + raise ValueError("Int8WeightOnlyConfig is not supported by torchao QAT.") + if weight_dtype == TorchAOQuantDType.int4: + from torchao.quantization.quant_api import Int4WeightOnlyConfig + + if group_size is not None: + return Int4WeightOnlyConfig(group_size=group_size, version=2) + else: + return Int4WeightOnlyConfig(version=2) + if ( + activation_dtype == TorchAOQuantDType.int4 + and weight_dtype == TorchAOQuantDType.int4 + ): + raise ValueError( + "Int4DynamicActivationInt4WeightConfig is not supported by torchao QAT." + ) + if ( + activation_dtype == TorchAOQuantDType.int8 + and weight_dtype == TorchAOQuantDType.int8 + ): + raise ValueError( + "Int8DynamicActivationInt8WeightConfig is not supported by torchao QAT." + ) + if ( + activation_dtype == TorchAOQuantDType.int8 + and weight_dtype == TorchAOQuantDType.int4 + ): + if group_size is not None: + return Int8DynamicActivationInt4WeightConfig(group_size=group_size) + else: + return Int8DynamicActivationInt4WeightConfig() + if ( + activation_dtype == TorchAOQuantDType.float8_e4m3fn + and weight_dtype == TorchAOQuantDType.float8_e4m3fn + ): + return Float8DynamicActivationFloat8WeightConfig() + if ( + activation_dtype == TorchAOQuantDType.float8_e4m3fn + and weight_dtype == TorchAOQuantDType.int4 + ): + return Float8DynamicActivationInt4WeightConfig() + if weight_dtype == TorchAOQuantDType.nvfp4: + from torchao.prototype.mx_formats import NVFP4InferenceConfig + + if group_size is not None and group_size != 16: + raise ValueError("NVFP4 quantization must use a group_size of 16") + return NVFP4InferenceConfig() raise ValueError( f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}" ) +def quantize_model( + model, + weight_dtype: TorchAOQuantDType, + group_size: int | None = None, + activation_dtype: TorchAOQuantDType | None = None, + quantize_embedding: bool | None = None, +): + """ + This function is used to quantize a model. + + Args: + model: The model to quantize. + weight_dtype: The dtype to use for weight quantization. + group_size: The group size to use for weight quantization. + activation_dtype: The dtype to use for activation quantization. + quantize_embedding: Whether to quantize the model's embedding weights. + + """ + linear_ptq_config = get_quantization_config( + weight_dtype=weight_dtype, + activation_dtype=activation_dtype, + group_size=group_size, + ) + quantize_(model, linear_ptq_config) + if quantize_embedding: + # activation fake quantization is not supported for embedding layers + embedding_quantize_config = get_quantization_config( + weight_dtype=weight_dtype, + activation_dtype=None, + group_size=group_size, + ) + quantize_( + model, + embedding_quantize_config, + filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), + ) + + def prepare_model_for_qat( model, - weight_dtype: TorchIntDType, - group_size: int, - activation_dtype: TorchIntDType | None = None, + weight_dtype: TorchAOQuantDType, + group_size: int | None = None, + activation_dtype: TorchAOQuantDType | None = None, quantize_embedding: bool = False, ): """ @@ -100,86 +174,40 @@ def prepare_model_for_qat( Raises: ValueError: If the activation/weight dtype combination is invalid. """ - if activation_dtype: - activation_config = FakeQuantizeConfig( - dtype=activation_dtype.value, granularity="per_token", is_symmetric=False - ) - weight_config = FakeQuantizeConfig(dtype=weight_dtype.value, group_size=group_size) - linear_quantize_config = IntXQuantizationAwareTrainingConfig( - activation_config=None if activation_dtype is None else activation_config, - weight_config=weight_config, - ) - quantize_(model, linear_quantize_config) - if quantize_embedding: - # activation fake quantization is not supported for embedding layers - embedding_quantize_config = IntXQuantizationAwareTrainingConfig( - activation_config=None, - weight_config=weight_config, - ) - quantize_( - model, - embedding_quantize_config, - filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), - ) - - -def quantize_model_for_ptq( - model, - weight_dtype: TorchIntDType, - group_size: int | None = None, - activation_dtype: TorchIntDType | None = None, - quantize_embedding: bool | None = None, -): - """ - This function is used to quantize a model for post-training quantization. - It swaps the model's linear layers with fake quantized linear layers. - If `quantize_embedding` is True, it will also swap the model's embedding weights with fake quantized embedding weights. - - Args: - model: The model to quantize. - weight_dtype: The dtype to use for weight quantization. - group_size: The group size to use for weight quantization. - activation_dtype: The dtype to use for activation quantization. - quantize_embedding: Whether to quantize the model's embedding weights. - - """ - linear_ptq_config = get_ptq_config( + base_config = get_quantization_config( weight_dtype=weight_dtype, activation_dtype=activation_dtype, group_size=group_size, ) - quantize_(model, linear_ptq_config) + qat_config = QATConfig(base_config) + quantize_(model, qat_config) if quantize_embedding: - embedding_quantize_config = get_ptq_config( + # activation fake quantization is not supported for embedding layers + embedding_base_config = get_quantization_config( weight_dtype=weight_dtype, activation_dtype=None, group_size=group_size, ) + embedding_qat_config = QATConfig(embedding_base_config) quantize_( model, - embedding_quantize_config, + embedding_qat_config, filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), ) -def convert_qat_model_for_ptq( +def convert_qat_model( model, - *, - quantize_embedding: bool | None = None, + quantize_embedding: bool = False, ): """ - This function is used to convert a swap fake-quantized modules in a model - which has been trained with QAT back to the original modules, ready for PTQ. - - Args: - model: The model to convert. - quantize_embedding: Whether to quantize the model's embedding weights. + This function converts a QAT model which has fake quantized layers back to the original model. """ + config = QATConfig(step="convert") + quantize_(model, config) if quantize_embedding: - - def filter_fn(m, _): - return isinstance(m, nn.Embedding) or _is_linear(m) - - else: - filter_fn = _is_linear - quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn) + quantize_( + model, + config, + filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), + ) diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index e488ed7d5..662c63caa 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -3,7 +3,10 @@ Multipack Batch Sampler - An efficient batch sampler for packing variable-length into fixed-capacity batches to optimize memory usage and training throughput. """ +import gc import math +import os +import time from concurrent.futures import ProcessPoolExecutor from multiprocessing import cpu_count, get_context from typing import Iterable, Iterator, Union @@ -126,7 +129,7 @@ def pack_parallel( bin_size: int, num_processes: int | None = None, safe_mode: bool = True, - mp_start_method: str | None = "spawn", + mp_start_method: str | None = "fork", ) -> list[list[int]]: """Pack sequences into bins using parallel processing. @@ -145,7 +148,7 @@ def pack_parallel( """ num_items = len(sequence_lengths) if num_processes is None: - num_processes = max(1, min(num_items // group_size, cpu_count())) + num_processes = max(1, min(num_items // group_size, cpu_count(), 16)) # Create tasks for parallel processing tasks = [] @@ -258,14 +261,15 @@ class MultipackBatchSampler(BatchSampler): batch_max_len: int, # Maximum sequence length (bin capacity) lengths: np.ndarray, # Sequence lengths packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate - drop_last: bool = False, # Whether to drop final batches (might be incomplete) - num_count_samples: int = 16, # Number of times to estimate batch count + drop_last: bool = True, # Whether to drop final batches (might be incomplete) + num_count_samples: int = 4, # Number of times to estimate batch count sequential: bool = False, # Whether to use sequential packing group_size: int = 100_000, # Size of groups for parallel packing bin_size: int = 200, # The max number of samples that can be packed in a single bin num_processes: int | None = None, # Number of processes for parallel packing safe_mode: bool = True, # Conservative packing to prevent training instability - **kwargs, # pylint: disable=unused-argument + mp_start_method: str = "fork", + **kwargs, ): super().__init__(sampler, batch_size, drop_last) self.batch_size = batch_size @@ -277,6 +281,7 @@ class MultipackBatchSampler(BatchSampler): self.bin_size = bin_size self.num_processes = num_processes self.safe_mode = safe_mode + self.mp_start_method = mp_start_method assert isinstance(self.lengths, np.ndarray) @@ -287,7 +292,10 @@ class MultipackBatchSampler(BatchSampler): self.total_token_slots = 0 # The number of times to calculate batches to determine minimum packed dataset length - self.num_count_samples = num_count_samples + world_size = int(os.environ.get("WORLD_SIZE", "1")) + self.num_count_samples = ( + 1 if world_size >= num_count_samples else num_count_samples + ) if self.sequential and not isinstance(sampler, SequentialSampler): LOG.warning( @@ -313,9 +321,7 @@ class MultipackBatchSampler(BatchSampler): return self._batches # Get indices from the sampler - indices = [ # pylint: disable=unnecessary-comprehension - idx for idx in self.sampler - ] + indices = [idx for idx in self.sampler] # Get lengths of the selected sequences lengths = self.lengths[indices] @@ -332,13 +338,15 @@ class MultipackBatchSampler(BatchSampler): bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins] else: # Use parallel packing + num_processes = self.num_processes or 1 all_bins = pack_parallel( lengths, bin_capacity=self.batch_max_len, group_size=self.group_size, bin_size=self.bin_size, - num_processes=self.num_processes, + num_processes=min(4, num_processes) if num_processes else 4, safe_mode=self.safe_mode, + mp_start_method=self.mp_start_method, ) # Map bin indices back to original indices @@ -349,6 +357,7 @@ class MultipackBatchSampler(BatchSampler): # Calculate efficiency statistics total_used = lengths.sum() total_slots = len(all_bins) * self.batch_max_len + del all_bins # Group bins into batches (each batch contains batch_size bins) batches = [ @@ -368,6 +377,7 @@ class MultipackBatchSampler(BatchSampler): self.total_token_slots += total_slots self._batches = batches + gc.collect() return batches def __iter__(self) -> Iterator[list[list[int]]]: @@ -409,7 +419,7 @@ class MultipackBatchSampler(BatchSampler): # Gather efficiency from all ranks and apply the calculation function sample_packing_actual_eff_all = reduce_and_broadcast( - lambda: float(self.efficiency()), # pylint: disable=unnecessary-lambda + lambda: float(self.efficiency()), calc_sample_packing_eff_est, ) @@ -443,10 +453,21 @@ class MultipackBatchSampler(BatchSampler): if self._len_across_ranks is None: # Sample multiple times to get stable estimate - len_batches = min( # pylint: disable=consider-using-generator - [len(self._batches) for _ in range(self.num_count_samples)] - ) + _sampled_lens = [] + for _ in range(self.num_count_samples): + self._batches = None # Reset cached batches + # log timer for generating batches + start_time = time.time() + _sampled_lens.append(len(self.generate_batches(set_stats=False))) + LOG.debug(f"generate_batches time: {time.time() - start_time}") + len_batches = min(_sampled_lens) + # Gather minimum across all ranks - self._len_across_ranks = self.gather_len_batches(len_batches) + if self._len_across_ranks is None: + self._len_across_ranks = self.gather_len_batches(len_batches) + else: + self._len_across_ranks = min( + self._len_across_ranks, self.gather_len_batches(len_batches) + ) return self._len_across_ranks diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py index b550ac02c..83a993089 100644 --- a/src/axolotl/utils/schedulers.py +++ b/src/axolotl/utils/schedulers.py @@ -2,7 +2,9 @@ import math from functools import partial +from typing import Sequence +from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler @@ -44,8 +46,10 @@ class RexLR(LRScheduler): # Ensure each parameter group has an "initial_lr" key to avoid issues when resuming. for group in optimizer.param_groups: - group.setdefault("initial_lr", group["lr"]) - + initial_lr = group["lr"] + if isinstance(initial_lr, Tensor): + initial_lr = initial_lr.clone() + group.setdefault("initial_lr", initial_lr) # Pass self.last_step as last_epoch to the parent. super().__init__(optimizer, last_epoch=self.last_step) @@ -103,9 +107,7 @@ class InterpolatingLogScheduler(LRScheduler): self.num_steps = num_steps self.min_lr = min_lr self.max_lr = max_lr - self.q = (max_lr / min_lr) ** ( # pylint: disable=invalid-name - 1 / (num_steps - 1) - ) + self.q = (max_lr / min_lr) ** (1 / (num_steps - 1)) super().__init__(optimizer, last_epoch) def get_lr(self): @@ -292,3 +294,49 @@ def get_cosine_schedule_with_warmup_decay_constant( num_cycles=num_cycles, ) return LambdaLR(optimizer, lr_lambda, last_epoch) + + +class JaggedLRRestartScheduler(LRScheduler): + """Wraps another scheduler to apply per-lora-restart learning rate warmups.""" + + def __init__( + self, + optimizer: Optimizer, + inner_schedule: LRScheduler, + jagged_restart_steps: int, + jagged_restart_warmup_steps: int, + jagged_restart_anneal_steps: int = 1, + min_lr_scale: float = 0.001, + ) -> None: + self.inner_schedule = inner_schedule + self.restarts_steps = jagged_restart_steps + self.warmup_steps = jagged_restart_warmup_steps + self.anneal_steps = jagged_restart_anneal_steps + self.min_lr_scale = min_lr_scale + super().__init__(optimizer, inner_schedule.last_epoch) + + def get_lr(self) -> float | Sequence[float]: + self.inner_schedule.last_epoch = self.last_epoch + + original = self.inner_schedule.get_lr() + step = self.last_epoch + + if step < self.restarts_steps - self.anneal_steps: + scale = 1 + else: + per_restart_progress = step % self.restarts_steps + if per_restart_progress < self.warmup_steps: + cycle_t = min(1.0, (per_restart_progress) / self.warmup_steps) + elif per_restart_progress > (self.restarts_steps - self.anneal_steps): + cycle_t = min( + 1.0, + (self.restarts_steps - per_restart_progress) / self.anneal_steps, + ) + else: + cycle_t = 1 + scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale + + if isinstance(original, Sequence): + return [lr * scale for lr in original] + + return original * scale diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 936ccf85b..2a2d63af8 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1,8 +1,5 @@ """Module with Pydantic models for configuration.""" -# pylint: disable=too-many-lines - -import os from typing import Annotated, Any, Literal from annotated_types import MinLen @@ -12,11 +9,10 @@ from pydantic import ( Field, StringConstraints, field_serializer, - field_validator, model_validator, ) -from transformers.utils.import_utils import is_torch_npu_available +from axolotl.utils.datasets import get_default_process_count from axolotl.utils.logging import get_logger from axolotl.utils.schemas.datasets import ( DatasetConfig, @@ -28,6 +24,7 @@ from axolotl.utils.schemas.datasets import ( ) from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType +from axolotl.utils.schemas.fsdp import FSDPConfig from axolotl.utils.schemas.integrations import ( CometConfig, GradioConfig, @@ -45,21 +42,20 @@ from axolotl.utils.schemas.model import ( from axolotl.utils.schemas.multimodal import MultiModalConfig from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig from axolotl.utils.schemas.quantization import PTQConfig, QATConfig -from axolotl.utils.schemas.training import HyperparametersConfig +from axolotl.utils.schemas.training import HyperparametersConfig, JaggedLRConfig from axolotl.utils.schemas.trl import TRLConfig +from axolotl.utils.schemas.validation import ValidationMixin from axolotl.utils.schemas.vllm import VllmConfig -LOG = get_logger(__name__, use_environ=True) - -SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} +LOG = get_logger(__name__) -# pylint: disable=too-many-public-methods,too-many-ancestors class AxolotlInputConfig( ModelInputConfig, ModelOutputConfig, LoraConfig, ReLoRAConfig, + JaggedLRConfig, HyperparametersConfig, WandbConfig, MLFlowConfig, @@ -70,38 +66,105 @@ class AxolotlInputConfig( MultiModalConfig, RemappedParameters, DeprecatedParameters, + ValidationMixin, BaseModel, ): - """Wrapper of all config options""" + """Wrapper of all config options.""" model_config = {"populate_by_name": True} - strict: bool | None = Field(default=False) - resume_from_checkpoint: str | None = None - auto_resume_from_checkpoints: bool | None = None - resize_token_embeddings_to_32x: bool | None = None + strict: bool | None = Field( + default=False, + json_schema_extra={"description": "Allow overwrite yml config using from cli"}, + ) + resume_from_checkpoint: str | None = Field( + default=None, + json_schema_extra={"description": "Resume from a specific checkpoint dir"}, + ) + auto_resume_from_checkpoints: bool | None = Field( + default=None, + json_schema_extra={ + "description": "If resume_from_checkpoint isn't set and you simply want it to start where it left off. Be careful with this being turned on between different models." + }, + ) + resize_token_embeddings_to_32x: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Resize the model embeddings when new tokens are added to multiples of 32. This is reported to improve training speed on some models" + }, + ) mean_resizing_embeddings: bool | None = False # optionally shrink the embeddings when the tokenizer vocab size is smaller - shrink_embeddings: bool | None = None - embeddings_skip_upcast: bool | None = None + shrink_embeddings: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink." + }, + ) + embeddings_skip_upcast: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs" + }, + ) + reinit_weights: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Reinitialize model weights randomly instead of loading pretrained weights" + }, + ) - rl: RLType | None = None + trainer_cls: str | None = Field( + default=None, + json_schema_extra={ + "description": "module to custom trainer class to use for training" + }, + ) + + rl: RLType | None = Field( + default=None, + json_schema_extra={ + "description": "Use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'" + }, + ) trl: TRLConfig | None = Field( - default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda + default_factory=lambda: TRLConfig(), ) vllm: VllmConfig | None = Field( - default_factory=lambda: VllmConfig(), # pylint: disable=unnecessary-lambda + default_factory=lambda: VllmConfig(), ) qat: QATConfig | None = None quantization: PTQConfig | None = None - reward_model: bool | None = None - process_reward_model: bool | None = None + reward_model: bool | None = Field( + default=None, + json_schema_extra={"description": "Reward modelling: `True` or `False`"}, + ) + process_reward_model: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Process reward modelling: `True` or `False`" + }, + ) + center_rewards_coefficient: float | None = Field( + default=None, + json_schema_extra={ + "description": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`." + }, + ) num_labels: int | None = None # Whether to use weighting in DPO trainer. # If `None`, default is `False` in the trainer. - dpo_use_weighting: bool | None = None + dpo_use_weighting: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to perform weighting in DPO trainer" + }, + ) dpo_use_logits_to_keep: bool | None = None dpo_label_smoothing: float | None = None + dpo_norm_loss: bool | None = None + dpo_padding_free: bool | None = None + dpo_generate_during_eval: bool | None = None datasets: ( Annotated[ @@ -109,7 +172,12 @@ class AxolotlInputConfig( MinLen(1), ] | None - ) = None + ) = Field( + default=None, + json_schema_extra={ + "description": "A list of one or more datasets to finetune the model with" + }, + ) test_datasets: ( Annotated[ @@ -117,22 +185,85 @@ class AxolotlInputConfig( MinLen(1), ] | None - ) = None - shuffle_merged_datasets: bool | None = True - dataset_prepared_path: str | None = None - dataset_shard_num: int | None = None - dataset_shard_idx: int | None = None + ) = Field( + default=None, + json_schema_extra={ + "description": "A list of one or more datasets to eval the model with. You can use either test_datasets, or val_set_size, but not both." + }, + ) + shuffle_merged_datasets: bool | None = Field( + default=True, + json_schema_extra={ + "description": "If false, the datasets will not be shuffled and will keep their original order in `datasets`. The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true." + }, + ) + shuffle_before_merging_datasets: bool | None = Field( + default=False, + json_schema_extra={ + "description": "If true, each dataset in `datasets` will be shuffled before merging. This allows curriculum learning strategies to be applied at the dataset level. Default is false." + }, + ) + dataset_prepared_path: str | None = Field( + default=None, + json_schema_extra={ + "description": "Axolotl attempts to save the dataset as an arrow after packing the data together so subsequent training attempts load faster, relative path" + }, + ) + dataset_shard_num: int | None = Field( + default=None, json_schema_extra={"description": "Num shards for whole dataset"} + ) + dataset_shard_idx: int | None = Field( + default=None, + json_schema_extra={"description": "Index of shard to use for whole dataset"}, + ) skip_prepare_dataset: bool | None = False + num_dataset_shards_to_save: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of shards to save the prepared dataset" + }, + ) pretraining_dataset: ( Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None ) = Field( default=None, - json_schema_extra={"description": "streaming dataset to use for pretraining"}, + json_schema_extra={ + "description": "Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize" + }, + ) + dataset_processes: int | None = Field( + default=None, + deprecated="Use `dataset_num_proc` instead. This parameter will be removed in a future version.", + json_schema_extra={ + "description": ( + "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n" + "For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT." + ) + }, + ) + dataset_num_proc: int | None = Field( + default=None, + json_schema_extra={ + "description": ( + "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n" + "For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT." + ) + }, + ) + + dataset_exact_deduplication: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Deduplicates datasets and test_datasets with identical entries" + }, + ) + dataset_keep_in_memory: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Keep dataset in memory while preprocessing. Only needed if cached dataset is taking too much storage" + }, ) - dataset_processes: int | None = Field(default=min(32, os.cpu_count() or 1)) - dataset_exact_deduplication: bool | None = None - dataset_keep_in_memory: bool | None = None dataloader_pin_memory: bool | None = None dataloader_num_workers: int | None = None dataloader_prefetch_factor: int | None = None @@ -142,75 +273,239 @@ class AxolotlInputConfig( remove_unused_columns: bool | None = None - push_dataset_to_hub: str | None = None - hf_use_auth_token: bool | None = None + push_dataset_to_hub: str | None = Field( + default=None, + json_schema_extra={ + "description": "Push prepared dataset to hub - repo_org/repo_name" + }, + ) + hf_use_auth_token: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets. Required to be true when used in combination with `push_dataset_to_hub`" + }, + ) device: Any | None = None - device_map: Any | None = None + device_map: Any | None = Field( + default=None, + json_schema_extra={ + "description": "Passed through to transformers when loading the model when launched without accelerate. Use `sequential` when training w/ model parallelism to limit memory" + }, + ) world_size: int | None = None - local_rank: int | None = None + local_rank: int | None = Field( + default=None, + json_schema_extra={ + "description": "Don't mess with this, it's here for accelerate and torchrun" + }, + ) ddp: bool | None = None - seed: int | None = None - ddp_timeout: int | None = None - ddp_bucket_cap_mb: int | None = None - ddp_broadcast_buffers: bool | None = None + seed: int | None = Field( + default=None, json_schema_extra={"description": "Seed for reproducibility"} + ) + ddp_timeout: int | None = Field( + default=None, + json_schema_extra={"description": "Advanced DDP Arguments - timeout"}, + ) + ddp_bucket_cap_mb: int | None = Field( + default=None, + json_schema_extra={"description": "Advanced DDP Arguments - bucket cap in MB"}, + ) + ddp_broadcast_buffers: bool | None = Field( + default=None, + json_schema_extra={"description": "Advanced DDP Arguments - broadcast buffers"}, + ) ddp_find_unused_parameters: bool | None = None - eval_table_size: int | None = None - eval_max_new_tokens: int | None = None - do_causal_lm_eval: bool | None = None - eval_causal_lm_metrics: list[str] | None = None + eval_table_size: int | None = Field( + default=None, + json_schema_extra={ + "description": "Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0" + }, + ) + eval_max_new_tokens: int | None = Field( + default=None, + json_schema_extra={ + "description": "Total number of tokens generated for predictions sent to wandb. Default is 128" + }, + ) + do_causal_lm_eval: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`" + }, + ) + eval_causal_lm_metrics: list[str] | None = Field( + default=None, + json_schema_extra={ + "description": "HF evaluate metrics used during evaluation. Default is ['sacrebleu', 'comet', 'ter', 'chrf', 'perplexity']" + }, + ) do_bench_eval: bool | None = None bench_dataset: str | None = None bench_split: str | None = None metric_for_best_model: str | None = None greater_is_better: bool | None = None - loss_watchdog_threshold: float | None = None - loss_watchdog_patience: int | None = None - - gc_steps: int | None = None - - bf16: Literal["auto"] | bool | None = "auto" - fp16: bool | None = None - fp8: bool | None = None - bfloat16: bool | None = None # for non-AMP cases - float16: bool | None = None # for non-AMP cases - tf32: bool | None = None - float32: bool | None = None - - # torch_dtype: torch.dtype | None - - gradient_checkpointing: Literal["offload", "offload_disk"] | bool | None = Field( - default=False - ) - gradient_checkpointing_kwargs: dict[str, Any] | None = None - - unfrozen_parameters: list[str] | None = None - - sequence_len: int = Field(default=512) - min_sample_len: int | None = None - max_prompt_len: int = Field( - default=512, - json_schema_extra={"description": "maximum prompt length for RL training"}, - ) - sample_packing: bool | None = None - sample_packing_group_size: int | None = 100_000 - sample_packing_bin_size: int | None = 200 - sample_packing_sequentially: bool | None = None - eval_sample_packing: bool | None = None - pad_to_sequence_len: bool | None = None - curriculum_sampling: bool | None = None - multipack_real_batches: bool | None = None - pretraining_sample_concatenation: bool | None = Field( + loss_watchdog_threshold: float | None = Field( default=None, json_schema_extra={ - "description": "whether to soft pack/concatenate samples during pretraining", + "description": "High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)" + }, + ) + loss_watchdog_patience: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of high-loss steps in a row before the trainer aborts (default: 3)" }, ) - batch_flattening: Literal["auto"] | bool | None = None + gc_steps: int | None = Field( + default=None, + json_schema_extra={ + "description": "Run garbage collection every `gc_steps` steps. -1 will run on epoch end and before evaluations. Default is 0 (disabled)." + }, + ) + + bf16: Literal["auto"] | bool | None = Field( + default="auto", + json_schema_extra={ + "description": "Use CUDA bf16. bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere" + }, + ) + fp16: bool | None = Field( + default=None, json_schema_extra={"description": "Use CUDA fp16"} + ) + fp8: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Enable FP8 mixed precision training using TorchAO. Best " + "used in combination with torch.compile." + }, + ) + fp8_enable_fsdp_float8_all_gather: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Enable FSDP float8 all-gather optimization for FP8 training. Can " + "improve training speed by 10-15% when FSDP is enabled." + }, + ) + bfloat16: bool | None = Field( + default=None, + json_schema_extra={ + "description": "No AMP (automatic mixed precision) - require >=ampere" + }, + ) # for non-AMP cases + float16: bool | None = Field( + default=None, + json_schema_extra={"description": "No AMP (automatic mixed precision)"}, + ) # for non-AMP cases + tf32: bool | None = Field( + default=None, + json_schema_extra={"description": "Use CUDA tf32 - require >=ampere"}, + ) + float32: bool | None = None + + gradient_checkpointing: Literal["offload", "offload_disk"] | bool | None = Field( + default=False, + json_schema_extra={ + "description": "Whether to use gradient checkpointing. Available options are: true, false, 'offload', 'offload_disk'. https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing" + }, + ) + gradient_checkpointing_kwargs: dict[str, Any] | None = Field( + default=None, + json_schema_extra={ + "description": "Additional kwargs to pass to the trainer for gradient checkpointing" + }, + ) + activation_offloading: Literal["legacy", "disk"] | bool | None = Field( + default=False, + json_schema_extra={ + "description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'." + }, + ) + + unfrozen_parameters: list[str] | None = None + + sequence_len: int = Field( + default=512, + json_schema_extra={ + "description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048" + }, + ) + excess_length_strategy: Literal["drop", "truncate"] | None = Field( + default=None, + json_schema_extra={ + "description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility." + }, + ) + eval_sequence_len: int | None = Field( + default=None, + json_schema_extra={ + "description": "The maximum length of an input for evaluation. If not specified, defaults to sequence_len" + }, + ) + min_sample_len: int | None = None + max_prompt_len: int | None = Field( + default=None, + json_schema_extra={"description": "maximum prompt length for RL training"}, + ) + sample_packing: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'" + }, + ) + sample_packing_group_size: int | None = Field( + default=100_000, + json_schema_extra={ + "description": "The number of samples packed at a time. Increasing the following values helps with packing, but usually only slightly (<%1.)" + }, + ) + sample_packing_bin_size: int | None = Field( + default=200, + json_schema_extra={ + "description": "The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples." + }, + ) + sample_packing_sequentially: bool | None = Field( + default=None, + json_schema_extra={"description": "Whether to pack samples sequentially"}, + ) + sample_packing_mp_start_method: str | None = Field( + default=None, + json_schema_extra={ + "description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'" + }, + ) + eval_sample_packing: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Set to 'false' if getting errors during eval with sample_packing on" + }, + ) + pad_to_sequence_len: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Pad inputs so each step uses constant sized buffers. This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently. Defaults to True if `sample_packing` enabled" + }, + ) + curriculum_sampling: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use sequential sampling for curriculum learning" + }, + ) + multipack_real_batches: bool | None = None + + batch_flattening: Literal["auto"] | bool | None = Field( + default=None, + json_schema_extra={ + "description": "Use batch flattening for speedups when not using sample_packing" + }, + ) # for PoSE context length extension use_pose: bool | None = None @@ -218,28 +513,93 @@ class AxolotlInputConfig( pose_max_context_len: int | None = None pose_num_chunks: int | None = None - pretrain_multipack_buffer_size: int | None = 10_000 + # Deprecated: Use streaming_multipack_buffer_size instead + pretrain_multipack_buffer_size: int | None = Field( + default=None, + deprecated="Deprecated in v0.13.0, will be removed in v0.14.0. Use streaming_multipack_buffer_size instead", + ) pretrain_multipack_attn: bool | None = Field( default=True, json_schema_extra={ "description": "whether to prevent cross attention for packed sequences during pretraining", }, ) + pretraining_sample_concatenation: bool | None = Field( + default=None, + json_schema_extra={ + "description": "whether to concatenate samples during pretraining", + }, + ) - xformers_attention: bool | None = None - sdp_attention: bool | None = None - s2_attention: bool | None = None + streaming: bool | None = Field( + default=None, + json_schema_extra={"description": "Use streaming mode for loading datasets"}, + ) + streaming_multipack_buffer_size: int | None = Field( + default=10_000, + json_schema_extra={ + "description": "Buffer size for multipack streaming datasets" + }, + ) + + xformers_attention: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use xformers attention patch https://github.com/facebookresearch/xformers" + }, + ) + sdp_attention: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use scaled-dot-product attention https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html" + }, + ) + s2_attention: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf" + }, + ) flex_attention: bool | None = None flex_attn_compile_kwargs: dict[str, Any] | None = None - flash_attention: bool | None = None - flash_attn_cross_entropy: bool | None = None - flash_attn_rms_norm: bool | None = None - flash_attn_fuse_qkv: bool | None = None - flash_attn_fuse_mlp: bool | None = None - flash_optimum: bool | None = None + flash_attention: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention" + }, + ) + flash_attn_cross_entropy: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use flash-attention cross entropy implementation - advanced use only" + }, + ) + flash_attn_rms_norm: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use flash-attention rms norm implementation - advanced use only" + }, + ) + flash_attn_fuse_mlp: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to fuse part of the MLP into a single operation" + }, + ) + flash_optimum: bool | None = Field( + default=None, + json_schema_extra={"description": "Whether to use bettertransformers"}, + ) eager_attention: bool | None = None + attn_implementation: str | None = Field( + default=None, + json_schema_extra={ + "description": "Specify a custom attention implementation, used mostly for kernels." + }, + ) + unsloth_cross_entropy_loss: bool | None = None unsloth_lora_mlp: bool | None = None unsloth_lora_qkv: bool | None = None @@ -247,118 +607,427 @@ class AxolotlInputConfig( unsloth_rms_norm: bool | None = None unsloth_rope: bool | None = None - lora_mlp_kernel: bool | None = None - lora_qkv_kernel: bool | None = None - lora_o_kernel: bool | None = None + lora_mlp_kernel: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html" + }, + ) + lora_qkv_kernel: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html" + }, + ) + lora_o_kernel: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html" + }, + ) + + chunked_cross_entropy: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use chunked cross entropy loss for memory efficiency" + }, + ) + chunked_cross_entropy_num_chunks: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of chunks to use for chunked cross entropy loss" + }, + ) + + tiled_mlp: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use ALST tiled mlp for memory efficient long context" + }, + ) + + tiled_mlp_num_shards: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of shards to use for ALST tiled mlp. If unset, it will be set based on seqlen/hidden_size" + }, + ) + + tiled_mlp_use_original_mlp: bool | None = Field( + default=True, + json_schema_extra={ + "description": "Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on llama." + }, + ) llama4_linearized_experts: bool | None = None - deepspeed: str | dict[str, Any] | None = None - fsdp: list[str] | None = None - fsdp_config: dict[str, Any] | None = None + deepspeed: str | dict[str, Any] | None = Field( + default=None, + json_schema_extra={ + "description": "Deepspeed config path. e.g., deepspeed_configs/zero3.json" + }, + ) + deepcompile: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use deepcompile for faster training with deepspeed" + }, + ) + fsdp: list[str] | None = Field( + default=None, + json_schema_extra={"description": "FSDP configuration"}, + deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ", + ) + fsdp_config: FSDPConfig | None = Field( + default=None, json_schema_extra={"description": "FSDP configuration options"} + ) + fsdp_version: int | None = Field( + default=None, + json_schema_extra={"description": "FSDP version"}, + ) fsdp_final_state_dict_type: ( Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None - ) = None + ) = Field( + default=None, + deprecated="Configuring FSDP final state dict type using `fsdp_final_state_dict_type` is deprecated. Please use `fsdp_config.final_state_dict_type` instead.", + ) - val_set_size: float | None = Field(default=0.0) + val_set_size: float | None = Field( + default=0.0, + json_schema_extra={ + "description": "How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval." + }, + ) - sequence_parallel_degree: int | None = None - heads_k_stride: int | None = None - ring_attn_func: RingAttnFunc | None = None + dp_shard_size: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of devices to shard across. If not set, will use all available devices." + }, + ) + dp_replicate_size: int | None = Field( + default=None, + json_schema_extra={"description": "Number of devices to replicate across."}, + ) + sequence_parallel_degree: int | None = Field( + default=None, + json_schema_extra={ + "description": "Deprecated: use `context_parallel_size` instead" + }, + ) + context_parallel_size: int | None = Field( + default=None, + json_schema_extra={ + "description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details." + }, + ) + heads_k_stride: int | None = Field( + default=None, + json_schema_extra={ + "description": "Optional; strides across the key dimension. Larger values use more memory but should make training faster. Must evenly divide the number of KV heads in your model." + }, + ) + ring_attn_func: RingAttnFunc | None = Field( + default=None, + json_schema_extra={ + "description": "One of 'varlen_llama3', 'batch_ring', 'batch_zigzag', 'batch_stripe'. Defaults to 'varlen_llama3' in the sample packing case, and 'batch_ring' in the non-sample packing case." + }, + ) + tensor_parallel_size: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP." + }, + ) + special_tokens: SpecialTokensConfig | None = Field( + default=None, + json_schema_extra={ + "description": "Add or change special tokens. If you add tokens here, you don't need to add them to the `tokens` list." + }, + ) + tokens: list[str] | None = Field( + default=None, + json_schema_extra={"description": "Add extra tokens to the tokenizer"}, + ) + added_tokens_overrides: dict[int, str] | None = Field( + default=None, + json_schema_extra={ + "description": "Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer. Only works for tokens that are not part of the base vocab (aka are added_tokens). Can be checked if they exist in tokenizer.json added_tokens." + }, + ) - special_tokens: SpecialTokensConfig | None = None - tokens: list[str] | None = None - added_tokens_overrides: dict[int, str] | None = None - - torch_compile: Literal["auto"] | bool | None = None - torch_compile_backend: str | None = None + torch_compile: Literal["auto"] | bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use torch.compile and which backend to use. setting to `auto` will enable torch compile when torch>=2.6.0" + }, + ) + torch_compile_backend: str | None = Field( + default=None, + json_schema_extra={"description": "Backend to use for torch.compile"}, + ) torch_compile_mode: Literal["default", "reduce-overhead", "max-autotune"] | None = ( None ) - max_steps: int | None = None - warmup_steps: int | None = None - warmup_ratio: float | None = None - eval_steps: int | float | None = None - evals_per_epoch: int | None = None - eval_strategy: str | None = None - save_steps: int | float | None = None - saves_per_epoch: int | None = None - save_strategy: str | None = None - save_total_limit: int | None = None - logging_steps: int | None = None - early_stopping_patience: int | None = None + max_steps: int | None = Field( + default=None, + json_schema_extra={ + "description": "Maximum number of iterations to train for. It precedes num_epochs which means that if both are set, num_epochs will not be guaranteed. e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps" + }, + ) + warmup_steps: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of warmup steps. Cannot use with warmup_ratio" + }, + ) + warmup_ratio: float | None = Field( + default=None, + json_schema_extra={"description": "Warmup ratio. Cannot use with warmup_steps"}, + ) + eval_steps: int | float | None = Field( + default=None, + json_schema_extra={ + "description": "Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps" + }, + ) + evals_per_epoch: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of times per epoch to run evals, mutually exclusive with eval_steps" + }, + ) + eval_strategy: str | None = Field( + default=None, + json_schema_extra={ + "description": "Set to `no` to skip evaluation, `epoch` at end of each epoch, leave empty to infer from `eval_steps`" + }, + ) + + save_steps: int | float | None = Field( + default=None, + json_schema_extra={ + "description": "Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps" + }, + ) + saves_per_epoch: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of times per epoch to save a checkpoint, mutually exclusive with save_steps" + }, + ) + save_strategy: str | None = Field( + default=None, + json_schema_extra={ + "description": "Set to `no` to skip checkpoint saves, `epoch` at end of each epoch, `best` when better result is achieved, leave empty to infer from `save_steps`" + }, + ) + save_total_limit: int | None = Field( + default=None, json_schema_extra={"description": "Checkpoints saved at a time"} + ) + save_first_step: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to checkpoint a model after the first step of training. Defaults to False." + }, + ) + + logging_steps: int | None = Field( + default=None, json_schema_extra={"description": "Logging frequency"} + ) + early_stopping_patience: int | None = Field( + default=None, + json_schema_extra={ + "description": "Stop training after this many evaluation losses have increased in a row. https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback" + }, + ) load_best_model_at_end: bool | None = False - save_only_model: bool | None = False - use_tensorboard: bool | None = None - profiler_steps: int | None = None - include_tokens_per_second: bool | None = None + save_only_model: bool | None = Field( + default=False, + json_schema_extra={ + "description": "Save only the model weights, skipping the optimizer. Using this means you can't resume from checkpoints." + }, + ) + use_tensorboard: bool | None = Field( + default=None, json_schema_extra={"description": "Use tensorboard for logging"} + ) + profiler_steps: int | None = Field( + default=None, + json_schema_extra={ + "description": "Enable the pytorch profiler to capture the first N steps of training to the output_dir. see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information. Snapshots can be visualized @ https://pytorch.org/memory_viz" + }, + ) + profiler_steps_start: int | None = Field( + default=0, + json_schema_extra={ + "description": "Which step to start the profiler at. Useful for only capturing a few steps mid-run." + }, + ) + include_tokens_per_second: bool | None = Field( + default=None, + json_schema_extra={ + "description": "bool of whether to report tokens per second at the end of training. This is not supported with pre-training datasets." + }, + ) + include_tkps: bool | None = Field( + default=True, + json_schema_extra={ + "description": "bool of whether to report tokens per second per-gpu during training by measuring throughput of non-padding tokens." + }, + ) + neftune_noise_alpha: float | None = Field( + default=None, + json_schema_extra={ + "description": "NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings. Currently only supported on Llama and Mistral" + }, + ) - neftune_noise_alpha: float | None = None + orpo_alpha: float | None = Field( + default=None, + json_schema_extra={ + "description": "Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping." + }, + ) + rpo_alpha: float | None = Field( + default=None, + json_schema_extra={ + "description": "Weighting of NLL term in loss from RPO paper" + }, + ) + simpo_gamma: float | None = Field( + default=None, + json_schema_extra={"description": "Target reward margin for the SimPO loss"}, + ) + cpo_alpha: float | None = Field( + default=None, json_schema_extra={"description": "Weight of the BC regularizer"} + ) - orpo_alpha: float | None = None - rpo_alpha: float | None = None - simpo_gamma: float | None = None - cpo_alpha: float | None = None + kto_desirable_weight: float | None = Field( + default=None, + json_schema_extra={"description": "Factor for desirable loss term in KTO loss"}, + ) + kto_undesirable_weight: float | None = Field( + default=None, + json_schema_extra={ + "description": "Factor for undesirable loss term in KTO loss" + }, + ) + rl_beta: float | None = Field( + default=None, + json_schema_extra={"description": "The beta parameter for the RL training"}, + ) - kto_desirable_weight: float | None = None - kto_undesirable_weight: float | None = None - rl_beta: float | None = None - - max_memory: dict[int | Literal["cpu", "disk"], int | str] | None = None - gpu_memory_limit: int | str | None = None - low_cpu_mem_usage: bool | None = None + max_memory: dict[int | Literal["cpu", "disk"], int | str] | None = Field( + default=None, + json_schema_extra={ + "description": "Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model." + }, + ) + gpu_memory_limit: int | str | None = Field( + default=None, + json_schema_extra={ + "description": "Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset" + }, + ) + low_cpu_mem_usage: bool | None = Field( + default=None, + json_schema_extra={"description": "Whether to use low_cpu_mem_usage"}, + ) chat_template: ( ChatTemplate | Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")] - ) | None = None - chat_template_jinja: str | None = None - chat_template_kwargs: dict[str, Any] | None = None - eot_tokens: list[str] | None = None - default_system_message: str | None = None + ) | None = Field( + default=None, + json_schema_extra={ + "description": "The name of the chat template to use for training, following values are supported: tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value. alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py. tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer. jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field. The selected chat template will be saved to the tokenizer_config.json for easier inferencing" + }, + ) + chat_template_jinja: str | None = Field( + default=None, + json_schema_extra={ + "description": "Custom jinja template or path to jinja file for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null." + }, + ) + chat_template_kwargs: dict[str, Any] | None = Field( + default=None, + json_schema_extra={ + "description": "Additional kwargs to pass to the chat template. This is useful for customizing the chat template. For example, you can pass `thinking=False` to add a generation prompt to the chat template." + }, + ) + eot_tokens: list[str] | None = Field( + default=None, + json_schema_extra={ + "description": "Custom EOT (End-of-Turn) tokens to mask/unmask during training. These tokens mark the boundaries between conversation turns. For example: ['/INST', '', '[/SYSTEM_PROMPT]']. If not specified, defaults to just the model's eos_token. This is useful for templates that use multiple delimiter tokens." + }, + ) + default_system_message: str | None = Field( + default=None, + json_schema_extra={ + "description": "Changes the default system message. Currently only supports chatml." + }, + ) - fix_untrained_tokens: int | list[int] | None = None + fix_untrained_tokens: int | list[int] | None = Field( + default=None, + json_schema_extra={ + "description": ( + "Token index or indices to adjust embedding weights to the mean of the other tokens. " + "This is useful when the model has untrained embeddings." + ) + }, + ) # INTERNALS - document for now, generally not set externally is_preprocess: bool | None = None preprocess_iterable: bool | None = None - total_num_tokens: int | None = None + total_num_tokens: int | None = Field( + default=None, + json_schema_extra={"description": "Total number of tokens - internal use"}, + ) total_supervised_tokens: int | None = None - sample_packing_eff_est: float | None = None + sample_packing_eff_est: float | None = Field( + default=None, + json_schema_extra={ + "description": "You can set these packing optimizations AFTER starting a training at least once. The trainer will provide recommended values for these values." + }, + ) axolotl_config_path: str | None = None - is_falcon_derived_model: bool | None = Field(default=None) - is_llama_derived_model: bool | None = Field(default=None) - is_mistral_derived_model: bool | None = Field(default=None) - is_qwen_derived_model: bool | None = Field(default=None) + is_falcon_derived_model: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Internal use only - Used to identify which the model is based on" + }, + ) + is_llama_derived_model: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Internal use only - Used to identify which the model is based on" + }, + ) + is_mistral_derived_model: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Internal use only - Used to identify which the model is based on. Please note that if you set this to true, `padding_side` will be set to 'left' by default" + }, + ) + is_qwen_derived_model: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Internal use only - Used to identify which the model is based on" + }, + ) - plugins: list[str] | None = Field(default=None) - - @field_validator("datasets", mode="before") - @classmethod - def deprecate_sharegpt_datasets(cls, datasets): - for _, ds_cfg in enumerate(datasets): - # Handle both dict and pydantic model cases - ds_type = ( - ds_cfg.get("type") - if isinstance(ds_cfg, dict) - else getattr(ds_cfg, "type", None) - ) - if not ds_type: - continue - - # skip if it's a dict (for custom user instruction prompt) - if isinstance(ds_type, dict): - continue - - if isinstance(ds_type, str) and ds_type.startswith("sharegpt"): - raise ValueError( - "`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead." - ) - - return datasets + plugins: list[str] | None = Field( + default=None, + json_schema_extra={ + "description": "Add plugins to extend the pipeline. See `src/axolotl/integrations` for the available plugins or doc below for more details. https://docs.axolotl.ai/docs/custom_integrations.html" + }, + ) @field_serializer("datasets") def datasets_serializer( @@ -370,896 +1039,27 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod - def check_attention_fields(cls, data): - fields = ( - "xformers_attention", - "sdp_attention", - "s2_attention", - "flash_attention", - "flex_attention", - ) - non_empty_count = sum(1 for field in fields if data.get(field)) - - if non_empty_count > 1: - raise ValueError(f"Only one of {', '.join(fields)} must be set") - return data - - @model_validator(mode="before") - @classmethod - def check_batch_size_fields(cls, data): - fields = ("micro_batch_size", "gradient_accumulation_steps", "batch_size") - non_empty_count = sum(1 for field in fields if data.get(field)) - - if non_empty_count < 2: - raise ValueError(f"At least two of {', '.join(fields)} must be set") - return data - - @model_validator(mode="before") - @classmethod - def check_pretraining_w_max_steps(cls, data): - if data.get("pretraining_dataset") and not data.get("max_steps"): - raise ValueError( - "max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_pretraining_w_group_by_length(cls, data): - if data.get("pretraining_dataset") and data.get("group_by_length"): - LOG.warning( - "You probably want to disable group_by_length as it will force a streamed dataset to download completely." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_pretraining_split_batches_accelerate(cls, data): - # alternatively set ACCELERATE_SPLIT_BATCHES=False - if data.get("pretraining_dataset"): - accelerator_config = data.get("accelerator_config", {}) - if not accelerator_config: - data["accelerator_config"] = { - "split_batches": False, - "dispatch_batches": False, - } - else: - if accelerator_config.get("split_batches") is None: - data["accelerator_config"]["split_batches"] = False - if accelerator_config.get("dispatch_batches") is None: - data["accelerator_config"]["dispatch_batches"] = False - return data - - @model_validator(mode="before") - @classmethod - def check_gptq_w_revision(cls, data): - if data.get("gptq") and data.get("revision_of_model"): - raise ValueError( - "revision_of_model is not supported for GPTQ models. " - + "Please download the model from HuggingFace Hub manually for correct branch, " - + "point to its path, and remove revision_of_model from the config." - ) - return data - - @model_validator(mode="before") - @classmethod - # pylint: disable=duplicate-code - def check_chat_template_config(cls, data): - # if chat_template is set to jinja, chat_template_jinja is required - if data.get("chat_template") == ChatTemplate.jinja and not data.get( - "chat_template_jinja" - ): - raise ValueError( - "chat_template_jinja is required when chat_template is set to jinja" - ) - - # If chat_template_jinja is set, set chat_template to jinja - if data.get("chat_template_jinja") and not data.get("chat_template"): - data["chat_template"] = ChatTemplate.jinja - - return data - - @model_validator(mode="before") - @classmethod - def check_sample_packing_wo_flash(cls, data): + def warn_peft_trainable_token_to_fix_untrained(cls, data): if ( - data.get("sample_packing") - and not data.get("flash_attention") - and not data.get("sdp_attention") - and not data.get("flex_attention") - and not data.get("xformers_attention") - ): - LOG.warning( - "sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination." - ) + peft_trainable_token_indices := data.get("peft_trainable_token_indices") + ) and (fix_untrained_tokens := data.get("fix_untrained_tokens")): + if isinstance(fix_untrained_tokens, int): + fix_untrained_tokens = (fix_untrained_tokens,) - return data + if isinstance(peft_trainable_token_indices, int): + peft_trainable_token_indices = (peft_trainable_token_indices,) - @model_validator(mode="before") - @classmethod - def check_sample_packing_with_s2attn(cls, data): - if data.get("sample_packing") and data.get("s2_attention"): - raise ValueError( - "Received `sample_packing=true` and `s2_attention=true`; however, \ - shifted-sparse attention does not currently support sample packing." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_batch_flattening_fa(cls, data): - if data.get("batch_flattening"): - batch_flattening_auto = data.get("batch_flattening") == "auto" - if not data.get("flash_attention") and not batch_flattening_auto: - raise ValueError("batch_flattening requires flash attention") - if data.get("sample_packing") and not batch_flattening_auto: - raise ValueError("batch_flattening not compatible with sample_packing") - if data.get("micro_batch_size") == 1 and not batch_flattening_auto: - LOG.warning("batch_flattening has no effect with micro_batch_size == 1") - - if ( - batch_flattening_auto - and data.get("flash_attention") - and not data.get("sample_packing") - and data.get("micro_batch_size") > 1 - ): - data["batch_flattening"] = True - elif batch_flattening_auto: - data["batch_flattening"] = False - - return data - - @model_validator(mode="before") - @classmethod - def check_sample_packing_w_rl(cls, data): - if data.get("sample_packing") and data.get("rl"): - raise ValueError("`sample_packing: true` does not work with RLHF training") - return data - - @model_validator(mode="before") - @classmethod - def hint_sample_packing_padding(cls, data): - if data.get("sample_packing"): - pad_to_sequence_len = data.get("pad_to_sequence_len") - if pad_to_sequence_len is False: - LOG.warning( - "`pad_to_sequence_len: true` is recommended when using sample_packing" - ) - elif pad_to_sequence_len is None: - LOG.info( - "Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing" - ) - data["pad_to_sequence_len"] = True - return data - - @model_validator(mode="before") - @classmethod - def hint_reward_model_pad(cls, data): - if data.get("reward_model") and not data.get("pad_to_sequence_len"): - LOG.warning( - "`pad_to_sequence_len: true` is recommended when using reward_model" - ) - if data.get("pad_to_sequence_len") is None: - data["pad_to_sequence_len"] = True - return data - - @model_validator(mode="before") - @classmethod - def check_gas_bsz(cls, data): - if data.get("gradient_accumulation_steps") and data.get("batch_size"): - raise ValueError( - "please set only one of gradient_accumulation_steps or batch_size" - ) - return data - - @model_validator(mode="before") - @classmethod - def hint_eval_train_mbsz(cls, data): - if ( - data.get("eval_batch_size") - and data.get("micro_batch_size") - and data.get("eval_batch_size") != data.get("micro_batch_size") - ): - LOG.warning( - "eval_batch_size != micro_batch_size. This can lead to VRAM instability." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_push_ds_auth(cls, data): - if ( - data.get("push_dataset_to_hub") - and data.get("hf_use_auth_token") is not True - ): - raise ValueError( - "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" - ) - return data - - @model_validator(mode="after") - def check_falcon_fsdp(self): - if (self.base_model and "falcon" in self.base_model.lower()) and self.fsdp: - raise ValueError("FSDP is not supported for falcon models") - return self - - @model_validator(mode="after") - def check_mpt_checkpointing(self): - if ( - self.base_model and "mpt" in self.base_model.lower() - ) and self.gradient_checkpointing: - raise ValueError("gradient_checkpointing is not supported for MPT models") - return self - - @model_validator(mode="after") - def check_offload_grad_checkpointing(self): - if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth": - LOG.warning( - "`unsloth` is deprecated for gradient_checkpointing, use `offload`" - ) - self.gradient_checkpointing = "offload" - return self - - @model_validator(mode="after") - def check_better_transformers(self): - if self.flash_optimum is True: - if self.adapter: - LOG.warning( - "BetterTransformers probably doesn't work with PEFT adapters" - ) - if self.fp16 or self.bf16: - raise ValueError("AMP is not supported with BetterTransformer") - if self.float16 is not True and self.bfloat16 is not True: - LOG.warning( - "You should probably set bfloat16 or float16 to true to " - "load the model in float16 for BetterTransformers" - ) - return self - - @model_validator(mode="after") - def check_adamw_optimizer_params(self): - if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and ( - not self.optimizer or "adamw" not in str(self.optimizer).lower() - ): - LOG.warning("adamw hyperparameters found, but no adamw optimizer set") - return self - - @model_validator(mode="before") - @classmethod - def check_lr_groups(cls, data): - if data.get("lr_groups") and data.get("loraplus_lr_ratio"): - raise ValueError("lr_groups and loraplus_lr_ratio cannot be used together.") - return data - - @model_validator(mode="before") - @classmethod - def check_saves(cls, data): - if ( - data.get("save_strategy") - and data.get("save_steps") - and data.get("save_strategy") != "steps" - ): - raise ValueError( - "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." - ) - if data.get("saves_per_epoch") and data.get("save_steps"): - raise ValueError( - "save_steps and saves_per_epoch are mutually exclusive and cannot be used together." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_push_save(cls, data): - if data.get("hub_model_id") and ( - data.get("save_strategy") not in ["steps", "epoch", None] - ): - LOG.warning( - "hub_model_id is set without any models being saved. To save a model, set save_strategy." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_evals(cls, data): - if ( - data.get("eval_strategy") - and data.get("eval_steps") - and data.get("eval_strategy") != "steps" - ): - raise ValueError( - "eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps." - ) - - if ( - data.get("val_set_size") == 0 - and (data.get("eval_steps") or data.get("eval_strategy")) - and not data.get("test_datasets") - and data.get("eval_strategy") != "no" - ): - raise ValueError( - "eval_steps and eval_strategy are not supported with val_set_size == 0" - ) - if data.get("evals_per_epoch") and data.get("eval_steps"): - raise ValueError( - "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together." - ) - if ( - data.get("evals_per_epoch") - and data.get("eval_strategy") - and data.get("eval_strategy") != "steps" - ): - raise ValueError( - "eval_strategy must be empty or set to `steps` when used with evals_per_epoch." - ) - - if data.get("do_bench_eval") and not ( - data.get("evals_per_epoch") or data.get("eval_steps") - ): - raise ValueError( - "do_bench_eval requires evals_per_epoch or eval_steps to be set." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_test_datasets_bench(cls, data): - if ( - data.get("do_bench_eval") - and not data.get("test_datasets") - and not data.get("val_set_size") - ): - LOG.warning( - "`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset." - ) - data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}] - return data - - @model_validator(mode="before") - @classmethod - def check_eval_packing(cls, data): - # TODO also should check test_datasets and val_set_size as we can skip - # if there are no eval datasets/splits - if ( - data.get("sample_packing") - and data.get("eval_table_size") - and data.get("eval_sample_packing") is not False - ): - raise ValueError( - "eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false." - ) - if ( - data.get("sample_packing") - and data.get("eval_sample_packing") is None - and not data.get("eval_table_size") - ): - LOG.info( - "explicitly setting `eval_sample_packing` to match `sample_packing`" - ) - data["eval_sample_packing"] = True - - if ( - data.get("sample_packing") - and data.get("eval_sample_packing") is False - and data.get("remove_unused_columns") is None - ): - LOG.info( - "setting `remove_unused_columns: false` for when sample_packing and eval_sample_packing don't match" - ) - data["remove_unused_columns"] = False - - return data - - @model_validator(mode="before") - @classmethod - def check_mm_prepare(cls, data): - if data.get("skip_prepare_dataset"): - if data.get("remove_unused_columns") is None: - LOG.info( - "setting `remove_unused_columns: false` for skip_prepare_dataset" - ) - data["remove_unused_columns"] = False - - return data - - @model_validator(mode="before") - @classmethod - def check_warmup(cls, data): - if data.get("warmup_steps") and data.get("warmup_ratio"): - raise ValueError("warmup_steps and warmup_ratio are mutually exclusive") - return data - - @model_validator(mode="before") - @classmethod - def check_neftune(cls, data): - if data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"): - data["neftune_noise_alpha"] = data["noisy_embedding_alpha"] - del data["noisy_embedding_alpha"] - elif data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"): - raise ValueError( - "noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting" - ) - return data - - @field_validator("neftune_noise_alpha") - @classmethod - def validate_neftune_noise_alpha(cls, neftune_noise_alpha): - if neftune_noise_alpha is not None and neftune_noise_alpha <= 0.0: - raise ValueError("neftune_noise_alpha must be > 0.0") - return neftune_noise_alpha - - @model_validator(mode="after") - def check_rl_beta(self): - if self.dpo_beta and not self.rl_beta: - self.rl_beta = self.dpo_beta - del self.dpo_beta - return self - - @model_validator(mode="after") - def check_simpo_warmup(self): - if self.rl is RLType.SIMPO and self.warmup_ratio: - raise ValueError( - "warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead" - ) - return self - - @model_validator(mode="before") - @classmethod - def check_frozen(cls, data): - if ( - data.get("adapter") - and data.get("peft_layers_to_transform") - and data.get("unfrozen_parameters") - ): - raise ValueError( - "`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior." - ) - - return data - - @model_validator(mode="before") - @classmethod - def check_peft_layers_pattern(cls, data): - if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"): - raise ValueError( - "peft_layers_pattern requires peft_layers_to_transform to be set" - ) - return data - - @model_validator(mode="after") - def check_fft_possible_bad_config(self): - if ( - # pylint: disable=too-many-boolean-expressions - not (self.bf16 or self.bfloat16) - and (self.fp16 or self.float16) - and not self.adapter - and not self.flash_attention - and self.sample_packing - ): - LOG.warning( - "Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA." - ) - # ValueError: Attempting to unscale FP16 gradients. - # OR - # RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half - return self - - @model_validator(mode="after") - def check_fused_lora(self): - if self.adapter in ["lora", "qlora"] and ( - self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp - ): - raise ValueError("Fused modules are not supported with LoRA/QLoRA") - return self - - @model_validator(mode="after") - def hint_lora_8bit(self): - loftq = ( - self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits - ) - if not self.load_in_8bit and self.adapter == "lora" and not loftq: - LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") - return self - - @model_validator(mode="after") - def check_early_stopping(self): - if self.early_stopping_patience: - if not self.save_steps or not self.eval_steps: - raise ValueError( - "`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps." - ) - if self.save_steps % self.eval_steps != 0: - raise ValueError( - "`early_stopping_patience` requires that eval_steps should evenly divide save_steps." - ) - return self - - @model_validator(mode="after") - def check_relora(self): - if self.relora_steps: - if self.adapter not in ("lora", "qlora"): - raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") - - if self.fsdp: - raise ValueError("fsdp not supported with ReLoRA") - - if self.deepspeed: - raise ValueError("deepspeed not supported with ReLoRA") - - if self.lr_scheduler == "one_cycle": - raise ValueError( - "ReLoRA is not compatible with the one_cycle scheduler" - ) - - if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp: - raise ValueError("Fused modules are not supported with ReLoRA") - return self - - @model_validator(mode="before") - @classmethod - def check_mem_mismatch(cls, data): - if ( - data.get("max_memory") is not None - and data.get("gpu_memory_limit") is not None - ): - raise ValueError( - "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_use_reentrant_mismatch(cls, data): - if ( - data.get("unfrozen_parameters") - and data.get("gradient_checkpointing_kwargs") - and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") - is True - ): - # https://github.com/huggingface/transformers/issues/21381 - raise ValueError( - "`use_reentrant` must be false when used with partially frozen model." - ) - return data - - @model_validator(mode="before") - @classmethod - def warn_qlora_zero3_w_use_reentrant(cls, data): - if ( - data.get("adapter") == "qlora" - and data.get("gradient_checkpointing_kwargs", {}) - and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") - is False - and data.get("deepspeed", "") is not None - and "zero3" in data.get("deepspeed", "") - ): - # may result in: - # torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: - # Recomputed values for the following tensors have different metadata - # than during the forward pass. - LOG.warning( - "qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_val_w_test_datasets(cls, data): - if data.get("test_datasets") and data.get("val_set_size"): - raise ValueError( - "non-zero val_set_size should not be used with test_datasets configuration" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_eval_strategy(cls, data): - if ( - data.get("evaluation_strategy") is not None - and data.get("eval_strategy") is None - ): - LOG.info( - "explicitly setting `eval_strategy` from the `evaluation_strategy`" - ) - data["eval_strategy"] = data.get("evaluation_strategy") - return data - - @model_validator(mode="before") - @classmethod - def check_fsdp_offload_w_8bit_optimizer(cls, data): - if ( - data.get("fsdp") - and "8bit" in data.get("optimizer", "") - and data.get("fsdp_config") - and data["fsdp_config"].get("fsdp_offload_params") - and str(data["fsdp_config"].get("fsdp_version")) != "2" - ): - raise ValueError( - f"FSDP Offload not compatible with {data.get('optimizer')}" - ) - if ( - data.get("fsdp") - and "8bit" in data.get("optimizer", "") - and data.get("fsdp_config") - and str(data["fsdp_config"].get("fsdp_version")) == "2" - ): - if data.get("optimizer", "") in ["adamw_8bit", "adamw_bnb_8bit"]: - # CUDA ops errors with bnb 8bit optimizer + FSDP2 - raise ValueError( - f"FSDP2 not compatible with {data.get('optimizer')}, use `adamw_torch_8bit` instead" - ) - - return data - - @model_validator(mode="before") - @classmethod - def check_fsdp_sharded_state_dict_w_safetensors(cls, data): - if ( - data.get("fsdp") - and data.get("save_safetensors") - and data.get("fsdp_config") - and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT" - ): - raise ValueError( - "FSDP SHARDED_STATE_DICT not compatible with save_safetensors" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_causal_lm_evals(cls, data): - if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"): - raise ValueError( - "do_causal_lm_eval is enabled, eval_sample_packing must be set to False" - ) - - if data.get("eval_causal_lm_metrics"): - if not isinstance(data.get("eval_causal_lm_metrics"), list): - raise ValueError("eval_causal_lm_metrics must be a list") - # only ["sacrebleu", "comet", "ter", "chrf"] supported - if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS: - raise ValueError( - f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_dataset_or_pretraining_dataset(cls, data): - if data.get("datasets") is None and data.get("pretraining_dataset") is None: - raise ValueError("either datasets or pretraining_dataset is required") - return data - - @model_validator(mode="before") - @classmethod - def check_xentropy_patch_conflicts(cls, data): - if data.get("flash_attn_cross_entropy") and data.get( - "unsloth_cross_entropy_loss" - ): - raise ValueError( - "flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_qlora_unsloth(cls, data): - if ( - data.get("unsloth_lora_mlp") - or data.get("unsloth_lora_qkv") - or data.get("unsloth_lora_o") - ): - if data.get("adapter") == "lora" and data.get("load_in_8bit"): - raise ValueError( - "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_lora_kernel_8bit(cls, data): - if ( - data.get("lora_mlp_kernel") - or data.get("lora_qkv_kernel") - or data.get("lora_o_kernel") - ): - if data.get("adapter") == "lora" and data.get("load_in_8bit"): - raise ValueError( - "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_lora_kernel_rl(cls, data): - if ( - data.get("lora_mlp_kernel") - or data.get("lora_qkv_kernel") - or data.get("lora_o_kernel") - ) and data.get("rl"): - raise ValueError( - "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with RL at the moment." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_lora_axolotl_unsloth(cls, data): - is_lora_kernel = any( - data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"] - ) - is_unsloth_lora = any( - data.get(k) - for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"] - ) - if is_lora_kernel and is_unsloth_lora: - raise ValueError( - "both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_torch_compile_deepspeed(cls, data): - if data.get("deepspeed") and data.get("torch_compile"): - raise ValueError( - "torch_compile should be set within your deepspeed config file" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_npu_config(cls, data): - if is_torch_npu_available(): - # check attention config - attn_list = ["flash_attention", "sdp_attention", "s2_attention"] - for attn in attn_list: - if data.get(attn): - raise NotImplementedError( - f"{attn} is currently not supported in Ascend npu, please disable this configuration." + for untrained_token_id in fix_untrained_tokens: + if untrained_token_id not in peft_trainable_token_indices: + LOG.warning_once( + f"Token {untrained_token_id} is fixed via `fix_untrained_tokens`, yet not in `peft_trainable_token_indices: ` list. " + "Please add it, otherwise the token won't be trained on." ) - - # check quant config - if data.get("optimizer") is not None and "bit" in data.get("optimizer"): - optimizer = data.get("optimizer") - raise NotImplementedError( - f"{optimizer} is currently not supported in Ascend npu, choose another one please." - ) - - quant_list = ["load_in_8bit", "load_in_4bit"] - for quant in quant_list: - if data.get(quant): - raise NotImplementedError( - f"Quantification is currently not supported in Ascend npu, please disable {quant}." - ) - - # check dtype config - if data.get("tf32"): - raise NotImplementedError( - "tf32 dtype is currently not supported in Ascend npu, please disable this configuration" - ) - - return data - - @model_validator(mode="before") - @classmethod - def check_rl_config_gradient_checkpointing(cls, data): - # TODO: SalmanMohammadi - # Distributed RL with QLoRA + gradient checkpointing - # and use_reentrant = True is broken upstream in TRL - # pylint: disable=too-many-boolean-expressions - if ( - data.get("rl") - and data.get("gradient_checkpointing") - and data.get("gradient_checkpointing_kwargs") - and data.get("gradient_checkpointing_kwargs").get("use_reentrant") - and data.get("load_in_4bit") - and data.get("adapter") == "qlora" - and data.get("capabilities") - and data.get("capabilities").get("n_gpu", 1) > 1 - ): - raise ValueError( - "The `use_reentrant: True` implementation of gradient checkpointing " - "is not supported for distributed RL training with QLoRA. Please set " - "`use_reentrant: False` in `gradient_checkpointing_kwargs`." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_kto_config(cls, data): - if data.get("rl") == "kto": - if data.get("sample_packing") or data.get("eval_sample_packing"): - raise ValueError("sample_packing is not supported with kto") - - if data.get("remove_unused_columns") is not False: - raise ValueError("Set `remove_unused_columns: False` when using kto") - - return data - - @model_validator(mode="before") - @classmethod - def check_grpo_liger_sequence_parallel(cls, data): - if ( - data.get("rl") == "grpo" - and data.get("trl", {}) - and data.get("trl").get("use_liger_loss") - and data.get("sequence_parallel_degree", 1) > 1 - ): - raise ValueError("GRPO + SP + Liger not currently supported") - return data - - @model_validator(mode="after") - def check_sequence_parallel_degree(self): - if not self.sequence_parallel_degree: - self.sequence_parallel_degree = 1 - elif self.sequence_parallel_degree > 1: - if not self.flash_attention: - raise ValueError( - "flash_attention: true must be set with sequence_parallel_degree > 1" - ) - - if self.sample_packing and self.micro_batch_size > 1: - raise ValueError( - "micro_batch_size must be set to 1 when sample_packing is enabled " - "due to a `ring-flash-attn` requirement" - ) - - try: - import ring_flash_attn # noqa: F401 # pylint:disable=unused-import - except ImportError as exception: - raise ImportError( - "sequence_parallel_degree > 1 but ring_flash_attn is not installed. " - "Please install it with `pip install axolotl[ring-flash-attn] " - "or `pip install ring-flash-attn>=0.1.4`." - ) from exception - - # TODO: monkeypatch / callback to average losses correctly across SP ranks - # / fix gradient scaling across SP ranks. Losses, grads should be scaled - # according to the proportion of non-padding tokens per rank. - LOG.warning( - "Sequence parallelism (SP) is enabled with " - f"sequence_parallel_degree={self.sequence_parallel_degree}. " - "Please note that logged losses may differ slightly to the non-SP " - "losses due to transformers Trainer implementation details. " - "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " - "for more details." - ) - - return self - - @model_validator(mode="after") - def validate_ring_attn_func(self): - if getattr(self, "sequence_parallel_degree", 1) == 1: - return self - - if self.ring_attn_func is not None: - self.ring_attn_func = RingAttnFunc(self.ring_attn_func) - else: - # Default ring attention function selection - sample_packing = getattr(self, "sample_packing", False) - self.ring_attn_func = ( - RingAttnFunc.VARLEN_LLAMA3 - if sample_packing - else RingAttnFunc.BATCH_RING - ) - - return self - - @model_validator(mode="before") - @classmethod - def check_muon_deepspeed_fsdp(cls, data): - if data.get("optimizer") == "muon" and ( - data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config") - ): - raise ValueError( - "Muon optimizer is currently incompatible with DeepSpeed and FSDP" - ) return data class AxolotlConfigWCapabilities(AxolotlInputConfig): - """Wrapper to validate GPU capabilities with the config options""" + """Wrapper to valdiate GPU capabilities with the configured options""" capabilities: GPUCapabilities env_capabilities: EnvCapabilities @@ -1303,13 +1103,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): return data - @model_validator(mode="before") - @classmethod - def check_fsdp_deepspeed(cls, data): - if data.get("deepspeed") and data.get("fsdp"): - raise ValueError("deepspeed and fsdp cannot be used together.") - return data - @model_validator(mode="before") @classmethod def check_multigpu_unsloth(cls, data): @@ -1334,11 +1127,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): or data.get("lora_o_kernel") ): capabilities = data.get("capabilities") - is_fsdp = data.get("fsdp") is not None - is_fsdp2 = ( - data.get("fsdp_config") is not None - and str(data.get("fsdp_config").get("fsdp_version")) == "2" - ) + is_fsdp = data.get("fsdp_config") is not None + is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2" + if capabilities and capabilities.get("n_gpu", 0) > 1 and not is_fsdp2: if is_fsdp: raise ValueError( @@ -1372,11 +1163,8 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): # Check multi-GPU compatibility capabilities = data.get("capabilities") is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1 - is_fsdp = data.get("fsdp") is not None - is_fsdp2 = ( - data.get("fsdp_config") is not None - and str(data.get("fsdp_config").get("fsdp_version")) == "2" - ) + is_fsdp = data.get("fsdp_config") is not None + is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2" if ( not is_multi_gpu @@ -1468,9 +1256,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): def check_min_torch_version(self): if self.env_capabilities and self.env_capabilities.torch_version: torch_version = self.env_capabilities.torch_version - if version.parse(torch_version) < version.parse("2.5.1"): + if version.parse(torch_version) < version.parse("2.6.0"): LOG.warning( - f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1." + f"torch=={torch_version} not be supported. Please upgrade to torch>=2.6.0." ) return self @@ -1499,17 +1287,69 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): torch_version = str(torch.__version__).split("+", maxsplit=1)[0] - if ( - data.get("fsdp") - and data.get("fsdp_config") - and str(data["fsdp_config"].get("fsdp_version")) == "2" - ): - if version.parse(torch_version) < version.parse("2.7.0"): - raise ValueError( - "FSDP2 and QAT are not supported on torch version < 2.7.0" - ) - if version.parse(torch_version) < version.parse("2.6.0"): raise ValueError("QAT is not supported on torch version < 2.6.0") return data + + @model_validator(mode="before") + @classmethod + def check_fsdp_torch_version(cls, data): + env_capabilities = data.get("env_capabilities", {}) + torch_version = env_capabilities.get("torch_version") + + if torch_version is None: + import torch + + torch_version = str(torch.__version__).split("+", maxsplit=1)[0] + + if data.get("fsdp_config") and str(data.get("fsdp_version")) == "2": + if version.parse(torch_version) < version.parse("2.7.0"): + raise ValueError("FSDP2 is not supported on torch version < 2.7.0") + + return data + + @model_validator(mode="before") + @classmethod + def default_dataloader_opts(cls, data): + if ( + data.get("dataloader_num_workers") is None + and data.get("dataloader_pin_memory") is None + and data.get("dataloader_prefetch_factor") is None + ): + data["dataloader_num_workers"] = data.get("capabilities").get("n_gpu", 1) + data["dataloader_pin_memory"] = True + data["dataloader_prefetch_factor"] = 256 + + return data + + @model_validator(mode="before") + @classmethod + def default_dataset_num_proc(cls, data): + if data.get("dataset_processes") is not None: + if data.get("dataset_num_proc") is None: + data["dataset_num_proc"] = data["dataset_processes"] + LOG.warning( + "dataset_processes is deprecated and will be removed in a future version. " + "Please use dataset_num_proc instead." + ) + else: + LOG.warning( + "Both dataset_processes and dataset_num_proc are set. " + "Using dataset_num_proc and ignoring dataset_processes." + ) + del data["dataset_processes"] + elif data.get("dataset_num_proc") is None: + data["dataset_num_proc"] = get_default_process_count() + return data + + @model_validator(mode="before") + @classmethod + def check_deduplication_with_streaming(cls, data): + if data.get("dataset_exact_deduplication") and ( + data.get("streaming") or data.get("pretraining_dataset") + ): + raise NotImplementedError( + "dataset_exact_deduplication is not available for streaming datasets. " + ) + return data diff --git a/src/axolotl/utils/schemas/datasets.py b/src/axolotl/utils/schemas/datasets.py index cc5d6daba..e32468706 100644 --- a/src/axolotl/utils/schemas/datasets.py +++ b/src/axolotl/utils/schemas/datasets.py @@ -1,6 +1,8 @@ """Pydantic models for datasets-related configuration""" -from pydantic import BaseModel, model_validator +from typing import Literal + +from pydantic import BaseModel, Field, model_validator from axolotl.utils.schemas.enums import ChatTemplate from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic @@ -9,56 +11,189 @@ from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic class UserDefinedPrompterType(BaseModel): """Structure for user defined prompt types""" - system_prompt: str | None = None - system_format: str | None = None + system_prompt: str | None = Field( + default=None, + json_schema_extra={"description": "Custom user instruction prompt"}, + ) + system_format: str | None = Field( + default=None, + json_schema_extra={"description": "Use {system} as key to be replaced"}, + ) field_system: str | None = None field_instruction: str | None = None field_input: str | None = None field_output: str | None = None - format: str | None = None - no_input_format: str | None = None - field: str | None = None + format: str | None = Field( + default=None, + json_schema_extra={ + "description": "Customizable to be single line or multi-line. Use {instruction}/{input} as key to be replaced. 'format' can include {input}" + }, + ) + no_input_format: str | None = Field( + default=None, + json_schema_extra={"description": "'no_input_format' cannot include {input}"}, + ) class SFTDataset(BaseModel): """SFT configuration subset""" - path: str | None = None - split: str | None = None - type: str | UserDefinedPrompterType | None = None + path: str | None = Field( + default=None, + json_schema_extra={ + "description": "HuggingFace dataset repo | s3:// | gs:// | path to local file or directory" + }, + ) + split: str | None = Field( + default=None, + json_schema_extra={"description": "name of dataset split to load from"}, + ) + type: str | UserDefinedPrompterType | None = Field( + default=None, + json_schema_extra={ + "description": "The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]" + }, + ) input_transform: str | None = None - shards: int | None = None - shards_idx: int | None = None - preprocess_shards: int | None = None + shards: int | None = Field( + default=None, + json_schema_extra={ + "description": "split dataset into N pieces (use with shards_idx)" + }, + ) + shards_idx: int | None = Field( + default=None, + json_schema_extra={"description": "the index of sharded dataset to use"}, + ) + preprocess_shards: int | None = Field( + default=None, + json_schema_extra={ + "description": "process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)" + }, + ) conversation: str | None = None # Do not make this too strict or it will break the validator to choose different dataset class - chat_template: ChatTemplate | str | None = None - chat_template_jinja: str | None = None - data_files: str | list[str] | None = None + chat_template: ChatTemplate | str | None = Field( + default=None, + json_schema_extra={ + "description": "The name of the chat template to use for training, following values are supported: tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default. alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py. tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml. jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field." + }, + ) + chat_template_jinja: str | None = Field( + default=None, + json_schema_extra={ + "description": "Custom jinja chat template or path to jinja file. Used only if `chat_template: jinja` or empty." + }, + ) + data_files: str | list[str] | None = Field( + default=None, json_schema_extra={"description": "path to source data files"} + ) input_format: str | None = None - name: str | None = None - ds_type: str | None = None - field: str | None = None + name: str | None = Field( + default=None, + json_schema_extra={"description": "name of dataset configuration to load"}, + ) + ds_type: str | None = Field( + default=None, + json_schema_extra={"description": "defines the datatype when path is a file"}, + ) + field: str | None = Field( + default=None, + json_schema_extra={ + "description": "For `completion` datasets only, uses the provided field instead of `text` column" + }, + ) field_human: str | None = None field_model: str | None = None - field_messages: str | None = None + field_messages: str | None = Field( + default=None, + json_schema_extra={ + "description": 'Key containing the messages (default: "messages")' + }, + ) + field_tools: str | None = Field( + default=None, + json_schema_extra={ + "description": 'Key containing the tools (default: "tools"). Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).' + }, + ) + field_thinking: str | None = Field( + default=None, + json_schema_extra={ + "description": 'Key containing the reasoning trace (default: "reasoning_content").' + }, + ) + template_thinking_key: str | None = Field( + default=None, + json_schema_extra={ + "description": "The key the chat template expects that indicates the reasoning trace." + }, + ) # deprecated, use message_property_mappings message_field_role: str | None = None # deprecated, use message_property_mappings message_field_content: str | None = None - message_property_mappings: dict[str, str] | None = None - message_field_training: str | None = None - message_field_training_detail: str | None = None - split_thinking: bool | None = None + message_property_mappings: dict[str, str] | None = Field( + default=None, + json_schema_extra={ + "description": "Mapping of properties from the input dataset to the chat template. (default: message_property_mappings={'role':'role', 'content':'content'}) If a property exists in the template but not in this mapping, the system will attempt to load it directly from the message using the property name as the key. Example: In the mapping below, 'from' is loaded from input dataset and used as 'role', while 'value' is loaded and used as 'content' in the chat template." + }, + ) + message_field_training: str | None = Field( + default=None, + json_schema_extra={ + "description": "The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`." + }, + ) + message_field_training_detail: str | None = Field( + default=None, + json_schema_extra={ + "description": "The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn. The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train)." + }, + ) + split_thinking: bool | None = Field( + default=None, + json_schema_extra={ + "description": "(for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags" + }, + ) logprobs_field: str | None = None temperature: float | None = None - roles_to_train: list[str] | None = None - train_on_eos: str | None = None - roles: dict[str, list[str]] | None = None - drop_system_message: bool | None = None - trust_remote_code: bool | None = False - revision: str | None = None + roles_to_train: list[str] | None = Field( + default=None, + json_schema_extra={ + "description": "Roles to train on. The tokens from these roles will be considered for the loss." + }, + ) + train_on_eos: Literal["all", "turn", "last"] | None = Field( + default=None, + json_schema_extra={ + "description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation" + }, + ) + roles: dict[str, list[str]] | None = Field( + default=None, + json_schema_extra={ + "description": 'Roles mapping in the messages. The format is {target_role: [source_roles]}. All source roles will be mapped to the target role. The default is: user: ["human", "user"], assistant: ["gpt", "assistant"], system: ["system"], tool: ["tool"]' + }, + ) + drop_system_message: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to drop the system turn from the dataset. Only works with chat_template. This does not drop the default system message from chat_template if it exists. If you wish to, we recommend using a custom jinja template with the default system message removed or adding a system turn with empty content." + }, + ) + trust_remote_code: bool | None = Field( + default=False, + json_schema_extra={"description": "Trust remote code for untrusted source"}, + ) + revision: str | None = Field( + default=None, + json_schema_extra={ + "description": "The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets." + }, + ) @model_validator(mode="before") @classmethod @@ -68,7 +203,6 @@ class SFTDataset(BaseModel): @model_validator(mode="before") @classmethod - # pylint: disable=duplicate-code def check_chat_template_config(cls, data): if isinstance(data, BaseModel): data = data.model_dump() diff --git a/src/axolotl/utils/schemas/deprecated.py b/src/axolotl/utils/schemas/deprecated.py index b8904136e..972fe0ccf 100644 --- a/src/axolotl/utils/schemas/deprecated.py +++ b/src/axolotl/utils/schemas/deprecated.py @@ -60,10 +60,30 @@ class RemappedParameters(BaseModel): """Parameters that have been remapped to other names""" overrides_of_model_config: dict[str, Any] | None = Field( - default=None, alias="model_config" + default=None, + alias="model_config", + json_schema_extra={ + "description": "optional overrides to the base model configuration" + }, ) overrides_of_model_kwargs: dict[str, Any] | None = Field( - default=None, alias="model_kwargs" + default=None, + alias="model_kwargs", + json_schema_extra={ + "description": "optional overrides the base model loading from_pretrained" + }, + ) + type_of_model: str | None = Field( + default=None, + alias="model_type", + json_schema_extra={ + "description": "If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too" + }, + ) + revision_of_model: str | None = Field( + default=None, + alias="model_revision", + json_schema_extra={ + "description": "You can specify to choose a specific model revision from huggingface hub" + }, ) - type_of_model: str | None = Field(default=None, alias="model_type") - revision_of_model: str | None = Field(default=None, alias="model_revision") diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index d09ab6387..bcd03e1a2 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -5,84 +5,90 @@ from enum import Enum import torch -class TorchIntDType(Enum): - """Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4""" +class TorchAOQuantDType(Enum): + int4 = torch.int4 + int8 = torch.int8 + float8_e4m3fn = torch.float8_e4m3fn + nvfp4 = "nvfp4" - uint1 = getattr(torch, "uint1", None) # pylint: disable=invalid-name - uint2 = getattr(torch, "uint2", None) # pylint: disable=invalid-name - uint3 = getattr(torch, "uint3", None) # pylint: disable=invalid-name - uint4 = getattr(torch, "uint4", None) # pylint: disable=invalid-name - uint5 = getattr(torch, "uint5", None) # pylint: disable=invalid-name - uint6 = getattr(torch, "uint6", None) # pylint: disable=invalid-name - uint7 = getattr(torch, "uint7", None) # pylint: disable=invalid-name - int4 = getattr(torch, "int4", None) # pylint: disable=invalid-name - int8 = getattr(torch, "int8", None) # pylint: disable=invalid-name + def from_string(str): + if str == "int4": + return TorchAOQuantDType.int4 + if str == "int8": + return TorchAOQuantDType.int8 + if str in ["float8_e4m3fn", "fp8", "float8"]: + return TorchAOQuantDType.float8_e4m3fn + if str == "nvfp4": + return TorchAOQuantDType.nvfp4 class RLType(str, Enum): """RL trainer type configuration subset""" - DPO = "dpo" # pylint: disable=invalid-name - GRPO = "grpo" # pylint: disable=invalid-name - IPO = "ipo" # pylint: disable=invalid-name - ORPO = "orpo" # pylint: disable=invalid-name - KTO = "kto" # pylint: disable=invalid-name - SIMPO = "simpo" # pylint: disable=invalid-name + DPO = "dpo" + GRPO = "grpo" + IPO = "ipo" + ORPO = "orpo" + KTO = "kto" + SIMPO = "simpo" class ChatTemplate(str, Enum): """Chat templates configuration subset""" - alpaca = "alpaca" # pylint: disable=invalid-name - chatml = "chatml" # pylint: disable=invalid-name - mistral_v1 = "mistral_v1" # pylint: disable=invalid-name - mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name - mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name - mistral_v7_tekken = "mistral_v7_tekken" # pylint: disable=invalid-name - gemma = "gemma" # pylint: disable=invalid-name - cohere = "cohere" # pylint: disable=invalid-name - llama3 = "llama3" # pylint: disable=invalid-name - llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name - llama4 = "llama4" # pylint: disable=invalid-name - phi_3 = "phi_3" # pylint: disable=invalid-name - phi_35 = "phi_35" # pylint: disable=invalid-name - deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name - deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name - jamba = "jamba" # pylint: disable=invalid-name - jinja = "jinja" # pylint: disable=invalid-name - qwen_25 = "qwen_25" # pylint: disable=invalid-name - qwen3 = "qwen3" # pylint: disable=invalid-name - tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name - exaone = "exaone" # pylint: disable=invalid-name - metharme = "metharme" # pylint: disable=invalid-name - pixtral = "pixtral" # pylint: disable=invalid-name - llava = "llava" # pylint: disable=invalid-name - qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name - gemma3 = "gemma3" # pylint: disable=invalid-name - command_a = "command_a" # pylint: disable=invalid-name - command_a_tool_use = "command_a_tool_use" # pylint: disable=invalid-name - command_a_rag = "command_a_rag" # pylint: disable=invalid-name - aya = "aya" # pylint: disable=invalid-name + alpaca = "alpaca" + chatml = "chatml" + mistral_v1 = "mistral_v1" + mistral_v2v3 = "mistral_v2v3" + mistral_v3_tekken = "mistral_v3_tekken" + mistral_v7_tekken = "mistral_v7_tekken" + gemma = "gemma" + cohere = "cohere" + llama3 = "llama3" + llama3_2_vision = "llama3_2_vision" + llama4 = "llama4" + phi_3 = "phi_3" + phi_35 = "phi_35" + deepseek_v2 = "deepseek_v2" + deepseek_v3 = "deepseek_v3" + jamba = "jamba" + jinja = "jinja" + qwen_25 = "qwen_25" + qwen3 = "qwen3" + falcon_h1 = "falcon_h1" + tokenizer_default = "tokenizer_default" + exaone = "exaone" + metharme = "metharme" + pixtral = "pixtral" + llava = "llava" + qwen2_vl = "qwen2_vl" + gemma3 = "gemma3" + gemma3n = "gemma3n" + command_a = "command_a" + command_a_tool_use = "command_a_tool_use" + command_a_rag = "command_a_rag" + aya = "aya" class CustomSupportedOptimizers(str, Enum): """Custom supported optimizers""" - optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name - ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name - ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name - ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name - adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name - came_pytorch = "came_pytorch" # pylint: disable=invalid-name - muon = "muon" # pylint: disable=invalid-name + optimi_adamw = "optimi_adamw" + ao_adamw_4bit = "ao_adamw_4bit" + ao_adamw_8bit = "ao_adamw_8bit" + ao_adamw_fp8 = "ao_adamw_fp8" + adopt_adamw = "adopt_adamw" + came_pytorch = "came_pytorch" + muon = "muon" + dion = "dion" class RingAttnFunc(str, Enum): """Enum class for supported `ring-flash-attn` implementations""" - # VARLEN_RING = "varlen_ring" - # VARLEN_ZIGZAG = "varlen_zigzag" VARLEN_LLAMA3 = "varlen_llama3" BATCH_RING = "batch_ring" + # VARLEN_RING = "varlen_ring" + # VARLEN_ZIGZAG = "varlen_zigzag" # BATCH_ZIGZAG = "batch_zigzag" # BATCH_STRIPE = "batch_stripe" diff --git a/src/axolotl/utils/schemas/fsdp.py b/src/axolotl/utils/schemas/fsdp.py new file mode 100644 index 000000000..f34f40e8e --- /dev/null +++ b/src/axolotl/utils/schemas/fsdp.py @@ -0,0 +1,71 @@ +""" +FSDP Configuration Schema +""" + +from typing import Literal + +from pydantic import BaseModel, Field + + +class FSDPConfig(BaseModel): + """ + FSDP Configuration Schema + """ + + activation_checkpointing: bool | None = Field( + default=None, + description="Enable activation checkpointing to reduce memory usage during forward passes", + ) + offload_params: bool | None = Field( + default=None, + description="Offload parameters to CPU to reduce GPU memory usage", + ) + sync_module_states: bool | None = Field( + default=None, + description="Synchronize module states across all processes", + ) + cpu_ram_efficient_loading: bool | None = Field( + default=None, + description="Enable CPU RAM efficient loading to reduce memory usage during model loading", + ) + cpu_offload_pin_memory: bool | None = Field( + default=None, + description="Disabling this enables swap memory usage for resource-constrained setups when offload_params is enabled.", + ) + use_orig_params: bool | None = Field( + default=None, + description="Use original parameters instead of flattened parameters", + ) + + state_dict_type: ( + Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None + ) = Field( + default=None, + description="Type of state dict to use for saving/loading checkpoints", + ) + final_state_dict_type: ( + Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None + ) = Field( + default=None, + description="Final state dict type to use after training completion", + ) + + auto_wrap_policy: Literal["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP"] | None = ( + Field( + default=None, + description="Policy for automatically wrapping modules with FSDP", + ) + ) + transformer_layer_cls_to_wrap: str | None = Field( + default=None, + description="Class name of transformer layers to wrap (e.g., 'LlamaDecoderLayer')", + ) + + reshard_after_forward: bool | None = Field( + default=None, + description="Reshard parameters after forward pass to save memory", + ) + mixed_precision_policy: str | None = Field( + default=None, + description="Mixed precision policy for FSDP (e.g., 'fp16', 'bf16')", + ) diff --git a/src/axolotl/utils/schemas/integrations.py b/src/axolotl/utils/schemas/integrations.py index 4843e3592..7332c7d39 100644 --- a/src/axolotl/utils/schemas/integrations.py +++ b/src/axolotl/utils/schemas/integrations.py @@ -13,10 +13,21 @@ class MLFlowConfig(BaseModel): """MLFlow configuration subset""" use_mlflow: bool | None = None - mlflow_tracking_uri: str | None = None - mlflow_experiment_name: str | None = None - mlflow_run_name: str | None = None - hf_mlflow_log_artifacts: bool | None = None + mlflow_tracking_uri: str | None = Field( + default=None, json_schema_extra={"description": "URI to mlflow"} + ) + mlflow_experiment_name: str | None = Field( + default=None, json_schema_extra={"description": "Your experiment name"} + ) + mlflow_run_name: str | None = Field( + default=None, json_schema_extra={"description": "Your run name"} + ) + hf_mlflow_log_artifacts: bool | None = Field( + default=None, + json_schema_extra={ + "description": "set to true to copy each saved checkpoint on each save to mlflow artifact registry" + }, + ) class LISAConfig(BaseModel): @@ -40,13 +51,33 @@ class WandbConfig(BaseModel): """Wandb configuration subset""" use_wandb: bool | None = None - wandb_name: str | None = None - wandb_run_id: str | None = None - wandb_mode: str | None = None - wandb_project: str | None = None - wandb_entity: str | None = None + wandb_name: str | None = Field( + default=None, + json_schema_extra={"description": "Set the name of your wandb run"}, + ) + wandb_run_id: str | None = Field( + default=None, json_schema_extra={"description": "Set the ID of your wandb run"} + ) + wandb_mode: str | None = Field( + default=None, + json_schema_extra={ + "description": '"offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb' + }, + ) + wandb_project: str | None = Field( + default=None, json_schema_extra={"description": "Your wandb project name"} + ) + wandb_entity: str | None = Field( + default=None, + json_schema_extra={"description": "A wandb Team name if using a Team"}, + ) wandb_watch: str | None = None - wandb_log_model: str | None = None + wandb_log_model: str | None = Field( + default=None, + json_schema_extra={ + "description": '"checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training' + }, + ) @model_validator(mode="before") @classmethod @@ -64,14 +95,52 @@ class WandbConfig(BaseModel): class CometConfig(BaseModel): """Comet configuration subset""" - use_comet: bool | None = None - comet_api_key: str | None = None - comet_workspace: str | None = None - comet_project_name: str | None = None - comet_experiment_key: str | None = None - comet_mode: str | None = None - comet_online: bool | None = None - comet_experiment_config: dict[str, Any] | None = None + use_comet: bool | None = Field( + default=None, + json_schema_extra={"description": "Enable or disable Comet integration."}, + ) + comet_api_key: str | None = Field( + default=None, + json_schema_extra={ + "description": "API key for Comet. Recommended to set via `comet login`." + }, + ) + comet_workspace: str | None = Field( + default=None, + json_schema_extra={ + "description": "Workspace name in Comet. Defaults to the user's default workspace." + }, + ) + comet_project_name: str | None = Field( + default=None, + json_schema_extra={ + "description": "Project name in Comet. Defaults to Uncategorized." + }, + ) + comet_experiment_key: str | None = Field( + default=None, + json_schema_extra={ + "description": "Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key." + }, + ) + comet_mode: str | None = Field( + default=None, + json_schema_extra={ + "description": 'Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.' + }, + ) + comet_online: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Set to True to log data to Comet server, or False for offline storage. Default is True." + }, + ) + comet_experiment_config: dict[str, Any] | None = Field( + default=None, + json_schema_extra={ + "description": "Dictionary for additional configuration settings, see the doc for more details." + }, + ) class GradioConfig(BaseModel): diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py index 57f5ae309..04312eedd 100644 --- a/src/axolotl/utils/schemas/model.py +++ b/src/axolotl/utils/schemas/model.py @@ -1,10 +1,12 @@ """Pydantic models for model input / output, etc. configuration""" +from typing import Any, Literal + from pydantic import BaseModel, Field, field_validator from axolotl.utils.logging import get_logger -LOG = get_logger(__name__, use_environ=True) +LOG = get_logger(__name__) class ModelInputConfig(BaseModel): @@ -12,19 +14,82 @@ class ModelInputConfig(BaseModel): model_config = {"protected_namespaces": ()} - base_model: str - base_model_config: str | None = None + base_model: str = Field( + json_schema_extra={ + "description": "This is the huggingface model that contains *.pt, *.safetensors, or *.bin files. This can also be a relative path to a model on disk" + } + ) + base_model_config: str | None = Field( + default=None, + json_schema_extra={ + "description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model" + }, + ) cls_model_config: str | None = None - tokenizer_config: str | None = None - tokenizer_use_fast: bool | None = None - tokenizer_legacy: bool | None = None + tokenizer_config: str | None = Field( + default=None, + json_schema_extra={ + "description": "Optional tokenizer configuration path in case you want to use a different tokenizer than the one defined in the base model" + }, + ) + tokenizer_use_fast: bool | None = Field( + default=None, + json_schema_extra={ + "description": "use_fast option for tokenizer loading from_pretrained, default to True" + }, + ) + tokenizer_legacy: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use the legacy tokenizer setting, defaults to True" + }, + ) + tokenizer_use_mistral_common: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use mistral-common tokenizer. If set to True, it will use the mistral-common tokenizer." + }, + ) tokenizer_type: str | None = Field( - default=None, json_schema_extra={"description": "transformers tokenizer class"} + default=None, + json_schema_extra={ + "description": "Corresponding tokenizer for the model AutoTokenizer is a good choice" + }, ) processor_type: str | None = Field( default=None, json_schema_extra={"description": "transformers processor class"} ) - trust_remote_code: bool | None = None + tokenizer_save_jinja_files: bool | None = Field( + default=True, # match the default behavior from transformers + json_schema_extra={ + "description": "Whether to save jinja files for tokenizer, transformers default is True" + }, + ) + trust_remote_code: bool | None = Field( + default=None, + json_schema_extra={"description": "Trust remote code for untrusted source"}, + ) + + experimental_skip_move_to_device: bool | None = Field( + default=True, + json_schema_extra={ + "description": "Don't move the model to the device before sharding. Set to `false` to revert to legacy behavior." + }, + ) + + use_kernels: bool | None = Field( + default=None, + json_schema_extra={"description": "Use custom kernels, e.g. MegaBlocks."}, + ) + + model_quantization_config: Literal["Mxfp4Config"] | None = Field( + default=None, + json_schema_extra={"description": "Model loading quantization config"}, + ) + model_quantization_config_kwargs: dict[str, Any] | None = Field( + default=None, + json_schema_extra={"description": "kwargs for model quantization config"}, + ) @field_validator("trust_remote_code") @classmethod @@ -39,10 +104,23 @@ class ModelInputConfig(BaseModel): class ModelOutputConfig(BaseModel): """model save configuration subset""" - output_dir: str = Field(default="./model-out") - hub_model_id: str | None = None - hub_strategy: str | None = None - save_safetensors: bool | None = True + output_dir: str = Field( + default="./model-out", + json_schema_extra={"description": "Where to save the full-finetuned model to"}, + ) + hub_model_id: str | None = Field( + default=None, json_schema_extra={"description": "push checkpoints to hub"} + ) + hub_strategy: str | None = Field( + default=None, + json_schema_extra={"description": "how to push checkpoints to hub"}, + ) + save_safetensors: bool | None = Field( + default=True, + json_schema_extra={ + "description": "Save model as safetensors (require safetensors package). Default True" + }, + ) class SpecialTokensConfig(BaseModel): diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index 5d408e1fe..af22913fd 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -9,7 +9,7 @@ class LoftQConfig(BaseModel): """LoftQ configuration subset""" loftq_bits: int = Field( - default=4, json_schema_extra={"description": "Quantization bits for LoftQ"} + default=4, json_schema_extra={"description": "typically 4 bits"} ) # loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"}) @@ -17,31 +17,89 @@ class LoftQConfig(BaseModel): class PeftConfig(BaseModel): """peftq configuration subset""" - loftq_config: LoftQConfig | None = None + loftq_config: LoftQConfig | None = Field( + default=None, + json_schema_extra={ + "description": "Configuration options for loftq initialization for LoRA" + }, + ) class LoraConfig(BaseModel): """Peft / LoRA configuration subset""" - load_in_8bit: bool | None = Field(default=False) - load_in_4bit: bool | None = Field(default=False) + load_in_8bit: bool | None = Field( + default=False, + json_schema_extra={ + "description": "This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer" + }, + ) + load_in_4bit: bool | None = Field( + default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"} + ) - adapter: str | None = None - lora_model_dir: str | None = None + adapter: str | None = Field( + default=None, + json_schema_extra={ + "description": "If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model" + }, + ) + lora_model_dir: str | None = Field( + default=None, + json_schema_extra={ + "description": "If you already have a lora model trained that you want to load, put that here. This means after training, if you want to test the model, you should set this to the value of `output_dir`. Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`." + }, + ) lora_r: int | None = None lora_alpha: int | None = None lora_fan_in_fan_out: bool | None = None lora_target_modules: str | list[str] | None = None - lora_target_linear: bool | None = None - lora_modules_to_save: list[str] | None = None + lora_target_parameters: str | list[str] | None = None + lora_target_linear: bool | None = Field( + default=None, + json_schema_extra={"description": "If true, will target all linear modules"}, + ) + lora_modules_to_save: list[str] | None = Field( + default=None, + json_schema_extra={ + "description": "If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens. For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models. `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities." + }, + ) lora_dropout: float | None = 0.0 - peft_layers_to_transform: list[int] | None = None + peft_layers_to_transform: list[int] | None = Field( + default=None, + json_schema_extra={ + "description": "The layer indices to transform, otherwise, apply to all layers" + }, + ) peft_layers_pattern: list[str] | None = None peft: PeftConfig | None = None - peft_use_dora: bool | None = None - peft_use_rslora: bool | None = None - peft_layer_replication: list[tuple[int, int]] | None = None - peft_init_lora_weights: bool | str | None = None + peft_use_dora: bool | None = Field( + default=None, json_schema_extra={"description": "Whether to use DoRA."} + ) + peft_use_rslora: bool | None = Field( + default=None, json_schema_extra={"description": "Whether to use RSLoRA."} + ) + peft_layer_replication: list[tuple[int, int]] | None = Field( + default=None, + json_schema_extra={"description": "List of layer indices to replicate."}, + ) + peft_init_lora_weights: bool | str | None = Field( + default=None, + json_schema_extra={ + "description": "How to initialize LoRA weights. Default to True which is MS original implementation." + }, + ) + peft_trainable_token_indices: list[int] | dict[str, list[int]] | None = Field( + default=None, + json_schema_extra={ + "description": ( + "A list of token indices to fine-tune on the `embed_tokens` layer.\n" + "Otherwise, a dict mapping an embedding layer name to its trainable token indices.\n" + "See https://huggingface.co/docs/peft/v0.17.0/en/developer_guides/lora#efficiently-train-tokens-alongside-lora" + ) + }, + ) qlora_sharded_model_loading: bool | None = Field( default=False, @@ -49,9 +107,24 @@ class LoraConfig(BaseModel): "description": "load qlora model in sharded format for FSDP using answer.ai technique." }, ) - lora_on_cpu: bool | None = None - gptq: bool | None = None - bnb_config_kwargs: dict[str, Any] | None = None + lora_on_cpu: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge" + }, + ) + gptq: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether you are training a 4-bit GPTQ quantized model" + }, + ) + bnb_config_kwargs: dict[str, Any] | None = Field( + default=None, + json_schema_extra={ + "description": "optional overrides to the bnb 4bit quantization configuration" + }, + ) loraplus_lr_ratio: float | None = Field( default=None, @@ -62,7 +135,7 @@ class LoraConfig(BaseModel): loraplus_lr_embedding: float | None = Field( default=1e-6, json_schema_extra={ - "description": "loraplus learning rate for lora embedding layers." + "description": "loraplus learning rate for lora embedding layers. Default value is 1e-6." }, ) @@ -125,8 +198,21 @@ class LoraConfig(BaseModel): class ReLoRAConfig(BaseModel): """ReLoRA configuration subset""" - relora_steps: int | None = None - relora_warmup_steps: int | None = None - relora_anneal_steps: int | None = None - relora_prune_ratio: float | None = None - relora_cpu_offload: bool | None = None + relora: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use ReLoRA. Use with jagged_restart_*steps options." + }, + ) + relora_prune_ratio: float | None = Field( + default=None, + json_schema_extra={ + "description": "threshold for optimizer magnitude when pruning" + }, + ) + relora_cpu_offload: bool | None = Field( + default=None, + json_schema_extra={ + "description": "True to perform lora weight merges on cpu during restarts, for modest gpu memory savings" + }, + ) diff --git a/src/axolotl/utils/schemas/quantization.py b/src/axolotl/utils/schemas/quantization.py index fe2cdb1fe..a7c130574 100644 --- a/src/axolotl/utils/schemas/quantization.py +++ b/src/axolotl/utils/schemas/quantization.py @@ -6,7 +6,23 @@ from typing import Any from pydantic import BaseModel, Field, field_validator -from axolotl.utils.schemas.enums import TorchIntDType +from axolotl.utils.schemas.enums import TorchAOQuantDType + + +def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None: + if v is None: + return None + if v == "int4": + return TorchAOQuantDType.int4 + if v == "int8": + return TorchAOQuantDType.int8 + if v in ["float8_e4m3fn", "fp8", "float8"]: + return TorchAOQuantDType.float8_e4m3fn + if v == "nvfp4": + return TorchAOQuantDType.nvfp4 + raise ValueError( + f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}" + ) class QATConfig(BaseModel): @@ -14,28 +30,29 @@ class QATConfig(BaseModel): QAT Config Schema """ - activation_dtype: TorchIntDType | None = Field( - default=None, description="Activation dtype" + activation_dtype: TorchAOQuantDType | None = Field( + default=None, + description="Fake quantization layout to use for activation quantization.", ) - weight_dtype: TorchIntDType = Field( - default=TorchIntDType.int8, description="Weight dtype" + weight_dtype: TorchAOQuantDType = Field( + default=TorchAOQuantDType.int8, + description="Fake quantization layout to use for weight quantization.", ) quantize_embedding: bool | None = Field( default=False, description="Quantize embedding" ) - group_size: int | None = Field(default=32, description="Group size") + group_size: int | None = Field( + default=32, + description="The number of elements in each group for per-group fake quantization", + ) fake_quant_after_n_steps: int | None = Field( - default=None, description="Fake quant after n steps" + default=None, description="The number of steps to apply fake quantization after" ) @field_validator("activation_dtype", "weight_dtype", mode="before") @classmethod - def validate_dtype(cls, v: Any) -> TorchIntDType | None: - if v == "int4": - return TorchIntDType.int4 - if v == "int8": - return TorchIntDType.int8 - raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']") + def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None: + return validate_ao_dtype(v) class PTQConfig(BaseModel): @@ -43,22 +60,23 @@ class PTQConfig(BaseModel): PTQ Config Schema """ - weight_dtype: TorchIntDType = Field( - default=TorchIntDType.int8, description="Weight dtype" + weight_dtype: TorchAOQuantDType = Field( + default=TorchAOQuantDType.int8, + description="Fake quantization layout to use for weight quantization.", ) - activation_dtype: TorchIntDType | None = Field( - default=None, description="Activation dtype" + activation_dtype: TorchAOQuantDType | None = Field( + default=None, + description="Fake quantization layout to use for activation quantization.", ) quantize_embedding: bool | None = Field( - default=None, description="Quantize embedding" + default=None, description="Whether to quantize the embedding layer." + ) + group_size: int | None = Field( + default=32, + description="The number of elements in each group for per-group fake quantization", ) - group_size: int | None = Field(default=32, description="Group size") @field_validator("activation_dtype", "weight_dtype", mode="before") @classmethod - def validate_dtype(cls, v: Any) -> TorchIntDType | None: - if v == "int4": - return TorchIntDType.int4 - if v == "int8": - return TorchIntDType.int8 - raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']") + def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None: + return validate_ao_dtype(v) diff --git a/src/axolotl/utils/schemas/training.py b/src/axolotl/utils/schemas/training.py index ad7f899ac..8e06e82cb 100644 --- a/src/axolotl/utils/schemas/training.py +++ b/src/axolotl/utils/schemas/training.py @@ -23,10 +23,17 @@ class LrGroup(BaseModel): class HyperparametersConfig(BaseModel): """Training hyperparams configuration subset""" - gradient_accumulation_steps: int | None = Field(default=1) + gradient_accumulation_steps: int | None = Field( + default=1, + json_schema_extra={ + "description": "If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps." + }, + ) micro_batch_size: int | None = Field( default=1, - json_schema_extra={"description": "per gpu micro batch size for training"}, + json_schema_extra={ + "description": "The number of samples to include in each batch. This is the number of samples sent to each GPU. Batch size per gpu = micro_batch_size * gradient_accumulation_steps" + }, ) batch_size: int | None = Field( default=None, @@ -41,45 +48,119 @@ class HyperparametersConfig(BaseModel): }, ) - auto_find_batch_size: bool | None = None + auto_find_batch_size: bool | None = Field( + default=None, + json_schema_extra={ + "description": "whether to find batch size that fits in memory. Passed to underlying transformers Trainer" + }, + ) - train_on_inputs: bool | None = False - group_by_length: bool | None = None + train_on_inputs: bool | None = Field( + default=False, + json_schema_extra={ + "description": "Whether to mask out or include the human's prompt from the training labels" + }, + ) + group_by_length: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Group similarly sized data to minimize padding. May be slower to start, as it must download and sort the entire dataset. Note that training loss may have an oscillating pattern with this enabled." + }, + ) learning_rate: str | float embedding_lr: float | None = None embedding_lr_scale: float | None = None - weight_decay: float | None = 0.0 - optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = ( - OptimizerNames.ADAMW_TORCH_FUSED + weight_decay: float | None = Field( + default=0.0, json_schema_extra={"description": "Specify weight decay"} + ) + optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = Field( + default=OptimizerNames.ADAMW_TORCH_FUSED, + json_schema_extra={"description": "Specify optimizer"}, ) optim_args: (str | dict[str, Any]) | None = Field( default=None, - json_schema_extra={"description": "Optional arguments to supply to optimizer."}, + json_schema_extra={ + "description": "Dictionary of arguments to pass to the optimizer" + }, ) optim_target_modules: (list[str] | Literal["all_linear"]) | None = Field( default=None, json_schema_extra={ - "description": "The target modules to optimize, i.e. the module names that you would like to train." + "description": "The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm" }, ) - torchdistx_path: str | None = None - lr_scheduler: (SchedulerType | Literal["one_cycle"] | Literal["rex"]) | None = ( - SchedulerType.COSINE + torchdistx_path: str | None = Field( + default=None, + json_schema_extra={ + "description": "Path to torch distx for optim 'adamw_anyprecision'" + }, + ) + lr_scheduler: ( + SchedulerType | Literal["one_cycle"] | Literal["rex"] + ) | None = SchedulerType.COSINE + lr_scheduler_kwargs: dict[str, Any] | None = Field( + default=None, + json_schema_extra={ + "description": "Specify a scheduler and kwargs to use with the optimizer" + }, ) - lr_scheduler_kwargs: dict[str, Any] | None = None lr_quadratic_warmup: bool | None = None - cosine_min_lr_ratio: float | None = None - cosine_constant_lr_ratio: float | None = None - lr_div_factor: float | None = None + cosine_min_lr_ratio: float | None = Field( + default=None, + json_schema_extra={ + "description": "decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr" + }, + ) + cosine_constant_lr_ratio: float | None = Field( + default=None, + json_schema_extra={ + "description": "freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step" + }, + ) + lr_div_factor: float | None = Field( + default=None, json_schema_extra={"description": "Learning rate div factor"} + ) lr_groups: list[LrGroup] | None = None - adam_epsilon: float | None = None - adam_epsilon2: float | None = None - adam_beta1: float | None = None - adam_beta2: float | None = None - adam_beta3: float | None = None - max_grad_norm: float | None = None + adam_epsilon: float | None = Field( + default=None, json_schema_extra={"description": "adamw hyperparams"} + ) + adam_epsilon2: float | None = Field( + default=None, json_schema_extra={"description": "only used for CAME Optimizer"} + ) + adam_beta1: float | None = Field( + default=None, json_schema_extra={"description": "adamw hyperparams"} + ) + adam_beta2: float | None = Field( + default=None, json_schema_extra={"description": "adamw hyperparams"} + ) + adam_beta3: float | None = Field( + default=None, json_schema_extra={"description": "only used for CAME Optimizer"} + ) + + dion_lr: float | None = Field( + default=None, json_schema_extra={"description": "Dion Optimizer learning rate"} + ) + dion_momentum: float | None = Field( + default=None, json_schema_extra={"description": "Dion Optimizer momentum"} + ) + dion_rank_fraction: float | None = Field( + default=1.0, + json_schema_extra={ + "description": "Dion Optimizer: r/d fraction for low-rank approximation. Used to compute the low-rank dimension." + }, + ) + dion_rank_multiple_of: int | None = Field( + default=1, + json_schema_extra={ + "description": "Dion Optimizer: Round up the low-rank dimension to a multiple of this number. This may be useful to ensure even sharding." + }, + ) + + max_grad_norm: float | None = Field( + default=None, json_schema_extra={"description": "Gradient clipping max norm"} + ) num_epochs: float = Field(default=1.0) @field_validator("batch_size") @@ -99,3 +180,24 @@ class HyperparametersConfig(BaseModel): if learning_rate and isinstance(learning_rate, str): learning_rate = float(learning_rate) return learning_rate + + +class JaggedLRConfig(BaseModel): + """JaggedLR configuration subset, can be used w/ ReLoRA training""" + + jagged_restart_steps: int | None = Field( + default=None, + json_schema_extra={"description": "how often to reset for jagged restarts"}, + ) + jagged_restart_warmup_steps: int | None = Field( + default=None, + json_schema_extra={ + "description": "how many warmup steps to take after reset for jagged restarts" + }, + ) + jagged_restart_anneal_steps: int | None = Field( + default=None, + json_schema_extra={ + "description": "how many anneal steps to take before reset for jagged restarts" + }, + ) diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py index 37b71dba8..624f7663e 100644 --- a/src/axolotl/utils/schemas/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -1,5 +1,7 @@ """Pydantic models for TRL trainer configuration""" +from typing import Literal + from pydantic import BaseModel, Field @@ -10,12 +12,14 @@ class TRLConfig(BaseModel): beta: float | None = Field( default=None, - json_schema_extra={"description": "Beta for RL training"}, + json_schema_extra={ + "description": "Beta parameter for the RL training. Same as `rl_beta`. Use" + }, ) max_completion_length: int | None = Field( default=None, json_schema_extra={ - "description": "Maximum length of the completion for RL training" + "description": "Maximum length of the completion for RL training." }, ) @@ -23,81 +27,83 @@ class TRLConfig(BaseModel): # Ref: https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/grpo_config.py#L23 use_vllm: bool = Field( default=False, - json_schema_extra={"description": "Whether to use VLLM for RL training"}, + json_schema_extra={"description": "Whether to use VLLM for RL training."}, + ) + vllm_mode: Literal["server", "colocate"] | None = Field( + default=None, + json_schema_extra={ + "description": "VLLM mode to use, one of 'server' or 'colocate'" + }, ) vllm_server_host: str | None = Field( default="0.0.0.0", # nosec B104 - json_schema_extra={"description": "Host of the vLLM server to connect to"}, + json_schema_extra={"description": "Host of the vLLM server to connect to."}, ) vllm_server_port: int | None = Field( default=8000, - json_schema_extra={"description": "Port of the vLLM server to connect to"}, + json_schema_extra={"description": "Port of the vLLM server to connect to."}, ) vllm_server_timeout: int | None = Field( default=None, json_schema_extra={ - "description": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up " - "after the timeout, a `ConnectionError` is raised." + "description": "Total timeout (in seconds) to wait for the vLLM server to respond." }, ) vllm_guided_decoding_regex: str | None = Field( default=None, - json_schema_extra={ - "description": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled." - }, + json_schema_extra={"description": "Regex for vLLM guided decoding."}, ) reward_funcs: list[str] | None = Field( default=None, - json_schema_extra={"description": "List of reward functions to load"}, + json_schema_extra={ + "description": "List of reward functions to load. Paths must be importable from current dir." + }, ) reward_weights: list[float] | None = Field( default=None, json_schema_extra={ - "description": "Weights for each reward function. Must match the number of reward functions." + "description": "List of reward weights for the reward functions." }, ) num_generations: int | None = Field( default=None, - json_schema_extra={ - "description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value." - }, + json_schema_extra={"description": "Number of generations to sample."}, ) log_completions: bool | None = Field( default=False, - json_schema_extra={"description": "Whether to log completions"}, + json_schema_extra={"description": "Whether to log completions."}, ) num_completions_to_print: int | None = Field( default=None, json_schema_extra={ - "description": "Number of completions to print. If `log_completions` is `True`, this will be the number of completions logged." + "description": "Number of completions to print when log_completions is True." }, ) + importance_sampling_level: Literal["sequence", "token"] | None = Field( + default=None, + json_schema_extra={ + "description": "Controls whether importance sampling ratios are computed at the `'token'` or `'sequence'` level. " + "For GSPO, use `sequence`, default is None which corresponds to the original GRPO paper." + }, + ) + sync_ref_model: bool | None = Field( default=False, - json_schema_extra={ - "description": ( - "Whether to sync the reference model every `ref_model_sync_steps` " - "steps, using the `ref_model_mixup_alpha` parameter." - ) - }, + json_schema_extra={"description": "Whether to sync the reference model."}, ) ref_model_mixup_alpha: float | None = Field( default=0.9, - json_schema_extra={ - "description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`." - }, + json_schema_extra={"description": "Mixup alpha for the reference model."}, ) ref_model_sync_steps: int | None = Field( default=64, - json_schema_extra={ - "description": "Sync steps for the reference model. Requires `sync_ref_model=True`." - }, + json_schema_extra={"description": "Sync steps for the reference model."}, ) scale_rewards: bool = Field( default=True, json_schema_extra={ - "description": "Whether to scale the rewards for GRPO by dividing them by their standard deviation." + "description": "Whether to scale rewards by their standard deviation." }, ) @@ -124,13 +130,13 @@ class TRLConfig(BaseModel): repetition_penalty: float | None = Field( default=None, json_schema_extra={ - "description": "Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far." + "description": "Penalty for tokens that appear in prompt and generated text." }, ) num_iterations: int | None = Field( default=None, json_schema_extra={ - "description": "Number of iterations per batch (denoted as μ in the algorithm) for GRPO." + "description": "Number of iterations per batch (μ) for GRPO." }, ) epsilon: float | None = Field( @@ -152,12 +158,18 @@ class TRLConfig(BaseModel): loss_type: str | None = Field( default=None, json_schema_extra={ - "description": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`." + "description": "Loss formulation to use. Supported values: grpo, bnpo, dr_grpo." }, ) mask_truncated_completions: bool = Field( default=False, json_schema_extra={ - "description": "When enabled, truncated completions are excluded from the loss calculation." + "description": "Whether to exclude truncated completions from loss calculation." + }, + ) + vllm_enable_sleep_mode: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Enable sleep mode for vLLM to offload VRAM when idle" }, ) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py new file mode 100644 index 000000000..368976831 --- /dev/null +++ b/src/axolotl/utils/schemas/validation.py @@ -0,0 +1,1430 @@ +"""Module with validation methods for config pydantic model.""" + +import json +import sys +import tempfile +from pathlib import Path + +from pydantic import ( + field_validator, + model_validator, +) +from transformers.utils.import_utils import is_torch_npu_available + +from axolotl.utils.logging import get_logger +from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType + +LOG = get_logger(__name__) + +SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} + + +class DatasetValidationMixin: + """Validation methods related to dataset configuration.""" + + @field_validator("seed", mode="after") + @classmethod + def set_default_seed(cls, seed): + if seed is None: + LOG.info("`seed` not set in config; setting to 42") + seed = 42 + return seed + + @field_validator("datasets", mode="before") + @classmethod + def deprecate_sharegpt_datasets(cls, datasets): + for _, ds_cfg in enumerate(datasets): + ds_type = ( + ds_cfg.get("type") + if isinstance(ds_cfg, dict) + else getattr(ds_cfg, "type", None) + ) + if not ds_type: + continue + + if isinstance(ds_type, dict): + continue + + if isinstance(ds_type, str) and ds_type.startswith("sharegpt"): + raise ValueError( + "`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead." + ) + + return datasets + + @model_validator(mode="before") + @classmethod + def check_dataset_or_pretraining_dataset(cls, data): + if data.get("datasets") is None and data.get("pretraining_dataset") is None: + raise ValueError("either datasets or pretraining_dataset is required") + return data + + @model_validator(mode="before") + @classmethod + def check_pretraining_streaming_deprecation(cls, data): + # TODO(djsaunde): remove this check + implement change for 0.13.0 release + if data.get("pretraining_dataset") and not data.get("streaming"): + LOG.warning( + "Setting `pretraining_dataset` without explicitly setting `streaming: " + "true` is deprecated. In a future release, streaming will not be " + "automatically enabled when using pretraining_dataset. Please " + "explicitly set `streaming: true` in your configuration to maintain " + "current behavior." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_push_ds_auth(cls, data): + if ( + data.get("push_dataset_to_hub") + and data.get("hf_use_auth_token") is not True + ): + raise ValueError( + "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_val_w_test_datasets(cls, data): + if data.get("test_datasets") and data.get("val_set_size"): + raise ValueError( + "non-zero val_set_size should not be used with test_datasets configuration" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_test_datasets_bench(cls, data): + if ( + data.get("do_bench_eval") + and not data.get("test_datasets") + and not data.get("val_set_size") + ): + LOG.warning( + "`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset." + ) + data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}] + return data + + @model_validator(mode="before") + @classmethod + def check_eval_packing(cls, data): + # TODO also should check test_datasets and val_set_size as we can skip + # if there are no eval datasets/splits + if ( + data.get("sample_packing") + and data.get("eval_table_size") + and data.get("eval_sample_packing") is not False + ): + raise ValueError( + "eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false." + ) + if ( + data.get("sample_packing") + and data.get("eval_sample_packing") is None + and not data.get("eval_table_size") + ): + LOG.info( + "explicitly setting `eval_sample_packing` to match `sample_packing`", + main_process_only=True, + ) + data["eval_sample_packing"] = True + + if ( + data.get("sample_packing") + and data.get("eval_sample_packing") is False + and data.get("remove_unused_columns") is None + ): + LOG.info( + "setting `remove_unused_columns: false` for when sample_packing and eval_sample_packing don't match" + ) + data["remove_unused_columns"] = False + + return data + + @model_validator(mode="before") + @classmethod + def check_mm_prepare(cls, data): + if data.get("skip_prepare_dataset"): + if data.get("remove_unused_columns") is None: + LOG.info( + "setting `remove_unused_columns: false` for skip_prepare_dataset" + ) + data["remove_unused_columns"] = False + + return data + + +class AttentionValidationMixin: + """Validation methods related to attention mechanisms.""" + + @model_validator(mode="before") + @classmethod + def check_attention_fields(cls, data): + fields = ( + "xformers_attention", + "sdp_attention", + "s2_attention", + "flash_attention", + "flex_attention", + ) + non_empty_count = sum(1 for field in fields if data.get(field)) + + if non_empty_count > 1: + raise ValueError(f"Only one of {', '.join(fields)} must be set") + return data + + @model_validator(mode="before") + @classmethod + def check_sample_packing_without_attention(cls, data): + if ( + data.get("sample_packing") + and not data.get("flash_attention") + and not data.get("sdp_attention") + and not data.get("flex_attention") + and not data.get("xformers_attention") + ): + LOG.warning( + "sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_sample_packing_with_s2attn(cls, data): + if data.get("sample_packing") and data.get("s2_attention"): + raise ValueError( + "Received `sample_packing=true` and `s2_attention=true`; however, \ + shifted-sparse attention does not currently support sample packing." + ) + return data + + +class TrainingValidationMixin: + """Validation methods related to training configuration.""" + + @model_validator(mode="before") + @classmethod + def check_batch_size_fields(cls, data): + fields = ("micro_batch_size", "gradient_accumulation_steps", "batch_size") + non_empty_count = sum(1 for field in fields if data.get(field)) + + if non_empty_count < 2: + raise ValueError(f"At least two of {', '.join(fields)} must be set") + return data + + @model_validator(mode="before") + @classmethod + def hint_sample_packing_padding(cls, data): + if data.get("sample_packing"): + pad_to_sequence_len = data.get("pad_to_sequence_len") + if pad_to_sequence_len is False: + LOG.warning( + "`pad_to_sequence_len: true` is recommended when using sample_packing" + ) + elif pad_to_sequence_len is None: + LOG.info( + "Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing" + ) + data["pad_to_sequence_len"] = True + return data + + @model_validator(mode="before") + @classmethod + def hint_reward_model_pad(cls, data): + if data.get("reward_model") and not data.get("pad_to_sequence_len"): + LOG.warning( + "`pad_to_sequence_len: true` is recommended when using reward_model" + ) + if data.get("pad_to_sequence_len") is None: + data["pad_to_sequence_len"] = True + return data + + @model_validator(mode="before") + @classmethod + def check_gas_bsz(cls, data): + if data.get("gradient_accumulation_steps") and data.get("batch_size"): + raise ValueError( + "please set only one of gradient_accumulation_steps or batch_size" + ) + return data + + @model_validator(mode="before") + @classmethod + def hint_eval_train_mbsz(cls, data): + if ( + data.get("eval_batch_size") + and data.get("micro_batch_size") + and data.get("eval_batch_size") != data.get("micro_batch_size") + ): + LOG.warning( + "eval_batch_size != micro_batch_size. This can lead to VRAM instability." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_warmup(cls, data): + if data.get("warmup_steps") and data.get("warmup_ratio"): + raise ValueError("warmup_steps and warmup_ratio are mutually exclusive") + return data + + @model_validator(mode="before") + @classmethod + def check_saves(cls, data): + if ( + data.get("save_strategy") + and data.get("save_steps") + and data.get("save_strategy") != "steps" + ): + raise ValueError( + "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." + ) + if data.get("saves_per_epoch") and data.get("save_steps"): + raise ValueError( + "save_steps and saves_per_epoch are mutually exclusive and cannot be used together." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_push_save(cls, data): + if data.get("hub_model_id") and ( + data.get("save_strategy") not in ["steps", "epoch", None] + ): + LOG.warning( + "hub_model_id is set without any models being saved. To save a model, set save_strategy." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_evals(cls, data): + if ( + data.get("eval_strategy") + and data.get("eval_steps") + and data.get("eval_strategy") != "steps" + ): + raise ValueError( + "eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps." + ) + + if ( + data.get("val_set_size") == 0 + and (data.get("eval_steps") or data.get("eval_strategy")) + and not data.get("test_datasets") + and data.get("eval_strategy") != "no" + ): + raise ValueError( + "eval_steps and eval_strategy are not supported with val_set_size == 0" + ) + if data.get("evals_per_epoch") and data.get("eval_steps"): + raise ValueError( + "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together." + ) + if ( + data.get("evals_per_epoch") + and data.get("eval_strategy") + and data.get("eval_strategy") != "steps" + ): + raise ValueError( + "eval_strategy must be empty or set to `steps` when used with evals_per_epoch." + ) + + if data.get("do_bench_eval") and not ( + data.get("evals_per_epoch") or data.get("eval_steps") + ): + raise ValueError( + "do_bench_eval requires evals_per_epoch or eval_steps to be set." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_neftune(cls, data): + if data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"): + data["neftune_noise_alpha"] = data["noisy_embedding_alpha"] + del data["noisy_embedding_alpha"] + elif data.get("noisy_embedding_alpha") and data.get("neftune_noise_alpha"): + raise ValueError( + "noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_multipack_buffer_size(cls, data): + if data.get("pretrain_multipack_buffer_size") and not data.get( + "streaming_multipack_buffer_size" + ): + LOG.warning( + "`pretrain_multipack_buffer_size` is deprecated in v0.13.0, will be " + "removed in v0.14.0. Use `streaming_multipack_buffer_size` instead." + ) + data["streaming_multipack_buffer_size"] = data[ + "pretrain_multipack_buffer_size" + ] + del data["pretrain_multipack_buffer_size"] + elif data.get("pretrain_multipack_buffer_size") and data.get( + "streaming_multipack_buffer_size" + ): + raise ValueError( + "pretrain_multipack_buffer_size is deprecated, use " + "streaming_multipack_buffer_size; both are set, please remove the " + "deprecated pretrain_multipack_buffer_size setting" + ) + return data + + @model_validator(mode="after") + def check_fft_possible_bad_config(self): + if ( + not (self.bf16 or self.bfloat16) + and (self.fp16 or self.float16) + and not self.adapter + and not self.flash_attention + and self.sample_packing + ): + LOG.warning( + "Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA." + ) + # ValueError: Attempting to unscale FP16 gradients. + # OR + # RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half + return self + + @model_validator(mode="before") + @classmethod + def check_fp8_config(cls, data): + if data.get("fp8") and not data.get("torch_compile"): + LOG.warning( + "torch_compile is strongly recommended for FP8 training in order to " + "see speed improvements. Please consider setting `torch_compile: " + "true` in your config." + ) + fsdp_config = data.get("fsdp_config") or {} + if data.get("fp8") and ( + fsdp_config.get("activation_checkpointing", False) is True + or fsdp_config.get("fsdp_activation_checkpointing", False) is True + ): + LOG.warning( + "FP8 + FSDP2 + activation checkpointing may be slower than BF16 " + "training. Please considering setting `activation_checkpointing: false` " + "in your FSDP config." + ) + if ( + data.get("fp8_enable_fsdp_float8_all_gather") + and not data.get("fsdp_version", None) == 2 + ): + raise ValueError( + "fp8_enable_fsdp_float8_all_gather requires FSDP2 (fsdp_version: 2) " + "to be used." + ) + + return data + + @model_validator(mode="before") + @classmethod + def check_use_reentrant_mismatch(cls, data): + if ( + data.get("unfrozen_parameters") + and data.get("gradient_checkpointing_kwargs") + and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") + is True + ): + # https://github.com/huggingface/transformers/issues/21381 + raise ValueError( + "`use_reentrant` must be false when used with partially frozen model." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_eval_strategy(cls, data): + if ( + data.get("evaluation_strategy") is not None + and data.get("eval_strategy") is None + ): + LOG.info( + "explicitly setting `eval_strategy` from the `evaluation_strategy`" + ) + data["eval_strategy"] = data.get("evaluation_strategy") + return data + + @model_validator(mode="before") + @classmethod + def check_causal_lm_evals(cls, data): + if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"): + raise ValueError( + "do_causal_lm_eval is enabled, eval_sample_packing must be set to False" + ) + + if data.get("eval_causal_lm_metrics"): + if not isinstance(data.get("eval_causal_lm_metrics"), list): + raise ValueError("eval_causal_lm_metrics must be a list") + # only ["sacrebleu", "comet", "ter", "chrf"] supported + if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS: + raise ValueError( + f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_tokenizer_use_mistral_common(cls, data): + if data.get("tokenizer_use_mistral_common") is None: + if any( + "magistral" in name.lower() + for name in [ + data.get("base_model", ""), + data.get("base_model_config", ""), + data.get("tokenizer_config", ""), + ] + ): + LOG.warning( + "tokenizer_use_mistral_common auto inferred to True for Magistral models. Please set it to True explicitly if you want to use mistral-common tokenizer." + ) + data["tokenizer_use_mistral_common"] = True + + return data + + @field_validator("tokenizer_use_mistral_common", mode="after") + @classmethod + def check_mistral_common_import(cls, tokenizer_use_mistral_common): + if tokenizer_use_mistral_common: + import importlib.util + + if importlib.util.find_spec("mistral_common") is None: + raise ImportError( + "mistral-common is required for mistral models. Please install it with `pip install axolotl` or `pip install -e .`." + ) + + return tokenizer_use_mistral_common + + @model_validator(mode="before") + @classmethod + def check_mistral_common_incompatible_options(cls, data): + if not data.get("tokenizer_use_mistral_common"): + return data + + # NOTE: mistral-common tokenizer is not compatible with editing tokenizer at the moment + + if data.get("added_tokens_overrides"): + raise ValueError( + "added_tokens_overrides is not supported with mistral-common tokenizer" + ) + + if data.get("special_tokens"): + raise ValueError( + "special_tokens override is not supported with mistral-common tokenizer" + ) + + if data.get("tokens"): + raise ValueError( + "tokens override is not supported with mistral-common tokenizer" + ) + + if data.get("chat_template"): + raise ValueError( + "Setting chat_template is not supported with mistral-common tokenizer" + ) + + return data + + @model_validator(mode="before") + @classmethod + def pretrain_with_tps(cls, data): + if data.get("pretraining_dataset") and data.get( + "include_tokens_per_second", False + ): + # combining these would raise `TypeError: cannot pickle 'dict_keys' object` + # due to trying to count the number of tokens total in the dataset + raise ValueError( + "pretraining_dataset and include_tokens_per_second cannot be used together." + ) + + return data + + +class LoRAValidationMixin: + """Validation methods related to LoRA/QLoRA configuration.""" + + @model_validator(mode="before") + @classmethod + def check_lr_groups(cls, data): + if data.get("lr_groups") and data.get("loraplus_lr_ratio"): + raise ValueError("lr_groups and loraplus_lr_ratio cannot be used together.") + return data + + @model_validator(mode="before") + @classmethod + def check_frozen(cls, data): + if ( + data.get("adapter") + and data.get("peft_layers_to_transform") + and data.get("unfrozen_parameters") + ): + raise ValueError( + "`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_peft_layers_pattern(cls, data): + if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"): + raise ValueError( + "peft_layers_pattern requires peft_layers_to_transform to be set" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_qlora_unsloth(cls, data): + if ( + data.get("unsloth_lora_mlp") + or data.get("unsloth_lora_qkv") + or data.get("unsloth_lora_o") + ): + if data.get("adapter") == "lora" and data.get("load_in_8bit"): + raise ValueError( + "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_lora_axolotl_unsloth(cls, data): + is_lora_kernel = any( + data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"] + ) + is_unsloth_lora = any( + data.get(k) + for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"] + ) + if is_lora_kernel and is_unsloth_lora: + raise ValueError( + "both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)" + ) + return data + + @model_validator(mode="after") + def check_fused_lora(self): + if self.adapter in ["lora", "qlora"] and self.flash_attn_fuse_mlp: + raise ValueError("Fused modules are not supported with LoRA/QLoRA") + return self + + @model_validator(mode="before") + @classmethod + def warn_qlora_zero3_w_use_reentrant(cls, data): + if ( + data.get("adapter") == "qlora" + and data.get("gradient_checkpointing_kwargs", {}) + and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") + is False + and data.get("deepspeed", "") is not None + and "zero3" in data.get("deepspeed", "") + ): + # may result in: + # torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: + # Recomputed values for the following tensors have different metadata + # than during the forward pass. + LOG.warning( + "qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_lora_kernels_8bit(cls, data): + if ( + data.get("lora_mlp_kernel") + or data.get("lora_qkv_kernel") + or data.get("lora_o_kernel") + ): + if data.get("adapter") == "lora" and data.get("load_in_8bit"): + raise ValueError( + "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not " + "compatible with 8-bit LoRA a the moment." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_lora_kernels_dora(cls, data): + if ( + data.get("lora_mlp_kernel") + or data.get("lora_qkv_kernel") + or data.get("lora_o_kernel") + ) and data.get("peft_use_dora"): + raise ValueError( + "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not " + "compatible with DoRA at the moment." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_lora_kernels_rl(cls, data): + if ( + data.get("lora_mlp_kernel") + or data.get("lora_qkv_kernel") + or data.get("lora_o_kernel") + ) and data.get("rl"): + raise ValueError( + "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not " + "compatible with RL at the moment." + ) + return data + + +class RLValidationMixin: + """Validation methods related to RL training configuration.""" + + @model_validator(mode="before") + @classmethod + def check_sample_packing_w_rl(cls, data): + if data.get("sample_packing") and data.get("rl"): + raise ValueError("`sample_packing: true` does not work with RLHF training") + return data + + @model_validator(mode="before") + @classmethod + def check_kto_config(cls, data): + if data.get("rl") == "kto": + if data.get("sample_packing") or data.get("eval_sample_packing"): + raise ValueError("sample_packing is not supported with kto") + + if data.get("remove_unused_columns") is not False: + raise ValueError("Set `remove_unused_columns: False` when using kto") + return data + + @model_validator(mode="before") + @classmethod + def check_grpo_liger_sequence_parallel(cls, data): + if ( + data.get("rl") == "grpo" + and data.get("trl", {}) + and data.get("trl").get("use_liger_loss") + and data.get("context_parallel_size", 1) > 1 + ): + raise ValueError("GRPO + SP + Liger not currently supported") + return data + + @model_validator(mode="before") + @classmethod + def check_rl_config_gradient_checkpointing(cls, data): + # TODO: SalmanMohammadi + # Distributed RL with QLoRA + gradient checkpointing + # and use_reentrant = True is broken upstream in TRL + + if ( + data.get("rl") + and data.get("gradient_checkpointing") + and data.get("gradient_checkpointing_kwargs") + and data.get("gradient_checkpointing_kwargs").get("use_reentrant") + and data.get("load_in_4bit") + and data.get("adapter") == "qlora" + and data.get("capabilities") + and data.get("capabilities").get("n_gpu", 1) > 1 + ): + raise ValueError( + "The `use_reentrant: True` implementation of gradient checkpointing " + "is not supported for distributed RL training with QLoRA. Please set " + "`use_reentrant: False` in `gradient_checkpointing_kwargs`." + ) + return data + + +class OptimizationValidationMixin: + """Validation methods related to optimization and performance.""" + + @model_validator(mode="after") + def check_adamw_optimizer_params(self): + if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and ( + not self.optimizer or "adamw" not in str(self.optimizer).lower() + ): + LOG.warning("adamw hyperparameters found, but no adamw optimizer set") + return self + + @model_validator(mode="before") + @classmethod + def check_muon_deepspeed_fsdp(cls, data): + if data.get("optimizer") == "muon" and ( + data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config") + ): + raise ValueError( + "Muon optimizer is currently incompatible with DeepSpeed and FSDP" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_batch_flattening_fa(cls, data): + if data.get("batch_flattening"): + batch_flattening_auto = data.get("batch_flattening") == "auto" + if not data.get("flash_attention") and not batch_flattening_auto: + raise ValueError("batch_flattening requires flash attention") + if data.get("sample_packing") and not batch_flattening_auto: + raise ValueError("batch_flattening not compatible with sample_packing") + if data.get("micro_batch_size") == 1 and not batch_flattening_auto: + LOG.warning("batch_flattening has no effect with micro_batch_size == 1") + + if ( + batch_flattening_auto + and data.get("flash_attention") + and not data.get("sample_packing") + and data.get("micro_batch_size") > 1 + ): + data["batch_flattening"] = True + elif batch_flattening_auto: + data["batch_flattening"] = False + + return data + + @model_validator(mode="before") + @classmethod + def check_xentropy_patch_conflicts(cls, data): + if data.get("flash_attn_cross_entropy") and data.get( + "unsloth_cross_entropy_loss" + ): + raise ValueError( + "flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_fsdp_version(cls, data): + fsdp_config = data.get("fsdp_config", {}) + if fsdp_config and str(data.get("fsdp_version")) != "2": + LOG.info( + "FSDP1 will be deprecated in an upcoming release of Axolotl." + "We recommend that you use FSDP version 2 for better performance and compatibility. " + "Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp " + "For more details on migrating your config. " + ) + return data + + @model_validator(mode="before") + @classmethod + def check_fsdp2_cpu_offload_pin_memory(cls, data): + if not (fsdp_config := data.get("fsdp_config")): + return data + + if fsdp_config.get("cpu_offload_pin_memory") is False: + if str(data.get("fsdp_version")) != "2": + raise ValueError( + "FSDP1 does not support disabling cpu_offload_pin_memory, please set `fsdp_version` to 2" + ) + if not fsdp_config.get("offload_params"): + raise ValueError( + "disabling cpu_offload_pin_memory requires enabling offload_params" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_fsdp2_base_model_quant_rl(cls, data): + if data.get("fsdp_version") == 2 and data.get("rl") in [ + RLType.DPO, + RLType.KTO, + RLType.ORPO, + RLType.IPO, + ]: + if data.get("load_in_8bit") or data.get("load_in_4bit"): + raise ValueError( + f"FSDP2 does not support load_in_8bit or load_in_4bit with {data.get('rl')}. Please use DeepSpeed or set `fsdp_version` to 1." + ) + + return data + + @model_validator(mode="before") + @classmethod + def check_fsdp_version_in_fsdp_config(cls, data): + fsdp_config = data.get("fsdp_config") or {} + if fsdp_config and fsdp_config.get("fsdp_version"): + LOG.warning( + "Configuring `fsdp_version` in `fsdp_config` is deprecated. " + "Please configure `fsdp_version` as a top-level field." + ) + data["fsdp_version"] = fsdp_config.pop("fsdp_version") + return data + + @model_validator(mode="before") + @classmethod + def check_fsdp_config_kwargs_prefix(cls, data): + if fsdp_config := data.get("fsdp_config"): + should_fix = False + for key, _ in fsdp_config.items(): + if key.startswith("fsdp_"): + should_fix = True + LOG.warning_once( + "Configuring FSDP fields with the `fsdp_` prefix is deprecated. " + "Please omit the `fsdp_` prefix from the any fields in `fsdp_config`." + ) + if should_fix: + update_fsdp_config = {} + for key, value in fsdp_config.items(): + if key.startswith("fsdp_") and key != "fsdp_version": + update_fsdp_config[key.replace("fsdp_", "")] = value + else: + update_fsdp_config[key] = value + data["fsdp_config"] = update_fsdp_config + return data + + @model_validator(mode="after") + def check_fsdp_offload_w_8bit_optimizer(self): + if ( + hasattr(self, "fsdp_config") + and self.fsdp_config + and self.optimizer + and "8bit" in self.optimizer.value + and self.fsdp_config.offload_params + and str(self.fsdp_version) != "2" + ): + raise ValueError( + f"FSDP Offload not compatible with {str(self.optimizer.value)}" + ) + return self + + @model_validator(mode="after") + def check_fsdp2_w_8bit_optimizer(self): + if ( + hasattr(self, "fsdp_config") + and self.fsdp_config + and self.optimizer + and "8bit" in self.optimizer.value + and str(self.fsdp_version) == "2" + ): + if self.optimizer in ["adamw_8bit", "adamw_bnb_8bit"]: + # CUDA ops errors with bnb 8bit optimizer + FSDP2 + raise ValueError( + f"FSDP2 not compatible with {self.optimizer.value}, use `adamw_torch_8bit` instead" + ) + + return self + + @model_validator(mode="after") + def lr_groups_ao_optimizer(self): + if ( + self.loraplus_lr_ratio is not None + or self.embedding_lr_scale is not None + or self.embedding_lr is not None + or self.lr_groups is not None + ) and self.optimizer.value in ["adamw_torch_8bit", "adamw_torch_4bit"]: + # TODO(wing): remove this once ao>0.12.0 + # requires https://github.com/pytorch/ao/pull/2606 in an ao release + raise ValueError( + "lr groups (`loraplus_lr_ratio`, `embedding_lr_scale`, `embedding_lr`, `lr_groups`) are not " + "supported with ao low-bit optimizers until ao>0.12.0. " + "Please refer to https://github.com/pytorch/ao/pull/2606." + ) + return self + + @model_validator(mode="before") + @classmethod + def check_tensor_parallel_size_update_ds_json(cls, data): + tensor_parallel_size = data.get("tensor_parallel_size") + if tensor_parallel_size is not None and tensor_parallel_size > 1: + if data.get("deepspeed"): + with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin: + ds_config = json.load(ds_fin) + should_save = False + if "tensor_parallel" not in ds_config: + ds_config["tensor_parallel"] = { + "autotp_size": tensor_parallel_size + } + should_save = True + if ( + "gather_16bit_weights_on_model_save" + not in ds_config["zero_optimization"] + ): + ds_config["zero_optimization"][ + "gather_16bit_weights_on_model_save" + ] = True + should_save = True + if should_save: + temp_dir = tempfile.mkdtemp() + with open( + Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8" + ) as ds_fout: + json.dump(ds_config, ds_fout, indent=4) + data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json") + + return data + + @model_validator(mode="before") + @classmethod + def check_deepcompile(cls, data): + deepcompile = data.get("deepcompile") + if deepcompile: + if not data.get("deepspeed"): + raise ValueError("DeepCompile is only supported with DeepSpeed") + with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin: + ds_config = json.load(ds_fin) + if "compile" not in ds_config: + ds_config["compile"] = {"deepcompile": True} + temp_dir = tempfile.mkdtemp() + with open( + Path(temp_dir) / "deepcompile_ds.json", "w", encoding="utf-8" + ) as ds_fout: + json.dump(ds_config, ds_fout, indent=4) + data["deepspeed"] = str(Path(temp_dir) / "deepcompile_ds.json") + + return data + + +class SystemValidationMixin: + """Validation methods related to system and hardware configuration.""" + + @model_validator(mode="before") + @classmethod + def check_mem_mismatch(cls, data): + if ( + data.get("max_memory") is not None + and data.get("gpu_memory_limit") is not None + ): + raise ValueError( + "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_fsdp_deepspeed(cls, data): + if data.get("deepspeed") and data.get("fsdp"): + raise ValueError("deepspeed and fsdp cannot be used together.") + return data + + @model_validator(mode="before") + @classmethod + def check_model_quantization_config_vs_bnb(cls, data): + if data.get("model_quantization_config"): + if data.get("load_in_8bit") or data.get("load_in_4bit"): + raise ValueError( + "model_quantization_config and load_in_8bit or load_in_4bit cannot be used together." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_npu_config(cls, data): + if is_torch_npu_available(): + # check attention config + attn_list = ["flash_attention", "sdp_attention", "s2_attention"] + for attn in attn_list: + if data.get(attn): + raise NotImplementedError( + f"{attn} is currently not supported in Ascend npu, please disable this configuration." + ) + + # check quant config + if data.get("optimizer") is not None and "bit" in data.get("optimizer"): + optimizer = data.get("optimizer") + raise NotImplementedError( + f"{optimizer} is currently not supported in Ascend npu, choose another one please." + ) + + quant_list = ["load_in_8bit", "load_in_4bit"] + for quant in quant_list: + if data.get(quant): + raise NotImplementedError( + f"Quantification is currently not supported in Ascend npu, please disable {quant}." + ) + + # check dtype config + if data.get("tf32"): + raise NotImplementedError( + "tf32 dtype is currently not supported in Ascend npu, please disable this configuration" + ) + + return data + + +class ChatTemplateValidationMixin: + """Validation methods related to chat template configuration.""" + + @model_validator(mode="before") + @classmethod + def check_chat_template_config(cls, data): + # if chat_template is set to jinja, chat_template_jinja is required + if data.get("chat_template") == ChatTemplate.jinja and not data.get( + "chat_template_jinja" + ): + raise ValueError( + "chat_template_jinja is required when chat_template is set to jinja" + ) + + # If chat_template_jinja is set, set chat_template to jinja + if data.get("chat_template_jinja") and not data.get("chat_template"): + data["chat_template"] = ChatTemplate.jinja + + return data + + +class PretrainingValidationMixin: + """Validation methods related to pretraining configuration.""" + + @model_validator(mode="before") + @classmethod + def check_pretraining_w_max_steps(cls, data): + if data.get("pretraining_dataset") and not data.get("max_steps"): + raise ValueError( + "max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_pretraining_w_group_by_length(cls, data): + if data.get("pretraining_dataset") and data.get("group_by_length"): + LOG.warning( + "You probably want to disable group_by_length as it will force a streamed dataset to download completely." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_pretraining_split_batches_accelerate(cls, data): + # alternatively set ACCELERATE_SPLIT_BATCHES=False + if data.get("pretraining_dataset"): + accelerator_config = data.get("accelerator_config", {}) + if not accelerator_config: + data["accelerator_config"] = { + "split_batches": False, + "dispatch_batches": False, + } + else: + if accelerator_config.get("split_batches") is None: + data["accelerator_config"]["split_batches"] = False + if accelerator_config.get("dispatch_batches") is None: + data["accelerator_config"]["dispatch_batches"] = False + return data + + @model_validator(mode="before") + @classmethod + def check_pretraining_w_val_set_size(cls, data): + if data.get("pretraining_dataset") and data.get("val_set_size"): + raise ValueError( + "val_set_size is not supported with pretraining_dataset. " + "Use test_datasets to specify evaluation datasets for pretraining." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_streaming_w_val_set_size(cls, data): + if data.get("streaming") and data.get("val_set_size"): + raise ValueError( + "val_set_size is not supported with streaming datasets. " + "Use test_datasets to specify evaluation datasets when streaming is enabled." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_streaming_w_max_steps(cls, data): + if data.get("streaming") and not data.get("max_steps"): + raise ValueError( + "max_steps must be set when using streaming datasets. " + "Trainer cannot infer dataset length for iterable datasets." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_streaming_w_multiple_datasets(cls, data): + if ( + data.get("streaming") + and data.get("sample_packing") + and data.get("datasets") + and len(data.get("datasets")) > 1 + ): + raise NotImplementedError( + "Sample packing with multiple streaming datasets is not yet supported" + ) + return data + + +class ModelCompatibilityValidationMixin: + """Validation methods for specific model compatibility.""" + + @model_validator(mode="after") + def check_falcon_fsdp(self): + if (self.base_model and "falcon" in self.base_model.lower()) and self.fsdp: + raise ValueError("FSDP is not supported for falcon models") + return self + + @model_validator(mode="after") + def check_mpt_checkpointing(self): + if ( + self.base_model and "mpt" in self.base_model.lower() + ) and self.gradient_checkpointing: + raise ValueError("gradient_checkpointing is not supported for MPT models") + return self + + @model_validator(mode="after") + def check_gradient_checkpointing_w_offload(self): + if self.gradient_checkpointing == "offload": + LOG.warning( + "`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true` or `activation_offloading: legacy`" + ) + self.gradient_checkpointing = True + LOG.warning( + "`offload` now uses a new stream implementation; to use the previous implementation, use `activation_offloading: legacy`" + ) + self.activation_offloading = True + if self.gradient_checkpointing == "offload_disk": + LOG.warning( + "`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`" + ) + self.gradient_checkpointing = True + self.activation_offloading = "disk" + return self + + @model_validator(mode="after") + def check_activation_offloading_wo_gc(self): + if self.activation_offloading and not self.gradient_checkpointing: + raise ValueError("activation_offloading requires gradient_checkpointing") + return self + + @model_validator(mode="after") + def check_better_transformers(self): + if self.flash_optimum is True: + if self.adapter: + LOG.warning( + "BetterTransformers probably doesn't work with PEFT adapters" + ) + if self.fp16 or self.bf16: + raise ValueError("AMP is not supported with BetterTransformer") + if self.float16 is not True and self.bfloat16 is not True: + LOG.warning( + "You should probably set bfloat16 or float16 to true to " + "load the model in float16 for BetterTransformers" + ) + return self + + @model_validator(mode="before") + @classmethod + def check_gptq_w_revision(cls, data): + if data.get("gptq") and data.get("revision_of_model"): + raise ValueError( + "revision_of_model is not supported for GPTQ models. " + + "Please download the model from HuggingFace Hub manually for correct branch, " + + "point to its path, and remove revision_of_model from the config." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_gpt_oss_fsdp_loading(cls, data): + if data.get("model_quantization_config", "") == "Mxfp4Config": + fsdp_config = data.get("fsdp_config") or {} + if fsdp_config.get("cpu_ram_efficient_loading", False) is True: + raise ValueError( + "FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization." + ) + return data + + +class ComplexValidationMixin: + """Complex validation methods that involve multiple systems.""" + + @field_validator("neftune_noise_alpha") + @classmethod + def validate_neftune_noise_alpha(cls, neftune_noise_alpha): + if neftune_noise_alpha is not None and neftune_noise_alpha <= 0.0: + raise ValueError("neftune_noise_alpha must be > 0.0") + return neftune_noise_alpha + + @model_validator(mode="after") + def check_rl_beta(self): + if self.dpo_beta and not self.rl_beta: + self.rl_beta = self.dpo_beta + del self.dpo_beta + return self + + @model_validator(mode="after") + def check_simpo_warmup(self): + if self.rl is RLType.SIMPO and self.warmup_ratio: + raise ValueError( + "warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead" + ) + return self + + @model_validator(mode="after") + def check_relora(self): + if self.relora: + if not self.jagged_restart_steps: + raise ValueError("jagged_restart_steps must be set to use ReLoRA") + if self.adapter not in ("lora", "qlora"): + raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") + + if self.fsdp or self.fsdp_config: + raise ValueError("fsdp not supported with ReLoRA") + + if self.deepspeed: + raise ValueError("deepspeed not supported with ReLoRA") + + if self.lr_scheduler == "one_cycle": + raise ValueError( + "ReLoRA is not compatible with the one_cycle scheduler" + ) + + if self.flash_attn_fuse_mlp: + raise ValueError("Fused modules are not supported with ReLoRA") + return self + + @model_validator(mode="after") + def check_early_stopping(self): + if self.early_stopping_patience: + if not self.save_steps or not self.eval_steps: + raise ValueError( + "`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps." + ) + if self.save_steps % self.eval_steps != 0: + raise ValueError( + "`early_stopping_patience` requires that eval_steps should evenly divide save_steps." + ) + return self + + @model_validator(mode="after") + def check_tensor_parallel_size(self): + if not self.tensor_parallel_size: + self.tensor_parallel_size = 1 + return self + + @model_validator(mode="after") + def check_context_parallel_size(self): + if self.sequence_parallel_degree and not self.context_parallel_size: + LOG.warning( + "`sequence_parallel_degree` is deprecated, use `context_parallel_size`" + ) + self.context_parallel_size = self.sequence_parallel_degree + if not self.context_parallel_size: + self.context_parallel_size = 1 + elif self.context_parallel_size > 1: + if not self.flash_attention: + raise ValueError( + "flash_attention: true must be set with context_parallel_size > 1" + ) + + if self.sample_packing and self.micro_batch_size > 1: + raise ValueError( + "micro_batch_size must be set to 1 when sample_packing is enabled " + "due to a `ring-flash-attn` requirement" + ) + + try: + import transformers.modeling_flash_attention_utils + from transformers.utils import is_flash_attn_greater_or_equal + + transformers.modeling_flash_attention_utils._flash_supports_window = ( + True + ) + sys.modules[ + "transformers.modeling_flash_attention_utils" + ]._flash_supports_window = True + sys.modules[ + "transformers.modeling_flash_attention_utils" + ]._flash_supports_window_size = True + sys.modules[ + "transformers.modeling_flash_attention_utils" + ].is_flash_attn_greater_or_equal = is_flash_attn_greater_or_equal + import ring_flash_attn # noqa: F401 # Required after monkey-patching + except ImportError as exception: + raise ImportError( + "context_parallel_size > 1 but ring_flash_attn is not installed. " + "Please install it with `pip install axolotl[ring-flash-attn] " + "or `pip install ring-flash-attn>=0.1.4`." + ) from exception + + LOG.warning( + "Sequence parallelism (SP) is enabled with " + f"context_parallel_size={self.context_parallel_size}. " + "Please note that logged losses may differ slightly to the non-SP " + "losses due to transformers Trainer implementation details. " + "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " + "for more details." + ) + + return self + + @model_validator(mode="after") + def validate_ring_attn_func(self): + if getattr(self, "context_parallel_size", 1) == 1: + return self + + if self.ring_attn_func is not None: + self.ring_attn_func = RingAttnFunc(self.ring_attn_func) + else: + # Default ring attention function selection + sample_packing = getattr(self, "sample_packing", False) + self.ring_attn_func = ( + RingAttnFunc.VARLEN_LLAMA3 + if sample_packing + else RingAttnFunc.BATCH_RING + ) + + return self + + def hint_gradient_checkpointing_dpo_lora_ddp(self): + if ( + (self.gradient_checkpointing is True or self.gradient_checkpointing is None) + and self.capabilities + and self.capabilities.get("n_gpu", 1) > 1 + and self.adapter in ("lora", "qlora") + and self.rl == RLType.DPO + and not self.fsdp + and not self.deepspeed + ): + LOG.warning( + "gradient_checkpointing with DPO + DDP + LoRA is not recommended." + ) + return self + + +class DistributedValidationMixin: + """validation for distributed training.""" + + @model_validator(mode="after") + def check_tensor_parallel_optimizer(self): + if self.tensor_parallel_size > 1: + if self.optimizer in ["paged_adamw_8bit", "adamw_8bit", "adamw_bnb_8bit"]: + raise ValueError( + "tensor_parallel_size is not supported with paged_adamw_8bit, adamw_8bit, and adamw_bnb_8bit optimizers" + ) + + return self + + +class GRPOVllmValidationMixin: + """Validation mixin for vllm when using GRPO.""" + + @model_validator(mode="after") + def check_vllm_mode_set(self): + if self.trl and self.trl.use_vllm and not self.trl.vllm_mode: + LOG.warning( + "vllm_mode must be set to either `server` or `colocate` when using vllm, using default value `server`" + ) + self.trl.vllm_mode = "server" + return self + + +class ValidationMixin( + DatasetValidationMixin, + AttentionValidationMixin, + TrainingValidationMixin, + LoRAValidationMixin, + RLValidationMixin, + OptimizationValidationMixin, + SystemValidationMixin, + ChatTemplateValidationMixin, + PretrainingValidationMixin, + ModelCompatibilityValidationMixin, + ComplexValidationMixin, + GRPOVllmValidationMixin, +): + """Full validation mixin for Axolotl configuration.""" diff --git a/src/axolotl/utils/schemas/vllm.py b/src/axolotl/utils/schemas/vllm.py index 0ae635589..518b8f62d 100644 --- a/src/axolotl/utils/schemas/vllm.py +++ b/src/axolotl/utils/schemas/vllm.py @@ -18,6 +18,10 @@ class VllmConfig(BaseModel): default=None, json_schema_extra={"description": "Tensor parallel size for VLLM"}, ) + data_parallel_size: int | None = Field( + default=None, + json_schema_extra={"description": "Data parallel size for VLLM"}, + ) gpu_memory_utilization: float | None = Field( default=0.9, json_schema_extra={"description": "GPU memory utilization for VLLM"}, diff --git a/src/axolotl/utils/tee.py b/src/axolotl/utils/tee.py new file mode 100644 index 000000000..7bc8efab0 --- /dev/null +++ b/src/axolotl/utils/tee.py @@ -0,0 +1,166 @@ +""" +Utilities for managing the debug log file and providing a file-only stream for logging +handlers. +""" + +from __future__ import annotations + +import io +import os +import sys +import threading +from pathlib import Path +from typing import TextIO, cast + +_lock = threading.Lock() +_file_handle: io.TextIOWrapper | None = None +_log_path: str | None = None +_tee_installed: bool = False +_orig_stdout: TextIO | None = None +_orig_stderr: TextIO | None = None + + +class _FileOnlyWriter(io.TextIOBase): + """A stream-like object that writes only to the tee file. + + Before the file is prepared, writes are dropped (no-op). + """ + + def write(self, s: str) -> int: # type: ignore[override] + with _lock: + if _file_handle is not None: + _file_handle.write(s) + return len(s) + return len(s) + + def flush(self) -> None: # type: ignore[override] + with _lock: + if _file_handle is not None: + try: + _file_handle.flush() + except Exception: + pass + + +file_only_stream: io.TextIOBase = _FileOnlyWriter() + + +class _StreamTee(io.TextIOBase): + """A minimal tee that mirrors writes to the debug log file. + + Installed only after the debug log is prepared; no buffering. + """ + + def __init__(self, stream: io.TextIOBase): + self._stream = stream + + def write(self, s: str) -> int: # type: ignore[override] + with _lock: + n = self._stream.write(s) + if _file_handle is not None: + _file_handle.write(s) + return n + + def flush(self) -> None: # type: ignore[override] + with _lock: + self._stream.flush() + if _file_handle is not None: + try: + _file_handle.flush() + except Exception: + pass + + @property + def encoding(self): # type: ignore[override] + return getattr(self._stream, "encoding", None) + + @property + def errors(self): # type: ignore[override] + return getattr(self._stream, "errors", None) + + def isatty(self): # type: ignore[override] + return getattr(self._stream, "isatty", lambda: False)() + + def fileno(self): # type: ignore[override] + if hasattr(self._stream, "fileno"): + return self._stream.fileno() + raise OSError("Underlying stream has no fileno") + + +def prepare_debug_log(cfg, filename: str = "debug.log") -> str: + """ + Prepare the debug log. + + Creates the output directory, handles append/truncate logic based on cfg, and opens + the debug log file for subsequent writes via file-only handlers. + """ + global _file_handle, _log_path, _tee_installed + + with _lock: + # If already initialized, reuse existing path + if _log_path is not None: + return _log_path + + output_dir = cfg.output_dir + os.makedirs(output_dir, exist_ok=True) + + log_path = Path(output_dir) / filename + append = bool( + cfg.get("resume_from_checkpoint") or cfg.get("auto_resume_from_checkpoints") + ) + + if not append: + log_path.unlink(missing_ok=True) + + fh = open(log_path, "a", encoding="utf-8") + fh.flush() + + _file_handle = fh + _log_path = str(log_path) + + # Install a tee so stdout/stderr are mirrored to the debug file + # Allow disabling via env for testing or advanced usage. + tee_enabled = os.getenv("AXOLOTL_TEE_STDOUT", "1").lower() not in { + "0", + "false", + "no", + } + if tee_enabled and not _tee_installed: + # Save originals so we can restore later (e.g., tests) + global _orig_stdout, _orig_stderr + _orig_stdout = sys.stdout + _orig_stderr = sys.stderr + sys.stdout = _StreamTee(cast(io.TextIOBase, sys.stdout)) + sys.stderr = _StreamTee(cast(io.TextIOBase, sys.stderr)) + _tee_installed = True + + return _log_path + + +def close_debug_log() -> None: + """Flush and close the debug log and uninstall the stdout/stderr tee. + + Safe to call even if not initialized. + """ + global _file_handle, _log_path, _tee_installed, _orig_stdout, _orig_stderr + with _lock: + # Restore original stdout/stderr if we installed a tee + if _tee_installed: + if _orig_stdout is not None: + sys.stdout = _orig_stdout + if _orig_stderr is not None: + sys.stderr = _orig_stderr + _tee_installed = False + _orig_stdout = None + _orig_stderr = None + + # Close the file handle if open + if _file_handle is not None: + try: + _file_handle.flush() + _file_handle.close() + except Exception: + pass + finally: + _file_handle = None + _log_path = None diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index 3526bd5b5..3f44a3429 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -31,7 +31,7 @@ def check_example_labels(example, tokenizer, text_only=False): # You can compare the input_ids and labels element-wise # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0 colored_tokens = [] - for _, (input_id, label_id) in enumerate(zip(input_ids, labels)): + for _, (input_id, label_id) in enumerate(zip(input_ids, labels, strict=False)): decoded_input_token = tokenizer.decode(input_id) # Choose the color based on whether the label has the ignore value or not color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") diff --git a/src/axolotl/utils/train.py b/src/axolotl/utils/train.py new file mode 100644 index 000000000..ad3f72be4 --- /dev/null +++ b/src/axolotl/utils/train.py @@ -0,0 +1,47 @@ +"""Training utils for checkpoints""" + +from pathlib import Path + +from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | None: + """ + Determine the checkpoint to resume from based on configuration. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + update: Whether to update the config with the determined checkpoint + + Returns: + Path to the checkpoint to resume from, or `None` if not resuming. + """ + last_checkpoint = None + checkpoints = sorted( + ( + p + for p in Path(cfg.output_dir).glob("checkpoint-*") + if p.name.split("-")[-1].isdigit() + ), + key=lambda p: int(p.name.split("-")[-1]), + ) + if checkpoints: + last_checkpoint = str(checkpoints[-1]) + if not update: + LOG.info(f"Resuming from last checkpoint at {last_checkpoint}") + return last_checkpoint + + if ( + cfg.resume_from_checkpoint is None + and cfg.auto_resume_from_checkpoints + and last_checkpoint is not None + ): + cfg.resume_from_checkpoint = last_checkpoint + LOG.info( + "Using auto-resume functionality to resume from checkpoint at " + f"{cfg.resume_from_checkpoint}" + ) + return cfg.resume_from_checkpoint diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 67f590a37..d97577d86 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -6,20 +6,20 @@ import os import random from contextlib import contextmanager from functools import partial +from tempfile import NamedTemporaryFile from typing import List, Optional import numpy as np import torch import torch.cuda -from accelerate.logging import get_logger from datasets import IterableDataset, disable_caching, enable_caching from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.utils import is_torch_bf16_gpu_available -from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder -from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2 -from axolotl.utils.distributed import reduce_and_broadcast +from axolotl.utils.dict import DictDefault +from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support +from axolotl.utils.logging import get_logger from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = get_logger(__name__) @@ -278,7 +278,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): prior_len = None filter_map_kwargs = {} if not isinstance(train_dataset, IterableDataset): - filter_map_kwargs["num_proc"] = cfg.dataset_processes + filter_map_kwargs["num_proc"] = cfg.dataset_num_proc filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess drop_long_kwargs = {} @@ -318,7 +318,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if cfg.group_by_length: train_dataset = train_dataset.map( add_length, - num_proc=cfg.dataset_processes, + num_proc=cfg.dataset_num_proc, load_from_cache_file=not cfg.is_preprocess, desc="Group By Length", ) @@ -335,7 +335,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): ) train_dataset = train_dataset.map( pose_fn, - num_proc=cfg.dataset_processes, + num_proc=cfg.dataset_num_proc, load_from_cache_file=not cfg.is_preprocess, desc="Add position_id column (PoSE)", ) @@ -344,7 +344,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if eval_dataset: eval_dataset = eval_dataset.map( pose_fn, - num_proc=cfg.dataset_processes, + num_proc=cfg.dataset_num_proc, load_from_cache_file=not cfg.is_preprocess, desc="Add position_id column (PoSE)", ) @@ -382,6 +382,7 @@ def process_pretraining_datasets_for_packing( if not skip_position_ids: train_dataset = train_dataset.map( add_position_ids, + batched=True, desc="Add position_id column (Pretraining Sample Packing)", ) if drop_attention_mask: @@ -442,7 +443,8 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): - 1 ) * cfg.num_epochs - * cfg.sequence_parallel_degree + * cfg.context_parallel_size + * cfg.tensor_parallel_size ) LOG.debug( f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}" @@ -467,28 +469,38 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): bin_size=cfg.sample_packing_bin_size, sequential=cfg.sample_packing_sequentially, drop_last=True, + num_processes=cfg.dataset_prcoesses, + mp_start_method=cfg.sample_packing_mp_start_method or "fork", ) data_loader = DataLoader( train_dataset.remove_columns(["length"]), batch_sampler=sampler, ) - data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size + data_loader_len = max( + 1, len(data_loader) * cfg.micro_batch_size // cfg.batch_size + ) LOG.debug(f"data_loader_len: {data_loader_len}") # FIXME: is there a bug here somewhere? the total num steps depends # on the agreed on value for sample_packing_eff_est total_num_steps = int( math.floor( - data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree + data_loader_len + * cfg.num_epochs + * cfg.context_parallel_size + * cfg.tensor_parallel_size ) ) + if cfg.dataloader_drop_last: + # drop the last batch for each epoch + total_num_steps -= int(math.ceil(cfg.num_epochs)) def calc_sample_packing_eff_est(estimates: List[float]): LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}") return max(estimates) sample_packing_actual_eff_all = reduce_and_broadcast( - lambda: sampler.efficiency(), # pylint: disable=unnecessary-lambda + lambda: sampler.efficiency(), calc_sample_packing_eff_est, ) sample_packing_eff_est = ( @@ -502,7 +514,8 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): math.ceil( len(train_dataset) * cfg.num_epochs - * cfg.sequence_parallel_degree + * cfg.context_parallel_size + * cfg.tensor_parallel_size / cfg.batch_size ) ) @@ -529,59 +542,126 @@ def setup_deepspeed_env(cfg, stage=None): ) os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + if isinstance(cfg.deepspeed, DictDefault): + with NamedTemporaryFile( + mode="w", delete=False, suffix=".json", prefix="deepspeed_config_" + ) as temp_file: + temp_file.write(json.dumps(cfg.deepspeed.to_dict(), indent=4)) + temp_file.close() + cfg.deepspeed = str(temp_file.name) os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed + os.environ["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str( + cfg.gradient_accumulation_steps + ) if stage: os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage) if stage == 3: os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true" + + device_count = torch.cuda.device_count() + if device_count == 1: + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("LOCAL_RANK", "0") + os.environ.setdefault("MASTER_ADDR", "0.0.0.0") # nosec B104 + os.environ.setdefault("MASTER_PORT", "29500") + + # NOTE(djsaunde): The distribued state cannot be initialized prior to the + # ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior + # to model load. + if ( + int(os.environ.get("WORLD_SIZE", "1")) == 1 + and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1" + and cfg.use_ray is not True + ): + os.environ["WORLD_SIZE"] = "1" # force it in case not set + os.environ["LOCAL_RANK"] = "0" # force it in case not set + os.environ["RANK"] = os.environ.get("LOCAL_RANK", "0") + import deepspeed.comm as dist + + dist.init_distributed( + dist_backend="nccl", auto_mpi_discovery=False, dist_init_required=True + ) + init_distributed_state() + # If we don't assign this, it doesn't actually get set in the accelerate weakref _ = HfTrainerDeepSpeedConfig(cfg.deepspeed) def setup_fsdp_envs(cfg): os.environ["ACCELERATE_USE_FSDP"] = "true" - if str(cfg.fsdp_config.fsdp_version) == "2": + + # TODO @SalmanMohammadi remove FSDP1 args in 0.12 + if str(cfg.fsdp_version) == "2": os.environ["FSDP_VERSION"] = "2" - if cfg.fsdp_config.fsdp_activation_checkpointing: + if cfg.fsdp_config.activation_checkpointing: os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true" - if cfg.fsdp_config.fsdp_offload_params: + if cfg.fsdp_config.offload_params: os.environ["FSDP_OFFLOAD_PARAMS"] = "true" - if cfg.fsdp_config.fsdp_sync_module_states: + if cfg.fsdp_config.sync_module_states: os.environ["FSDP_SYNC_MODULE_STATES"] = "true" - if cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + if cfg.fsdp_config.cpu_ram_efficient_loading: os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "true" - if cfg.fsdp_config.fsdp_use_orig_params: + if cfg.fsdp_config.use_orig_params: os.environ["FSDP_USE_ORIG_PARAMS"] = "true" - if cfg.fsdp_config.fsdp_state_dict_type: - os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type - if cfg.fsdp_config.fsdp_auto_wrap_policy: - os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.fsdp_auto_wrap_policy - if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap: + if cfg.fsdp_config.state_dict_type: + os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.state_dict_type + if cfg.fsdp_config.cpu_offload_pin_memory is not None: + os.environ["FSDP_CPU_OFFLOAD_PIN_MEMORY"] = str( + cfg.fsdp_config.cpu_offload_pin_memory + ).lower() + if cfg.fsdp_config.auto_wrap_policy: + os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.auto_wrap_policy + if cfg.fsdp_config.transformer_layer_cls_to_wrap: os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ( - cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap - ) - if cfg.fsdp_config.fsdp_reshard_after_forward is not None: - os.environ["FSDP_RESHARD_AFTER_FORWARD"] = ( - "true" if cfg.fsdp_config.fsdp_reshard_after_forward else "false" + cfg.fsdp_config.transformer_layer_cls_to_wrap ) + if cfg.fsdp_config.reshard_after_forward: + os.environ["FSDP_RESHARD_AFTER_FORWARD"] = "true" + + +def setup_parallelism_envs(cfg): + set_accelerate_parallelism_config = False + if cfg.tensor_parallel_size and cfg.tensor_parallel_size > 1: + set_accelerate_parallelism_config = True + os.environ["PARALLELISM_CONFIG_TP_SIZE"] = str(cfg.tensor_parallel_size) + if cfg.dp_shard_size and cfg.dp_shard_size > 1: + set_accelerate_parallelism_config = True + os.environ["PARALLELISM_CONFIG_DP_SHARD_SIZE"] = str(cfg.dp_shard_size) + if cfg.dp_replicate_size and cfg.dp_replicate_size > 1: + set_accelerate_parallelism_config = True + os.environ["PARALLELISM_CONFIG_DP_REPLICATE_SIZE"] = str(cfg.dp_replicate_size) + if cfg.context_parallel_size and cfg.context_parallel_size > 1: + set_accelerate_parallelism_config = True + os.environ["PARALLELISM_CONFIG_CP_SIZE"] = str(cfg.context_parallel_size) + os.environ["ACCELERATE_ALLOW_CP_STANDALONE"] = "true" + if set_accelerate_parallelism_config: + os.environ["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true" def prepare_optim_env(cfg): if not check_cuda_p2p_ib_support(): if os.getenv("NCCL_P2P_DISABLE") is None: + LOG.warning("P2P support not detected, setting `NCCL_P2P_DISABLE=1`") os.environ["NCCL_P2P_DISABLE"] = "1" - if cfg.fsdp: + # TODO @SalmanMohammadi remove the cfg.fsdp check in 0.12 + if cfg.fsdp or cfg.fsdp_config: + cfg.fsdp = True if not cfg.fsdp else cfg.fsdp setup_fsdp_envs(cfg) elif cfg.deepspeed: stage = None + deepspeed_config = None # check if the cfg.deepspeed is a file - if os.path.isfile(cfg.deepspeed): + if isinstance(cfg.deepspeed, DictDefault): + deepspeed_config = cfg.deepspeed + elif os.path.isfile(cfg.deepspeed): # parse with json with open(cfg.deepspeed, "r", encoding="utf-8") as fin: deepspeed_config = json.load(fin) + if deepspeed_config: stage = deepspeed_config.get("zero_optimization", {}).get("stage", None) setup_deepspeed_env(cfg, stage=stage) + setup_parallelism_envs(cfg) setup_torch_compile_env(cfg) if cfg.fp8: @@ -594,12 +674,6 @@ def prepare_optim_env(cfg): os.environ["ACCELERATE_MIXED_PRECISION"] = "no" -def prepare_opinionated_env(cfg): - if cfg.qlora_sharded_model_loading: - # model loading is forked after the tokenizer - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - def setup_trainer( cfg, train_dataset, @@ -629,12 +703,8 @@ def setup_trainer( A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based on the provided parameters. """ - if ( - cfg.torch_compile - and cfg.fsdp_config - and str(cfg.fsdp_config.fsdp_version) == "2" - ): - patch_evaluation_loop_for_fsdp2() + from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder + if cfg.rl: trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor) trainer_builder.model_ref = model_ref diff --git a/src/setuptools_axolotl_dynamic_dependencies.py b/src/setuptools_axolotl_dynamic_dependencies.py index 02a5b8083..3bb54cda8 100644 --- a/src/setuptools_axolotl_dynamic_dependencies.py +++ b/src/setuptools_axolotl_dynamic_dependencies.py @@ -9,7 +9,6 @@ from importlib.metadata import PackageNotFoundError, version from setuptools.command.build_py import build_py as _build_py -# pylint: disable=duplicate-code def parse_requirements(): _install_requires = [] _dependency_links = [] @@ -34,7 +33,6 @@ def parse_requirements(): try: xformers_version = [req for req in _install_requires if "xformers" in req][0] torchao_version = [req for req in _install_requires if "torchao" in req][0] - autoawq_version = [req for req in _install_requires if "autoawq" in req][0] if "Darwin" in platform.system(): # don't install xformers on MacOS @@ -64,7 +62,6 @@ def parse_requirements(): _install_requires.append("xformers==0.0.28.post2") else: _install_requires.append("xformers==0.0.28.post3") - _install_requires.pop(_install_requires.index(autoawq_version)) elif (major, minor) >= (2, 4): if patch == 0: _install_requires.pop(_install_requires.index(xformers_version)) diff --git a/tests/cli/test_cli_base.py b/tests/cli/test_cli_base.py index 6dbae045f..e28bbb75c 100644 --- a/tests/cli/test_cli_base.py +++ b/tests/cli/test_cli_base.py @@ -17,16 +17,23 @@ class BaseCliTest: command: Command to test (train/evaluate) """ # Test missing config file - result = cli_runner.invoke(cli, [command, "--no-accelerate"]) + result = cli_runner.invoke(cli, [command, "--launcher", "python"]) assert result.exit_code != 0 # Test non-existent config file - result = cli_runner.invoke(cli, [command, "nonexistent.yml", "--no-accelerate"]) + result = cli_runner.invoke( + cli, [command, "nonexistent.yml", "--launcher", "python"] + ) assert result.exit_code != 0 assert "Error: Invalid value for 'CONFIG'" in result.output def _test_basic_execution( - self, cli_runner, tmp_path: Path, valid_test_config: str, command: str + self, + cli_runner, + tmp_path: Path, + valid_test_config: str, + command: str, + train: bool = True, ): """Test basic execution with accelerate. @@ -35,24 +42,37 @@ class BaseCliTest: tmp_path: Temporary path fixture valid_test_config: Valid config fixture command: Command to test (train/evaluate) + train: Whether to test training (default) or evaluation """ config_path = tmp_path / "config.yml" config_path.write_text(valid_test_config) - with patch("subprocess.run") as mock: + mock_fn = "os.execvpe" if command == "train" else "subprocess.run" + + with patch(mock_fn) as mock: result = cli_runner.invoke(cli, [command, str(config_path)]) assert mock.called - assert mock.call_args.args[0] == [ + + expected = [ "accelerate", "launch", "-m", f"axolotl.cli.{command}", str(config_path), - "--debug-num-examples", - "0", + "--debug=False", + "--debug-text-only=False", + "--debug-num-examples=0", ] - assert mock.call_args.kwargs == {"check": True} + if train: + expected.append("--shard=False") + + if command == "train": + assert mock.call_args.args[0] == "accelerate" + assert mock.call_args.args[1] == expected + else: + assert mock.call_args.args[0] == expected + assert mock.call_args.kwargs == {"check": True} assert result.exit_code == 0 def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str): diff --git a/tests/cli/test_cli_evaluate.py b/tests/cli/test_cli_evaluate.py index d8eb41467..e8b88625a 100644 --- a/tests/cli/test_cli_evaluate.py +++ b/tests/cli/test_cli_evaluate.py @@ -18,7 +18,9 @@ class TestEvaluateCommand(BaseCliTest): def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_test_config): """Test basic successful execution""" - self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "evaluate") + self._test_basic_execution( + cli_runner, tmp_path, valid_test_config, "evaluate", train=False + ) def test_evaluate_basic_execution_no_accelerate( self, cli_runner, tmp_path, valid_test_config @@ -33,7 +35,8 @@ class TestEvaluateCommand(BaseCliTest): [ "evaluate", str(config_path), - "--no-accelerate", + "--launcher", + "python", ], catch_exceptions=False, ) @@ -55,7 +58,8 @@ class TestEvaluateCommand(BaseCliTest): "2", "--sequence-len", "128", - "--no-accelerate", + "--launcher", + "python", ], catch_exceptions=False, ) @@ -65,3 +69,104 @@ class TestEvaluateCommand(BaseCliTest): cfg = mock_evaluate.call_args[0][0] assert cfg.micro_batch_size == 2 assert cfg.sequence_len == 128 + + def test_evaluate_with_launcher_args_torchrun( + self, cli_runner, tmp_path, valid_test_config + ): + """Test evaluate with torchrun launcher arguments""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "evaluate", + str(config_path), + "--launcher", + "torchrun", + "--", + "--nproc_per_node=2", + "--nnodes=1", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify launcher args are passed to torchrun + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "torchrun" + assert "--nproc_per_node=2" in called_cmd + assert "--nnodes=1" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.evaluate" in called_cmd + + def test_evaluate_with_launcher_args_accelerate( + self, cli_runner, tmp_path, valid_test_config + ): + """Test evaluate with accelerate launcher arguments""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "evaluate", + str(config_path), + "--launcher", + "accelerate", + "--", + "--config_file=accelerate_config.yml", + "--num_processes=4", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify launcher args are passed to accelerate + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + assert "--config_file=accelerate_config.yml" in called_cmd + assert "--num_processes=4" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.evaluate" in called_cmd + + def test_evaluate_backward_compatibility_no_launcher_args( + self, cli_runner, tmp_path, valid_test_config + ): + """Test that existing evaluate commands work without launcher args""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "evaluate", + str(config_path), + "--launcher", + "accelerate", + "--micro-batch-size", + "2", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify no launcher args contamination + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + # Should not contain any extra launcher args + launcher_section = called_cmd[2 : called_cmd.index("-m")] + assert ( + len(launcher_section) == 0 + ) # No launcher args between 'launch' and '-m' diff --git a/tests/cli/test_cli_inference.py b/tests/cli/test_cli_inference.py index b8effa3d2..807dc7fa3 100644 --- a/tests/cli/test_cli_inference.py +++ b/tests/cli/test_cli_inference.py @@ -10,7 +10,7 @@ def test_inference_basic(cli_runner, config_path): with patch("axolotl.cli.inference.do_inference") as mock: result = cli_runner.invoke( cli, - ["inference", str(config_path), "--no-accelerate"], + ["inference", str(config_path), "--launcher", "python"], catch_exceptions=False, ) @@ -23,9 +23,124 @@ def test_inference_gradio(cli_runner, config_path): with patch("axolotl.cli.inference.do_inference_gradio") as mock: result = cli_runner.invoke( cli, - ["inference", str(config_path), "--no-accelerate", "--gradio"], + ["inference", str(config_path), "--launcher", "python", "--gradio"], catch_exceptions=False, ) assert mock.called assert result.exit_code == 0 + + +def test_inference_with_launcher_args_torchrun(cli_runner, config_path): + """Test inference with torchrun launcher arguments""" + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "inference", + str(config_path), + "--launcher", + "torchrun", + "--", + "--nproc_per_node=2", + "--nnodes=1", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify launcher args are passed to torchrun + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "torchrun" + assert "--nproc_per_node=2" in called_cmd + assert "--nnodes=1" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.inference" in called_cmd + + +def test_inference_with_launcher_args_accelerate(cli_runner, config_path): + """Test inference with accelerate launcher arguments""" + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "inference", + str(config_path), + "--launcher", + "accelerate", + "--", + "--config_file=accelerate_config.yml", + "--num_processes=4", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify launcher args are passed to accelerate + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + assert "--config_file=accelerate_config.yml" in called_cmd + assert "--num_processes=4" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.inference" in called_cmd + + +def test_inference_gradio_with_launcher_args(cli_runner, config_path): + """Test inference with gradio and launcher arguments""" + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "inference", + str(config_path), + "--launcher", + "accelerate", + "--gradio", + "--", + "--num_processes=2", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify both gradio flag and launcher args are present + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + assert "--num_processes=2" in called_cmd + assert "--gradio" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.inference" in called_cmd + + +def test_inference_backward_compatibility_no_launcher_args(cli_runner, config_path): + """Test that existing inference commands work without launcher args""" + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "inference", + str(config_path), + "--launcher", + "accelerate", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify no launcher args contamination + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + # Should not contain any extra launcher args + launcher_section = called_cmd[2 : called_cmd.index("-m")] + assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m' diff --git a/tests/cli/test_cli_interface.py b/tests/cli/test_cli_interface.py index 8b5fec17f..ebd91ea60 100644 --- a/tests/cli/test_cli_interface.py +++ b/tests/cli/test_cli_interface.py @@ -18,11 +18,10 @@ def test_build_command(): assert result == [ "accelerate", "launch", - "--learning-rate", - "0.0001", - "--batch-size", - "8", - "--debug", + "--learning-rate=0.0001", + "--batch-size=8", + "--debug=True", + "--use-fp16=False", ] @@ -38,7 +37,7 @@ def test_invalid_command_options(cli_runner): ], ) assert result.exit_code != 0 - assert "No such option" in result.output + assert "does not exist" in result.output def test_required_config_argument(cli_runner): diff --git a/tests/cli/test_cli_merge_sharded_fsdp_weights.py b/tests/cli/test_cli_merge_sharded_fsdp_weights.py index ec96b4ed4..de13b28ed 100644 --- a/tests/cli/test_cli_merge_sharded_fsdp_weights.py +++ b/tests/cli/test_cli_merge_sharded_fsdp_weights.py @@ -1,7 +1,5 @@ """pytest tests for axolotl CLI merge_sharded_fsdp_weights command.""" -# pylint: disable=duplicate-code - from unittest.mock import patch from axolotl.cli.main import cli @@ -11,9 +9,101 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path): """Test merge_sharded_fsdp_weights command without accelerate""" with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock: result = cli_runner.invoke( - cli, ["merge-sharded-fsdp-weights", str(config_path), "--no-accelerate"] + cli, + ["merge-sharded-fsdp-weights", str(config_path), "--launcher", "python"], ) assert mock.called assert mock.call_args.kwargs["config"] == str(config_path) assert result.exit_code == 0 + + +def test_merge_sharded_fsdp_weights_with_launcher_args_torchrun( + cli_runner, config_path +): + """Test merge-sharded-fsdp-weights with torchrun launcher arguments""" + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "merge-sharded-fsdp-weights", + str(config_path), + "--launcher", + "torchrun", + "--", + "--nproc_per_node=2", + "--nnodes=1", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify launcher args are passed to torchrun + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "torchrun" + assert "--nproc_per_node=2" in called_cmd + assert "--nnodes=1" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd + + +def test_merge_sharded_fsdp_weights_with_launcher_args_accelerate( + cli_runner, config_path +): + """Test merge-sharded-fsdp-weights with accelerate launcher arguments""" + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "merge-sharded-fsdp-weights", + str(config_path), + "--launcher", + "accelerate", + "--", + "--config_file=accelerate_config.yml", + "--num_processes=4", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify launcher args are passed to accelerate + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + assert "--config_file=accelerate_config.yml" in called_cmd + assert "--num_processes=4" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd + + +def test_merge_sharded_fsdp_weights_backward_compatibility_no_launcher_args( + cli_runner, config_path +): + """Test that existing merge-sharded-fsdp-weights commands work without launcher args""" + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "merge-sharded-fsdp-weights", + str(config_path), + "--launcher", + "accelerate", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify no launcher args contamination + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + # Should not contain any extra launcher args + launcher_section = called_cmd[2 : called_cmd.index("-m")] + assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m' diff --git a/tests/cli/test_cli_sweeps.py b/tests/cli/test_cli_sweeps.py index 40b360717..1b14f5aca 100644 --- a/tests/cli/test_cli_sweeps.py +++ b/tests/cli/test_cli_sweeps.py @@ -2,7 +2,7 @@ unit tests for generating sweep configurations """ -from axolotl.cli.main import generate_sweep_configs +from axolotl.cli.utils import generate_sweep_configs def test_generate_sweep_configs_no_pairs(): diff --git a/tests/cli/test_cli_train.py b/tests/cli/test_cli_train.py index 473913599..1251ab3c0 100644 --- a/tests/cli/test_cli_train.py +++ b/tests/cli/test_cli_train.py @@ -18,7 +18,9 @@ class TestTrainCommand(BaseCliTest): def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config): """Test basic successful execution""" - self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "train") + self._test_basic_execution( + cli_runner, tmp_path, valid_test_config, "train", train=True + ) def test_train_basic_execution_no_accelerate( self, cli_runner, tmp_path, valid_test_config @@ -37,7 +39,8 @@ class TestTrainCommand(BaseCliTest): [ "train", str(config_path), - "--no-accelerate", + "--launcher", + "python", ], catch_exceptions=False, ) @@ -59,11 +62,10 @@ class TestTrainCommand(BaseCliTest): [ "train", str(config_path), - "--learning-rate", - "1e-4", - "--micro-batch-size", - "2", - "--no-accelerate", + "--learning-rate=1e-4", + "--micro-batch-size=2", + "--launcher", + "python", ], catch_exceptions=False, ) @@ -73,3 +75,177 @@ class TestTrainCommand(BaseCliTest): cfg = mock_train.call_args[1]["cfg"] assert cfg["learning_rate"] == 1e-4 assert cfg["micro_batch_size"] == 2 + + def test_train_with_launcher_args_torchrun( + self, cli_runner, tmp_path, valid_test_config + ): + """Test train with torchrun launcher arguments""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("os.execvpe") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--launcher", + "torchrun", + "--", + "--nproc_per_node=2", + "--nnodes=1", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify launcher args are passed to torchrun + called_cmd = mock_subprocess.call_args.args[1] + assert called_cmd[0] == "torchrun" + assert "--nproc_per_node=2" in called_cmd + assert "--nnodes=1" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.train" in called_cmd + + def test_train_with_launcher_args_accelerate( + self, cli_runner, tmp_path, valid_test_config + ): + """Test train with accelerate launcher arguments""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("os.execvpe") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--launcher", + "accelerate", + "--", + "--config_file=accelerate_config.yml", + "--num_processes=4", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify launcher args are passed to accelerate + assert mock_subprocess.call_args.args[0] == "accelerate" + called_cmd = mock_subprocess.call_args.args[1] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + assert "--config_file=accelerate_config.yml" in called_cmd + assert "--num_processes=4" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.train" in called_cmd + + def test_train_backward_compatibility_no_launcher_args( + self, cli_runner, tmp_path, valid_test_config + ): + """Test that existing train commands work without launcher args""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("os.execvpe") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--launcher", + "accelerate", + "--learning-rate", + "1e-4", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify no launcher args contamination + assert mock_subprocess.call_args.args[0] == "accelerate" + called_cmd = mock_subprocess.call_args.args[1] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + # Should not contain any extra launcher args + launcher_section = called_cmd[2 : called_cmd.index("-m")] + assert ( + len(launcher_section) == 0 + ) # No launcher args between 'launch' and '-m' + + def test_train_mixed_args_with_launcher_args( + self, cli_runner, tmp_path, valid_test_config + ): + """Test train with both regular CLI args and launcher args""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("os.execvpe") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--launcher", + "torchrun", + "--learning-rate", + "2e-4", + "--micro-batch-size", + "4", + "--", + "--nproc_per_node=8", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + assert mock_subprocess.call_args.args[0] == "torchrun" + called_cmd = mock_subprocess.call_args.args[1] + # Verify launcher args + assert "--nproc_per_node=8" in called_cmd + # Verify axolotl args are also present + assert "--learning-rate=2e-4" in called_cmd + assert "--micro-batch-size=4" in called_cmd + + def test_train_cloud_with_launcher_args( + self, cli_runner, tmp_path, valid_test_config + ): + """Test train with cloud and launcher arguments""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + cloud_path = tmp_path / "cloud.yml" + cloud_path.write_text("provider: modal\ngpu: a100") + + with patch("axolotl.cli.cloud.do_cli_train") as mock_cloud_train: + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--cloud", + str(cloud_path), + "--launcher", + "torchrun", + "--", + "--nproc_per_node=4", + "--nnodes=2", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_cloud_train.assert_called_once() + + # Verify cloud training was called with launcher args + call_kwargs = mock_cloud_train.call_args.kwargs + assert call_kwargs["launcher"] == "torchrun" + assert call_kwargs["launcher_args"] == ["--nproc_per_node=4", "--nnodes=2"] diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index 2dab5bba9..431c35c3c 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -1,7 +1,5 @@ """pytest tests for axolotl CLI utils.""" -# pylint: disable=redefined-outer-name - import json from unittest.mock import Mock, patch @@ -25,7 +23,7 @@ MOCK_TREE_RESPONSE = { def mock_responses(): """Mock responses for API and file downloads""" - def mock_get(url, timeout=None): # pylint: disable=unused-argument + def mock_get(url, timeout=None): response = Mock() if "api.github.com" in url: response.text = json.dumps(MOCK_TREE_RESPONSE) @@ -72,3 +70,160 @@ def test_fetch_from_github_network_error(): with patch("requests.get", side_effect=requests.RequestException): with pytest.raises(requests.RequestException): fetch_from_github("examples/", None) + + +def assert_launcher_args_in_command( + mock_subprocess_call, + launcher: str, + expected_launcher_args: list[str], + command_module: str, +): + """ + Helper function to verify launcher arguments are properly passed in subprocess calls. + + Args: + mock_subprocess_call: The mock subprocess.run call + launcher: Expected launcher ("accelerate", "torchrun", etc.) + expected_launcher_args: List of expected launcher arguments + command_module: Expected module name (e.g., "axolotl.cli.train") + """ + assert mock_subprocess_call.called, "subprocess.run should have been called" + called_cmd = mock_subprocess_call.call_args.args[0] + + # Verify launcher + assert called_cmd[0] == launcher, ( + f"Expected launcher {launcher}, got {called_cmd[0]}" + ) + + # Verify launcher args are present + for arg in expected_launcher_args: + assert arg in called_cmd, ( + f"Expected launcher arg '{arg}' not found in command: {called_cmd}" + ) + + # Verify module is present + assert "-m" in called_cmd, "Expected -m flag for module execution" + assert command_module in called_cmd, ( + f"Expected module {command_module} not found in command: {called_cmd}" + ) + + +def assert_no_launcher_args_contamination(mock_subprocess_call, launcher: str): + """ + Helper function to verify no unwanted launcher arguments are present. + + Args: + mock_subprocess_call: The mock subprocess.run call + launcher: Expected launcher ("accelerate", "torchrun", etc.) + """ + assert mock_subprocess_call.called, "subprocess.run should have been called" + called_cmd = mock_subprocess_call.call_args.args[0] + + if launcher == "accelerate": + # For accelerate, launcher args should be between 'launch' and '-m' + launch_idx = called_cmd.index("launch") + m_idx = called_cmd.index("-m") + launcher_section = called_cmd[launch_idx + 1 : m_idx] + assert len(launcher_section) == 0, ( + f"Unexpected launcher args found: {launcher_section}" + ) + elif launcher == "torchrun": + # For torchrun, launcher args should be between 'torchrun' and '-m' + torchrun_idx = called_cmd.index("torchrun") + m_idx = called_cmd.index("-m") + launcher_section = called_cmd[torchrun_idx + 1 : m_idx] + assert len(launcher_section) == 0, ( + f"Unexpected launcher args found: {launcher_section}" + ) + + +@pytest.fixture +def common_launcher_args(): + """Fixture providing common launcher argument combinations for testing.""" + return { + "torchrun": ["--nproc_per_node=2", "--nnodes=1"], + "accelerate": ["--config_file=accelerate_config.yml", "--num_processes=4"], + } + + +def test_add_default_rdzv_args_with_endpoint(): + """Test that default RDZV args are added when rdzv_endpoint is present.""" + from axolotl.cli.utils.train import _add_default_rdzv_args + + launcher_args = ["--nnodes=2", "--rdzv_endpoint=127.0.0.1:29400"] + result = _add_default_rdzv_args(launcher_args) + + # Should have added rdzv_backend + assert "--rdzv_backend" in result + assert "c10d" in result + + # Original args should still be present + assert "--nnodes=2" in result + assert "--rdzv_endpoint=127.0.0.1:29400" in result + + +def test_add_default_rdzv_args_with_existing_backend(): + """Test that existing rdzv_backend is not overridden.""" + from axolotl.cli.utils.train import _add_default_rdzv_args + + launcher_args = [ + "--nnodes=2", + "--rdzv_endpoint=127.0.0.1:29400", + "--rdzv_backend=static", + ] + result = _add_default_rdzv_args(launcher_args) + + # Should not add another rdzv_backend + backend_count = sum(1 for arg in result if "--rdzv_backend" in arg) + assert backend_count == 1 + assert "--rdzv_backend=static" in result + + +def test_add_default_rdzv_args_with_existing_id(): + """Test that existing rdzv_id is not overridden.""" + from axolotl.cli.utils.train import _add_default_rdzv_args + + launcher_args = [ + "--nnodes=2", + "--rdzv_endpoint=127.0.0.1:29400", + "--rdzv_id=my_job_123", + ] + result = _add_default_rdzv_args(launcher_args) + + # Should not add another rdzv_id + id_count = sum(1 for arg in result if "--rdzv_id" in arg) + assert id_count == 1 + assert "--rdzv_id=my_job_123" in result + + # Should still add rdzv_backend + assert "--rdzv_backend" in result + assert "c10d" in result + + +def test_add_default_rdzv_args_without_endpoint(): + """Test that no RDZV args are added when rdzv_endpoint is not present.""" + from axolotl.cli.utils.train import _add_default_rdzv_args + + launcher_args = ["--nnodes=2", "--nproc_per_node=4"] + result = _add_default_rdzv_args(launcher_args) + + # Should not add any rdzv args + assert "--rdzv_backend" not in result + assert result == launcher_args + + +def test_add_default_rdzv_args_with_all_existing(): + """Test that no defaults are added when all RDZV args are present.""" + from axolotl.cli.utils.train import _add_default_rdzv_args + + launcher_args = [ + "--nnodes=2", + "--rdzv_endpoint=127.0.0.1:29400", + "--rdzv_backend=static", + "--rdzv_id=existing_job", + ] + result = _add_default_rdzv_args(launcher_args) + + # Should not add any additional args + assert len(result) == len(launcher_args) + assert result == launcher_args diff --git a/tests/conftest.py b/tests/conftest.py index 12014e78e..d3b9407ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,32 +2,38 @@ import functools import importlib +import logging import os import shutil import sys import tempfile import time from pathlib import Path +from typing import Generator import datasets import pytest import requests +import torch from huggingface_hub import snapshot_download from huggingface_hub.errors import LocalEntryNotFoundError from tokenizers import AddedToken from transformers import AutoTokenizer +from axolotl.utils.dict import DictDefault + from tests.hf_offline_utils import ( enable_hf_offline, hf_offline_context, ) +logging.getLogger("filelock").setLevel(logging.CRITICAL) + def retry_on_request_exceptions(max_retries=3, delay=1): - # pylint: disable=duplicate-code def decorator(func): @functools.wraps(func) - def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements + def wrapper(*args, **kwargs): for attempt in range(max_retries): try: return func(*args, **kwargs) @@ -162,7 +168,7 @@ def download_argilla_distilabel_intel_orca_dpo_dataset(): # @disable_hf_offline # def dataset_fozzie_alpaca_dpo_dataset( # download_fozzie_alpaca_dpo_dataset, -# ): # pylint: disable=unused-argument,redefined-outer-name +# ): # return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train") # # @@ -170,7 +176,7 @@ def download_argilla_distilabel_intel_orca_dpo_dataset(): # @disable_hf_offline # def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff( # download_fozzie_alpaca_dpo_dataset, -# ): # pylint: disable=unused-argument,redefined-outer-name +# ): # return load_dataset( # "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff" # ) @@ -350,7 +356,7 @@ def download_llama32_1b_model_fixture(): @enable_hf_offline def tokenizer_huggyllama( download_huggyllama_model_fixture, -): # pylint: disable=unused-argument,redefined-outer-name +): tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") tokenizer.pad_token = "" @@ -361,7 +367,7 @@ def tokenizer_huggyllama( @enable_hf_offline def tokenizer_huggyllama_w_special_tokens( tokenizer_huggyllama, -): # pylint: disable=redefined-outer-name +): tokenizer_huggyllama.add_special_tokens( { "bos_token": "", @@ -377,7 +383,7 @@ def tokenizer_huggyllama_w_special_tokens( @enable_hf_offline def tokenizer_llama2_7b( download_llama2_model_fixture, -): # pylint: disable=unused-argument,redefined-outer-name +): tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf") return tokenizer @@ -387,7 +393,7 @@ def tokenizer_llama2_7b( @enable_hf_offline def tokenizer_mistral_7b_instruct( download_mlx_mistral_7b_model_fixture, -): # pylint: disable=unused-argument,redefined-outer-name +): return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq") @@ -409,7 +415,7 @@ def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct): @pytest.fixture -def temp_dir(): +def temp_dir() -> Generator[str, None, None]: # Create a temporary directory _temp_dir = tempfile.mkdtemp() yield _temp_dir @@ -417,6 +423,11 @@ def temp_dir(): shutil.rmtree(_temp_dir) +@pytest.fixture(scope="function", autouse=True) +def torch_manual_seed(): + torch.manual_seed(42) + + @pytest.fixture(scope="function", autouse=True) def cleanup_monkeypatches(): from transformers import Trainer @@ -428,9 +439,7 @@ def cleanup_monkeypatches(): # original_fa2_forward = LlamaFlashAttention2.forward original_llama_attn_forward = LlamaAttention.forward original_llama_forward = LlamaForCausalLM.forward - original_trainer_inner_training_loop = ( - Trainer._inner_training_loop # pylint: disable=protected-access - ) + original_trainer_inner_training_loop = Trainer._inner_training_loop original_trainer_training_step = Trainer.training_step # monkey patches can happen inside the tests yield @@ -438,9 +447,7 @@ def cleanup_monkeypatches(): # LlamaFlashAttention2.forward = original_fa2_forward LlamaAttention.forward = original_llama_attn_forward LlamaForCausalLM.forward = original_llama_forward - Trainer._inner_training_loop = ( # pylint: disable=protected-access - original_trainer_inner_training_loop - ) + Trainer._inner_training_loop = original_trainer_inner_training_loop Trainer.training_step = original_trainer_training_step # Reset other known monkeypatches @@ -476,7 +483,7 @@ def cleanup_monkeypatches(): @pytest.fixture def dataset_winglian_tiny_shakespeare( download_ds_fixture_bundle: Path, -): # pylint: disable=redefined-outer-name +): ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare" return datasets.load_from_disk(ds_path) @@ -484,7 +491,7 @@ def dataset_winglian_tiny_shakespeare( @pytest.fixture def dataset_tatsu_lab_alpaca( download_ds_fixture_bundle: Path, -): # pylint: disable=redefined-outer-name +): ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca" return datasets.load_from_disk(ds_path)["train"] @@ -492,7 +499,7 @@ def dataset_tatsu_lab_alpaca( @pytest.fixture def dataset_mhenrichsen_alpaca_2k_test( download_ds_fixture_bundle: Path, -): # pylint: disable=redefined-outer-name +): ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test" return datasets.load_from_disk(ds_path)["train"] @@ -500,7 +507,7 @@ def dataset_mhenrichsen_alpaca_2k_test( @pytest.fixture def dataset_argilla_ultrafeedback_binarized_preferences_cleaned( download_ds_fixture_bundle: Path, -): # pylint: disable=redefined-outer-name +): ds_path = ( download_ds_fixture_bundle / "argilla__ultrafeedback-binarized-preferences-cleaned" @@ -511,7 +518,7 @@ def dataset_argilla_ultrafeedback_binarized_preferences_cleaned( @pytest.fixture def dataset_fozziethebeat_alpaca_messages_2k_dpo_test( download_ds_fixture_bundle: Path, -): # pylint: disable=redefined-outer-name +): ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test" return datasets.load_from_disk(ds_path)["train"] @@ -519,7 +526,7 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test( @pytest.fixture def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff( download_ds_fixture_bundle: Path, -): # pylint: disable=redefined-outer-name +): ds_path = ( download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff" @@ -527,7 +534,23 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff( return datasets.load_from_disk(ds_path)["train"] -# # pylint: disable=redefined-outer-name,unused-argument +@pytest.fixture(name="min_base_cfg") +def fixture_min_base_cfg(): + return DictDefault( + base_model="HuggingFaceTB/SmolLM2-135M", + learning_rate=1e-3, + datasets=[ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + micro_batch_size=1, + gradient_accumulation_steps=1, + ) + + +# @pytest.mark.skipif( os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1", reason="Not running in CI cache preload", diff --git a/tests/constants.py b/tests/constants.py index e024e6920..cd75bd339 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -3,6 +3,7 @@ This module contains constants and configuration dictionaries used for datasets and other utilities in the Axolotl project, specifically for testing. """ + # Configuration for Alpaca Messages Dataset ALPACA_MESSAGES_CONFIG_OG = { "path": "fozziethebeat/alpaca_messages_2k_dpo_test", diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index cde7b74ce..67481b2ad 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -1,7 +1,5 @@ """Unit tests for axolotl.core.builders""" -# pylint: disable=protected-access - import sys from pathlib import Path from unittest.mock import patch @@ -12,7 +10,7 @@ from axolotl.common.datasets import load_datasets from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.loaders import ModelLoader, load_tokenizer from axolotl.utils.config import normalize_config -from axolotl.utils.data.rl import load_prepare_preference_datasets +from axolotl.utils.data import prepare_preference_datasets from axolotl.utils.dict import DictDefault from axolotl.utils.schemas.enums import RLType @@ -64,7 +62,8 @@ def fixture_base_cfg(): "dataloader_num_workers": 1, "dataloader_pin_memory": True, "dataloader_prefetch_factor": 2, - "sequence_parallel_degree": 1, + "context_parallel_size": 1, + "tensor_parallel_size": 1, # Dtype "fp16": False, "bf16": False, @@ -81,6 +80,7 @@ def fixture_base_cfg(): "ddp_timeout": 1800, "ddp_bucket_cap_mb": 25, "ddp_broadcast_buffers": False, + "dataset_processes": 4, } ) @@ -279,7 +279,9 @@ class TestHFRLTrainerBuilder: # Other settings assert training_arguments.dataloader_num_workers == 1 assert training_arguments.dataloader_pin_memory is True - assert training_arguments.gradient_checkpointing is False + + # TODO(wing): restore once trl releases 0.22.0 + # assert training_arguments.gradient_checkpointing is True def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer): builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer) @@ -326,7 +328,6 @@ def rand_reward_func(prompts, completions) -> list[float]: ) def test_grpo_training_arguments(self, grpo_cfg, model, tokenizer, tmp_path): - rewards_dir = tmp_path / "rewards_test" self._write_rewards_file(rewards_dir) @@ -439,6 +440,7 @@ def rand_reward_func(prompts, completions) -> list[float]: ] else: raise ValueError(f"Unhandled cfg_string: {cfg_string}") + cfg["dataset_num_proc"] = 4 if cfg_string == "grpo_cfg": rewards_dir = tmp_path / "rewards_test" @@ -451,15 +453,19 @@ def rand_reward_func(prompts, completions) -> list[float]: # Only use mock for the commented out configs if dataset_name is not None: with patch( - "axolotl.utils.data.rl.load_dataset_w_config" + "axolotl.utils.data.rl.load_dataset_with_config" ) as mock_load_dataset: mock_load_dataset.return_value = request.getfixturevalue( dataset_name ) - train_dataset, eval_dataset = load_prepare_preference_datasets(cfg) + train_dataset, eval_dataset = prepare_preference_datasets( + cfg, tokenizer + ) else: # Load actual datasets for orpo_cfg and kto_cfg - train_dataset, eval_dataset = load_prepare_preference_datasets(cfg) + train_dataset, eval_dataset = prepare_preference_datasets( + cfg, tokenizer + ) builder.train_dataset = train_dataset builder.eval_dataset = eval_dataset @@ -468,7 +474,7 @@ def rand_reward_func(prompts, completions) -> list[float]: assert trainer.optimizer_cls_and_kwargs is not None - from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module + from axolotl.contribs.mit.muon import ( Muon, MuonOptimizerFactory, ) @@ -550,7 +556,7 @@ class TestHFCausalTrainerBuilder: assert trainer.optimizer_cls_and_kwargs is not None - from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module + from axolotl.contribs.mit.muon import ( Muon, MuonOptimizerFactory, ) @@ -590,6 +596,6 @@ class TestTrainerClsPlugin: except TypeError as e: # Error raised if trainer_cls is None assert "'tuple' object has no attribute 'config'" not in str(e) - except Exception: # pylint: disable=broad-exception-caught + except Exception: # Another error happens, so we passed trainer_cls to builder pass diff --git a/tests/e2e/integrations/test_cut_cross_entropy.py b/tests/e2e/integrations/test_cut_cross_entropy.py index 2ae59a15a..1ba05077c 100644 --- a/tests/e2e/integrations/test_cut_cross_entropy.py +++ b/tests/e2e/integrations/test_cut_cross_entropy.py @@ -4,7 +4,6 @@ Simple end-to-end test for Cut Cross Entropy integration import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils import get_pytorch_version @@ -13,8 +12,6 @@ from axolotl.utils.dict import DictDefault from ..utils import check_model_output_exists -# pylint: disable=duplicate-code - @pytest.fixture() def min_cfg(temp_dir): @@ -45,6 +42,7 @@ def min_cfg(temp_dir): "save_safetensors": True, "max_steps": 10, "bf16": "auto", + "save_first_step": False, } @@ -53,14 +51,12 @@ class TestCutCrossEntropyIntegration: e2e tests for cut_cross_entropy integration with Axolotl """ - # pylint: disable=redefined-outer-name def test_llama_w_cce(self, min_cfg, temp_dir): cfg = DictDefault(min_cfg) cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) major, minor, _ = get_pytorch_version() if (major, minor) < (2, 4): @@ -70,7 +66,6 @@ class TestCutCrossEntropyIntegration: train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) - # pylint: disable=redefined-outer-name def test_qwen2_w_cce(self, temp_dir): cfg = DictDefault( { @@ -100,13 +95,13 @@ class TestCutCrossEntropyIntegration: "save_safetensors": True, "max_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) major, minor, _ = get_pytorch_version() if (major, minor) < (2, 4): @@ -134,8 +129,7 @@ class TestCutCrossEntropyIntegration: cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) major, minor, _ = get_pytorch_version() if (major, minor) < (2, 4): diff --git a/tests/e2e/integrations/test_fp8.py b/tests/e2e/integrations/test_fp8.py new file mode 100644 index 000000000..7db63cc4d --- /dev/null +++ b/tests/e2e/integrations/test_fp8.py @@ -0,0 +1,61 @@ +""" +Simple end-to-end smoke tests for FP8 mixed precision training +""" + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import check_model_output_exists, require_torch_2_7_0 + + +class FP8IntegrationTestCase: + """ + e2e smoke tests for FP8 mixed precision training with Axolotl + """ + + @require_torch_2_7_0 + def test_fp8_single_gpu_smoke(self, temp_dir): + """Smoke test for single GPU FP8 + torch.compile training""" + + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "trust_remote_code": True, + "sequence_len": 512, + "val_set_size": 0.05, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, # Very short smoke test + "micro_batch_size": 1, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "sdp_attention": True, + "pad_to_seq_len": True, + "sample_packing": True, + "fp8": True, + "torch_compile": True, + "save_safetensors": True, + "save_first_step": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/integrations/test_hooks.py b/tests/e2e/integrations/test_hooks.py index 45d7200fb..b85505caa 100644 --- a/tests/e2e/integrations/test_hooks.py +++ b/tests/e2e/integrations/test_hooks.py @@ -5,7 +5,6 @@ e2e tests to make sure all the hooks are fired on the plugin import os from pathlib import Path -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.integrations.base import BasePlugin from axolotl.train import train @@ -29,85 +28,81 @@ class LogHooksPlugin(BasePlugin): except FileNotFoundError: pass - def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument + def post_trainer_create(self, cfg, trainer): with open( self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" ) as f: f.write("post_trainer_create\n") - def pre_model_load(self, cfg): # pylint: disable=unused-argument + def pre_model_load(self, cfg): with open( self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" ) as f: f.write("pre_model_load\n") - def post_model_build(self, cfg, model): # pylint: disable=unused-argument + def post_model_build(self, cfg, model): with open( self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" ) as f: f.write("post_model_build\n") - def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument + def pre_lora_load(self, cfg, model): with open( self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" ) as f: f.write("pre_lora_load\n") - def post_lora_load(self, cfg, model): # pylint: disable=unused-argument + def post_lora_load(self, cfg, model): with open( self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" ) as f: f.write("post_lora_load\n") - def post_model_load(self, cfg, model): # pylint: disable=unused-argument + def post_model_load(self, cfg, model): with open( self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" ) as f: f.write("post_model_load\n") - def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument + def create_optimizer(self, cfg, trainer): with open( self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" ) as f: f.write("create_optimizer\n") - def get_trainer_cls(self, cfg): # pylint: disable=unused-argument + def get_trainer_cls(self, cfg): with open( self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" ) as f: f.write("get_trainer_cls\n") - def create_lr_scheduler( - self, cfg, trainer, optimizer, num_training_steps - ): # pylint: disable=unused-argument + def create_lr_scheduler(self, cfg, trainer, optimizer, num_training_steps): with open( self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" ) as f: f.write("create_lr_scheduler\n") - def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument + def add_callbacks_pre_trainer(self, cfg, model): with open( self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" ) as f: f.write("add_callbacks_pre_trainer\n") return [] - def add_callbacks_post_trainer( - self, cfg, trainer - ): # pylint: disable=unused-argument + def add_callbacks_post_trainer(self, cfg, trainer): with open( self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" ) as f: f.write("add_callbacks_post_trainer\n") return [] - def post_train(self, cfg, model): # pylint: disable=unused-argument + def post_train(self, cfg, model): with open( self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" ) as f: f.write("post_train\n") - def post_train_unload(self, cfg): # pylint: disable=unused-argument + def post_train_unload(self, cfg): with open( self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8" ) as f: @@ -120,7 +115,6 @@ class TestPluginHooks: """ def test_plugin_hooks(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -154,14 +148,14 @@ class TestPluginHooks: "max_steps": 5, "flash_attention": True, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py index dad777947..d89044247 100644 --- a/tests/e2e/integrations/test_kd.py +++ b/tests/e2e/integrations/test_kd.py @@ -5,11 +5,9 @@ e2e tests for kd trainer support in Axolotl from pathlib import Path import pytest +import yaml +from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port -from axolotl.cli.args import TrainerCliArgs -from axolotl.common.datasets import load_datasets -from axolotl.train import train -from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.dict import DictDefault from tests.e2e.utils import check_tensorboard, require_torch_2_5_1 @@ -18,8 +16,8 @@ from tests.e2e.utils import check_tensorboard, require_torch_2_5_1 @pytest.fixture(name="kd_min_cfg") def min_cfg(temp_dir): return { - "base_model": "osllmai-community/Llama-3.2-1B", - "tokenizer_config": "axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer", + "base_model": "Qwen/Qwen3-0.6B", + "tokenizer_config": "winglian/qwen3-14b-math", "plugins": [ "axolotl.integrations.kd.KDPlugin", "axolotl.integrations.liger.LigerPlugin", @@ -27,25 +25,27 @@ def min_cfg(temp_dir): "liger_rms_norm": True, "liger_glu_activation": True, "torch_compile": True, - "chat_template": "llama3", + "chat_template": "qwen3", "kd_trainer": True, "kd_ce_alpha": 0.1, "kd_alpha": 0.9, "kd_temperature": 1.0, + "kd_beta": 0.0, + "kd_normalize_topk": True, "dataloader_prefetch_factor": 8, "dataloader_num_workers": 4, "dataloader_pin_memory": True, "datasets": [ { - "path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", - "type": "axolotl.integrations.kd.chat_template", - "field_messages": "messages_combined", + "path": "winglian/OpenThoughts-114k-math-correct-qwen3-14b-math-prepared-topk128-normalized", + "type": "chat_template", "split": "train", - "logprobs_field": "llm_text_generation_vllm_logprobs", - "temperature": 1.0, - "preprocess_shards": 2, + "split_thinking": True, + "eot_tokens": ["<|im_end|>"], + "data_files": ["train/batch-000000.parquet"], }, ], + "skip_prepare_dataset": True, "val_set_size": 0.0, "sequence_len": 2048, "sample_packing": True, @@ -67,6 +67,7 @@ def min_cfg(temp_dir): "output_dir": temp_dir, "save_safetensors": True, "use_tensorboard": True, + "save_first_step": False, } @@ -80,14 +81,24 @@ class TestKnowledgeDistillation: @require_torch_2_5_1 def test_llama_kd(self, temp_dir, kd_min_cfg): cfg = DictDefault(kd_min_cfg) - # pylint: disable=duplicate-code - cfg = validate_config(cfg) - prepare_plugins(cfg) - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, dataset_meta=dataset_meta) + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "1", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + assert (Path(temp_dir) / "model.safetensors").exists() check_tensorboard( temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high" @@ -108,17 +119,30 @@ class TestKnowledgeDistillation: "lora_r": 16, "lora_alpha": 32, "lora_dropout": 0.0, + "lora_modules_to_save": ["embed_tokens", "lm_head"], + "lora_mlp_kernel": False, + "lora_qkv_kernel": False, + "lora_o_kernel": False, } | kd_min_cfg ) - # pylint: disable=duplicate-code - cfg = validate_config(cfg) - prepare_plugins(cfg) - normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, dataset_meta=dataset_meta) + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "1", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) assert (Path(temp_dir) / "adapter_model.safetensors").exists() check_tensorboard( temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high" diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py index 8ecfc4746..285969963 100644 --- a/tests/e2e/integrations/test_liger.py +++ b/tests/e2e/integrations/test_liger.py @@ -2,7 +2,6 @@ Simple end-to-end test for Liger integration """ -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, prepare_plugins, validate_config @@ -18,7 +17,6 @@ class LigerIntegrationTestCase: @require_torch_2_4_1 def test_llama_wo_flce(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -51,21 +49,20 @@ class LigerIntegrationTestCase: "save_safetensors": True, "bf16": "auto", "max_steps": 5, + "save_first_step": False, } ) - # pylint: disable=duplicate-code + cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @require_torch_2_4_1 def test_llama_w_flce(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -98,14 +95,14 @@ class LigerIntegrationTestCase: "save_safetensors": True, "bf16": "auto", "max_steps": 5, + "save_first_step": False, } ) - # pylint: disable=duplicate-code + cfg = validate_config(cfg) prepare_plugins(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/integrations/test_llm_compressor.py b/tests/e2e/integrations/test_llm_compressor.py index 20bf821bf..dceecea9f 100644 --- a/tests/e2e/integrations/test_llm_compressor.py +++ b/tests/e2e/integrations/test_llm_compressor.py @@ -6,7 +6,6 @@ from pathlib import Path import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, prepare_plugins, validate_config @@ -82,14 +81,14 @@ class TestLLMCompressorIntegration: }, "save_compressed": save_compressed, }, + "save_first_step": False, } ) prepare_plugins(cfg) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) try: train(cfg=cfg, dataset_meta=dataset_meta) diff --git a/tests/e2e/kernels/test_geglu.py b/tests/e2e/kernels/test_geglu.py index 90403ab4a..78ba74c0e 100644 --- a/tests/e2e/kernels/test_geglu.py +++ b/tests/e2e/kernels/test_geglu.py @@ -19,8 +19,15 @@ def test_geglu_forward_shape(): assert out.device == gate.device -def test_geglu_forward_values(): +@pytest.mark.flaky(retries=1, delay=5) +@pytest.mark.parametrize( + "torch_seed", + [0, 42], +) +def test_geglu_forward_values(torch_seed): """Test GEGLU forward pass matches PyTorch reference implementation.""" + torch.manual_seed(torch_seed) + gate = torch.randn(2, 3, 64, device="cuda") up = torch.randn(2, 3, 64, device="cuda") @@ -33,6 +40,7 @@ def test_geglu_forward_values(): assert torch.allclose(triton_out, torch_out, rtol=1e-3) +@pytest.mark.flaky(retries=1, delay=5) @pytest.mark.parametrize( "torch_seed", [0, 42], @@ -77,6 +85,6 @@ def test_geglu_inplace_preservation(): assert not torch.equal(gate, gate_copy), "Gate should be modified in-place" assert not torch.equal(up, up_copy), "Up should be modified in-place" - assert not torch.equal( - grad_output, grad_copy - ), "Grad output should be modified in-place" + assert not torch.equal(grad_output, grad_copy), ( + "Grad output should be modified in-place" + ) diff --git a/tests/e2e/kernels/test_lora.py b/tests/e2e/kernels/test_lora.py index 5ad186cbf..9baceb668 100644 --- a/tests/e2e/kernels/test_lora.py +++ b/tests/e2e/kernels/test_lora.py @@ -1,7 +1,5 @@ """Tests for LoRA custom autograd.""" -# pylint: disable=invalid-name,redefined-outer-name - import pytest import torch from bitsandbytes.functional import QuantState @@ -64,6 +62,7 @@ def sample_tensors(): batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16 ), "W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16), + "b": torch.randn(out_dim, device="cuda", dtype=torch.float16), "scale": 0.5, "shapes": { "batch": batch_size, @@ -103,23 +102,24 @@ def mock_proj(): def test_get_lora_parameters(mock_proj): """Tests get_lora_parameters function""" # Test with LoRA enabled - W, _, A, B, s = get_lora_parameters(mock_proj) + W, b, _, A, B, s = get_lora_parameters(mock_proj) assert isinstance(W, torch.Tensor) assert W.shape == (128, 64) + assert b.shape == (128,) assert A.shape == (8, 64) assert B.shape == (128, 8) assert s == 0.5 # Test with LoRA disabled mock_proj.disable_adapters = True - W, _, A, B, s = get_lora_parameters(mock_proj) + W, b, _, A, B, s = get_lora_parameters(mock_proj) assert A is None and B is None and s is None # Test with merged state mock_proj.disable_adapters = False mock_proj.merged = True - W, _, A, B, s = get_lora_parameters(mock_proj) + W, b, _, A, B, s = get_lora_parameters(mock_proj) assert A is None and B is None and s is None @@ -127,6 +127,7 @@ def test_matmul_lora(sample_tensors): """Tests matmul_lora function""" X = sample_tensors["X"] W = sample_tensors["W"] + b = sample_tensors["b"] scale = sample_tensors["scale"] shapes = sample_tensors["shapes"] @@ -138,19 +139,20 @@ def test_matmul_lora(sample_tensors): B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16) # Test base matmul - out1 = matmul_lora(X, W, None, None, None, None) - expected1 = torch.matmul(X, W.t()) + out1 = matmul_lora(X, W, b, None, None, None, None) + matmul = torch.matmul(X, W.t()) + expected1 = matmul + b assert torch.allclose(out1, expected1, rtol=1e-3) # Test with LoRA - out2 = matmul_lora(X, W, None, A, B, scale) + out2 = matmul_lora(X, W, b, None, A, B, scale) lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t()) - expected2 = expected1 + lora_term + expected2 = matmul + lora_term + b assert torch.allclose(out2, expected2, rtol=1e-3) # Test 3D input reshaping X_3d = X.clone() - out3 = matmul_lora(X_3d, W, None, A, B, scale) + out3 = matmul_lora(X_3d, W, b, None, A, B, scale) assert out3.shape == (X.shape[0], X.shape[1], W.shape[0]) @@ -175,16 +177,19 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward output = LoRA_MLP.apply( X, gate_proj.weight, + gate_proj.bias, None, # gate_quant None, # gate_A None, # gate_B None, # gate_scale up_proj.weight, + up_proj.bias, None, # up_quant None, # up_A None, # up_B None, # up_scale down_proj.weight, + down_proj.bias, None, # down_quant None, # down_A None, # down_B @@ -243,16 +248,19 @@ def test_lora_mlp_with_adapters( output = LoRA_MLP.apply( X, gate_proj.weight, + gate_proj.bias, None, gate_A, gate_B, scale, up_proj.weight, + up_proj.bias, None, up_A, up_B, scale, down_proj.weight, + down_proj.bias, None, down_A, down_B, @@ -323,6 +331,7 @@ def test_lora_qkv(sample_tensors): X.requires_grad = True # Test without LoRA adapters + Q1, K1, V1 = LoRA_QKV.apply( X, q_weight, @@ -330,16 +339,19 @@ def test_lora_qkv(sample_tensors): None, None, None, + None, k_weight, None, None, None, None, + None, v_weight, None, None, None, None, + None, True, ) @@ -356,16 +368,19 @@ def test_lora_qkv(sample_tensors): X, q_weight, None, + None, q_A, q_B, scale, k_weight, None, + None, k_A, k_B, scale, v_weight, None, + None, v_A, v_B, scale, @@ -399,6 +414,7 @@ def test_lora_o(sample_tensors): """Tests LoRA output projection""" X = sample_tensors["X"] W = sample_tensors["W"] + b = sample_tensors["b"] scale = sample_tensors["scale"] shapes = sample_tensors["shapes"] @@ -411,7 +427,7 @@ def test_lora_o(sample_tensors): # Test forward pass X.requires_grad = True - output = LoRA_O.apply(X, W, None, A, B, scale) + output = LoRA_O.apply(X, W, b, None, A, B, scale) assert output.shape == (X.shape[0], X.shape[1], W.shape[0]) @@ -425,6 +441,7 @@ def test_with_quantization(sample_tensors, mock_quantstate): """Tests LoRA with quantized weights""" X = sample_tensors["X"] # [batch, seq, hidden] W = sample_tensors["W"] # [out, hidden] + b = sample_tensors["b"] # [out] scale = 0.5 shapes = sample_tensors["shapes"] @@ -436,13 +453,13 @@ def test_with_quantization(sample_tensors, mock_quantstate): B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16) # Test matmul with quantization - out = matmul_lora(X, W, mock_quantstate, A, B, scale) + out = matmul_lora(X, W, b, mock_quantstate, A, B, scale) assert out.shape == (X.shape[0], X.shape[1], W.shape[0]) assert not torch.isnan(out).any() # Test with different batch sizes X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16) - out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale) + out2 = matmul_lora(X2, W, b, mock_quantstate, A, B, scale) assert out2.shape == (4, 6, W.shape[0]) assert not torch.isnan(out2).any() @@ -459,11 +476,12 @@ def test_shapes_and_dimensions(batch, seq, hidden, rank, out): """Tests various input shapes and dimensions""" X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16) W = torch.randn(out, hidden, device="cuda", dtype=torch.float16) + b = torch.randn(out, device="cuda", dtype=torch.float16) A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16) B = torch.randn(out, rank, device="cuda", dtype=torch.float16) scale = 0.5 - result = matmul_lora(X, W, None, A, B, scale) + result = matmul_lora(X, W, b, None, A, B, scale) assert result.shape == (batch, seq, out) @@ -471,6 +489,7 @@ def test_gradient_flow(sample_tensors): """Tests gradient flow through LoRA layers""" X = sample_tensors["X"].clone() W = sample_tensors["W"].clone() + b = sample_tensors["b"].clone() scale = sample_tensors["scale"] shapes = sample_tensors["shapes"] @@ -486,7 +505,7 @@ def test_gradient_flow(sample_tensors): B.requires_grad = True # Forward pass - out = matmul_lora(X, W, None, A, B, scale) + out = matmul_lora(X, W, b, None, A, B, scale) loss = out.sum() # Backward pass diff --git a/tests/e2e/kernels/test_quantize.py b/tests/e2e/kernels/test_quantize.py index ea91407ef..60396584c 100644 --- a/tests/e2e/kernels/test_quantize.py +++ b/tests/e2e/kernels/test_quantize.py @@ -1,7 +1,5 @@ """Tests for quantization utility functions.""" -# pylint: disable=invalid-name - import torch from bitsandbytes.functional import QuantState diff --git a/tests/e2e/kernels/test_swiglu.py b/tests/e2e/kernels/test_swiglu.py index 60fdafb79..58d5e04a7 100644 --- a/tests/e2e/kernels/test_swiglu.py +++ b/tests/e2e/kernels/test_swiglu.py @@ -1,7 +1,5 @@ """Tests for SwiGLU activation function Triton kernels.""" -# pylint: disable=duplicate-code - import torch import torch.nn.functional as F @@ -74,6 +72,6 @@ def test_swiglu_inplace_preservation(): assert not torch.equal(gate, gate_copy), "Gate should be modified in-place" assert not torch.equal(up, up_copy), "Up should be modified in-place" - assert not torch.equal( - grad_output, grad_copy - ), "Grad output should be modified in-place" + assert not torch.equal(grad_output, grad_copy), ( + "Grad output should be modified in-place" + ) diff --git a/tests/e2e/multigpu/patched/test_sp.py b/tests/e2e/multigpu/patched/test_sp.py index e90def2b7..a005e6742 100644 --- a/tests/e2e/multigpu/patched/test_sp.py +++ b/tests/e2e/multigpu/patched/test_sp.py @@ -54,6 +54,7 @@ class TestSequenceParallelism: "micro_batch_size": micro_batch_size, "gradient_accumulation_steps": 2, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.00001, "optimizer": "adamw_8bit", "lr_scheduler": "cosine", @@ -66,8 +67,9 @@ class TestSequenceParallelism: "logging_steps": 1, "weight_decay": 0.0, "use_tensorboard": True, - "sequence_parallel_degree": 2, + "context_parallel_size": 2, "ring_attn_func": ring_attn_func, + "save_first_step": False, } ) @@ -91,7 +93,10 @@ class TestSequenceParallelism: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", threshold, "Train Loss is too high" + temp_dir + "/runs", + "train/train_loss", + threshold, + "Train Loss (%s) is too high", ) @pytest.mark.parametrize( @@ -100,13 +105,13 @@ class TestSequenceParallelism: (True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func (False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func # (False, 2, True, "batch_zigzag", 2.5), - (False, 2, False, None, 2.5), # defaults to batch_ring ring_attn_func + # (False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func ], ids=[ "sample_packing, varlen_llama3 ring_attn_func", "no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func", # "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func", - "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func", + # "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func", ], ) def test_sequence_parallel_training( diff --git a/tests/e2e/multigpu/solo/test_flex.py b/tests/e2e/multigpu/solo/test_flex.py index 42c3c00c8..881d75c25 100644 --- a/tests/e2e/multigpu/solo/test_flex.py +++ b/tests/e2e/multigpu/solo/test_flex.py @@ -31,7 +31,6 @@ class TestPackedFlex: @require_torch_2_6_0 def test_loss_llama(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -54,12 +53,14 @@ class TestPackedFlex: "gradient_accumulation_steps": 2, "gradient_checkpointing": True, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", "max_steps": 2, "use_tensorboard": True, "save_strategy": "no", + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): @@ -85,5 +86,5 @@ class TestPackedFlex: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high" ) diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py index 8ea2e3ce4..b48eb30e1 100644 --- a/tests/e2e/multigpu/solo/test_grpo.py +++ b/tests/e2e/multigpu/solo/test_grpo.py @@ -80,7 +80,7 @@ def start_vllm( cmd_env = env.copy() cmd_env.update({"VLLM_LOGGING_CONFIG_PATH": vllm_logging_json}) # start `trl vllm-serve` command in the background and capture the process id - process = subprocess.Popen( # pylint: disable=consider-using-with + process = subprocess.Popen( cmd, env=cmd_env, stdout=subprocess.DEVNULL if quiet else subprocess.PIPE, @@ -105,7 +105,7 @@ def start_vllm( print(f"{i}: VLLM server failed to start: {str(exc)}") # also check if the process.pid is still running - if not process.poll() is None: + if process.poll() is not None: break time.sleep(period_seconds) @@ -141,6 +141,7 @@ def recursive_kill(process: subprocess.Popen): os.kill(process.pid, 9) +@pytest.mark.skip(reason="flaky vllm tests in modal") class TestGRPO: """ Test case for GRPO training using multilpe GPUs @@ -222,6 +223,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) @@ -296,7 +298,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "lora_alpha": 16, "lora_dropout": 0.05, "lora_target_linear": True, - "sequence_parallel_degree": 2, + "context_parallel_size": 2, "flash_attention": True, "sequence_len": 1024, "special_tokens": { @@ -309,12 +311,14 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "warmup_steps": 10, "val_set_size": 0.0, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.0001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) @@ -400,12 +404,14 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "warmup_steps": 10, "val_set_size": 0.0, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.0001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) diff --git a/tests/e2e/multigpu/test_eval.py b/tests/e2e/multigpu/test_eval.py index 379562e40..504659a3a 100644 --- a/tests/e2e/multigpu/test_eval.py +++ b/tests/e2e/multigpu/test_eval.py @@ -21,7 +21,6 @@ class TestMultiGPUEval: """ def test_eval_sample_packing(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -38,12 +37,13 @@ class TestMultiGPUEval: "lora_dropout": 0.05, "lora_target_linear": True, "lora_modules_to_save": ["embed_tokens", "lm_head"], - "val_set_size": 0.004, + "val_set_size": 0.05, "special_tokens": {"pad_token": "<|endoftext|>"}, "datasets": [ { "path": "teknium/GPT4-LLM-Cleaned", "type": "alpaca", + "split": "train[:5%]", }, ], "num_epochs": 1, @@ -51,6 +51,7 @@ class TestMultiGPUEval: "micro_batch_size": 2, "gradient_accumulation_steps": 2, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.00001, "optimizer": "adamw_8bit", "lr_scheduler": "cosine", @@ -65,6 +66,7 @@ class TestMultiGPUEval: "logging_steps": 1, "weight_decay": 0.0, "use_tensorboard": True, + "save_first_step": False, } ) @@ -90,7 +92,6 @@ class TestMultiGPUEval: check_tensorboard(temp_dir + "/runs", "eval/loss", 2.5, "Eval Loss is too high") def test_eval(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -107,12 +108,13 @@ class TestMultiGPUEval: "lora_dropout": 0.05, "lora_target_linear": True, "lora_modules_to_save": ["embed_tokens", "lm_head"], - "val_set_size": 0.0004, + "val_set_size": 0.01, "special_tokens": {"pad_token": "<|endoftext|>"}, "datasets": [ { "path": "teknium/GPT4-LLM-Cleaned", "type": "alpaca", + "split": "train[:5%]", }, ], "num_epochs": 1, @@ -120,6 +122,7 @@ class TestMultiGPUEval: "micro_batch_size": 2, "gradient_accumulation_steps": 2, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.00001, "optimizer": "adamw_8bit", "lr_scheduler": "cosine", @@ -134,6 +137,7 @@ class TestMultiGPUEval: "logging_steps": 1, "weight_decay": 0.0, "use_tensorboard": True, + "save_first_step": False, } ) diff --git a/tests/e2e/multigpu/test_fp8_fsdp2.py b/tests/e2e/multigpu/test_fp8_fsdp2.py new file mode 100644 index 000000000..dc369f3de --- /dev/null +++ b/tests/e2e/multigpu/test_fp8_fsdp2.py @@ -0,0 +1,119 @@ +"""Test module for FP8 mixed precision with FSDP2 multi-GPU functionality.""" + +import os +from pathlib import Path + +import torch +import yaml +from accelerate.test_utils import execute_subprocess_async +from tbparse import SummaryReader +from transformers.testing_utils import get_torch_dist_unique_port + +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import most_recent_subdir, require_hopper, require_torch_2_7_0 + +AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent + + +def verify_fp8_training_success(temp_dir): + """Verify that FP8 training completed successfully by checking artifacts and loss.""" + output_path = Path(temp_dir) + + model_files = list(output_path.glob("*.bin")) + list( + output_path.glob("*.safetensors") + ) + assert len(model_files) > 0, "No model files found - training may have failed" + + checkpoint_files = list(output_path.glob("checkpoint-*")) + assert len(checkpoint_files) > 0, ( + "No checkpoint files found - training may have failed" + ) + + tb_log_path = most_recent_subdir(temp_dir + "/runs") + if tb_log_path: + event_files = sorted(os.listdir(tb_log_path)) + if event_files: + event_file = os.path.join(tb_log_path, event_files[0]) + reader = SummaryReader(event_file) + df = reader.scalars + train_loss_df = df[df.tag == "train/train_loss"] + if len(train_loss_df) > 0: + final_loss = train_loss_df.value.values[-1] + assert not torch.isnan(torch.tensor(final_loss)), ( + f"Training loss is NaN: {final_loss}" + ) + + +class TestFP8FSDP2: + """Test class for FP8 mixed precision with FSDP2 functionality.""" + + @require_torch_2_7_0 + @require_hopper + def test_fp8_fsdp2_smoke(self, temp_dir): + """Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training""" + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "trust_remote_code": True, + "sequence_len": 512, + "val_set_size": 0.05, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, # Very short smoke test + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", # Use standard optimizer for stability + "lr_scheduler": "cosine", + "sdp_attention": True, + "pad_to_seq_len": True, + "sample_packing": True, + # FP8 configuration + "fp8": True, + "fp8_enable_fsdp_float8_all_gather": True, + "torch_compile": True, + # FSDP2 configuration + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "LlamaDecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "save_safetensors": True, + "save_first_step": False, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_fp8_training_success(temp_dir) diff --git a/tests/e2e/multigpu/test_fsdp1.py b/tests/e2e/multigpu/test_fsdp1.py new file mode 100644 index 000000000..cb92c80b5 --- /dev/null +++ b/tests/e2e/multigpu/test_fsdp1.py @@ -0,0 +1,324 @@ +"""Test module for FSDP1 multi-GPU functionality.""" + +import os +from pathlib import Path + +import pytest +import torch +import yaml +from accelerate.test_utils import execute_subprocess_async +from tbparse import SummaryReader +from transformers.testing_utils import get_torch_dist_unique_port + +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import most_recent_subdir + +AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent + + +def verify_training_success(temp_dir): + """Verify that training completed successfully by checking artifacts and loss.""" + output_path = Path(temp_dir) + + model_files = list(output_path.glob("*.bin")) + list( + output_path.glob("*.safetensors") + ) + assert len(model_files) > 0, "No model files found - training may have failed" + + checkpoint_files = list(output_path.glob("checkpoint-*")) + assert len(checkpoint_files) > 0, ( + "No checkpoint files found - training may have failed" + ) + + tb_log_path = most_recent_subdir(temp_dir + "/runs") + if tb_log_path: + event_files = sorted(os.listdir(tb_log_path)) + if event_files: + event_file = os.path.join(tb_log_path, event_files[0]) + reader = SummaryReader(event_file) + df = reader.scalars + train_loss_df = df[df.tag == "train/train_loss"] + if len(train_loss_df) > 0: + final_loss = train_loss_df.value.values[-1] + assert not torch.isnan(torch.tensor(final_loss)), ( + f"Training loss is NaN: {final_loss}" + ) + + +class TestFSDP1: + """Test class for FSDP1 functionality.""" + + @pytest.mark.parametrize( + "fsdp_cpu_ram_efficient_loading", + [True, False], + ) + def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": "1", + "fsdp_config": { + "fsdp_offload_params": False, + "fsdp_cpu_ram_efficient_loading": fsdp_cpu_ram_efficient_loading, + "fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_sharding_strategy": "FULL_SHARD", + "fsdp_sync_module_states": True, + "fsdp_use_orig_params": False, + }, + "use_tensorboard": True, + "bf16": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir) + + @pytest.mark.parametrize( + "adapter_config", + [ + { + "adapter": "lora", + "load_in_4bit": False, + }, + { + "adapter": "qlora", + "load_in_4bit": True, + }, + ], + ) + def test_lora_sft(self, temp_dir, adapter_config): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "adapter": adapter_config["adapter"], + "load_in_4bit": adapter_config["load_in_4bit"], + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": "1", + "fsdp_config": { + "fsdp_offload_params": False, + "fsdp_cpu_ram_efficient_loading": True, + "fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_sharding_strategy": "FULL_SHARD", + "fsdp_sync_module_states": True, + "fsdp_use_orig_params": False, + }, + "use_tensorboard": True, + "bf16": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir) + + def test_dpo_fft(self, temp_dir): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "rl": "dpo", + "chat_template": "chatml", + "datasets": [ + { + "path": "Intel/orca_dpo_pairs", + "split": "train", + "type": "chatml.intel", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": "1", + "fsdp_config": { + "fsdp_offload_params": False, + "fsdp_cpu_ram_efficient_loading": True, + "fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_sharding_strategy": "FULL_SHARD", + "fsdp_sync_module_states": True, + "fsdp_use_orig_params": False, + }, + "use_tensorboard": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir) + + @pytest.mark.parametrize( + "adapter_config", + [ + { + "adapter": "lora", + "load_in_4bit": False, + }, + { + "adapter": "qlora", + "load_in_4bit": True, + }, + ], + ) + def test_dpo_lora(self, temp_dir, adapter_config): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "load_in_4bit": adapter_config["load_in_4bit"], + "rl": "dpo", + "chat_template": "chatml", + "sequence_len": 2048, + "adapter": adapter_config["adapter"], + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.01, + "datasets": [ + { + "path": "Intel/orca_dpo_pairs", + "split": "train", + "type": "chatml.intel", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": "1", + "fsdp_config": { + "fsdp_offload_params": False, + "fsdp_cpu_ram_efficient_loading": True, + "fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_sharding_strategy": "FULL_SHARD", + "fsdp_sync_module_states": True, + "fsdp_use_orig_params": False, + }, + "use_tensorboard": True, + "bf16": "auto", + "tf32": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir) diff --git a/tests/e2e/multigpu/test_fsdp2.py b/tests/e2e/multigpu/test_fsdp2.py new file mode 100644 index 000000000..8b7ee710e --- /dev/null +++ b/tests/e2e/multigpu/test_fsdp2.py @@ -0,0 +1,480 @@ +"""Test module for FSDP2 multi-GPU functionality.""" + +import os +from pathlib import Path + +import pytest +import torch +import yaml +from accelerate.test_utils import execute_subprocess_async +from tbparse import SummaryReader +from transformers.testing_utils import get_torch_dist_unique_port + +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0 + +AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent + + +def verify_training_success(temp_dir): + """Verify that training completed successfully by checking artifacts and loss.""" + output_path = Path(temp_dir) + + model_files = list(output_path.glob("*.bin")) + list( + output_path.glob("*.safetensors") + ) + assert len(model_files) > 0, "No model files found - training may have failed" + + checkpoint_files = list(output_path.glob("checkpoint-*")) + assert len(checkpoint_files) > 0, ( + "No checkpoint files found - training may have failed" + ) + + tb_log_path = most_recent_subdir(temp_dir + "/runs") + if tb_log_path: + event_files = sorted(os.listdir(tb_log_path)) + if event_files: + event_file = os.path.join(tb_log_path, event_files[0]) + reader = SummaryReader(event_file) + df = reader.scalars + train_loss_df = df[df.tag == "train/train_loss"] + if len(train_loss_df) > 0: + final_loss = train_loss_df.value.values[-1] + assert not torch.isnan(torch.tensor(final_loss)), ( + f"Training loss is NaN: {final_loss}" + ) + + +class TestFSDP2: + """Test class for FSDP2 functionality.""" + + @require_torch_2_7_0 + @pytest.mark.parametrize( + "fsdp_cpu_ram_efficient_loading", + [True, False], + ) + def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": fsdp_cpu_ram_efficient_loading, + "transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "bf16": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir) + + @require_torch_2_7_0 + @pytest.mark.parametrize("peft_use_dora", [True, False]) + def test_lora_sft(self, temp_dir, peft_use_dora): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "peft_use_dora": peft_use_dora, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "bf16": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir) + + @require_torch_2_7_0 + def test_lora_sft_kernels(self, temp_dir): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_target_linear": True, + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "bf16": True, + "lora_mlp_kernel": True, + "lora_qkv_kernel": True, + "lora_o_kernel": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir) + + @require_torch_2_7_0 + def test_qlora_sft(self, temp_dir): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "bf16": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir) + + @require_torch_2_7_0 + def test_qlora_sft_kernels(self, temp_dir): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 8, + "lora_alpha": 16, + "lora_target_linear": True, + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "bf16": True, + "lora_mlp_kernel": True, + "lora_qkv_kernel": True, + "lora_o_kernel": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir) + + @require_torch_2_7_0 + def test_dpo_fft(self, temp_dir): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "rl": "dpo", + "chat_template": "chatml", + "datasets": [ + { + "path": "Intel/orca_dpo_pairs", + "split": "train", + "type": "chatml.intel", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir) + + @require_torch_2_7_0 + def test_dpo_lora(self, temp_dir): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "rl": "dpo", + "chat_template": "chatml", + "datasets": [ + { + "path": "Intel/orca_dpo_pairs", + "split": "train", + "type": "chatml.intel", + }, + ], + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir) diff --git a/tests/e2e/multigpu/test_gemma3.py b/tests/e2e/multigpu/test_gemma3.py index 9bff25f40..51ec68b11 100644 --- a/tests/e2e/multigpu/test_gemma3.py +++ b/tests/e2e/multigpu/test_gemma3.py @@ -29,7 +29,6 @@ class TestMultiGPUGemma3: """ def test_lora_ddp_packed(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "axolotl-mirrors/gemma-3-4b-pt", @@ -64,12 +63,14 @@ class TestMultiGPUGemma3: }, "gradient_accumulation_steps": 2, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.0001, "optimizer": "adamw_8bit", "lr_scheduler": "cosine", "flash_attention": True, "use_tensorboard": True, "bf16": True, + "save_first_step": False, } ) @@ -91,5 +92,5 @@ class TestMultiGPUGemma3: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss (%s) is too high" ) diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 9c4bf5054..ffdbad942 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -35,7 +35,6 @@ class TestMultiGPULlama: """ def test_lora_ddp(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -62,12 +61,14 @@ class TestMultiGPULlama: "gradient_accumulation_steps": 2, # "gradient_checkpointing": True, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.00001, "optimizer": "adamw_8bit", "lr_scheduler": "cosine", "flash_attention": True, "use_tensorboard": True, "bf16": True, + "save_first_step": False, } ) @@ -89,7 +90,7 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.8, "Train Loss (%s) is too high" ) @pytest.mark.parametrize( @@ -97,7 +98,6 @@ class TestMultiGPULlama: [1, 2], ) def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -127,12 +127,14 @@ class TestMultiGPULlama: "gradient_accumulation_steps": gradient_accumulation_steps, # "gradient_checkpointing": True, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.00001, "optimizer": "adamw_8bit", "lr_scheduler": "cosine", "flash_attention": True, "use_tensorboard": True, "bf16": True, + "save_first_step": False, } ) @@ -154,11 +156,10 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) def test_dpo_lora_ddp(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -198,8 +199,9 @@ class TestMultiGPULlama: "max_steps": 2, "micro_batch_size": 2, "gradient_accumulation_steps": 2, - # "gradient_checkpointing": True, + "gradient_checkpointing": False, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "warmup_steps": 0, "learning_rate": 0.00001, "optimizer": "adamw_8bit", @@ -207,6 +209,7 @@ class TestMultiGPULlama: "flash_attention": True, "use_tensorboard": True, "bf16": True, + "save_first_step": False, } ) @@ -232,11 +235,10 @@ class TestMultiGPULlama: temp_dir + "/runs", "train/train_loss", loss_threshold, - "Train Loss is too high", + "Train Loss (%s) is too high", ) def test_dpo_qlora_ddp(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -276,8 +278,9 @@ class TestMultiGPULlama: "max_steps": 2, "micro_batch_size": 2, "gradient_accumulation_steps": 2, - # "gradient_checkpointing": True, + "gradient_checkpointing": False, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "warmup_steps": 0, "learning_rate": 0.00001, "optimizer": "adamw_8bit", @@ -285,6 +288,7 @@ class TestMultiGPULlama: "flash_attention": True, "use_tensorboard": True, "bf16": True, + "save_first_step": False, } ) @@ -310,7 +314,7 @@ class TestMultiGPULlama: temp_dir + "/runs", "train/train_loss", loss_threshold, - "Train Loss is too high", + "Train Loss (%s) is too high", ) @pytest.mark.parametrize( @@ -318,7 +322,6 @@ class TestMultiGPULlama: [1, 2], ) def test_fsdp(self, temp_dir, gradient_accumulation_steps): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -340,6 +343,7 @@ class TestMultiGPULlama: "gradient_accumulation_steps": gradient_accumulation_steps, # "gradient_checkpointing": True, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", @@ -349,7 +353,6 @@ class TestMultiGPULlama: "auto_wrap", ], "fsdp_config": { - "fsdp_limit_all_gathers": True, "fsdp_offload_params": False, "fsdp_sync_module_states": True, "fsdp_use_orig_params": False, @@ -359,6 +362,8 @@ class TestMultiGPULlama: "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", }, "use_tensorboard": True, + "seed": 42, + "save_first_step": False, } ) @@ -380,15 +385,17 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) @pytest.mark.parametrize( "fsdp_state_dict_type", - ["FULL_STATE_DICT", "SHARDED_STATE_DICT"], + [ + "FULL_STATE_DICT", + # "SHARDED_STATE_DICT", # not supported since intermediate checkpoints fail with fsdp1 + ], ) def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -407,11 +414,13 @@ class TestMultiGPULlama: }, ], "num_epochs": 1, - "max_steps": 2, + "max_steps": 3, + "save_steps": 2, "micro_batch_size": 2, "gradient_accumulation_steps": 2, # "gradient_checkpointing": True, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", @@ -421,7 +430,6 @@ class TestMultiGPULlama: "auto_wrap", ], "fsdp_config": { - "fsdp_limit_all_gathers": True, "fsdp_offload_params": False, "fsdp_sync_module_states": True, "fsdp_use_orig_params": False, @@ -431,6 +439,7 @@ class TestMultiGPULlama: "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", }, "use_tensorboard": True, + "save_first_step": False, } ) @@ -452,7 +461,7 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) @require_torch_2_6_0 @@ -467,7 +476,6 @@ class TestMultiGPULlama: def test_fsdp2_packed( self, temp_dir, attention_backend, fsdp_reshard_after_forward ): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -491,6 +499,7 @@ class TestMultiGPULlama: "gradient_accumulation_steps": 2, "gradient_checkpointing": True, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.00001, "optimizer": "adamw_torch_8bit", "lr_scheduler": "cosine", @@ -508,6 +517,7 @@ class TestMultiGPULlama: "fsdp_reshard_after_forward": fsdp_reshard_after_forward, }, "use_tensorboard": True, + "save_first_step": False, } ) if attention_backend == "flash": @@ -533,11 +543,11 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high" ) + @pytest.mark.skip("regression failure from v4.57.0") def test_fsdp_qlora_prequant_packed(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16", @@ -573,6 +583,7 @@ class TestMultiGPULlama: "gradient_accumulation_steps": 2, # "gradient_checkpointing": True, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", @@ -582,16 +593,16 @@ class TestMultiGPULlama: "auto_wrap", ], "fsdp_config": { - "fsdp_limit_all_gathers": True, "fsdp_offload_params": False, "fsdp_sync_module_states": True, "fsdp_use_orig_params": False, "fsdp_cpu_ram_efficient_loading": True, "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer", - "fsdp_state_dict_type": "SHARDED_STATE_DICT", + "fsdp_state_dict_type": "FULL_STATE_DICT", "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", }, "use_tensorboard": True, + "save_first_step": False, } ) @@ -613,7 +624,7 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) @pytest.mark.parametrize( @@ -635,7 +646,6 @@ class TestMultiGPULlama: def test_ds_zero3_packed( self, temp_dir, gradient_accumulation_steps, deepspeed, qlora ): - # pylint: disable=duplicate-code if qlora: adapter = { "adapter": "qlora", @@ -669,12 +679,14 @@ class TestMultiGPULlama: "micro_batch_size": 1, "gradient_accumulation_steps": gradient_accumulation_steps, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", "flash_attention": True, "deepspeed": str(AXOLOTL_ROOT / deepspeed), "use_tensorboard": True, + "save_first_step": False, **adapter, } ) @@ -697,7 +709,7 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.45, "Train Loss (%s) is too high" ) @pytest.mark.parametrize( @@ -709,7 +721,6 @@ class TestMultiGPULlama: [True, False], ) def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps, qlora): - # pylint: disable=duplicate-code if qlora: adapter = { "adapter": "qlora", @@ -743,12 +754,15 @@ class TestMultiGPULlama: "micro_batch_size": 1, "gradient_accumulation_steps": gradient_accumulation_steps, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", "flash_attention": True, "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"), "use_tensorboard": True, + "seed": 42, + "save_first_step": False, **adapter, } ) @@ -771,7 +785,7 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) @pytest.mark.parametrize( @@ -783,7 +797,6 @@ class TestMultiGPULlama: [True, False], ) def test_ds_zero1_packed(self, temp_dir, gradient_accumulation_steps, qlora): - # pylint: disable=duplicate-code if qlora: adapter = { "adapter": "qlora", @@ -817,12 +830,14 @@ class TestMultiGPULlama: "micro_batch_size": 1, "gradient_accumulation_steps": gradient_accumulation_steps, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", "flash_attention": True, "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"), "use_tensorboard": True, + "save_first_step": False, **adapter, } ) @@ -845,14 +860,13 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high" ) @pytest.mark.skip( reason="fix untrained tokens brittle with lots of edge cases in latest transformers" ) def test_fix_untrained_tokens(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -891,6 +905,7 @@ class TestMultiGPULlama: "save_safetensors": True, # "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"), "use_tensorboard": True, + "save_first_step": False, } ) @@ -912,5 +927,5 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss (%s) is too high" ) diff --git a/tests/e2e/multigpu/test_locking.py b/tests/e2e/multigpu/test_locking.py new file mode 100644 index 000000000..42502dfa3 --- /dev/null +++ b/tests/e2e/multigpu/test_locking.py @@ -0,0 +1,192 @@ +"""Tests for FileLockLoader class.""" + +import tempfile +import threading +import time +from pathlib import Path +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from axolotl.utils.data.lock import FileLockLoader +from axolotl.utils.dict import DictDefault + + +class TestFileLockLoader: + """Class with tests for FileLockLoader.""" + + @pytest.fixture + def temp_dir(self): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as tmp_dir: + yield Path(tmp_dir) + + @pytest.fixture + def cfg(self, temp_dir): + """Create a test configuration.""" + return DictDefault({"dataset_prepared_path": str(temp_dir)}) + + @pytest.fixture + def loader(self, cfg): + """Create a FileLockLoader instance for testing.""" + return FileLockLoader(cfg) + + def test_load_first_process(self, loader): + """Test load() when no ready flag exists (first process).""" + mock_load_fn = Mock(return_value="test_data") + + result = loader.load(mock_load_fn) + + # Should call the load function + mock_load_fn.assert_called_once() + assert result == "test_data" + + # Should create the ready flag + assert loader.ready_flag_path.exists() + + def test_load_subsequent_process(self, loader): + """Test load() when ready flag already exists (subsequent process).""" + # Create ready flag first + loader.ready_flag_path.touch() + + mock_load_fn = Mock(return_value="loaded_data") + + result = loader.load(mock_load_fn) + + # Should still call load function (to load the prepared data) + mock_load_fn.assert_called_once() + assert result == "loaded_data" + + def test_load_concurrent_processes(self, cfg): + """Test that concurrent processes coordinate correctly.""" + results = [] + call_count = 0 + + def slow_load_fn(): + nonlocal call_count + call_count += 1 + time.sleep(0.1) # Simulate slow loading + return f"data_{call_count}" + + def worker(): + loader = FileLockLoader(cfg) + result = loader.load(slow_load_fn) + results.append(result) + + # Start multiple threads simultaneously + threads = [threading.Thread(target=worker) for _ in range(3)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Only one thread should have done the initial loading + # All should return data, but the load function should be called + # once by the first process and once by each subsequent process + assert len(results) == 3 + assert all(result.startswith("data_") for result in results) + + @patch("time.sleep") + def test_load_waiting_for_ready_flag(self, mock_sleep, loader): + """Test that processes wait for the ready flag to appear.""" + mock_load_fn = Mock(return_value="waiting_data") + mock_ready_flag_path = Mock() + exists_call_count = 0 + + def mock_exists(): + nonlocal exists_call_count + exists_call_count += 1 + + if exists_call_count == 1: + # First check: ready flag exists (not first process) + return True + if exists_call_count <= 3: + # While loop checks: flag doesn't exist yet + return False + return True + + mock_ready_flag_path.exists.side_effect = mock_exists + + # Replace the ready_flag_path with our mock + original_path = loader.ready_flag_path + loader.ready_flag_path = mock_ready_flag_path + + try: + result = loader.load(mock_load_fn) + finally: + # Restore original path + loader.ready_flag_path = original_path + + # Should have slept twice while waiting + assert mock_sleep.call_count == 2 + mock_sleep.assert_called_with(1) + + # Should eventually call load function + mock_load_fn.assert_called_once() + assert result == "waiting_data" + + def test_complete_workflow_with_cleanup(self, loader): + """Test the complete load -> cleanup workflow.""" + mock_load_fn = Mock(return_value="test_data") + + # First process calls load (this should set up counter) + result = loader.load(mock_load_fn) + assert result == "test_data" + assert loader.ready_flag_path.exists() + assert loader.counter_path.exists() + + # Cleanup should remove everything since there's only one process + loader.cleanup() + assert not loader.ready_flag_path.exists() + assert not loader.counter_path.exists() + + def test_multiple_processes_workflow(self, loader): + """Test workflow with multiple processes.""" + # Simulate multiple processes by manually setting up counter + loader.ready_flag_path.touch() + loader.counter_path.write_text("3") # 3 processes + + # First process cleanup + loader.cleanup() + assert loader.ready_flag_path.exists() + assert loader.counter_path.read_text().strip() == "2" + + # Second process cleanup + loader.cleanup() + assert loader.ready_flag_path.exists() + assert loader.counter_path.read_text().strip() == "1" + + # Last process cleanup + loader.cleanup() + assert not loader.ready_flag_path.exists() + assert not loader.counter_path.exists() + + def test_load_exception_handling(self, loader): + """Test behavior when load_fn raises an exception.""" + + def failing_load_fn(): + raise ValueError("Load failed") + + with pytest.raises(ValueError, match="Load failed"): + loader.load(failing_load_fn) + + # Ready flag should not be created on failure + assert not loader.ready_flag_path.exists() + + def test_file_lock_called(self, loader): + """Test that FileLock is properly used.""" + mock_load_fn = Mock(return_value="locked_data") + + with patch("axolotl.utils.data.lock.FileLock") as mock_filelock: + mock_context = MagicMock() + mock_filelock.return_value.__enter__ = Mock(return_value=mock_context) + mock_filelock.return_value.__exit__ = Mock(return_value=None) + + loader.load(mock_load_fn) + + # Verify FileLock was called with correct path + mock_filelock.assert_called_once_with(str(loader.lock_file_path)) + + # Verify context manager was used + mock_filelock.return_value.__enter__.assert_called_once() + mock_filelock.return_value.__exit__.assert_called_once() diff --git a/tests/e2e/multigpu/test_qwen2.py b/tests/e2e/multigpu/test_qwen2.py deleted file mode 100644 index fa4efa32b..000000000 --- a/tests/e2e/multigpu/test_qwen2.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -E2E tests for multigpu qwen2 -""" - -from pathlib import Path - -import pytest -import yaml -from accelerate.test_utils import execute_subprocess_async -from transformers.testing_utils import get_torch_dist_unique_port - -from axolotl.utils.dict import DictDefault - - -class TestMultiGPUQwen2: - """ - Test case for Llama models using LoRA - """ - - @pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"]) - def test_qlora_fsdp_dpo(self, base_model, temp_dir): - # pylint: disable=duplicate-code - cfg = DictDefault( - { - "base_model": base_model, - "load_in_4bit": True, - "rl": "dpo", - "chat_template": "chatml", - "sequence_len": 2048, - "adapter": "qlora", - "lora_r": 8, - "lora_alpha": 16, - "lora_dropout": 0.05, - "lora_target_linear": True, - "val_set_size": 0.01, - "datasets": [ - { - "path": "Intel/orca_dpo_pairs", - "split": "train", - "type": "chatml.intel", - }, - ], - "num_epochs": 1, - "max_steps": 2, - "warmup_steps": 20, - "micro_batch_size": 2, - "gradient_accumulation_steps": 2, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_torch_fused", - "lr_scheduler": "cosine", - "flash_attention": True, - "bf16": "auto", - "tf32": True, - # "gradient_checkpointing": True, - "gradient_checkpointing_kwargs": { - "use_reentrant": False, - }, - "fsdp": [ - "full_shard", - "auto_wrap", - ], - "fsdp_config": { - "fsdp_limit_all_gathers": True, - "fsdp_offload_params": False, - "fsdp_sync_module_states": True, - "fsdp_use_orig_params": False, - "fsdp_cpu_ram_efficient_loading": False, - "fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", - "fsdp_state_dict_type": "FULL_STATE_DICT", - "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", - "fsdp_sharding_strategy": "FULL_SHARD", - }, - } - ) - - # write cfg to yaml file - Path(temp_dir).mkdir(parents=True, exist_ok=True) - with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: - fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) - - execute_subprocess_async( - [ - "axolotl", - "train", - str(Path(temp_dir) / "config.yaml"), - "--num-processes", - "2", - "--main-process-port", - f"{get_torch_dist_unique_port()}", - ] - ) diff --git a/tests/e2e/multigpu/test_ray.py b/tests/e2e/multigpu/test_ray.py index f2c812eb5..df41b1444 100644 --- a/tests/e2e/multigpu/test_ray.py +++ b/tests/e2e/multigpu/test_ray.py @@ -10,7 +10,10 @@ from accelerate.test_utils import execute_subprocess_async from axolotl.utils.dict import DictDefault -from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0 +from tests.e2e.utils import ( + check_tensorboard, + require_torch_2_7_0, +) AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent @@ -20,9 +23,8 @@ class TestMultiGPURay: Test cases for AnyScale Ray post training """ - @require_torch_lt_2_6_0 + @require_torch_2_7_0 def test_lora_ddp(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -48,6 +50,7 @@ class TestMultiGPURay: "micro_batch_size": 4, "gradient_accumulation_steps": 2, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.00001, "optimizer": "adamw_8bit", "lr_scheduler": "cosine", @@ -55,6 +58,7 @@ class TestMultiGPURay: "use_tensorboard": True, "use_ray": True, "ray_num_workers": 2, + "save_first_step": False, } ) @@ -75,16 +79,15 @@ class TestMultiGPURay: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) - @require_torch_lt_2_6_0 + @require_torch_2_7_0 @pytest.mark.parametrize( "gradient_accumulation_steps", [1, 2], ) def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -107,12 +110,14 @@ class TestMultiGPURay: "micro_batch_size": 1, "gradient_accumulation_steps": gradient_accumulation_steps, "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", "learning_rate": 0.00001, "optimizer": "adamw_torch", "lr_scheduler": "cosine", "flash_attention": True, "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"), "use_tensorboard": True, + "save_first_step": False, } ) @@ -133,5 +138,72 @@ class TestMultiGPURay: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" + ) + + @require_torch_2_7_0 + @pytest.mark.parametrize( + "gradient_accumulation_steps", + [1, 2], + ) + def test_sft_fsdp2_packed(self, temp_dir, gradient_accumulation_steps): + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sample_packing": True, + "pad_to_sequence_len": True, + "sequence_len": 1024, + "val_set_size": 0.01, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 1, + "gradient_accumulation_steps": gradient_accumulation_steps, + "output_dir": temp_dir, + "dataset_prepared_path": temp_dir + "/last_run_prepared", + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "LlamaDecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "save_first_step": False, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--use-ray", + "--ray-num-workers", + "2", + ] + ) + + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) diff --git a/tests/e2e/multigpu/test_tp.py b/tests/e2e/multigpu/test_tp.py new file mode 100644 index 000000000..9891a0906 --- /dev/null +++ b/tests/e2e/multigpu/test_tp.py @@ -0,0 +1,68 @@ +"""multigpu e2e test for tensor parallelism.""" + +from pathlib import Path + +import pytest +import yaml +from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port + +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import check_tensorboard, require_torch_2_7_0 + + +class TestTensorParallel: + """Test class for Tensor Parallel functionality.""" + + @pytest.mark.skip( + reason="TP doesn't work with models with tied weights (embeddings)" + ) + @require_torch_2_7_0 + def test_fft_sft(self, temp_dir): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "tensor_parallel_size": 2, + "lr_scheduler": "cosine", + "flash_attention": True, + "use_tensorboard": True, + "bf16": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high" + ) diff --git a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py index 76c383a92..73f883858 100644 --- a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py +++ b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py @@ -1,7 +1,5 @@ """Integration tests for LoRA activation and attention kernels.""" -# pylint: disable=redefined-outer-name - from pathlib import Path import pytest @@ -25,7 +23,9 @@ from axolotl.loaders.model import ModelLoader from axolotl.loaders.tokenizer import load_tokenizer from axolotl.monkeypatch.lora_kernels import ( apply_lora_kernel_patches, + find_self_attn_in_layer, get_attention_cls_from_config, + get_layers, patch_self_attn_lora, ) from axolotl.utils.dict import DictDefault @@ -86,7 +86,7 @@ def test_attention_patching_integration(model_name, attention_cls): cfg = DictDefault({"base_model": model_name}) # Store the original implementation - original_forward = getattr(attention_cls, "forward") + original_forward = attention_cls.forward # Apply patch patch_self_attn_lora(cfg) @@ -102,7 +102,7 @@ def test_attention_patching_integration(model_name, attention_cls): assert hasattr(attention_cls, "_original_forward") # Clean up - setattr(attention_cls, "forward", original_forward) + attention_cls.forward = original_forward delattr(attention_cls, "_original_forward") @@ -160,7 +160,7 @@ def test_geglu_model_integration(): """Test GeGLU activation with Gemma model.""" model = AutoModelForCausalLM.from_pretrained( "trl-internal-testing/tiny-Gemma2ForCausalLM", - torch_dtype=torch.float16, + dtype=torch.float16, device_map="cuda:0", ) peft_config = get_peft_config( @@ -377,9 +377,9 @@ def test_model_architecture(model_config): # Verify correct activation function layer = patched_model.model.model.layers[0] - assert ( - layer.mlp.forward.__func__ is model_config["expected_activation"] - ), f"Wrong activation for {model_config['name']}" + assert layer.mlp.forward.__func__ is model_config["expected_activation"], ( + f"Wrong activation for {model_config['name']}" + ) # Test forward pass inputs = get_test_inputs(model) @@ -388,13 +388,12 @@ def test_model_architecture(model_config): patched_output = patched_model(inputs).logits # Check outputs match - assert torch.allclose( - original_output, patched_output, rtol=1e-4 - ), f"Outputs don't match for {model_config['name']}" + assert torch.allclose(original_output, patched_output, rtol=1e-4), ( + f"Outputs don't match for {model_config['name']}" + ) -# pylint: disable=duplicate-code -def test_kernel_training_integration(): +def test_kernel_training_integration(temp_dir): """Test model loading with kernel patches enabled.""" from axolotl.cli.utils import load_model_and_tokenizer @@ -424,6 +423,14 @@ def test_kernel_training_integration(): } ) + # Write cfg to yaml file + path = Path(temp_dir) / "config.yaml" + with open(path, "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + # Load config + cfg = load_cfg(str(path)) + # Load model model, _, _ = load_model_and_tokenizer(cfg=cfg) @@ -501,3 +508,69 @@ def test_kernel_training_integration_auto_enable(temp_dir): break assert found_patched_attn + + +def test_kernel_training_integration_dropout_non_zero(temp_dir): + """Test model loading with dropout non-zero should not patch.""" + + from axolotl.cli.utils import load_model_and_tokenizer + + # Create minimal config + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_config": "HuggingFaceTB/SmolLM2-135M", + "learning_rate": 0.000001, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + } + ], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.1, + "lora_target_linear": True, + "sequence_len": 1024, + } + ) + + # Write cfg to yaml file + path = Path(temp_dir) / "config.yaml" + with open(path, "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + # Load config + cfg = load_cfg(str(path)) + + # Get original attention class + attention_cls = get_attention_cls_from_config(cfg) + + # Store original state before patching + original_forward_method = attention_cls.forward + + # Load model + model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg) + + # We call modelloader as that's where the patches are applied + # despite the fact that we're not using it to load the model + model_loader = ModelLoader(cfg, tokenizer) + + # Apply patch + model_loader.patch_manager._apply_self_attention_lora_patch() + + # Verify patch was not applied + assert attention_cls.forward == original_forward_method + + # Apply apply_lora_kernel_patches + model_loader.patch_manager._apply_lora_kernel_patch(model) + + # Verify patch was not applied + layers = get_layers(model) + for layer in layers: + for self_attn in find_self_attn_in_layer(layer): + assert not hasattr(self_attn, "apply_qkv") + assert not hasattr(self_attn, "apply_o") diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py index 490ce77fb..ef28cc406 100644 --- a/tests/e2e/patched/test_4d_multipack_llama.py +++ b/tests/e2e/patched/test_4d_multipack_llama.py @@ -4,7 +4,6 @@ E2E tests for multipack fft llama using 4d attention masks import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -20,7 +19,6 @@ class Test4dMultipackLlama(unittest.TestCase): @with_temp_dir def test_sdp_lora_packing(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -56,19 +54,18 @@ class Test4dMultipackLlama(unittest.TestCase): "save_steps": 3, "eval_steps": 4, "fp16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_torch_lora_packing(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -104,12 +101,12 @@ class Test4dMultipackLlama(unittest.TestCase): "save_steps": 3, "eval_steps": 4, "fp16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_activation_checkpointing.py b/tests/e2e/patched/test_activation_checkpointing.py index 45107b871..e8006c162 100644 --- a/tests/e2e/patched/test_activation_checkpointing.py +++ b/tests/e2e/patched/test_activation_checkpointing.py @@ -6,7 +6,6 @@ import pytest import transformers from torch.utils.checkpoint import checkpoint -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -33,10 +32,9 @@ class TestActivationCheckpointing: def test_activation_checkpointing_offload( self, temp_dir, - fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name + fix_checkpoint_after_test, gradient_checkpointing, ): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -70,13 +68,14 @@ class TestActivationCheckpointing: "bf16": True, "save_safetensors": True, "gradient_checkpointing": gradient_checkpointing, + "save_first_step": False, + "dataset_num_proc": 4, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_cli_integrations.py b/tests/e2e/patched/test_cli_integrations.py index 6c908faf1..6eba92689 100644 --- a/tests/e2e/patched/test_cli_integrations.py +++ b/tests/e2e/patched/test_cli_integrations.py @@ -10,7 +10,6 @@ from axolotl.cli.config import load_cfg from axolotl.utils.dict import DictDefault -# pylint: disable=duplicate-code class TestPluginArgs: """ test class for plugin args loaded from the config file diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py index e66b67e6d..9f4699854 100644 --- a/tests/e2e/patched/test_fa_xentropy.py +++ b/tests/e2e/patched/test_fa_xentropy.py @@ -5,7 +5,6 @@ E2E tests for lora llama import pytest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -24,7 +23,6 @@ class TestFAXentropyLlama: [1, 4], ) def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_steps): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -63,6 +61,7 @@ class TestFAXentropyLlama: "optimizer": "adamw_8bit", "lr_scheduler": "cosine", "use_tensorboard": True, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): @@ -73,12 +72,11 @@ class TestFAXentropyLlama: cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high" ) diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py index bd80221ce..cc5091403 100644 --- a/tests/e2e/patched/test_falcon_samplepack.py +++ b/tests/e2e/patched/test_falcon_samplepack.py @@ -6,7 +6,6 @@ import unittest import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -23,7 +22,6 @@ class TestFalconPatched(unittest.TestCase): @pytest.mark.skip(reason="no tiny models for testing with safetensors") @with_temp_dir def test_qlora(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "illuin/tiny-random-FalconForCausalLM", @@ -59,12 +57,12 @@ class TestFalconPatched(unittest.TestCase): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -72,7 +70,6 @@ class TestFalconPatched(unittest.TestCase): @pytest.mark.skip(reason="no tiny models for testing with safetensors") @with_temp_dir def test_ft(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "illuin/tiny-random-FalconForCausalLM", @@ -101,12 +98,12 @@ class TestFalconPatched(unittest.TestCase): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_flattening.py b/tests/e2e/patched/test_flattening.py new file mode 100644 index 000000000..2c247d406 --- /dev/null +++ b/tests/e2e/patched/test_flattening.py @@ -0,0 +1,81 @@ +""" +E2E tests for flattening batches +""" + +import pytest +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from ..utils import check_model_output_exists, check_tensorboard + + +class TestFAFlattening: + """ + Test case for Llama models using LoRA w batch flattening + """ + + @pytest.mark.parametrize( + "gradient_accumulation_steps", + [1, 4], + ) + def test_lora_packing_flattening(self, temp_dir, gradient_accumulation_steps): + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sequence_len": 1024, + "batch_flattening": True, + "flash_attention": True, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.05, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "chat_template": "chatml", + "datasets": [ + { + "path": "mlabonne/FineTome-100k", + "field_messages": "conversations", + "message_field_content": "value", + "message_field_role": "from", + "type": "chat_template", + "split": "train[:2%]", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "save_steps": 5, + "micro_batch_size": 2, + "gradient_accumulation_steps": gradient_accumulation_steps, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "use_tensorboard": True, + "save_first_step": False, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + + cfg = validate_config(cfg) + normalize_config(cfg) + + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high" + ) diff --git a/tests/e2e/patched/test_fsdp2_qlora.py b/tests/e2e/patched/test_fsdp2_qlora.py new file mode 100644 index 000000000..de9c929e1 --- /dev/null +++ b/tests/e2e/patched/test_fsdp2_qlora.py @@ -0,0 +1,30 @@ +"""Integration tests for FSDP2 Params4bit patches.""" + +import pytest +from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam + + +class TestFSDPPatchIntegration: + """Test FSDP patch integration.""" + + @pytest.mark.integration + def test_fsdp2_init_patches(self): + """Test that all patches can be applied together.""" + from axolotl.monkeypatch.fsdp2_qlora import ( + apply_init_sharded_param_patch, + apply_init_unsharded_param_patch, + ) + + original_init_sharded = FSDPParam._init_sharded_param + original_init_unsharded = FSDPParam.init_unsharded_param + + # Apply patches + apply_init_sharded_param_patch() + apply_init_unsharded_param_patch() + + assert FSDPParam._init_sharded_param != original_init_sharded, ( + "_init_sharded_param was not patched" + ) + assert FSDPParam.init_unsharded_param != original_init_unsharded, ( + "init_unsharded_param was not patched" + ) diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py index 49478f10c..f0c5df18a 100644 --- a/tests/e2e/patched/test_fused_llama.py +++ b/tests/e2e/patched/test_fused_llama.py @@ -7,7 +7,6 @@ import unittest import pytest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -24,13 +23,11 @@ class TestFusedLlama(unittest.TestCase): @with_temp_dir def test_fft_packing(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", "flash_attention": True, "pad_to_sequence_len": True, - "flash_attn_fuse_qkv": True, "flash_attn_fuse_mlp": True, "sample_packing": True, "sequence_len": 1024, @@ -54,6 +51,7 @@ class TestFusedLlama(unittest.TestCase): "max_steps": 10, "save_steps": 5, "eval_steps": 5, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): @@ -62,8 +60,7 @@ class TestFusedLlama(unittest.TestCase): cfg.fp16 = True cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index 327bb13f8..0dd748945 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -6,7 +6,6 @@ import unittest import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -23,7 +22,6 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): @with_temp_dir def test_lora_s2_attn(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -59,20 +57,19 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): "save_steps": 5, "eval_steps": 5, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_fft_s2_attn(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -102,13 +99,13 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): "save_steps": 5, "eval_steps": 5, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py index 1bad677b9..1833c750b 100644 --- a/tests/e2e/patched/test_lora_llama_multipack.py +++ b/tests/e2e/patched/test_lora_llama_multipack.py @@ -7,7 +7,6 @@ import unittest import pytest from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -23,7 +22,6 @@ class TestLoraLlama(unittest.TestCase): @with_temp_dir def test_lora_packing(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -56,6 +54,7 @@ class TestLoraLlama(unittest.TestCase): "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): @@ -65,8 +64,7 @@ class TestLoraLlama(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -74,7 +72,6 @@ class TestLoraLlama(unittest.TestCase): @pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available") @with_temp_dir def test_lora_gptq_packed(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "lilmeaty/SmolLM2-135M-Instruct-GPTQ", @@ -110,12 +107,12 @@ class TestLoraLlama(unittest.TestCase): "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py index 994b9dfca..e03941b07 100644 --- a/tests/e2e/patched/test_mistral_samplepack.py +++ b/tests/e2e/patched/test_mistral_samplepack.py @@ -4,13 +4,12 @@ E2E tests for lora llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault -from ..utils import check_model_output_exists, with_temp_dir +from ..utils import check_model_output_exists, require_torch_2_6_0, with_temp_dir class TestMistral(unittest.TestCase): @@ -18,9 +17,9 @@ class TestMistral(unittest.TestCase): Test case for Llama models using LoRA """ + @require_torch_2_6_0 @with_temp_dir def test_lora_packing(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2", @@ -56,19 +55,18 @@ class TestMistral(unittest.TestCase): "save_steps": 3, "eval_steps": 4, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft_packing(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2", @@ -98,12 +96,12 @@ class TestMistral(unittest.TestCase): "save_steps": 3, "eval_steps": 4, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index 6a84069ef..3517ff3db 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -4,7 +4,6 @@ E2E tests for mixtral import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -20,7 +19,6 @@ class TestMixtral(unittest.TestCase): @with_temp_dir def test_qlora(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "hf-internal-testing/Mixtral-tiny", @@ -53,19 +51,18 @@ class TestMixtral(unittest.TestCase): "save_steps": 3, "eval_steps": 4, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "hf-internal-testing/Mixtral-tiny", @@ -92,12 +89,12 @@ class TestMixtral(unittest.TestCase): "save_steps": 3, "eval_steps": 4, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py index 5ea88b001..aaaaf5fe2 100644 --- a/tests/e2e/patched/test_model_patches.py +++ b/tests/e2e/patched/test_model_patches.py @@ -45,6 +45,7 @@ class TestModelPatches(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -78,6 +79,7 @@ class TestModelPatches(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -87,5 +89,5 @@ class TestModelPatches(unittest.TestCase): assert ( "torch.jit" - in transformers.modeling_flash_attention_utils._get_unpad_data.__module__ # pylint: disable=protected-access + in transformers.modeling_flash_attention_utils._get_unpad_data.__module__ ) diff --git a/tests/e2e/patched/test_peft_embeddings.py b/tests/e2e/patched/test_peft_embeddings.py index d4f59a128..374ef97d8 100644 --- a/tests/e2e/patched/test_peft_embeddings.py +++ b/tests/e2e/patched/test_peft_embeddings.py @@ -15,7 +15,6 @@ class TestLlamaPeftEmbeddings: """ def test_peft_embeddings_upcast(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -49,6 +48,7 @@ class TestLlamaPeftEmbeddings: "bf16": "auto", "save_safetensors": True, "embeddings_skip_upcast": True, + "save_first_step": False, } ) diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py index ee2a3ffb4..77b2d99e5 100644 --- a/tests/e2e/patched/test_phi_multipack.py +++ b/tests/e2e/patched/test_phi_multipack.py @@ -4,7 +4,6 @@ E2E tests for lora llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -20,7 +19,6 @@ class TestPhiMultipack(unittest.TestCase): @with_temp_dir def test_ft_packed(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "microsoft/phi-1_5", @@ -55,20 +53,19 @@ class TestPhiMultipack(unittest.TestCase): "eval_steps": 3, "save_steps": 4, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_qlora_packed(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "microsoft/phi-1_5", @@ -107,13 +104,13 @@ class TestPhiMultipack(unittest.TestCase): "eval_steps": 3, "save_steps": 4, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py index cc1f3ddee..747b79dc7 100644 --- a/tests/e2e/patched/test_resume.py +++ b/tests/e2e/patched/test_resume.py @@ -7,7 +7,6 @@ import subprocess from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -23,7 +22,6 @@ class TestResumeLlama: @require_torch_2_6_0 def test_resume_lora_packed(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -59,6 +57,7 @@ class TestResumeLlama: "max_steps": 15, "use_tensorboard": True, "save_safetensors": True, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): @@ -67,8 +66,7 @@ class TestResumeLlama: cfg.fp16 = True cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) @@ -78,7 +76,6 @@ class TestResumeLlama: } ) normalize_config(resume_cfg) - cli_args = TrainerCliArgs() train(cfg=resume_cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py deleted file mode 100644 index 2b4d11b30..000000000 --- a/tests/e2e/patched/test_sp.py +++ /dev/null @@ -1,480 +0,0 @@ -"""Tests for sequence parallelism functionality.""" - -# pylint: disable=redefined-outer-name,unused-argument - -import functools -import sys -from unittest.mock import MagicMock, patch - -import pytest -import torch -from accelerate.state import PartialState - -from axolotl.monkeypatch.ring_attn import ( - get_ring_attn_group, - register_ring_attn, - set_ring_attn_group, -) -from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism -from axolotl.utils.dict import DictDefault -from axolotl.utils.schemas.enums import RingAttnFunc -from axolotl.utils.schemas.trl import TRLConfig - - -@pytest.fixture -def partial_state(): - """Create a real PartialState instance for testing.""" - state = PartialState() - return state - - -@pytest.fixture(name="cfg") -def fixture_cfg(): - cfg = DictDefault( - { - "base_model": "HuggingFaceTB/SmolLM2-135M", - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - "micro_batch_size": 1, - "gradient_accumulation_steps": 1, - "learning_rate": 1e-3, - "output_dir": "./model-out", - "sequence_len": 512, - "special_tokens": { - "pad_token": "<|endoftext|>", - }, - } - ) - - return cfg - - -@pytest.fixture -def sequence_parallel_batch(): - """Create a test batch for sequence parallelism tests.""" - batch_size = 1 - seq_len = 8 - - # Create test tensors - input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len) - attention_mask = torch.ones(batch_size, seq_len) - position_ids = torch.arange(seq_len).expand(batch_size, seq_len) - labels = input_ids.clone() - - # Create test batch - batch = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - "labels": labels, - } - - return batch - - -class TestRingAttention: - """Tests for the ring attention functionality.""" - - @patch("torch.distributed.get_rank") - @patch("torch.distributed.get_world_size") - def test_get_ring_attn_group_no_registration( - self, mock_world_size, mock_rank, partial_state - ): - """Test that get_ring_attn_group raises RuntimeError when no group has been registered.""" - # Setup mocks - mock_world_size.return_value = 4 - mock_rank.return_value = 0 - - # Verify that RuntimeError is raised when no group is registered - with pytest.raises( - RuntimeError, match="register_ring_attn\\(\\) not yet called" - ): - get_ring_attn_group() - - @patch("torch.distributed.new_group") - @patch("torch.distributed.get_rank") - @patch("torch.distributed.get_world_size") - def test_register_ring_attn( - self, mock_world_size, mock_rank, mock_new_group, partial_state - ): - """Test that ring attention groups are created correctly.""" - # Setup mocks - mock_world_size.return_value = 8 # 8 GPUs total - mock_rank.return_value = 3 # GPU #3 - mock_group = MagicMock() - mock_new_group.return_value = mock_group - - # Call register_ring_attn with size 4 - register_ring_attn( - sequence_parallel_degree=4, - heads_k_stride=1, - ring_attn_func=RingAttnFunc.VARLEN_LLAMA3, - ) - - # Verify the number of calls without examining the arguments - assert mock_new_group.call_count == 2 - - # Verify that new_group was called - mock_new_group.assert_called() - - # Clean up - set_ring_attn_group(None) - - -class TestConfigValidation: - """Tests for validating sequence parallelism configurations.""" - - @pytest.fixture(autouse=True) - def setup_mocks(self, monkeypatch): - """Set up mocks for all tests in this class.""" - # Mock the ring_flash_attn module - monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock()) - - @pytest.fixture - def base_cfg(self): - """Create a base configuration for testing.""" - return DictDefault( - { - "base_model": "HuggingFaceTB/SmolLM2-135M", - "datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}], - "micro_batch_size": 1, - "gradient_accumulation_steps": 1, - "learning_rate": 1e-3, - "output_dir": "./model-out", - "sequence_len": 512, - "special_tokens": {"pad_token": "<|endoftext|>"}, - } - ) - - @pytest.mark.parametrize( - "config_updates, expected_values, should_pass, error_msg", - [ - # Valid configuration - ( - {"sequence_parallel_degree": 2, "flash_attention": True}, - {"sequence_parallel_degree": 2, "flash_attention": True}, - True, - None, - ), - # Default sequence_parallel_degree - ({}, {"sequence_parallel_degree": 1}, True, None), - # Invalid: sequence_parallel_degree > 1 without flash_attention - ( - {"sequence_parallel_degree": 2, "flash_attention": False}, - None, - False, - "flash_attention: true must be set", - ), - # Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1 - ( - { - "sequence_parallel_degree": 2, - "flash_attention": True, - "sample_packing": True, - "micro_batch_size": 2, - "pad_to_sequence_len": True, - }, - None, - False, - "micro_batch_size must be set to 1", - ), - # Valid: Basic GRPO config - ( - { - "sequence_parallel_degree": 2, - "flash_attention": True, - "micro_batch_size": 2, - "trl": {"use_liger_loss": True}, - }, - { - "sequence_parallel_degree": 2, - "flash_attention": True, - "micro_batch_size": 2, - "trl": TRLConfig(use_liger_loss=True), - }, - True, - "GRPO + SP + Liger not currently supported", - ), - # Invalid: GRPO config with Liger loss - ( - { - "rl": "grpo", - "sequence_parallel_degree": 2, - "flash_attention": True, - "micro_batch_size": 2, - "trl": {"use_liger_loss": True}, - }, - None, - False, - "GRPO + SP + Liger not currently supported", - ), - ], - ids=[ - "valid_config", - "default_sp_degree", - "without_flash_attention", - "sample_packing_with_large_batch", - "valid_grpo", - "grpo_with_liger_loss", - ], - ) - def test_sequence_parallel_config_validation( - self, base_cfg, config_updates, expected_values, should_pass, error_msg - ): - """Test various sequence parallelism configuration scenarios.""" - from axolotl.utils.schemas.config import AxolotlInputConfig - - # Apply updates to base config - cfg = base_cfg - cfg.update(config_updates) - - if should_pass: - # Should validate without errors - config = AxolotlInputConfig(**cfg) - - # Check expected values - for key, value in expected_values.items(): - assert getattr(config, key) == value - else: - # Should raise exception - with pytest.raises(ValueError) as excinfo: - AxolotlInputConfig(**cfg) - assert error_msg in str(excinfo.value) - - @pytest.mark.parametrize( - "ring_attn_func, sample_packing, expected_func", - [ - (None, True, RingAttnFunc.VARLEN_LLAMA3), - (None, False, RingAttnFunc.BATCH_RING), - ], - ids=["default_with_sample_packing", "default_without_sample_packing"], - ) - def test_ring_attn_func_validation( - self, base_cfg, ring_attn_func, sample_packing, expected_func - ): - """Test ring_attn_func validation and defaults.""" - from axolotl.utils.schemas.config import AxolotlInputConfig - - # Apply updates to base config - cfg = base_cfg | { - "sequence_parallel_degree": 2, - "flash_attention": True, - "sample_packing": sample_packing, - } - - if ring_attn_func is not None: - cfg["ring_attn_func"] = ring_attn_func - - # Should validate without errors - config = AxolotlInputConfig(**cfg) - - # Check ring_attn_func value - assert config.ring_attn_func.value == expected_func - - def test_invalid_ring_attn_func(self, base_cfg): - """Test that an invalid ring_attn_func is rejected.""" - from axolotl.utils.schemas.config import AxolotlInputConfig - - # Invalid configuration with invalid ring_attn_func - cfg = base_cfg | { - "sequence_parallel_degree": 2, - "flash_attention": True, - "ring_attn_func": "INVALID_FUNC", - } - - # Should raise ValidationError - with pytest.raises(ValueError) as excinfo: - AxolotlInputConfig(**cfg) - - # Verify error message - assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value) - - -class TestApplySequenceParallelism: - """Tests for the apply_sequence_parallelism function.""" - - @pytest.fixture(autouse=True) - def mock_distributed(self, monkeypatch): - """Mock torch.distributed functions for testing.""" - # Mock is_initialized to return True - monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True) - - # Mock get_rank to return 0 by default - monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0) - - # Mock get_world_size to return 2 by default - monkeypatch.setattr( - torch.distributed, "get_world_size", lambda *args, **kwargs: 2 - ) - - # Mock the process group - monkeypatch.setattr( - "axolotl.monkeypatch.ring_attn.get_ring_attn_group", - MagicMock, - ) - - # Mock update_ring_attn_params - monkeypatch.setattr( - "axolotl.monkeypatch.ring_attn.update_ring_attn_params", - lambda **kwargs: None, - ) - - @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") - def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch): - """Test that function returns original batch when world size is 1.""" - mock_get_ring_attn_group.return_value = 0 - - result, _, _ = apply_sequence_parallelism( - batch=sequence_parallel_batch, - local_rank=0, - local_world_size=1, - gradient_accumulation_steps=1, - ring_attn_func=RingAttnFunc.BATCH_RING, - ) - - # Should return the original batch unchanged - assert result == sequence_parallel_batch - - @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") - def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch): - """Test BATCH_RING sharding for rank 0 in a 2-process group.""" - mock_get_ring_attn_group.return_value = 0 - - batch = sequence_parallel_batch - seq_len = batch["input_ids"].size(1) - - result, _, _ = apply_sequence_parallelism( - batch=batch, - local_rank=0, - local_world_size=2, - gradient_accumulation_steps=1, - ring_attn_func=RingAttnFunc.BATCH_RING, - ) - - # Check that sequence dimension was sharded correctly - assert result["input_ids"].shape[1] == seq_len // 2 - assert result["attention_mask"].shape[1] == seq_len // 2 - - # Verify content: rank 0 should get the first half of the sequence - assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2]) - assert torch.equal( - result["position_ids"], batch["position_ids"][:, : seq_len // 2] - ) - - @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") - def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch): - """Test BATCH_RING sharding for rank 1 in a 2-process group.""" - mock_get_ring_attn_group.return_value = 0 - - batch = sequence_parallel_batch - seq_len = batch["input_ids"].size(1) - original_input_ids = batch["input_ids"].clone() - - result, _, _ = apply_sequence_parallelism( - batch=batch, - local_rank=1, - local_world_size=2, - gradient_accumulation_steps=1, - ring_attn_func=RingAttnFunc.BATCH_RING, - ) - - # Verify content: rank 1 should get the second half of the sequence - assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :]) - - # TODO(djsaunde): add back once implemented. - # def test_batch_zigzag(self, sequence_parallel_batch): - # """Test BATCH_ZIGZAG sharding pattern.""" - # batch = sequence_parallel_batch - # original_input_ids = batch["input_ids"].clone() - # seq_len = batch["input_ids"].size(1) - - # # Test rank 0 - # result_rank0 = apply_sequence_parallelism( - # batch={k: v.clone() for k, v in batch.items()}, - # local_rank=0, - # local_world_size=2, - # ring_attn_func=RingAttnFunc.BATCH_ZIGZAG, - # ) - - # # Test rank 1 - # result_rank1 = apply_sequence_parallelism( - # batch={k: v.clone() for k, v in batch.items()}, - # local_rank=1, - # local_world_size=2, - # ring_attn_func=RingAttnFunc.BATCH_ZIGZAG, - # ) - - # # Checks for both ranks - # assert result_rank0["input_ids"].shape[1] == seq_len // 2 - # assert result_rank1["input_ids"].shape[1] == seq_len // 2 - - # # For a 2-rank system with 8 tokens, check specific zigzag pattern - # # Rank 0 should get chunks [0, 1] and [6, 7] - # # Rank 1 should get chunks [2, 3] and [4, 5] - # if seq_len == 8: - # # Create expected tensors for comparison - # rank0_expected = torch.cat( - # [original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1 - # ) - - # rank1_expected = torch.cat( - # [original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1 - # ) - - # assert torch.equal(result_rank0["input_ids"], rank0_expected) - # assert torch.equal(result_rank1["input_ids"], rank1_expected) - - @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") - def test_partial_application( - self, mock_get_ring_attn_group, sequence_parallel_batch - ): - """Test that we can create a partially applied version of the function.""" - mock_get_ring_attn_group.return_value = 0 - - batch = sequence_parallel_batch - original_input_ids = batch["input_ids"].clone() - - # Create a partially applied function - rank0_ring_parallel = functools.partial( - apply_sequence_parallelism, - local_rank=0, - local_world_size=2, - gradient_accumulation_steps=1, - ring_attn_func=RingAttnFunc.BATCH_RING, - ) - - # Use the partially applied function - result, _, _ = rank0_ring_parallel(batch=batch) - - # Verify it works as expected - assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2 - assert torch.equal( - result["input_ids"], - original_input_ids[:, : original_input_ids.shape[1] // 2], - ) - - def test_missing_position_ids(self, sequence_parallel_batch): - """Test handling of batch without position_ids.""" - # Create a batch without position_ids - batch = { - k: v for k, v in sequence_parallel_batch.items() if k != "position_ids" - } - original_input_ids = batch["input_ids"].clone() - - # This should run without error even though position_ids is missing - result, _, _ = apply_sequence_parallelism( - batch=batch, - local_rank=0, - local_world_size=2, - gradient_accumulation_steps=1, - ring_attn_func=RingAttnFunc.BATCH_RING, - ) - - # Verification should pass - assert "position_ids" in result - assert result["input_ids"].shape[1] == result["position_ids"].shape[1] - assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2 diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index 46f5b6614..bf00e8a5f 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -4,7 +4,6 @@ e2e tests for unsloth qlora import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -13,7 +12,6 @@ from axolotl.utils.dict import DictDefault from ..utils import check_model_output_exists, check_tensorboard -# pylint: disable=duplicate-code @pytest.mark.skip( reason="Unsloth integration will be broken going into latest transformers" ) @@ -63,19 +61,19 @@ class TestUnslothQLoRA: "lr_scheduler": "cosine", "use_tensorboard": True, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high" ) def test_unsloth_llama_qlora_unpacked(self, temp_dir): @@ -114,19 +112,19 @@ class TestUnslothQLoRA: "lr_scheduler": "cosine", "use_tensorboard": True, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high" ) @pytest.mark.parametrize( @@ -170,17 +168,17 @@ class TestUnslothQLoRA: "lr_scheduler": "cosine", "use_tensorboard": True, "fp16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high" ) diff --git a/tests/e2e/solo/test_flex.py b/tests/e2e/solo/test_flex.py index b33869b1c..abe8fb69a 100644 --- a/tests/e2e/solo/test_flex.py +++ b/tests/e2e/solo/test_flex.py @@ -6,7 +6,6 @@ import unittest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -23,7 +22,6 @@ class TestPackedFlex(unittest.TestCase): @require_torch_2_6_0 @with_temp_dir def test_loss_llama(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -50,6 +48,7 @@ class TestPackedFlex(unittest.TestCase): "lr_scheduler": "cosine", "max_steps": 5, "use_tensorboard": True, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): @@ -59,11 +58,10 @@ class TestPackedFlex(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high" ) diff --git a/tests/e2e/solo/test_relora_llama.py b/tests/e2e/solo/test_relora_llama.py index cff8313f3..be77684ba 100644 --- a/tests/e2e/solo/test_relora_llama.py +++ b/tests/e2e/solo/test_relora_llama.py @@ -5,7 +5,6 @@ E2E tests for relora llama import unittest from pathlib import Path -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -21,7 +20,6 @@ class TestReLoraLlama(unittest.TestCase): @with_temp_dir def test_relora(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -35,9 +33,10 @@ class TestReLoraLlama(unittest.TestCase): "lora_alpha": 16, "lora_dropout": 0.05, "lora_target_modules": ["q_proj", "v_proj"], - "relora_steps": 50, - "relora_warmup_steps": 10, - "relora_anneal_steps": 10, + "relora": True, + "jagged_restart_steps": 50, + "jagged_restart_warmup_steps": 10, + "jagged_restart_anneal_steps": 10, "relora_prune_ratio": 0.9, "relora_cpu_offload": True, "val_set_size": 0.0, @@ -66,19 +65,19 @@ class TestReLoraLlama(unittest.TestCase): "lr_scheduler": "cosine", "save_safetensors": True, "use_tensorboard": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg) - assert ( - Path(temp_dir) / "checkpoint-100/relora/model.safetensors" - ).exists(), "Relora model checkpoint not found" + assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists(), ( + "Relora model checkpoint not found" + ) check_tensorboard( temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high" diff --git a/tests/e2e/test_activation_offloading.py b/tests/e2e/test_activation_offloading.py new file mode 100644 index 000000000..9df85ab31 --- /dev/null +++ b/tests/e2e/test_activation_offloading.py @@ -0,0 +1,80 @@ +""" +E2E tests for activation offloading +""" + +import pytest + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from .utils import check_model_output_exists + + +class TestActivationOffloading: + """ + E2E test cases for activation offloading + """ + + @pytest.mark.parametrize( + "adapter", + ["lora", "qlora", None], + ) + def test_activation_offloading( + self, + temp_dir, + adapter, + ): + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sequence_len": 1024, + "val_set_size": 0.0, + "special_tokens": { + "pad_token": "<|endoftext|>", + "eos_token": "<|im_end|>", + }, + "datasets": [ + { + "chat_template": "chatml", + "path": "mlabonne/FineTome-100k", + "type": "chat_template", + "split": "train[:10%]", + "field_messages": "conversations", + "message_field_role": "from", + "message_field_content": "value", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + "bf16": "auto", + "save_safetensors": True, + "gradient_checkpointing": True, + "activation_offloading": True, + "save_first_step": False, + "lora_r": 8, + "lora_alpha": 16, + "lora_target_linear": True, + } + ) + if adapter == "lora": + cfg["adapter"] = "lora" + if adapter == "qlora": + cfg["adapter"] = "qlora" + cfg["load_in_4bit"] = True + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_deepseekv3.py b/tests/e2e/test_deepseekv3.py index d882286cc..e11be8265 100644 --- a/tests/e2e/test_deepseekv3.py +++ b/tests/e2e/test_deepseekv3.py @@ -6,7 +6,6 @@ from pathlib import Path import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -26,7 +25,6 @@ class TestDeepseekV3: [True, False], ) def test_lora_deepseekv3(self, temp_dir, sample_packing): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "axolotl-ai-co/DeepSeek-V3-11M", @@ -68,12 +66,12 @@ class TestDeepseekV3: "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "adapter_model.safetensors").exists() @@ -84,7 +82,6 @@ class TestDeepseekV3: [True, False], ) def test_fft_deepseekv3(self, temp_dir, sample_packing): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "axolotl-ai-co/DeepSeek-V3-11M", @@ -118,12 +115,12 @@ class TestDeepseekV3: "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py new file mode 100644 index 000000000..cc3d8070b --- /dev/null +++ b/tests/e2e/test_diffusion.py @@ -0,0 +1,139 @@ +"""E2E smoke test for diffusion training plugin.""" + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import check_model_output_exists + + +class TestDiffusion: + """Test case for diffusion training plugin.""" + + def test_diffusion_smoke_test(self, temp_dir): + """ + Smoke test for diffusion training to ensure the plugin loads and trains without + error. + """ + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "trust_remote_code": True, + "sequence_len": 256, + "val_set_size": 0.1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "bf16": True, + "save_safetensors": True, + "save_first_step": False, + "logging_steps": 1, + "eval_steps": 3, + # Diffusion-specific config + "plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"], + "diffusion": { + # sample generation + "generate_samples": True, + "generation_interval": 1, + "num_generation_samples": 1, + "generation_steps": 2, + "generation_max_length": 32, + "generation_temperature": 0.0, + # training-specific + "mask_token_id": 16, + "eps": 1e-3, + "importance_weighting": False, + }, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + + def test_diffusion_sft_labels(self, temp_dir): + """Test that diffusion training properly handles SFT data with labels.""" + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "trust_remote_code": True, + "sequence_len": 256, + "val_set_size": 0.1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "bf16": True, + "save_safetensors": True, + "save_first_step": False, + "logging_steps": 1, + "eval_steps": 2, + # Diffusion-specific config + "plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"], + "diffusion": { + # sample generation + "generate_samples": True, + "generation_interval": 1, + "num_generation_samples": 1, + "generation_steps": 2, + "generation_max_length": 32, + "generation_temperature": 0.0, + # training-specific + "mask_token_id": 16, + "eps": 1e-3, + "importance_weighting": True, + }, + # Ensure we have proper SFT labels + "train_on_inputs": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + # Verify that the dataset has labels + sample = dataset_meta.train_dataset[0] + assert "labels" in sample, "SFT dataset should have labels" + + # Check that some labels are -100 (prompt tokens) + labels = sample["labels"] + if hasattr(labels, "tolist"): + labels = labels.tolist() + assert -100 in labels, "SFT dataset should have -100 labels for prompt tokens" + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index e9f70758b..8f577ef47 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -1,6 +1,4 @@ -""" -E2E tests for lora llama -""" +"""E2E tests for lora llama""" import unittest from pathlib import Path @@ -23,7 +21,6 @@ class TestDPOLlamaLora(unittest.TestCase): @with_temp_dir def test_dpo_lora(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -58,6 +55,7 @@ class TestDPOLlamaLora(unittest.TestCase): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -71,7 +69,6 @@ class TestDPOLlamaLora(unittest.TestCase): @with_temp_dir def test_dpo_nll_lora(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -107,6 +104,7 @@ class TestDPOLlamaLora(unittest.TestCase): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -120,7 +118,6 @@ class TestDPOLlamaLora(unittest.TestCase): @with_temp_dir def test_dpo_use_weighting(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -156,6 +153,7 @@ class TestDPOLlamaLora(unittest.TestCase): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -170,7 +168,6 @@ class TestDPOLlamaLora(unittest.TestCase): @pytest.mark.skip("kto_pair no longer supported in trl") @with_temp_dir def test_kto_pair_lora(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -205,6 +202,7 @@ class TestDPOLlamaLora(unittest.TestCase): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -218,7 +216,6 @@ class TestDPOLlamaLora(unittest.TestCase): @with_temp_dir def test_ipo_lora(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -253,6 +250,7 @@ class TestDPOLlamaLora(unittest.TestCase): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -266,7 +264,6 @@ class TestDPOLlamaLora(unittest.TestCase): @with_temp_dir def test_orpo_lora(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -304,6 +301,7 @@ class TestDPOLlamaLora(unittest.TestCase): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -318,7 +316,6 @@ class TestDPOLlamaLora(unittest.TestCase): @pytest.mark.skip(reason="Fix the implementation") @with_temp_dir def test_kto_lora(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -372,6 +369,7 @@ class TestDPOLlamaLora(unittest.TestCase): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py index f1297fcf3..633e449ef 100644 --- a/tests/e2e/test_embeddings_lr.py +++ b/tests/e2e/test_embeddings_lr.py @@ -4,7 +4,6 @@ E2E tests for llama pretrain import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -20,7 +19,6 @@ class TestEmbeddingsLrScale(unittest.TestCase): @with_temp_dir def test_train_w_embedding_lr_scale(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -49,13 +47,13 @@ class TestEmbeddingsLrScale(unittest.TestCase): "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -66,7 +64,6 @@ class TestEmbeddingsLrScale(unittest.TestCase): @with_temp_dir def test_train_w_embedding_lr(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -95,12 +92,12 @@ class TestEmbeddingsLrScale(unittest.TestCase): "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_evaluate.py b/tests/e2e/test_evaluate.py index 6271bba28..3b0ab1450 100644 --- a/tests/e2e/test_evaluate.py +++ b/tests/e2e/test_evaluate.py @@ -13,7 +13,6 @@ class TestE2eEvaluate: """Test cases for evaluate CLI""" def test_evaluate(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -36,6 +35,7 @@ class TestE2eEvaluate: "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", "max_steps": 20, + "save_first_step": False, } ) diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py index 7ea7e30f4..1a363fe6a 100644 --- a/tests/e2e/test_falcon.py +++ b/tests/e2e/test_falcon.py @@ -6,7 +6,6 @@ import unittest import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -23,7 +22,6 @@ class TestFalcon(unittest.TestCase): @pytest.mark.skip(reason="no tiny models for testing with safetensors") @with_temp_dir def test_lora(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "illuin/tiny-random-FalconForCausalLM", @@ -61,13 +59,13 @@ class TestFalcon(unittest.TestCase): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -75,7 +73,6 @@ class TestFalcon(unittest.TestCase): @pytest.mark.skip(reason="no tiny models for testing with safetensors") @with_temp_dir def test_lora_added_vocab(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "illuin/tiny-random-FalconForCausalLM", @@ -117,13 +114,13 @@ class TestFalcon(unittest.TestCase): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -131,7 +128,6 @@ class TestFalcon(unittest.TestCase): @pytest.mark.skip(reason="no tiny models for testing with safetensors") @with_temp_dir def test_ft(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "illuin/tiny-random-FalconForCausalLM", @@ -159,13 +155,13 @@ class TestFalcon(unittest.TestCase): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_gemma2.py b/tests/e2e/test_gemma2.py index 65732a737..9e9f1a9cc 100644 --- a/tests/e2e/test_gemma2.py +++ b/tests/e2e/test_gemma2.py @@ -6,7 +6,6 @@ from pathlib import Path import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -23,7 +22,6 @@ class TestGemma2: [True, False], ) def test_lora_gemma2(self, temp_dir, sample_packing): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "axolotl-ai-co/gemma-2-33M", @@ -69,8 +67,7 @@ class TestGemma2: ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "adapter_model.safetensors").exists() @@ -80,7 +77,6 @@ class TestGemma2: [True, False], ) def test_fft_gemma2(self, temp_dir, sample_packing): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "axolotl-ai-co/gemma-2-33M", @@ -121,8 +117,7 @@ class TestGemma2: ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() diff --git a/tests/e2e/test_gemma3_text.py b/tests/e2e/test_gemma3_text.py index d790fa156..6cd999242 100644 --- a/tests/e2e/test_gemma3_text.py +++ b/tests/e2e/test_gemma3_text.py @@ -6,7 +6,6 @@ from pathlib import Path import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -23,7 +22,6 @@ class TestGemma3Text: [True, False], ) def test_lora_gemma3_text(self, temp_dir, sample_packing): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "axolotl-ai-co/gemma-3-34M", @@ -64,12 +62,12 @@ class TestGemma3Text: "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "adapter_model.safetensors").exists() @@ -79,7 +77,6 @@ class TestGemma3Text: [True, False], ) def test_fft_gemma3_text(self, temp_dir, sample_packing): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "axolotl-ai-co/gemma-3-34M", @@ -115,12 +112,12 @@ class TestGemma3Text: "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() diff --git a/tests/e2e/test_imports.py b/tests/e2e/test_imports.py index 050e4dfb3..4c01e50be 100644 --- a/tests/e2e/test_imports.py +++ b/tests/e2e/test_imports.py @@ -11,11 +11,7 @@ class TestImports(unittest.TestCase): """ def test_import_causal_trainer(self): - from axolotl.core.builders import ( # pylint: disable=unused-import # noqa: F401 - HFCausalTrainerBuilder, - ) + pass def test_import_rl_trainer(self): - from axolotl.core.builders import ( # pylint: disable=unused-import # noqa: F401 - HFRLTrainerBuilder, - ) + pass diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index 455e17532..de085cbe2 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -2,7 +2,6 @@ E2E tests for llama """ -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -17,7 +16,6 @@ class TestLlama: """ def test_fft_trust_remote_code(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -46,19 +44,18 @@ class TestLlama: "sample_packing": True, "bf16": True, "save_safetensors": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) def test_fix_untrained_tokens(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -94,19 +91,18 @@ class TestLlama: "sample_packing": True, "bf16": True, "save_safetensors": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) def test_fix_untrained_tokens_already_trained(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -139,19 +135,18 @@ class TestLlama: "sample_packing": True, "bf16": True, "save_safetensors": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) def test_batch_flattening(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -180,13 +175,13 @@ class TestLlama: "batch_flattening": True, "bf16": True, "save_safetensors": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index ec1e164a4..f0daa9dd6 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -1,10 +1,7 @@ -""" -E2E tests for llama pretrain -""" +"""E2E tests for llama pretrain""" import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -14,23 +11,17 @@ from .utils import check_model_output_exists, check_tensorboard class TestPretrainLlama: - """ - Test case for Llama models w pretraining - """ + """Test case for Llama models w pretraining""" @pytest.mark.parametrize( - "sample_packing", - [True, False], - ) - @pytest.mark.parametrize( - "pretrain_multipack_attn", - [True, False], + ("sample_packing", "pretrain_multipack_attn"), + [ + (False, False), + (True, True), + (True, False), + ], ) def test_pretrain(self, temp_dir, sample_packing, pretrain_multipack_attn): - if not sample_packing and pretrain_multipack_attn: - return - - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -38,7 +29,7 @@ class TestPretrainLlama: "sequence_len": 1024, "sample_packing": sample_packing, "pretrain_multipack_attn": pretrain_multipack_attn, - "dataset_processes": 1, + "dataset_num_proc": 1, "special_tokens": { "pad_token": "<|endoftext|>", }, @@ -61,22 +52,22 @@ class TestPretrainLlama: "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) - loss_threshold = 3.5 + loss_threshold = 3.6 if sample_packing and not pretrain_multipack_attn: loss_threshold = 6.5 check_tensorboard( temp_dir + "/runs", "train/train_loss", loss_threshold, - "Train Loss is too high", + "Train Loss (%s) is too high", ) diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py index 32657c156..0cc927f76 100644 --- a/tests/e2e/test_llama_vision.py +++ b/tests/e2e/test_llama_vision.py @@ -4,7 +4,6 @@ E2E tests for lora llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -20,7 +19,6 @@ class TestLlamaVision(unittest.TestCase): @with_temp_dir def test_lora_llama_vision_text_only_dataset(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "axolotl-ai-co/Llama-3.2-39M-Vision", @@ -55,20 +53,19 @@ class TestLlamaVision(unittest.TestCase): "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_lora_llama_vision_multimodal_dataset(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "axolotl-ai-co/Llama-3.2-39M-Vision", @@ -102,12 +99,12 @@ class TestLlamaVision(unittest.TestCase): "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_load_model.py b/tests/e2e/test_load_model.py index 5061945b4..7c5389a58 100644 --- a/tests/e2e/test_load_model.py +++ b/tests/e2e/test_load_model.py @@ -52,15 +52,15 @@ class TestLoadModelUtils: "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", + "tensor_parallel_size": 1, + "context_parallel_size": 1, } ) - self.model_loader = ( # pylint: disable=attribute-defined-outside-init - ModelLoader( - cfg=self.cfg, - tokenizer="", - inference=False, - reference_model=True, - ) + self.model_loader = ModelLoader( + cfg=self.cfg, + tokenizer="", + inference=False, + reference_model=True, ) @pytest.mark.parametrize("embedding_modules", ["embed_tokens", "lm_head"]) @@ -72,7 +72,7 @@ class TestLoadModelUtils: self, temp_dir, embedding_modules, dist_dtype, before_kbit_train_or_finetune ): self.cfg.output_dir = temp_dir - self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all + self.model_loader.tokenizer = load_tokenizer(self.cfg) self.model_loader.load() self.model_loader._convert_embedding_modules_dtype( embedding_modules, dist_dtype, before_kbit_train_or_finetune diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index 999625070..b6ee393df 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -4,7 +4,6 @@ E2E tests for lora llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -20,7 +19,6 @@ class TestLoraLlama(unittest.TestCase): @with_temp_dir def test_lora(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -50,13 +48,13 @@ class TestLoraLlama(unittest.TestCase): "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", "max_steps": 5, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py index efffb4547..67935377d 100644 --- a/tests/e2e/test_mamba.py +++ b/tests/e2e/test_mamba.py @@ -6,7 +6,6 @@ import unittest import pytest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -23,7 +22,6 @@ class TestMamba(unittest.TestCase): @with_temp_dir def test_fft(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "state-spaces/mamba-130m", @@ -52,13 +50,13 @@ class TestMamba(unittest.TestCase): "save_steps": 10, "eval_steps": None, "save_safetensors": False, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py index 98a82a5f0..08b3b05af 100644 --- a/tests/e2e/test_mistral.py +++ b/tests/e2e/test_mistral.py @@ -6,7 +6,6 @@ import unittest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -22,7 +21,6 @@ class TestMistral(unittest.TestCase): @with_temp_dir def test_lora(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2", @@ -56,20 +54,19 @@ class TestMistral(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_ft(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2", @@ -97,6 +94,7 @@ class TestMistral(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): @@ -106,8 +104,7 @@ class TestMistral(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py index b551e431a..c46cf906d 100644 --- a/tests/e2e/test_mixtral.py +++ b/tests/e2e/test_mixtral.py @@ -7,7 +7,6 @@ import unittest import torch from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -23,7 +22,6 @@ class TestMixtral(unittest.TestCase): @with_temp_dir def test_qlora_w_fa2(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "hf-internal-testing/Mixtral-tiny", @@ -62,13 +60,13 @@ class TestMixtral(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( @@ -79,7 +77,6 @@ class TestMixtral(unittest.TestCase): @with_temp_dir def test_qlora_wo_fa2(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "hf-internal-testing/Mixtral-tiny", @@ -118,13 +115,13 @@ class TestMixtral(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( @@ -135,7 +132,6 @@ class TestMixtral(unittest.TestCase): @with_temp_dir def test_16bit_lora_w_fa2(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "hf-internal-testing/Mixtral-tiny", @@ -173,6 +169,7 @@ class TestMixtral(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): @@ -182,8 +179,7 @@ class TestMixtral(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( @@ -194,7 +190,6 @@ class TestMixtral(unittest.TestCase): @with_temp_dir def test_16bit_lora_wo_fa2(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "hf-internal-testing/Mixtral-tiny", @@ -232,6 +227,7 @@ class TestMixtral(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) @@ -241,8 +237,7 @@ class TestMixtral(unittest.TestCase): cfg.bf16 = True else: cfg.fp16 = True - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( @@ -253,7 +248,6 @@ class TestMixtral(unittest.TestCase): @with_temp_dir def test_ft(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "hf-internal-testing/Mixtral-tiny", @@ -278,6 +272,7 @@ class TestMixtral(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): @@ -287,8 +282,7 @@ class TestMixtral(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index e812a5f7e..dbea92a5b 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -4,7 +4,6 @@ E2E tests for custom optimizers using Llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -14,6 +13,7 @@ from .utils import ( check_model_output_exists, require_torch_2_5_1, require_torch_2_6_0, + require_torch_2_7_0, with_temp_dir, ) @@ -25,7 +25,6 @@ class TestCustomOptimizers(unittest.TestCase): @with_temp_dir def test_optimi_adamw(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -56,13 +55,13 @@ class TestCustomOptimizers(unittest.TestCase): "optimizer": "optimi_adamw", "max_steps": 5, "lr_scheduler": "cosine", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -71,7 +70,6 @@ class TestCustomOptimizers(unittest.TestCase): @with_temp_dir @require_torch_2_5_1 def test_adopt_adamw(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -102,13 +100,13 @@ class TestCustomOptimizers(unittest.TestCase): "learning_rate": 0.00001, "optimizer": "adopt_adamw", "lr_scheduler": "cosine", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -117,7 +115,6 @@ class TestCustomOptimizers(unittest.TestCase): @with_temp_dir @require_torch_2_5_1 def test_muon(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -149,21 +146,62 @@ class TestCustomOptimizers(unittest.TestCase): "optimizer": "muon", "lr_scheduler": "cosine", "weight_decay": 0.01, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) assert "Muon" in trainer.optimizer.optimizer.__class__.__name__ + @with_temp_dir + @require_torch_2_7_0 + def test_dion(self, temp_dir): + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "model_type": "AutoModelForCausalLM", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 1024, + "val_set_size": 0.0, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "dion", + "dion_lr": 0.01, + "dion_momentum": 0.95, + "lr_scheduler": "cosine", + "weight_decay": 0.01, + "save_first_step": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + assert "Dion" in trainer.optimizer.optimizer.__class__.__name__ + @with_temp_dir def test_fft_schedule_free_adamw(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -188,14 +226,13 @@ class TestCustomOptimizers(unittest.TestCase): "lr_scheduler": "constant", "save_safetensors": True, "max_steps": 10, + "save_first_step": False, } ) - # pylint: disable=duplicate-code cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @@ -203,7 +240,6 @@ class TestCustomOptimizers(unittest.TestCase): @with_temp_dir @require_torch_2_6_0 def test_came_pytorch(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "JackFram/llama-68m", @@ -237,13 +273,13 @@ class TestCustomOptimizers(unittest.TestCase): "adam_epsilon2": 1e-16, "max_steps": 5, "lr_scheduler": "cosine", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py index 12e272888..7cb979ce6 100644 --- a/tests/e2e/test_packing_loss.py +++ b/tests/e2e/test_packing_loss.py @@ -6,7 +6,6 @@ import unittest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -22,7 +21,6 @@ class TestPackedLlama(unittest.TestCase): @with_temp_dir def test_loss_packed(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -49,6 +47,7 @@ class TestPackedLlama(unittest.TestCase): "lr_scheduler": "cosine", "max_steps": 5, "use_tensorboard": True, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): @@ -58,11 +57,10 @@ class TestPackedLlama(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high" ) diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index f8b43ad32..ae2210249 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -4,7 +4,6 @@ E2E tests for lora llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -20,7 +19,6 @@ class TestPhi(unittest.TestCase): @with_temp_dir def test_phi_ft(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "microsoft/phi-1_5", @@ -54,19 +52,18 @@ class TestPhi(unittest.TestCase): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir def test_phi_qlora(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "microsoft/phi-1_5", @@ -104,12 +101,12 @@ class TestPhi(unittest.TestCase): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_preprocess.py b/tests/e2e/test_preprocess.py new file mode 100644 index 000000000..4aa4cb6c2 --- /dev/null +++ b/tests/e2e/test_preprocess.py @@ -0,0 +1,58 @@ +"""E2E Test the preprocess cli""" + +from pathlib import Path + +import yaml +from accelerate.test_utils import execute_subprocess_async + +from axolotl.utils.dict import DictDefault + +AXOLOTL_ROOT = Path(__file__).parent.parent.parent + + +class TestPreprocess: + """test cases for preprocess""" + + def test_w_deepspeed(self, temp_dir): + """make sure preproces doesn't choke when using deepspeed in the config""" + + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "num_epochs": 1, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "bf16": "auto", + "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"), + "dataset_prepared_path": temp_dir + "/last_run_prepared", + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "preprocess", + str(Path(temp_dir) / "config.yaml"), + ] + ) + + assert (Path(temp_dir) / "last_run_prepared").exists() diff --git a/tests/e2e/test_process_reward_model_smollm2.py b/tests/e2e/test_process_reward_model_smollm2.py index eb81959a2..9d83aabbc 100644 --- a/tests/e2e/test_process_reward_model_smollm2.py +++ b/tests/e2e/test_process_reward_model_smollm2.py @@ -4,7 +4,6 @@ E2E tests for process reward model w/ lora llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -20,7 +19,6 @@ class TestProcessRewardSmolLM2(unittest.TestCase): @with_temp_dir def test_prm(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -50,12 +48,12 @@ class TestProcessRewardSmolLM2(unittest.TestCase): "use_tensorboard": True, "special_tokens": {"pad_token": "<|endoftext|>"}, "seed": 42, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_tensorboard( diff --git a/tests/e2e/test_profiler.py b/tests/e2e/test_profiler.py new file mode 100644 index 000000000..ab273b981 --- /dev/null +++ b/tests/e2e/test_profiler.py @@ -0,0 +1,113 @@ +""" +e2e gpu test for the pytorch profiler callback +""" + +from pathlib import Path + +import pytest + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="profiler_base_cfg") +def fixture_profiler_base_cfg(): + cfg = DictDefault( + base_model="HuggingFaceTB/SmolLM2-135M", + tokenizer_type="AutoTokenizer", + sequence_len=1024, + load_in_8bit=True, + adapter="lora", + lora_r=8, + lora_alpha=16, + lora_dropout=0.05, + lora_target_linear=True, + val_set_size=0.02, + special_tokens={"pad_token": "<|endoftext|>"}, + datasets=[ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + num_epochs=1, + micro_batch_size=2, + gradient_accumulation_steps=1, + learning_rate=0.00001, + optimizer="adamw_torch_fused", + lr_scheduler="cosine", + ) + return cfg + + +class TestProfiler: + """ + test cases for the pytorch profiler callback + """ + + def test_profiler_saves(self, profiler_base_cfg, temp_dir): + cfg = profiler_base_cfg | DictDefault( + output_dir=temp_dir, + max_steps=5, + profiler_steps=3, + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "snapshot.pickle").exists() + + def test_profiler_saves_w_start(self, profiler_base_cfg, temp_dir): + cfg = profiler_base_cfg | DictDefault( + output_dir=temp_dir, + max_steps=5, + profiler_steps=3, + profiler_steps_start=1, + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "snapshot.pickle").exists() + + @pytest.mark.parametrize( + "profiler_steps_start", + [3, 5], + ) + def test_profiler_saves_past_end( + self, profiler_base_cfg, temp_dir, profiler_steps_start + ): + cfg = profiler_base_cfg | DictDefault( + output_dir=temp_dir, + max_steps=5, + profiler_steps=3, + profiler_steps_start=profiler_steps_start, + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "snapshot.pickle").exists() + + def test_profiler_never_started(self, profiler_base_cfg, temp_dir): + cfg = profiler_base_cfg | DictDefault( + output_dir=temp_dir, + max_steps=5, + profiler_steps=3, + profiler_steps_start=6, + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + assert not (Path(temp_dir) / "snapshot.pickle").exists() diff --git a/tests/e2e/test_qat.py b/tests/e2e/test_qat.py index f9e7993be..2f8398ef7 100644 --- a/tests/e2e/test_qat.py +++ b/tests/e2e/test_qat.py @@ -2,26 +2,22 @@ E2E tests for QAT """ -import unittest from pathlib import Path -from axolotl.cli.args import TrainerCliArgs -from axolotl.common.datasets import load_datasets +from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault -from .utils import check_model_output_exists, with_temp_dir +from .utils import check_model_output_exists, check_tensorboard -class TestQATLlama(unittest.TestCase): +class TestQATLlama: """ Test case for QAT Llama models """ - @with_temp_dir - def test_qat_lora(self, temp_dir): - # pylint: disable=duplicate-code + def test_qat(self, temp_dir): cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -47,7 +43,7 @@ class TestQATLlama(unittest.TestCase): "qat": { "quantize_embedding": True, "activation_dtype": "int8", - "weight_dtype": "int8", + "weight_dtype": "int4", "group_size": 8, }, "num_epochs": 1, @@ -60,12 +56,78 @@ class TestQATLlama(unittest.TestCase): "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg) + + def test_qat_dpo(self, temp_dir): + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sequence_len": 2048, + "sample_packing": False, + "eval_sample_packing": False, + "pad_to_sequence_len": True, + "val_set_size": 0.01, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "rl": "dpo", + "chat_template": "chatml", + "datasets": [ + { + "path": "fozziethebeat/alpaca_messages_2k_dpo_test", + "type": "chat_template.default", + "field_messages": "conversation", + "field_chosen": "chosen", + "field_rejected": "rejected", + "message_field_role": "role", + "message_field_content": "content", + "roles": { + "system": ["system"], + "user": ["user"], + "assistant": ["assistant"], + }, + }, + ], + "num_epochs": 1, + "max_steps": 5, + "micro_batch_size": 2, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "warmup_steps": 0, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "use_tensorboard": True, + "bf16": True, + "qat": { + "quantize_embedding": True, + "activation_dtype": "int8", + "weight_dtype": "int4", + "group_size": 8, + }, + "save_first_step": False, + } + ) + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_preference_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg) + + loss_threshold = 2.3 + check_tensorboard( + temp_dir + "/runs", + "train/train_loss", + loss_threshold, + "Train Loss (%s) is too high", + ) diff --git a/tests/e2e/test_quantization.py b/tests/e2e/test_quantization.py index 500b7e556..706279c6c 100644 --- a/tests/e2e/test_quantization.py +++ b/tests/e2e/test_quantization.py @@ -5,42 +5,41 @@ Tests for axolotl.utils.quantization import pytest import torch from torch import nn -from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor -from torchao.quantization.granularity import PerAxis, PerGroup -from torchao.quantization.linear_activation_quantized_tensor import ( - LinearActivationQuantizedTensor, -) +from torchao.quantization import LinearActivationQuantizedTensor from torchao.quantization.qat.embedding import FakeQuantizedEmbedding from torchao.quantization.qat.linear import FakeQuantizedLinear from torchao.quantization.quant_api import ( - Int4DynamicActivationInt4WeightConfig, - Int4WeightOnlyConfig, - Int8DynamicActivationInt8WeightConfig, - Int8WeightOnlyConfig, - UIntXWeightOnlyConfig, + Float8DynamicActivationFloat8WeightConfig, + Float8DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt4WeightConfig, ) +from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor from transformers import AutoModelForCausalLM from transformers.trainer_callback import TrainerState from axolotl.utils.callbacks.qat import QATCallback from axolotl.utils.quantization import ( - convert_qat_model_for_ptq, - get_ptq_config, + convert_qat_model, + get_quantization_config, prepare_model_for_qat, - quantize_model_for_ptq, + quantize_model, ) -from axolotl.utils.schemas.enums import TorchIntDType +from axolotl.utils.schemas.enums import TorchAOQuantDType from axolotl.utils.schemas.quantization import QATConfig -from tests.e2e.utils import require_torch_2_6_0 +from tests.e2e.utils import ( + require_torch_2_8_0, + requires_cuda_ge_8_9, + requires_sm_ge_100, +) @pytest.fixture() def model(): dummy_model = AutoModelForCausalLM.from_pretrained( - "HuggingFaceTB/SmolLM2-135M", - device_map="cuda", - torch_dtype=torch.bfloat16, + "Qwen/Qwen2-0.5B", + device_map="auto", + dtype=torch.bfloat16, ) with torch.device(dummy_model.device): dummy_model.model.embed_tokens = torch.nn.Embedding( @@ -48,45 +47,56 @@ def model(): dummy_model.model.embed_tokens.weight.shape[1], dtype=dummy_model.model.embed_tokens.weight.dtype, ) - return dummy_model + yield dummy_model + del dummy_model ptq_config_test_cases = [ - # weight_dtype, activation_dtype, group_size, expected_type, expected_params + # weight_dtype, activation_dtype, group_size, expected_type ( - TorchIntDType.uint4, + TorchAOQuantDType.int4, + TorchAOQuantDType.int8, None, - None, - UIntXWeightOnlyConfig, - {"dtype": torch.uint4, "group_size": None}, - ), - (TorchIntDType.int8, None, 32, Int8WeightOnlyConfig, {"group_size": 32}), - (TorchIntDType.int4, None, 4, Int4WeightOnlyConfig, {"group_size": 4}), - ( - TorchIntDType.int4, - TorchIntDType.int4, - None, - Int4DynamicActivationInt4WeightConfig, - {}, + Int8DynamicActivationInt4WeightConfig, ), ( - TorchIntDType.int8, - TorchIntDType.int8, + TorchAOQuantDType.float8_e4m3fn, + TorchAOQuantDType.float8_e4m3fn, None, - Int8DynamicActivationInt8WeightConfig, - {}, + Float8DynamicActivationFloat8WeightConfig, + ), + ( + TorchAOQuantDType.int4, + TorchAOQuantDType.float8_e4m3fn, + None, + Float8DynamicActivationInt4WeightConfig, ), ] ptq_test_cases = [ - # weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception - (TorchIntDType.int8, None, 8, False, None), - (TorchIntDType.int4, None, 4, True, None), - (TorchIntDType.uint4, None, 8, False, None), - (TorchIntDType.int4, TorchIntDType.int4, 8, False, None), - (TorchIntDType.int8, TorchIntDType.int8, 8, True, None), - (TorchIntDType.int8, None, None, False, ValueError), - (TorchIntDType.int4, None, None, False, ValueError), + # weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception, expected_tensor_class + (TorchAOQuantDType.int4, None, 4, True, None, Int4Tensor), + ( + TorchAOQuantDType.int4, + TorchAOQuantDType.int8, + 8, + False, + None, + LinearActivationQuantizedTensor, + ), + # ( + # TorchAOQuantDType.int4, + # TorchAOQuantDType.float8_e4m3fn, + # None, + # False, + # None, + # Int4Tensor, + # ), + (TorchAOQuantDType.int4, None, None, False, None, Int4Tensor), + # Deprecated configs + (TorchAOQuantDType.int8, None, 8, False, ValueError, None), + (TorchAOQuantDType.int4, TorchAOQuantDType.int4, 8, False, ValueError, None), + (TorchAOQuantDType.int8, TorchAOQuantDType.int8, 8, True, ValueError, None), ] @@ -96,44 +106,132 @@ class TestQuantization: """ @pytest.mark.parametrize( - "weight_dtype,activation_dtype,group_size,expected_type,expected_params", + "weight_dtype,activation_dtype,group_size,expected_type", ptq_config_test_cases, ) - @require_torch_2_6_0 + @requires_cuda_ge_8_9 + @require_torch_2_8_0 def test_get_ptq_config( - self, weight_dtype, activation_dtype, group_size, expected_type, expected_params + self, weight_dtype, activation_dtype, group_size, expected_type ): - config = get_ptq_config(weight_dtype, activation_dtype, group_size) - + config = get_quantization_config(weight_dtype, activation_dtype, group_size) assert isinstance(config, expected_type) - for param_name, param_value in expected_params.items(): - if isinstance(param_value, (PerAxis, PerGroup)): - if isinstance(param_value, PerAxis): - assert isinstance(getattr(config, param_name), PerAxis) - assert getattr(config, param_name).axis == param_value.axis - else: - assert isinstance(getattr(config, param_name), PerGroup) - assert ( - getattr(config, param_name).group_size == param_value.group_size - ) - else: - assert getattr(config, param_name) == param_value + @requires_cuda_ge_8_9 + @require_torch_2_8_0 + def test_get_ptq_config_int4_weight_only(self): + from torchao.quantization.quant_api import Int4WeightOnlyConfig + + config = get_quantization_config(TorchAOQuantDType.int4, None, 4) + assert isinstance(config, Int4WeightOnlyConfig) @pytest.mark.parametrize( - "weight_dtype", [TorchIntDType.int8, TorchIntDType.int4, TorchIntDType.uint4] + "weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception,expected_tensor_class", + ptq_test_cases, ) + @requires_cuda_ge_8_9 + @require_torch_2_8_0 + def test_quantize_model_for_ptq( + self, + model, + weight_dtype, + activation_dtype, + group_size, + quantize_embedding, + expected_exception, + expected_tensor_class, + ): + if expected_exception: + with pytest.raises(expected_exception): + quantize_model( + model, + weight_dtype, + group_size, + activation_dtype, + quantize_embedding, + ) + else: + quantize_model( + model, weight_dtype, group_size, activation_dtype, quantize_embedding + ) + if quantize_embedding: + assert isinstance( + model.model.embed_tokens.weight, expected_tensor_class + ), "Embedding weight should be quantized" + for child in list(model.children()): + if isinstance(child, torch.nn.Linear): + assert isinstance(child.weight, expected_tensor_class) + + @require_torch_2_8_0 + @requires_sm_ge_100 + def test_quantize_model_for_ptq_fp8( + self, + model, + ): + from torchao.quantization.quantize_.workflows.float8.float8_tensor import ( + Float8Tensor, + QuantizeTensorToFloat8Kwargs, + ) + + quantize_model( + model, + TorchAOQuantDType.float8_e4m3fn, + None, + TorchAOQuantDType.float8_e4m3fn, + ) + for child in list(model.children()): + if isinstance(child, torch.nn.Linear): + assert isinstance(child.weight, Float8Tensor) + assert child.weight.act_quant_kwargs is not None and isinstance( + child.weight.act_quant_kwargs, QuantizeTensorToFloat8Kwargs + ) + + @require_torch_2_8_0 + @requires_sm_ge_100 + def test_quantize_model_for_ptq_nvfp4( + self, + model, + ): + from torchao.prototype.mx_formats.nvfp4_tensor import ( + NVFP4Tensor, + QuantizeTensorToNVFP4Kwargs, + ) + + quantize_model(model, TorchAOQuantDType.nvfp4, 16, TorchAOQuantDType.nvfp4) + for child in list(model.children()): + if isinstance(child, torch.nn.Linear): + assert isinstance(child.weight, NVFP4Tensor) + assert child.weight.act_quant_kwargs is not None and isinstance( + child.weight.act_quant_kwargs, QuantizeTensorToNVFP4Kwargs + ) + @pytest.mark.parametrize( - "activation_dtype", [None, TorchIntDType.int4, TorchIntDType.int8] + "weight_dtype,activation_dtype,group_size,quantize_embedding", + [ + (TorchAOQuantDType.int4, None, 8, False), + (TorchAOQuantDType.int4, None, 16, True), + (TorchAOQuantDType.int4, TorchAOQuantDType.int8, 8, False), + (TorchAOQuantDType.int4, TorchAOQuantDType.int8, 16, True), + ( + TorchAOQuantDType.float8_e4m3fn, + TorchAOQuantDType.float8_e4m3fn, + None, + False, + ), + (TorchAOQuantDType.int4, TorchAOQuantDType.float8_e4m3fn, None, True), + ], ) - @pytest.mark.parametrize("group_size", [4, 8]) - @pytest.mark.parametrize("quantize_embedding", [False, True]) - @require_torch_2_6_0 + @require_torch_2_8_0 + @requires_cuda_ge_8_9 def test_prepare_model_for_qat( self, model, weight_dtype, activation_dtype, group_size, quantize_embedding - ): # pylint: disable=redefined-outer-name + ): prepare_model_for_qat( - model, weight_dtype, group_size, activation_dtype, quantize_embedding + model, + weight_dtype, + group_size, + activation_dtype, + quantize_embedding, ) if quantize_embedding: assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) @@ -142,17 +240,19 @@ class TestQuantization: model.model.embed_tokens.weight_fake_quantizer.config.dtype == weight_dtype.value ) - assert ( - model.model.embed_tokens.weight_fake_quantizer.config.group_size - == group_size - ) + if group_size: + assert ( + model.model.embed_tokens.weight_fake_quantizer.config.group_size + == group_size + ) for child in list(model.children()): if isinstance(child, torch.nn.Linear): assert isinstance(child, FakeQuantizedLinear) assert hasattr(child, "weight_fake_quantizer") assert child.weight_fake_quantizer.config.dtype == weight_dtype.value - assert child.weight_fake_quantizer.config.group_size == group_size + if group_size: + assert child.weight_fake_quantizer.config.group_size == group_size if activation_dtype: assert hasattr(child, "activation_fake_quantizer") assert ( @@ -162,47 +262,40 @@ class TestQuantization: else: assert child.activation_fake_quantizer is None - @pytest.mark.parametrize( - "weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception", - ptq_test_cases, - ) - @require_torch_2_6_0 - def test_quantize_model_for_ptq( - self, - model, - weight_dtype, - activation_dtype, - group_size, - quantize_embedding, - expected_exception, - ): # pylint: disable=redefined-outer-name - if expected_exception: - with pytest.raises(expected_exception): - quantize_model_for_ptq( - model, - weight_dtype, - group_size, - activation_dtype, - quantize_embedding, - ) - else: - quantize_model_for_ptq( - model, weight_dtype, group_size, activation_dtype, quantize_embedding - ) - if quantize_embedding: - assert isinstance( - model.model.embed_tokens.weight, AffineQuantizedTensor - ), "Embedding weight should be quantized" - for child in list(model.children()): - if isinstance(child, torch.nn.Linear): - if activation_dtype: - assert isinstance( - child.weight, LinearActivationQuantizedTensor - ), "Linear weight should be quantized with activation quantization" - else: - assert isinstance( - child.weight, AffineQuantizedTensor - ), "Linear weight should be quantized without activation quantization" + @require_torch_2_8_0 + @requires_cuda_ge_8_9 + def test_convert_qat_model(self, model): + config = QATConfig( + weight_dtype="int4", + activation_dtype="int8", + group_size=8, + quantize_embedding=True, + ) + + # quantize model for qat + prepare_model_for_qat( + model, + config.weight_dtype, + config.group_size, + config.activation_dtype, + config.quantize_embedding, + ) + + assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) + assert isinstance(model.lm_head, FakeQuantizedLinear) + + # apply conversion + convert_qat_model( + model, + config.quantize_embedding, + ) + # ensure modules have been swapped out + assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) + assert not isinstance(model.lm_head, FakeQuantizedLinear) + + # ensure weights have been quantized + assert isinstance(model.model.embed_tokens.weight, nn.Parameter) + assert isinstance(model.lm_head.weight, nn.Parameter) class TestQuantizationCallback: @@ -216,12 +309,10 @@ class TestQuantizationCallback: global_step=0, ) - @require_torch_2_6_0 - def test_qat_callback_fake_quant_after_n_steps( - self, model, trainer_state - ): # pylint: disable=redefined-outer-name + @require_torch_2_8_0 + def test_qat_callback_fake_quant_after_n_steps(self, model, trainer_state): cfg = QATConfig( - weight_dtype="int8", + weight_dtype="int4", activation_dtype="int8", group_size=8, quantize_embedding=True, @@ -268,12 +359,10 @@ class TestQuantizationCallback: assert model.model.embed_tokens.weight_fake_quantizer.enabled assert model.lm_head.weight_fake_quantizer.enabled - @require_torch_2_6_0 - def test_qat_callback_fake_quant_after_n_steps_is_none( - self, model, trainer_state - ): # pylint: disable=redefined-outer-name + @require_torch_2_8_0 + def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state): cfg = QATConfig( - weight_dtype="int8", + weight_dtype="int4", activation_dtype="int8", group_size=8, quantize_embedding=True, @@ -306,45 +395,3 @@ class TestQuantizationCallback: # quantization should be enabled from the get-go assert model.model.embed_tokens.weight_fake_quantizer.enabled assert model.lm_head.weight_fake_quantizer.enabled - - -class TestConvertQATModelForPTQ: - """ - Test convert_qat_model_for_ptq - """ - - @require_torch_2_6_0 - def test_convert_qat_model_for_ptq( - self, model - ): # pylint: disable=redefined-outer-name - config = QATConfig( - weight_dtype="int8", - activation_dtype="int8", - group_size=8, - quantize_embedding=True, - ) - - # quantize model for qat - prepare_model_for_qat( - model, - config.weight_dtype, - config.group_size, - config.activation_dtype, - config.quantize_embedding, - ) - - assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) - assert isinstance(model.lm_head, FakeQuantizedLinear) - - # apply conversion - convert_qat_model_for_ptq( - model, - quantize_embedding=config.quantize_embedding, - ) - # ensure modules have been swapped out - assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) - assert not isinstance(model.lm_head, FakeQuantizedLinear) - - # ensure weights have been quantized - assert isinstance(model.model.embed_tokens.weight, nn.Parameter) - assert isinstance(model.lm_head.weight, nn.Parameter) diff --git a/tests/e2e/test_qwen.py b/tests/e2e/test_qwen.py index aa8b9f6c0..1c75d817b 100644 --- a/tests/e2e/test_qwen.py +++ b/tests/e2e/test_qwen.py @@ -19,7 +19,6 @@ class TestE2eQwen: @pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"]) def test_dpo(self, base_model, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": base_model, @@ -59,6 +58,7 @@ class TestE2eQwen: "bf16": "auto", "tf32": True, "gradient_checkpointing": True, + "save_first_step": False, } ) diff --git a/tests/e2e/test_reward_model_smollm2.py b/tests/e2e/test_reward_model_smollm2.py index 55405d58c..cc768b173 100644 --- a/tests/e2e/test_reward_model_smollm2.py +++ b/tests/e2e/test_reward_model_smollm2.py @@ -4,7 +4,6 @@ E2E tests for reward model lora llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -20,7 +19,6 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase): @with_temp_dir def test_rm_lora(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -59,15 +57,15 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase): "gradient_checkpointing": True, "warmup_ratio": 0.1, "use_tensorboard": True, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high" ) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_save_first_step.py b/tests/e2e/test_save_first_step.py new file mode 100644 index 000000000..ce2d3f145 --- /dev/null +++ b/tests/e2e/test_save_first_step.py @@ -0,0 +1,100 @@ +""" +E2E tests for relora llama +""" + +import unittest +from pathlib import Path + +import pytest + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from .utils import check_model_output_exists, with_temp_dir + + +class TestSaveFirstStepCallback(unittest.TestCase): + """Test cases for save_first_step callback config.""" + + @with_temp_dir + def test_save_first_step(self, temp_dir): + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 512, + "val_set_size": 0.02, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + "bf16": True, + "save_safetensors": True, + "save_first_step": True, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg) + + @with_temp_dir + def test_no_save_first_step(self, temp_dir): + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 512, + "val_set_size": 0.02, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + "bf16": True, + "save_safetensors": True, + "save_first_step": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + with pytest.raises(AssertionError): + check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg) diff --git a/tests/e2e/test_schedulers.py b/tests/e2e/test_schedulers.py index e468081b1..5b9c56288 100644 --- a/tests/e2e/test_schedulers.py +++ b/tests/e2e/test_schedulers.py @@ -4,7 +4,6 @@ E2E tests for custom schedulers using Llama import unittest -from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config @@ -20,7 +19,6 @@ class TestCustomSchedulers(unittest.TestCase): @with_temp_dir def test_rex_scheduler(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -52,13 +50,13 @@ class TestCustomSchedulers(unittest.TestCase): "lr_scheduler": "rex", "warmup_steps": 5, "cosine_min_lr_ratio": 0.05, + "save_first_step": False, } ) cfg = validate_config(cfg) normalize_config(cfg) - cli_args = TrainerCliArgs() - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_streaming.py b/tests/e2e/test_streaming.py new file mode 100644 index 000000000..5dccf00dd --- /dev/null +++ b/tests/e2e/test_streaming.py @@ -0,0 +1,73 @@ +"""E2E tests for streaming dataset functionality""" + +# pylint: disable=duplicate-code + +import pytest + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from .utils import check_model_output_exists, check_tensorboard + + +class TestStreamingDatasets: + """Test case for streaming datasets""" + + @pytest.mark.parametrize( + "sample_packing", + [True, False], + ) + def test_streaming_dataset(self, temp_dir, sample_packing): + """Test streaming datasets""" + + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "flash_attention": True, + "sequence_len": 1024, + "sample_packing": sample_packing, + "pretrain_multipack_attn": sample_packing, + "streaming_multipack_buffer_size": 10000, + "dataset_processes": 1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + # Streaming config + "streaming": True, + "max_steps": 3, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "val_set_size": 0.0, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + "use_tensorboard": True, + "save_first_step": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + + # Verify training actually happened by checking loss decrease + check_tensorboard( + temp_dir + "/runs", + "train/train_loss", + 3.0, + "Train Loss (%s) is too high", + ) diff --git a/tests/e2e/test_tokenizer.py b/tests/e2e/test_tokenizer.py new file mode 100644 index 000000000..a65c17ac3 --- /dev/null +++ b/tests/e2e/test_tokenizer.py @@ -0,0 +1,63 @@ +""" +e2e test for saving the tokenizer +""" + +from unittest.mock import patch + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import check_model_output_exists + + +def test_tokenizer_no_save_jinja_files(temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.02, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "chat_template": "chatml", + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "max_steps": 5, + "save_first_step": False, + "fp16": False, + "tokenizer_save_jinja_files": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + with patch("axolotl.train.execute_training"): + train(cfg=cfg, dataset_meta=dataset_meta) + + check_model_output_exists(temp_dir, cfg) + with open(f"{temp_dir}/tokenizer_config.json", "r", encoding="utf-8") as f: + tokenizer_config = f.read() + assert "chat_template" in tokenizer_config diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 65069eb16..a2dd8bc5e 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -2,6 +2,7 @@ helper utils for tests """ +import importlib.util import os import shutil import tempfile @@ -77,6 +78,30 @@ def require_torch_2_6_0(test_case): return unittest.skipUnless(is_min_2_6_0(), "test requires torch>=2.6.0")(test_case) +def require_torch_2_7_0(test_case): + """ + Decorator marking a test that requires torch >= 2.7.0 + """ + + def is_min_2_7_0(): + torch_version = version.parse(torch.__version__) + return torch_version >= version.parse("2.7.0") + + return unittest.skipUnless(is_min_2_7_0(), "test requires torch>=2.7.0")(test_case) + + +def require_torch_2_8_0(test_case): + """ + Decorator marking a test that requires torch >= 2.7.0 + """ + + def is_min_2_8_0(): + torch_version = version.parse(torch.__version__) + return torch_version >= version.parse("2.8.0") + + return unittest.skipUnless(is_min_2_8_0(), "test requires torch>=2.8.0")(test_case) + + def require_torch_lt_2_6_0(test_case): """ Decorator marking a test that requires torch < 2.6.0 @@ -95,12 +120,7 @@ def require_vllm(test_case): """ def is_vllm_installed(): - try: - import vllm # pylint: disable=unused-import # noqa: F401 - - return True - except ImportError: - return False + return importlib.util.find_spec("vllm") is not None return unittest.skipUnless( is_vllm_installed(), "test requires vllm to be installed" @@ -113,25 +133,46 @@ def require_llmcompressor(test_case): """ def is_llmcompressor_installed(): - try: - import llmcompressor # pylint: disable=unused-import # noqa: F401 - - return True - except ImportError: - return False + return importlib.util.find_spec("llmcompressor") is not None return unittest.skipUnless( is_llmcompressor_installed(), "test requires llmcompressor to be installed" )(test_case) +def requires_sm_ge_100(test_case): + is_sm_ge_100 = ( + torch.cuda.is_available() + and torch.version.cuda + and torch.cuda.get_device_capability() >= (10, 0) + ) + return unittest.skipUnless(is_sm_ge_100, "test requires sm>=100")(test_case) + + +def requires_cuda_ge_8_9(test_case): + is_cuda_ge_8_9 = ( + torch.cuda.is_available() + and torch.version.cuda + and torch.cuda.get_device_capability() >= (8, 9) + ) + return unittest.skipUnless(is_cuda_ge_8_9, "test requires cuda>=8.9")(test_case) + + def is_hopper(): compute_capability = torch.cuda.get_device_capability() return compute_capability == (9, 0) +def require_hopper(test_case): + return unittest.skipUnless(is_hopper(), "test requires h100/hopper GPU")(test_case) + + def check_tensorboard( - temp_run_dir: str, tag: str, lt_val: float, assertion_err: str + temp_run_dir: str, + tag: str, + lt_val: float, + assertion_err: str, + rtol: float = 0.02, ) -> None: """ helper function to parse and check tensorboard logs @@ -139,8 +180,9 @@ def check_tensorboard( tb_log_path = most_recent_subdir(temp_run_dir) event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0]) reader = SummaryReader(event_file) - df = reader.scalars # pylint: disable=invalid-name - df = df[(df.tag == tag)] # pylint: disable=invalid-name + df = reader.scalars + df = df[(df.tag == tag)] + lt_val = (1 + rtol) * lt_val if "%s" in assertion_err: assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1] else: diff --git a/tests/hf_offline_utils.py b/tests/hf_offline_utils.py index 385e61f18..0e4a2f067 100644 --- a/tests/hf_offline_utils.py +++ b/tests/hf_offline_utils.py @@ -20,7 +20,7 @@ def reload_modules(hf_hub_offline): importlib.reload(huggingface_hub.constants) huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline importlib.reload(datasets.config) - setattr(datasets.config, "HF_HUB_OFFLINE", hf_hub_offline) + datasets.config.HF_HUB_OFFLINE = hf_hub_offline reset_sessions() diff --git a/tests/integrations/test_diffusion.py b/tests/integrations/test_diffusion.py new file mode 100644 index 000000000..141d8d150 --- /dev/null +++ b/tests/integrations/test_diffusion.py @@ -0,0 +1,274 @@ +"""Tests for diffusion trainer integration.""" + +# pylint: disable=redefined-outer-name,protected-access + +from unittest.mock import Mock + +import pytest +import torch + +from axolotl.integrations.diffusion import DiffusionTrainer +from axolotl.integrations.diffusion.utils import create_bidirectional_attention_mask +from axolotl.utils.dict import DictDefault + + +@pytest.fixture +def mock_tokenizer(): + """Create a mock tokenizer.""" + tokenizer = Mock() + tokenizer.bos_token_id = 1 + tokenizer.eos_token_id = 2 + tokenizer.pad_token_id = 0 + return tokenizer + + +@pytest.fixture +def diffusion_config(): + """Create a diffusion config.""" + return DictDefault( + { + "diffusion": { + "mask_token_id": 32000, + "eps": 1e-3, + "importance_weighting": False, + }, + "sample_packing": False, + } + ) + + +@pytest.fixture +def diffusion_trainer_instance(mock_tokenizer, diffusion_config): + """Create a diffusion trainer instance for testing methods directly.""" + # Create a minimal trainer instance just for testing methods + trainer = object.__new__(DiffusionTrainer) # Bypass __init__ + trainer.cfg = diffusion_config + trainer._special_token_ids = {0, 1, 2} # pad, bos, eos + trainer.processing_class = mock_tokenizer + trainer.store_metrics = Mock() # Mock metrics storage + return trainer + + +class TestDiffusionTrainer: + """Test the DiffusionTrainer class.""" + + def test_forward_process_basic(self, diffusion_trainer_instance): + """Test basic forward process without labels.""" + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + + noisy_batch, masked_indices, p_mask = ( + diffusion_trainer_instance._forward_process(input_ids, eps=0.1) + ) + + # Check shapes + assert noisy_batch.shape == input_ids.shape + assert masked_indices.shape == input_ids.shape + assert p_mask.shape == input_ids.shape + + # Check that special tokens are not masked + special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0) + assert not masked_indices[special_token_positions].any() + + # Check that mask token is applied + mask_token_id = diffusion_trainer_instance.cfg.diffusion.mask_token_id + masked_positions = masked_indices + if masked_positions.any(): + assert (noisy_batch[masked_positions] == mask_token_id).all() + + def test_forward_process_with_labels(self, diffusion_trainer_instance): + """Test forward process with SFT labels.""" + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) + + noisy_batch, masked_indices, p_mask = ( + diffusion_trainer_instance._forward_process( + input_ids, labels=labels, eps=0.1 + ) + ) + + # Check shapes + assert noisy_batch.shape == input_ids.shape + assert masked_indices.shape == input_ids.shape + assert p_mask.shape == input_ids.shape + + # Check that only answer tokens can be masked (where labels != -100) + non_answer_mask = labels == -100 + + # No masking should occur on non-answer tokens + assert not masked_indices[non_answer_mask].any() + + # p_mask should be the same for all positions (sampled timestep), + # but masking is only applied to answer tokens + assert p_mask.shape == input_ids.shape + # Verify that masked_indices respects the answer mask + assert not masked_indices[non_answer_mask].any() + + def test_forward_process_with_attention_mask(self, diffusion_trainer_instance): + """Test forward process with attention mask.""" + input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long) + attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long) + + _, masked_indices, p_mask = diffusion_trainer_instance._forward_process( + input_ids, attention_mask=attention_mask, eps=0.1 + ) + + # Check that padding tokens are not masked + padding_positions = attention_mask == 0 + assert not masked_indices[padding_positions].any() + assert (p_mask[padding_positions] == 0).all() + + def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance): + """Test bidirectional attention mask without sample packing.""" + input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long) + + mask = create_bidirectional_attention_mask(input_ids) + + # Should be all-to-all attention + expected_shape = (1, 1, 4, 4) + assert mask.shape == expected_shape + assert mask.all() + + def test_bidirectional_attention_mask_with_packing( + self, diffusion_trainer_instance + ): + """Test bidirectional attention mask with sample packing.""" + diffusion_trainer_instance.cfg.sample_packing = True + input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long) + # Sample IDs: first sample (1), second sample (2) + attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long) + + mask = create_bidirectional_attention_mask( + input_ids, attention_mask, sample_packing=True + ) + + # Check that tokens within same sample can attend to each other + # but not across samples + assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other + assert mask[0, 0, 1, 2].item() + assert not mask[0, 0, 0, 3].item() # Can't attend across samples + assert not mask[0, 0, 2, 4].item() + assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other + + def test_compute_loss_basic(self, diffusion_trainer_instance): + """Test basic loss computation.""" + # Mock model that returns logits + mock_model = Mock() + mock_outputs = Mock() + vocab_size = 1000 + seq_len = 5 + mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True) + mock_model.return_value = mock_outputs + mock_model.training = True + + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + + loss, outputs = diffusion_trainer_instance._compute_diffusion_loss( + mock_model, input_ids + ) + + # Check that loss is computed + assert isinstance(loss, torch.Tensor) + assert loss.requires_grad + assert outputs == mock_outputs + + # Check that metrics were stored + diffusion_trainer_instance.store_metrics.assert_called_once() + + def test_compute_loss_sft(self, diffusion_trainer_instance): + """Test loss computation with SFT labels.""" + # Mock model + mock_model = Mock() + mock_outputs = Mock() + vocab_size = 1000 + seq_len = 5 + mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True) + mock_model.return_value = mock_outputs + mock_model.training = True + diffusion_trainer_instance.cfg.datasets = Mock() + + input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) + labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) + + loss, _ = diffusion_trainer_instance._compute_diffusion_loss( + mock_model, input_ids, labels=labels + ) + + # Check that loss is computed + assert isinstance(loss, torch.Tensor) + assert loss.requires_grad + + # Check that SFT metrics were added + call_args = diffusion_trainer_instance.store_metrics.call_args[0][0] + assert "answer_ratio" in call_args + assert "avg_answer_length" in call_args + + def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance): + """Test loss computation when no tokens are masked.""" + # Mock model + mock_model = Mock() + mock_outputs = Mock() + vocab_size = 1000 + seq_len = 3 + mock_outputs.logits = torch.randn(1, seq_len, vocab_size) + mock_model.return_value = mock_outputs + mock_model.training = True + + # Only special tokens (which won't be masked) + input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long) + + loss, _ = diffusion_trainer_instance._compute_diffusion_loss( + mock_model, input_ids + ) + + # Loss should be zero when no tokens are masked + assert loss.item() == 0.0 + assert loss.requires_grad + + def test_cache_special_token_ids(self, mock_tokenizer): + """Test caching of special token IDs.""" + trainer = object.__new__(DiffusionTrainer) + trainer.processing_class = mock_tokenizer + trainer._cache_special_token_ids() + assert trainer._special_token_ids == {0, 1, 2} + + def test_cache_special_token_ids_no_tokenizer(self): + """Test caching when no tokenizer is available.""" + trainer = object.__new__(DiffusionTrainer) + trainer.processing_class = None + trainer._cache_special_token_ids() + + assert trainer._special_token_ids == set() + + def test_main_compute_loss_interface(self, diffusion_trainer_instance): + """Test the main compute_loss interface.""" + # Mock model + mock_model = Mock() + mock_outputs = Mock() + mock_outputs.logits = torch.randn(1, 5, 1000) + mock_model.return_value = mock_outputs + mock_model.training = True + + inputs = { + "input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long), + "attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long), + "labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long), + } + + # Test without return_outputs + loss = diffusion_trainer_instance.compute_loss(mock_model, inputs) + assert isinstance(loss, torch.Tensor) + + # Test with return_outputs + loss, outputs = diffusion_trainer_instance.compute_loss( + mock_model, inputs, return_outputs=True + ) + assert isinstance(loss, torch.Tensor) + assert outputs == mock_outputs + + def test_missing_input_ids_raises_error(self, diffusion_trainer_instance): + """Test that missing input_ids raises ValueError.""" + mock_model = Mock() + inputs = {"attention_mask": torch.tensor([[1, 1, 1]])} + + with pytest.raises(ValueError, match="input_ids is required"): + diffusion_trainer_instance.compute_loss(mock_model, inputs) diff --git a/tests/integrations/test_diffusion_callback.py b/tests/integrations/test_diffusion_callback.py new file mode 100644 index 000000000..3e8785fe0 --- /dev/null +++ b/tests/integrations/test_diffusion_callback.py @@ -0,0 +1,92 @@ +"""Tests for diffusion generation callback dataloader selection and triggering.""" + +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from axolotl.integrations.diffusion import DiffusionGenerationCallback + + +class DummyTrainer: + """Minimal trainer double with required attributes/methods for the callback.""" + + def __init__(self, use_eval: bool): + # Config used by callback + self.cfg = SimpleNamespace( + diffusion=SimpleNamespace( + generation_interval=1, + num_generation_samples=1, + generation_max_length=32, + generation_steps=4, + generation_temperature=0.0, + mask_token_id=16, + ), + use_wandb=False, + ) + + # Model/tokenizer are passed through to generate_samples; not used here + self.model = Mock() + self.processing_class = Mock() + + # Datasets and loaders + self.eval_dataset = object() if use_eval else None + self._train_loader = object() + self._eval_loader = object() + + # State for world process check + self.state = SimpleNamespace(is_world_process_zero=True) + + # Track which loader was requested + self.requested: list[str] = [] + + def get_train_dataloader(self): + self.requested.append("train") + return self._train_loader + + def get_eval_dataloader(self): + self.requested.append("eval") + return self._eval_loader + + +@pytest.mark.parametrize("use_eval", [False, True]) +def test_callback_uses_correct_dataloader(monkeypatch, use_eval): + trainer = DummyTrainer(use_eval=use_eval) + callback = DiffusionGenerationCallback(trainer) + + captured = {} + + # Patch generate_samples in the callback module's namespace + def fake_generate_samples(**kwargs): + captured["dataloader"] = kwargs.get("dataloader") + # Return one dummy sample to exercise logging path + return [ + { + "original": "o", + "masked": "m", + "generated": "g", + "mask_ratio": 0.5, + "masked_tokens": 1, + "total_tokens": 2, + } + ] + + monkeypatch.setattr( + "axolotl.integrations.diffusion.callbacks.generate_samples", + fake_generate_samples, + ) + + # Trigger at step 1 (interval=1) + args = SimpleNamespace() + state = SimpleNamespace(global_step=1) + control = SimpleNamespace() + + callback.on_step_end(args=args, state=state, control=control) + + # Assert the expected dataloader path was used + if use_eval: + assert trainer.requested[0] == "eval" + assert captured["dataloader"] is trainer._eval_loader + else: + assert trainer.requested[0] == "train" + assert captured["dataloader"] is trainer._train_loader diff --git a/tests/integrations/test_liger.py b/tests/integrations/test_liger.py index 5c4bd1028..d7b171ec2 100644 --- a/tests/integrations/test_liger.py +++ b/tests/integrations/test_liger.py @@ -10,7 +10,6 @@ from axolotl.utils.config import prepare_plugins, validate_config from axolotl.utils.dict import DictDefault -# pylint: disable=duplicate-code @pytest.fixture(name="minimal_liger_cfg") def fixture_cfg(): return DictDefault( @@ -30,7 +29,6 @@ def fixture_cfg(): ) -# pylint: disable=too-many-public-methods class TestValidation: """ Test the validation module for liger diff --git a/tests/monkeypatch/test_mistral_tokenizer_patch.py b/tests/monkeypatch/test_mistral_tokenizer_patch.py new file mode 100644 index 000000000..cb82c0890 --- /dev/null +++ b/tests/monkeypatch/test_mistral_tokenizer_patch.py @@ -0,0 +1,35 @@ +"""Integration tests for MistralCommonTokenizer patches.""" + +import pytest + + +class TestMistralTokenizerPatchIntegration: + """Test MistralCommonTokenizer patch integration.""" + + @pytest.mark.integration + def test_mistral_tokenizer_image_patch(self): + """Test that MistralCommonTokenizer image patch can be applied.""" + try: + from transformers.tokenization_mistral_common import MistralCommonTokenizer + except ImportError: + pytest.skip("MistralCommonTokenizer not available") + + from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import ( + apply_mistral_tokenizer_image_patch, + ) + + # Store original method + original_apply_chat_template = MistralCommonTokenizer.apply_chat_template + + # Apply patch + apply_mistral_tokenizer_image_patch() + + # Verify patch was applied + assert ( + MistralCommonTokenizer.apply_chat_template != original_apply_chat_template + ), "apply_chat_template was not patched" + + # Verify the method is still callable + assert callable(MistralCommonTokenizer.apply_chat_template), ( + "Patched method is not callable" + ) diff --git a/tests/monkeypatch/test_pixtral_flash_attention_patch.py b/tests/monkeypatch/test_pixtral_flash_attention_patch.py new file mode 100644 index 000000000..285fde41e --- /dev/null +++ b/tests/monkeypatch/test_pixtral_flash_attention_patch.py @@ -0,0 +1,77 @@ +"""Integration tests for Pixtral Flash Attention patches.""" + +import pytest +import torch + + +class TestPixtralFlashAttentionPatchIntegration: + """Test Pixtral Flash Attention patch integration.""" + + @pytest.mark.integration + def test_pixtral_flash_attention_patch(self): + """Test that Pixtral Flash Attention patch can be applied and works correctly.""" + try: + from transformers import modeling_flash_attention_utils + except ImportError: + pytest.skip("Flash Attention utils not available") + + from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import ( + apply_patch_is_packed_sequence, + ) + + # Store original method + original_is_packed_sequence = modeling_flash_attention_utils._is_packed_sequence + + # Apply patch and get unpatch function + unpatch_fn = apply_patch_is_packed_sequence() + + # Verify patch was applied + assert ( + modeling_flash_attention_utils._is_packed_sequence + != original_is_packed_sequence + ), "_is_packed_sequence was not patched" + + # Test the patched function with 1D position_ids + patched_fn = modeling_flash_attention_utils._is_packed_sequence + + # Test 1D position_ids 1 sequence + position_ids_1d = torch.tensor([0, 1, 2, 3]) + result = patched_fn(position_ids_1d, batch_size=1) + assert isinstance(result, bool), "Function should return a boolean" + assert result is False, "1D sequential position_ids should not be packed" + + # Test 1D packed 2 sequences + position_ids_1d_packed = torch.tensor([0, 1, 2, 0, 1, 2]) + result = patched_fn(position_ids_1d_packed, batch_size=1) + assert isinstance(result, bool), "Function should return a boolean" + assert result is True, "1D packed position_ids should be detected as packed" + + # Test 2D packed 2 sequences + position_ids_2d_packed = torch.tensor([[0, 1, 2, 3, 0, 1]]) + result = patched_fn(position_ids_2d_packed, batch_size=1) + assert isinstance(result, bool), "Function should return a boolean" + assert result is True, "2D packed position_ids should be detected as packed" + + # Test 2D 1 sequence + position_ids_2d_normal = torch.tensor([[0, 1, 2, 3, 4, 5]]) + result = patched_fn(position_ids_2d_normal, batch_size=1) + assert isinstance(result, bool), "Function should return a boolean" + assert result is False, "2D sequential position_ids should not be packed" + + # Test 2D batch size 2 + position_ids_2d_normal = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8]]) + result = patched_fn(position_ids_2d_normal, batch_size=2) + assert isinstance(result, bool), "Function should return a boolean" + assert result is False, "2D position_ids batch 2 should not be packed" + + # Test None case + result = patched_fn(None, batch_size=1) + assert isinstance(result, bool), "Function should return a boolean" + assert result is False, "None position_ids should return False" + + # Test unpatch function + unpatch_fn() + assert ( + modeling_flash_attention_utils._is_packed_sequence + == original_is_packed_sequence + ), "unpatch function did not restore original method" diff --git a/tests/monkeypatch/test_qwen3_next_modeling_patch.py b/tests/monkeypatch/test_qwen3_next_modeling_patch.py new file mode 100644 index 000000000..91d9fc1cf --- /dev/null +++ b/tests/monkeypatch/test_qwen3_next_modeling_patch.py @@ -0,0 +1,111 @@ +"""Integration tests for Qwen3 Next modeling patches.""" + +import pytest +import torch + +# Skip entire module if qwen3_next not available +qwen3_next = pytest.importorskip("transformers.models.qwen3_next.modeling_qwen3_next") + + +class TestQwen3NextModelingPatchIntegration: + """Test Qwen3 Next modeling patch integration.""" + + @pytest.mark.integration + def test_qwen3_next_decoder_layer_patch(self): + """Test that Qwen3Next decoder layer patch can be applied.""" + from axolotl.monkeypatch.models.qwen3_next.modeling import ( + patch_qwen3_next_decoder_layer, + ) + + # Store original method + original_forward = qwen3_next.Qwen3NextDecoderLayer.forward + + # Apply patch and get unpatch function + unpatch_fn = patch_qwen3_next_decoder_layer() + + # Verify patch was applied + assert qwen3_next.Qwen3NextDecoderLayer.forward != original_forward, ( + "decoder layer forward method was not patched" + ) + + # Verify the method is still callable + assert callable(qwen3_next.Qwen3NextDecoderLayer.forward), ( + "Patched method is not callable" + ) + + # Test unpatch function + if unpatch_fn: + unpatch_fn() + assert qwen3_next.Qwen3NextDecoderLayer.forward == original_forward, ( + "unpatch function did not restore original method" + ) + + @pytest.mark.integration + def test_qwen3_next_gateddelta_layer_patch(self): + """Test that Qwen3Next GatedDeltaNet patch can be applied.""" + from axolotl.monkeypatch.models.qwen3_next.modeling import ( + patch_qwen3_next_gateddelta_layer, + ) + + # Store original method + original_forward = qwen3_next.Qwen3NextGatedDeltaNet.forward + + # Apply patch and get unpatch function + unpatch_fn = patch_qwen3_next_gateddelta_layer() + + # Verify patch was applied + assert qwen3_next.Qwen3NextGatedDeltaNet.forward != original_forward, ( + "GatedDeltaNet forward method was not patched" + ) + + # Verify the method is still callable + assert callable(qwen3_next.Qwen3NextGatedDeltaNet.forward), ( + "Patched method is not callable" + ) + + # Test unpatch function + if unpatch_fn: + unpatch_fn() + assert qwen3_next.Qwen3NextGatedDeltaNet.forward == original_forward, ( + "unpatch function did not restore original method" + ) + + @pytest.mark.integration + def test_qwen3_next_imports_patch(self): + """Test that Qwen3Next imports patch can be applied without errors.""" + from axolotl.monkeypatch.models.qwen3_next.modeling import ( + patch_qwen3_next_imports, + ) + + # Apply patch - should not raise any exceptions even if modules unavailable + unpatch_fn = patch_qwen3_next_imports() + + # Test that unpatch function is returned (or None if skipped) + assert unpatch_fn is None or callable(unpatch_fn), ( + "patch_qwen3_next_imports should return None or callable unpatch function" + ) + + @pytest.mark.integration + def test_qwen3_next_modeling_packing_patch(self): + """Test that all Qwen3Next modeling patches can be applied together.""" + from axolotl.monkeypatch.models.qwen3_next.modeling import ( + patch_qwen3_next_modeling_packing, + ) + + # This should not raise any exceptions + patch_qwen3_next_modeling_packing() + + +@pytest.mark.integration +def test_get_cu_seqlens_utility(): + """Test the get_cu_seqlens utility function.""" + from axolotl.monkeypatch.models.qwen3_next.modeling import get_cu_seqlens + + # Test with simple position_ids + position_ids = torch.tensor([[0, 1, 2, 0, 1]]) + cu_seqlens = get_cu_seqlens(position_ids) + assert cu_seqlens.dtype == torch.int32, "Should be int32 dtype" + + # Should return tensor with start positions and total length + expected = torch.tensor([0, 3, 5], dtype=torch.int32) + assert torch.equal(cu_seqlens, expected), f"Expected {expected}, got {cu_seqlens}" diff --git a/tests/monkeypatch/test_trainer_accelerator_args.py b/tests/monkeypatch/test_trainer_accelerator_args.py new file mode 100644 index 000000000..fab2597f0 --- /dev/null +++ b/tests/monkeypatch/test_trainer_accelerator_args.py @@ -0,0 +1,26 @@ +""" +Unit tests for trainer accelerator args monkeypatch +""" + +import unittest + +from axolotl.monkeypatch.trainer_accelerator_args import ( + check_create_accelerate_code_is_patchable, +) + + +class TestTrainerAcceleratorArgs(unittest.TestCase): + """ + Unit test class for trainer accelerator args monkeypatch + """ + + def test_check_create_accelerate_code_is_patchable(self): + """ + Test that the upstream transformers code is still patchable. + This will fail if the patched code changes upstream. + """ + assert check_create_accelerate_code_is_patchable() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/monkeypatch/test_trainer_context_parallel_patch.py b/tests/monkeypatch/test_trainer_context_parallel_patch.py new file mode 100644 index 000000000..84c883e91 --- /dev/null +++ b/tests/monkeypatch/test_trainer_context_parallel_patch.py @@ -0,0 +1,66 @@ +"""Tests for the HF Trainer context parallel patch.""" + +import pytest +from transformers import Trainer + +from axolotl.monkeypatch.transformers.trainer_context_parallel import ( + GUARD_PATTERN, + PATCHED_GUARD, + patch_prepare_context_parallel_inputs, +) + + +@pytest.fixture +def restore_trainer_prepare_method(): + """Ensure Trainer._prepare_context_parallel_inputs is restored after a test.""" + original_method = getattr( + Trainer, + "_original_prepare_context_parallel_inputs", + Trainer._prepare_context_parallel_inputs, + ) + patched_attr_present = hasattr( + Trainer, "_axolotl_prepare_context_parallel_inputs_patched" + ) + + yield + + Trainer._prepare_context_parallel_inputs = original_method + if patched_attr_present: + delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched") + if hasattr(Trainer, "_original_prepare_context_parallel_inputs"): + delattr(Trainer, "_original_prepare_context_parallel_inputs") + if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source"): + delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source") + + +def test_patch_attention_guard(restore_trainer_prepare_method): + """Patch should swap the guard to allow sdpa or flash attention.""" + # Ensure we start from the unpatched method + if hasattr(Trainer, "_original_prepare_context_parallel_inputs"): + Trainer._prepare_context_parallel_inputs = ( + Trainer._original_prepare_context_parallel_inputs + ) + delattr(Trainer, "_original_prepare_context_parallel_inputs") + if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched"): + delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched") + + patch_prepare_context_parallel_inputs() + + patched_method = Trainer._prepare_context_parallel_inputs + assert patched_method is not None + assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False) + + source = Trainer._axolotl_prepare_context_parallel_inputs_source + assert GUARD_PATTERN not in source + assert PATCHED_GUARD in source + + +def test_patch_is_idempotent(restore_trainer_prepare_method): + """Calling the patch twice should leave the same patched function in place.""" + patch_prepare_context_parallel_inputs() + first_patched = Trainer._prepare_context_parallel_inputs + + patch_prepare_context_parallel_inputs() + second_patched = Trainer._prepare_context_parallel_inputs + + assert first_patched is second_patched diff --git a/tests/monkeypatch/test_trainer_loss_calc.py b/tests/monkeypatch/test_trainer_loss_calc.py new file mode 100644 index 000000000..c72cb621b --- /dev/null +++ b/tests/monkeypatch/test_trainer_loss_calc.py @@ -0,0 +1,26 @@ +"""Unit tests for trainer loss calc monkeypatch.""" + +import unittest + +from axolotl.monkeypatch.transformers.trainer_loss_calc import ( + check_evaluation_loop_is_patchable, + check_maybe_log_save_evaluate_is_patchable, +) + + +class TestTrainerLossCalc(unittest.TestCase): + """ + Unit test class for trainer loss calc monkeypatch + """ + + def test_trainer_loss_calc_is_patchable(self): + """ + Test that the upstream transformers code is still patchable. This will fail if + the patched code changes upstream. + """ + assert check_evaluation_loop_is_patchable() + assert check_maybe_log_save_evaluate_is_patchable() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/monkeypatch/test_voxtral_modeling_patch.py b/tests/monkeypatch/test_voxtral_modeling_patch.py new file mode 100644 index 000000000..878bbc185 --- /dev/null +++ b/tests/monkeypatch/test_voxtral_modeling_patch.py @@ -0,0 +1,43 @@ +"""Integration tests for Voxtral modeling patches.""" + +import pytest + + +class TestVoxtralModelingPatchIntegration: + """Test Voxtral modeling patch integration.""" + + @pytest.mark.integration + def test_voxtral_conditional_generation_patch(self): + """Test that Voxtral conditional generation patch can be applied.""" + try: + from transformers.models.voxtral.modeling_voxtral import ( + VoxtralForConditionalGeneration, + ) + except ImportError: + pytest.skip("VoxtralForConditionalGeneration not available") + + from axolotl.monkeypatch.models.voxtral.modeling import ( + patch_voxtral_conditional_generation_forward, + ) + + # Store original method + original_forward = VoxtralForConditionalGeneration.forward + + # Apply patch and get unpatch function + unpatch_fn = patch_voxtral_conditional_generation_forward() + + # Verify patch was applied + assert VoxtralForConditionalGeneration.forward != original_forward, ( + "forward method was not patched" + ) + + # Verify the method is still callable + assert callable(VoxtralForConditionalGeneration.forward), ( + "Patched method is not callable" + ) + + # Test unpatch function + unpatch_fn() + assert VoxtralForConditionalGeneration.forward == original_forward, ( + "unpatch function did not restore original method" + ) diff --git a/tests/patched/test_validation.py b/tests/patched/test_validation.py index 2c28a71ea..21299ed98 100644 --- a/tests/patched/test_validation.py +++ b/tests/patched/test_validation.py @@ -1,4 +1,3 @@ -# pylint: disable=too-many-lines """Module for testing the validation module""" import os @@ -49,7 +48,6 @@ class BaseValidation: self._caplog = caplog -# pylint: disable=too-many-public-methods class TestValidation(BaseValidation): """ Test the validation module @@ -241,7 +239,7 @@ class TestValidation(BaseValidation): def test_lr_as_float(self, minimal_cfg): cfg = ( - DictDefault( # pylint: disable=unsupported-binary-operation + DictDefault( { "learning_rate": "5e-5", } @@ -303,7 +301,7 @@ class TestValidation(BaseValidation): ) cfg = ( - DictDefault( # pylint: disable=unsupported-binary-operation + DictDefault( { "load_in_8bit": True, } @@ -315,7 +313,7 @@ class TestValidation(BaseValidation): validate_config(cfg) cfg = ( - DictDefault( # pylint: disable=unsupported-binary-operation + DictDefault( { "gptq": True, } @@ -327,7 +325,7 @@ class TestValidation(BaseValidation): validate_config(cfg) cfg = ( - DictDefault( # pylint: disable=unsupported-binary-operation + DictDefault( { "load_in_4bit": False, } @@ -339,7 +337,7 @@ class TestValidation(BaseValidation): validate_config(cfg) cfg = ( - DictDefault( # pylint: disable=unsupported-binary-operation + DictDefault( { "load_in_4bit": True, } @@ -361,7 +359,7 @@ class TestValidation(BaseValidation): ) cfg = ( - DictDefault( # pylint: disable=unsupported-binary-operation + DictDefault( { "load_in_8bit": True, } @@ -373,7 +371,7 @@ class TestValidation(BaseValidation): validate_config(cfg) cfg = ( - DictDefault( # pylint: disable=unsupported-binary-operation + DictDefault( { "gptq": True, } @@ -385,7 +383,7 @@ class TestValidation(BaseValidation): validate_config(cfg) cfg = ( - DictDefault( # pylint: disable=unsupported-binary-operation + DictDefault( { "load_in_4bit": True, } @@ -692,7 +690,7 @@ class TestValidation(BaseValidation): "bf16": True, "capabilities": {"bf16": False}, "env_capabilities": { - "torch_version": "2.5.1", + "torch_version": "2.6.0", }, } ) @@ -1202,7 +1200,7 @@ class TestValidation(BaseValidation): cfg, capabilities=capabilities, env_capabilities=env_capabilities ) - env_capabilities = {"torch_version": "2.5.1"} + env_capabilities = {"torch_version": "2.6.0"} capabilities = {"bf16": False} _ = validate_config( cfg, capabilities=capabilities, env_capabilities=env_capabilities @@ -1244,7 +1242,7 @@ class TestTorchCompileValidation(BaseValidation): | minimal_cfg ) - env_capabilities = {"torch_version": "2.5.1"} + env_capabilities = {"torch_version": "2.6.0"} capabilities = {"bf16": True} updated_cfg = validate_config( cfg, capabilities=capabilities, env_capabilities=env_capabilities @@ -1690,3 +1688,18 @@ class TestValidationMLflow(BaseValidation): assert new_cfg.use_mlflow is True os.environ.pop("MLFLOW_EXPERIMENT_NAME", None) + + +class TestDataloaderValidation(BaseValidation): + """ + tests for dataloader_* sane defaults + """ + + def test_dataloader_auto_defaults(self, minimal_cfg): + cfg = minimal_cfg + + new_cfg = validate_config(cfg, {"n_gpu": 8}, {"torch_version": "2.6.0"}) + + assert new_cfg.dataloader_num_workers == 8 + assert new_cfg.dataloader_pin_memory is True + assert new_cfg.dataloader_prefetch_factor == 256 diff --git a/tests/prompt_strategies/conftest.py b/tests/prompt_strategies/conftest.py index fe59e00d8..0af7b3e93 100644 --- a/tests/prompt_strategies/conftest.py +++ b/tests/prompt_strategies/conftest.py @@ -30,7 +30,6 @@ def fixture_assistant_dataset(): @pytest.fixture(name="sharegpt_dataset") def fixture_sharegpt_dataset(): - # pylint: disable=duplicate-code return Dataset.from_list( [ { @@ -47,7 +46,6 @@ def fixture_sharegpt_dataset(): @pytest.fixture(name="basic_dataset") def fixture_basic_dataset(): - # pylint: disable=duplicate-code return Dataset.from_list( [ { @@ -65,7 +63,6 @@ def fixture_basic_dataset(): @pytest.fixture(name="toolcalling_dataset") def fixture_toolcalling_dataset(): - # pylint: disable=duplicate-code return Dataset.from_list( [ { @@ -112,7 +109,7 @@ def fixture_toolcalling_dataset(): @enable_hf_offline def fixture_llama3_tokenizer( download_llama3_8b_instruct_model_fixture, -): # pylint: disable=unused-argument,redefined-outer-name +): tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct") return tokenizer @@ -129,7 +126,7 @@ def fixture_smollm2_tokenizer(): @enable_hf_offline def fixture_mistralv03_tokenizer( download_mlx_mistral_7b_model_fixture, -): # pylint: disable=unused-argument,redefined-outer-name +): tokenizer = AutoTokenizer.from_pretrained( "mlx-community/Mistral-7B-Instruct-v0.3-4bit" ) @@ -143,6 +140,12 @@ def fixture_phi35_tokenizer(): return tokenizer +@pytest.fixture(name="phi4_tokenizer", scope="session", autouse=True) +def fixture_phi4_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-4-reasoning") + return tokenizer + + @pytest.fixture(name="gemma2_tokenizer", scope="session", autouse=True) def fixture_gemma2_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("mlx-community/gemma-2-9b-it-4bit") @@ -150,6 +153,39 @@ def fixture_gemma2_tokenizer(): return tokenizer +@pytest.fixture(name="magistral_tokenizer") +def fixture_magistral_tokenizer(): + from axolotl.utils.mistral import HFMistralTokenizer + + tokenizer = HFMistralTokenizer.from_pretrained("mistralai/Magistral-Small-2506") + return tokenizer + + +@pytest.fixture(name="devstral_tokenizer") +def fixture_devstral_tokenizer(): + from axolotl.utils.mistral import HFMistralTokenizer + + tokenizer = HFMistralTokenizer.from_pretrained("mistralai/Devstral-Small-2505") + return tokenizer + + +@pytest.fixture(name="devstral_1_1_tokenizer") +def fixture_devstral_1_1_tokenizer(): + from axolotl.utils.mistral import HFMistralTokenizer + + tokenizer = HFMistralTokenizer.from_pretrained("mistralai/Devstral-Small-2507") + return tokenizer + + +@pytest.fixture(name="qwen3_tokenizer") +def qwen3_tokenizer_fixture( + download_qwen3_half_billion_model, +): # pylint: disable=unused-argument,redefined-outer-name + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") + + return tokenizer + + @pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja") def fixture_mistralv03_chat_template_jinja_w_system() -> str: return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n' diff --git a/tests/prompt_strategies/messages/test_chat.py b/tests/prompt_strategies/messages/test_chat.py index a4c2ae67f..f083232a8 100644 --- a/tests/prompt_strategies/messages/test_chat.py +++ b/tests/prompt_strategies/messages/test_chat.py @@ -2,7 +2,6 @@ tests for chat_template prompt strategy """ -# pylint: disable=duplicate-code import unittest from axolotl.prompt_strategies.messages.chat import load @@ -53,9 +52,9 @@ class TestMessagesChatLlama3: # fmt: on LOG.debug(f"Expected input_ids: {expected_input_ids}") LOG.debug(f"Actual input_ids: {input_ids}") - assert ( - input_ids == expected_input_ids - ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + assert input_ids == expected_input_ids, ( + f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + ) if __name__ == "__main__": diff --git a/tests/prompt_strategies/test_alpaca.py b/tests/prompt_strategies/test_alpaca.py index 78f783747..b96ebce19 100644 --- a/tests/prompt_strategies/test_alpaca.py +++ b/tests/prompt_strategies/test_alpaca.py @@ -30,7 +30,6 @@ def fixture_alpaca_dataset(): @pytest.fixture(name="tokenizer") @enable_hf_offline def fixture_tokenizer(): - # pylint: disable=all tokenizer = AutoTokenizer.from_pretrained( "casperhansen/mistral-7b-instruct-v0.1-awq" ) diff --git a/tests/prompt_strategies/test_chat_template_ds_schema_unification.py b/tests/prompt_strategies/test_chat_template_ds_schema_unification.py new file mode 100644 index 000000000..4f4e32208 --- /dev/null +++ b/tests/prompt_strategies/test_chat_template_ds_schema_unification.py @@ -0,0 +1,63 @@ +""" +Tests for chat template prompt strategy with schema unification for none fields +""" + +import json + +import pytest +from datasets import Dataset + +from axolotl.prompt_strategies.chat_template import StrategyLoader +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="messages_w_tools") +def fixture_messages_w_tools(): + jsons = """ +{"messages":[{"role":"user","content":"move to (0, 1)"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"move","arguments":{"x":0,"y":1}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false} +{"messages":[{"role":"user","content":"turn 270 degree"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"turn","arguments":{"theta": 270}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false} +{"messages":[{"role":"user","content":"jump high"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"invalid_prompt","arguments":{"message": "jump is not a valid action"}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false} + """.strip().split("\n") + rows = [json.loads(row) for row in jsons] + return Dataset.from_list(rows) + + +@pytest.fixture(name="qwen3_prompt_strategy") +def qwen3_chat_template_strategy(qwen3_tokenizer): + cfg = DictDefault( + sequence_len=2048, + chat_template="qwen3", + eot_tokens=["<|im_end|>"], + ) + ds_cfg = DictDefault( + type="chat_template", + ) + load = StrategyLoader() + strat = load(qwen3_tokenizer, cfg, ds_cfg) + return strat + + +class TestSchemaUnification: + """ + Test class on handling null fields for tool calling + """ + + def test_schema_unification_single_prompt( + self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer + ): + for row in messages_w_tools: + inputs = qwen3_prompt_strategy.tokenize_prompt(row) + decoded = qwen3_tokenizer.decode(inputs["input_ids"]) + tool_call = decoded.split("")[-1].split("")[0] + assert '"message": null' not in tool_call + assert '"theta": null' not in tool_call + + def test_schema_unification_batched( + self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer + ): + rows = messages_w_tools.map(qwen3_prompt_strategy.tokenize_prompt, batched=True) + for row in rows: + decoded = qwen3_tokenizer.decode(row["input_ids"]) + tool_call = decoded.split("")[-1].split("")[0] + assert '"message": null' not in tool_call + assert '"theta": null' not in tool_call diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index 371ccf616..90e0e274b 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -67,9 +67,9 @@ class TestAssistantChatTemplateLlama3: # fmt: on LOG.debug(f"Expected input_ids: {expected_input_ids}") LOG.debug(f"Actual input_ids: {input_ids}") - assert ( - input_ids == expected_input_ids - ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + assert input_ids == expected_input_ids, ( + f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + ) def test_llama3(self, llama3_tokenizer, assistant_dataset): LOG.info("Testing llama-3 with assistant dataset") @@ -109,9 +109,9 @@ class TestAssistantChatTemplateLlama3: # fmt: on LOG.debug(f"Expected input_ids: {expected_input_ids}") LOG.debug(f"Actual input_ids: {input_ids}") - assert ( - input_ids == expected_input_ids - ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + assert input_ids == expected_input_ids, ( + f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + ) def test_phi35(self, phi35_tokenizer, assistant_dataset): LOG.info("Testing phi-3.5 with assistant dataset") @@ -161,15 +161,15 @@ class TestAssistantChatTemplateLlama3: # fmt: on LOG.debug(f"Expected input_ids: {expected_input_ids}") LOG.debug(f"Actual input_ids: {input_ids}") - assert ( - input_ids == expected_input_ids - ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + assert input_ids == expected_input_ids, ( + f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + ) LOG.debug(f"Expected labels : {expected_labels}") LOG.debug(f"Actual labels : {labels}") - assert ( - labels == expected_labels - ), f"Input IDs mismatch: {labels} != {expected_labels}" + assert labels == expected_labels, ( + f"Input IDs mismatch: {labels} != {expected_labels}" + ) def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset): LOG.info("Testing llama-3 with assistant dataset including training data") @@ -234,7 +234,7 @@ class TestSharegptChatTemplateLlama3: def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset): LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts") - # pylint: disable=duplicate-code + strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, @@ -285,16 +285,16 @@ class TestSharegptChatTemplateLlama3: LOG.debug(f"Expected labels: {expected_labels}") LOG.debug(f"Actual labels: {labels}") - assert ( - input_ids == expected_input_ids - ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" - assert ( - labels == expected_labels - ), f"Labels mismatch: {labels} != {expected_labels}" + assert input_ids == expected_input_ids, ( + f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + ) + assert labels == expected_labels, ( + f"Labels mismatch: {labels} != {expected_labels}" + ) def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset): LOG.info("Testing ShareGPT style datasets with llama-3 human prompts") - # pylint: disable=duplicate-code + strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, @@ -345,16 +345,16 @@ class TestSharegptChatTemplateLlama3: LOG.debug(f"Expected labels: {expected_labels}") LOG.debug(f"Actual labels: {labels}") - assert ( - input_ids == expected_input_ids - ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" - assert ( - labels == expected_labels - ), f"Labels mismatch: {labels} != {expected_labels}" + assert input_ids == expected_input_ids, ( + f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + ) + assert labels == expected_labels, ( + f"Labels mismatch: {labels} != {expected_labels}" + ) def test_llama3_system_human(self, llama3_tokenizer, basic_dataset): LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts") - # pylint: disable=duplicate-code + strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, @@ -409,12 +409,12 @@ class TestSharegptChatTemplateLlama3: LOG.debug(f"Expected labels: {expected_labels}") LOG.debug(f"Actual labels: {labels}") - assert ( - input_ids == expected_input_ids - ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" - assert ( - labels == expected_labels - ), f"Labels mismatch: {labels} != {expected_labels}" + assert input_ids == expected_input_ids, ( + f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + ) + assert labels == expected_labels, ( + f"Labels mismatch: {labels} != {expected_labels}" + ) class TestAssistantToolCallingChatTemplateLlama32Vision: @@ -481,13 +481,13 @@ class TestAssistantToolCallingChatTemplateLlama32Vision: ] # fmt: on - assert ( - input_ids == expected_input_ids - ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + assert input_ids == expected_input_ids, ( + f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + ) - assert ( - labels == expected_labels - ), f"Labels mismatch: {labels} != {expected_labels}" + assert labels == expected_labels, ( + f"Labels mismatch: {labels} != {expected_labels}" + ) def test_llama32vision_train_on_tools( self, llama3_tokenizer, toolcalling_dataset, llama3_2_vision_chat_template_jinja @@ -495,7 +495,6 @@ class TestAssistantToolCallingChatTemplateLlama32Vision: LOG.info( "Testing assistant style datasets with tool_calling with llama-32 chat template, training on tools" ) - # pylint: disable=duplicate-code strategy = ChatTemplateStrategy( ChatTemplatePrompter( @@ -549,13 +548,13 @@ class TestAssistantToolCallingChatTemplateLlama32Vision: ] # fmt: on - assert ( - input_ids == expected_input_ids - ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + assert input_ids == expected_input_ids, ( + f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + ) - assert ( - labels == expected_labels - ), f"Labels mismatch: {labels} != {expected_labels}" + assert labels == expected_labels, ( + f"Labels mismatch: {labels} != {expected_labels}" + ) if __name__ == "__main__": diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py index 7f011f954..fd39a4305 100644 --- a/tests/prompt_strategies/test_chat_templates_advanced.py +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -2,8 +2,6 @@ tests for chat_template prompt strategy """ -# pylint: disable=too-many-lines - from copy import deepcopy import pytest @@ -33,15 +31,14 @@ PARAMETRIZE_PARAMS = [ "mistralv03_tokenizer_chat_template_jinja", "[/INST]", ), - # TODO: temporarily skip gemma due to gemma3 template - # Re-enable on new chat_template implementation for perf - # ( - # "gemma2_tokenizer", - # "jinja", - # "gemma2_tokenizer_chat_template_jinja", - # "", - # ), + ( + "gemma2_tokenizer", + "jinja", + "gemma2_tokenizer_chat_template_jinja", + "", + ), ("phi35_tokenizer", "phi_35", None, "<|end|>"), + ("phi4_tokenizer", "phi_4", None, "<|im_end|>"), ] @@ -95,15 +92,11 @@ class TestChatTemplateConfigurations: if ( turn_idx == 0 and turn.get("from") in ["system", "context"] - and ( - "mistral" in tokenizer.name_or_path.lower() - or "gemma" - in tokenizer.name_or_path.lower() # temporarily skip gemma due to gemma3 template - ) + and ("mistral" in tokenizer.name_or_path.lower()) ): - assert ( - start_idx == -1 and end_idx == -1 - ), "Expected system message to be skipped" + assert start_idx == -1 and end_idx == -1, ( + "Expected system message to be skipped" + ) return True return False @@ -160,7 +153,9 @@ class TestChatTemplateConfigurations: assert all( label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx] - ), f"Expected labels for input '{response}' to be ignored, but got {labels[start_idx:end_idx]}" + ), ( + f"Expected labels for input '{response}' to be ignored, but got {labels[start_idx:end_idx]}" + ) LOG.debug("Full labels: %s", labels) LOG.debug("Full input_ids: %s", input_ids) @@ -220,11 +215,15 @@ class TestChatTemplateConfigurations: if is_assistant: assert all( label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx] - ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}" + ), ( + f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}" + ) else: assert all( label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx] - ), f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}" + ), ( + f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}" + ) def test_roles_to_train_human_assistant_only( self, @@ -281,11 +280,15 @@ class TestChatTemplateConfigurations: if should_be_labelled: assert all( label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx] - ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}" + ), ( + f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}" + ) else: assert all( label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx] - ), f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}" + ), ( + f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}" + ) def test_roles_to_train_all( self, @@ -332,13 +335,15 @@ class TestChatTemplateConfigurations: continue decoded_response = tokenizer.decode(input_ids[start_idx:end_idx]) - assert ( - response in decoded_response - ), f"Response {response} not found in index {start_idx}:{end_idx} decoded:{decoded_response}" + assert response in decoded_response, ( + f"Response {response} not found in index {start_idx}:{end_idx} decoded:{decoded_response}" + ) assert all( label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx] - ), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:end_idx]}" + ), ( + f"Expected labels for response '{response}' to be set, but got {labels[start_idx:end_idx]}" + ) def test_empty_roles_to_train( self, @@ -376,9 +381,9 @@ class TestChatTemplateConfigurations: # Verify that no labels are set when roles_to_train is empty LOG.debug("Full labels: %s", labels) - assert all( - label == IGNORE_TOKEN_ID for label in labels - ), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty" + assert all(label == IGNORE_TOKEN_ID for label in labels), ( + "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty" + ) def test_train_on_eos_all( self, @@ -422,9 +427,9 @@ class TestChatTemplateConfigurations: assert len(eos_indices) > 0, "Expected at least one EOS token in the input" for eos_idx in eos_indices: - assert ( - labels[eos_idx] != IGNORE_TOKEN_ID - ), f"Expected EOS token at index {eos_idx} to be labeled" + assert labels[eos_idx] != IGNORE_TOKEN_ID, ( + f"Expected EOS token at index {eos_idx} to be labeled" + ) def test_train_on_eos_turn( self, @@ -482,9 +487,9 @@ class TestChatTemplateConfigurations: while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id: eos_idx += 1 - assert eos_idx < len( - input_ids - ), f"Could not find EOS token after '{response}'" + assert eos_idx < len(input_ids), ( + f"Could not find EOS token after '{response}'" + ) LOG.debug( f"Turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}, eos_idx={eos_idx}" @@ -497,13 +502,13 @@ class TestChatTemplateConfigurations: # Verify EOS token labeling based on role is_assistant = turn["from"] == "assistant" if is_assistant: - assert ( - labels[eos_idx] != IGNORE_TOKEN_ID - ), f"Expected EOS token after assistant response '{response}' to be labeled" + assert labels[eos_idx] != IGNORE_TOKEN_ID, ( + f"Expected EOS token after assistant response '{response}' to be labeled" + ) else: - assert ( - labels[eos_idx] == IGNORE_TOKEN_ID - ), f"Expected EOS token after non-assistant input '{response}' to not be labeled" + assert labels[eos_idx] == IGNORE_TOKEN_ID, ( + f"Expected EOS token after non-assistant input '{response}' to not be labeled" + ) def test_train_on_eos_last( self, @@ -550,12 +555,12 @@ class TestChatTemplateConfigurations: # Check that only the last EOS token is labeled for idx in eos_indices[:-1]: - assert ( - labels[idx] == IGNORE_TOKEN_ID - ), f"Expected EOS token at index {idx} to not be labeled" - assert ( - labels[last_eos_idx] != IGNORE_TOKEN_ID - ), f"Expected last EOS token at index {last_eos_idx} to be labeled" + assert labels[idx] == IGNORE_TOKEN_ID, ( + f"Expected EOS token at index {idx} to not be labeled" + ) + assert labels[last_eos_idx] != IGNORE_TOKEN_ID, ( + f"Expected last EOS token at index {last_eos_idx} to be labeled" + ) def test_train_on_eos_none( self, @@ -599,9 +604,9 @@ class TestChatTemplateConfigurations: assert len(eos_indices) > 0, "Expected at least one EOS token in the input" for eos_idx in eos_indices: - assert ( - labels[eos_idx] == IGNORE_TOKEN_ID - ), f"Expected EOS token at index {eos_idx} to not be labeled" + assert labels[eos_idx] == IGNORE_TOKEN_ID, ( + f"Expected EOS token at index {eos_idx} to not be labeled" + ) def test_drop_system_message( self, @@ -639,9 +644,9 @@ class TestChatTemplateConfigurations: # Check if system message is not present in input_ids system_message = "You are an AI assistant." decoded_message = tokenizer.decode(input_ids) - assert ( - system_message not in decoded_message - ), "Expected system message to be dropped" + assert system_message not in decoded_message, ( + "Expected system message to be dropped" + ) def test_custom_roles( self, @@ -716,7 +721,9 @@ class TestChatTemplateConfigurations: else: assert all( label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx] - ), f"Expected labels for non-AI message '{response}' to be IGNORE_TOKEN_ID" + ), ( + f"Expected labels for non-AI message '{response}' to be IGNORE_TOKEN_ID" + ) def test_message_field_training( self, @@ -781,13 +788,13 @@ class TestChatTemplateConfigurations: def verify_labels(labels_span, should_train, context_message): """Helper to verify if a span of labels matches expected training state""" if should_train: - assert all( - label != IGNORE_TOKEN_ID for label in labels_span - ), f"Expected all labels for {context_message} to be set, but got {labels_span}" + assert all(label != IGNORE_TOKEN_ID for label in labels_span), ( + f"Expected all labels for {context_message} to be set, but got {labels_span}" + ) else: - assert all( - label == IGNORE_TOKEN_ID for label in labels_span - ), f"Expected all labels for {context_message} to be {IGNORE_TOKEN_ID}, but got {labels_span}" + assert all(label == IGNORE_TOKEN_ID for label in labels_span), ( + f"Expected all labels for {context_message} to be {IGNORE_TOKEN_ID}, but got {labels_span}" + ) # Process all turns and verify labeling for i, turn in enumerate(modified_dataset[0]["messages"]): @@ -866,9 +873,9 @@ class TestChatTemplateConfigurations: actual_labels = labels[ start_idx : start_idx + len(token_offsets_masked) ] - assert ( - actual_labels == expected_labels - ), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}" + assert actual_labels == expected_labels, ( + f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}" + ) # Verify each detail section for detail in adjusted_train_details: @@ -935,36 +942,14 @@ class TestChatTemplateConfigurations: "messages", ) - if chat_template == "llama3": - assert variables == {"role", "content"}, ( - f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n" - f"Got: {variables}\n" - f"Chat template: {actual_jinja_template}" - ) - elif chat_template == "chatml": - assert variables == {"role", "content"}, ( - f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n" - f"Got: {variables}\n" - f"Chat template: {actual_jinja_template}" - ) - elif chat_template == "jinja" and tokenizer == "mistralv03_tokenizer": - assert variables == {"role", "content", "tool_call_id", "tool_calls"}, ( - f"Expected variables: {'role', 'content', 'tool_call_id', 'tool_calls'} from {tokenizer}/{chat_template}\n" - f"Got: {variables}\n" - f"Chat template: {actual_jinja_template}" - ) - elif chat_template == "jinja" and tokenizer == "gemma2_tokenizer": - assert variables == {"role", "content"}, ( - f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n" - f"Got: {variables}\n" - f"Chat template: {actual_jinja_template}" - ) - elif chat_template == "phi_35": - assert variables == {"role", "content"}, ( - f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n" - f"Got: {variables}\n" - f"Chat template: {actual_jinja_template}" - ) + # Special case for Mistral with additional tool variables + if chat_template == "jinja" and tokenizer == "mistralv03_tokenizer": + expected_variables = {"role", "content", "tool_call_id", "tool_calls"} + # Most chat templates use the standard role and content variables + elif chat_template in ["llama3", "chatml", "phi_35", "phi_4"] or ( + chat_template == "jinja" and tokenizer == "gemma2_tokenizer" + ): + expected_variables = {"role", "content"} else: LOG.warning( f"Unsupported chat template: {chat_template} with {chat_template_jinja}" @@ -973,13 +958,19 @@ class TestChatTemplateConfigurations: f"Unsupported chat template: {chat_template} with {chat_template_jinja}" ) + assert variables == expected_variables, ( + f"Expected variables: {expected_variables} from {tokenizer}/{chat_template}\n" + f"Got: {variables}\n" + f"Chat template: {actual_jinja_template}" + ) + def test_eot_tokens_conflict_with_eos_token( self, tokenizer, chat_template, chat_template_jinja, eos_token, - basic_dataset, # pylint: disable=unused-argument + basic_dataset, request, ): """Test that an error is raised when eot_tokens contains eos_token and train_on_eot/train_on_eos conflict""" @@ -1026,7 +1017,7 @@ class TestChatTemplateConfigurations: chat_template, chat_template_jinja, eos_token, - basic_dataset, # pylint: disable=unused-argument + basic_dataset, request, ): """Test that eot_tokens inherits from eos_token when not specified""" @@ -1053,12 +1044,12 @@ class TestChatTemplateConfigurations: ) # In backward compatibility mode, eot_tokens should be derived from eos_token - assert strategy.eot_tokens == [ - tokenizer.eos_token - ], f"Expected eot_tokens to inherit from eos_token, got {strategy.eot_tokens}" - assert ( - strategy.train_on_eot == "turn" - ), f"Expected train_on_eot to inherit from train_on_eos, got {strategy.train_on_eot}" + assert strategy.eot_tokens == [tokenizer.eos_token], ( + f"Expected eot_tokens to inherit from eos_token, got {strategy.eot_tokens}" + ) + assert strategy.train_on_eot == "turn", ( + f"Expected train_on_eot to inherit from train_on_eos, got {strategy.train_on_eot}" + ) def test_token_not_in_template( self, @@ -1112,7 +1103,7 @@ class TestChatTemplateConfigurations: tokenizer, chat_template, chat_template_jinja, - eos_token, # pylint: disable=unused-argument + eos_token, basic_dataset, request, ): @@ -1178,13 +1169,13 @@ class TestChatTemplateConfigurations: ) if is_after_assistant: - assert ( - labels[eot_idx] != IGNORE_TOKEN_ID - ), f"Expected EOT token after assistant turn at index {eot_idx} to be labeled" + assert labels[eot_idx] != IGNORE_TOKEN_ID, ( + f"Expected EOT token after assistant turn at index {eot_idx} to be labeled" + ) else: - assert ( - labels[eot_idx] == IGNORE_TOKEN_ID - ), f"Expected EOT token not after assistant turn at index {eot_idx} to not be labeled" + assert labels[eot_idx] == IGNORE_TOKEN_ID, ( + f"Expected EOT token not after assistant turn at index {eot_idx} to not be labeled" + ) def test_multiple_train_on_eot_settings( self, @@ -1245,9 +1236,9 @@ class TestChatTemplateConfigurations: i for i, token_id in enumerate(input_ids) if token_id == eos_token_id ] - assert ( - len(eos_indices) > 0 - ), "Expected at least one EOS/EOT token in the input" + assert len(eos_indices) > 0, ( + "Expected at least one EOS/EOT token in the input" + ) # Check labeling for each EOS/EOT token for idx, eos_idx in enumerate(eos_indices): @@ -1273,10 +1264,167 @@ class TestChatTemplateConfigurations: ) if expected_label: - assert ( - labels[eos_idx] == IGNORE_TOKEN_ID - ), f"Expected EOT token at index {eos_idx} to not be labeled with train_on_eot='{setting}'" + assert labels[eos_idx] == IGNORE_TOKEN_ID, ( + f"Expected EOT token at index {eos_idx} to not be labeled with train_on_eot='{setting}'" + ) else: - assert ( - labels[eos_idx] != IGNORE_TOKEN_ID - ), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'" + assert labels[eos_idx] != IGNORE_TOKEN_ID, ( + f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'" + ) + + +class TestChatTemplateToolCalling: + """ + Test class for tool calling functionality with chat templates. + """ + + def test_tool_calling_with_llama4_template( + self, + llama3_tokenizer, + ): + LOG.info("Testing tool calling with llama3 tokenizer and llama4 chat template") + + # Create tool calling dataset + tool_calling_dataset = [ + { + "tools": [ + { + "type": "function", + "function": { + "name": "xml_escape", + "description": 'Replaces any "<", ">", or "&" characters in the input string with their corresponding XML entities.', + "parameters": { + "type": "object", + "properties": { + "s": { + "type": "string", + "description": "The input string to be XML-escaped.", + } + }, + "required": ["s"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "multiples", + "description": "Generates a list of all the multiples of a number that are less than a given limit.", + "parameters": { + "type": "object", + "properties": { + "number": { + "type": "integer", + "description": "The number to find multiples of.", + }, + "limit": { + "type": "integer", + "description": "The upper limit for the multiples.", + }, + }, + "required": ["number", "limit"], + }, + }, + }, + ], + "messages": [ + { + "role": "user", + "content": "Can you help me find multiples of 5 that are less than 20?", + }, + { + "role": "assistant", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "multiples", + "arguments": { + "number": 5, + "limit": 20, + }, + }, + } + ], + }, + {"role": "tool", "name": "multiples", "content": "5,10,15"}, + { + "role": "assistant", + "content": "The multiples of 5 less than 20 are: 5, 10, and 15.", + }, + ], + } + ] + + # Setup tokenizer with llama4 chat template + tokenizer = deepcopy(llama3_tokenizer) + + # Add EOS token to the tokenizer + eot_token = "<|eot_id|>" + tokenizer.add_special_tokens({"additional_special_tokens": [eot_token]}) + + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + tokenizer, + chat_template=get_chat_template("llama4"), + message_property_mappings={"role": "role", "content": "content"}, + field_messages="messages", + field_tools="tools", + ), + tokenizer=tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + eot_tokens=[eot_token], + ) + + res = strategy.tokenize_prompt(tool_calling_dataset[0]) + input_ids = res["input_ids"] + labels = res["labels"] + + # Verify that the input_ids contain expected tokens + assert len(input_ids) > 0, "Input IDs should not be empty" + assert len(labels) == len(input_ids), "Labels should match input_ids length" + + # Decode the full conversation to verify structure + decoded_conversation = tokenizer.decode(input_ids) + + # Verify tool calling structure is present in the decoded conversation + assert '"type": "function",' in decoded_conversation, ( + "Tool type function should be in conversation" + ) + assert '"name": "multiples",' in decoded_conversation, ( + "Tool function name should be in conversation" + ) + + assert ( + '<|python_start|><|python_end|>{"name": "multiples", "parameters": {"number": 5, "limit": 20}}<|eot|>' + in decoded_conversation + ), "Assistant tool call should be in conversation" + assert "<|header_start|>ipython<|header_end|>" in decoded_conversation, ( + "IPython header should be in conversation" + ) + assert '"5,10,15"' in decoded_conversation, ( + "Tool response should be in conversation" + ) + + # Get conversation turns to verify labeling + turns = strategy.get_conversation_thread(tool_calling_dataset[0]) + tools = strategy._get_tools(tool_calling_dataset[0]) + + # Check that assistant responses are properly labeled + for i, turn in enumerate(tool_calling_dataset[0]["messages"]): + if turn["role"] == "assistant": + start_idx, end_idx = strategy.find_turn( + turns=turns, turn_idx=i, tools=tools + ) + + assert start_idx != -1 and end_idx != -1, ( + f"Assistant turn {i} should be found" + ) + + # Verify that assistant responses have proper labels + turn_labels = labels[start_idx:end_idx] + assert all(label != IGNORE_TOKEN_ID for label in turn_labels), ( + f"Assistant turn {i} should be unmasked" + ) diff --git a/tests/prompt_strategies/test_chat_templates_mistral.py b/tests/prompt_strategies/test_chat_templates_mistral.py new file mode 100644 index 000000000..85aa72111 --- /dev/null +++ b/tests/prompt_strategies/test_chat_templates_mistral.py @@ -0,0 +1,851 @@ +"""Test chat templates for mistral-common wrapper tokenizer""" + +import unittest +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from axolotl.utils.mistral import HFMistralTokenizer + + +# fmt: off +@pytest.mark.parametrize( + ("tokenizer_str", "assistant_toolcall_ids", "tool_result_ids"), + ( + ("magistral_tokenizer", (9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2), (7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8)), + ("devstral_tokenizer", (9, 1091, 19227, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 61906, 2811, 16753, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 4179, 1429, 1327, 2811, 1429, 19881, 1049, 1050, 1051, 1052, 1053, 1034, 27028, 2), (7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8)), + ("devstral_1_1_tokenizer", (9, 44627, 3684, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2,), (7, 1049, 1044, 1050, 8)), + ) +) +# fmt: on +def test_mistral_chat_template( + tokenizer_str: str, + assistant_toolcall_ids: tuple[int, ...], + tool_result_ids: tuple[int, ...], + request: pytest.FixtureRequest, +): + """Test chat template with the Magistral/Devstral tokenizer""" + + from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy + + tokenizer: HFMistralTokenizer = request.getfixturevalue(tokenizer_str) + + # check bos, eos, pad, unk are accessible properties + assert tokenizer.bos_token_id == 1 + assert tokenizer.eos_token_id == 2 + assert tokenizer.pad_token_id == 11 + assert tokenizer.unk_token_id == 0 + + assert tokenizer.pad_token == "" + assert tokenizer.eos_token == "" + assert tokenizer.bos_token == "" + assert tokenizer.unk_token == "" + + strategy = MistralStrategy( + MistralPrompter( + tokenizer, + chat_template=None, + message_property_mappings={"role": "role", "content": "content"}, + ), + tokenizer=tokenizer, + train_on_inputs=False, + train_on_eos="turn", + sequence_len=512, + roles_to_train=["assistant"], + ) + + # test chat template masking without system prompt + res = strategy.tokenize_prompt( + { + "messages": [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing great, thank you!"}, + ] + } + ) + + assert res["input_ids"] == [ + 1, # bos + 3, # [INST] + 22177, # Hello + 1044, # , + 2606, # how + 1584, # are + 1636, # you + 1063, # ? + 4, # [/INST] + 1073, # I + 4525, # 'm + 6965, # doing + 4824, # great + 1044, # , + 15412, # thank + 1636, # you + 1033, # ! + 2, # + ] + + assert res["labels"] == [ + -100, # bos + -100, # [INST] + -100, # Hello + -100, # , + -100, # how + -100, # are + -100, # you + -100, # ? + -100, # [/INST] + 1073, # I + 4525, # 'm + 6965, # doing + 4824, # great + 1044, # , + 15412, # thank + 1636, # you + 1033, # ! + 2, # + ] + + # test chat template masking with system prompt + res = strategy.tokenize_prompt( + { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing great, thank you!"}, + ] + } + ) + + assert res["input_ids"] == [ + 1, # bos + 17, # [SYSTEM_PROMPT] + 4568, # You + 1584, # are + 1261, # a + 20351, # helpful + 27089, # assistant + 1046, # . + 18, # [/SYSTEM_PROMPT] + 3, # [INST] + 22177, # Hello + 1044, # , + 2606, # how + 1584, # are + 1636, # you + 1063, # ? + 4, # [/INST] + 1073, # I + 4525, # 'm + 6965, # doing + 4824, # great + 1044, # , + 15412, # thank + 1636, # you + 1033, # ! + 2, # + ] + + assert res["labels"] == [ + -100, # bos + -100, # [SYSTEM_PROMPT] + -100, # You + -100, # are + -100, # a + -100, # helpful + -100, # assistant + -100, # . + -100, # [/SYSTEM_PROMPT] + -100, # [INST] + -100, # Hello + -100, # , + -100, # how + -100, # are + -100, # you + -100, # ? + -100, # [/INST] + 1073, # I + 4525, # 'm + 6965, # doing + 4824, # great + 1044, # , + 15412, # thank + 1636, # you + 1033, # ! + 2, # + ] + + # test chat template with tools + res = strategy.tokenize_prompt( + { + "tools": [ + { + "type": "function", + "function": { + "name": "multiples", + "description": "Generates a list of all the multiples of a number that are less than a given limit.", + "parameters": { + "type": "object", + "properties": { + "number": { + "type": "integer", + "description": "The number to find multiples of.", + }, + "limit": { + "type": "integer", + "description": "The upper limit for the multiples.", + }, + }, + "required": ["number", "limit"], + }, + }, + }, + ], + "messages": [ + { + "role": "user", + "content": "Hey, can you give me a breakdown of how to throw an awesome themed party? Like, what themes work best, and how can I set everything up to really wow my guests? I want some ideas on decorations, food, and activities that will make the party unforgettable!", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call12345", + "type": "function", + "function": { + "name": "multiples", + "arguments": { + "number": 16, + "limit": 2, + }, + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call12345", + "name": "multiples", + "content": "1,2", + }, + {"role": "assistant", "content": "The multiples of 16 is 1 and 2."}, + ], + } + ) + + # fmt: off + assert res["input_ids"] == [ + 1, # bos + 5, 1091, 19227, 4994, 2811, 1429, 5165, 1897, 1429, 5165, 2811, 16753, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 14653, 2811, 1429, 10639, 2130, 1261, 2951, 1307, 1747, 1278, 60092, 1307, 1261, 2782, 1455, 1584, 4289, 2224, 1261, 4265, 6139, 39249, 1429, 26204, 2811, 16753, 4994, 2811, 1429, 6371, 1897, 1429, 48649, 2811, 16753, 12856, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 2782, 1317, 3081, 60092, 1307, 2613, 4179, 1429, 33319, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 9229, 6139, 1394, 1278, 60092, 2613, 47579, 1429, 15760, 2811, 12161, 12856, 1897, 1429, 33319, 4964, 2821, 27028, 6, # tool prompt + 3, 46634, 1044, 1710, 1636, 5628, 1639, 1261, 44433, 1307, 2606, 1317, 5388, 1420, 54191, 2424, 1286, 8967, 1063, 15621, 1044, 2549, 30305, 2196, 3560, 1044, 1321, 2606, 1710, 1362, 2016, 8605, 2015, 1317, 5524, 118931, 2036, 32951, 1063, 1362, 2933, 2269, 12106, 1408, 101987, 1044, 6939, 1044, 1321, 9216, 1455, 2084, 3180, 1278, 8967, 119141, 1689, 5935, 1033, 4, # user + *assistant_toolcall_ids, # assistant tool calling + *tool_result_ids, # tool result + 1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant + 2 # eos + ] + + assert res["labels"] == [ + -100, # bos + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool prompt + -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # user prompt + *assistant_toolcall_ids, # assistant tool calling + *([-100] * len(tool_result_ids)), # tool result + 1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant + 2 # eos + ] + # fmt: on + + # test chat template with tokenize=False + res = tokenizer.apply_chat_template( + [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing great, thank you!"}, + ], + tokenize=False, + ) + + assert res == "[INST]Hello, how are you?[/INST]I'm doing great, thank you!" + + # test encode + res = tokenizer.encode("Hello, how are you?", add_special_tokens=True) + assert res == [ + 1, # bos + 22177, # Hello + 1044, # , + 2606, # how + 1584, # are + 1636, # you + 1063, # ? + 2, # eos + ] + + # test decode no skip special tokens + decoded_res = tokenizer.decode(res, skip_special_tokens=False) + + assert decoded_res == "Hello, how are you?" + + # test decode skip special tokens + decoded_res = tokenizer.decode(res, skip_special_tokens=True) + assert decoded_res == "Hello, how are you?" + + # test encode no special tokens + res = tokenizer.encode("Hello, how are you?", add_special_tokens=False) + assert res == [ + 22177, # Hello + 1044, # , + 2606, # how + 1584, # are + 1636, # you + 1063, # ? + ] + + # test convert ids to tokens + res = tokenizer.convert_ids_to_tokens(res) + # spacing are needed as we are converting without decoding + assert res == ["Hello", ",", " how", " are", " you", "?"] + + +@pytest.mark.skip(reason="TODO, fix for new HF wrapper call") +def test_magistral_tokenizer_pad_method(magistral_tokenizer: "HFMistralTokenizer"): + """Test the MistralTokenizer pad method""" + from axolotl.utils.collators.core import IGNORE_INDEX + + magistral_pad_token_id = 11 # taken from tokenizer.pad_token_id + + # Test padding with input_ids and labels only + features = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6]}, + {"input_ids": [7, 8], "labels": [9, 10]}, + ] + + result = magistral_tokenizer.pad(features, padding=True, return_tensors="pt") + + # Check that input_ids are padded correctly + assert result["input_ids"].shape == (2, 3) + assert result["input_ids"].tolist() == [[1, 2, 3], [7, 8, magistral_pad_token_id]] + + # Check that labels are padded correctly + assert result["labels"].shape == (2, 3) + assert result["labels"].tolist() == [[4, 5, 6], [9, 10, IGNORE_INDEX]] + + # Check that attention_mask and position_ids are NOT created + assert "attention_mask" not in result + assert "position_ids" not in result + + # Test padding with attention_mask + features_with_attention = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6], "attention_mask": [1, 1, 1]}, + {"input_ids": [7, 8], "labels": [9, 10], "attention_mask": [1, 1]}, + ] + + result = magistral_tokenizer.pad( + features_with_attention, padding=True, return_tensors="pt" + ) + + # Check that attention_mask is padded correctly + assert result["attention_mask"].shape == (2, 3) + assert result["attention_mask"].tolist() == [[1, 1, 1], [1, 1, 0]] + + # Test padding with position_ids + features_with_position = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6], "position_ids": [0, 1, 2]}, + {"input_ids": [7, 8], "labels": [9, 10], "position_ids": [0, 1]}, + ] + + result = magistral_tokenizer.pad( + features_with_position, padding=True, return_tensors="pt" + ) + + # Check that position_ids are padded correctly (continuing sequence) + assert result["position_ids"].shape == (2, 3) + assert result["position_ids"].tolist() == [[0, 1, 2], [0, 1, 2]] + + # Test padding with all fields + features_all = [ + { + "input_ids": [1, 2, 3], + "labels": [4, 5, 6], + "attention_mask": [1, 1, 1], + "position_ids": [0, 1, 2], + }, + { + "input_ids": [7, 8], + "labels": [9, 10], + "attention_mask": [1, 1], + "position_ids": [0, 1], + }, + ] + + result = magistral_tokenizer.pad(features_all, padding=True, return_tensors="pt") + + # All fields should be present and correctly padded + assert "input_ids" in result + assert "labels" in result + assert "attention_mask" in result + assert "position_ids" in result + + # Test padding with all sequences same length + features_same_length = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6]}, + {"input_ids": [7, 8, 9], "labels": [10, 11, 12]}, + ] + + result = magistral_tokenizer.pad( + features_same_length, padding=True, return_tensors="pt" + ) + + # Check match when no padding is needed + assert result["input_ids"][0].tolist() == features_same_length[0]["input_ids"] + assert result["labels"][0].tolist() == features_same_length[0]["labels"] + + assert result["input_ids"][1].tolist() == features_same_length[1]["input_ids"] + assert result["labels"][1].tolist() == features_same_length[1]["labels"] + + # Test padding with max_length parameter + result = magistral_tokenizer.pad( + features, padding="max_length", max_length=5, return_tensors="pt" + ) + + # Should pad to max_length + assert result["input_ids"].shape == (2, 5) + assert result["labels"].shape == (2, 5) + + # Test numpy return type + result = magistral_tokenizer.pad(features, padding=True, return_tensors="np") + + # Should return numpy arrays + import numpy as np + + assert isinstance(result["input_ids"], np.ndarray) + assert isinstance(result["labels"], np.ndarray) + + # Test unsupported field rejection + features_unsupported = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6], "unsupported_field": [7, 8, 9]}, + ] + + with pytest.raises(NotImplementedError, match="unsupported_field"): + magistral_tokenizer.pad(features_unsupported, padding=True, return_tensors="pt") + + # Test token_type_ids rejection + features_token_type = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6], "token_type_ids": [0, 0, 0]}, + ] + + with pytest.raises(ValueError, match="token_type_ids is not supported"): + magistral_tokenizer.pad(features_token_type, padding=True, return_tensors="pt") + + +def test_magistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"): + """Test tool calling with the Magistral tokenizer""" + from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy + + strategy = MistralStrategy( + MistralPrompter( + magistral_tokenizer, + chat_template=None, + message_property_mappings={"role": "role", "content": "content"}, + ), + tokenizer=magistral_tokenizer, + train_on_inputs=False, + train_on_eos="turn", + sequence_len=512, + roles_to_train=["assistant"], + ) + + # Test basic tool calling with single function + basic_tool_calling = { + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + }, + "required": ["location"], + }, + }, + }, + ], + "messages": [ + { + "role": "user", + "content": "What's the weather like in San Francisco?", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call12345", + "type": "function", + "function": { + "name": "get_weather", + "arguments": { + "location": "San Francisco, CA", + }, + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call12345", + "name": "get_weather", + "content": "Sunny, 72°F", + }, + { + "role": "assistant", + "content": "The weather in San Francisco is sunny and 72°F.", + }, + ], + } + + res = strategy.tokenize_prompt(basic_tool_calling) + + # Basic validation + assert "input_ids" in res + assert "labels" in res + assert len(res["input_ids"]) > 0 + assert len(res["labels"]) == len(res["input_ids"]) + + # Decode and verify structure + decoded = magistral_tokenizer.decode(res["input_ids"]) + assert ( + '[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}}, "required": ["location"]}}}][/AVAILABLE_TOOLS]' + in decoded + ) + assert ( + '[TOOL_CALLS]get_weather[CALL_ID]call12345[ARGS]{"location": "San Francisco, CA"}' + in decoded + ) + assert "[TOOL_RESULTS]call12345[TOOL_CONTENT]Sunny, 72°F[/TOOL_RESULTS]" in decoded + assert "The weather in San Francisco is sunny and 72°F." in decoded + + # Test multiple tool calls in sequence + multi_tool_calling = { + "tools": [ + { + "type": "function", + "function": { + "name": "add_numbers", + "description": "Add two numbers together", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "number", "description": "First number"}, + "b": {"type": "number", "description": "Second number"}, + }, + "required": ["a", "b"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "multiply_numbers", + "description": "Multiply two numbers", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "number", "description": "First number"}, + "y": {"type": "number", "description": "Second number"}, + }, + "required": ["x", "y"], + }, + }, + }, + ], + "messages": [ + { + "role": "user", + "content": "Add 5 and 3, then multiply the result by 2", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call12345", + "type": "function", + "function": { + "name": "add_numbers", + "arguments": {"a": 5, "b": 3}, + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call12345", + "name": "add_numbers", + "content": "8", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call23456", + "type": "function", + "function": { + "name": "multiply_numbers", + "arguments": {"x": 8, "y": 2}, + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call23456", + "name": "multiply_numbers", + "content": "16", + }, + { + "role": "assistant", + "content": "The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16.", + }, + ], + } + + res = strategy.tokenize_prompt(multi_tool_calling) + + # Validation + assert len(res["input_ids"]) > 0 + assert len(res["labels"]) == len(res["input_ids"]) + + decoded = magistral_tokenizer.decode(res["input_ids"]) + assert ( + '[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "add_numbers", "description": "Add two numbers together", "parameters": {"type": "object", "properties": {"a": {"type": "number", "description": "First number"}, "b": {"type": "number", "description": "Second number"}}, "required": ["a", "b"]}}}, {"type": "function", "function": {"name": "multiply_numbers", "description": "Multiply two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "First number"}, "y": {"type": "number", "description": "Second number"}}, "required": ["x", "y"]}}}][/AVAILABLE_TOOLS]' + in decoded + ) + assert ( + '[TOOL_CALLS]add_numbers[CALL_ID]call12345[ARGS]{"a": 5, "b": 3}' in decoded + ) + assert "[TOOL_RESULTS]call12345[TOOL_CONTENT]8[/TOOL_RESULTS]" in decoded + assert ( + '[TOOL_CALLS]multiply_numbers[CALL_ID]call23456[ARGS]{"x": 8, "y": 2}' + in decoded + ) + assert "[TOOL_RESULTS]call23456[TOOL_CONTENT]16[/TOOL_RESULTS]" in decoded + assert ( + "The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16." + in decoded + ) + + # Test tool calling with system message + system_tool_calling = { + "tools": [ + { + "type": "function", + "function": { + "name": "search_database", + "description": "Search for information in database", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + }, + "required": ["query"], + }, + }, + }, + ], + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant with access to a database.", + }, + { + "role": "user", + "content": "Find information about Python programming", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "search123", + "type": "function", + "function": { + "name": "search_database", + "arguments": {"query": "Python programming"}, + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "search123", + "name": "search_database", + "content": "Python is a high-level programming language known for its simplicity.", + }, + { + "role": "assistant", + "content": "Based on the database search, Python is a high-level programming language known for its simplicity and readability.", + }, + ], + } + + res = strategy.tokenize_prompt(system_tool_calling) + + # Validation + assert len(res["input_ids"]) > 0 + assert len(res["labels"]) == len(res["input_ids"]) + + decoded = magistral_tokenizer.decode(res["input_ids"]) + + assert ( + '[SYSTEM_PROMPT]You are a helpful assistant with access to a database.[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "search_database", "description": "Search for information in database", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}}, "required": ["query"]}}}][/AVAILABLE_TOOLS]' + in decoded + ) + + # Test error handling - missing tool response + incomplete_tool_calling = { + "tools": [ + { + "type": "function", + "function": { + "name": "get_time", + "description": "Get current time", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + "messages": [ + { + "role": "user", + "content": "What time is it?", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "time12345", + "type": "function", + "function": { + "name": "get_time", + "arguments": {}, + }, + } + ], + }, + { + "role": "assistant", + "content": "The current time is 12:00 PM.", + }, + ], + } + + from mistral_common.exceptions import InvalidMessageStructureException + + try: + strategy.tokenize_prompt(incomplete_tool_calling) + except InvalidMessageStructureException as e: + assert "Not the same number of function calls and responses" in str(e) + + +@pytest.mark.skip(reason="TODO, fix for new HF wrapper call") +def test_magistral_tokenizer_call_method( + magistral_tokenizer: "HFMistralTokenizer", llama3_tokenizer: "PreTrainedTokenizer" +): + """Test the __call__ method behavior matches HuggingFace standards""" + from copy import deepcopy + + import numpy as np + import torch + + hf_tokenizer = deepcopy(llama3_tokenizer) + hf_tokenizer.pad_token = hf_tokenizer.eos_token + + test_text = "Hello, how are you?" + batch_texts = ["Hello world", "How are you?"] + + # Test single string with return_tensors=None + hf_result: dict[str, list[int]] = hf_tokenizer(test_text, return_tensors=None) + mistral_result: dict[str, list[int]] = magistral_tokenizer( + test_text, return_tensors=None + ) + + assert isinstance(mistral_result, dict) + assert set(mistral_result.keys()) == {"input_ids", "attention_mask"} + assert isinstance(mistral_result["input_ids"], type(hf_result["input_ids"])) # list + assert isinstance( + mistral_result["attention_mask"], type(hf_result["attention_mask"]) + ) + assert len(mistral_result["input_ids"]) == len(mistral_result["attention_mask"]) + assert np.all(mistral_result["attention_mask"]) + assert len(np.array(mistral_result["input_ids"]).shape) == 1 # 1D array + + # Test single string with return_tensors='pt' + hf_result_pt: dict[str, torch.Tensor] = hf_tokenizer(test_text, return_tensors="pt") + mistral_result_pt: dict[str, torch.Tensor] = magistral_tokenizer( + test_text, return_tensors="pt" + ) + + # Check structure and types + assert isinstance(mistral_result_pt["input_ids"], torch.Tensor) + assert isinstance(mistral_result_pt["attention_mask"], torch.Tensor) + + # Check shapes match (don't compare token dimension) + assert len(hf_result_pt["input_ids"].shape) == len( + mistral_result_pt["input_ids"].shape + ) + assert hf_result_pt["input_ids"].shape[0] == mistral_result_pt["input_ids"].shape[0] + assert ( + mistral_result_pt["attention_mask"].shape + == mistral_result_pt["input_ids"].shape + ) + assert torch.all(mistral_result_pt["attention_mask"] == 1) + + # Test batch input with padding + hf_batch: dict[str, torch.Tensor] = hf_tokenizer( + batch_texts, return_tensors="pt", padding=True + ) + mistral_batch: dict[str, torch.Tensor] = magistral_tokenizer( + batch_texts, return_tensors="pt", padding=True + ) + + # Check batch behavior + assert len(hf_batch["input_ids"].shape) == len(mistral_batch["input_ids"].shape) + assert hf_batch["input_ids"].shape[0] == mistral_batch["input_ids"].shape[0] + assert mistral_batch["attention_mask"].shape == mistral_batch["input_ids"].shape + assert torch.any( + mistral_batch["attention_mask"][0] == 0 + ) # padding in shorter sequence + assert torch.all( + mistral_batch["attention_mask"][1] == 1 + ) # no padding in longer sequence + + # Test numpy tensors + mistral_result_np: dict[str, np.ndarray] = magistral_tokenizer( + test_text, return_tensors="np" + ) + assert isinstance(mistral_result_np["input_ids"], np.ndarray) + assert isinstance(mistral_result_np["attention_mask"], np.ndarray) + + # Test consistency with encode() + encoded: list[int] = magistral_tokenizer.encode(test_text, add_special_tokens=True) + called: dict[str, torch.Tensor] = magistral_tokenizer( + test_text, return_tensors="pt" + ) + assert encoded == called["input_ids"][0].tolist() + + # Test Error handling + with pytest.raises(ValueError, match="Unsupported kwargs"): + magistral_tokenizer(test_text, unsupported_param=True) + + with pytest.raises( + ValueError, match="return_tensors='pt' or 'np' requires padding or truncation" + ): + magistral_tokenizer(batch_texts, return_tensors="pt") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/prompt_strategies/test_chat_templates_thinking.py b/tests/prompt_strategies/test_chat_templates_thinking.py index 79429b731..054012e00 100644 --- a/tests/prompt_strategies/test_chat_templates_thinking.py +++ b/tests/prompt_strategies/test_chat_templates_thinking.py @@ -4,15 +4,12 @@ Tests for splitting reasoning/thinking from content into separate field import pytest from datasets import Dataset -from transformers import AutoTokenizer from axolotl.prompt_strategies.chat_template import ( load, ) from axolotl.utils.dict import DictDefault -from tests.hf_offline_utils import enable_hf_offline - @pytest.fixture(name="messages_w_reasoning") def messages_w_reasoning_fixture(): @@ -58,23 +55,12 @@ def messages_w_reasoning_fixture(): ) -@pytest.fixture(name="qwen3_tokenizer") -@enable_hf_offline -def qwen3_tokenizer_fixture( - download_qwen3_half_billion_model, -): # pylint: disable=unused-argument - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") - - return tokenizer - - class TestSplitThinking: """ test class to make sure datasets with reasoning content conforms to the chat_template strategy """ def test_splits_think(self, messages_w_reasoning, qwen3_tokenizer): - # pylint: disable=duplicate-code strategy = load( qwen3_tokenizer, DictDefault( @@ -133,6 +119,6 @@ class TestSplitThinking: 198, # \n ] # fmt: on - assert ( - input_ids == expected_input_ids - ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + assert input_ids == expected_input_ids, ( + f"Input IDs mismatch: {input_ids} != {expected_input_ids}" + ) diff --git a/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py b/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py new file mode 100644 index 000000000..7de21b940 --- /dev/null +++ b/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py @@ -0,0 +1,214 @@ +""" +Tests for handling json tool content +""" + +import json + +import pytest +from datasets import Dataset + +from axolotl.prompt_strategies.chat_template import ( + load, +) +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="qwen3_instruct_prompt_strategy") +def qwen3_instruct_chat_template_strategy(qwen3_tokenizer): + strategy = load( + qwen3_tokenizer, + DictDefault( + { + "train_on_inputs": False, + "sequence_len": 512, + } + ), + DictDefault( + { + "chat_template": "qwen3", + "message_field_role": "role", + "message_field_content": "content", + "message_property_mappings": { + "role": "role", + "content": "content", + }, + "roles": { + "user": ["user"], + "assistant": ["assistant"], + "system": ["system"], + }, + "field_messages": "messages", + } + ), + ) + return strategy + + +class TestQwen3IdenticalConversationArgs: + """ + Test Qwen3 tools is identical between JSON and dict + """ + + @pytest.fixture(name="conversation_dict_args_dataset") + def fixture_conversation_dict_args_dataset(self): + """ + Provides a dataset with conversation where arguments is a dict. + """ + user_content = "What is the weather in Boston?" + function_name = "get_current_weather" + arguments_dict = {"location": "Boston, MA", "unit": "celsius"} + + data = [ + { + "messages": [ + {"role": "user", "content": user_content}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": function_name, + "arguments": arguments_dict, # dict格式 + } + } + ], + }, + ], + } + ] + return Dataset.from_list(data) + + @pytest.fixture(name="conversation_str_args_dataset") + def fixture_conversation_str_args_dataset(self): + """ + Provides a dataset with conversation where arguments is a JSON string. + """ + user_content = "What is the weather in Boston?" + function_name = "get_current_weather" + arguments_dict = {"location": "Boston, MA", "unit": "celsius"} + arguments_str = json.dumps(arguments_dict) + + data = [ + { + "messages": [ + {"role": "user", "content": user_content}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": function_name, + "arguments": arguments_str, # str格式 + } + } + ], + }, + ], + } + ] + return Dataset.from_list(data) + + @pytest.fixture(name="conversation_mixed_time_types_dataset") + def fixture_conversation_mixed_time_types_dataset(self): + """ + Provides a dataset where 'time' field has different types in different tool calls. + """ + data = [ + { + "messages": [ + { + "role": "user", + "content": "Get weather information at different times", + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": "func1", + "arguments": json.dumps( + {"time": "2025-08-01"} + ), # string type + } + }, + { + "function": { + "name": "func2", + "arguments": json.dumps( + {"time": 1690876800} + ), # number type + } + }, + ], + }, + ], + } + ] + return Dataset.from_list(data) + + def test_dict_and_str_args_produce_identical_output( + self, + conversation_dict_args_dataset, + conversation_str_args_dataset, + qwen3_instruct_prompt_strategy, + qwen3_tokenizer, + ): + """ + Tests that after tokenization and decoding, the outputs for both + dict and string `arguments` are exactly the same. + """ + processed_dict_args = conversation_dict_args_dataset.map( + qwen3_instruct_prompt_strategy.tokenize_prompt, + batched=True, + remove_columns=["messages"], + ) + + processed_str_args = conversation_str_args_dataset.map( + qwen3_instruct_prompt_strategy.tokenize_prompt, + batched=True, + remove_columns=["messages"], + ) + + decoded_prompt_from_dict = qwen3_tokenizer.decode( + processed_dict_args[0]["input_ids"] + ) + + decoded_prompt_from_str = qwen3_tokenizer.decode( + processed_str_args[0]["input_ids"] + ) + + assert decoded_prompt_from_dict == decoded_prompt_from_str, ( + f"Dict format output:\n{decoded_prompt_from_dict}\n" + f"String format output:\n{decoded_prompt_from_str}" + ) + + assert ( + processed_dict_args[0]["input_ids"] == processed_str_args[0]["input_ids"] + ), "The tokenized input_ids should be identical for dict and str arguments" + + def test_str_args_with_mixed_time_types_no_error( + self, + conversation_mixed_time_types_dataset, + qwen3_instruct_prompt_strategy, + qwen3_tokenizer, + ): + """ + Tests that when 'time' field has different types (string vs number) + in different tool calls, str format arguments don't cause errors. + """ + processed = conversation_mixed_time_types_dataset.map( + qwen3_instruct_prompt_strategy.tokenize_prompt, + batched=True, + remove_columns=["messages"], + ) + + assert len(processed) == 1 + assert "input_ids" in processed[0] + assert len(processed[0]["input_ids"]) > 0 + + decoded = qwen3_tokenizer.decode(processed[0]["input_ids"]) + assert "2025-08-01" in decoded, "String time value should be present" + assert "1690876800" in decoded, "Number time value should be present" diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py index e5f30a6c4..e570cfc9d 100644 --- a/tests/prompt_strategies/test_dpo_chat_templates.py +++ b/tests/prompt_strategies/test_dpo_chat_templates.py @@ -16,7 +16,6 @@ from tests.hf_offline_utils import enable_hf_offline @pytest.fixture(name="assistant_dataset") def fixture_assistant_dataset(): - # pylint: disable=duplicate-code return Dataset.from_list( [ { @@ -49,7 +48,6 @@ def fixture_assistant_dataset(): @pytest.fixture(name="custom_assistant_dataset") def fixture_custom_assistant_dataset(): - # pylint: disable=duplicate-code return Dataset.from_list( [ { @@ -102,7 +100,6 @@ class TestAssistantDPOChatTemplateLlama3: """ def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset): - # pylint: disable=duplicate-code transform_fn, _ = default( DictDefault( { @@ -127,7 +124,6 @@ class TestAssistantDPOChatTemplateLlama3: assert result["rejected"] == "party on<|eot_id|>" def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset): - # pylint: disable=duplicate-code transform_fn, _ = default( DictDefault( { @@ -168,7 +164,6 @@ class TestAssistantDPOChatTemplatePhi3: """ def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset): - # pylint: disable=duplicate-code transform_fn, _ = default( DictDefault( { @@ -198,7 +193,6 @@ class TestAssistantDPOChatTemplateGemma: """ def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset): - # pylint: disable=duplicate-code transform_fn, _ = default( DictDefault( { diff --git a/tests/prompt_strategies/test_dpo_chatml.py b/tests/prompt_strategies/test_dpo_chatml.py index b313a4b64..2c089067f 100644 --- a/tests/prompt_strategies/test_dpo_chatml.py +++ b/tests/prompt_strategies/test_dpo_chatml.py @@ -6,8 +6,9 @@ import unittest import pytest +from axolotl.loaders.tokenizer import load_tokenizer from axolotl.prompt_strategies.dpo import load as load_dpo -from axolotl.utils.data.rl import load_prepare_preference_datasets +from axolotl.utils.data.rl import prepare_preference_datasets from axolotl.utils.dict import DictDefault from tests.hf_offline_utils import enable_hf_offline @@ -55,7 +56,8 @@ class TestDPOChatml: # test that dpo.load works load_dpo("chatml", cfg) # now actually load the datasets with the strategy - train_ds, _ = load_prepare_preference_datasets(cfg) + tokenizer = load_tokenizer(cfg) + train_ds, _ = prepare_preference_datasets(cfg, tokenizer) assert train_ds[0]["prompt"].startswith("<|im_start|>") assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n") assert "chosen" in train_ds[0] diff --git a/tests/prompt_strategies/test_stepwise.py b/tests/prompt_strategies/test_stepwise.py index 2abe4ae18..ad3f7531f 100644 --- a/tests/prompt_strategies/test_stepwise.py +++ b/tests/prompt_strategies/test_stepwise.py @@ -20,7 +20,6 @@ class TestStepWiseSupervisedPromptTokenizingStrategy: @pytest.fixture() def stepwise_supervised_dataset(self): - # pylint: disable=duplicate-code return Dataset.from_list( [ { diff --git a/tests/test_chunked_xentropy.py b/tests/test_chunked_xentropy.py new file mode 100644 index 000000000..56ac1b168 --- /dev/null +++ b/tests/test_chunked_xentropy.py @@ -0,0 +1,40 @@ +""" +test suite for chunked cross entropy +""" + +import pytest +import torch +from torch import nn + +from axolotl.monkeypatch.loss.chunked import get_causal_lm_loss + + +@pytest.fixture +def chunked_fixtures(): + model_dim = 512 + vocab_size = 1024 * 256 + seq_len = 2048 + batch_size = 1 + + lm_head = nn.Linear(model_dim, vocab_size) + hidden_state = torch.randn(batch_size, seq_len, model_dim) + labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) + return lm_head, hidden_state, labels, vocab_size + + +def test_chunked_forward(chunked_fixtures): + lm_head, hidden_state, labels, vocab_size = chunked_fixtures + lm_loss = get_causal_lm_loss() + + logits = lm_head(hidden_state) + + chunked_lm_loss = lm_loss(logits, labels) + + logits_flattened = logits.view(-1, vocab_size) + labels_flattened = labels.view(-1) + + loss = nn.functional.cross_entropy( + logits_flattened.float(), labels_flattened, reduction="mean" + ) + + assert torch.allclose(chunked_lm_loss, loss, atol=1e-2, rtol=1e-2) diff --git a/tests/test_data.py b/tests/test_data.py index 6d583cfd3..99ed06336 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -6,7 +6,7 @@ import unittest from transformers import LlamaTokenizer -from axolotl.utils.data import encode_pretraining, md5 +from axolotl.utils.data import encode_streaming, md5 from tests.hf_offline_utils import enable_hf_offline @@ -39,7 +39,7 @@ class TestEncodePretraining(unittest.TestCase): "hello, hello", ] } - result = encode_pretraining(self.tokenizer, self.max_tokens, examples) + result = encode_streaming(examples, self.tokenizer, self.max_tokens) self.assertEqual(len(result["input_ids"]), 3) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index bd77591cf..bd1c8f2c2 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,10 +1,9 @@ -""" -Test dataset loading under various conditions. -""" +"""Test dataset loading under various conditions.""" import shutil import tempfile from pathlib import Path +from typing import Any, Generator from unittest.mock import patch import pytest @@ -12,8 +11,9 @@ from datasets import Dataset from huggingface_hub import snapshot_download from transformers import PreTrainedTokenizer -from axolotl.utils.data import load_tokenized_prepared_datasets -from axolotl.utils.data.rl import load_prepare_preference_datasets +from axolotl.loaders.tokenizer import load_tokenizer +from axolotl.utils.data.rl import prepare_preference_datasets +from axolotl.utils.data.sft import _load_tokenized_prepared_datasets from axolotl.utils.dict import DictDefault from tests.constants import ( @@ -28,7 +28,9 @@ class TestDatasetPreparation: """Test a configured dataloader.""" @pytest.fixture - def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer: + def tokenizer( + self, tokenizer_huggyllama + ) -> Generator[PreTrainedTokenizer, Any, Any]: tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS) yield tokenizer_huggyllama @@ -63,7 +65,10 @@ class TestDatasetPreparation: } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 2000 assert "input_ids" in dataset.features @@ -107,7 +112,10 @@ class TestDatasetPreparation: } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 2000 assert "input_ids" in dataset.features @@ -133,10 +141,14 @@ class TestDatasetPreparation: "type": "alpaca", }, ], + "dataset_num_proc": 4, } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 1 assert "input_ids" in dataset.features @@ -145,7 +157,7 @@ class TestDatasetPreparation: @enable_hf_offline def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture): - """Usual use case. Verify a directory of parquet files can be loaded.""" + """Usual use case. Verify a directory of parquet files can be loaded.""" with tempfile.TemporaryDirectory() as tmp_dir: tmp_ds_dir = Path(tmp_dir) / "tmp_dataset" tmp_ds_dir.mkdir() @@ -168,10 +180,14 @@ class TestDatasetPreparation: "type": "alpaca", }, ], + "dataset_num_proc": 4, } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 1 assert "input_ids" in dataset.features @@ -203,10 +219,14 @@ class TestDatasetPreparation: "type": "alpaca", }, ], + "dataset_num_proc": 4, } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 1 assert "input_ids" in dataset.features @@ -232,10 +252,14 @@ class TestDatasetPreparation: "type": "alpaca", }, ], + "dataset_num_proc": 4, } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 1 assert "input_ids" in dataset.features @@ -261,10 +285,14 @@ class TestDatasetPreparation: "type": "alpaca", }, ], + "dataset_num_proc": 4, } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 1 assert "input_ids" in dataset.features @@ -286,7 +314,8 @@ class TestDatasetPreparation: } ) - train_dataset, _ = load_prepare_preference_datasets(cfg) + tokenizer = load_tokenizer(cfg) + train_dataset, _ = prepare_preference_datasets(cfg, tokenizer) assert len(train_dataset) == 1800 assert "conversation" not in train_dataset.features @@ -318,7 +347,10 @@ class TestDatasetPreparation: } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 2000 assert "input_ids" in dataset.features @@ -338,17 +370,20 @@ class TestDatasetPreparation: "rl": "dpo", "chat_template": "llama3", "datasets": [ALPACA_MESSAGES_CONFIG_REVISION], + "dataset_num_proc": 4, } ) - # pylint: disable=duplicate-code - with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset: + with patch( + "axolotl.utils.data.rl.load_dataset_with_config" + ) as mock_load_dataset: # Set up the mock to return different values on successive calls mock_load_dataset.return_value = ( dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff ) - train_dataset, _ = load_prepare_preference_datasets(cfg) + tokenizer = load_tokenizer(cfg) + train_dataset, _ = prepare_preference_datasets(cfg, tokenizer) assert len(train_dataset) == 1800 assert "conversation" not in train_dataset.features @@ -393,16 +428,18 @@ class TestDatasetPreparation: ) with patch( - "axolotl.utils.data.shared.load_dataset_w_config" + "axolotl.utils.data.shared.load_dataset_with_config" ) as mock_load_dataset: # Set up the mock to return different values on successive calls mock_load_dataset.return_value = ( dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff ) - dataset, _ = load_tokenized_prepared_datasets( - tokenizer, cfg, prepared_path - ) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", + str(prepared_path), + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 2000 assert "input_ids" in dataset.features @@ -434,10 +471,14 @@ class TestDatasetPreparation: "type": "alpaca", }, ], + "dataset_num_proc": 4, } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) assert len(dataset) == 2000 assert "input_ids" in dataset.features diff --git a/tests/test_dict.py b/tests/test_dict.py index 0bcf8ca7b..19a370199 100644 --- a/tests/test_dict.py +++ b/tests/test_dict.py @@ -21,26 +21,26 @@ class DictDefaultTest(unittest.TestCase): } ) - assert ( - cfg.key_a.key_b == "value_a" - ), "DictDefault should return value for existing nested keys" + assert cfg.key_a.key_b == "value_a", ( + "DictDefault should return value for existing nested keys" + ) - assert ( - cfg.key_c == "value_c" - ), "DictDefault should return value for existing keys" + assert cfg.key_c == "value_c", ( + "DictDefault should return value for existing keys" + ) - assert ( - cfg.key_d[0] == "value_d" - ), "DictDefault should return value for existing keys in list" + assert cfg.key_d[0] == "value_d", ( + "DictDefault should return value for existing keys in list" + ) - assert ( - "value_e" in cfg.key_d - ), "DictDefault should support in operator for existing keys in list" + assert "value_e" in cfg.key_d, ( + "DictDefault should support in operator for existing keys in list" + ) def test_dict_or_operator(self): cfg = DictDefault({"key_a": {"key_b": "value_b"}, "key_f": "value_g"}) - cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation + cfg = cfg | DictDefault( { "key_a": {"key_b": "value_a"}, "key_c": "value_c", @@ -49,9 +49,9 @@ class DictDefaultTest(unittest.TestCase): } ) - assert ( - cfg.key_a.key_b == "value_b" - ), "DictDefault should support OR operator for existing nested keys" + assert cfg.key_a.key_b == "value_b", ( + "DictDefault should support OR operator for existing nested keys" + ) assert cfg.key_c == "value_c", "DictDefault should not delete existing key" @@ -60,9 +60,9 @@ class DictDefaultTest(unittest.TestCase): "value_e", ], "DictDefault should not overwrite existing keys in list" - assert ( - cfg.key_f == "value_g" - ), "DictDefault should support OR operator for existing key" + assert cfg.key_f == "value_g", ( + "DictDefault should support OR operator for existing key" + ) def test_dict_missingkey(self): cfg = DictDefault({}) @@ -72,9 +72,9 @@ class DictDefaultTest(unittest.TestCase): def test_dict_or(self): cfg = DictDefault({}) | DictDefault({}) - assert ( - cfg.random_key is None - ), "DictDefault should return None for missing keys after | operation" + assert cfg.random_key is None, ( + "DictDefault should return None for missing keys after | operation" + ) def test_dict_nested_missingparentkey(self): """ diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py index 29672c9e5..a519db525 100644 --- a/tests/test_exact_deduplication.py +++ b/tests/test_exact_deduplication.py @@ -5,7 +5,6 @@ Additionally, this test suite includes tests for functions that indirectly call `deduplicate_and_log_datasets` during the execution of the preprocess command. """ -import hashlib import unittest from unittest.mock import patch @@ -14,8 +13,7 @@ from datasets import Dataset from axolotl.loaders import load_processor, load_tokenizer from axolotl.utils.config import normalize_config, validate_config -from axolotl.utils.data import prepare_dataset -from axolotl.utils.data.rl import load_prepare_preference_datasets +from axolotl.utils.data import prepare_datasets, prepare_preference_datasets from axolotl.utils.data.utils import deduplicate_and_log_datasets from axolotl.utils.dict import DictDefault @@ -43,9 +41,9 @@ def verify_deduplication(actual_dataset, expected_dataset, dataset_name): assert actual_rows == expected_rows, f"Mismatch in {dataset_name} dataset" # Verify size consistency - assert len(actual_rows) == len( - actual_dataset - ), f"Size mismatch in {dataset_name} dataset after deduplication" + assert len(actual_rows) == len(actual_dataset), ( + f"Size mismatch in {dataset_name} dataset after deduplication" + ) class TestDeduplicateIndividualFunctions(unittest.TestCase): @@ -71,36 +69,14 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase): self.expected_dataset = Dataset.from_dict(self.expected_data) def test_deduplication(self): - train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=self.dataset) - _, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=self.dataset) + train_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset) + eval_dataset, _ = deduplicate_and_log_datasets( + dataset=self.dataset, dataset_name="eval" + ) verify_deduplication(train_dataset, self.expected_dataset, "train_dataset") verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset") - def test_datasets_are_none(self): - # Test when both datasets are None - train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( - train_dataset=None, eval_dataset=None - ) - self.assertIsNone(train_dataset, "Expected train_dataset to be None") - self.assertIsNone(eval_dataset, "Expected eval_dataset to be None") - - def test_only_train_is_none(self): - # Test when only train_dataset is None - train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( - train_dataset=None, eval_dataset=self.dataset - ) - self.assertIsNone(train_dataset, "Expected train_dataset to be None") - verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset") - - def test_only_eval_is_none(self): - # Test when only eval_dataset is None - train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( - train_dataset=self.dataset, eval_dataset=None - ) - self.assertIsNone(eval_dataset, "Expected eval_dataset to be None") - verify_deduplication(train_dataset, self.expected_dataset, "train_dataset") - def test_exact_duplicates(self): # Test when datasets are exact duplicates duplicate_data = { @@ -115,8 +91,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase): expected_dataset = Dataset.from_dict(expected_data) # Run deduplication - train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset) - _, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset) + train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset) + eval_dataset, _ = deduplicate_and_log_datasets( + dataset=dataset, dataset_name="eval" + ) verify_deduplication(train_dataset, expected_dataset, "train_dataset") verify_deduplication(eval_dataset, expected_dataset, "eval_dataset") @@ -139,8 +117,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase): expected_dataset = Dataset.from_dict(expected_data) # Run deduplication - train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset) - _, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset) + train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset) + eval_dataset, _ = deduplicate_and_log_datasets( + dataset=dataset, dataset_name="eval" + ) verify_deduplication(train_dataset, expected_dataset, "train_dataset") verify_deduplication(eval_dataset, expected_dataset, "eval_dataset") @@ -169,8 +149,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase): expected_dataset_eval = Dataset.from_dict(expected_data_eval) # Run deduplication - train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( - train_dataset=dataset, eval_dataset=dataset + train_dataset, eval_dataset = deduplicate_and_log_datasets( + dataset=dataset, other_dataset=dataset ) verify_deduplication(train_dataset, expected_dataset_train, "train_dataset") @@ -206,8 +186,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase): expected_dataset_eval = Dataset.from_dict(expected_data_eval) # Run deduplication - train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( - train_dataset=dataset_train, eval_dataset=dataset_eval + train_dataset, eval_dataset = deduplicate_and_log_datasets( + dataset=dataset_train, other_dataset=dataset_eval ) verify_deduplication(train_dataset, expected_dataset_train, "train_dataset") @@ -230,6 +210,7 @@ class TestDeduplicateRLDataset: ALPACA_MESSAGES_CONFIG_REVISION, ALPACA_MESSAGES_CONFIG_REVISION, ], + "dataset_num_proc": 4, } ) yield fixture @@ -243,9 +224,10 @@ class TestDeduplicateRLDataset: ): """Verify that loading with deduplication removes duplicates.""" - # pylint: disable=duplicate-code with ( - patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset, + patch( + "axolotl.utils.data.rl.load_dataset_with_config" + ) as mock_load_dataset, patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer, ): # Set up the mock to return different values on successive calls @@ -255,7 +237,8 @@ class TestDeduplicateRLDataset: ] mock_load_tokenizer.return_value = tokenizer_huggyllama - train_dataset, _ = load_prepare_preference_datasets(cfg) + tokenizer = load_tokenizer(cfg) + train_dataset, _ = prepare_preference_datasets(cfg, tokenizer) # Verify that the dataset has been deduplicated assert len(train_dataset) == 1800, "Dataset was not properly deduplicated" @@ -267,9 +250,10 @@ class TestDeduplicateRLDataset: dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, tokenizer_huggyllama, ): - # pylint: disable=duplicate-code with ( - patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset, + patch( + "axolotl.utils.data.rl.load_dataset_with_config" + ) as mock_load_dataset, patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer, ): # Set up the mock to return different values on successive calls @@ -279,14 +263,15 @@ class TestDeduplicateRLDataset: ] mock_load_tokenizer.return_value = tokenizer_huggyllama - cfg.dataset_exact_deduplication = False # Load the dataset without deduplication - train_dataset, _ = load_prepare_preference_datasets(cfg) + cfg.dataset_exact_deduplication = False + tokenizer = load_tokenizer(cfg) + train_dataset, _ = prepare_preference_datasets(cfg, tokenizer) # Verify that the dataset retains duplicates - assert ( - len(train_dataset) == 1800 * 2 - ), "Dataset deduplication occurred when it should not have" + assert len(train_dataset) == 1800 * 2, ( + "Dataset deduplication occurred when it should not have" + ) class TestDeduplicateNonRL(unittest.TestCase): @@ -335,7 +320,7 @@ class TestDeduplicateNonRL(unittest.TestCase): ) # Prepare dataset using the prepare_dataset function - train_dataset, _, _, _ = prepare_dataset( + train_dataset, _, _, _ = prepare_datasets( self.cfg_1, tokenizer, processor=processor, @@ -362,7 +347,7 @@ class TestDeduplicateNonRL(unittest.TestCase): ) # Prepare dataset using the prepare_dataset function - _, eval_dataset, _, _ = prepare_dataset( + _, eval_dataset, _, _ = prepare_datasets( self.cfg_1, tokenizer, processor=processor, @@ -389,7 +374,7 @@ class TestDeduplicateNonRL(unittest.TestCase): ) # Prepare dataset using the prepare_dataset function - train_dataset, eval_dataset, _, _ = prepare_dataset( + train_dataset, eval_dataset, _, _ = prepare_datasets( self.cfg_1, tokenizer, processor=processor, @@ -428,41 +413,8 @@ class TestWrongCollisions(unittest.TestCase): self.eval_dataset = Dataset.from_dict(self.eval_data) self.dataset = Dataset.from_dict(self.dataset_data) - @patch( - "axolotl.utils.data.utils.sha256", - side_effect=lambda x: ( - hashlib.sha256("forced_collision_hash".encode("utf-8")).hexdigest() - if "sample 5" in x - else hashlib.sha256(x.encode("utf-8")).hexdigest() - ), - ) - def test_deduplication_wrong_collision_train_eval(self, _mock_sha256): - dedup_train, dedup_eval, _ = deduplicate_and_log_datasets( - train_dataset=self.train_dataset, eval_dataset=self.eval_dataset - ) - self.assertEqual( - len(dedup_train), - 2, - "train dataset should not deduplicate rows with forced hash collisions but different labels.", - ) - self.assertEqual( - len(dedup_eval), - 2, - "Eval dataset should not deduplicate rows with forced hash collisions but different labels.", - ) - self.assertEqual( - len(dedup_eval), - len(self.eval_dataset), - "The output eval dataset should have the same number of rows as the input eval dataset.", - ) - self.assertEqual( - str(dedup_eval), - str(self.eval_dataset), - "The string representation of the output eval dataset should be identical to the input eval dataset.", - ) - def test_deduplication_dataset_only(self): - _, _, dedup_dataset = deduplicate_and_log_datasets(dataset=self.dataset) + dedup_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset) self.assertEqual( len(dedup_dataset), 3, "Dataset should have all original values" ) diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 7313a8267..f516d0ca4 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -9,6 +9,7 @@ from transformers.utils.import_utils import is_torch_mps_available from axolotl.loaders import ModelLoader from axolotl.utils.dict import DictDefault +from axolotl.utils.distributed import _get_parallel_config_kwargs class TestModelsUtils: @@ -16,7 +17,7 @@ class TestModelsUtils: def setup_method(self) -> None: # load config - self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init + self.cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", "model_type": "AutoModelForCausalLM", @@ -29,20 +30,16 @@ class TestModelsUtils: "device_map": "auto", } ) - self.tokenizer = MagicMock( # pylint: disable=attribute-defined-outside-init - spec=PreTrainedTokenizerBase - ) - self.inference = False # pylint: disable=attribute-defined-outside-init - self.reference_model = True # pylint: disable=attribute-defined-outside-init + self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase) + self.inference = False + self.reference_model = True # init ModelLoader - self.model_loader = ( # pylint: disable=attribute-defined-outside-init - ModelLoader( - cfg=self.cfg, - tokenizer=self.tokenizer, - inference=self.inference, - reference_model=self.reference_model, - ) + self.model_loader = ModelLoader( + cfg=self.cfg, + tokenizer=self.tokenizer, + inference=self.inference, + reference_model=self.reference_model, ) def test_set_device_map_config(self): @@ -50,7 +47,7 @@ class TestModelsUtils: device_map = self.cfg.device_map if is_torch_mps_available(): device_map = "mps" - # pylint: disable=protected-access + self.model_loader._set_device_map_config() if is_deepspeed_zero3_enabled(): assert "device_map" not in self.model_loader.model_kwargs @@ -77,7 +74,6 @@ class TestModelsUtils: self.cfg.gptq = gptq self.cfg.adapter = adapter - # pylint: disable=protected-access self.model_loader._set_quantization_config() if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq: assert not ( @@ -171,3 +167,42 @@ class TestModelsUtils: message_property_mappings={"content": "different_content"}, ) assert "Conflicting message content fields" in str(exc_info.value) + + @pytest.mark.parametrize( + "world_size, tensor_parallel_size, context_parallel_size, dp_shard_size, dp_replicate_size, is_fsdp, expected", + [ + (16, 2, 2, 2, 2, True, (2, 2, 2, 2)), + (16, 1, 1, None, None, True, (0, 0, 16, 1)), + (16, 2, 2, 2, None, True, (2, 2, 2, 2)), + (16, 2, 2, None, 2, True, (2, 2, 2, 2)), + (16, 1, 1, None, 2, True, (0, 0, 8, 2)), + (2, 1, 1, None, None, True, (0, 0, 2, 1)), + ], + ) + def test_get_parallel_config_kwargs( + self, + world_size, + tensor_parallel_size, + context_parallel_size, + dp_shard_size, + dp_replicate_size, + is_fsdp, + expected, + ): + res = _get_parallel_config_kwargs( + world_size, + tensor_parallel_size, + context_parallel_size, + dp_shard_size, + dp_replicate_size, + is_fsdp, + ) + + if expected[0] > 1: + assert res["tp_size"] == expected[0] + if expected[1] > 1: + assert res["cp_size"] == expected[1] + if expected[2] > 1: + assert res["dp_shard_size"] == expected[2] + if expected[3] > 1: + assert res["dp_replicate_size"] == expected[3] diff --git a/tests/test_logging_config_file_capture.py b/tests/test_logging_config_file_capture.py new file mode 100644 index 000000000..44b0ee5e6 --- /dev/null +++ b/tests/test_logging_config_file_capture.py @@ -0,0 +1,103 @@ +import logging +import tempfile + +import pytest + + +def read(path: str) -> str: + with open(path, "r", encoding="utf-8") as f: + return f.read() + + +@pytest.fixture(autouse=True) +def _reset_logging_state(): + # Ensure a clean slate for logging between tests + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + logging.shutdown() + # Note: dictConfig in configure_logging will set up handlers again + yield + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + logging.shutdown() + + +def test_axolotl_logs_captured_at_all_levels(monkeypatch): + from axolotl.logging_config import configure_logging + from axolotl.utils import tee + from axolotl.utils.logging import get_logger + + with tempfile.TemporaryDirectory() as td: + # Avoid stdout tee in this test to simplify interaction with pytest capture + monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0") + configure_logging() + path = tee.prepare_debug_log( + type("Cfg", (), {"output_dir": td, "get": lambda *_: False}) + ) + + log = get_logger("axolotl.test") + log.info("AX-INFO") + log.debug("AX-DEBUG") + tee.file_only_stream.flush() + + data = read(path) + assert "AX-INFO" in data + assert "AX-DEBUG" in data + tee.close_debug_log() + + +def test_third_party_logs_filtered_and_warning_captured(monkeypatch): + from axolotl.logging_config import configure_logging + from axolotl.utils import tee + + with tempfile.TemporaryDirectory() as td: + monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0") + configure_logging() + path = tee.prepare_debug_log( + type("Cfg", (), {"output_dir": td, "get": lambda *_: False}) + ) + + # Third-party logger (non-axolotl) + other = logging.getLogger("thirdparty.lib") + other.info("TP-INFO") + other.warning("TP-WARN") + + # Simulate Python warnings routed through logging + logging.getLogger("py.warnings").warning("PY-WARN") + + # Push through buffers + tee.file_only_stream.flush() + + data = read(path) + # INFO from non-axolotl should be filtered out (not present) + assert "TP-INFO" not in data + # WARNING+ should be present + assert "TP-WARN" in data + # Python warnings captured (via py.warnings logger) + assert "PY-WARN" in data + tee.close_debug_log() + tee.close_debug_log() + + +def test_prepare_debug_log_idempotent_and_no_duplicate(monkeypatch): + from axolotl.logging_config import configure_logging + from axolotl.utils import tee + from axolotl.utils.logging import get_logger + + with tempfile.TemporaryDirectory() as td: + monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0") + configure_logging() + cfg = type("Cfg", (), {"output_dir": td, "get": lambda *_: False}) + p1 = tee.prepare_debug_log(cfg) + p2 = tee.prepare_debug_log(cfg) + assert p1 == p2 + + log = get_logger("axolotl.test") + marker = "UNIQUE-MARKER-12345" + log.info(marker) + tee.file_only_stream.flush() + + data = read(p1) + # Ensure the marker appears once (not duplicated via propagation) + assert data.count(marker) == 1 + tee.close_debug_log() diff --git a/tests/test_lora.py b/tests/test_lora.py index 6edcdd88e..50cbea9bc 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -6,7 +6,6 @@ from axolotl.loaders import ModelLoader, load_tokenizer from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault -# pylint: disable=duplicate-code minimal_config = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", diff --git a/tests/test_normalize_config.py b/tests/test_normalize_config.py index ea98bf97d..f0d3a2d72 100644 --- a/tests/test_normalize_config.py +++ b/tests/test_normalize_config.py @@ -5,7 +5,11 @@ Test classes for checking functionality of the cfg normalization import unittest from unittest.mock import patch -from axolotl.utils.config import normalize_cfg_datasets, normalize_config +from axolotl.utils.config import ( + normalize_cfg_datasets, + normalize_config, + validate_config, +) from axolotl.utils.dict import DictDefault @@ -23,6 +27,13 @@ class NormalizeConfigTestCase(unittest.TestCase): "num_epochs": 1, "micro_batch_size": 1, "gradient_accumulation_steps": 1, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "learning_rate": 0.0001, } ) @@ -90,3 +101,99 @@ class NormalizeConfigTestCase(unittest.TestCase): self.assertTrue(cfg.bf16) self.assertFalse(cfg.fp16) + + def test_migrate_fsdp_config(self): + """Test basic FSDP config migration with and without fsdp_version""" + cfg_with_version = self._get_base_cfg() | DictDefault( + { + "fsdp_config": { + "fsdp_version": 2, + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_offload_params": False, + "fsdp_cpu_ram_efficient_loading": True, + } + } + ) + + cfg_with_version = validate_config(cfg_with_version) + + self.assertEqual(cfg_with_version.fsdp_version, 2) + self.assertEqual( + cfg_with_version.fsdp_config.auto_wrap_policy, "TRANSFORMER_BASED_WRAP" + ) + self.assertEqual(cfg_with_version.fsdp_config.offload_params, False) + self.assertEqual(cfg_with_version.fsdp_config.cpu_ram_efficient_loading, True) + + self.assertNotIn("fsdp_auto_wrap_policy", cfg_with_version.fsdp_config) + self.assertNotIn("fsdp_offload_params", cfg_with_version.fsdp_config) + self.assertNotIn("fsdp_cpu_ram_efficient_loading", cfg_with_version.fsdp_config) + self.assertNotIn("fsdp_version", cfg_with_version.fsdp_config) + self.assertNotIn("version", cfg_with_version.fsdp_config) + + cfg_without_version = self._get_base_cfg() | DictDefault( + { + "fsdp_config": { + "fsdp_auto_wrap_policy": "SIZE_BASED_WRAP", + "fsdp_offload_params": True, + } + } + ) + + cfg_without_version = validate_config(cfg_without_version) + + self.assertNotIn("fsdp_version", cfg_without_version) + self.assertEqual( + cfg_without_version.fsdp_config.auto_wrap_policy, "SIZE_BASED_WRAP" + ) + self.assertEqual(cfg_without_version.fsdp_config.offload_params, True) + + self.assertNotIn("fsdp_auto_wrap_policy", cfg_without_version.fsdp_config) + self.assertNotIn("fsdp_offload_params", cfg_without_version.fsdp_config) + + def test_migrate_fsdp_config_no_fsdp_config(self): + """Test that function doesn't crash when no fsdp_config is present""" + cfg = self._get_base_cfg() + + cfg = validate_config(cfg) + + self.assertNotIn("fsdp_config", cfg) + self.assertNotIn("fsdp_version", cfg) + + def test_migrate_fsdp_config_empty_fsdp_config(self): + """Test migration with empty fsdp_config""" + cfg = self._get_base_cfg() | DictDefault({"fsdp_config": {}}) + + cfg = validate_config(cfg) + + self.assertNotIn("fsdp_version", cfg) + self.assertEqual(cfg.fsdp_config, {}) + + def test_migrate_fsdp_config_mixed_keys(self): + """Test migration with a mix of fsdp_ and non-fsdp_ keys""" + cfg = self._get_base_cfg() | DictDefault( + { + "fsdp_config": { + "fsdp_version": 1, + "fsdp_state_dict_type": "FULL_STATE_DICT", + "mixed_precision_policy": "fp16", + "activation_checkpointing": True, + "fsdp_reshard_after_forward": False, + } + } + ) + + cfg = validate_config(cfg) + + self.assertEqual(cfg.fsdp_version, 1) + self.assertEqual(cfg.fsdp_config.state_dict_type, "FULL_STATE_DICT") + self.assertEqual(cfg.fsdp_config.reshard_after_forward, False) + self.assertEqual(cfg.fsdp_config.mixed_precision_policy, "fp16") + self.assertEqual(cfg.fsdp_config.activation_checkpointing, True) + + # Check original fsdp_ keys are removed + self.assertNotIn("fsdp_version", cfg.fsdp_config) + self.assertNotIn("fsdp_state_dict_type", cfg.fsdp_config) + self.assertNotIn("fsdp_reshard_after_forward", cfg.fsdp_config) + + # Ensure no duplicate version key + self.assertNotIn("version", cfg.fsdp_config) diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py index 2b03c62f8..a5db7cbe0 100644 --- a/tests/test_packed_batch_sampler.py +++ b/tests/test_packed_batch_sampler.py @@ -8,7 +8,7 @@ from transformers import AutoTokenizer from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_strategies.completion import load from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq -from axolotl.utils.data.utils import drop_long_seq_in_dataset +from axolotl.utils.data.utils import handle_long_seq_in_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -48,7 +48,13 @@ class TestBatchedSamplerPacking: max_seq_length, sequential, ): - import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 + from axolotl.monkeypatch.data.batch_dataset_fetcher import ( + apply_multipack_dataloader_patch, + remove_multipack_dataloader_patch, + ) + + # Apply the patch for multipack handling + apply_multipack_dataloader_patch() dataset = dataset_winglian_tiny_shakespeare["train"] @@ -70,7 +76,7 @@ class TestBatchedSamplerPacking: ) train_dataset = concatenate_datasets([dataset_wrapper]) - train_dataset = drop_long_seq_in_dataset(train_dataset, cfg) + train_dataset = handle_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg) lengths = get_dataset_lengths(train_dataset) batch_sampler = MultipackBatchSampler( @@ -81,12 +87,13 @@ class TestBatchedSamplerPacking: group_size=100000, bin_size=200, sequential=sequential, + drop_last=False, ) loader = DataLoader( train_dataset, batch_sampler=batch_sampler, - collate_fn=V2BatchSamplerDataCollatorForSeq2Seq( # pylint: disable=unexpected-keyword-arg + collate_fn=V2BatchSamplerDataCollatorForSeq2Seq( tokenizer=tokenizer, padding=True, pad_to_multiple_of=max_seq_length, @@ -100,10 +107,14 @@ class TestBatchedSamplerPacking: for pack in batch: batch_idxs.extend(pack) - for batch in loader: - assert batch["input_ids"].numel() <= batch_size * max_seq_length - assert batch["input_ids"].shape[1] == max_seq_length + try: + for batch in loader: + assert batch["input_ids"].numel() <= batch_size * max_seq_length + assert batch["input_ids"].shape[1] == max_seq_length - original_idxs = set(range(len(train_dataset))) - assert original_idxs == set(batch_idxs) - assert len(batch_idxs) == len(set(batch_idxs)) + original_idxs = set(range(len(train_dataset))) + assert original_idxs == set(batch_idxs) + assert len(batch_idxs) == len(set(batch_idxs)) + finally: + # Clean up: remove the patch after the test + remove_multipack_dataloader_patch() diff --git a/tests/test_packed_dataset.py b/tests/test_packed_dataset.py index 8b29eab21..953d523af 100644 --- a/tests/test_packed_dataset.py +++ b/tests/test_packed_dataset.py @@ -1,16 +1,11 @@ """Module for testing dataset sequence packing""" import unittest -from pathlib import Path -from datasets import Dataset, load_dataset from transformers import AutoTokenizer from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets -from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset -from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy -from axolotl.prompters import AlpacaPrompter from axolotl.train import setup_model_and_trainer from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault @@ -26,7 +21,6 @@ class TestPacking(unittest.TestCase): @enable_hf_offline def setUp(self) -> None: - # pylint: disable=duplicate-code self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer.add_special_tokens( { @@ -36,46 +30,8 @@ class TestPacking(unittest.TestCase): } ) - def test_increments_attention(self): - prompter = AlpacaPrompter("chat") - strat = AlpacaPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - dateset = load_dataset( - "json", - data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"), - )["train"] - dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset))) - - constant_len_dataset = ConstantLengthDataset( - self.tokenizer, - [dataset], - seq_length=2048, - ) - packed_dataset = Dataset.from_list(list(constant_len_dataset)) - example = packed_dataset[0] - next_bos_index = ( - example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1 - ) # add one since we sliced - - # first example doesn't have mask reset - assert example["input_ids"][0] == self.tokenizer.bos_token_id - assert example["attention_mask"][0] == 1 - assert example["position_ids"][0] == 0 - assert example["position_ids"][1] == 1 - - # but subsequent one does - assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id - assert example["attention_mask"][next_bos_index] == 2 - assert example["position_ids"][next_bos_index] == 0 - assert example["position_ids"][next_bos_index + 1] == 1 - @with_temp_dir def test_lora_packing(self, temp_dir): - # pylint: disable=duplicate-code cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -99,6 +55,7 @@ class TestPacking(unittest.TestCase): "type": "alpaca", }, ], + "dataset_num_proc": 4, "num_epochs": 1, "max_steps": 20, "save_steps": 10, @@ -126,9 +83,7 @@ class TestPacking(unittest.TestCase): _, ) = setup_model_and_trainer(cfg, dataset_meta) - sampler = trainer._get_eval_sampler( # pylint: disable=protected-access - trainer.eval_dataset - ) + sampler = trainer._get_eval_sampler(trainer.eval_dataset) assert "MultipackBatchSampler" in sampler.__class__.__name__ assert ( "V2BatchSamplerDataCollatorForSeq2Seq" @@ -139,9 +94,7 @@ class TestPacking(unittest.TestCase): batch = next(dataloader_iter) assert batch["input_ids"].shape == (1, 8192) - sampler = trainer._get_train_sampler( # pylint: disable=protected-access - trainer.train_dataset - ) + sampler = trainer._get_train_sampler(trainer.train_dataset) assert "MultipackBatchSampler" in sampler.__class__.__name__ assert ( "V2BatchSamplerDataCollatorForSeq2Seq" diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py index 115813df2..0458f7ba2 100644 --- a/tests/test_packed_pretraining.py +++ b/tests/test_packed_pretraining.py @@ -9,7 +9,7 @@ import torch from datasets import IterableDataset from torch.utils.data import DataLoader -from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset +from axolotl.utils.data import get_dataset_wrapper, wrap_streaming_dataset from axolotl.utils.dict import DictDefault @@ -76,16 +76,12 @@ class TestPretrainingPacking: cfg.pretraining_dataset[0]["type"] or "pretrain", ) - # pylint: disable=duplicate-code original_bsz = cfg.micro_batch_size - train_dataset = wrap_pretraining_dataset( + train_dataset = wrap_streaming_dataset( dataset, tokenizer_huggyllama, cfg, ds_wrapper_partial, - max_tokens=cfg.sequence_len, - batch_size=cfg.micro_batch_size, - seed=cfg.seed or 42, ) trainer_loader = DataLoader( diff --git a/tests/test_perplexity.py b/tests/test_perplexity.py index 9a1c9b223..8f4306994 100644 --- a/tests/test_perplexity.py +++ b/tests/test_perplexity.py @@ -1,7 +1,5 @@ """unit tests for perplexity eval callback""" -# pylint: disable=redefined-outer-name - from pytest import fixture from transformers.models.auto.modeling_auto import AutoModelForCausalLM from transformers.models.auto.tokenization_auto import AutoTokenizer diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 5e5de4ff8..672643a92 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -64,7 +64,7 @@ class TestPromptTokenizationStrategies: tests the interface between the user and assistant parts """ prompter = NoSystemPrompter() - # pylint: disable=duplicate-code + strat = AlpacaPromptTokenizingStrategy( prompter, tokenizer_huggyllama_w_special_tokens, @@ -85,7 +85,7 @@ class TestPromptTokenizationStrategies: """ tests the interface between the user and assistant parts """ - # pylint: disable=duplicate-code + prompter = AlpacaPrompter() strat = AlpacaPromptTokenizingStrategy( prompter, @@ -171,7 +171,7 @@ class Llama2ChatTokenizationTest: # from transformers.models.llama.tokenization_llama import DEFAULT_SYSTEM_PROMPT # broken as of 23/7/20 # see https://github.com/huggingface/transformers/pull/24935 - # pylint: disable=C0103 + DEFAULT_SYSTEM_PROMPT = """\ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. @@ -201,7 +201,7 @@ If a question does not make any sense, or is not factually coherent, explain why + user_input[1:-1], generated_responses=answers, ) - # pylint: disable=W0212 + hf_tokens = tokenizer_llama2_7b._build_conversation_input_ids(hf_conf) assert hf_tokens == tokenized_conversation["input_ids"][: len(hf_tokens)] diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py index 92664cca8..c783a68db 100644 --- a/tests/test_schedulers.py +++ b/tests/test_schedulers.py @@ -22,7 +22,7 @@ class TestCosineConstantLr(unittest.TestCase): self.constant_lr_ratio = 0.8 self._lr = 0.01 self.optimizer = SGD([torch.tensor(1)], lr=self._lr) - self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init + self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( self.optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=self.train_steps, diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 000000000..2c1f9f936 --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,238 @@ +"""Test streaming configuration and data loading functionality.""" + +import unittest +from unittest.mock import Mock, patch + +from datasets import IterableDataset + +from axolotl.utils.config import validate_config +from axolotl.utils.data.sft import ( + _prepare_streaming_dataset, + prepare_datasets, +) +from axolotl.utils.dict import DictDefault + + +class TestStreamingConfig(unittest.TestCase): + """Test streaming configuration and deprecation handling.""" + + def test_streaming_multipack_buffer_size_deprecation(self): + """Test that pretrain_multipack_buffer_size is properly deprecated.""" + # Test with old config name + cfg_old = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "pretrain_multipack_buffer_size": 5000, + "datasets": [{"path": "test/dataset", "type": "alpaca"}], + "sequence_len": 256, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 0.0001, + } + ) + + with self.assertLogs("axolotl.utils.schemas.validation", level="WARNING") as cm: + validated_cfg = validate_config(cfg_old) + self.assertIn("pretrain_multipack_buffer_size` is deprecated", cm.output[0]) + + self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 5000) + self.assertIsNone( + getattr(validated_cfg, "pretrain_multipack_buffer_size", None) + ) + + def test_streaming_multipack_buffer_size_new(self): + """Test that new streaming_multipack_buffer_size works correctly.""" + cfg_new = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "streaming_multipack_buffer_size": 7000, + "datasets": [{"path": "test/dataset", "type": "alpaca"}], + "sequence_len": 256, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 0.0001, + } + ) + + validated_cfg = validate_config(cfg_new) + self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 7000) + + def test_both_buffer_sizes_raises_error(self): + """Test that having both old and new buffer size configs raises an error.""" + cfg_both = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "pretrain_multipack_buffer_size": 5000, + "streaming_multipack_buffer_size": 7000, + "datasets": [{"path": "test/dataset", "type": "alpaca"}], + "sequence_len": 256, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 0.0001, + } + ) + + with self.assertRaises(ValueError) as cm: + validate_config(cfg_both) + self.assertIn("both are set", str(cm.exception)) + + +class TestStreamingDatasetPreparation(unittest.TestCase): + """Test dataset preparation with streaming configuration.""" + + def setUp(self): + self.tokenizer = Mock() + self.tokenizer.pad_token_id = 0 + self.tokenizer.eos_token_id = 1 + + @patch("axolotl.utils.data.sft._prepare_streaming_dataset") + def test_prepare_datasets_with_streaming_true(self, mock_prepare_streaming): + """Test that streaming=True triggers streaming dataset preparation.""" + cfg = DictDefault( + { + "streaming": True, + "datasets": [{"path": "test/dataset", "type": "alpaca"}], + } + ) + + mock_prepare_streaming.return_value = (Mock(), None, 100, []) + + prepare_datasets(cfg, self.tokenizer) + + mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None) + + @patch("axolotl.utils.data.sft._prepare_streaming_dataset") + def test_prepare_datasets_with_pretraining_dataset(self, mock_prepare_streaming): + """Test that pretraining_dataset triggers streaming dataset preparation.""" + cfg = DictDefault( + { + "pretraining_dataset": "test/dataset", + } + ) + + mock_prepare_streaming.return_value = (Mock(), None, 100, []) + + prepare_datasets(cfg, self.tokenizer) + + mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None) + + @patch("axolotl.utils.data.sft._prepare_standard_dataset") + def test_prepare_datasets_without_streaming(self, mock_prepare_standard): + """Test that without streaming, standard dataset preparation is used.""" + cfg = DictDefault( + { + "datasets": [{"path": "test/dataset", "type": "alpaca"}], + } + ) + + mock_prepare_standard.return_value = (Mock(), None, 100, []) + + prepare_datasets(cfg, self.tokenizer) + + mock_prepare_standard.assert_called_once_with(cfg, self.tokenizer, None) + + +class TestStreamingWithSamplePacking(unittest.TestCase): + """Test streaming dataset preparation with sample packing.""" + + def setUp(self): + self.tokenizer = Mock() + self.tokenizer.pad_token_id = 0 + self.tokenizer.eos_token_id = 1 + + @patch("axolotl.utils.data.sft._load_streaming_dataset") + def test_streaming_sft_with_sample_packing_sets_split(self, mock_load_streaming): + """Test that streaming SFT with sample_packing sets default split.""" + cfg = DictDefault( + { + "streaming": True, + "sample_packing": True, + "datasets": [{"path": "test/dataset", "type": "alpaca"}], + "sequence_len": 256, + "micro_batch_size": 1, + } + ) + + mock_load_streaming.return_value = Mock(spec=IterableDataset) + + with patch("axolotl.utils.data.sft._load_and_prepare_datasets"): + _prepare_streaming_dataset(cfg, self.tokenizer, None) + + # Check that the dataset config has split set to 'train' + call_args = mock_load_streaming.call_args + dataset_config = call_args[0][0] + self.assertEqual(dataset_config.split, "train") + + def test_multipack_attn_forced_true_for_sft(self): + """Test that multipack_attn is forced to True for SFT with sample packing.""" + from axolotl.utils.data.streaming import wrap_streaming_dataset + + cfg = DictDefault( + { + "sample_packing": True, + "pretrain_multipack_attn": False, # Should be overridden for SFT + "pretraining_dataset": None, # This makes it SFT + "sequence_len": 256, + "micro_batch_size": 1, + "streaming_multipack_buffer_size": 1000, + "seed": 42, + } + ) + + mock_dataset = Mock() + mock_dataset.features = None # For streaming datasets + mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator + mock_dataset.map = Mock(return_value=mock_dataset) + mock_ds_wrapper = Mock() + + with patch( + "axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq" + ) as mock_collator: + with patch("axolotl.utils.data.streaming.encode_packed_streaming"): + wrap_streaming_dataset( + mock_dataset, self.tokenizer, cfg, mock_ds_wrapper + ) + + # Check that multipack_attn=True was used in the collator + mock_collator.assert_called_once() + call_kwargs = mock_collator.call_args[1] + self.assertTrue(call_kwargs["multipack_attn"]) + + def test_multipack_attn_respects_config_for_pretraining(self): + """Test that multipack_attn respects config for pretraining datasets.""" + from axolotl.utils.data.streaming import wrap_streaming_dataset + + cfg = DictDefault( + { + "sample_packing": True, + "pretrain_multipack_attn": False, # Should be respected for pretraining + "pretraining_dataset": "test/dataset", # This makes it pretraining + "sequence_len": 256, + "micro_batch_size": 1, + "streaming_multipack_buffer_size": 1000, + "seed": 42, + } + ) + + mock_dataset = Mock() + mock_dataset.features = None # For streaming datasets + mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator + mock_dataset.map = Mock(return_value=mock_dataset) + mock_ds_wrapper = Mock() + + with patch( + "axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq" + ) as mock_collator: + with patch("axolotl.utils.data.streaming.encode_packed_streaming"): + wrap_streaming_dataset( + mock_dataset, self.tokenizer, cfg, mock_ds_wrapper + ) + + # Check that multipack_attn=False was used (respecting config) + mock_collator.assert_called_once() + call_kwargs = mock_collator.call_args[1] + self.assertFalse(call_kwargs["multipack_attn"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_train.py b/tests/test_train.py new file mode 100644 index 000000000..2c29b58ee --- /dev/null +++ b/tests/test_train.py @@ -0,0 +1,39 @@ +"""Test for batch size calculation for multi-gpu training.""" + +import pytest + +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="train_base_cfg") +def fixture_train_base_cfg(min_base_cfg): + return ( + DictDefault( + micro_batch_size=2, + gradient_accumulation_steps=4, + sequence_len=2048, + sample_packing=True, + num_epochs=1, + ) + | min_base_cfg + ) + + +class TestTrain: + """test class for train related tests""" + + @pytest.mark.parametrize( + "world_size, expected_batch_size", + [ + (1, 8), + (4, 32), + ], + ) + def test_batch_size_ddp( + self, train_base_cfg, monkeypatch, world_size, expected_batch_size + ): + monkeypatch.setenv("WORLD_SIZE", str(world_size)) + cfg = validate_config(train_base_cfg) + normalize_config(cfg) + assert cfg.batch_size == expected_batch_size diff --git a/tests/test_utils_tee.py b/tests/test_utils_tee.py new file mode 100644 index 000000000..e2c153667 --- /dev/null +++ b/tests/test_utils_tee.py @@ -0,0 +1,107 @@ +import os +import tempfile + + +def _dummy_cfg(output_dir: str, append: bool = False): + # Minimal object with attributes used by prepare_debug_log + class Cfg: + def __init__(self, out, append): + self.output_dir = out + self._append = append + + def get(self, key, default=None): + if key in {"resume_from_checkpoint", "auto_resume_from_checkpoints"}: + return self._append + return default + + return Cfg(output_dir, append) + + +def read(path: str) -> str: + with open(path, "r", encoding="utf-8") as f: + return f.read() + + +def test_file_only_stream_writes_after_prepare(monkeypatch): + from axolotl.utils import tee + + with tempfile.TemporaryDirectory() as td: + # Avoid stdout tee in this test + monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0") + cfg = _dummy_cfg(td, append=False) + + # before prepare: writing to file_only_stream creates no file + tee.file_only_stream.write("before\n") + tee.file_only_stream.flush() + assert not os.path.exists(os.path.join(td, "debug.log")) + + # prepare and write + path = tee.prepare_debug_log(cfg) + assert os.path.basename(path) == "debug.log" + tee.file_only_stream.write("hello\n") + tee.file_only_stream.flush() + + content = read(path) + assert "hello" in content + + tee.close_debug_log() + + +def test_stdout_is_mirrored_after_prepare(capsys, monkeypatch): + from axolotl.utils import tee + + with tempfile.TemporaryDirectory() as td: + cfg = _dummy_cfg(td, append=False) + try: + # Install tee while capture is disabled so stdout tee wraps real stdout. + with capsys.disabled(): + monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "1") + path = tee.prepare_debug_log(cfg) + import sys + + print("printed-line") + sys.stdout.flush() + + # Now verify file contains the line + content = read(path) + assert "printed-line" in content + finally: + tee.close_debug_log() + + +def test_truncate_vs_append_behavior(monkeypatch): + from axolotl.utils import tee + + with tempfile.TemporaryDirectory() as td: + # Avoid stdout tee in this test + monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0") + # First run creates file with A + cfg = _dummy_cfg(td, append=False) + _ = tee.prepare_debug_log(cfg) + try: + tee.file_only_stream.write("A\n") + tee.file_only_stream.flush() + finally: + tee.close_debug_log() + + # Second run with append=False truncates + cfg2 = _dummy_cfg(td, append=False) + path2 = tee.prepare_debug_log(cfg2) + try: + tee.file_only_stream.write("B\n") + tee.file_only_stream.flush() + content = read(path2) + assert "A\n" not in content and "B\n" in content + finally: + tee.close_debug_log() + + # Third run with append=True preserves existing + cfg3 = _dummy_cfg(td, append=True) + path3 = tee.prepare_debug_log(cfg3) + try: + tee.file_only_stream.write("C\n") + tee.file_only_stream.flush() + content = read(path3) + assert "B\n" in content and "C\n" in content + finally: + tee.close_debug_log() diff --git a/tests/test_validation_dataset.py b/tests/test_validation_dataset.py index ba142f3bf..3d3b5db96 100644 --- a/tests/test_validation_dataset.py +++ b/tests/test_validation_dataset.py @@ -24,7 +24,6 @@ def fixture_cfg(): ) -# pylint: disable=too-many-public-methods (duplicate-code) class BaseValidation: """ Base validation module to setup the log capture @@ -73,7 +72,7 @@ class TestValidationCheckDatasetConfig(BaseValidation): "compute_capability": "8.0", }, env_capabilities={ - "torch_version": "2.5.1", + "torch_version": "2.6.0", }, ) @@ -128,7 +127,7 @@ class TestValidationCheckDatasetConfig(BaseValidation): "compute_capability": "8.0", }, env_capabilities={ - "torch_version": "2.5.1", + "torch_version": "2.6.0", }, ) @@ -184,7 +183,7 @@ class TestValidationCheckDatasetConfig(BaseValidation): "compute_capability": "8.0", }, env_capabilities={ - "torch_version": "2.5.1", + "torch_version": "2.6.0", }, ) @@ -241,7 +240,7 @@ class TestValidationCheckDatasetConfig(BaseValidation): "compute_capability": "8.0", }, env_capabilities={ - "torch_version": "2.5.1", + "torch_version": "2.6.0", }, ) diff --git a/tests/utils/schemas/validation/test_activation_offloading.py b/tests/utils/schemas/validation/test_activation_offloading.py new file mode 100644 index 000000000..433133a80 --- /dev/null +++ b/tests/utils/schemas/validation/test_activation_offloading.py @@ -0,0 +1,35 @@ +"""Test for config validation for activation offloading.""" + +from axolotl.utils.config import validate_config +from axolotl.utils.dict import DictDefault + + +class TestActivationOffloading: + """ + Test cases for activation offloading schema validation + """ + + def test_gc_converts_offload_wo_lora(self, min_base_cfg): + cfg = ( + DictDefault( + gradient_checkpointing="offload", + ) + | min_base_cfg + ) + + cfg = validate_config(cfg) + assert cfg.gradient_checkpointing is True + assert cfg.activation_offloading is True + + def test_ac_offload_impl_noop_wo_adapter(self, min_base_cfg): + cfg = ( + DictDefault( + gradient_checkpointing=True, + activation_offloading=True, + ) + | min_base_cfg + ) + + cfg = validate_config(cfg) + assert cfg.gradient_checkpointing is True + assert cfg.activation_offloading is True diff --git a/tests/utils/schemas/validation/test_default_values.py b/tests/utils/schemas/validation/test_default_values.py new file mode 100644 index 000000000..332dfe77f --- /dev/null +++ b/tests/utils/schemas/validation/test_default_values.py @@ -0,0 +1,21 @@ +"""Tests for default values for configurations""" + +from axolotl.utils.config import validate_config +from axolotl.utils.dict import DictDefault + + +class TestDefaultConfigValues: + """Tests for default values for configurations""" + + def test_pad_to_sequence_len(self, min_base_cfg): + """Tests that sample packing automatically sets pad_to_sequence_len to True""" + cfg = ( + DictDefault( + sample_packing=True, + ) + | min_base_cfg + ) + + cfg = validate_config(cfg) + + assert cfg.pad_to_sequence_len is True diff --git a/tests/utils/schemas/validation/test_fsdp.py b/tests/utils/schemas/validation/test_fsdp.py new file mode 100644 index 000000000..65f9c66a3 --- /dev/null +++ b/tests/utils/schemas/validation/test_fsdp.py @@ -0,0 +1,150 @@ +""" +tests for pydantic fsdp validation +""" + +import pytest + +from axolotl.utils.config import validate_config +from axolotl.utils.dict import DictDefault + + +class TestFSDPValidation: + """ + test class for pydantic fsdp validation + """ + + def test_fsdp_version_in_fsdp_config(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + fsdp_config={ + "fsdp_version": 2, + }, + ) + cfg = validate_config( + cfg, + ) + assert cfg.fsdp_version == 2 + assert cfg.fsdp_config.fsdp_version is None + + def test_fsdp_offload_w_8bit_optim(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + fsdp_config={ + "offload_params": True, + }, + optimizer="adamw_8bit", + fsdp_version=1, + ) + with pytest.raises( + ValueError, match="FSDP Offload not compatible with adamw_8bit" + ): + validate_config(cfg) + + def test_fsdp2_w_8bit_optim(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + fsdp_config={ + "offload_params": True, + }, + optimizer="adamw_8bit", + fsdp_version=2, + ) + with pytest.raises( + ValueError, + match="FSDP2 not compatible with adamw_8bit, use `adamw_torch_8bit` instead", + ): + validate_config(cfg) + + def test_fsdp2_w_cpu_ram_efficient_loading(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + load_in_8bit=True, + adapter="lora", + fsdp_config={ + "cpu_ram_efficient_loading": True, + }, + fsdp_version=2, + ) + validated_cfg = validate_config(cfg) + assert validated_cfg.fsdp_version == 2 + assert validated_cfg.fsdp_config.cpu_ram_efficient_loading is True + + def test_fsdp2_cpu_offload_pin_memory_requires_offload_params(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + fsdp_config={ + "cpu_offload_pin_memory": False, + "offload_params": False, + }, + fsdp_version=2, + ) + with pytest.raises( + ValueError, + match="disabling cpu_offload_pin_memory requires enabling offload_params", + ): + validate_config(cfg) + + def test_fsdp1_cpu_offload_pin_memory_not_supported(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + fsdp_config={ + "cpu_offload_pin_memory": False, + "offload_params": True, + }, + fsdp_version=1, + ) + with pytest.raises( + ValueError, + match="FSDP1 does not support disabling cpu_offload_pin_memory, please set `fsdp_version` to 2", + ): + validate_config(cfg) + + def test_fsdp2_cpu_offload_pin_memory_w_offload_params(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + fsdp_config={ + "cpu_offload_pin_memory": False, + "offload_params": True, + }, + fsdp_version=2, + ) + validated_cfg = validate_config(cfg) + assert validated_cfg.fsdp_config.cpu_offload_pin_memory is False + assert validated_cfg.fsdp_config.offload_params is True + + def test_fsdp_prefixes_removed(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + fsdp_config={ + "fsdp_version": 2, + "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer", + "fsdp_reshard_after_forward": True, + } + ) + cfg = validate_config(cfg) + assert cfg.fsdp_version == 2 + assert cfg.fsdp_config.fsdp_version is None + for keys in cfg.fsdp_config.keys(): + assert not keys.startswith("fsdp_") + assert cfg.fsdp_config.auto_wrap_policy == "TRANSFORMER_BASED_WRAP" + assert cfg.fsdp_config.transformer_layer_cls_to_wrap == "LlamaDecoderLayer" + assert cfg.fsdp_config.reshard_after_forward is True + + @pytest.mark.parametrize( + "rl", + [ + "dpo", + "kto", + "orpo", + "ipo", + ], + ) + def test_fsdp2_dpo(self, min_base_cfg, rl): + cfg = min_base_cfg | DictDefault( + fsdp_version=2, + fsdp_config={ + "reshard_after_forward": True, + }, + rl=rl, + load_in_8bit=True, + adapter="lora", + remove_unused_columns=False, + ) + with pytest.raises( + ValueError, + match="FSDP2 does not support load_in_8bit or load_in_4bit with ", + ): + validate_config(cfg) diff --git a/tests/utils/test_import_helper.py b/tests/utils/test_import_helper.py new file mode 100644 index 000000000..e1ab8bec5 --- /dev/null +++ b/tests/utils/test_import_helper.py @@ -0,0 +1,37 @@ +""" +test cases for axolotl.utils.import_helper +""" + +import pytest + +from axolotl.utils.import_helper import get_cls_from_module_str + + +def test_get_cls_from_module_str(): + cls = get_cls_from_module_str("axolotl.core.trainers.base.AxolotlTrainer") + assert cls.__name__ == "AxolotlTrainer" + + +def test_get_cls_from_module_str_empty_string(): + with pytest.raises(ValueError, match="module_str must be a non-empty string"): + get_cls_from_module_str("") + + +def test_get_cls_from_module_str_whitespace_only(): + with pytest.raises(ValueError, match="module_str must be a non-empty string"): + get_cls_from_module_str(" ") + + +def test_get_cls_from_module_str_invalid_format(): + with pytest.raises(ValueError, match="Invalid module string format"): + get_cls_from_module_str("single_part") + + +def test_get_cls_from_module_str_nonexistent_module(): + with pytest.raises(ImportError, match="Failed to import module"): + get_cls_from_module_str("nonexistent.module.Class") + + +def test_get_cls_from_module_str_nonexistent_class(): + with pytest.raises(AttributeError, match="Class 'NonExistentClass' not found"): + get_cls_from_module_str("axolotl.core.trainers.base.NonExistentClass") diff --git a/tests/utils/test_train.py b/tests/utils/test_train.py new file mode 100644 index 000000000..a1f6f6088 --- /dev/null +++ b/tests/utils/test_train.py @@ -0,0 +1,24 @@ +"""test for train checkpoint utils""" + +import os + +from axolotl.utils.dict import DictDefault +from axolotl.utils.train import determine_last_checkpoint + + +def test_determine_last_checkpoint(temp_dir): + cfg = DictDefault( + output_dir=temp_dir, + ) + for cpt_idx in [1, 9, 10, 20]: + os.makedirs( + os.path.join(cfg.output_dir, f"checkpoint-{cpt_idx}"), exist_ok=True + ) + + last_checkpoint = determine_last_checkpoint(cfg, update=False) + assert last_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20") + + cfg.resume_from_checkpoint = None + cfg.auto_resume_from_checkpoints = True + determine_last_checkpoint(cfg, update=True) + assert cfg.resume_from_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")