Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
208f8b253f add validation for DFT 2026-01-13 09:33:04 -05:00
Wing Lian
75ad1a9932 use dynamic finetuning with chunked cross entropy 2026-01-13 09:33:04 -05:00
19 changed files with 133 additions and 148 deletions

View File

@@ -15,11 +15,6 @@
<!--- Include details of your testing environment, tests ran to see how --> <!--- Include details of your testing environment, tests ran to see how -->
<!--- your change affects other areas of the code, etc. --> <!--- your change affects other areas of the code, etc. -->
## AI Usage Disclaimer
<!--- Was AI (e.g., ChatGPT, Claude, Copilot) used to generate or assist with this PR? -->
<!--- Please indicate: No / Yes (specify which tool and to what extent) -->
## Screenshots (if appropriate) ## Screenshots (if appropriate)
## Types of changes ## Types of changes

View File

@@ -21,8 +21,6 @@ jobs:
timeout-minutes: 480 timeout-minutes: 480
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: ubuntu-latest-m runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@@ -34,7 +32,6 @@ jobs:
pytorch: 2.8.0 pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -42,7 +39,6 @@ jobs:
pytorch: 2.9.0 pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -50,7 +46,6 @@ jobs:
pytorch: 2.9.1 pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130" - cuda: "130"
cuda_version: 13.0.0 cuda_version: 13.0.0
cudnn_version: "" cudnn_version: ""
@@ -58,7 +53,6 @@ jobs:
pytorch: 2.9.1 pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX" torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base" dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "128" # - cuda: "128"
# cuda_version: 12.8.1 # cuda_version: 12.8.1
# cudnn_version: "" # cudnn_version: ""
@@ -85,7 +79,7 @@ jobs:
axolotlai/axolotl-base axolotlai/axolotl-base
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@v2 uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }} if: ${{ github.event_name != 'pull_request' && secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -96,7 +90,7 @@ jobs:
with: with:
context: . context: .
file: ./docker/${{ matrix.dockerfile }} file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }} platforms: linux/amd64,linux/arm64
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
@@ -111,8 +105,6 @@ jobs:
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }} if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
timeout-minutes: 480 timeout-minutes: 480
runs-on: ubuntu-latest-m runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@@ -124,7 +116,6 @@ jobs:
pytorch: 2.8.0 pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -132,7 +123,6 @@ jobs:
pytorch: 2.9.1 pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""
@@ -140,7 +130,6 @@ jobs:
pytorch: 2.9.0 pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130" - cuda: "130"
cuda_version: 13.0.0 cuda_version: 13.0.0
cudnn_version: "" cudnn_version: ""
@@ -148,7 +137,6 @@ jobs:
pytorch: 2.9.1 pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX" torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base" dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -160,7 +148,6 @@ jobs:
axolotlai/axolotl-base-uv axolotlai/axolotl-base-uv
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@v2 uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -171,7 +158,6 @@ jobs:
with: with:
context: . context: .
file: ./docker/${{ matrix.dockerfile }} file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -20,26 +20,22 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.8.0
axolotl_extras: axolotl_extras:
platforms: "linux/amd64" is_latest: true
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.0 pytorch: 2.9.0
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 130 - cuda: 130
cuda_version: 13.0.0 cuda_version: 13.0.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -65,7 +61,7 @@ jobs:
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: . context: .
platforms: ${{ matrix.platforms }} platforms: linux/amd64,linux/arm64
build-args: | build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
@@ -92,26 +88,22 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.8.0
axolotl_extras: axolotl_extras:
platforms: "linux/amd64" is_latest: true
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.0 pytorch: 2.9.0
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 130 - cuda: 130
cuda_version: 13.0.0 cuda_version: 13.0.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -136,7 +128,7 @@ jobs:
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: . context: .
platforms: ${{ matrix.platforms }} platforms: linux/amd64,linux/arm64
build-args: | build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
@@ -157,11 +149,11 @@ jobs:
- cuda: 128 - cuda: 128
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.8.0
axolotl_extras: axolotl_extras:
is_latest: true is_latest:
- cuda: 130 - cuda: 128
cuda_version: 13.0.0 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras:

View File

@@ -47,8 +47,7 @@ jobs:
cuda_version: 13.0.0 cuda_version: 13.0.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.9.1 pytorch: 2.9.1
axolotl_extras: axolotl_extras: fbgemm-gpu
# axolotl_extras: fbgemm-gpu
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]

View File

@@ -6,7 +6,6 @@ ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS="" ARG AXOLOTL_ARGS=""
ARG CUDA="118" ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2" ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -21,17 +20,13 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64 # If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$TARGETARCH" = "arm64" ]; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi && \ fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \ python scripts/unsloth_install.py | sh && \
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \ python scripts/unsloth_install.py | sh && \
python scripts/cutcrossentropy_install.py | sh && \ python scripts/cutcrossentropy_install.py | sh && \
pip install pytest && \ pip install pytest && \
pip cache purge pip cache purge

View File

@@ -2,7 +2,6 @@ ARG CUDA_VERSION="12.6.3"
ARG CUDNN_VERSION="" ARG CUDNN_VERSION=""
ARG UBUNTU_VERSION="22.04" ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4 ARG MAX_JOBS=4
ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
@@ -32,35 +31,20 @@ ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel psutil \ RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install torch==${PYTORCH_VERSION} torchvision \ && uv pip install torch==${PYTORCH_VERSION} torchvision \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic && uv pip install awscli pydantic
RUN if [ "$TARGETARCH" = "amd64" ]; then \
uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main"; \
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
fi
RUN case "$PYTORCH_VERSION" in \ RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \ 2.9.[0-9]*) \
if [ "$TARGETARCH" = "amd64" ]; then \ if [ "$CUDA" = "128" ]; then \
if [ "$CUDA" = "128" ]; then \ wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ elif [ "$CUDA" = "130" ]; then \
elif [ "$CUDA" = "130" ]; then \ wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
fi \
elif [ "$TARGETARCH" = "arm64" ]; then \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
fi \
fi \ fi \
;; \ ;; \
esac esac

View File

@@ -1,7 +1,6 @@
base_model: google/gemma-3-1b-it base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
@@ -30,7 +29,7 @@ output_dir: ./outputs/out
adapter: qlora adapter: qlora
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0 lora_dropout: 0.05
lora_target_linear: true lora_target_linear: true
sequence_len: 2048 sequence_len: 2048

View File

@@ -1,7 +1,6 @@
base_model: google/gemma-3-270m-it base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
@@ -30,7 +29,7 @@ output_dir: ./outputs/out
adapter: qlora adapter: qlora
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0 lora_dropout: 0.05
lora_target_linear: true lora_target_linear: true
sequence_len: 2048 sequence_len: 2048

View File

@@ -2,7 +2,6 @@ base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too # Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true load_in_4bit: true
@@ -33,8 +32,8 @@ sample_packing: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0 lora_dropout: 0.05
lora_target_linear: true lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project: wandb_project:
wandb_entity: wandb_entity:

View File

@@ -31,7 +31,7 @@ pad_to_sequence_len: false
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0 lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project: wandb_project:

View File

@@ -11,11 +11,11 @@ liger-kernel==0.6.4
packaging==23.2 packaging==23.2
huggingface_hub>=0.36.0 huggingface_hub>=0.36.0
peft>=0.18.1 peft>=0.18.0
tokenizers>=0.22.1 tokenizers>=0.22.1
transformers==4.57.6 transformers==4.57.1
accelerate==1.12.0 accelerate==1.12.0
datasets==4.5.0 datasets==4.4.2
deepspeed>=0.18.3 deepspeed>=0.18.3
trl==0.25.1 trl==0.25.1
hf_xet==1.2.0 hf_xet==1.2.0

View File

@@ -26,7 +26,6 @@ def parse_requirements(extras_require_map):
_install_requires.append(line) _install_requires.append(line)
try: try:
xformers_version = [req for req in _install_requires if "xformers" in req][0] xformers_version = [req for req in _install_requires if "xformers" in req][0]
install_xformers = platform.machine() != "aarch64"
if "Darwin" in platform.system(): if "Darwin" in platform.system():
# skip packages not compatible with OSX # skip packages not compatible with OSX
skip_packages = [ skip_packages = [
@@ -63,63 +62,44 @@ def parse_requirements(extras_require_map):
else: else:
raise ValueError("Invalid version format") raise ValueError("Invalid version format")
torch_parts = torch_version.split("+")
if len(torch_parts) == 2:
torch_cuda_version = torch_parts[1]
_dependency_links.append(
f"https://download.pytorch.org/whl/{torch_cuda_version}"
)
if (major, minor) >= (2, 9): if (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu") extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [ extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
"fbgemm-gpu==1.4.0",
"fbgemm-gpu-genai==1.4.2",
]
extras_require_map["vllm"] = ["vllm==0.11.1"] extras_require_map["vllm"] = ["vllm==0.11.1"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
elif (major, minor) >= (2, 8): elif (major, minor) >= (2, 8):
extras_require_map.pop("fbgemm-gpu") extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"] extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
extras_require_map["vllm"] = ["vllm==0.11.0"] extras_require_map["vllm"] = ["vllm==0.11.0"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
elif (major, minor) >= (2, 7): elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if patch == 0: if patch == 0:
if install_xformers: _install_requires.append("xformers==0.0.30")
_install_requires.append("xformers==0.0.30")
# vllm 0.9.x is incompatible with latest transformers # vllm 0.9.x is incompatible with latest transformers
extras_require_map.pop("vllm") extras_require_map.pop("vllm")
else: else:
if install_xformers: _install_requires.append("xformers==0.0.31")
_install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm==0.10.1"] extras_require_map["vllm"] = ["vllm==0.10.1"]
elif (major, minor) >= (2, 6): elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if install_xformers: _install_requires.append("xformers==0.0.29.post3")
_install_requires.append("xformers==0.0.29.post3")
# since we only support 2.6.0+cu126 # since we only support 2.6.0+cu126
_dependency_links.append("https://download.pytorch.org/whl/cu126") _dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map.pop("vllm") extras_require_map.pop("vllm")
elif (major, minor) >= (2, 5): elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if install_xformers: if patch == 0:
if patch == 0: _install_requires.append("xformers==0.0.28.post2")
_install_requires.append("xformers==0.0.28.post2") else:
else: _install_requires.append("xformers>=0.0.28.post3")
_install_requires.append("xformers>=0.0.28.post3")
extras_require_map.pop("vllm") extras_require_map.pop("vllm")
elif (major, minor) >= (2, 4): elif (major, minor) >= (2, 4):
extras_require_map.pop("vllm") extras_require_map.pop("vllm")
if install_xformers: if patch == 0:
if patch == 0: _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers>=0.0.27")
_install_requires.append("xformers>=0.0.27") else:
else: _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.append("xformers==0.0.28.post1")
_install_requires.append("xformers==0.0.28.post1")
else: else:
raise ValueError("axolotl requires torch>=2.4") raise ValueError("axolotl requires torch>=2.4")

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package __path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.13.1" __version__ = "0.13.0.dev"

View File

@@ -153,9 +153,12 @@ class PatchManager:
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
if self.cfg.chunked_cross_entropy_num_chunks: if self.cfg.chunked_cross_entropy_num_chunks:
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks) patch_chunked_ce_loss_fn(
self.cfg.chunked_cross_entropy_num_chunks,
use_dft=self.cfg.use_dynamic_finetuning,
)
else: else:
patch_chunked_ce_loss_fn() patch_chunked_ce_loss_fn(use_dft=self.cfg.use_dynamic_finetuning)
def _apply_fsdp_patches(self): def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations.""" """Apply patches for FSDP configurations."""

View File

@@ -5,7 +5,6 @@ from typing import Type
import addict import addict
import torch import torch
import transformers
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -154,9 +153,6 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
This function determines the appropriate model config source, loads it, applies any This function determines the appropriate model config source, loads it, applies any
necessary overrides, and validates it for compatibility with the `axolotl` config. necessary overrides, and validates it for compatibility with the `axolotl` config.
If `cfg.cls_model_config` is set, a custom config class from transformers will be
used instead of `AutoConfig` (e.g., 'LlamaConfig', 'MistralConfig').
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
@@ -178,13 +174,8 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
if cfg.num_labels: if cfg.num_labels:
# num_labels is used to initialize classifier models # num_labels is used to initialize classifier models
config_kwargs["num_labels"] = cfg.num_labels config_kwargs["num_labels"] = cfg.num_labels
config_cls = AutoConfig
if cfg.cls_model_config:
config_cls = getattr(transformers, cfg.cls_model_config)
try: try:
model_config = config_cls.from_pretrained( model_config = AutoConfig.from_pretrained(
model_config_name, model_config_name,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
**config_kwargs, **config_kwargs,

View File

@@ -16,10 +16,16 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390 For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
""" """
def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100): def __init__(
self,
num_output_chunks: int = 8,
ignore_index: int = -100,
use_dft: bool = False,
):
super().__init__() super().__init__()
self.num_output_chunks = num_output_chunks self.num_output_chunks = num_output_chunks
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.use_dft = use_dft
def compute_cross_entropy( def compute_cross_entropy(
self, self,
@@ -30,10 +36,30 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
""" """
Upcast logits to fp32 and compute cross entropy loss. Upcast logits to fp32 and compute cross entropy loss.
""" """
return F.cross_entropy( ce_loss = F.cross_entropy(
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum" logits.float(), labels, ignore_index=self.ignore_index, reduction="none"
) )
if self.use_dft:
# Compute probabilities and gather the ones corresponding to labels
with torch.no_grad(): # Stop gradient
probs = torch.softmax(logits.float(), dim=-1)
# Create mask for valid tokens (not ignore_index)
valid_mask = labels != self.ignore_index
# Gather probabilities for the correct tokens
label_probs = probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
# Apply mask to only scale valid tokens
label_probs = label_probs * valid_mask
# Avoid multiplication by 0 for ignored tokens
label_probs = torch.where(
valid_mask, label_probs, torch.ones_like(label_probs)
)
# Scale the loss by the probability (DFT)
ce_loss = ce_loss * label_probs
return ce_loss.sum()
def forward( def forward(
self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum" self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum"
) -> torch.Tensor: ) -> torch.Tensor:
@@ -71,16 +97,20 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
return total_loss / total_elements return total_loss / total_elements
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): def _build_chunked_ce_loss_fn(
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index) num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index, use_dft)
loss_fn_ce.compute_cross_entropy = torch.compile( loss_fn_ce.compute_cross_entropy = torch.compile(
loss_fn_ce.compute_cross_entropy, backend="inductor" loss_fn_ce.compute_cross_entropy, backend="inductor"
) )
return loss_fn_ce return loss_fn_ce
def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100): def get_causal_lm_loss(
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index) num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index, use_dft)
def chunked_fix_cross_entropy( def chunked_fix_cross_entropy(
source, source,
@@ -124,10 +154,14 @@ def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
return for_causal_lm_chunked_loss return for_causal_lm_chunked_loss
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): def patch_chunked_ce_loss_fn(
num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
import transformers.loss.loss_utils import transformers.loss.loss_utils
for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index) for_causal_lm_chunked_loss = get_causal_lm_loss(
num_output_chunks, ignore_index, use_dft
)
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = ( transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
for_causal_lm_chunked_loss for_causal_lm_chunked_loss

View File

@@ -664,6 +664,13 @@ class AxolotlInputConfig(
}, },
) )
use_dynamic_finetuning: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use dynamic fine-tuning for scaled SFT gradients."
},
)
chunked_cross_entropy: bool | None = Field( chunked_cross_entropy: bool | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={

View File

@@ -25,12 +25,7 @@ class ModelInputConfig(BaseModel):
"description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model" "description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model"
}, },
) )
cls_model_config: str | None = Field( cls_model_config: str | None = None
default=None,
json_schema_extra={
"description": "transformers config class (e.g., 'LlamaConfig', 'MistralConfig'). Defaults to AutoConfig."
},
)
tokenizer_config: str | None = Field( tokenizer_config: str | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={

View File

@@ -434,6 +434,18 @@ class TrainingValidationMixin:
return data return data
@model_validator(mode="before")
@classmethod
def check_ao_optim_fsdp2_offload(cls, data):
if data.get("fsdp_config") and data.get("fsdp_config", {}).get(
"offload_params"
):
if data.get("optimizer") in ["adamw_torch_8bit", "adamw_torch_4bit"]:
raise ValueError(
"low bit ao optimizers is not supported with FSDP2 w/ offload_params."
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_use_reentrant_mismatch(cls, data): def check_use_reentrant_mismatch(cls, data):
@@ -557,6 +569,20 @@ class TrainingValidationMixin:
return data return data
class CELossValidationMixin:
"""Validation methods related to CE loss configuration."""
@model_validator(mode="before")
@classmethod
def check_dft_loss_fn(cls, data):
if data.get("use_dynamic_finetuning"):
if not data.get("chunked_cross_entropy"):
raise ValueError(
"`use_dynamic_finetuning` requires `chunked_cross_entropy`"
)
return data
class LoRAValidationMixin: class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration.""" """Validation methods related to LoRA/QLoRA configuration."""
@@ -1464,6 +1490,7 @@ class ValidationMixin(
DatasetValidationMixin, DatasetValidationMixin,
AttentionValidationMixin, AttentionValidationMixin,
TrainingValidationMixin, TrainingValidationMixin,
CELossValidationMixin,
LoRAValidationMixin, LoRAValidationMixin,
RLValidationMixin, RLValidationMixin,
OptimizationValidationMixin, OptimizationValidationMixin,