Compare commits

..

1 Commits

Author SHA1 Message Date
Salman Mohammadi
7a08e4117a wip ao upgrade 2026-01-05 18:23:33 +00:00
44 changed files with 83 additions and 6494 deletions

View File

@@ -15,11 +15,6 @@
<!--- Include details of your testing environment, tests ran to see how -->
<!--- your change affects other areas of the code, etc. -->
## AI Usage Disclaimer
<!--- Was AI (e.g., ChatGPT, Claude, Copilot) used to generate or assist with this PR? -->
<!--- Please indicate: No / Yes (specify which tool and to what extent) -->
## Screenshots (if appropriate)
## Types of changes

View File

@@ -21,8 +21,6 @@ jobs:
timeout-minutes: 480
# this job needs to be run on self-hosted GPU runners...
runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy:
fail-fast: false
matrix:
@@ -34,7 +32,6 @@ jobs:
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -42,7 +39,6 @@ jobs:
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -50,7 +46,6 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
@@ -58,7 +53,6 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
@@ -85,7 +79,6 @@ jobs:
axolotlai/axolotl-base
- name: Login to Docker Hub
uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -96,7 +89,6 @@ jobs:
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
@@ -111,8 +103,6 @@ jobs:
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
timeout-minutes: 480
runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy:
fail-fast: false
matrix:
@@ -124,7 +114,6 @@ jobs:
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -132,7 +121,6 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -140,7 +128,6 @@ jobs:
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
@@ -148,7 +135,6 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -160,7 +146,6 @@ jobs:
axolotlai/axolotl-base-uv
- name: Login to Docker Hub
uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -171,7 +156,6 @@ jobs:
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -20,26 +20,22 @@ jobs:
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
platforms: "linux/amd64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
# - cuda: 130
# cuda_version: 13.0.0
# python_version: "3.11"
# pytorch: 2.9.1
# axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -65,7 +61,6 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
@@ -92,26 +87,22 @@ jobs:
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
platforms: "linux/amd64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
# - cuda: 130
# cuda_version: 13.0.0
# python_version: "3.11"
# pytorch: 2.9.1
# axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -136,7 +127,6 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
@@ -157,11 +147,11 @@ jobs:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.8.0
axolotl_extras:
is_latest: true
- cuda: 130
cuda_version: 13.0.0
is_latest:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
@@ -190,7 +180,6 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/amd64,linux/arm64
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}

View File

@@ -43,13 +43,6 @@ jobs:
axolotl_extras: fbgemm-gpu
num_gpus: 2
nightly_build: "true"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras: fbgemm-gpu
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:

View File

@@ -316,12 +316,6 @@ jobs:
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -6,7 +6,6 @@ ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -21,17 +20,13 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \ python scripts/unsloth_install.py | sh && \
python scripts/unsloth_install.py | sh && \
python scripts/cutcrossentropy_install.py | sh && \
pip install pytest && \
pip cache purge

View File

@@ -2,16 +2,14 @@ ARG CUDA_VERSION="11.8.0"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG TARGETARCH
ARG PYTHON_VERSION="3.11"
ARG PYTHON_VERSION="3.10"
ARG PYTORCH_VERSION="2.1.2"
ARG CUDA="128"
ARG CUDA="118"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
@@ -24,17 +22,11 @@ RUN apt-get update \
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
&& rm -rf /var/cache/apt/archives \
&& rm -rf /var/lib/apt/lists/* \
&& if [ "$TARGETARCH" = "amd64" ]; then \
MINICONDA_ARCH="x86_64"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
MINICONDA_ARCH="aarch64"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi \
&& wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh -b \
&& rm -f Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
@@ -59,34 +51,8 @@ RUN git lfs install --skip-repo && \
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge
RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \
if [ "$CUDA" = "128" ]; then \
if [ "$TARGETARCH" = "amd64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl"; \
WHL_VERSION="v0.5.4"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl"; \
WHL_VERSION="v0.6.4"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi; \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
pip3 install --no-cache-dir ${WHL_FILE}; \
rm ${WHL_FILE}; \
elif [ "$CUDA" = "130" ]; then \
if [ "$TARGETARCH" = "amd64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl"; \
WHL_VERSION="v0.5.4"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl"; \
WHL_VERSION="v0.6.4"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi; \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
pip3 install --no-cache-dir ${WHL_FILE}; \
rm ${WHL_FILE}; \
fi \
;; \
esac
RUN if [ "$PYTORCH_VERSION" =~ ^2\.9\.[0-9]+$ ] && [ "$CUDA" = "128" ] ; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
fi

View File

@@ -2,7 +2,6 @@ ARG CUDA_VERSION="12.6.3"
ARG CUDNN_VERSION=""
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
@@ -32,35 +31,12 @@ ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic
RUN if [ "$TARGETARCH" = "amd64" ]; then \
uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main"; \
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
fi
RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \
if [ "$TARGETARCH" = "amd64" ]; then \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
fi \
elif [ "$TARGETARCH" = "arm64" ]; then \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
fi \
fi \
;; \
esac

View File

@@ -52,7 +52,6 @@ gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
scaling_softmax: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -1,53 +0,0 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3
eot_tokens:
- <end_of_turn>
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.05
output_dir: ./outputs/gemma-3-1b-fft-dft
sequence_len: 2048
use_dynamic_finetuning: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -1,7 +1,6 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -30,7 +29,7 @@ output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048

View File

@@ -1,7 +1,6 @@
base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -30,7 +29,7 @@ output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048

View File

@@ -2,7 +2,6 @@ base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true
@@ -33,8 +32,8 @@ sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_linear: true
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -31,7 +31,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:

View File

@@ -59,7 +59,6 @@ gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
scaling_softmax: true
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -1,285 +0,0 @@
# SwanLab Integration Examples
This directory contains example configurations demonstrating SwanLab integration with Axolotl.
## Examples Overview
### 1. DPO with Completion Logging
**File**: `dpo-swanlab-completions.yml`
Demonstrates DPO (Direct Preference Optimization) training with RLHF completion table logging.
**Features**:
- Basic SwanLab experiment tracking
- Completion table logging (prompts, chosen/rejected responses, rewards)
- Memory-bounded buffer for long training runs
- Cloud sync configuration
**Best for**: RLHF practitioners who want to analyze model outputs qualitatively
**Quick start**:
```bash
export SWANLAB_API_KEY=your-api-key
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
```
---
### 2. LoRA with Performance Profiling
**File**: `lora-swanlab-profiling.yml`
Demonstrates standard LoRA fine-tuning with performance profiling enabled.
**Features**:
- SwanLab experiment tracking
- Automatic profiling of trainer methods
- Profiling metrics visualization
- Performance optimization guidance
**Best for**: Engineers optimizing training performance and comparing different configurations
**Quick start**:
```bash
export SWANLAB_API_KEY=your-api-key
accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
```
---
### 3. Full-Featured DPO Production Setup
**File**: `dpo-swanlab-full-featured.yml`
Comprehensive production-ready configuration with ALL SwanLab features enabled.
**Features**:
- Experiment tracking with team workspace
- RLHF completion logging
- Performance profiling
- Lark (Feishu) team notifications
- Private deployment support
- Production checklist and troubleshooting
**Best for**: Production RLHF training with team collaboration
**Quick start**:
```bash
export SWANLAB_API_KEY=your-api-key
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
export SWANLAB_LARK_SECRET=your-webhook-secret
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
```
---
### 4. Custom Trainer Profiling (Python)
**File**: `custom_trainer_profiling.py`
Python code examples showing how to add SwanLab profiling to custom trainers.
**Features**:
- `@swanlab_profile` decorator examples
- Context manager profiling for fine-grained timing
- `ProfilingConfig` for advanced filtering and throttling
- Multiple profiling patterns and best practices
**Best for**: Advanced users creating custom trainers
**Usage**:
```python
from custom_trainer_profiling import CustomTrainerWithProfiling
# See file for detailed examples and patterns
```
---
## Feature Matrix
| Example | Tracking | Completion Logging | Profiling | Lark Notifications | Team Workspace |
|---------|----------|-------------------|-----------|-------------------|----------------|
| dpo-swanlab-completions.yml | ✅ | ✅ | ✅ (auto) | (commented) | (commented) |
| lora-swanlab-profiling.yml | ✅ | (disabled) | ✅ (auto) | (commented) | (commented) |
| dpo-swanlab-full-featured.yml | ✅ | ✅ | ✅ (auto) | ✅ | ✅ |
| custom_trainer_profiling.py | N/A | N/A | ✅ (manual) | N/A | N/A |
---
## Configuration Quick Reference
### Basic SwanLab Setup
```yaml
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
use_swanlab: true
swanlab_project: my-project
swanlab_experiment_name: my-experiment
swanlab_mode: cloud # cloud, local, offline, disabled
```
### RLHF Completion Logging
```yaml
swanlab_log_completions: true
swanlab_completion_log_interval: 100 # Log every 100 steps
swanlab_completion_max_buffer: 128 # Memory-bounded buffer
```
### Lark Team Notifications
```yaml
swanlab_lark_webhook_url: https://open.feishu.cn/...
swanlab_lark_secret: your-webhook-secret # Required for production
```
### Team Workspace
```yaml
swanlab_workspace: my-research-team
```
### Private Deployment
```yaml
swanlab_web_host: https://swanlab.yourcompany.com
swanlab_api_host: https://api.swanlab.yourcompany.com
```
---
## Authentication
### Recommended: Environment Variable
```bash
export SWANLAB_API_KEY=your-api-key
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
export SWANLAB_LARK_SECRET=your-webhook-secret
```
### Alternative: Config File (less secure)
```yaml
swanlab_api_key: your-api-key
swanlab_lark_webhook_url: https://open.feishu.cn/...
swanlab_lark_secret: your-webhook-secret
```
---
## Common Use Cases
### Use Case 1: Migrate from WandB to SwanLab
Start with `lora-swanlab-profiling.yml`, add your model/dataset config, disable WandB:
```yaml
use_swanlab: true
use_wandb: false
```
### Use Case 2: Analyze DPO Model Outputs
Use `dpo-swanlab-completions.yml`, adjust completion logging interval based on your training length:
```yaml
swanlab_completion_log_interval: 50 # More frequent for short training
swanlab_completion_log_interval: 200 # Less frequent for long training
```
### Use Case 3: Optimize Training Performance
Use `lora-swanlab-profiling.yml`, run multiple experiments with different optimizations:
- Baseline: `flash_attention: false, gradient_checkpointing: false`
- Flash Attention: `flash_attention: true`
- Gradient Checkpointing: `gradient_checkpointing: true`
- Both: `flash_attention: true, gradient_checkpointing: true`
Compare profiling metrics in SwanLab dashboard.
### Use Case 4: Production RLHF with Team Collaboration
Use `dpo-swanlab-full-featured.yml`, set up team workspace and Lark notifications:
```yaml
swanlab_workspace: ml-team
swanlab_lark_webhook_url: ...
swanlab_lark_secret: ...
```
---
## Viewing Your Experiments
### Cloud Mode
Visit [https://swanlab.cn](https://swanlab.cn) and navigate to your project.
**Dashboard sections**:
- **Metrics**: Training loss, learning rate, profiling metrics
- **Tables**: RLHF completions (for DPO/KTO/ORPO/GRPO)
- **Config**: Hyperparameters and configuration
- **System**: Resource usage (GPU, memory, CPU)
- **Files**: Logged artifacts
### Local Mode
```bash
swanlab watch ./swanlog
# Open browser to http://localhost:5092
```
---
## Troubleshooting
### SwanLab not initializing
```bash
# Check API key
echo $SWANLAB_API_KEY
# Verify SwanLab is installed
pip show swanlab
# Check config
grep -A 5 "use_swanlab" your-config.yml
```
### Completions not appearing
- Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
- Check `swanlab_log_completions: true`
- Wait for `swanlab_completion_log_interval` steps
- Look for "Registered SwanLab RLHF completion logging" in logs
### Lark notifications not working
- Test webhook manually: `curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...`
- Verify `SWANLAB_LARK_SECRET` is set correctly
- Check bot is added to Lark group chat
- Look for "Registered Lark notification callback" in logs
### Profiling metrics not appearing
- Verify `use_swanlab: true`
- Check SwanLab is initialized (look for init log message)
- Profiling metrics are under "profiling/" namespace
- Profiling auto-enabled when SwanLab is enabled
---
## Performance Notes
### Overhead Comparison
| Feature | Overhead per Step | Memory Usage |
|---------|------------------|--------------|
| Basic tracking | < 0.1% | ~10 MB |
| Completion logging | < 0.5% | ~64 KB (buffer=128) |
| Profiling | < 0.1% | ~1 KB |
| **Total** | **< 0.7%** | **~10 MB** |
### Best Practices
1. Use ONE logging tool in production (disable WandB/MLflow when using SwanLab)
2. Adjust completion log interval based on training length (100-200 steps)
3. Keep completion buffer size reasonable (128-512)
4. Profile critical path methods first (training_step, compute_loss)
5. Use ProfilingConfig to throttle high-frequency operations
---
## Further Reading
- **Full Documentation**: [src/axolotl/integrations/swanlab/README.md](../../src/axolotl/integrations/swanlab/README.md)
- **SwanLab Docs**: [https://docs.swanlab.cn](https://docs.swanlab.cn)
- **Axolotl Docs**: [https://axolotl-ai-cloud.github.io/axolotl/](https://axolotl-ai-cloud.github.io/axolotl/)
- **DPO Paper**: [Direct Preference Optimization](https://arxiv.org/abs/2305.18290)
---
## Contributing
Found an issue or have an improvement? Please submit a PR or open an issue:
- [Axolotl Issues](https://github.com/axolotl-ai-cloud/axolotl/issues)
- [SwanLab Issues](https://github.com/SwanHubX/SwanLab/issues)

View File

@@ -1,299 +0,0 @@
"""Example: Custom Trainer with SwanLab Profiling
This example demonstrates how to add SwanLab profiling to your custom trainer.
Features:
- @swanlab_profile decorator for automatic profiling
- swanlab_profiling_context for fine-grained profiling
- ProfilingConfig for advanced filtering and throttling
Usage:
1. Create your custom trainer extending AxolotlTrainer
2. Add @swanlab_profile decorators to methods you want to profile
3. Use swanlab_profiling_context for fine-grained profiling within methods
4. Enable SwanLab in your config (use_swanlab: true)
See also:
- examples/swanlab/lora-swanlab-profiling.yml for config
- src/axolotl/integrations/swanlab/profiling.py for implementation
"""
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.integrations.swanlab.profiling import (
ProfilingConfig,
swanlab_profile,
swanlab_profiling_context,
swanlab_profiling_context_advanced,
)
class CustomTrainerWithProfiling(AxolotlTrainer):
"""Custom trainer with SwanLab profiling enabled.
This trainer demonstrates three profiling patterns:
1. Decorator-based profiling (@swanlab_profile)
2. Context manager profiling (swanlab_profiling_context)
3. Advanced profiling with filtering (ProfilingConfig)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Create custom profiling config for high-frequency operations
self.fast_op_config = ProfilingConfig(
enabled=True,
min_duration_ms=0.5, # Only log if duration > 0.5ms
log_interval=50, # Log every 50th call
)
# ========================================================================
# Pattern 1: Decorator-based Profiling
# ========================================================================
# Best for: Methods you always want to profile
# Overhead: ~2-5 microseconds per call (negligible)
@swanlab_profile
def training_step(self, model, inputs):
"""Main training step - always profile.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.training_step
"""
return super().training_step(model, inputs)
@swanlab_profile
def compute_loss(self, model, inputs, return_outputs=False):
"""Loss computation - always profile.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.compute_loss
"""
return super().compute_loss(model, inputs, return_outputs)
@swanlab_profile
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
"""Prediction step - always profile.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prediction_step
"""
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
# ========================================================================
# Pattern 2: Fine-grained Context Manager Profiling
# ========================================================================
# Best for: Profiling specific code blocks within a method
# Use case: When you want to profile forward vs backward separately
def complex_training_step(self, model, inputs):
"""Training step with fine-grained profiling.
Profiling metrics:
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
- profiling/Time taken: CustomTrainerWithProfiling.optimizer_step
"""
# Profile just the forward pass
with swanlab_profiling_context(self, "forward_pass"):
outputs = model(**inputs)
loss = outputs.loss
# Profile just the backward pass
with swanlab_profiling_context(self, "backward_pass"):
loss.backward()
# Profile optimizer step
with swanlab_profiling_context(self, "optimizer_step"):
self.optimizer.step()
self.optimizer.zero_grad()
return outputs
# ========================================================================
# Pattern 3: Advanced Profiling with Filtering
# ========================================================================
# Best for: High-frequency operations where you want to throttle logging
# Use case: Methods called 100+ times per step
def _prepare_inputs(self, inputs):
"""Prepare inputs - throttled profiling.
This method is called frequently (once per batch), so we throttle
profiling to reduce overhead:
- Only log if duration > 0.5ms (skip very fast operations)
- Only log every 50th call (reduce logging frequency)
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_inputs
"""
with swanlab_profiling_context_advanced(
self, "prepare_inputs", config=self.fast_op_config
):
return super()._prepare_inputs(inputs)
def _prepare_input_for_model(self, input_ids):
"""Another high-frequency operation - throttled profiling.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_input_for_model
"""
with swanlab_profiling_context_advanced(
self, "prepare_input_for_model", config=self.fast_op_config
):
# Your custom input preparation logic
return input_ids
# ========================================================================
# Pattern 4: Exception-safe Profiling
# ========================================================================
# Profiling is exception-safe: duration is logged even if method raises
@swanlab_profile
def potentially_failing_method(self):
"""This method may raise an exception.
SwanLab profiling will still log the duration before re-raising.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.potentially_failing_method
"""
# Do some work
result = self._do_risky_computation()
# If this raises, profiling duration is still logged
if result < 0:
raise ValueError("Invalid result")
return result
def _do_risky_computation(self):
"""Placeholder for risky computation."""
return 42
# ============================================================================
# Advanced Example: Custom ProfilingConfig Per Method
# ============================================================================
class AdvancedProfilingTrainer(AxolotlTrainer):
"""Trainer with method-specific profiling configurations."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Different profiling configs for different method types
self.critical_path_config = ProfilingConfig(
enabled=True,
min_duration_ms=0.0, # Log everything on critical path
log_interval=1, # Log every call
)
self.fast_path_config = ProfilingConfig(
enabled=True,
min_duration_ms=1.0, # Only log if > 1ms
log_interval=100, # Log every 100th call
)
self.debug_config = ProfilingConfig(
enabled=True,
min_duration_ms=0.0, # Log everything
log_interval=1, # Log every call
)
def training_step(self, model, inputs):
"""Critical path - log everything."""
with swanlab_profiling_context_advanced(
self, "training_step", config=self.critical_path_config
):
return super().training_step(model, inputs)
def _prepare_inputs(self, inputs):
"""Fast path - throttle logging."""
with swanlab_profiling_context_advanced(
self, "prepare_inputs", config=self.fast_path_config
):
return super()._prepare_inputs(inputs)
def _debug_method(self, data):
"""Debug-only method - verbose logging."""
with swanlab_profiling_context_advanced(
self, "debug_method", config=self.debug_config
):
# Your debug logic
pass
# ============================================================================
# How to Use This Custom Trainer
# ============================================================================
"""
To use this custom trainer:
1. Save this file to your project (e.g., my_custom_trainer.py)
2. Create a config file that uses your custom trainer:
# config.yml
base_model: NousResearch/Llama-3.2-1B
# ... other config ...
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
use_swanlab: true
swanlab_project: my-profiling-experiment
# Optional: Specify custom trainer
# (Or modify axolotl to use your custom trainer class)
3. Run training:
export SWANLAB_API_KEY=your-api-key
accelerate launch -m axolotl.cli.train config.yml
4. View profiling metrics in SwanLab dashboard:
- profiling/Time taken: CustomTrainerWithProfiling.training_step
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
- etc.
5. Compare profiling metrics across runs:
- Run baseline without optimizations
- Run with flash_attention enabled
- Run with gradient_checkpointing enabled
- Compare profiling metrics to see performance impact
"""
# ============================================================================
# Tips for Effective Profiling
# ============================================================================
"""
1. Profile the critical path first:
- training_step, compute_loss, prediction_step
- These methods are called most frequently and have biggest impact
2. Use throttling for high-frequency operations:
- Methods called 100+ times per step
- Use log_interval=50 or log_interval=100
- Reduces profiling overhead and dashboard clutter
3. Filter noise with min_duration_ms:
- Set min_duration_ms=1.0 to skip very fast operations
- Focus on operations that actually take time
4. Compare across runs:
- Run same config multiple times to check consistency
- Compare different optimization strategies
- Track profiling trends over time
5. Monitor distributed training:
- Check for per-rank timing differences
- Look for stragglers (slower ranks)
- Identify synchronization bottlenecks
6. Disable profiling in production:
- from axolotl.integrations.swanlab.profiling import DEFAULT_PROFILING_CONFIG
- DEFAULT_PROFILING_CONFIG.enabled = False
7. Exception handling:
- Profiling is exception-safe
- Duration logged even if method raises
- Useful for debugging methods that fail intermittently
"""

View File

@@ -1,168 +0,0 @@
# SwanLab DPO Training Example with Completion Logging
#
# This example demonstrates DPO (Direct Preference Optimization) training
# with SwanLab integration for experiment tracking and completion table logging.
#
# Features enabled:
# - SwanLab experiment tracking
# - RLHF completion table logging (prompts, chosen/rejected responses, rewards)
# - Lark (Feishu) team notifications (optional)
#
# To run:
# export SWANLAB_API_KEY=your-api-key
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
# Model Configuration
base_model: meta-llama/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
# Quantization
load_in_8bit: true
load_in_4bit: false
# LoRA Configuration
adapter: lora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
# DPO Configuration
chat_template: llama3
rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_property_mappings:
role: role
content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
# Dataset and Output
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/dpo-swanlab-out
# Training Configuration
sequence_len: 4096
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 4
num_epochs: 4
# Optimization
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
warmup_ratio: 0.1
weight_decay: 0.0
# Precision
bf16: auto
tf32: false
# Performance
gradient_checkpointing: true
flash_attention: true
# Checkpointing and Logging
logging_steps: 1
evals_per_epoch: 4
saves_per_epoch: 1
# ============================================================================
# SwanLab Integration
# ============================================================================
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
# Basic SwanLab Configuration
use_swanlab: true
swanlab_project: dpo-training
swanlab_experiment_name: llama-3-dpo-completions-demo
swanlab_description: "DPO training with completion table logging"
swanlab_mode: cloud # Options: cloud, local, offline, disabled
# SwanLab Authentication
# Recommended: Set via environment variable
# export SWANLAB_API_KEY=your-api-key
# Or set in config (less secure):
# swanlab_api_key: your-api-key
# Optional: Team workspace
# swanlab_workspace: my-research-team
# ============================================================================
# RLHF Completion Table Logging
# ============================================================================
#
# Automatically logs model completions to SwanLab for qualitative analysis:
# - Prompts from your DPO dataset
# - Chosen responses (preferred)
# - Rejected responses (non-preferred)
# - Reward differences
#
# View the table in SwanLab dashboard under "rlhf_completions"
swanlab_log_completions: true
swanlab_completion_log_interval: 100 # Log every 100 training steps
swanlab_completion_max_buffer: 128 # Keep last 128 completions in memory
# Memory Usage Notes:
# - Buffer size 128: ~64 KB (default, recommended)
# - Buffer size 512: ~256 KB (for more historical completions)
# - Buffer size 1024: ~512 KB (maximum for very long training runs)
# Performance Notes:
# - Completion logging overhead: < 0.5% per training step
# - Only logs every N steps to minimize impact
# - Memory-bounded buffer prevents memory leaks
# ============================================================================
# Optional: Lark (Feishu) Team Notifications
# ============================================================================
#
# Get real-time training notifications in your team chat
# Uncomment to enable:
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
# swanlab_lark_secret: your-webhook-secret # Recommended for production
# Notifications sent for:
# - Training start
# - Training completion
# - Training errors
# - Metric milestones (if configured)
# ============================================================================
# Optional: Private SwanLab Deployment
# ============================================================================
#
# For enterprise users with private SwanLab deployment:
# swanlab_web_host: https://swanlab.yourcompany.com
# swanlab_api_host: https://api.swanlab.yourcompany.com
# ============================================================================
# Disable WandB if you're migrating from it
# ============================================================================
# wandb_project:
# wandb_entity:
# use_wandb: false

View File

@@ -1,329 +0,0 @@
# SwanLab Full-Featured DPO Training Example
#
# This example demonstrates ALL SwanLab integration features:
# - Experiment tracking with cloud sync
# - RLHF completion table logging
# - Performance profiling
# - Lark (Feishu) team notifications
# - Team workspace collaboration
#
# Use this as a reference for production RLHF training setups.
#
# To run:
# export SWANLAB_API_KEY=your-api-key
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
# export SWANLAB_LARK_SECRET=your-webhook-secret
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
# ============================================================================
# Model Configuration
# ============================================================================
base_model: meta-llama/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
# Quantization for efficient training
load_in_8bit: true
load_in_4bit: false
# ============================================================================
# LoRA Configuration
# ============================================================================
adapter: lora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true # Target all linear layers
# ============================================================================
# DPO (Direct Preference Optimization) Configuration
# ============================================================================
chat_template: llama3
rl: dpo # Enable DPO trainer
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_property_mappings:
role: role
content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
# ============================================================================
# Dataset and Output Configuration
# ============================================================================
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/dpo-swanlab-full-featured-out
# ============================================================================
# Training Configuration
# ============================================================================
sequence_len: 4096
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 4
num_epochs: 4
# ============================================================================
# Optimization
# ============================================================================
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
warmup_ratio: 0.1
weight_decay: 0.0
# ============================================================================
# Precision and Performance
# ============================================================================
bf16: auto
tf32: false
gradient_checkpointing: true
flash_attention: true
# ============================================================================
# Checkpointing and Logging
# ============================================================================
logging_steps: 1
evals_per_epoch: 4
saves_per_epoch: 1
# ============================================================================
# SwanLab Integration - Full Configuration
# ============================================================================
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
# ------------------------------------------------------------------------------
# Basic SwanLab Configuration
# ------------------------------------------------------------------------------
use_swanlab: true
swanlab_project: dpo-production
swanlab_experiment_name: llama-3-dpo-full-featured-v1
swanlab_description: |
Production DPO training with all SwanLab features enabled:
- Completion table logging for qualitative analysis
- Performance profiling for optimization
- Lark notifications for team collaboration
swanlab_mode: cloud # Options: cloud, local, offline, disabled
# ------------------------------------------------------------------------------
# Team Collaboration
# ------------------------------------------------------------------------------
# Workspace for team collaboration (shared experiments)
swanlab_workspace: ml-research-team
# Authentication (recommended: use environment variable)
# export SWANLAB_API_KEY=your-api-key
# Or set in config (less secure):
# swanlab_api_key: your-api-key
# ------------------------------------------------------------------------------
# RLHF Completion Table Logging
# ------------------------------------------------------------------------------
# Automatically logs model completions for qualitative analysis:
# - Prompts from your DPO dataset
# - Chosen responses (preferred)
# - Rejected responses (non-preferred)
# - Reward differences
#
# View in SwanLab dashboard under "rlhf_completions" table
swanlab_log_completions: true
swanlab_completion_log_interval: 100 # Log every 100 steps
swanlab_completion_max_buffer: 256 # Larger buffer for long training runs
# Buffer size recommendations:
# - 128: Default, ~64 KB memory (recommended for most cases)
# - 256: ~128 KB memory (this config, good for longer training)
# - 512: ~256 KB memory (maximum for very long runs)
# ------------------------------------------------------------------------------
# Lark (Feishu) Team Notifications
# ------------------------------------------------------------------------------
# Get real-time training notifications in your team chat
#
# Notifications sent for:
# - Training start
# - Training completion
# - Training errors
# - Metric milestones (if configured)
# Recommended: Set via environment variables
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
# export SWANLAB_LARK_SECRET=your-webhook-secret
# Or set in config (less secure):
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
# swanlab_lark_secret: your-webhook-secret # REQUIRED for production
# Security note: ALWAYS use swanlab_lark_secret in production to prevent
# unauthorized parties from sending fake notifications to your team chat.
# ------------------------------------------------------------------------------
# Performance Profiling
# ------------------------------------------------------------------------------
# Profiling is automatically enabled when SwanLab is enabled.
# Metrics logged to SwanLab under "profiling/" namespace:
# profiling/Time taken: AxolotlTrainer.training_step
# profiling/Time taken: AxolotlTrainer.compute_loss
# profiling/Time taken: AxolotlTrainer.prediction_step
#
# Use these metrics to:
# - Identify bottlenecks in training loop
# - Compare performance across different configurations
# - Monitor performance regressions over time
# - Debug unexpected slowdowns
# For custom profiling in your own trainer, see:
# examples/swanlab/custom_trainer_profiling.py
# ------------------------------------------------------------------------------
# Optional: Private SwanLab Deployment
# ------------------------------------------------------------------------------
# For enterprise users with private SwanLab deployment:
# swanlab_web_host: https://swanlab.yourcompany.com
# swanlab_api_host: https://api.swanlab.yourcompany.com
# ------------------------------------------------------------------------------
# Optional: Model Checkpointing to SwanLab
# ------------------------------------------------------------------------------
# Log model checkpoints to SwanLab (coming soon)
swanlab_log_model: false
# ============================================================================
# Disable Other Logging Tools (Recommended)
# ============================================================================
# Using multiple logging tools simultaneously can impact performance:
# - Expected overhead: ~1-2% per logger
# - Potential config/callback conflicts
#
# For production training, use ONLY SwanLab:
# wandb_project:
# use_wandb: false
#
# use_mlflow: false
#
# use_comet: false
# ============================================================================
# Expected Training Behavior
# ============================================================================
# With this configuration, you should see:
#
# 1. SwanLab Initialization (rank 0 only):
# INFO: SwanLab initialized for project: dpo-production
# INFO: SwanLab experiment: llama-3-dpo-full-featured-v1
# INFO: SwanLab mode: cloud
# INFO: SwanLab workspace: ml-research-team
#
# 2. Completion Logging (rank 0 only):
# INFO: Registered SwanLab RLHF completion logging callback for DPOTrainer
# (log_interval=100, max_buffer=256)
#
# 3. Lark Notifications (rank 0 only):
# INFO: Registered Lark notification callback with HMAC authentication
#
# 4. Distributed Training Detection (if multi-GPU):
# INFO: Distributed training detected (world_size=N)
# INFO: Only rank 0 will initialize SwanLab
# INFO: Other ranks will skip SwanLab to avoid conflicts
#
# 5. Training Start Notification (Lark):
# Your team chat receives: "Training started: llama-3-dpo-full-featured-v1"
#
# 6. Periodic Completion Logging:
# Every 100 steps, completion table is updated in SwanLab dashboard
#
# 7. Training Complete Notification (Lark):
# Your team chat receives: "Training completed: llama-3-dpo-full-featured-v1"
# With link to SwanLab dashboard and final metrics
#
# 8. SwanLab Dashboard Shows:
# - Training metrics (loss, learning rate, etc.)
# - Completion table (rlhf_completions)
# - Profiling metrics (profiling/Time taken: ...)
# - Hyperparameters and configuration
# - System resource usage
# ============================================================================
# Production Checklist
# ============================================================================
# Before deploying to production, verify:
# ✅ SwanLab API key is set via environment variable (not in config)
# ✅ Lark webhook secret is set (required for HMAC authentication)
# ✅ Workspace is set to your team's workspace
# ✅ Experiment name is descriptive and unique
# ✅ Only SwanLab is enabled (other loggers disabled)
# ✅ Completion logging buffer size is appropriate for your training duration
# ✅ Private deployment hosts are set (if using enterprise SwanLab)
# ✅ Test run completes successfully and shows up in SwanLab dashboard
# ✅ Lark notifications are received in team chat
# ✅ Profiling metrics are logged correctly
# ============================================================================
# Troubleshooting
# ============================================================================
# If SwanLab initialization fails:
# 1. Check SWANLAB_API_KEY environment variable is set
# 2. Verify swanlab_project is set in config
# 3. Check swanlab_mode is valid (cloud/local/offline/disabled)
# 4. Verify internet connectivity (for cloud mode)
# If Lark notifications not received:
# 1. Check SWANLAB_LARK_WEBHOOK_URL is set correctly
# 2. Verify SWANLAB_LARK_SECRET matches your Lark bot settings
# 3. Test webhook manually: curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...
# 4. Check training logs for "Registered Lark notification callback"
# 5. Verify bot is added to the target Lark group chat
# If completions not appearing in SwanLab:
# 1. Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
# 2. Check swanlab_log_completions is true
# 3. Wait for log_interval steps (default: 100)
# 4. Check training logs for "Registered SwanLab RLHF completion logging"
# If profiling metrics not appearing:
# 1. Verify use_swanlab is true
# 2. Check SwanLab is initialized (check logs)
# 3. Look under "profiling/" namespace in dashboard
# 4. Profiling may be disabled if DEFAULT_PROFILING_CONFIG.enabled = False
# For more help:
# - SwanLab docs: https://docs.swanlab.cn
# - Axolotl SwanLab integration: src/axolotl/integrations/swanlab/README.md
# - GitHub issues: https://github.com/axolotl-ai-cloud/axolotl/issues

View File

@@ -1,178 +0,0 @@
# SwanLab LoRA Training Example with Performance Profiling
#
# This example demonstrates standard LoRA fine-tuning with SwanLab integration
# for performance profiling and optimization.
#
# Features enabled:
# - SwanLab experiment tracking
# - Performance profiling (training step, forward/backward pass timing)
# - Real-time metrics visualization
#
# To run:
# export SWANLAB_API_KEY=your-api-key
# accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
# Model Configuration
base_model: NousResearch/Llama-3.2-1B
# Dataset Configuration
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
val_set_size: 0.1
output_dir: ./outputs/lora-swanlab-profiling-out
# LoRA Configuration
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
# Training Configuration
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
micro_batch_size: 2
gradient_accumulation_steps: 2
num_epochs: 1
# Optimization
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
warmup_ratio: 0.1
weight_decay: 0.0
# Precision
bf16: auto
tf32: false
# Performance
gradient_checkpointing: true
flash_attention: true
# Checkpointing and Logging
logging_steps: 1
evals_per_epoch: 4
saves_per_epoch: 1
# Loss Monitoring
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
special_tokens:
pad_token: "<|end_of_text|>"
# ============================================================================
# SwanLab Integration
# ============================================================================
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
# Basic SwanLab Configuration
use_swanlab: true
swanlab_project: lora-profiling
swanlab_experiment_name: llama-3.2-1b-profiling-demo
swanlab_description: "LoRA fine-tuning with performance profiling"
swanlab_mode: cloud # Options: cloud, local, offline, disabled
# SwanLab Authentication
# Recommended: Set via environment variable
# export SWANLAB_API_KEY=your-api-key
# Or set in config (less secure):
# swanlab_api_key: your-api-key
# Optional: Team workspace
# swanlab_workspace: my-ml-team
# ============================================================================
# Performance Profiling
# ============================================================================
#
# SwanLab automatically profiles trainer methods when enabled.
# Profiling metrics appear in SwanLab dashboard under "profiling/" namespace.
#
# Built-in profiling:
# - Minimal overhead (< 0.1% per step)
# - High-precision timing (microsecond accuracy)
# - Exception-safe (logs duration even if method fails)
#
# View profiling metrics in SwanLab dashboard:
# profiling/Time taken: AxolotlTrainer.training_step
# profiling/Time taken: AxolotlTrainer.compute_loss
# profiling/Time taken: AxolotlTrainer.prediction_step
#
# For custom profiling in your own trainer, see:
# examples/swanlab/custom_trainer_profiling.py
# Completion logging is disabled for non-RLHF trainers
swanlab_log_completions: false # Only works with DPO/KTO/ORPO/GRPO
# ============================================================================
# Optional: Compare with Multiple Runs
# ============================================================================
#
# To compare profiling metrics across different configurations:
#
# 1. Run baseline without flash attention:
# swanlab_experiment_name: llama-3.2-1b-no-flash-attn
# flash_attention: false
#
# 2. Run with gradient checkpointing:
# swanlab_experiment_name: llama-3.2-1b-grad-checkpoint
# gradient_checkpointing: true
#
# 3. Run with both:
# swanlab_experiment_name: llama-3.2-1b-optimized
# flash_attention: true
# gradient_checkpointing: true
#
# Then compare profiling metrics in SwanLab dashboard to see performance impact
# ============================================================================
# Optional: Lark (Feishu) Team Notifications
# ============================================================================
#
# Get notified when profiling experiments complete:
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
# swanlab_lark_secret: your-webhook-secret
# ============================================================================
# Profiling Best Practices
# ============================================================================
#
# 1. Run multiple epochs to see profiling trends over time
# 2. Ignore first ~10 steps (warmup period, slower)
# 3. Look for outliers (steps that take significantly longer)
# 4. Compare profiling metrics before/after optimization changes
# 5. Monitor per-rank profiling in distributed training
#
# Common bottlenecks to profile:
# - training_step: Overall step time (should be consistent)
# - compute_loss: Loss computation (scales with sequence length)
# - prediction_step: Evaluation time (can be slow for large val sets)
#
# If you see inconsistent timing:
# - Check for data loading bottlenecks
# - Monitor GPU utilization (may be CPU-bound)
# - Check for gradient accumulation effects
# - Verify CUDA kernel synchronization
# ============================================================================
# Disable WandB if you're migrating from it
# ============================================================================
# wandb_project:
# use_wandb: false

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.49.1
bitsandbytes==0.48.2
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
@@ -63,7 +63,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.13.0
torchao==0.15.0
openenv-core==0.1.0
schedulefree==1.4.1

View File

@@ -373,11 +373,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple
if self.cfg.use_dynamic_finetuning:
from axolotl.monkeypatch.loss.dft import dft_loss
trainer_kwargs["compute_loss_func"] = dft_loss
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(

View File

@@ -660,10 +660,11 @@ class AxolotlTrainer(
logs["tokens/train_per_sec_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
if "total" in self.state.tokens:
logs["tokens/total"] = int(self.state.tokens["total"].item())
if "trainable" in self.state.tokens:
logs["tokens/trainable"] = int(self.state.tokens["trainable"].item())
if (
hasattr(self.state, "total_tokens")
and self.state.total_tokens is not None
):
logs["total_tokens"] = int(self.state.total_tokens.item())
del self._stored_metrics[train_eval]

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +0,0 @@
"""SwanLab integration plugin for Axolotl"""
from axolotl.integrations.swanlab.args import SwanLabConfig
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
__all__ = ["SwanLabConfig", "SwanLabPlugin"]

View File

@@ -1,140 +0,0 @@
"""SwanLab configuration arguments"""
from pydantic import BaseModel, Field, field_validator, model_validator
class SwanLabConfig(BaseModel):
"""SwanLab configuration subset"""
use_swanlab: bool | None = Field(
default=True,
json_schema_extra={
"description": "Enable SwanLab experiment tracking and visualization"
},
)
swanlab_project: str | None = Field(
default=None,
json_schema_extra={"description": "Your SwanLab project name"},
)
swanlab_experiment_name: str | None = Field(
default=None,
json_schema_extra={"description": "Set the name of your SwanLab experiment"},
)
swanlab_description: str | None = Field(
default=None,
json_schema_extra={"description": "Description for your SwanLab experiment"},
)
swanlab_mode: str | None = Field(
default=None,
json_schema_extra={
"description": '"cloud" to sync to SwanLab cloud, "local" for local only, "offline" to save metadata locally, "disabled" to turn off SwanLab'
},
)
swanlab_workspace: str | None = Field(
default=None,
json_schema_extra={
"description": "SwanLab workspace name (organization or username)"
},
)
swanlab_api_key: str | None = Field(
default=None,
json_schema_extra={
"description": "SwanLab API key for authentication. Can also be set via SWANLAB_API_KEY environment variable"
},
)
swanlab_log_model: bool | None = Field(
default=False,
json_schema_extra={
"description": "Whether to log model checkpoints to SwanLab (feature coming soon)"
},
)
swanlab_web_host: str | None = Field(
default=None,
json_schema_extra={
"description": "Web address for SwanLab cloud environment (for private deployment)"
},
)
swanlab_api_host: str | None = Field(
default=None,
json_schema_extra={
"description": "API address for SwanLab cloud environment (for private deployment)"
},
)
swanlab_lark_webhook_url: str | None = Field(
default=None,
json_schema_extra={
"description": "Lark (Feishu) webhook URL for sending training notifications to team chat"
},
)
swanlab_lark_secret: str | None = Field(
default=None,
json_schema_extra={
"description": "Secret for Lark webhook HMAC signature authentication (optional)"
},
)
swanlab_log_completions: bool | None = Field(
default=True,
json_schema_extra={
"description": "Enable logging RLHF completions to SwanLab for qualitative analysis (DPO/KTO/ORPO/GRPO)"
},
)
swanlab_completion_log_interval: int | None = Field(
default=100,
json_schema_extra={
"description": "Number of training steps between completion table logging to SwanLab"
},
)
swanlab_completion_max_buffer: int | None = Field(
default=128,
json_schema_extra={
"description": "Maximum number of completions to buffer before logging (prevents memory leaks)"
},
)
@field_validator("swanlab_mode")
@classmethod
def validate_swanlab_mode(cls, v):
"""Validate swanlab_mode is one of the allowed values."""
if v is None:
return v
valid_modes = ["cloud", "local", "offline", "disabled"]
if v not in valid_modes:
raise ValueError(
f"Invalid swanlab_mode: '{v}'.\n\n"
f"Valid options: {', '.join(valid_modes)}\n\n"
f"Examples:\n"
f" swanlab_mode: cloud # Sync to SwanLab cloud\n"
f" swanlab_mode: local # Local only, no cloud sync\n"
f" swanlab_mode: offline # Save metadata locally\n"
f" swanlab_mode: disabled # Turn off SwanLab\n"
)
return v
@field_validator("swanlab_project")
@classmethod
def validate_swanlab_project(cls, v):
"""Validate swanlab_project is non-empty when provided."""
if v is not None and isinstance(v, str) and len(v.strip()) == 0:
raise ValueError(
"swanlab_project cannot be an empty string.\n\n"
"Either:\n"
" 1. Provide a valid project name: swanlab_project: my-project\n"
" 2. Remove the swanlab_project field entirely\n"
)
return v
@model_validator(mode="after")
def validate_swanlab_enabled_requires_project(self):
"""Validate that if use_swanlab is True, swanlab_project must be set."""
if self.use_swanlab is True and not self.swanlab_project:
raise ValueError(
"SwanLab enabled (use_swanlab: true) but 'swanlab_project' is not set.\n\n"
"Solutions:\n"
" 1. Add 'swanlab_project: your-project-name' to your config\n"
" 2. Set 'use_swanlab: false' to disable SwanLab\n\n"
"Example:\n"
" use_swanlab: true\n"
" swanlab_project: my-llm-training\n"
)
return self

View File

@@ -1,179 +0,0 @@
"""SwanLab callbacks for Axolotl trainers.
This module provides HuggingFace Trainer callbacks for logging
RLHF completions to SwanLab.
"""
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from axolotl.integrations.swanlab.completion_logger import CompletionLogger
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class SwanLabRLHFCompletionCallback(TrainerCallback):
"""Callback for logging RLHF completions to SwanLab.
This callback periodically logs model completions (prompts, chosen/rejected
responses, rewards) to SwanLab during RLHF training for qualitative analysis.
Supports DPO, KTO, ORPO, and GRPO trainers.
Example usage:
>>> callback = SwanLabRLHFCompletionCallback(
... log_interval=100, # Log every 100 steps
... max_completions=128, # Keep last 128 completions
... )
>>> trainer.add_callback(callback)
Attributes:
logger: CompletionLogger instance
log_interval: Number of steps between SwanLab logging
trainer_type: Auto-detected trainer type (dpo/kto/orpo/grpo)
"""
def __init__(
self,
log_interval: int = 100,
max_completions: int = 128,
table_name: str = "rlhf_completions",
):
"""Initialize SwanLab RLHF completion callback.
Args:
log_interval: Log to SwanLab every N steps. Default: 100
max_completions: Maximum completions to buffer. Default: 128
table_name: SwanLab table name. Default: "rlhf_completions"
"""
super().__init__()
self.logger = CompletionLogger(maxlen=max_completions)
self.log_interval = log_interval
self.table_name = table_name
self.trainer_type: str | None = None # Auto-detected
self._last_logged_step = 0
def on_init_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Detect trainer type on initialization."""
trainer = kwargs.get("trainer")
if trainer is not None:
trainer_name = trainer.__class__.__name__
if "DPO" in trainer_name:
self.trainer_type = "dpo"
elif "KTO" in trainer_name:
self.trainer_type = "kto"
elif "ORPO" in trainer_name:
self.trainer_type = "orpo"
elif "GRPO" in trainer_name:
self.trainer_type = "grpo"
else:
self.trainer_type = "unknown"
LOG.info(
f"SwanLab RLHF completion logging enabled for {trainer_name} "
f"(type: {self.trainer_type})"
)
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs: dict | None = None,
**kwargs,
):
"""Capture completions from logs and buffer them.
Different trainers log completions in different formats:
- DPO: logs['dpo/chosen'], logs['dpo/rejected'], logs['dpo/reward_diff']
- KTO: logs['kto/completion'], logs['kto/label'], logs['kto/reward']
- ORPO: logs['orpo/chosen'], logs['orpo/rejected']
- GRPO: logs['grpo/completion'], logs['grpo/reward']
Note: This is a placeholder implementation. Actual log keys depend
on the TRL trainer implementation. You may need to patch the trainers
to expose completion data in logs.
"""
if logs is None or self.trainer_type is None:
return
step = state.global_step
# DPO completions
if self.trainer_type == "dpo":
if all(key in logs for key in ["dpo/prompt", "dpo/chosen", "dpo/rejected"]):
self.logger.add_dpo_completion(
step=step,
prompt=logs.get("dpo/prompt", ""),
chosen=logs.get("dpo/chosen", ""),
rejected=logs.get("dpo/rejected", ""),
reward_diff=logs.get("dpo/reward_diff"),
)
# KTO completions
elif self.trainer_type == "kto":
if all(key in logs for key in ["kto/prompt", "kto/completion"]):
self.logger.add_kto_completion(
step=step,
prompt=logs.get("kto/prompt", ""),
completion=logs.get("kto/completion", ""),
label=logs.get("kto/label", False),
reward=logs.get("kto/reward"),
)
# ORPO completions
elif self.trainer_type == "orpo":
if all(
key in logs for key in ["orpo/prompt", "orpo/chosen", "orpo/rejected"]
):
self.logger.add_orpo_completion(
step=step,
prompt=logs.get("orpo/prompt", ""),
chosen=logs.get("orpo/chosen", ""),
rejected=logs.get("orpo/rejected", ""),
log_odds_ratio=logs.get("orpo/log_odds_ratio"),
)
# GRPO completions
elif self.trainer_type == "grpo":
if all(key in logs for key in ["grpo/prompt", "grpo/completion"]):
self.logger.add_grpo_completion(
step=step,
prompt=logs.get("grpo/prompt", ""),
completion=logs.get("grpo/completion", ""),
reward=logs.get("grpo/reward"),
advantage=logs.get("grpo/advantage"),
)
# Periodically log to SwanLab
if step - self._last_logged_step >= self.log_interval:
if len(self.logger) > 0:
self.logger.log_to_swanlab(table_name=self.table_name)
self.logger.clear()
self._last_logged_step = step
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Log remaining completions at end of training."""
if len(self.logger) > 0:
LOG.info(
f"Training complete, logging final {len(self.logger)} completions to SwanLab"
)
self.logger.log_to_swanlab(table_name=self.table_name)
self._last_logged_step = state.global_step

View File

@@ -1,228 +0,0 @@
"""SwanLab completion logger for RLHF/DPO/KTO/ORPO/GRPO training.
This module provides utilities for logging model completions during
preference training to SwanLab for qualitative analysis.
"""
from collections import deque
from collections.abc import Mapping
from typing import Any
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class CompletionLogger:
"""Memory-bounded logger for RLHF completions.
Stores prompts, completions, and rewards in fixed-size deques to prevent
memory leaks during long training runs. Logs completion tables to SwanLab
for qualitative analysis of model outputs.
Example usage:
>>> logger = CompletionLogger(maxlen=128)
>>> logger.add_dpo_completion(
... step=0,
... prompt="What is AI?",
... chosen="Artificial Intelligence is...",
... rejected="AI means...",
... reward_diff=0.5
... )
>>> logger.log_to_swanlab()
Attributes:
maxlen: Maximum number of completions to store (older ones are dropped)
data: Deque storing completion dictionaries
"""
def __init__(self, maxlen: int = 128):
"""Initialize completion logger with bounded buffer.
Args:
maxlen: Maximum number of completions to store. When the buffer
is full, oldest completions are automatically discarded.
Default: 128 (sufficient for most RLHF runs without memory issues)
"""
self.maxlen = maxlen
self.data: deque[Mapping[str, Any]] = deque(maxlen=maxlen)
def add_dpo_completion(
self,
step: int,
prompt: str,
chosen: str,
rejected: str,
reward_diff: float | None = None,
) -> None:
"""Add a DPO completion to the buffer.
Args:
step: Training step number
prompt: Input prompt
chosen: Chosen (preferred) completion
rejected: Rejected (non-preferred) completion
reward_diff: Reward difference (chosen - rejected), if available
"""
entry = {
"step": step,
"prompt": prompt,
"chosen": chosen,
"rejected": rejected,
}
if reward_diff is not None:
entry["reward_diff"] = reward_diff
self.data.append(entry)
def add_kto_completion(
self,
step: int,
prompt: str,
completion: str,
label: bool,
reward: float | None = None,
) -> None:
"""Add a KTO completion to the buffer.
Args:
step: Training step number
prompt: Input prompt
completion: Model-generated completion
label: True if desirable, False if undesirable
reward: Reward score, if available
"""
entry = {
"step": step,
"prompt": prompt,
"completion": completion,
"label": "desirable" if label else "undesirable",
}
if reward is not None:
entry["reward"] = reward
self.data.append(entry)
def add_orpo_completion(
self,
step: int,
prompt: str,
chosen: str,
rejected: str,
log_odds_ratio: float | None = None,
) -> None:
"""Add an ORPO completion to the buffer.
Args:
step: Training step number
prompt: Input prompt
chosen: Chosen (preferred) completion
rejected: Rejected (non-preferred) completion
log_odds_ratio: Log odds ratio between chosen and rejected
"""
entry = {
"step": step,
"prompt": prompt,
"chosen": chosen,
"rejected": rejected,
}
if log_odds_ratio is not None:
entry["log_odds_ratio"] = log_odds_ratio
self.data.append(entry)
def add_grpo_completion(
self,
step: int,
prompt: str,
completion: str,
reward: float | None = None,
advantage: float | None = None,
) -> None:
"""Add a GRPO completion to the buffer.
Args:
step: Training step number
prompt: Input prompt
completion: Model-generated completion
reward: Reward score from reward model
advantage: Advantage estimate (reward - baseline)
"""
entry = {
"step": step,
"prompt": prompt,
"completion": completion,
}
if reward is not None:
entry["reward"] = reward
if advantage is not None:
entry["advantage"] = advantage
self.data.append(entry)
def log_to_swanlab(self, table_name: str = "completions") -> bool:
"""Log buffered completions to SwanLab as a table.
Creates a SwanLab echarts Table with all buffered completions.
Only logs if SwanLab is initialized and data is available.
Args:
table_name: Name of the table in SwanLab dashboard.
Default: "completions"
Returns:
True if logging succeeded, False otherwise
"""
if not self.data:
LOG.debug("No completions to log to SwanLab")
return False
try:
import swanlab
if swanlab.get_run() is None:
LOG.debug("SwanLab not initialized, skipping completion logging")
return False
# Convert deque to list of dicts
completions = list(self.data)
# Extract headers from first entry (all entries should have same structure)
headers = list(completions[0].keys())
# Build rows: each completion becomes one row
rows = []
for completion in completions:
row = [completion.get(header, "") for header in headers]
rows.append(row)
# Log to SwanLab as echarts Table
swanlab.log({table_name: swanlab.echarts.Table().add(headers, rows)})
LOG.info(f"Logged {len(rows)} completions to SwanLab table '{table_name}'")
return True
except ImportError:
LOG.warning(
"SwanLab not installed, cannot log completions. "
"Install with: pip install swanlab"
)
return False
except Exception as err: # pylint: disable=broad-except
LOG.exception("Failed to log completions to SwanLab: %s", err)
return False
def clear(self) -> None:
"""Clear all buffered completions."""
self.data.clear()
def __len__(self) -> int:
"""Return number of buffered completions."""
return len(self.data)
def __repr__(self) -> str:
"""String representation showing buffer status."""
return (
f"CompletionLogger(maxlen={self.maxlen}, "
f"buffered={len(self.data)}/{self.maxlen})"
)

View File

@@ -1,554 +0,0 @@
"""SwanLab Plugin for Axolotl"""
from __future__ import annotations
from typing import TYPE_CHECKING
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from transformers import TrainerCallback
from axolotl.utils.dict import DictDefault
LOG = get_logger(__name__)
class SwanLabPlugin(BasePlugin):
"""
SwanLab integration plugin for Axolotl.
Provides experiment tracking, visualization, and logging capabilities
using SwanLab (https://swanlab.cn).
Usage in config.yaml:
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
use_swanlab: true
swanlab_project: my-project
swanlab_experiment_name: my-experiment
swanlab_mode: cloud # or 'local', 'offline', 'disabled'
"""
def __init__(self):
super().__init__()
self.swanlab_initialized = False
LOG.info("SwanLab plugin initialized")
def get_input_args(self) -> str:
"""Returns the configuration model for SwanLab integration."""
return "axolotl.integrations.swanlab.SwanLabConfig"
def register(self, cfg: dict):
"""Register SwanLab plugin with configuration and conflict detection."""
LOG.info("Registering SwanLab plugin")
# === Conflict Detection: Required Fields ===
# Check if SwanLab is enabled
if cfg.get("use_swanlab"):
# 1. Validate project name is set
if not cfg.get("swanlab_project"):
raise ValueError(
"SwanLab enabled but 'swanlab_project' is not set.\n\n"
"Solutions:\n"
" 1. Add 'swanlab_project: your-project-name' to your config\n"
" 2. Set 'use_swanlab: false' to disable SwanLab\n\n"
"See: src/axolotl/integrations/swanlab/README.md for examples"
)
# 2. Validate swanlab_mode value
valid_modes = ["cloud", "local", "offline", "disabled"]
mode = cfg.get("swanlab_mode")
if mode and mode not in valid_modes:
raise ValueError(
f"Invalid swanlab_mode: '{mode}'.\n\n"
f"Valid options: {', '.join(valid_modes)}\n\n"
f"Example:\n"
f" swanlab_mode: cloud # Sync to SwanLab cloud\n"
f" swanlab_mode: local # Local only, no cloud sync\n"
)
# 3. Check API key for cloud mode
import os
mode = cfg.get("swanlab_mode", "cloud") # Default is cloud
if mode == "cloud":
api_key = cfg.get("swanlab_api_key") or os.environ.get(
"SWANLAB_API_KEY"
)
if not api_key:
LOG.warning(
"SwanLab cloud mode enabled but no API key found.\n"
"SwanLab may fail to initialize during training.\n\n"
"Solutions:\n"
" 1. Set SWANLAB_API_KEY environment variable:\n"
" export SWANLAB_API_KEY=your-api-key\n"
" 2. Add 'swanlab_api_key: your-api-key' to config (less secure)\n"
" 3. Run 'swanlab login' before training\n"
" 4. Use 'swanlab_mode: local' for offline tracking\n"
)
# === Conflict Detection: Multi-Logger Performance Warning ===
# Detect all active logging tools
active_loggers = []
if cfg.get("use_wandb"):
active_loggers.append("WandB")
if cfg.get("use_mlflow"):
active_loggers.append("MLflow")
if cfg.get("comet_api_key") or cfg.get("comet_project_name"):
active_loggers.append("Comet")
if cfg.get("use_swanlab"):
active_loggers.append("SwanLab")
if len(active_loggers) > 1:
LOG.warning(
f"\n{'=' * 70}\n"
f"Multiple logging tools enabled: {', '.join(active_loggers)}\n"
f"{'=' * 70}\n"
f"This may cause:\n"
f" - Performance overhead (~1-2% per logger, cumulative)\n"
f" - Increased memory usage\n"
f" - Longer training time per step\n"
f" - Potential config/callback conflicts\n\n"
f"Recommendations:\n"
f" - Choose ONE primary logging tool for production training\n"
f" - Use multiple loggers only for:\n"
f" * Migration period (transitioning between tools)\n"
f" * Short comparison runs\n"
f" * Debugging specific tool issues\n"
f" - Monitor system resources (CPU, memory) during training\n"
f"{'=' * 70}\n"
)
if len(active_loggers) >= 3:
LOG.error(
f"\n{'!' * 70}\n"
f"WARNING: {len(active_loggers)} logging tools enabled simultaneously!\n"
f"{'!' * 70}\n"
f"This is likely unintentional and WILL significantly impact performance.\n"
f"Expected overhead: ~{len(active_loggers) * 1.5:.1f}% per training step.\n\n"
f"STRONGLY RECOMMEND:\n"
f" - Disable all but ONE logging tool\n"
f" - Use config inheritance to manage multiple configs\n"
f"{'!' * 70}\n"
)
# === Auto-Enable Logic ===
# Enable SwanLab if project is specified
if cfg.get("swanlab_project") and not cfg.get("use_swanlab"):
cfg["use_swanlab"] = True
LOG.info("Automatically enabled use_swanlab because swanlab_project is set")
def pre_model_load(self, cfg: DictDefault):
"""Initialize SwanLab before model loading with runtime checks."""
if not cfg.use_swanlab:
return
# === Runtime Check: Import Availability ===
try:
import swanlab
except ImportError as err:
raise ImportError(
"SwanLab is not installed.\n\n"
"Install with:\n"
" pip install swanlab\n\n"
"Or add to requirements:\n"
" swanlab>=0.3.0\n\n"
f"Original error: {err}"
) from err
# Log SwanLab version
try:
swanlab_version = swanlab.__version__
LOG.info(f"SwanLab version: {swanlab_version}")
except AttributeError:
LOG.warning("Could not determine SwanLab version")
# === Runtime Check: Distributed Training Setup ===
from axolotl.utils.distributed import get_world_size, is_main_process
world_size = get_world_size()
if world_size > 1:
mode = getattr(cfg, "swanlab_mode", "cloud")
LOG.info(
f"\n{'=' * 70}\n"
f"Distributed training detected (world_size={world_size})\n"
f"SwanLab mode: {mode}\n"
f"{'=' * 70}\n"
f"Behavior:\n"
f" - Only rank 0 will initialize SwanLab\n"
f" - Other ranks will skip SwanLab to avoid conflicts\n"
)
if mode == "cloud":
LOG.info(
f" - Only rank 0 will upload to SwanLab cloud\n"
f" - Other ranks run without SwanLab overhead\n"
f"{'=' * 70}\n"
)
# Only initialize SwanLab on the main process (rank 0)
# to avoid creating multiple runs in distributed training
if not is_main_process():
LOG.debug("Skipping SwanLab initialization on non-main process")
return
# Initialize SwanLab run (passing all params directly to init)
try:
init_kwargs = self._get_swanlab_init_kwargs(cfg)
swanlab.init(**init_kwargs)
self.swanlab_initialized = True
LOG.info(f"SwanLab initialized with project: {cfg.swanlab_project}")
# Register Lark notification callback (if configured)
self._register_lark_callback(cfg)
# Log configuration (with error handling)
try:
config_dict = self._prepare_config_for_logging(cfg)
swanlab.config.update(config_dict)
LOG.debug("Successfully logged config to SwanLab")
except Exception as config_err: # pylint: disable=broad-except
LOG.warning(
f"Failed to log config to SwanLab: {config_err}. Continuing anyway."
)
except Exception as err: # pylint: disable=broad-except
LOG.exception("Failed to initialize SwanLab: %s", err)
self.swanlab_initialized = False
def add_callbacks_pre_trainer(self, cfg: DictDefault, model):
"""Add SwanLab callbacks before trainer creation."""
callbacks: list[TrainerCallback] = []
if not cfg.use_swanlab:
return callbacks
if not self.swanlab_initialized:
LOG.warning("SwanLab not initialized, skipping callback registration")
return callbacks
try:
from axolotl.utils.callbacks.swanlab import (
CustomSwanLabCallback,
SaveAxolotlConfigtoSwanLabCallback,
)
# Add our custom lightweight SwanLabCallback
# (avoids omegaconf/antlr4 version conflicts)
swanlab_callback = CustomSwanLabCallback()
callbacks.append(swanlab_callback)
LOG.info("Added CustomSwanLabCallback for metrics logging")
# Add Axolotl config logging callback
if cfg.axolotl_config_path:
config_callback = SaveAxolotlConfigtoSwanLabCallback(
cfg.axolotl_config_path
)
callbacks.append(config_callback)
LOG.info("Added SaveAxolotlConfigtoSwanLabCallback")
except ImportError as err:
LOG.exception("Failed to import SwanLab callbacks: %s", err)
return callbacks
def post_trainer_create(self, cfg: DictDefault, trainer):
"""Post-trainer creation hook."""
if cfg.use_swanlab and self.swanlab_initialized:
try:
import swanlab
# Log additional trainer information (with safe conversion)
trainer_config = {
"total_steps": int(trainer.state.max_steps)
if trainer.state.max_steps
else None,
"num_train_epochs": float(trainer.args.num_train_epochs)
if trainer.args.num_train_epochs
else None,
"train_batch_size": int(trainer.args.train_batch_size)
if hasattr(trainer.args, "train_batch_size")
else None,
"gradient_accumulation_steps": int(
trainer.args.gradient_accumulation_steps
)
if trainer.args.gradient_accumulation_steps
else None,
}
# Remove None values
trainer_config = {
k: v for k, v in trainer_config.items() if v is not None
}
if trainer_config:
swanlab.config.update(trainer_config)
LOG.info("Logged trainer configuration to SwanLab")
except Exception as err: # pylint: disable=broad-except
LOG.debug(f"Failed to log trainer config to SwanLab: {err}")
# Register RLHF completion logging callback if enabled
self._register_completion_callback(cfg, trainer)
def _get_swanlab_init_kwargs(self, cfg: DictDefault) -> dict:
"""Prepare kwargs for swanlab.init().
Passes all configuration parameters directly to swanlab.init()
instead of using environment variables as an intermediate layer.
Returns:
dict: Keyword arguments for swanlab.init()
"""
init_kwargs = {}
# Project name (required)
if cfg.swanlab_project:
init_kwargs["project"] = cfg.swanlab_project
# Experiment name
if cfg.swanlab_experiment_name:
init_kwargs["experiment_name"] = cfg.swanlab_experiment_name
# Description
if cfg.swanlab_description:
init_kwargs["description"] = cfg.swanlab_description
# Workspace (organization)
if cfg.swanlab_workspace:
init_kwargs["workspace"] = cfg.swanlab_workspace
# Mode: cloud, local, offline, disabled
if cfg.swanlab_mode:
init_kwargs["mode"] = cfg.swanlab_mode
# API key (pass directly instead of via env var)
if cfg.swanlab_api_key:
init_kwargs["api_key"] = cfg.swanlab_api_key
# Private deployment hosts (pass directly instead of via env var)
if cfg.swanlab_web_host:
init_kwargs["web_host"] = cfg.swanlab_web_host
if cfg.swanlab_api_host:
init_kwargs["api_host"] = cfg.swanlab_api_host
# Log model checkpoints (coming soon in SwanLab)
if cfg.swanlab_log_model:
init_kwargs["log_model"] = cfg.swanlab_log_model
# Custom branding - adds Axolotl identifier to SwanLab UI
# This helps identify runs from Axolotl vs other frameworks
init_kwargs["config"] = {"UPPERFRAME": "🦎 Axolotl"}
return init_kwargs
def _prepare_config_for_logging(self, cfg: DictDefault) -> dict:
"""Prepare configuration dict for logging to SwanLab."""
def safe_convert(value):
"""Convert value to JSON-serializable type."""
if value is None:
return None
if isinstance(value, (int, float, bool)):
return value
if isinstance(value, str):
return value
# Convert everything else to string
return str(value)
try:
# Extract important training parameters with safe conversion
config_dict = {
"base_model": safe_convert(getattr(cfg, "base_model", "")),
"model_type": safe_convert(getattr(cfg, "model_type", "")),
"sequence_len": safe_convert(getattr(cfg, "sequence_len", None)),
"micro_batch_size": safe_convert(
getattr(cfg, "micro_batch_size", None)
),
"gradient_accumulation_steps": safe_convert(
getattr(cfg, "gradient_accumulation_steps", None)
),
"num_epochs": safe_convert(getattr(cfg, "num_epochs", None)),
"max_steps": safe_convert(getattr(cfg, "max_steps", None)),
"learning_rate": safe_convert(getattr(cfg, "learning_rate", None)),
"lr_scheduler": safe_convert(getattr(cfg, "lr_scheduler", "")),
"optimizer": safe_convert(getattr(cfg, "optimizer", "")),
"warmup_ratio": safe_convert(getattr(cfg, "warmup_ratio", None)),
"weight_decay": safe_convert(getattr(cfg, "weight_decay", None)),
"seed": safe_convert(getattr(cfg, "seed", None)),
"bf16": safe_convert(getattr(cfg, "bf16", None)),
"tf32": safe_convert(getattr(cfg, "tf32", None)),
"flash_attention": safe_convert(getattr(cfg, "flash_attention", None)),
"sample_packing": safe_convert(getattr(cfg, "sample_packing", None)),
}
# Add FSDP/parallel config - only boolean flags
if hasattr(cfg, "fsdp_config") and cfg.fsdp_config:
config_dict["fsdp_enabled"] = True
config_dict["fsdp_version"] = safe_convert(
getattr(cfg, "fsdp_version", None)
)
if hasattr(cfg, "deepspeed") and cfg.deepspeed:
config_dict["deepspeed_enabled"] = True
# Add context parallel info
if hasattr(cfg, "context_parallel_size"):
config_dict["context_parallel_size"] = safe_convert(
getattr(cfg, "context_parallel_size", None)
)
if hasattr(cfg, "tensor_parallel_size"):
config_dict["tensor_parallel_size"] = safe_convert(
getattr(cfg, "tensor_parallel_size", None)
)
if hasattr(cfg, "dp_shard_size"):
config_dict["dp_shard_size"] = safe_convert(
getattr(cfg, "dp_shard_size", None)
)
# Remove None values and empty strings
config_dict = {
k: v
for k, v in config_dict.items()
if v is not None and v != "" and v != "None"
}
return config_dict
except Exception as err: # pylint: disable=broad-except
LOG.warning(f"Failed to prepare config for logging: {err}")
# Return minimal config
try:
lr = getattr(cfg, "learning_rate", None)
lr_value = float(lr) if lr is not None else None
except (TypeError, ValueError):
lr_value = None
return {
"base_model": str(getattr(cfg, "base_model", "unknown")),
"learning_rate": lr_value,
}
def _register_lark_callback(self, cfg: DictDefault):
"""Register Lark (Feishu) notification callback if configured.
Lark notifications enable sending training updates to team chat channels,
useful for production monitoring and team collaboration.
Args:
cfg: Configuration object with Lark webhook settings
"""
# Check if Lark webhook URL is configured
lark_webhook_url = getattr(cfg, "swanlab_lark_webhook_url", None)
if not lark_webhook_url:
return # Lark not configured, skip
try:
import swanlab
from swanlab.plugin.notification import LarkCallback
# Get optional secret for HMAC signature authentication
lark_secret = getattr(cfg, "swanlab_lark_secret", None)
# Create Lark callback with webhook URL and optional secret
lark_callback = LarkCallback(
webhook_url=lark_webhook_url,
secret=lark_secret,
)
# Register callback with SwanLab
swanlab.register_callbacks([lark_callback])
if lark_secret:
LOG.info(
"Registered Lark notification callback with HMAC authentication"
)
else:
LOG.info("Registered Lark notification callback (no HMAC secret)")
LOG.warning(
"Lark webhook has no secret configured. "
"For production use, set 'swanlab_lark_secret' to enable HMAC signature verification."
)
except ImportError as err:
LOG.warning(
f"Failed to import SwanLab Lark plugin: {err}\n\n"
"Lark notifications require SwanLab >= 0.3.0 with plugin support.\n"
"Install with: pip install 'swanlab>=0.3.0'\n\n"
"Continuing without Lark notifications..."
)
except Exception as err: # pylint: disable=broad-except
LOG.exception(
"Failed to register Lark callback: %s\n\n"
"Check your Lark webhook URL and secret configuration.\n"
"Continuing without Lark notifications...",
err,
)
def _register_completion_callback(self, cfg: DictDefault, trainer):
"""Register RLHF completion logging callback if enabled and applicable.
This callback logs model completions (prompts, chosen/rejected responses,
rewards) to SwanLab during RLHF training for qualitative analysis.
Args:
cfg: Configuration object with completion logging settings
trainer: The trainer instance to add callback to
"""
# Check if completion logging is enabled
log_completions = getattr(cfg, "swanlab_log_completions", True)
if not log_completions:
LOG.debug("SwanLab completion logging disabled by config")
return
# Check if trainer is an RLHF trainer
trainer_name = trainer.__class__.__name__
rlhf_trainers = ["DPO", "KTO", "ORPO", "GRPO", "CPO"]
is_rlhf_trainer = any(name in trainer_name for name in rlhf_trainers)
if not is_rlhf_trainer:
LOG.debug(
f"Trainer {trainer_name} is not an RLHF trainer, "
"skipping completion logging callback"
)
return
try:
from axolotl.integrations.swanlab.callbacks import (
SwanLabRLHFCompletionCallback,
)
# Get configuration parameters
log_interval = getattr(cfg, "swanlab_completion_log_interval", 100)
max_buffer = getattr(cfg, "swanlab_completion_max_buffer", 128)
# Create and register callback
completion_callback = SwanLabRLHFCompletionCallback(
log_interval=log_interval,
max_completions=max_buffer,
table_name="rlhf_completions",
)
trainer.add_callback(completion_callback)
LOG.info(
f"Registered SwanLab RLHF completion logging callback for {trainer_name} "
f"(log_interval={log_interval}, max_buffer={max_buffer})"
)
except ImportError as err:
LOG.warning(
f"Failed to import SwanLab completion callback: {err}\n\n"
"This is a bug - the callback should be available.\n"
"Please report this issue.\n\n"
"Continuing without completion logging..."
)
except Exception as err: # pylint: disable=broad-except
LOG.exception(
"Failed to register SwanLab completion callback: %s\n\n"
"Continuing without completion logging...",
err,
)

View File

@@ -1,203 +0,0 @@
"""SwanLab profiling utilities for Axolotl trainers.
This module provides decorators and context managers for profiling
trainer methods and logging execution times to SwanLab.
"""
import time
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@contextmanager
def swanlab_profiling_context(trainer: Any, func_name: str):
"""Context manager for profiling trainer methods.
Measures execution time and logs to SwanLab if enabled.
Example usage:
>>> with swanlab_profiling_context(self, "training_step"):
... result = do_expensive_computation()
Args:
trainer: Trainer instance (must have cfg attribute with use_swanlab flag)
func_name: Name of the function being profiled
Yields:
None
"""
start_time = time.perf_counter()
try:
yield
finally:
duration = time.perf_counter() - start_time
# Check if SwanLab is enabled and initialized
use_swanlab = getattr(getattr(trainer, "cfg", None), "use_swanlab", False)
if use_swanlab:
try:
import swanlab
if swanlab.get_run() is not None:
# Log profiling metric
trainer_class = trainer.__class__.__name__
metric_name = f"profiling/Time taken: {trainer_class}.{func_name}"
swanlab.log({metric_name: duration})
except ImportError:
# SwanLab not installed, silently skip
pass
except Exception as err: # pylint: disable=broad-except
# Log error but don't fail training
LOG.debug(f"Failed to log profiling metric for {func_name}: {err}")
def swanlab_profile(func: Callable) -> Callable:
"""Decorator to profile and log function execution time to SwanLab.
Automatically measures execution time of trainer methods and logs
to SwanLab as profiling metrics.
Example usage:
>>> class MyTrainer:
... @swanlab_profile
... def training_step(self, model, inputs):
... return super().training_step(model, inputs)
Args:
func: Function to profile (must be a method of a trainer instance)
Returns:
Wrapped function with profiling
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
with swanlab_profiling_context(self, func.__name__):
return func(self, *args, **kwargs)
return wrapper
class ProfilingConfig:
"""Configuration for SwanLab profiling.
This class provides a centralized way to control profiling behavior.
Attributes:
enabled: Whether profiling is enabled globally
min_duration_ms: Minimum duration (in ms) to log (filters out very fast ops)
log_interval: Log every N function calls (to reduce overhead)
"""
def __init__(
self,
enabled: bool = True,
min_duration_ms: float = 0.1,
log_interval: int = 1,
):
"""Initialize profiling configuration.
Args:
enabled: Enable profiling. Default: True
min_duration_ms: Minimum duration to log (ms). Default: 0.1
log_interval: Log every N calls. Default: 1 (log all)
"""
self.enabled = enabled
self.min_duration_ms = min_duration_ms
self.log_interval = log_interval
self._call_counts: dict[str, int] = {}
def should_log(self, func_name: str, duration_seconds: float) -> bool:
"""Check if a profiling measurement should be logged.
Args:
func_name: Name of the profiled function
duration_seconds: Execution duration in seconds
Returns:
True if should log, False otherwise
"""
if not self.enabled:
return False
# Check minimum duration threshold
duration_ms = duration_seconds * 1000
if duration_ms < self.min_duration_ms:
return False
# Check log interval
self._call_counts.setdefault(func_name, 0)
self._call_counts[func_name] += 1
# Always log on first call OR at intervals
count = self._call_counts[func_name]
if count == 1 or count % self.log_interval == 0:
return True
return False
# Global profiling config (can be modified by users)
DEFAULT_PROFILING_CONFIG = ProfilingConfig()
@contextmanager
def swanlab_profiling_context_advanced(
trainer: Any,
func_name: str,
config: ProfilingConfig | None = None,
):
"""Advanced profiling context with configurable behavior.
Similar to swanlab_profiling_context but with additional configuration
options for filtering and throttling profiling logs.
Example usage:
>>> config = ProfilingConfig(min_duration_ms=1.0, log_interval=10)
>>> with swanlab_profiling_context_advanced(self, "forward", config):
... output = model(inputs)
Args:
trainer: Trainer instance
func_name: Function name
config: Profiling configuration. If None, uses DEFAULT_PROFILING_CONFIG
Yields:
None
"""
if config is None:
config = DEFAULT_PROFILING_CONFIG
start_time = time.perf_counter()
try:
yield
finally:
duration = time.perf_counter() - start_time
# Check if should log based on config
if config.should_log(func_name, duration):
# Check if SwanLab is enabled
use_swanlab = getattr(getattr(trainer, "cfg", None), "use_swanlab", False)
if use_swanlab:
try:
import swanlab
if swanlab.get_run() is not None:
trainer_class = trainer.__class__.__name__
metric_name = (
f"profiling/Time taken: {trainer_class}.{func_name}"
)
swanlab.log({metric_name: duration})
except ImportError:
pass
except Exception as err: # pylint: disable=broad-except
LOG.debug(f"Failed to log profiling metric for {func_name}: {err}")

View File

@@ -138,7 +138,6 @@ class PatchManager:
self._apply_llama_flash_attn_patches(model)
self._apply_unsloth_patches(model)
self._apply_lora_kernel_patch(model)
self._apply_scaling_softmax_patch(model)
def _apply_flash_attention_patches(self):
"""Apply patches related to Flash Attention."""
@@ -561,16 +560,3 @@ class PatchManager:
)
patch_apertus_xielu_activation()
def _apply_scaling_softmax_patch(self, model: PreTrainedModel):
"""Apply Scaling Softmax (SSMax) patch. Ref: https://arxiv.org/abs/2501.19399"""
if self.cfg.scaling_softmax:
from axolotl.monkeypatch.scaled_softmax_attn import (
patch_scaled_softmax_attention,
)
patch_scaled_softmax_attention(
scaling_factor_init=self.cfg.scaling_softmax_factor or 0.43,
bias=self.cfg.scaling_softmax_bias or 0.0,
model=model,
)

View File

@@ -5,7 +5,6 @@ from typing import Type
import addict
import torch
import transformers
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from axolotl.utils.dict import DictDefault
@@ -154,9 +153,6 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
This function determines the appropriate model config source, loads it, applies any
necessary overrides, and validates it for compatibility with the `axolotl` config.
If `cfg.cls_model_config` is set, a custom config class from transformers will be
used instead of `AutoConfig` (e.g., 'LlamaConfig', 'MistralConfig').
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
@@ -178,13 +174,8 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
if cfg.num_labels:
# num_labels is used to initialize classifier models
config_kwargs["num_labels"] = cfg.num_labels
config_cls = AutoConfig
if cfg.cls_model_config:
config_cls = getattr(transformers, cfg.cls_model_config)
try:
model_config = config_cls.from_pretrained(
model_config = AutoConfig.from_pretrained(
model_config_name,
trust_remote_code=trust_remote_code,
**config_kwargs,

View File

@@ -1,98 +0,0 @@
"""Dynamic Fine-Tuning (DFT) loss implementation"""
from typing import Optional
import torch
import torch.nn.functional as F
def selective_log_softmax(logits, index):
"""Memory-efficient log_softmax -> gather"""
if logits.dtype in [torch.float32, torch.float64]:
selected_logits = torch.gather(
logits, dim=-1, index=index.unsqueeze(-1)
).squeeze(-1)
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
per_token_logps = selected_logits - logsumexp_values
else:
per_token_logps = []
for row_logits, row_labels in zip(logits, index, strict=True):
row_logps = F.log_softmax(row_logits, dim=-1)
row_per_token_logps = row_logps.gather(
dim=-1, index=row_labels.unsqueeze(-1)
).squeeze(-1)
per_token_logps.append(row_per_token_logps)
per_token_logps = torch.stack(per_token_logps)
return per_token_logps
def get_dft_loss(ignore_index: int = -100):
"""Creates DFT loss function"""
def for_causal_lm_dft_loss(
logits,
labels,
vocab_size: int = None,
num_items_in_batch: Optional[int] = None,
ignore_index: int = -100,
shift_labels: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""DFT loss: -exp(logprobs).detach() * logprobs"""
if shift_labels is None:
# Shift so that tokens < n predict n
labels = F.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.to(logits.device)
# Create loss mask
loss_mask = shift_labels != ignore_index
shift_labels_masked = shift_labels.clone()
shift_labels_masked[~loss_mask] = 0
# Compute log probabilities
logprobs = selective_log_softmax(logits, shift_labels_masked)
# DFT loss: -exp(logprobs).detach() * logprobs
per_token_loss = -logprobs.exp().detach() * logprobs
# Sum over valid tokens and normalize
if num_items_in_batch is None:
num_items_in_batch = loss_mask.sum()
loss = (per_token_loss * loss_mask).sum() / num_items_in_batch
return loss
return for_causal_lm_dft_loss
def dft_loss(outputs, labels, num_items_in_batch=None):
"""DFT loss compatible with Trainer.compute_loss_func signature.
This function is designed to be passed to Trainer's compute_loss_func parameter.
"""
ignore_index = -100
# Shift labels for causal LM
labels = F.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.to(outputs.logits.device)
# Create loss mask
loss_mask = shift_labels != ignore_index
shift_labels_masked = shift_labels.clone()
shift_labels_masked[~loss_mask] = 0
# Compute log probabilities
logprobs = selective_log_softmax(outputs.logits, shift_labels_masked)
# DFT loss: -exp(logprobs).detach() * logprobs
per_token_loss = -logprobs.exp().detach() * logprobs
# Sum over valid tokens and normalize
if num_items_in_batch is None:
num_items_in_batch = loss_mask.sum()
loss = (per_token_loss * loss_mask).sum() / num_items_in_batch
return loss

View File

@@ -1,141 +0,0 @@
"""
Scaled Softmax (SSMax) attention patch using FlexAttention.
SSMax: softmax(scores * s * log(n) + b) where n is the position index
Ref: https://arxiv.org/abs/2501.19399
"""
import torch
from transformers import PreTrainedModel
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
try:
from torch.nn.attention.flex_attention import BlockMask
from transformers.integrations.flex_attention import (
compile_friendly_flex_attention,
repeat_kv,
)
FLEX_ATTENTION_AVAILABLE = True
except ImportError:
FLEX_ATTENTION_AVAILABLE = False
BlockMask = None
_ssmax_config = {}
def patch_scaled_softmax_attention(
scaling_factor_init: float = 0.43, bias: float = 0.0, model: PreTrainedModel = None
):
"""Patch attention to apply SSMax via FlexAttention score_mod."""
global _ssmax_config
if not FLEX_ATTENTION_AVAILABLE:
raise RuntimeError("SSMax requires FlexAttention.")
_ssmax_config["ssmax_s"] = scaling_factor_init
_ssmax_config["ssmax_b"] = bias
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
if "flex_attention" in ALL_ATTENTION_FUNCTIONS:
_ssmax_config["original_flex_fn"] = ALL_ATTENTION_FUNCTIONS["flex_attention"]
ALL_ATTENTION_FUNCTIONS["flex_attention"] = ssmax_flex_attention_forward
LOG.info(
f"Patched flex_attention with SSMax (s={scaling_factor_init}, b={bias})"
)
else:
LOG.warning("flex_attention not found. Ensure flex_attention: true is set.")
def ssmax_flex_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask,
scaling: float | None = None,
softcap: float | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""FlexAttention forward with SSMax: score * (s * log(n) + b)."""
if kwargs.get("dropout", 0.0) > 0:
raise ValueError("flex_attention does not support dropout")
ssmax_s = _ssmax_config.get("ssmax_s", 0.43)
ssmax_b = _ssmax_config.get("ssmax_b", 0.0)
position_ids = kwargs.get("position_ids", None)
position_ids_flat = position_ids.view(-1) if position_ids is not None else None
block_mask = attention_mask if isinstance(attention_mask, BlockMask) else None
score_mask = None if block_mask else attention_mask
if score_mask is not None:
score_mask = score_mask[:, :, :, : key.shape[-2]]
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
"""
Apply SSMax scaling: score * (s * log(n) + b)
where n is the relative position within each packed sequence.
"""
if position_ids_flat is not None:
relative_pos = position_ids_flat[q_idx]
n = (relative_pos + 1).float()
else:
n = (q_idx + 1).float()
n = torch.clamp(n, min=2.0)
ssmax_scale = ssmax_s * torch.log(n) + ssmax_b
score = score * ssmax_scale
if softcap is not None:
score = softcap * torch.tanh(score / softcap)
if score_mask is not None:
score = score + score_mask[batch_idx][0][q_idx][kv_idx]
return score
enable_gqa = True
if (query.shape[1] & (query.shape[1] - 1)) != 0:
key = repeat_kv(key, query.shape[1] // key.shape[1])
value = repeat_kv(value, query.shape[1] // value.shape[1])
enable_gqa = False
return_lse = query.device.type != "cpu"
flex_output = compile_friendly_flex_attention(
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=enable_gqa,
scale=scaling,
kernel_options=kwargs.get("kernel_options"),
return_lse=return_lse,
training=module.training,
)
if return_lse:
attention_output, lse = flex_output
lse = lse.to(value.dtype)
else:
attention_output, lse = flex_output, None
return attention_output.transpose(1, 2).contiguous(), lse
def unpatch_scaled_softmax_attention():
"""Restore the original FlexAttention function."""
global _ssmax_config
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
if "original_flex_fn" in _ssmax_config:
ALL_ATTENTION_FUNCTIONS["flex_attention"] = _ssmax_config["original_flex_fn"]
_ssmax_config.clear()
LOG.info("Unpatched flex_attention, restored original")

View File

@@ -1,248 +0,0 @@
"""Callbacks for SwanLab integration"""
from __future__ import annotations
import json
import os
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from axolotl.core.training_args import AxolotlTrainingArguments
LOG = get_logger(__name__)
class CustomSwanLabCallback(TrainerCallback):
"""
Lightweight SwanLab callback that directly logs metrics without using
SwanLab's transformers integration (which requires omegaconf).
This avoids the antlr4 version conflict between omegaconf and axolotl.
"""
def __init__(self):
self._initialized = False
self.swanlab = None
def setup(self):
"""Lazy initialization of SwanLab"""
if self._initialized:
return
try:
import swanlab
self.swanlab = swanlab
# Check if SwanLab run is initialized
if swanlab.get_run() is None:
LOG.warning("SwanLab run is not initialized")
return
self._initialized = True
LOG.info("CustomSwanLabCallback initialized successfully")
except ImportError:
LOG.error("SwanLab is not installed")
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Called at the beginning of training"""
if not state.is_world_process_zero:
return control
self.setup()
if not self._initialized:
return control
# Log training configuration
try:
self.swanlab.config.update(
{
"train_batch_size": args.per_device_train_batch_size,
"eval_batch_size": args.per_device_eval_batch_size,
"learning_rate": args.learning_rate,
"num_train_epochs": args.num_train_epochs,
"max_steps": args.max_steps,
"warmup_steps": args.warmup_steps,
"logging_steps": args.logging_steps,
"save_steps": args.save_steps,
"gradient_accumulation_steps": args.gradient_accumulation_steps,
}
)
LOG.debug("Training configuration logged to SwanLab")
except Exception as err:
LOG.warning(f"Failed to log training config: {err}")
return control
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs=None,
**kwargs,
):
"""Called when logging metrics"""
if not state.is_world_process_zero:
return control
if not self._initialized:
self.setup()
if not self._initialized or logs is None:
return control
# Log metrics to SwanLab
try:
# Filter out non-numeric values and prepare for logging
metrics = {}
for key, value in logs.items():
if isinstance(value, (int, float)):
# Use step from state
metrics[key] = value
if metrics and state.global_step is not None:
self.swanlab.log(metrics, step=state.global_step)
except Exception as err:
LOG.warning(f"Failed to log metrics to SwanLab: {err}")
return control
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Called at the end of training"""
if not state.is_world_process_zero:
return control
if self._initialized:
LOG.info("Training completed. SwanLab logs are available.")
return control
class SaveAxolotlConfigtoSwanLabCallback(TrainerCallback):
"""Callback to save axolotl config to SwanLab"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: AxolotlTrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if state.is_world_process_zero:
try:
import swanlab
# Check if SwanLab is initialized
if swanlab.get_run() is None:
LOG.warning(
"SwanLab run is not initialized. Please initialize SwanLab before training."
)
return control
# Log Axolotl config as artifact
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
# Log config file to SwanLab
with open(temp_file.name, "r", encoding="utf-8") as config_file:
swanlab.log(
{
"axolotl_config": swanlab.Text(
config_file.read(), caption="Axolotl Config"
)
}
)
LOG.info(
"The Axolotl config has been saved to the SwanLab run under logs."
)
# Clean up temp file
os.unlink(temp_file.name)
except ImportError:
LOG.warning(
"SwanLab is not installed. Install it with: pip install swanlab"
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to SwanLab: {err}")
# Log DeepSpeed config if available
if args.deepspeed:
try:
import swanlab
with NamedTemporaryFile(
mode="w",
delete=False,
suffix=".json",
prefix="deepspeed_config_",
) as temp_file:
skip_upload = False
if isinstance(args.deepspeed, dict):
json.dump(args.deepspeed, temp_file, indent=4)
elif isinstance(args.deepspeed, str) and os.path.exists(
args.deepspeed
):
copyfile(args.deepspeed, temp_file.name)
else:
skip_upload = True
if not skip_upload:
temp_file.flush()
with open(
temp_file.name, "r", encoding="utf-8"
) as ds_config_file:
swanlab.log(
{
"deepspeed_config": swanlab.Text(
ds_config_file.read(),
caption="DeepSpeed Config",
)
}
)
LOG.info(
"The DeepSpeed config has been saved to the SwanLab run under logs."
)
# Clean up temp file
os.unlink(temp_file.name)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(
f"Error while saving DeepSpeed config to SwanLab: {err}"
)
except ImportError:
pass
return control

View File

@@ -101,3 +101,9 @@ class TokensPerSecondCallback(TrainerCallback):
# Clear per-step tokens after logging
if tokens and "trainable_tokens" in tokens:
tokens["trainable_tokens"] = torch.zeros_like(tokens["trainable_tokens"])
if tokens and "total" in tokens:
logs["tokens/total"] = tokens["total"].item()
if tokens and "trainable" in tokens:
logs["tokens/trainable"] = tokens["trainable"].item()

View File

@@ -9,6 +9,10 @@ from torchao.quantization import quantize_
from torchao.quantization.qat import (
QATConfig,
)
from torchao.quantization.qat import fake_quantizer
from torchao.quantization.qat.fake_quantizer import (
Int4WeightFakeQuantizer as AoInt4WeightFakeQuantizer,
)
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
@@ -17,6 +21,27 @@ from torchao.quantization.quant_api import (
from axolotl.utils.schemas.enums import TorchAOQuantDType
class Int4WeightFakeQuantizer(AoInt4WeightFakeQuantizer):
"""
Adds 'enabled' attribute to Int4WeightFakeQuantizer (removed in torchao 0.15).
Allows toggling fake quantization on/off for fake_quant_after_n_steps.
"""
def __init__(self, config):
super().__init__(config)
self.enabled = True
def forward(self, w: torch.Tensor) -> torch.Tensor:
if not self.enabled:
return w
return super().forward(w)
# Replace the original Int4WeightFakeQuantizer in the fake_quantizer module
# so that torchao's quantize_() function will use our version
fake_quantizer.Int4WeightFakeQuantizer = Int4WeightFakeQuantizer
quantization_config_to_str = {
Int8DynamicActivationInt4WeightConfig: "int8int4",
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",

View File

@@ -619,25 +619,6 @@ class AxolotlInputConfig(
},
)
scaling_softmax: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use Scaled Softmax (SSMax) attention. Ref: https://arxiv.org/abs/2501.19399"
},
)
scaling_softmax_factor: float | None = Field(
default=None,
json_schema_extra={
"description": "Scaling factor for SSMax attention. Default is 0.43"
},
)
scaling_softmax_bias: float | None = Field(
default=None,
json_schema_extra={
"description": "Bias for SSMax attention. Default is 0.0. Note: The paper recommends bias=0 for better length generalization."
},
)
unsloth_cross_entropy_loss: bool | None = None
unsloth_lora_mlp: bool | None = None
unsloth_lora_qkv: bool | None = None
@@ -676,10 +657,6 @@ class AxolotlInputConfig(
"description": "Number of chunks to use for chunked cross entropy loss"
},
)
use_dynamic_finetuning: bool | None = Field(
default=None,
json_schema_extra={"description": "Enable Dynamic Fine-Tuning loss (DFT)"},
)
tiled_mlp: bool | None = Field(
default=None,

View File

@@ -25,12 +25,7 @@ class ModelInputConfig(BaseModel):
"description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model"
},
)
cls_model_config: str | None = Field(
default=None,
json_schema_extra={
"description": "transformers config class (e.g., 'LlamaConfig', 'MistralConfig'). Defaults to AutoConfig."
},
)
cls_model_config: str | None = None
tokenizer_config: str | None = Field(
default=None,
json_schema_extra={

View File

@@ -201,16 +201,6 @@ class AttentionValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_scaling_softmax_requires_flex(cls, data):
if data.get("scaling_softmax") and not data.get("flex_attention"):
raise ValueError(
"scaling_softmax requires flex_attention: true\n"
"Add 'flex_attention: true' to your config file.\n"
)
return data
class TrainingValidationMixin:
"""Validation methods related to training configuration."""

File diff suppressed because it is too large Load Diff

View File

@@ -1,92 +0,0 @@
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
class TestLoRAConfigValidation:
"""Test suite for LoRA/QLoRA configuration validation"""
def test_basic_configuration_validation(self):
"""Test basic LoRA configuration validation"""
valid_config = DictDefault(
{
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.1,
"lora_target_modules": ["q_proj", "v_proj"],
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
result = validate_config(valid_config)
assert result["adapter"] == "lora"
with pytest.raises(ValueError, match="not compatible with DoRA"):
invalid_config = DictDefault(
{
"adapter": "lora",
"lora_mlp_kernel": True,
"peft_use_dora": True,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
validate_config(invalid_config)
def test_qlora_4bit_validation(self):
"""Test QLoRA 4-bit configuration validation"""
valid_config = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": True,
"bnb_4bit_compute_dtype": "float16",
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
result = validate_config(valid_config)
assert result["adapter"] == "qlora"
assert result["load_in_4bit"] is True
# Test QLoRA without 4-bit (should fail via PEFT validation)
with pytest.raises(ValueError, match=r"Require cfg\.load_in_4bit"):
invalid_config = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": False,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
validate_config(invalid_config)
# Test QLoRA with 8-bit (incompatible)
with pytest.raises(ValueError, match="Can't load qlora in 8bit"):
invalid_config = DictDefault(
{
"adapter": "qlora",
"load_in_8bit": True,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
validate_config(invalid_config)

View File

@@ -1,261 +0,0 @@
import importlib.util
from unittest.mock import Mock
import pytest
import torch
import torch.nn as nn
from axolotl.kernels.lora import get_lora_parameters
PEFT_AVAILABLE = importlib.util.find_spec("peft") is not None
class TestLoRAParameterFreezing:
"""Test suite for LoRA parameter freezing validation."""
def setup_method(self):
self.dtype = torch.float32
def create_mock_lora_layer(
self, has_adapters=True, adapters_disabled=False, merged=False
):
"""Create a mock LoRA layer for testing."""
mock_layer = Mock()
base_layer = Mock()
base_layer.weight = torch.randn(512, 256, dtype=self.dtype)
base_layer.bias = torch.randn(512, dtype=self.dtype)
if has_adapters:
mock_layer.base_layer = base_layer
mock_layer.disable_adapters = adapters_disabled
mock_layer.merged = merged
mock_layer.active_adapters = ["default"]
mock_layer.lora_A = {"default": Mock()}
mock_layer.lora_B = {"default": Mock()}
mock_layer.scaling = {"default": 0.1}
mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype)
mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype)
else:
mock_layer.weight = base_layer.weight
mock_layer.bias = base_layer.bias
return mock_layer
def test_parameter_freezing_adapters_disabled(self):
"""Test that LoRA parameters are None when adapters are disabled."""
layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
assert b is not None
# LoRA parameters should be None (frozen)
assert A is None
assert B is None
assert s is None
def test_parameter_freezing_adapters_merged(self):
"""Test that LoRA parameters are None when adapters are merged."""
layer = self.create_mock_lora_layer(has_adapters=True, merged=True)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
assert b is not None
# LoRA parameters should be None (frozen)
assert A is None
assert B is None
assert s is None
def test_parameter_freezing_no_adapters(self):
"""Test parameter behavior when no adapters are present."""
layer = self.create_mock_lora_layer(has_adapters=False)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
assert b is not None
# LoRA parameters should be None (frozen)
assert A is None
assert B is None
assert s is None
def test_parameter_active_adapters_enabled(self):
"""Test that LoRA parameters are returned when adapters are active."""
layer = self.create_mock_lora_layer(
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
# All parameters should be returned
assert W is not None
assert b is not None
assert A is not None
assert B is not None
assert s is not None
assert s == 0.1
def test_parameter_shapes_consistency(self):
"""Test that parameter shapes are consistent when active."""
layer = self.create_mock_lora_layer(
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
# Check shape consistency
assert W.shape == (512, 256)
assert b.shape == (512,)
assert A.shape == (16, 256)
assert B.shape == (512, 16)
def test_parameter_dtypes_consistency(self):
"""Test that parameter dtypes are consistent."""
layer = self.create_mock_lora_layer(
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
assert W.dtype == self.dtype
assert b.dtype == self.dtype
assert A.dtype == self.dtype
assert B.dtype == self.dtype
def test_quantization_state_handling(self):
"""Test that quantization state is properly handled."""
layer = self.create_mock_lora_layer(has_adapters=True)
quant_state_mock = Mock()
layer.base_layer.weight.quant_state = quant_state_mock
W, b, quant_state, A, B, s = get_lora_parameters(layer)
assert quant_state == quant_state_mock
def test_multiple_adapters_active_adapter_selection(self):
"""Test that the correct adapter is selected when multiple adapters exist."""
layer = self.create_mock_lora_layer(
has_adapters=True, adapters_disabled=False, merged=False
)
layer.lora_A["adapter2"] = Mock()
layer.lora_B["adapter2"] = Mock()
layer.scaling["adapter2"] = 0.2
layer.lora_A["adapter2"].weight = torch.randn(16, 256, dtype=self.dtype)
layer.lora_B["adapter2"].weight = torch.randn(512, 16, dtype=self.dtype)
layer.active_adapters = ["adapter2"]
W, b, quant_state, A, B, s = get_lora_parameters(layer)
assert s == 0.2
assert torch.equal(A, layer.lora_A["adapter2"].weight)
assert torch.equal(B, layer.lora_B["adapter2"].weight)
class TestLoRAParameterFreezingIntegration:
"""Integration tests for parameter freezing with actual LoRA layers."""
@pytest.mark.skipif(
not PEFT_AVAILABLE, reason="PEFT not available for integration tests"
)
def test_parameter_freezing_with_real_lora_layer(self):
"""Test parameter freezing with actual PEFT LoRA layer."""
from peft import LoraConfig, get_peft_model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(256, 512)
def forward(self, x):
return self.linear(x)
base_model = SimpleModel()
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["linear"],
lora_dropout=0.1,
)
model = get_peft_model(base_model, lora_config)
lora_layer = model.base_model.model.linear
# Test with adapters enabled
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
assert A is not None
assert B is not None
assert s is not None
# Test with adapters disabled
model.disable_adapter_layers()
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
assert A is None
assert B is None
assert s is None
@pytest.mark.skipif(
not PEFT_AVAILABLE, reason="PEFT not available for integration tests"
)
def test_parameter_freezing_gradient_behavior(self):
"""Test that frozen parameters don't receive gradients."""
from peft import LoraConfig, get_peft_model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(256, 512)
def forward(self, x):
return self.linear(x)
base_model = SimpleModel()
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["linear"],
lora_dropout=0.1,
)
model = get_peft_model(base_model, lora_config)
x = torch.randn(1, 256)
target = torch.randn(1, 512)
model.enable_adapter_layers()
output = model(x)
loss = nn.MSELoss()(output, target)
loss.backward()
lora_layer = model.base_model.model.linear
has_lora_grads = any(
param.grad is not None
for name, param in lora_layer.named_parameters()
if "lora_" in name
)
assert has_lora_grads, (
"LoRA parameters should have gradients when adapters are enabled"
)
model.zero_grad()
model.disable_adapter_layers()
output = model(x)
loss = nn.MSELoss()(output, target)
any_requires_grad = any(param.requires_grad for param in model.parameters())
if any_requires_grad:
loss.backward()
has_lora_grads_disabled = any(
param.grad is not None
for name, param in lora_layer.named_parameters()
if "lora_" in name
)
assert not has_lora_grads_disabled, (
"LoRA parameters should not have gradients when adapters are disabled"
)
model.zero_grad()
del model, base_model, lora_layer, x, target, output, loss
torch.cuda.empty_cache() if torch.cuda.is_available() else None

View File

@@ -1,181 +0,0 @@
from unittest.mock import Mock, patch
import torch
from axolotl.cli.merge_lora import do_merge_lora
from axolotl.utils.dict import DictDefault
class TestAdapterMergeUnmerge:
"""Test suite for LoRA adapter merging/unmerging functionality"""
def setup_method(self):
self.dtype = torch.float32
self.device = torch.device("cpu")
def create_mock_base_model(self, vocab_size=1000, hidden_size=256):
"""Create a mock base model with linear layers"""
mock_model = Mock()
mock_model.config = Mock()
mock_model.config.vocab_size = vocab_size
mock_model.config.hidden_size = hidden_size
mock_model.q_proj = Mock()
mock_model.q_proj.weight = torch.randn(
hidden_size, hidden_size, dtype=self.dtype
)
mock_model.q_proj.bias = torch.randn(hidden_size, dtype=self.dtype)
mock_model.v_proj = Mock()
mock_model.v_proj.weight = torch.randn(
hidden_size, hidden_size, dtype=self.dtype
)
mock_model.v_proj.bias = torch.randn(hidden_size, dtype=self.dtype)
return mock_model
def create_mock_lora_model(self, base_model, r=8, alpha=16):
"""Create a mock LoRA model wrapping the base model"""
mock_lora_model = Mock()
mock_lora_model.base_model = base_model
mock_lora_model.merge_and_unload = None
mock_lora_model.to = Mock(return_value=mock_lora_model)
mock_lora_model.generation_config = Mock()
mock_lora_model.config = Mock()
self.original_q_weight = base_model.q_proj.weight.clone()
self.original_v_weight = base_model.v_proj.weight.clone()
mock_lora_model.peft_config = {"default": Mock()}
mock_lora_model.peft_config["default"].r = r
mock_lora_model.peft_config["default"].lora_alpha = alpha
self.lora_A_q = torch.randn(
r, base_model.q_proj.weight.shape[1], dtype=self.dtype
)
self.lora_B_q = torch.randn(
base_model.q_proj.weight.shape[0], r, dtype=self.dtype
)
self.lora_A_v = torch.randn(
r, base_model.v_proj.weight.shape[1], dtype=self.dtype
)
self.lora_B_v = torch.randn(
base_model.v_proj.weight.shape[0], r, dtype=self.dtype
)
self.scaling = alpha / r
def mock_merge_and_unload(progressbar=False):
"""Simulate the actual merge operation"""
# Apply LoRA delta to base weights: W_new = W_base + (B @ A) * scaling
delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
delta_v = (self.lora_B_v @ self.lora_A_v) * self.scaling
base_model.q_proj.weight = self.original_q_weight + delta_q
base_model.v_proj.weight = self.original_v_weight + delta_v
return base_model
mock_lora_model.merge_and_unload = mock_merge_and_unload
return mock_lora_model
def test_basic_lora_merge_unmerge_cycle(self):
"""Test: original_weights -> merge -> unmerge -> should equal original_weights"""
base_model = self.create_mock_base_model()
lora_model = self.create_mock_lora_model(base_model)
original_q_weight = self.original_q_weight.clone()
original_v_weight = self.original_v_weight.clone()
merged_model = lora_model.merge_and_unload()
assert not torch.equal(merged_model.q_proj.weight, original_q_weight)
assert not torch.equal(merged_model.v_proj.weight, original_v_weight)
delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
delta_v = (self.lora_B_v @ self.lora_A_v) * self.scaling
unmerged_q_weight = merged_model.q_proj.weight - delta_q
unmerged_v_weight = merged_model.v_proj.weight - delta_v
assert torch.allclose(unmerged_q_weight, original_q_weight, atol=1e-6)
assert torch.allclose(unmerged_v_weight, original_v_weight, atol=1e-6)
def test_merge_weight_calculation_accuracy(self):
"""Test: merged_weight = base_weight + (lora_B @ lora_A * scaling)"""
base_model = self.create_mock_base_model()
lora_model = self.create_mock_lora_model(base_model, r=16, alpha=32)
expected_delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
expected_merged_q = self.original_q_weight + expected_delta_q
merged_model = lora_model.merge_and_unload()
assert torch.allclose(merged_model.q_proj.weight, expected_merged_q, atol=1e-6)
@patch("axolotl.cli.merge_lora.load_model_and_tokenizer")
def test_cli_do_merge_functionality(self, mock_load_model, tmp_path):
base_model = self.create_mock_base_model()
lora_model = self.create_mock_lora_model(base_model)
tokenizer = Mock()
processor = None
mock_load_model.return_value = (lora_model, tokenizer, processor)
cfg = DictDefault(
{
"save_safetensors": True,
"torch_dtype": torch.float32,
"local_rank": 0,
"output_dir": str(tmp_path),
}
)
with (
patch("pathlib.Path.mkdir"),
patch.object(base_model, "save_pretrained") as mock_save_model,
patch.object(tokenizer, "save_pretrained") as mock_save_tokenizer,
):
do_merge_lora(cfg=cfg)
mock_save_model.assert_called_once()
mock_save_tokenizer.assert_called_once()
def test_quantized_model_merge_compatibility(self):
"""Test 4-bit/8-bit model merging scenarios"""
base_model = self.create_mock_base_model()
# Mock quantized weights
base_model.q_proj.weight.quant_state = Mock()
base_model.q_proj.weight.quant_state.dtype = torch.uint8
lora_model = self.create_mock_lora_model(base_model)
merged_model = lora_model.merge_and_unload()
assert merged_model is not None
@patch.dict("os.environ", {"CUDA_VISIBLE_DEVICES": ""})
def test_memory_efficient_merge_with_cpu_offload(self, tmp_path):
"""Test lora_on_cpu configuration during merge"""
cfg = DictDefault(
{
"lora_on_cpu": True,
"save_safetensors": True,
"output_dir": str(tmp_path),
"local_rank": 0,
}
)
with patch("axolotl.cli.merge_lora.load_model_and_tokenizer") as mock_load:
base_model = self.create_mock_base_model()
lora_model = self.create_mock_lora_model(base_model)
mock_load.return_value = (lora_model, Mock(), None)
with patch("pathlib.Path.mkdir"), patch("torch.save"):
do_merge_lora(cfg=cfg)
assert mock_load.called