Compare commits
1 Commits
testingci
...
quantize-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5a51852af1 |
@@ -1,41 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
_axolotl_completions() {
|
|
||||||
local cur prev
|
|
||||||
COMPREPLY=()
|
|
||||||
cur="${COMP_WORDS[COMP_CWORD]}"
|
|
||||||
prev="${COMP_WORDS[COMP_CWORD-1]}"
|
|
||||||
|
|
||||||
# If we're completing the first argument (the command)
|
|
||||||
if [[ $COMP_CWORD -eq 1 ]]; then
|
|
||||||
mapfile -t COMPREPLY < <(compgen -W "delinearize-llama4 fetch lm-eval merge-sharded-fsdp-weights quantize vllm-serve evaluate inference merge-lora preprocess train" -- "$cur")
|
|
||||||
return 0
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Commands that should complete with directories and YAML files
|
|
||||||
local -a yaml_commands=("merge-sharded-fsdp-weights" "quantize" "vllm-serve" "evaluate" "inference" "merge-lora" "preprocess" "train")
|
|
||||||
|
|
||||||
# Check if previous word is in our list
|
|
||||||
if [[ " ${yaml_commands[*]} " =~ (^|[[:space:]])$prev($|[[:space:]]) ]]; then
|
|
||||||
# Use filename completion which handles directories properly
|
|
||||||
compopt -o filenames
|
|
||||||
mapfile -t COMPREPLY < <(compgen -f -- "$cur")
|
|
||||||
|
|
||||||
# Filter to only include directories and YAML files
|
|
||||||
local -a filtered=()
|
|
||||||
for item in "${COMPREPLY[@]}"; do
|
|
||||||
if [[ -d "$item" ]] || [[ "$item" == *.yaml ]] || [[ "$item" == *.yml ]]; then
|
|
||||||
filtered+=("$item")
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
COMPREPLY=("${filtered[@]}")
|
|
||||||
|
|
||||||
return 0
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Default: no completion
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
# Remove the -o nospace option - let filenames handle it
|
|
||||||
complete -F _axolotl_completions axolotl
|
|
||||||
4
.github/workflows/base.yml
vendored
4
.github/workflows/base.yml
vendored
@@ -17,7 +17,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-base:
|
build-base:
|
||||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
timeout-minutes: 480
|
timeout-minutes: 480
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: ubuntu-latest-m
|
runs-on: ubuntu-latest-m
|
||||||
@@ -108,7 +108,7 @@ jobs:
|
|||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||||
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
|
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
|
||||||
build-base-uv:
|
build-base-uv:
|
||||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
timeout-minutes: 480
|
timeout-minutes: 480
|
||||||
runs-on: ubuntu-latest-m
|
runs-on: ubuntu-latest-m
|
||||||
strategy:
|
strategy:
|
||||||
|
|||||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -3,7 +3,6 @@ on:
|
|||||||
# check on PRs, and manual triggers
|
# check on PRs, and manual triggers
|
||||||
merge_group:
|
merge_group:
|
||||||
pull_request:
|
pull_request:
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- 'requirements.txt'
|
- 'requirements.txt'
|
||||||
@@ -17,7 +16,6 @@ jobs:
|
|||||||
pre-commit:
|
pre-commit:
|
||||||
name: pre-commit
|
name: pre-commit
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: ${{ !github.event.pull_request.draft }}
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
|
|||||||
2
.github/workflows/multi-gpu-e2e.yml
vendored
2
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -21,7 +21,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test-axolotl-multigpu:
|
test-axolotl-multigpu:
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
|
|||||||
3
.github/workflows/preview-docs.yml
vendored
3
.github/workflows/preview-docs.yml
vendored
@@ -2,7 +2,7 @@ name: Preview
|
|||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
pull_request:
|
pull_request:
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
types: [opened, synchronize, reopened]
|
||||||
|
|
||||||
# Run the workflow only when one of these files changes
|
# Run the workflow only when one of these files changes
|
||||||
paths:
|
paths:
|
||||||
@@ -25,7 +25,6 @@ permissions:
|
|||||||
jobs:
|
jobs:
|
||||||
preview:
|
preview:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: ${{ !github.event.pull_request.draft }}
|
|
||||||
steps:
|
steps:
|
||||||
- name: Check out repository
|
- name: Check out repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
2
.github/workflows/tests-nightly.yml
vendored
2
.github/workflows/tests-nightly.yml
vendored
@@ -52,7 +52,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
pip3 install torch==${{ matrix.pytorch_version }} torchvision
|
pip3 install torch==${{ matrix.pytorch_version }}
|
||||||
|
|
||||||
- name: Update requirements.txt
|
- name: Update requirements.txt
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
13
.github/workflows/tests.yml
vendored
13
.github/workflows/tests.yml
vendored
@@ -13,7 +13,6 @@ on:
|
|||||||
- 'cicd/cicd.sh'
|
- 'cicd/cicd.sh'
|
||||||
- 'cicd/Dockerfile.jinja'
|
- 'cicd/Dockerfile.jinja'
|
||||||
pull_request:
|
pull_request:
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- 'requirements.txt'
|
- 'requirements.txt'
|
||||||
@@ -35,7 +34,6 @@ jobs:
|
|||||||
pre-commit:
|
pre-commit:
|
||||||
name: pre-commit
|
name: pre-commit
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: ${{ !github.event.pull_request.draft }}
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
@@ -49,7 +47,6 @@ jobs:
|
|||||||
pytest:
|
pytest:
|
||||||
name: PyTest
|
name: PyTest
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: ${{ !github.event.pull_request.draft }}
|
|
||||||
# needs: [preload-cache]
|
# needs: [preload-cache]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
@@ -81,7 +78,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
pip3 install torch==${{ matrix.pytorch_version }} torchvision
|
pip3 install torch==${{ matrix.pytorch_version }}
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -124,7 +121,6 @@ jobs:
|
|||||||
pytest-sdist:
|
pytest-sdist:
|
||||||
name: PyTest from Source Dist
|
name: PyTest from Source Dist
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: ${{ !github.event.pull_request.draft }}
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@@ -155,7 +151,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
pip3 install torch==${{ matrix.pytorch_version }} torchvision
|
pip3 install torch==${{ matrix.pytorch_version }}
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -189,7 +185,7 @@ jobs:
|
|||||||
|
|
||||||
docker-e2e-tests-1st:
|
docker-e2e-tests-1st:
|
||||||
# Run this job first as a gate for running the remainder of the test matrix
|
# Run this job first as a gate for running the remainder of the test matrix
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
@@ -239,7 +235,7 @@ jobs:
|
|||||||
modal run cicd.e2e_tests
|
modal run cicd.e2e_tests
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
@@ -293,7 +289,6 @@ jobs:
|
|||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 90
|
timeout-minutes: 90
|
||||||
needs: [docker-e2e-tests]
|
needs: [docker-e2e-tests]
|
||||||
if: ${{ !github.event.pull_request.draft }}
|
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.17.0
|
rev: v1.16.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
|
|||||||
@@ -268,8 +268,6 @@ website:
|
|||||||
- docs/batch_vs_grad.qmd
|
- docs/batch_vs_grad.qmd
|
||||||
- docs/dataset_preprocessing.qmd
|
- docs/dataset_preprocessing.qmd
|
||||||
- docs/multipack.qmd
|
- docs/multipack.qmd
|
||||||
- docs/mixed_precision.qmd
|
|
||||||
- docs/gradient_accumulation.qmd
|
|
||||||
|
|
||||||
- section: "Advanced Features"
|
- section: "Advanced Features"
|
||||||
contents:
|
contents:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
|||||||
ENV HF_HOME="{{ HF_HOME }}"
|
ENV HF_HOME="{{ HF_HOME }}"
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ ENV HF_HOME="{{ HF_HOME }}"
|
|||||||
ENV AXOLOTL_DATASET_PROCESSES="8"
|
ENV AXOLOTL_DATASET_PROCESSES="8"
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
|||||||
@@ -7,9 +7,9 @@
|
|||||||
"reduce_bucket_size": "auto",
|
"reduce_bucket_size": "auto",
|
||||||
"stage3_prefetch_bucket_size": "auto",
|
"stage3_prefetch_bucket_size": "auto",
|
||||||
"stage3_param_persistence_threshold": "auto",
|
"stage3_param_persistence_threshold": "auto",
|
||||||
"max_live_parameters": 0,
|
"stage3_max_live_parameters": 0,
|
||||||
"max_reuse_distance": 0,
|
"stage3_max_reuse_distance": 0,
|
||||||
"gather_16bit_weights_on_model_save": true
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
},
|
},
|
||||||
"bf16": {
|
"bf16": {
|
||||||
"enabled": "auto"
|
"enabled": "auto"
|
||||||
|
|||||||
@@ -7,9 +7,9 @@
|
|||||||
"reduce_bucket_size": "auto",
|
"reduce_bucket_size": "auto",
|
||||||
"stage3_prefetch_bucket_size": "auto",
|
"stage3_prefetch_bucket_size": "auto",
|
||||||
"stage3_param_persistence_threshold": "auto",
|
"stage3_param_persistence_threshold": "auto",
|
||||||
"max_live_parameters": 0,
|
"stage3_max_live_parameters": 0,
|
||||||
"max_reuse_distance": 0,
|
"stage3_max_reuse_distance": 0,
|
||||||
"gather_16bit_weights_on_model_save": true
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
},
|
},
|
||||||
"bf16": {
|
"bf16": {
|
||||||
"enabled": true
|
"enabled": true
|
||||||
|
|||||||
@@ -17,9 +17,9 @@
|
|||||||
"reduce_bucket_size": "auto",
|
"reduce_bucket_size": "auto",
|
||||||
"stage3_prefetch_bucket_size": "auto",
|
"stage3_prefetch_bucket_size": "auto",
|
||||||
"stage3_param_persistence_threshold": "auto",
|
"stage3_param_persistence_threshold": "auto",
|
||||||
"max_live_parameters": 0,
|
"stage3_max_live_parameters": 0,
|
||||||
"max_reuse_distance": 0,
|
"stage3_max_reuse_distance": 0,
|
||||||
"gather_16bit_weights_on_model_save": true
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
},
|
},
|
||||||
"bf16": {
|
"bf16": {
|
||||||
"enabled": true
|
"enabled": true
|
||||||
|
|||||||
@@ -13,9 +13,9 @@
|
|||||||
"reduce_bucket_size": "auto",
|
"reduce_bucket_size": "auto",
|
||||||
"stage3_prefetch_bucket_size": "auto",
|
"stage3_prefetch_bucket_size": "auto",
|
||||||
"stage3_param_persistence_threshold": "auto",
|
"stage3_param_persistence_threshold": "auto",
|
||||||
"max_live_parameters": 0,
|
"stage3_max_live_parameters": 0,
|
||||||
"max_reuse_distance": 0,
|
"stage3_max_reuse_distance": 0,
|
||||||
"gather_16bit_weights_on_model_save": true
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
},
|
},
|
||||||
"bf16": {
|
"bf16": {
|
||||||
"enabled": true
|
"enabled": true
|
||||||
|
|||||||
@@ -10,9 +10,7 @@ ARG PYTORCH_VERSION="2.1.2"
|
|||||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
|
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs
|
||||||
rm -rf /var/cache/apt/archives && \
|
|
||||||
rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
@@ -25,17 +23,17 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
|||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||||
fi && \
|
fi
|
||||||
python scripts/unsloth_install.py | sh && \
|
|
||||||
python scripts/cutcrossentropy_install.py | sh && \
|
|
||||||
pip install pytest && \
|
|
||||||
pip cache purge
|
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works with shallow clone
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
RUN python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|
||||||
|
# So we can test the Docker image
|
||||||
|
RUN pip install pytest
|
||||||
|
|
||||||
|
# fix so that git fetch/pull from remote works
|
||||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||||
git config --get remote.origin.fetch && \
|
git config --get remote.origin.fetch
|
||||||
git config --global credential.helper store
|
|
||||||
|
|
||||||
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
|
# helper for huggingface-login cli
|
||||||
RUN chmod +x /root/.axolotl-complete.bash && \
|
RUN git config --global credential.helper store
|
||||||
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc
|
|
||||||
|
|||||||
@@ -16,16 +16,12 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
|
|||||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||||
|
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config \
|
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||||
&& rm -rf /var/cache/apt/archives \
|
|
||||||
&& rm -rf /var/lib/apt/lists/* \
|
|
||||||
&& wget \
|
&& wget \
|
||||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||||
&& mkdir /root/.conda \
|
&& mkdir /root/.conda \
|
||||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
|
||||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
|
||||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||||
|
|
||||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||||
@@ -35,14 +31,12 @@ WORKDIR /workspace
|
|||||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
|
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||||
python3 -m pip cache purge
|
|
||||||
|
|
||||||
RUN git lfs install --skip-repo && \
|
RUN git lfs install --skip-repo && \
|
||||||
pip3 install awscli && \
|
pip3 install awscli && \
|
||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
pip3 cache purge
|
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
|
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
|
||||||
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
|
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
|
||||||
|
|||||||
@@ -22,22 +22,18 @@ RUN apt-get update \
|
|||||||
&& mkdir /root/.conda \
|
&& mkdir /root/.conda \
|
||||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
|
||||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
|
||||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||||
|
|
||||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||||
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
|
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
|
||||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
|
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||||
python3 -m pip cache purge
|
|
||||||
|
|
||||||
RUN git lfs install --skip-repo && \
|
RUN git lfs install --skip-repo && \
|
||||||
pip3 install awscli && \
|
pip3 install awscli && \
|
||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
pip3 cache purge
|
|
||||||
|
|||||||
@@ -14,10 +14,7 @@ COPY scripts/motd /etc/motd
|
|||||||
|
|
||||||
RUN pip install jupyterlab notebook ipywidgets && \
|
RUN pip install jupyterlab notebook ipywidgets && \
|
||||||
jupyter lab clean
|
jupyter lab clean
|
||||||
RUN apt update && \
|
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
|
||||||
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm && \
|
|
||||||
rm -rf /var/cache/apt/archives && \
|
|
||||||
rm -rf /var/lib/apt/lists/* && \
|
|
||||||
mkdir -p ~/.ssh && \
|
mkdir -p ~/.ssh && \
|
||||||
chmod 700 ~/.ssh && \
|
chmod 700 ~/.ssh && \
|
||||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||||
|
|||||||
@@ -1,149 +0,0 @@
|
|||||||
---
|
|
||||||
title: "Mixed Precision Training"
|
|
||||||
format:
|
|
||||||
html:
|
|
||||||
toc: true
|
|
||||||
toc-depth: 3
|
|
||||||
number-sections: true
|
|
||||||
code-tools: true
|
|
||||||
execute:
|
|
||||||
enabled: false
|
|
||||||
---
|
|
||||||
|
|
||||||
Mixed precision training uses lower precision data types to reduce memory usage and increase training speed while maintaining model quality. Axolotl supports several mixed precision formats:
|
|
||||||
|
|
||||||
- **FP16** - Half precision 16-bit (Pascal generation+)
|
|
||||||
- **BF16** - Brain Float 16-bit (Ampere generation+)
|
|
||||||
- **FP8** - 8-bit floating point (Hopper generation+)
|
|
||||||
|
|
||||||
## FP16 Mixed Precision {#sec-fp16}
|
|
||||||
|
|
||||||
### Overview {#sec-fp16-overview}
|
|
||||||
|
|
||||||
FP16 is the traditional half-precision format, supported on older GPUs but can be less numerically stable than BF16.
|
|
||||||
|
|
||||||
### Configuration {#sec-fp16-config}
|
|
||||||
|
|
||||||
```{.yaml}
|
|
||||||
fp16: true
|
|
||||||
```
|
|
||||||
|
|
||||||
### FP16 Considerations {#sec-fp16-considerations}
|
|
||||||
|
|
||||||
- May require gradient scaling to prevent underflow
|
|
||||||
- Less numerically stable than BF16
|
|
||||||
- Can cause training instability with some model architectures
|
|
||||||
- Consider using BF16 if your hardware supports it
|
|
||||||
|
|
||||||
## BF16 Mixed Precision {#sec-bf16}
|
|
||||||
|
|
||||||
### Overview {#sec-bf16-overview}
|
|
||||||
|
|
||||||
BF16 (Brain Float 16) offers better numerical stability than FP16 and is the recommended mixed precision format for modern GPUs. It provides the same dynamic range as FP32 while using half the memory.
|
|
||||||
|
|
||||||
### Configuration {#sec-bf16-config}
|
|
||||||
|
|
||||||
```{.yaml}
|
|
||||||
# Automatic BF16 detection (recommended)
|
|
||||||
bf16: auto
|
|
||||||
|
|
||||||
# Or explicitly enable
|
|
||||||
bf16: true
|
|
||||||
|
|
||||||
# For evaluation with BF16
|
|
||||||
bf16: full # Equivalent to bf16_full_eval in the HF trainer
|
|
||||||
```
|
|
||||||
|
|
||||||
## FP8 Mixed Precision {#sec-fp8}
|
|
||||||
|
|
||||||
::: {.callout-note}
|
|
||||||
FP8 support is experimental and requires compatible hardware (H100, H200) and recent PyTorch versions with TorchAO.
|
|
||||||
:::
|
|
||||||
|
|
||||||
### What is FP8? {#sec-fp8-overview}
|
|
||||||
|
|
||||||
FP8 (8-bit floating point) can provide significant time savings compared to FP16/BF16 while maintaining training stability. Axolotl's implementation uses PyTorch's TorchAO library with "tensorwise" scaling strategy.
|
|
||||||
|
|
||||||
### Requirements {#sec-fp8-software}
|
|
||||||
|
|
||||||
- Hopper+ GPUs (H100/H200)
|
|
||||||
- PyTorch 2.7+ (+ compatible TorchAO version)
|
|
||||||
- CUDA 12.4+
|
|
||||||
|
|
||||||
### Configuration {#sec-fp8-config}
|
|
||||||
|
|
||||||
Add to your YAML config:
|
|
||||||
|
|
||||||
```{.yaml}
|
|
||||||
# Enable FP8 mixed precision
|
|
||||||
fp8: true
|
|
||||||
|
|
||||||
# Optional: Enable FP8 for FSDP all-gather operations
|
|
||||||
fp8_enable_fsdp_float8_all_gather: true
|
|
||||||
|
|
||||||
# Enable torch.compile (almost always necessary for FP8 speedups)
|
|
||||||
torch_compile: true
|
|
||||||
```
|
|
||||||
|
|
||||||
::: {.callout-important}
|
|
||||||
**torch.compile is critical for FP8 performance**
|
|
||||||
|
|
||||||
FP8 training requires `torch_compile: true` to see meaningful speedups. Without compilation, FP8 may actually be slower and use more memory than FP16/BF16.
|
|
||||||
:::
|
|
||||||
|
|
||||||
### Advanced FP8 Configs {#sec-fp8-advanced}
|
|
||||||
|
|
||||||
For [FSDP](multi-gpu.qmd#sec-fsdp) (Fully Sharded Data Parallel) training:
|
|
||||||
|
|
||||||
```{.yaml}
|
|
||||||
fp8: true
|
|
||||||
fp8_enable_fsdp_float8_all_gather: true
|
|
||||||
|
|
||||||
torch_compile: true
|
|
||||||
|
|
||||||
# FSDP configuration
|
|
||||||
fsdp_version: 2
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: false
|
|
||||||
cpu_ram_efficient_loading: true
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
reshard_after_forward: true
|
|
||||||
```
|
|
||||||
|
|
||||||
## Best Practices {#sec-best-practices}
|
|
||||||
|
|
||||||
### Choosing Precision Format {#sec-choosing-format}
|
|
||||||
|
|
||||||
- **Start with automatic detection**: `bf16: auto`
|
|
||||||
- **For Hopper+ (H100/H200)**: Try FP8 + torch.compile for maximum speed
|
|
||||||
- **For Ampere (A100/RTX 30/40)**: Use BF16
|
|
||||||
- **For older Pascal/Turing GPUs**: Use FP16 with caution
|
|
||||||
- **For very old or unsupported GPUs**: Use FP32
|
|
||||||
|
|
||||||
### Validation and Testing {#sec-validation}
|
|
||||||
|
|
||||||
Always validate your mixed precision setup:
|
|
||||||
|
|
||||||
- **Start with a small dataset** to verify stability
|
|
||||||
- **Monitor loss curves** for irregularities
|
|
||||||
- **Compare with FP32 baseline** when possible
|
|
||||||
- **Test evaluation metrics** match expectations
|
|
||||||
|
|
||||||
### FP8 Particulars {#sec-fp8-details}
|
|
||||||
|
|
||||||
- Use cases
|
|
||||||
- Single GPU training
|
|
||||||
- Multi GPU training with FSDP2 or Deepspeed
|
|
||||||
- Speedups
|
|
||||||
- Please refer to the [TorchAO FP8 training benchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling) for expected matmul speedups for different (M, K, N) settings
|
|
||||||
- Concrete number for LLaMA 3 8B training can be found [here](https://github.com/pytorch/ao/tree/main/torchao/float8#training-benchmarks)
|
|
||||||
- Known issues:
|
|
||||||
- FP8 + DDP + `torch.compile` (causes [error](https://gist.github.com/djsaunde/0c1664c32e44a64d31b5e01b4aafe5c4))
|
|
||||||
- FP8 + FSDP2 + `torch.compile` + FSDP2 activation checkpointing tends to be _slower_ than the BF16 equivalent training
|
|
||||||
- Flash Attention 2 does not play nicely with `torch.compile`
|
|
||||||
|
|
||||||
See `examples/llama-3/3b-fp8-fsdp2.yaml` for an optimized example config. Enabling FP8 mixed precision + FP8 all-gather training results in ~10% faster iterations per second vs. BF16 for a relatively small (3B param) model
|
|
||||||
|
|
||||||
For more information on multi-GPU training, see our [Multi-GPU guide](multi-gpu.qmd).
|
|
||||||
@@ -98,8 +98,8 @@ fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading
|
|||||||
fsdp_state_dict_type | state_dict_type
|
fsdp_state_dict_type | state_dict_type
|
||||||
fsdp_use_orig_params | **REMOVED**
|
fsdp_use_orig_params | **REMOVED**
|
||||||
|
|
||||||
For more details, please see the migration guide in the [torchtitan repo](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md). In Axolotl,
|
|
||||||
if you were using the following FSDP1 config:
|
For example, if you were using the following FSDP1 config:
|
||||||
|
|
||||||
```{.yaml}
|
```{.yaml}
|
||||||
fsdp_version: 1
|
fsdp_version: 1
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ format:
|
|||||||
- [Llava-1.5](#sec-llava-15)
|
- [Llava-1.5](#sec-llava-15)
|
||||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||||
- [Gemma-3](#sec-gemma-3)
|
- [Gemma-3](#sec-gemma-3)
|
||||||
- [Gemma-3n](#sec-gemma-3n)
|
|
||||||
- [Qwen2-VL](#sec-qwen2-vl)
|
- [Qwen2-VL](#sec-qwen2-vl)
|
||||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||||
|
|
||||||
@@ -111,22 +110,6 @@ base_model: google/gemma-3-4b-it
|
|||||||
chat_template: gemma3
|
chat_template: gemma3
|
||||||
```
|
```
|
||||||
|
|
||||||
### Gemma-3n {#sec-gemma-3n}
|
|
||||||
|
|
||||||
::: {.callout-warning}
|
|
||||||
The model's initial loss and grad norm will be very high. We suspect this to be due to the Conv in the vision layers.
|
|
||||||
:::
|
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
Please make sure to install `timm` via `pip3 install timm==1.0.17`
|
|
||||||
:::
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base_model: google/gemma-3n-E2B-it
|
|
||||||
|
|
||||||
chat_template: gemma3n
|
|
||||||
```
|
|
||||||
|
|
||||||
### Qwen2-VL {#sec-qwen2-vl}
|
### Qwen2-VL {#sec-qwen2-vl}
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
@@ -149,9 +132,7 @@ For multi-modal datasets, we adopt an extended `chat_template` format similar to
|
|||||||
|
|
||||||
- A message is a list of `role` and `content`.
|
- A message is a list of `role` and `content`.
|
||||||
- `role` can be `system`, `user`, `assistant`, etc.
|
- `role` can be `system`, `user`, `assistant`, etc.
|
||||||
- `content` is a list of `type` and (`text`, `image`, `path`, `url`, `base64`, or `audio`).
|
- `content` is a list of `type` and (`text` or `image` or `path` or `url` or `base64`).
|
||||||
|
|
||||||
### Image
|
|
||||||
|
|
||||||
::: {.callout-note}
|
::: {.callout-note}
|
||||||
For backwards compatibility:
|
For backwards compatibility:
|
||||||
@@ -160,29 +141,15 @@ For backwards compatibility:
|
|||||||
- If `content` is a string, it will be converted to a list with `type` as `text`.
|
- If `content` is a string, it will be converted to a list with `type` as `text`.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
For image loading, you can use the following keys within `content` alongside `"type": "image"`:
|
For image loading, you can use the following keys within `content` alongside `"type": "image"`:
|
||||||
|
|
||||||
- `"path": "/path/to/image.jpg"`
|
- `"path": "/path/to/image.jpg"`
|
||||||
- `"url": "https://example.com/image.jpg"`
|
- `"url": "https://example.com/image.jpg"`
|
||||||
- `"base64": "..."`
|
- `"base64": "..."`
|
||||||
- `"image": PIL.Image`
|
- `"image": PIL.Image`
|
||||||
|
|
||||||
### Audio
|
|
||||||
|
|
||||||
For audio loading, you can use the following keys within `content` alongside `"type": "audio"`:
|
|
||||||
|
|
||||||
- `"path": "/path/to/audio.mp3"`
|
|
||||||
- `"url": "https://example.com/audio.mp3"`
|
|
||||||
- `"audio": np.ndarray`
|
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
|
|
||||||
You may need to install `librosa` via `pip3 install librosa==0.11.0`.
|
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
### Example
|
|
||||||
|
|
||||||
Here is an example of a multi-modal dataset:
|
Here is an example of a multi-modal dataset:
|
||||||
```json
|
```json
|
||||||
[
|
[
|
||||||
@@ -211,9 +178,3 @@ Here is an example of a multi-modal dataset:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
## FAQ
|
|
||||||
|
|
||||||
1. `PIL.UnidentifiedImageError: cannot identify image file ...`
|
|
||||||
|
|
||||||
`PIL` could not retrieve the file at `url` using `requests`. Please check for typo. One alternative reason is that the request is blocked by the server.
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ output_dir: ./outputs/lora-out
|
|||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ lora_target_linear: true
|
|||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter:
|
adapter:
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ output_dir: ./outputs/lora-out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out
|
|||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ lora_model_dir:
|
|||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -26,3 +26,5 @@ timeout: 86400
|
|||||||
# Preprocess specific configurations
|
# Preprocess specific configurations
|
||||||
memory_preprocess: 32
|
memory_preprocess: 32
|
||||||
timeout_preprocess: 14400
|
timeout_preprocess: 14400
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ lora_target_linear: true
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@631d646\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ output_dir: ./outputs/lora-out
|
|||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ output_dir: ./outputs/lora-out
|
|||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ lora_target_modules:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ lora_target_modules:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ lora_target_modules:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ lora_target_modules:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ lora_target_modules:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ lora_target_modules:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_target_linear: true
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ remove_unused_columns: false
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ lora_target_linear: true
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -1,19 +0,0 @@
|
|||||||
# Gemma-3n
|
|
||||||
|
|
||||||
## Requirements
|
|
||||||
|
|
||||||
In addition to Axolotl's requirements, Gemma-3n requires
|
|
||||||
|
|
||||||
```
|
|
||||||
pip3 install timm
|
|
||||||
```
|
|
||||||
|
|
||||||
If you will load audio datasets, please also install
|
|
||||||
|
|
||||||
```
|
|
||||||
pip3 install librosa
|
|
||||||
```
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
See example configs and the [multimodal doc](https://docs.axolotl.ai/docs/multimodal.html).
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
base_model: google/gemma-3n-E2B-it
|
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
cut_cross_entropy: true
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
# for use with fft to only train on language model layers
|
|
||||||
# unfrozen_parameters:
|
|
||||||
# - model.language_model.*
|
|
||||||
# - lm_head
|
|
||||||
# - embed_tokens
|
|
||||||
|
|
||||||
|
|
||||||
chat_template: gemma3n
|
|
||||||
eot_tokens:
|
|
||||||
- <end_of_turn>
|
|
||||||
datasets:
|
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
field_messages: conversations
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
# lora_target_linear: # Does not work with gemma3n currently
|
|
||||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
eval_sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
# flash_attention: true # Any attention impl does not work with gemma3n now
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch:
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
base_model: google/gemma-3n-E2B-it
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
cut_cross_entropy: true
|
|
||||||
|
|
||||||
# for use with fft to only train on language model layers
|
|
||||||
# unfrozen_parameters:
|
|
||||||
# - model.language_model.*
|
|
||||||
# - lm_head
|
|
||||||
# - embed_tokens
|
|
||||||
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
# gemma3 doesn't seem to play nice with ddp
|
|
||||||
ddp_find_unused_parameters: true
|
|
||||||
|
|
||||||
chat_template: gemma3n
|
|
||||||
eot_tokens:
|
|
||||||
- <end_of_turn>
|
|
||||||
|
|
||||||
# sample dataset below requires downloading audio/image in advance
|
|
||||||
# wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg
|
|
||||||
# wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga
|
|
||||||
datasets:
|
|
||||||
- path: Nanobit/text-vision-audio-2k-test
|
|
||||||
type: chat_template
|
|
||||||
data_files:
|
|
||||||
- dataset.jsonl
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.01
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
logging_steps: 1
|
|
||||||
# flash_attention: true # Any attention impl does not work with gemma3n now
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
base_model: google/gemma-3n-E2B-it
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
cut_cross_entropy: true
|
|
||||||
|
|
||||||
# for use with fft to only train on language model layers
|
|
||||||
# unfrozen_parameters:
|
|
||||||
# - model.language_model.*
|
|
||||||
# - lm_head
|
|
||||||
# - embed_tokens
|
|
||||||
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
# gemma3 doesn't seem to play nice with ddp
|
|
||||||
ddp_find_unused_parameters: true
|
|
||||||
|
|
||||||
chat_template: gemma3n
|
|
||||||
eot_tokens:
|
|
||||||
- <end_of_turn>
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.01
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
logging_steps: 1
|
|
||||||
# flash_attention: true # Any attention impl does not work with gemma3n now
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
@@ -17,7 +17,7 @@ lora_model_dir:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: true
|
eval_sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 16
|
lora_r: 16
|
||||||
lora_alpha: 32
|
lora_alpha: 32
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ save_safetensors: true
|
|||||||
adapter: qlora
|
adapter: qlora
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 16
|
lora_r: 16
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter:
|
adapter:
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ output_dir: ./outputs/lisa-out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter:
|
adapter:
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ output_dir: ./outputs/lora-out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 512
|
sequence_len: 512
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ datasets:
|
|||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:1%]
|
split: train[:1%]
|
||||||
dataset_prepared_path:
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
@@ -49,8 +50,8 @@ tf32: true
|
|||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
# flash_attention: true # use for text-only mode
|
flash_attention: true
|
||||||
sdp_attention: true
|
eager_attention:
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -1,76 +0,0 @@
|
|||||||
base_model: meta-llama/Llama-3.2-3B
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.liger.LigerPlugin
|
|
||||||
|
|
||||||
liger_rope: true
|
|
||||||
liger_rms_norm: true
|
|
||||||
liger_glu_activation: true
|
|
||||||
liger_layer_norm: true
|
|
||||||
liger_fused_linear_cross_entropy: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: yahma/alpaca-cleaned
|
|
||||||
type: alpaca
|
|
||||||
|
|
||||||
output_dir: ./outputs/fp8_out/
|
|
||||||
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
sequence_len: 512
|
|
||||||
|
|
||||||
flex_attention: true
|
|
||||||
flex_attn_compile_kwargs:
|
|
||||||
dynamic: false
|
|
||||||
mode: max-autotune-no-cudagraphs
|
|
||||||
|
|
||||||
torch_compile: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 16
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_fused
|
|
||||||
|
|
||||||
cosine_constant_lr_ratio: 0
|
|
||||||
cosine_min_lr_ratio: 1.0
|
|
||||||
learning_rate: 2e-5
|
|
||||||
save_only_model: true
|
|
||||||
|
|
||||||
fp8: true
|
|
||||||
fp8_enable_fsdp_float8_all_gather: true
|
|
||||||
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
weight_decay: 0.0
|
|
||||||
|
|
||||||
fsdp_version: 2
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: false
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: false
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <|end_of_text|>
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -22,7 +22,7 @@ datasets:
|
|||||||
output_dir: ./outputs/qat_out/
|
output_dir: ./outputs/qat_out/
|
||||||
|
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
sequence_len: 512
|
sequence_len: 512
|
||||||
|
|
||||||
flex_attention: true
|
flex_attention: true
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ output_dir: ./outputs/lora-out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ output_dir: ./outputs/lora-out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ output_dir: ./outputs/lora-out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ dataset_exact_deduplication: true
|
|||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 16
|
lora_r: 16
|
||||||
lora_alpha: 32
|
lora_alpha: 32
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ lora_model_dir:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: true
|
eval_sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 16
|
lora_r: 16
|
||||||
lora_alpha: 32
|
lora_alpha: 32
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ sample_packing: true
|
|||||||
sample_packing_sequentially: true
|
sample_packing_sequentially: true
|
||||||
curriculum_sampling: true
|
curriculum_sampling: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ lora_model_dir:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: true
|
eval_sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 16
|
lora_r: 16
|
||||||
lora_alpha: 32
|
lora_alpha: 32
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ output_dir: ./outputs/lora-out
|
|||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ lora_model_dir:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: true
|
eval_sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ adapter: qlora
|
|||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 16
|
lora_r: 16
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 512
|
sequence_len: 512
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
sequence_len: 4096 # up to 8k will work on a single H100
|
sequence_len: 4096 # up to 8k will work on a single H100
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
@@ -74,7 +74,7 @@ fsdp:
|
|||||||
fsdp_config:
|
fsdp_config:
|
||||||
fsdp_version: 2
|
fsdp_version: 2
|
||||||
fsdp_offload_params: false
|
fsdp_offload_params: false
|
||||||
# fsdp_cpu_ram_efficient_loading: true # does not work with load_in_8bit/4bit
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
|
||||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
sequence_len: 4096 # up to 8k will work on a single H100
|
sequence_len: 4096 # up to 8k will work on a single H100
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ datasets:
|
|||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:1%]
|
split: train[:1%]
|
||||||
dataset_prepared_path:
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ lora_model_dir:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ output_dir: ./outputs/out
|
|||||||
|
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ output_dir: ./outputs/dpo-qlora
|
|||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user