Compare commits
1 Commits
dft
...
upgrade-to
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a08e4117a |
5
.github/PULL_REQUEST_TEMPLATE.md
vendored
5
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -15,11 +15,6 @@
|
||||
<!--- Include details of your testing environment, tests ran to see how -->
|
||||
<!--- your change affects other areas of the code, etc. -->
|
||||
|
||||
## AI Usage Disclaimer
|
||||
|
||||
<!--- Was AI (e.g., ChatGPT, Claude, Copilot) used to generate or assist with this PR? -->
|
||||
<!--- Please indicate: No / Yes (specify which tool and to what extent) -->
|
||||
|
||||
## Screenshots (if appropriate)
|
||||
|
||||
## Types of changes
|
||||
|
||||
16
.github/workflows/base.yml
vendored
16
.github/workflows/base.yml
vendored
@@ -21,8 +21,6 @@ jobs:
|
||||
timeout-minutes: 480
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: ubuntu-latest-m
|
||||
env:
|
||||
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -34,7 +32,6 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -42,7 +39,6 @@ jobs:
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -50,7 +46,6 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
@@ -58,7 +53,6 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: "128"
|
||||
# cuda_version: 12.8.1
|
||||
# cudnn_version: ""
|
||||
@@ -85,7 +79,6 @@ jobs:
|
||||
axolotlai/axolotl-base
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
@@ -96,7 +89,6 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/${{ matrix.dockerfile }}
|
||||
platforms: ${{ matrix.platforms }}
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
@@ -111,8 +103,6 @@ jobs:
|
||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||
timeout-minutes: 480
|
||||
runs-on: ubuntu-latest-m
|
||||
env:
|
||||
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -124,7 +114,6 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -132,7 +121,6 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -140,7 +128,6 @@ jobs:
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
@@ -148,7 +135,6 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -160,7 +146,6 @@ jobs:
|
||||
axolotlai/axolotl-base-uv
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
@@ -171,7 +156,6 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/${{ matrix.dockerfile }}
|
||||
platforms: ${{ matrix.platforms }}
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
43
.github/workflows/main.yml
vendored
43
.github/workflows/main.yml
vendored
@@ -20,26 +20,22 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64"
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
is_latest: true
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: 130
|
||||
# cuda_version: 13.0.0
|
||||
# python_version: "3.11"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -65,7 +61,6 @@ jobs:
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platforms }}
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
@@ -92,26 +87,22 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64"
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: 130
|
||||
# cuda_version: 13.0.0
|
||||
# python_version: "3.11"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -136,7 +127,6 @@ jobs:
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platforms }}
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
@@ -157,11 +147,11 @@ jobs:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
is_latest:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
@@ -190,7 +180,6 @@ jobs:
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
|
||||
7
.github/workflows/multi-gpu-e2e.yml
vendored
7
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -43,13 +43,6 @@ jobs:
|
||||
axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
steps:
|
||||
|
||||
6
.github/workflows/tests.yml
vendored
6
.github/workflows/tests.yml
vendored
@@ -316,12 +316,6 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -6,7 +6,6 @@ ARG AXOLOTL_EXTRAS=""
|
||||
ARG AXOLOTL_ARGS=""
|
||||
ARG CUDA="118"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
ARG TARGETARCH
|
||||
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
|
||||
@@ -21,17 +20,13 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi && \
|
||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||
fi && \ python scripts/unsloth_install.py | sh && \
|
||||
python scripts/unsloth_install.py | sh && \
|
||||
python scripts/cutcrossentropy_install.py | sh && \
|
||||
pip install pytest && \
|
||||
pip cache purge
|
||||
|
||||
@@ -2,16 +2,14 @@ ARG CUDA_VERSION="11.8.0"
|
||||
ARG CUDNN_VERSION="8"
|
||||
ARG UBUNTU_VERSION="22.04"
|
||||
ARG MAX_JOBS=4
|
||||
ARG TARGETARCH
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG TARGETARCH
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTHON_VERSION="3.10"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
ARG CUDA="128"
|
||||
ARG CUDA="118"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
@@ -24,17 +22,11 @@ RUN apt-get update \
|
||||
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
|
||||
&& rm -rf /var/cache/apt/archives \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
MINICONDA_ARCH="x86_64"; \
|
||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||
MINICONDA_ARCH="aarch64"; \
|
||||
else \
|
||||
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
|
||||
fi \
|
||||
&& wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
@@ -59,34 +51,8 @@ RUN git lfs install --skip-repo && \
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||
pip3 cache purge
|
||||
|
||||
RUN case "$PYTORCH_VERSION" in \
|
||||
2.9.[0-9]*) \
|
||||
if [ "$CUDA" = "128" ]; then \
|
||||
if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl"; \
|
||||
WHL_VERSION="v0.5.4"; \
|
||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl"; \
|
||||
WHL_VERSION="v0.6.4"; \
|
||||
else \
|
||||
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
|
||||
fi; \
|
||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
|
||||
pip3 install --no-cache-dir ${WHL_FILE}; \
|
||||
rm ${WHL_FILE}; \
|
||||
elif [ "$CUDA" = "130" ]; then \
|
||||
if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl"; \
|
||||
WHL_VERSION="v0.5.4"; \
|
||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl"; \
|
||||
WHL_VERSION="v0.6.4"; \
|
||||
else \
|
||||
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
|
||||
fi; \
|
||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
|
||||
pip3 install --no-cache-dir ${WHL_FILE}; \
|
||||
rm ${WHL_FILE}; \
|
||||
fi \
|
||||
;; \
|
||||
esac
|
||||
RUN if [ "$PYTORCH_VERSION" =~ ^2\.9\.[0-9]+$ ] && [ "$CUDA" = "128" ] ; then \
|
||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
fi
|
||||
|
||||
@@ -2,7 +2,6 @@ ARG CUDA_VERSION="12.6.3"
|
||||
ARG CUDNN_VERSION=""
|
||||
ARG UBUNTU_VERSION="22.04"
|
||||
ARG MAX_JOBS=4
|
||||
ARG TARGETARCH
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
@@ -32,35 +31,12 @@ ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
|
||||
|
||||
RUN uv pip install packaging setuptools wheel psutil \
|
||||
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
|
||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||
&& uv pip install awscli pydantic
|
||||
|
||||
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main"; \
|
||||
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
fi
|
||||
|
||||
RUN case "$PYTORCH_VERSION" in \
|
||||
2.9.[0-9]*) \
|
||||
if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
if [ "$CUDA" = "128" ]; then \
|
||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
elif [ "$CUDA" = "130" ]; then \
|
||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
fi \
|
||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||
if [ "$CUDA" = "128" ]; then \
|
||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||
elif [ "$CUDA" = "130" ]; then \
|
||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||
fi \
|
||||
fi \
|
||||
;; \
|
||||
esac
|
||||
|
||||
@@ -52,7 +52,6 @@ gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
scaling_softmax: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
base_model: google/gemma-3-1b-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
chat_template: gemma3
|
||||
eot_tokens:
|
||||
- <end_of_turn>
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/gemma-3-1b-fft-dft
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
use_dynamic_finetuning: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-5
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
@@ -1,7 +1,6 @@
|
||||
base_model: google/gemma-3-1b-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
@@ -30,7 +29,7 @@ output_dir: ./outputs/out
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
base_model: google/gemma-3-270m-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
@@ -30,7 +29,7 @@ output_dir: ./outputs/out
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
@@ -2,7 +2,6 @@ base_model: google/gemma-3-4b-it
|
||||
|
||||
# Need to set else transformers tries to load vision too
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
@@ -33,8 +32,8 @@ sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_linear: true
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
@@ -31,7 +31,7 @@ pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
|
||||
@@ -59,7 +59,6 @@ gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
scaling_softmax: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -1,285 +0,0 @@
|
||||
# SwanLab Integration Examples
|
||||
|
||||
This directory contains example configurations demonstrating SwanLab integration with Axolotl.
|
||||
|
||||
## Examples Overview
|
||||
|
||||
### 1. DPO with Completion Logging
|
||||
**File**: `dpo-swanlab-completions.yml`
|
||||
|
||||
Demonstrates DPO (Direct Preference Optimization) training with RLHF completion table logging.
|
||||
|
||||
**Features**:
|
||||
- Basic SwanLab experiment tracking
|
||||
- Completion table logging (prompts, chosen/rejected responses, rewards)
|
||||
- Memory-bounded buffer for long training runs
|
||||
- Cloud sync configuration
|
||||
|
||||
**Best for**: RLHF practitioners who want to analyze model outputs qualitatively
|
||||
|
||||
**Quick start**:
|
||||
```bash
|
||||
export SWANLAB_API_KEY=your-api-key
|
||||
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 2. LoRA with Performance Profiling
|
||||
**File**: `lora-swanlab-profiling.yml`
|
||||
|
||||
Demonstrates standard LoRA fine-tuning with performance profiling enabled.
|
||||
|
||||
**Features**:
|
||||
- SwanLab experiment tracking
|
||||
- Automatic profiling of trainer methods
|
||||
- Profiling metrics visualization
|
||||
- Performance optimization guidance
|
||||
|
||||
**Best for**: Engineers optimizing training performance and comparing different configurations
|
||||
|
||||
**Quick start**:
|
||||
```bash
|
||||
export SWANLAB_API_KEY=your-api-key
|
||||
accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 3. Full-Featured DPO Production Setup
|
||||
**File**: `dpo-swanlab-full-featured.yml`
|
||||
|
||||
Comprehensive production-ready configuration with ALL SwanLab features enabled.
|
||||
|
||||
**Features**:
|
||||
- Experiment tracking with team workspace
|
||||
- RLHF completion logging
|
||||
- Performance profiling
|
||||
- Lark (Feishu) team notifications
|
||||
- Private deployment support
|
||||
- Production checklist and troubleshooting
|
||||
|
||||
**Best for**: Production RLHF training with team collaboration
|
||||
|
||||
**Quick start**:
|
||||
```bash
|
||||
export SWANLAB_API_KEY=your-api-key
|
||||
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
|
||||
export SWANLAB_LARK_SECRET=your-webhook-secret
|
||||
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4. Custom Trainer Profiling (Python)
|
||||
**File**: `custom_trainer_profiling.py`
|
||||
|
||||
Python code examples showing how to add SwanLab profiling to custom trainers.
|
||||
|
||||
**Features**:
|
||||
- `@swanlab_profile` decorator examples
|
||||
- Context manager profiling for fine-grained timing
|
||||
- `ProfilingConfig` for advanced filtering and throttling
|
||||
- Multiple profiling patterns and best practices
|
||||
|
||||
**Best for**: Advanced users creating custom trainers
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from custom_trainer_profiling import CustomTrainerWithProfiling
|
||||
# See file for detailed examples and patterns
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Feature Matrix
|
||||
|
||||
| Example | Tracking | Completion Logging | Profiling | Lark Notifications | Team Workspace |
|
||||
|---------|----------|-------------------|-----------|-------------------|----------------|
|
||||
| dpo-swanlab-completions.yml | ✅ | ✅ | ✅ (auto) | ➖ (commented) | ➖ (commented) |
|
||||
| lora-swanlab-profiling.yml | ✅ | ➖ (disabled) | ✅ (auto) | ➖ (commented) | ➖ (commented) |
|
||||
| dpo-swanlab-full-featured.yml | ✅ | ✅ | ✅ (auto) | ✅ | ✅ |
|
||||
| custom_trainer_profiling.py | N/A | N/A | ✅ (manual) | N/A | N/A |
|
||||
|
||||
---
|
||||
|
||||
## Configuration Quick Reference
|
||||
|
||||
### Basic SwanLab Setup
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||
|
||||
use_swanlab: true
|
||||
swanlab_project: my-project
|
||||
swanlab_experiment_name: my-experiment
|
||||
swanlab_mode: cloud # cloud, local, offline, disabled
|
||||
```
|
||||
|
||||
### RLHF Completion Logging
|
||||
```yaml
|
||||
swanlab_log_completions: true
|
||||
swanlab_completion_log_interval: 100 # Log every 100 steps
|
||||
swanlab_completion_max_buffer: 128 # Memory-bounded buffer
|
||||
```
|
||||
|
||||
### Lark Team Notifications
|
||||
```yaml
|
||||
swanlab_lark_webhook_url: https://open.feishu.cn/...
|
||||
swanlab_lark_secret: your-webhook-secret # Required for production
|
||||
```
|
||||
|
||||
### Team Workspace
|
||||
```yaml
|
||||
swanlab_workspace: my-research-team
|
||||
```
|
||||
|
||||
### Private Deployment
|
||||
```yaml
|
||||
swanlab_web_host: https://swanlab.yourcompany.com
|
||||
swanlab_api_host: https://api.swanlab.yourcompany.com
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Authentication
|
||||
|
||||
### Recommended: Environment Variable
|
||||
```bash
|
||||
export SWANLAB_API_KEY=your-api-key
|
||||
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
|
||||
export SWANLAB_LARK_SECRET=your-webhook-secret
|
||||
```
|
||||
|
||||
### Alternative: Config File (less secure)
|
||||
```yaml
|
||||
swanlab_api_key: your-api-key
|
||||
swanlab_lark_webhook_url: https://open.feishu.cn/...
|
||||
swanlab_lark_secret: your-webhook-secret
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Common Use Cases
|
||||
|
||||
### Use Case 1: Migrate from WandB to SwanLab
|
||||
Start with `lora-swanlab-profiling.yml`, add your model/dataset config, disable WandB:
|
||||
```yaml
|
||||
use_swanlab: true
|
||||
use_wandb: false
|
||||
```
|
||||
|
||||
### Use Case 2: Analyze DPO Model Outputs
|
||||
Use `dpo-swanlab-completions.yml`, adjust completion logging interval based on your training length:
|
||||
```yaml
|
||||
swanlab_completion_log_interval: 50 # More frequent for short training
|
||||
swanlab_completion_log_interval: 200 # Less frequent for long training
|
||||
```
|
||||
|
||||
### Use Case 3: Optimize Training Performance
|
||||
Use `lora-swanlab-profiling.yml`, run multiple experiments with different optimizations:
|
||||
- Baseline: `flash_attention: false, gradient_checkpointing: false`
|
||||
- Flash Attention: `flash_attention: true`
|
||||
- Gradient Checkpointing: `gradient_checkpointing: true`
|
||||
- Both: `flash_attention: true, gradient_checkpointing: true`
|
||||
|
||||
Compare profiling metrics in SwanLab dashboard.
|
||||
|
||||
### Use Case 4: Production RLHF with Team Collaboration
|
||||
Use `dpo-swanlab-full-featured.yml`, set up team workspace and Lark notifications:
|
||||
```yaml
|
||||
swanlab_workspace: ml-team
|
||||
swanlab_lark_webhook_url: ...
|
||||
swanlab_lark_secret: ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Viewing Your Experiments
|
||||
|
||||
### Cloud Mode
|
||||
Visit [https://swanlab.cn](https://swanlab.cn) and navigate to your project.
|
||||
|
||||
**Dashboard sections**:
|
||||
- **Metrics**: Training loss, learning rate, profiling metrics
|
||||
- **Tables**: RLHF completions (for DPO/KTO/ORPO/GRPO)
|
||||
- **Config**: Hyperparameters and configuration
|
||||
- **System**: Resource usage (GPU, memory, CPU)
|
||||
- **Files**: Logged artifacts
|
||||
|
||||
### Local Mode
|
||||
```bash
|
||||
swanlab watch ./swanlog
|
||||
# Open browser to http://localhost:5092
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### SwanLab not initializing
|
||||
```bash
|
||||
# Check API key
|
||||
echo $SWANLAB_API_KEY
|
||||
|
||||
# Verify SwanLab is installed
|
||||
pip show swanlab
|
||||
|
||||
# Check config
|
||||
grep -A 5 "use_swanlab" your-config.yml
|
||||
```
|
||||
|
||||
### Completions not appearing
|
||||
- Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
|
||||
- Check `swanlab_log_completions: true`
|
||||
- Wait for `swanlab_completion_log_interval` steps
|
||||
- Look for "Registered SwanLab RLHF completion logging" in logs
|
||||
|
||||
### Lark notifications not working
|
||||
- Test webhook manually: `curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...`
|
||||
- Verify `SWANLAB_LARK_SECRET` is set correctly
|
||||
- Check bot is added to Lark group chat
|
||||
- Look for "Registered Lark notification callback" in logs
|
||||
|
||||
### Profiling metrics not appearing
|
||||
- Verify `use_swanlab: true`
|
||||
- Check SwanLab is initialized (look for init log message)
|
||||
- Profiling metrics are under "profiling/" namespace
|
||||
- Profiling auto-enabled when SwanLab is enabled
|
||||
|
||||
---
|
||||
|
||||
## Performance Notes
|
||||
|
||||
### Overhead Comparison
|
||||
|
||||
| Feature | Overhead per Step | Memory Usage |
|
||||
|---------|------------------|--------------|
|
||||
| Basic tracking | < 0.1% | ~10 MB |
|
||||
| Completion logging | < 0.5% | ~64 KB (buffer=128) |
|
||||
| Profiling | < 0.1% | ~1 KB |
|
||||
| **Total** | **< 0.7%** | **~10 MB** |
|
||||
|
||||
### Best Practices
|
||||
1. Use ONE logging tool in production (disable WandB/MLflow when using SwanLab)
|
||||
2. Adjust completion log interval based on training length (100-200 steps)
|
||||
3. Keep completion buffer size reasonable (128-512)
|
||||
4. Profile critical path methods first (training_step, compute_loss)
|
||||
5. Use ProfilingConfig to throttle high-frequency operations
|
||||
|
||||
---
|
||||
|
||||
## Further Reading
|
||||
|
||||
- **Full Documentation**: [src/axolotl/integrations/swanlab/README.md](../../src/axolotl/integrations/swanlab/README.md)
|
||||
- **SwanLab Docs**: [https://docs.swanlab.cn](https://docs.swanlab.cn)
|
||||
- **Axolotl Docs**: [https://axolotl-ai-cloud.github.io/axolotl/](https://axolotl-ai-cloud.github.io/axolotl/)
|
||||
- **DPO Paper**: [Direct Preference Optimization](https://arxiv.org/abs/2305.18290)
|
||||
|
||||
---
|
||||
|
||||
## Contributing
|
||||
|
||||
Found an issue or have an improvement? Please submit a PR or open an issue:
|
||||
- [Axolotl Issues](https://github.com/axolotl-ai-cloud/axolotl/issues)
|
||||
- [SwanLab Issues](https://github.com/SwanHubX/SwanLab/issues)
|
||||
@@ -1,299 +0,0 @@
|
||||
"""Example: Custom Trainer with SwanLab Profiling
|
||||
|
||||
This example demonstrates how to add SwanLab profiling to your custom trainer.
|
||||
|
||||
Features:
|
||||
- @swanlab_profile decorator for automatic profiling
|
||||
- swanlab_profiling_context for fine-grained profiling
|
||||
- ProfilingConfig for advanced filtering and throttling
|
||||
|
||||
Usage:
|
||||
1. Create your custom trainer extending AxolotlTrainer
|
||||
2. Add @swanlab_profile decorators to methods you want to profile
|
||||
3. Use swanlab_profiling_context for fine-grained profiling within methods
|
||||
4. Enable SwanLab in your config (use_swanlab: true)
|
||||
|
||||
See also:
|
||||
- examples/swanlab/lora-swanlab-profiling.yml for config
|
||||
- src/axolotl/integrations/swanlab/profiling.py for implementation
|
||||
"""
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.integrations.swanlab.profiling import (
|
||||
ProfilingConfig,
|
||||
swanlab_profile,
|
||||
swanlab_profiling_context,
|
||||
swanlab_profiling_context_advanced,
|
||||
)
|
||||
|
||||
|
||||
class CustomTrainerWithProfiling(AxolotlTrainer):
|
||||
"""Custom trainer with SwanLab profiling enabled.
|
||||
|
||||
This trainer demonstrates three profiling patterns:
|
||||
1. Decorator-based profiling (@swanlab_profile)
|
||||
2. Context manager profiling (swanlab_profiling_context)
|
||||
3. Advanced profiling with filtering (ProfilingConfig)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Create custom profiling config for high-frequency operations
|
||||
self.fast_op_config = ProfilingConfig(
|
||||
enabled=True,
|
||||
min_duration_ms=0.5, # Only log if duration > 0.5ms
|
||||
log_interval=50, # Log every 50th call
|
||||
)
|
||||
|
||||
# ========================================================================
|
||||
# Pattern 1: Decorator-based Profiling
|
||||
# ========================================================================
|
||||
# Best for: Methods you always want to profile
|
||||
# Overhead: ~2-5 microseconds per call (negligible)
|
||||
|
||||
@swanlab_profile
|
||||
def training_step(self, model, inputs):
|
||||
"""Main training step - always profile.
|
||||
|
||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.training_step
|
||||
"""
|
||||
return super().training_step(model, inputs)
|
||||
|
||||
@swanlab_profile
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
"""Loss computation - always profile.
|
||||
|
||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.compute_loss
|
||||
"""
|
||||
return super().compute_loss(model, inputs, return_outputs)
|
||||
|
||||
@swanlab_profile
|
||||
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
|
||||
"""Prediction step - always profile.
|
||||
|
||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prediction_step
|
||||
"""
|
||||
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
|
||||
|
||||
# ========================================================================
|
||||
# Pattern 2: Fine-grained Context Manager Profiling
|
||||
# ========================================================================
|
||||
# Best for: Profiling specific code blocks within a method
|
||||
# Use case: When you want to profile forward vs backward separately
|
||||
|
||||
def complex_training_step(self, model, inputs):
|
||||
"""Training step with fine-grained profiling.
|
||||
|
||||
Profiling metrics:
|
||||
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
|
||||
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
|
||||
- profiling/Time taken: CustomTrainerWithProfiling.optimizer_step
|
||||
"""
|
||||
# Profile just the forward pass
|
||||
with swanlab_profiling_context(self, "forward_pass"):
|
||||
outputs = model(**inputs)
|
||||
loss = outputs.loss
|
||||
|
||||
# Profile just the backward pass
|
||||
with swanlab_profiling_context(self, "backward_pass"):
|
||||
loss.backward()
|
||||
|
||||
# Profile optimizer step
|
||||
with swanlab_profiling_context(self, "optimizer_step"):
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
return outputs
|
||||
|
||||
# ========================================================================
|
||||
# Pattern 3: Advanced Profiling with Filtering
|
||||
# ========================================================================
|
||||
# Best for: High-frequency operations where you want to throttle logging
|
||||
# Use case: Methods called 100+ times per step
|
||||
|
||||
def _prepare_inputs(self, inputs):
|
||||
"""Prepare inputs - throttled profiling.
|
||||
|
||||
This method is called frequently (once per batch), so we throttle
|
||||
profiling to reduce overhead:
|
||||
- Only log if duration > 0.5ms (skip very fast operations)
|
||||
- Only log every 50th call (reduce logging frequency)
|
||||
|
||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_inputs
|
||||
"""
|
||||
with swanlab_profiling_context_advanced(
|
||||
self, "prepare_inputs", config=self.fast_op_config
|
||||
):
|
||||
return super()._prepare_inputs(inputs)
|
||||
|
||||
def _prepare_input_for_model(self, input_ids):
|
||||
"""Another high-frequency operation - throttled profiling.
|
||||
|
||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_input_for_model
|
||||
"""
|
||||
with swanlab_profiling_context_advanced(
|
||||
self, "prepare_input_for_model", config=self.fast_op_config
|
||||
):
|
||||
# Your custom input preparation logic
|
||||
return input_ids
|
||||
|
||||
# ========================================================================
|
||||
# Pattern 4: Exception-safe Profiling
|
||||
# ========================================================================
|
||||
# Profiling is exception-safe: duration is logged even if method raises
|
||||
|
||||
@swanlab_profile
|
||||
def potentially_failing_method(self):
|
||||
"""This method may raise an exception.
|
||||
|
||||
SwanLab profiling will still log the duration before re-raising.
|
||||
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.potentially_failing_method
|
||||
"""
|
||||
# Do some work
|
||||
result = self._do_risky_computation()
|
||||
|
||||
# If this raises, profiling duration is still logged
|
||||
if result < 0:
|
||||
raise ValueError("Invalid result")
|
||||
|
||||
return result
|
||||
|
||||
def _do_risky_computation(self):
|
||||
"""Placeholder for risky computation."""
|
||||
return 42
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Advanced Example: Custom ProfilingConfig Per Method
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AdvancedProfilingTrainer(AxolotlTrainer):
|
||||
"""Trainer with method-specific profiling configurations."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Different profiling configs for different method types
|
||||
self.critical_path_config = ProfilingConfig(
|
||||
enabled=True,
|
||||
min_duration_ms=0.0, # Log everything on critical path
|
||||
log_interval=1, # Log every call
|
||||
)
|
||||
|
||||
self.fast_path_config = ProfilingConfig(
|
||||
enabled=True,
|
||||
min_duration_ms=1.0, # Only log if > 1ms
|
||||
log_interval=100, # Log every 100th call
|
||||
)
|
||||
|
||||
self.debug_config = ProfilingConfig(
|
||||
enabled=True,
|
||||
min_duration_ms=0.0, # Log everything
|
||||
log_interval=1, # Log every call
|
||||
)
|
||||
|
||||
def training_step(self, model, inputs):
|
||||
"""Critical path - log everything."""
|
||||
with swanlab_profiling_context_advanced(
|
||||
self, "training_step", config=self.critical_path_config
|
||||
):
|
||||
return super().training_step(model, inputs)
|
||||
|
||||
def _prepare_inputs(self, inputs):
|
||||
"""Fast path - throttle logging."""
|
||||
with swanlab_profiling_context_advanced(
|
||||
self, "prepare_inputs", config=self.fast_path_config
|
||||
):
|
||||
return super()._prepare_inputs(inputs)
|
||||
|
||||
def _debug_method(self, data):
|
||||
"""Debug-only method - verbose logging."""
|
||||
with swanlab_profiling_context_advanced(
|
||||
self, "debug_method", config=self.debug_config
|
||||
):
|
||||
# Your debug logic
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# How to Use This Custom Trainer
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
To use this custom trainer:
|
||||
|
||||
1. Save this file to your project (e.g., my_custom_trainer.py)
|
||||
|
||||
2. Create a config file that uses your custom trainer:
|
||||
|
||||
# config.yml
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
|
||||
# ... other config ...
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||
|
||||
use_swanlab: true
|
||||
swanlab_project: my-profiling-experiment
|
||||
|
||||
# Optional: Specify custom trainer
|
||||
# (Or modify axolotl to use your custom trainer class)
|
||||
|
||||
3. Run training:
|
||||
|
||||
export SWANLAB_API_KEY=your-api-key
|
||||
accelerate launch -m axolotl.cli.train config.yml
|
||||
|
||||
4. View profiling metrics in SwanLab dashboard:
|
||||
- profiling/Time taken: CustomTrainerWithProfiling.training_step
|
||||
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
|
||||
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
|
||||
- etc.
|
||||
|
||||
5. Compare profiling metrics across runs:
|
||||
- Run baseline without optimizations
|
||||
- Run with flash_attention enabled
|
||||
- Run with gradient_checkpointing enabled
|
||||
- Compare profiling metrics to see performance impact
|
||||
"""
|
||||
|
||||
# ============================================================================
|
||||
# Tips for Effective Profiling
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
1. Profile the critical path first:
|
||||
- training_step, compute_loss, prediction_step
|
||||
- These methods are called most frequently and have biggest impact
|
||||
|
||||
2. Use throttling for high-frequency operations:
|
||||
- Methods called 100+ times per step
|
||||
- Use log_interval=50 or log_interval=100
|
||||
- Reduces profiling overhead and dashboard clutter
|
||||
|
||||
3. Filter noise with min_duration_ms:
|
||||
- Set min_duration_ms=1.0 to skip very fast operations
|
||||
- Focus on operations that actually take time
|
||||
|
||||
4. Compare across runs:
|
||||
- Run same config multiple times to check consistency
|
||||
- Compare different optimization strategies
|
||||
- Track profiling trends over time
|
||||
|
||||
5. Monitor distributed training:
|
||||
- Check for per-rank timing differences
|
||||
- Look for stragglers (slower ranks)
|
||||
- Identify synchronization bottlenecks
|
||||
|
||||
6. Disable profiling in production:
|
||||
- from axolotl.integrations.swanlab.profiling import DEFAULT_PROFILING_CONFIG
|
||||
- DEFAULT_PROFILING_CONFIG.enabled = False
|
||||
|
||||
7. Exception handling:
|
||||
- Profiling is exception-safe
|
||||
- Duration logged even if method raises
|
||||
- Useful for debugging methods that fail intermittently
|
||||
"""
|
||||
@@ -1,168 +0,0 @@
|
||||
# SwanLab DPO Training Example with Completion Logging
|
||||
#
|
||||
# This example demonstrates DPO (Direct Preference Optimization) training
|
||||
# with SwanLab integration for experiment tracking and completion table logging.
|
||||
#
|
||||
# Features enabled:
|
||||
# - SwanLab experiment tracking
|
||||
# - RLHF completion table logging (prompts, chosen/rejected responses, rewards)
|
||||
# - Lark (Feishu) team notifications (optional)
|
||||
#
|
||||
# To run:
|
||||
# export SWANLAB_API_KEY=your-api-key
|
||||
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
|
||||
|
||||
# Model Configuration
|
||||
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
eos_token: <|eot_id|>
|
||||
|
||||
# Quantization
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
# LoRA Configuration
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
# DPO Configuration
|
||||
chat_template: llama3
|
||||
rl: dpo
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||
type: chat_template.default
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
# Dataset and Output
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/dpo-swanlab-out
|
||||
|
||||
# Training Configuration
|
||||
sequence_len: 4096
|
||||
sample_packing: false
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 4
|
||||
|
||||
# Optimization
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
|
||||
# Precision
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
# Performance
|
||||
gradient_checkpointing: true
|
||||
flash_attention: true
|
||||
|
||||
# Checkpointing and Logging
|
||||
logging_steps: 1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
|
||||
# ============================================================================
|
||||
# SwanLab Integration
|
||||
# ============================================================================
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||
|
||||
# Basic SwanLab Configuration
|
||||
use_swanlab: true
|
||||
swanlab_project: dpo-training
|
||||
swanlab_experiment_name: llama-3-dpo-completions-demo
|
||||
swanlab_description: "DPO training with completion table logging"
|
||||
swanlab_mode: cloud # Options: cloud, local, offline, disabled
|
||||
|
||||
# SwanLab Authentication
|
||||
# Recommended: Set via environment variable
|
||||
# export SWANLAB_API_KEY=your-api-key
|
||||
# Or set in config (less secure):
|
||||
# swanlab_api_key: your-api-key
|
||||
|
||||
# Optional: Team workspace
|
||||
# swanlab_workspace: my-research-team
|
||||
|
||||
# ============================================================================
|
||||
# RLHF Completion Table Logging
|
||||
# ============================================================================
|
||||
#
|
||||
# Automatically logs model completions to SwanLab for qualitative analysis:
|
||||
# - Prompts from your DPO dataset
|
||||
# - Chosen responses (preferred)
|
||||
# - Rejected responses (non-preferred)
|
||||
# - Reward differences
|
||||
#
|
||||
# View the table in SwanLab dashboard under "rlhf_completions"
|
||||
|
||||
swanlab_log_completions: true
|
||||
swanlab_completion_log_interval: 100 # Log every 100 training steps
|
||||
swanlab_completion_max_buffer: 128 # Keep last 128 completions in memory
|
||||
|
||||
# Memory Usage Notes:
|
||||
# - Buffer size 128: ~64 KB (default, recommended)
|
||||
# - Buffer size 512: ~256 KB (for more historical completions)
|
||||
# - Buffer size 1024: ~512 KB (maximum for very long training runs)
|
||||
|
||||
# Performance Notes:
|
||||
# - Completion logging overhead: < 0.5% per training step
|
||||
# - Only logs every N steps to minimize impact
|
||||
# - Memory-bounded buffer prevents memory leaks
|
||||
|
||||
# ============================================================================
|
||||
# Optional: Lark (Feishu) Team Notifications
|
||||
# ============================================================================
|
||||
#
|
||||
# Get real-time training notifications in your team chat
|
||||
# Uncomment to enable:
|
||||
|
||||
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
|
||||
# swanlab_lark_secret: your-webhook-secret # Recommended for production
|
||||
|
||||
# Notifications sent for:
|
||||
# - Training start
|
||||
# - Training completion
|
||||
# - Training errors
|
||||
# - Metric milestones (if configured)
|
||||
|
||||
# ============================================================================
|
||||
# Optional: Private SwanLab Deployment
|
||||
# ============================================================================
|
||||
#
|
||||
# For enterprise users with private SwanLab deployment:
|
||||
|
||||
# swanlab_web_host: https://swanlab.yourcompany.com
|
||||
# swanlab_api_host: https://api.swanlab.yourcompany.com
|
||||
|
||||
# ============================================================================
|
||||
# Disable WandB if you're migrating from it
|
||||
# ============================================================================
|
||||
|
||||
# wandb_project:
|
||||
# wandb_entity:
|
||||
# use_wandb: false
|
||||
@@ -1,329 +0,0 @@
|
||||
# SwanLab Full-Featured DPO Training Example
|
||||
#
|
||||
# This example demonstrates ALL SwanLab integration features:
|
||||
# - Experiment tracking with cloud sync
|
||||
# - RLHF completion table logging
|
||||
# - Performance profiling
|
||||
# - Lark (Feishu) team notifications
|
||||
# - Team workspace collaboration
|
||||
#
|
||||
# Use this as a reference for production RLHF training setups.
|
||||
#
|
||||
# To run:
|
||||
# export SWANLAB_API_KEY=your-api-key
|
||||
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
|
||||
# export SWANLAB_LARK_SECRET=your-webhook-secret
|
||||
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
|
||||
|
||||
# ============================================================================
|
||||
# Model Configuration
|
||||
# ============================================================================
|
||||
|
||||
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
eos_token: <|eot_id|>
|
||||
|
||||
# Quantization for efficient training
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
# ============================================================================
|
||||
# LoRA Configuration
|
||||
# ============================================================================
|
||||
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true # Target all linear layers
|
||||
|
||||
# ============================================================================
|
||||
# DPO (Direct Preference Optimization) Configuration
|
||||
# ============================================================================
|
||||
|
||||
chat_template: llama3
|
||||
rl: dpo # Enable DPO trainer
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||
type: chat_template.default
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
# ============================================================================
|
||||
# Dataset and Output Configuration
|
||||
# ============================================================================
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/dpo-swanlab-full-featured-out
|
||||
|
||||
# ============================================================================
|
||||
# Training Configuration
|
||||
# ============================================================================
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: false
|
||||
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
num_epochs: 4
|
||||
|
||||
# ============================================================================
|
||||
# Optimization
|
||||
# ============================================================================
|
||||
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
|
||||
# ============================================================================
|
||||
# Precision and Performance
|
||||
# ============================================================================
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
flash_attention: true
|
||||
|
||||
# ============================================================================
|
||||
# Checkpointing and Logging
|
||||
# ============================================================================
|
||||
|
||||
logging_steps: 1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
|
||||
# ============================================================================
|
||||
# SwanLab Integration - Full Configuration
|
||||
# ============================================================================
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Basic SwanLab Configuration
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
use_swanlab: true
|
||||
swanlab_project: dpo-production
|
||||
swanlab_experiment_name: llama-3-dpo-full-featured-v1
|
||||
swanlab_description: |
|
||||
Production DPO training with all SwanLab features enabled:
|
||||
- Completion table logging for qualitative analysis
|
||||
- Performance profiling for optimization
|
||||
- Lark notifications for team collaboration
|
||||
|
||||
swanlab_mode: cloud # Options: cloud, local, offline, disabled
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Team Collaboration
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Workspace for team collaboration (shared experiments)
|
||||
swanlab_workspace: ml-research-team
|
||||
|
||||
# Authentication (recommended: use environment variable)
|
||||
# export SWANLAB_API_KEY=your-api-key
|
||||
# Or set in config (less secure):
|
||||
# swanlab_api_key: your-api-key
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# RLHF Completion Table Logging
|
||||
# ------------------------------------------------------------------------------
|
||||
# Automatically logs model completions for qualitative analysis:
|
||||
# - Prompts from your DPO dataset
|
||||
# - Chosen responses (preferred)
|
||||
# - Rejected responses (non-preferred)
|
||||
# - Reward differences
|
||||
#
|
||||
# View in SwanLab dashboard under "rlhf_completions" table
|
||||
|
||||
swanlab_log_completions: true
|
||||
swanlab_completion_log_interval: 100 # Log every 100 steps
|
||||
swanlab_completion_max_buffer: 256 # Larger buffer for long training runs
|
||||
|
||||
# Buffer size recommendations:
|
||||
# - 128: Default, ~64 KB memory (recommended for most cases)
|
||||
# - 256: ~128 KB memory (this config, good for longer training)
|
||||
# - 512: ~256 KB memory (maximum for very long runs)
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Lark (Feishu) Team Notifications
|
||||
# ------------------------------------------------------------------------------
|
||||
# Get real-time training notifications in your team chat
|
||||
#
|
||||
# Notifications sent for:
|
||||
# - Training start
|
||||
# - Training completion
|
||||
# - Training errors
|
||||
# - Metric milestones (if configured)
|
||||
|
||||
# Recommended: Set via environment variables
|
||||
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
|
||||
# export SWANLAB_LARK_SECRET=your-webhook-secret
|
||||
|
||||
# Or set in config (less secure):
|
||||
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
|
||||
# swanlab_lark_secret: your-webhook-secret # REQUIRED for production
|
||||
|
||||
# Security note: ALWAYS use swanlab_lark_secret in production to prevent
|
||||
# unauthorized parties from sending fake notifications to your team chat.
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Performance Profiling
|
||||
# ------------------------------------------------------------------------------
|
||||
# Profiling is automatically enabled when SwanLab is enabled.
|
||||
# Metrics logged to SwanLab under "profiling/" namespace:
|
||||
# profiling/Time taken: AxolotlTrainer.training_step
|
||||
# profiling/Time taken: AxolotlTrainer.compute_loss
|
||||
# profiling/Time taken: AxolotlTrainer.prediction_step
|
||||
#
|
||||
# Use these metrics to:
|
||||
# - Identify bottlenecks in training loop
|
||||
# - Compare performance across different configurations
|
||||
# - Monitor performance regressions over time
|
||||
# - Debug unexpected slowdowns
|
||||
|
||||
# For custom profiling in your own trainer, see:
|
||||
# examples/swanlab/custom_trainer_profiling.py
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Optional: Private SwanLab Deployment
|
||||
# ------------------------------------------------------------------------------
|
||||
# For enterprise users with private SwanLab deployment:
|
||||
|
||||
# swanlab_web_host: https://swanlab.yourcompany.com
|
||||
# swanlab_api_host: https://api.swanlab.yourcompany.com
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Optional: Model Checkpointing to SwanLab
|
||||
# ------------------------------------------------------------------------------
|
||||
# Log model checkpoints to SwanLab (coming soon)
|
||||
|
||||
swanlab_log_model: false
|
||||
|
||||
# ============================================================================
|
||||
# Disable Other Logging Tools (Recommended)
|
||||
# ============================================================================
|
||||
# Using multiple logging tools simultaneously can impact performance:
|
||||
# - Expected overhead: ~1-2% per logger
|
||||
# - Potential config/callback conflicts
|
||||
#
|
||||
# For production training, use ONLY SwanLab:
|
||||
|
||||
# wandb_project:
|
||||
# use_wandb: false
|
||||
#
|
||||
# use_mlflow: false
|
||||
#
|
||||
# use_comet: false
|
||||
|
||||
# ============================================================================
|
||||
# Expected Training Behavior
|
||||
# ============================================================================
|
||||
|
||||
# With this configuration, you should see:
|
||||
#
|
||||
# 1. SwanLab Initialization (rank 0 only):
|
||||
# INFO: SwanLab initialized for project: dpo-production
|
||||
# INFO: SwanLab experiment: llama-3-dpo-full-featured-v1
|
||||
# INFO: SwanLab mode: cloud
|
||||
# INFO: SwanLab workspace: ml-research-team
|
||||
#
|
||||
# 2. Completion Logging (rank 0 only):
|
||||
# INFO: Registered SwanLab RLHF completion logging callback for DPOTrainer
|
||||
# (log_interval=100, max_buffer=256)
|
||||
#
|
||||
# 3. Lark Notifications (rank 0 only):
|
||||
# INFO: Registered Lark notification callback with HMAC authentication
|
||||
#
|
||||
# 4. Distributed Training Detection (if multi-GPU):
|
||||
# INFO: Distributed training detected (world_size=N)
|
||||
# INFO: Only rank 0 will initialize SwanLab
|
||||
# INFO: Other ranks will skip SwanLab to avoid conflicts
|
||||
#
|
||||
# 5. Training Start Notification (Lark):
|
||||
# Your team chat receives: "Training started: llama-3-dpo-full-featured-v1"
|
||||
#
|
||||
# 6. Periodic Completion Logging:
|
||||
# Every 100 steps, completion table is updated in SwanLab dashboard
|
||||
#
|
||||
# 7. Training Complete Notification (Lark):
|
||||
# Your team chat receives: "Training completed: llama-3-dpo-full-featured-v1"
|
||||
# With link to SwanLab dashboard and final metrics
|
||||
#
|
||||
# 8. SwanLab Dashboard Shows:
|
||||
# - Training metrics (loss, learning rate, etc.)
|
||||
# - Completion table (rlhf_completions)
|
||||
# - Profiling metrics (profiling/Time taken: ...)
|
||||
# - Hyperparameters and configuration
|
||||
# - System resource usage
|
||||
|
||||
# ============================================================================
|
||||
# Production Checklist
|
||||
# ============================================================================
|
||||
|
||||
# Before deploying to production, verify:
|
||||
# ✅ SwanLab API key is set via environment variable (not in config)
|
||||
# ✅ Lark webhook secret is set (required for HMAC authentication)
|
||||
# ✅ Workspace is set to your team's workspace
|
||||
# ✅ Experiment name is descriptive and unique
|
||||
# ✅ Only SwanLab is enabled (other loggers disabled)
|
||||
# ✅ Completion logging buffer size is appropriate for your training duration
|
||||
# ✅ Private deployment hosts are set (if using enterprise SwanLab)
|
||||
# ✅ Test run completes successfully and shows up in SwanLab dashboard
|
||||
# ✅ Lark notifications are received in team chat
|
||||
# ✅ Profiling metrics are logged correctly
|
||||
|
||||
# ============================================================================
|
||||
# Troubleshooting
|
||||
# ============================================================================
|
||||
|
||||
# If SwanLab initialization fails:
|
||||
# 1. Check SWANLAB_API_KEY environment variable is set
|
||||
# 2. Verify swanlab_project is set in config
|
||||
# 3. Check swanlab_mode is valid (cloud/local/offline/disabled)
|
||||
# 4. Verify internet connectivity (for cloud mode)
|
||||
|
||||
# If Lark notifications not received:
|
||||
# 1. Check SWANLAB_LARK_WEBHOOK_URL is set correctly
|
||||
# 2. Verify SWANLAB_LARK_SECRET matches your Lark bot settings
|
||||
# 3. Test webhook manually: curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...
|
||||
# 4. Check training logs for "Registered Lark notification callback"
|
||||
# 5. Verify bot is added to the target Lark group chat
|
||||
|
||||
# If completions not appearing in SwanLab:
|
||||
# 1. Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
|
||||
# 2. Check swanlab_log_completions is true
|
||||
# 3. Wait for log_interval steps (default: 100)
|
||||
# 4. Check training logs for "Registered SwanLab RLHF completion logging"
|
||||
|
||||
# If profiling metrics not appearing:
|
||||
# 1. Verify use_swanlab is true
|
||||
# 2. Check SwanLab is initialized (check logs)
|
||||
# 3. Look under "profiling/" namespace in dashboard
|
||||
# 4. Profiling may be disabled if DEFAULT_PROFILING_CONFIG.enabled = False
|
||||
|
||||
# For more help:
|
||||
# - SwanLab docs: https://docs.swanlab.cn
|
||||
# - Axolotl SwanLab integration: src/axolotl/integrations/swanlab/README.md
|
||||
# - GitHub issues: https://github.com/axolotl-ai-cloud/axolotl/issues
|
||||
@@ -1,178 +0,0 @@
|
||||
# SwanLab LoRA Training Example with Performance Profiling
|
||||
#
|
||||
# This example demonstrates standard LoRA fine-tuning with SwanLab integration
|
||||
# for performance profiling and optimization.
|
||||
#
|
||||
# Features enabled:
|
||||
# - SwanLab experiment tracking
|
||||
# - Performance profiling (training step, forward/backward pass timing)
|
||||
# - Real-time metrics visualization
|
||||
#
|
||||
# To run:
|
||||
# export SWANLAB_API_KEY=your-api-key
|
||||
# accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
|
||||
|
||||
# Model Configuration
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
|
||||
# Dataset Configuration
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-swanlab-profiling-out
|
||||
|
||||
# LoRA Configuration
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Training Configuration
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 1
|
||||
|
||||
# Optimization
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
|
||||
# Precision
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
# Performance
|
||||
gradient_checkpointing: true
|
||||
flash_attention: true
|
||||
|
||||
# Checkpointing and Logging
|
||||
logging_steps: 1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
|
||||
# Loss Monitoring
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
# ============================================================================
|
||||
# SwanLab Integration
|
||||
# ============================================================================
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||
|
||||
# Basic SwanLab Configuration
|
||||
use_swanlab: true
|
||||
swanlab_project: lora-profiling
|
||||
swanlab_experiment_name: llama-3.2-1b-profiling-demo
|
||||
swanlab_description: "LoRA fine-tuning with performance profiling"
|
||||
swanlab_mode: cloud # Options: cloud, local, offline, disabled
|
||||
|
||||
# SwanLab Authentication
|
||||
# Recommended: Set via environment variable
|
||||
# export SWANLAB_API_KEY=your-api-key
|
||||
# Or set in config (less secure):
|
||||
# swanlab_api_key: your-api-key
|
||||
|
||||
# Optional: Team workspace
|
||||
# swanlab_workspace: my-ml-team
|
||||
|
||||
# ============================================================================
|
||||
# Performance Profiling
|
||||
# ============================================================================
|
||||
#
|
||||
# SwanLab automatically profiles trainer methods when enabled.
|
||||
# Profiling metrics appear in SwanLab dashboard under "profiling/" namespace.
|
||||
#
|
||||
# Built-in profiling:
|
||||
# - Minimal overhead (< 0.1% per step)
|
||||
# - High-precision timing (microsecond accuracy)
|
||||
# - Exception-safe (logs duration even if method fails)
|
||||
#
|
||||
# View profiling metrics in SwanLab dashboard:
|
||||
# profiling/Time taken: AxolotlTrainer.training_step
|
||||
# profiling/Time taken: AxolotlTrainer.compute_loss
|
||||
# profiling/Time taken: AxolotlTrainer.prediction_step
|
||||
#
|
||||
# For custom profiling in your own trainer, see:
|
||||
# examples/swanlab/custom_trainer_profiling.py
|
||||
|
||||
# Completion logging is disabled for non-RLHF trainers
|
||||
swanlab_log_completions: false # Only works with DPO/KTO/ORPO/GRPO
|
||||
|
||||
# ============================================================================
|
||||
# Optional: Compare with Multiple Runs
|
||||
# ============================================================================
|
||||
#
|
||||
# To compare profiling metrics across different configurations:
|
||||
#
|
||||
# 1. Run baseline without flash attention:
|
||||
# swanlab_experiment_name: llama-3.2-1b-no-flash-attn
|
||||
# flash_attention: false
|
||||
#
|
||||
# 2. Run with gradient checkpointing:
|
||||
# swanlab_experiment_name: llama-3.2-1b-grad-checkpoint
|
||||
# gradient_checkpointing: true
|
||||
#
|
||||
# 3. Run with both:
|
||||
# swanlab_experiment_name: llama-3.2-1b-optimized
|
||||
# flash_attention: true
|
||||
# gradient_checkpointing: true
|
||||
#
|
||||
# Then compare profiling metrics in SwanLab dashboard to see performance impact
|
||||
|
||||
# ============================================================================
|
||||
# Optional: Lark (Feishu) Team Notifications
|
||||
# ============================================================================
|
||||
#
|
||||
# Get notified when profiling experiments complete:
|
||||
|
||||
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
|
||||
# swanlab_lark_secret: your-webhook-secret
|
||||
|
||||
# ============================================================================
|
||||
# Profiling Best Practices
|
||||
# ============================================================================
|
||||
#
|
||||
# 1. Run multiple epochs to see profiling trends over time
|
||||
# 2. Ignore first ~10 steps (warmup period, slower)
|
||||
# 3. Look for outliers (steps that take significantly longer)
|
||||
# 4. Compare profiling metrics before/after optimization changes
|
||||
# 5. Monitor per-rank profiling in distributed training
|
||||
#
|
||||
# Common bottlenecks to profile:
|
||||
# - training_step: Overall step time (should be consistent)
|
||||
# - compute_loss: Loss computation (scales with sequence length)
|
||||
# - prediction_step: Evaluation time (can be slow for large val sets)
|
||||
#
|
||||
# If you see inconsistent timing:
|
||||
# - Check for data loading bottlenecks
|
||||
# - Monitor GPU utilization (may be CPU-bound)
|
||||
# - Check for gradient accumulation effects
|
||||
# - Verify CUDA kernel synchronization
|
||||
|
||||
# ============================================================================
|
||||
# Disable WandB if you're migrating from it
|
||||
# ============================================================================
|
||||
|
||||
# wandb_project:
|
||||
# use_wandb: false
|
||||
@@ -1,7 +1,7 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.49.1
|
||||
bitsandbytes==0.48.2
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
@@ -63,7 +63,7 @@ langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.13.0
|
||||
torchao==0.15.0
|
||||
openenv-core==0.1.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
|
||||
@@ -373,11 +373,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
||||
|
||||
if self.cfg.use_dynamic_finetuning:
|
||||
from axolotl.monkeypatch.loss.dft import dft_loss
|
||||
|
||||
trainer_kwargs["compute_loss_func"] = dft_loss
|
||||
|
||||
trainer_cls = self._get_trainer_cls()
|
||||
|
||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||
|
||||
@@ -660,10 +660,11 @@ class AxolotlTrainer(
|
||||
logs["tokens/train_per_sec_per_gpu"] = round(
|
||||
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
||||
)
|
||||
if "total" in self.state.tokens:
|
||||
logs["tokens/total"] = int(self.state.tokens["total"].item())
|
||||
if "trainable" in self.state.tokens:
|
||||
logs["tokens/trainable"] = int(self.state.tokens["trainable"].item())
|
||||
if (
|
||||
hasattr(self.state, "total_tokens")
|
||||
and self.state.total_tokens is not None
|
||||
):
|
||||
logs["total_tokens"] = int(self.state.total_tokens.item())
|
||||
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +0,0 @@
|
||||
"""SwanLab integration plugin for Axolotl"""
|
||||
|
||||
from axolotl.integrations.swanlab.args import SwanLabConfig
|
||||
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
|
||||
|
||||
__all__ = ["SwanLabConfig", "SwanLabPlugin"]
|
||||
@@ -1,140 +0,0 @@
|
||||
"""SwanLab configuration arguments"""
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
|
||||
class SwanLabConfig(BaseModel):
|
||||
"""SwanLab configuration subset"""
|
||||
|
||||
use_swanlab: bool | None = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Enable SwanLab experiment tracking and visualization"
|
||||
},
|
||||
)
|
||||
swanlab_project: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Your SwanLab project name"},
|
||||
)
|
||||
swanlab_experiment_name: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Set the name of your SwanLab experiment"},
|
||||
)
|
||||
swanlab_description: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Description for your SwanLab experiment"},
|
||||
)
|
||||
swanlab_mode: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": '"cloud" to sync to SwanLab cloud, "local" for local only, "offline" to save metadata locally, "disabled" to turn off SwanLab'
|
||||
},
|
||||
)
|
||||
swanlab_workspace: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "SwanLab workspace name (organization or username)"
|
||||
},
|
||||
)
|
||||
swanlab_api_key: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "SwanLab API key for authentication. Can also be set via SWANLAB_API_KEY environment variable"
|
||||
},
|
||||
)
|
||||
swanlab_log_model: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "Whether to log model checkpoints to SwanLab (feature coming soon)"
|
||||
},
|
||||
)
|
||||
swanlab_web_host: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Web address for SwanLab cloud environment (for private deployment)"
|
||||
},
|
||||
)
|
||||
swanlab_api_host: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "API address for SwanLab cloud environment (for private deployment)"
|
||||
},
|
||||
)
|
||||
swanlab_lark_webhook_url: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Lark (Feishu) webhook URL for sending training notifications to team chat"
|
||||
},
|
||||
)
|
||||
swanlab_lark_secret: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Secret for Lark webhook HMAC signature authentication (optional)"
|
||||
},
|
||||
)
|
||||
swanlab_log_completions: bool | None = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Enable logging RLHF completions to SwanLab for qualitative analysis (DPO/KTO/ORPO/GRPO)"
|
||||
},
|
||||
)
|
||||
swanlab_completion_log_interval: int | None = Field(
|
||||
default=100,
|
||||
json_schema_extra={
|
||||
"description": "Number of training steps between completion table logging to SwanLab"
|
||||
},
|
||||
)
|
||||
swanlab_completion_max_buffer: int | None = Field(
|
||||
default=128,
|
||||
json_schema_extra={
|
||||
"description": "Maximum number of completions to buffer before logging (prevents memory leaks)"
|
||||
},
|
||||
)
|
||||
|
||||
@field_validator("swanlab_mode")
|
||||
@classmethod
|
||||
def validate_swanlab_mode(cls, v):
|
||||
"""Validate swanlab_mode is one of the allowed values."""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
valid_modes = ["cloud", "local", "offline", "disabled"]
|
||||
if v not in valid_modes:
|
||||
raise ValueError(
|
||||
f"Invalid swanlab_mode: '{v}'.\n\n"
|
||||
f"Valid options: {', '.join(valid_modes)}\n\n"
|
||||
f"Examples:\n"
|
||||
f" swanlab_mode: cloud # Sync to SwanLab cloud\n"
|
||||
f" swanlab_mode: local # Local only, no cloud sync\n"
|
||||
f" swanlab_mode: offline # Save metadata locally\n"
|
||||
f" swanlab_mode: disabled # Turn off SwanLab\n"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("swanlab_project")
|
||||
@classmethod
|
||||
def validate_swanlab_project(cls, v):
|
||||
"""Validate swanlab_project is non-empty when provided."""
|
||||
if v is not None and isinstance(v, str) and len(v.strip()) == 0:
|
||||
raise ValueError(
|
||||
"swanlab_project cannot be an empty string.\n\n"
|
||||
"Either:\n"
|
||||
" 1. Provide a valid project name: swanlab_project: my-project\n"
|
||||
" 2. Remove the swanlab_project field entirely\n"
|
||||
)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_swanlab_enabled_requires_project(self):
|
||||
"""Validate that if use_swanlab is True, swanlab_project must be set."""
|
||||
if self.use_swanlab is True and not self.swanlab_project:
|
||||
raise ValueError(
|
||||
"SwanLab enabled (use_swanlab: true) but 'swanlab_project' is not set.\n\n"
|
||||
"Solutions:\n"
|
||||
" 1. Add 'swanlab_project: your-project-name' to your config\n"
|
||||
" 2. Set 'use_swanlab: false' to disable SwanLab\n\n"
|
||||
"Example:\n"
|
||||
" use_swanlab: true\n"
|
||||
" swanlab_project: my-llm-training\n"
|
||||
)
|
||||
return self
|
||||
@@ -1,179 +0,0 @@
|
||||
"""SwanLab callbacks for Axolotl trainers.
|
||||
|
||||
This module provides HuggingFace Trainer callbacks for logging
|
||||
RLHF completions to SwanLab.
|
||||
"""
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from axolotl.integrations.swanlab.completion_logger import CompletionLogger
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class SwanLabRLHFCompletionCallback(TrainerCallback):
|
||||
"""Callback for logging RLHF completions to SwanLab.
|
||||
|
||||
This callback periodically logs model completions (prompts, chosen/rejected
|
||||
responses, rewards) to SwanLab during RLHF training for qualitative analysis.
|
||||
|
||||
Supports DPO, KTO, ORPO, and GRPO trainers.
|
||||
|
||||
Example usage:
|
||||
>>> callback = SwanLabRLHFCompletionCallback(
|
||||
... log_interval=100, # Log every 100 steps
|
||||
... max_completions=128, # Keep last 128 completions
|
||||
... )
|
||||
>>> trainer.add_callback(callback)
|
||||
|
||||
Attributes:
|
||||
logger: CompletionLogger instance
|
||||
log_interval: Number of steps between SwanLab logging
|
||||
trainer_type: Auto-detected trainer type (dpo/kto/orpo/grpo)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_interval: int = 100,
|
||||
max_completions: int = 128,
|
||||
table_name: str = "rlhf_completions",
|
||||
):
|
||||
"""Initialize SwanLab RLHF completion callback.
|
||||
|
||||
Args:
|
||||
log_interval: Log to SwanLab every N steps. Default: 100
|
||||
max_completions: Maximum completions to buffer. Default: 128
|
||||
table_name: SwanLab table name. Default: "rlhf_completions"
|
||||
"""
|
||||
super().__init__()
|
||||
self.logger = CompletionLogger(maxlen=max_completions)
|
||||
self.log_interval = log_interval
|
||||
self.table_name = table_name
|
||||
self.trainer_type: str | None = None # Auto-detected
|
||||
self._last_logged_step = 0
|
||||
|
||||
def on_init_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Detect trainer type on initialization."""
|
||||
trainer = kwargs.get("trainer")
|
||||
if trainer is not None:
|
||||
trainer_name = trainer.__class__.__name__
|
||||
if "DPO" in trainer_name:
|
||||
self.trainer_type = "dpo"
|
||||
elif "KTO" in trainer_name:
|
||||
self.trainer_type = "kto"
|
||||
elif "ORPO" in trainer_name:
|
||||
self.trainer_type = "orpo"
|
||||
elif "GRPO" in trainer_name:
|
||||
self.trainer_type = "grpo"
|
||||
else:
|
||||
self.trainer_type = "unknown"
|
||||
|
||||
LOG.info(
|
||||
f"SwanLab RLHF completion logging enabled for {trainer_name} "
|
||||
f"(type: {self.trainer_type})"
|
||||
)
|
||||
|
||||
def on_log(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
logs: dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Capture completions from logs and buffer them.
|
||||
|
||||
Different trainers log completions in different formats:
|
||||
- DPO: logs['dpo/chosen'], logs['dpo/rejected'], logs['dpo/reward_diff']
|
||||
- KTO: logs['kto/completion'], logs['kto/label'], logs['kto/reward']
|
||||
- ORPO: logs['orpo/chosen'], logs['orpo/rejected']
|
||||
- GRPO: logs['grpo/completion'], logs['grpo/reward']
|
||||
|
||||
Note: This is a placeholder implementation. Actual log keys depend
|
||||
on the TRL trainer implementation. You may need to patch the trainers
|
||||
to expose completion data in logs.
|
||||
"""
|
||||
if logs is None or self.trainer_type is None:
|
||||
return
|
||||
|
||||
step = state.global_step
|
||||
|
||||
# DPO completions
|
||||
if self.trainer_type == "dpo":
|
||||
if all(key in logs for key in ["dpo/prompt", "dpo/chosen", "dpo/rejected"]):
|
||||
self.logger.add_dpo_completion(
|
||||
step=step,
|
||||
prompt=logs.get("dpo/prompt", ""),
|
||||
chosen=logs.get("dpo/chosen", ""),
|
||||
rejected=logs.get("dpo/rejected", ""),
|
||||
reward_diff=logs.get("dpo/reward_diff"),
|
||||
)
|
||||
|
||||
# KTO completions
|
||||
elif self.trainer_type == "kto":
|
||||
if all(key in logs for key in ["kto/prompt", "kto/completion"]):
|
||||
self.logger.add_kto_completion(
|
||||
step=step,
|
||||
prompt=logs.get("kto/prompt", ""),
|
||||
completion=logs.get("kto/completion", ""),
|
||||
label=logs.get("kto/label", False),
|
||||
reward=logs.get("kto/reward"),
|
||||
)
|
||||
|
||||
# ORPO completions
|
||||
elif self.trainer_type == "orpo":
|
||||
if all(
|
||||
key in logs for key in ["orpo/prompt", "orpo/chosen", "orpo/rejected"]
|
||||
):
|
||||
self.logger.add_orpo_completion(
|
||||
step=step,
|
||||
prompt=logs.get("orpo/prompt", ""),
|
||||
chosen=logs.get("orpo/chosen", ""),
|
||||
rejected=logs.get("orpo/rejected", ""),
|
||||
log_odds_ratio=logs.get("orpo/log_odds_ratio"),
|
||||
)
|
||||
|
||||
# GRPO completions
|
||||
elif self.trainer_type == "grpo":
|
||||
if all(key in logs for key in ["grpo/prompt", "grpo/completion"]):
|
||||
self.logger.add_grpo_completion(
|
||||
step=step,
|
||||
prompt=logs.get("grpo/prompt", ""),
|
||||
completion=logs.get("grpo/completion", ""),
|
||||
reward=logs.get("grpo/reward"),
|
||||
advantage=logs.get("grpo/advantage"),
|
||||
)
|
||||
|
||||
# Periodically log to SwanLab
|
||||
if step - self._last_logged_step >= self.log_interval:
|
||||
if len(self.logger) > 0:
|
||||
self.logger.log_to_swanlab(table_name=self.table_name)
|
||||
self.logger.clear()
|
||||
self._last_logged_step = step
|
||||
|
||||
def on_train_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Log remaining completions at end of training."""
|
||||
if len(self.logger) > 0:
|
||||
LOG.info(
|
||||
f"Training complete, logging final {len(self.logger)} completions to SwanLab"
|
||||
)
|
||||
self.logger.log_to_swanlab(table_name=self.table_name)
|
||||
self._last_logged_step = state.global_step
|
||||
@@ -1,228 +0,0 @@
|
||||
"""SwanLab completion logger for RLHF/DPO/KTO/ORPO/GRPO training.
|
||||
|
||||
This module provides utilities for logging model completions during
|
||||
preference training to SwanLab for qualitative analysis.
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class CompletionLogger:
|
||||
"""Memory-bounded logger for RLHF completions.
|
||||
|
||||
Stores prompts, completions, and rewards in fixed-size deques to prevent
|
||||
memory leaks during long training runs. Logs completion tables to SwanLab
|
||||
for qualitative analysis of model outputs.
|
||||
|
||||
Example usage:
|
||||
>>> logger = CompletionLogger(maxlen=128)
|
||||
>>> logger.add_dpo_completion(
|
||||
... step=0,
|
||||
... prompt="What is AI?",
|
||||
... chosen="Artificial Intelligence is...",
|
||||
... rejected="AI means...",
|
||||
... reward_diff=0.5
|
||||
... )
|
||||
>>> logger.log_to_swanlab()
|
||||
|
||||
Attributes:
|
||||
maxlen: Maximum number of completions to store (older ones are dropped)
|
||||
data: Deque storing completion dictionaries
|
||||
"""
|
||||
|
||||
def __init__(self, maxlen: int = 128):
|
||||
"""Initialize completion logger with bounded buffer.
|
||||
|
||||
Args:
|
||||
maxlen: Maximum number of completions to store. When the buffer
|
||||
is full, oldest completions are automatically discarded.
|
||||
Default: 128 (sufficient for most RLHF runs without memory issues)
|
||||
"""
|
||||
self.maxlen = maxlen
|
||||
self.data: deque[Mapping[str, Any]] = deque(maxlen=maxlen)
|
||||
|
||||
def add_dpo_completion(
|
||||
self,
|
||||
step: int,
|
||||
prompt: str,
|
||||
chosen: str,
|
||||
rejected: str,
|
||||
reward_diff: float | None = None,
|
||||
) -> None:
|
||||
"""Add a DPO completion to the buffer.
|
||||
|
||||
Args:
|
||||
step: Training step number
|
||||
prompt: Input prompt
|
||||
chosen: Chosen (preferred) completion
|
||||
rejected: Rejected (non-preferred) completion
|
||||
reward_diff: Reward difference (chosen - rejected), if available
|
||||
"""
|
||||
entry = {
|
||||
"step": step,
|
||||
"prompt": prompt,
|
||||
"chosen": chosen,
|
||||
"rejected": rejected,
|
||||
}
|
||||
if reward_diff is not None:
|
||||
entry["reward_diff"] = reward_diff
|
||||
|
||||
self.data.append(entry)
|
||||
|
||||
def add_kto_completion(
|
||||
self,
|
||||
step: int,
|
||||
prompt: str,
|
||||
completion: str,
|
||||
label: bool,
|
||||
reward: float | None = None,
|
||||
) -> None:
|
||||
"""Add a KTO completion to the buffer.
|
||||
|
||||
Args:
|
||||
step: Training step number
|
||||
prompt: Input prompt
|
||||
completion: Model-generated completion
|
||||
label: True if desirable, False if undesirable
|
||||
reward: Reward score, if available
|
||||
"""
|
||||
entry = {
|
||||
"step": step,
|
||||
"prompt": prompt,
|
||||
"completion": completion,
|
||||
"label": "desirable" if label else "undesirable",
|
||||
}
|
||||
if reward is not None:
|
||||
entry["reward"] = reward
|
||||
|
||||
self.data.append(entry)
|
||||
|
||||
def add_orpo_completion(
|
||||
self,
|
||||
step: int,
|
||||
prompt: str,
|
||||
chosen: str,
|
||||
rejected: str,
|
||||
log_odds_ratio: float | None = None,
|
||||
) -> None:
|
||||
"""Add an ORPO completion to the buffer.
|
||||
|
||||
Args:
|
||||
step: Training step number
|
||||
prompt: Input prompt
|
||||
chosen: Chosen (preferred) completion
|
||||
rejected: Rejected (non-preferred) completion
|
||||
log_odds_ratio: Log odds ratio between chosen and rejected
|
||||
"""
|
||||
entry = {
|
||||
"step": step,
|
||||
"prompt": prompt,
|
||||
"chosen": chosen,
|
||||
"rejected": rejected,
|
||||
}
|
||||
if log_odds_ratio is not None:
|
||||
entry["log_odds_ratio"] = log_odds_ratio
|
||||
|
||||
self.data.append(entry)
|
||||
|
||||
def add_grpo_completion(
|
||||
self,
|
||||
step: int,
|
||||
prompt: str,
|
||||
completion: str,
|
||||
reward: float | None = None,
|
||||
advantage: float | None = None,
|
||||
) -> None:
|
||||
"""Add a GRPO completion to the buffer.
|
||||
|
||||
Args:
|
||||
step: Training step number
|
||||
prompt: Input prompt
|
||||
completion: Model-generated completion
|
||||
reward: Reward score from reward model
|
||||
advantage: Advantage estimate (reward - baseline)
|
||||
"""
|
||||
entry = {
|
||||
"step": step,
|
||||
"prompt": prompt,
|
||||
"completion": completion,
|
||||
}
|
||||
if reward is not None:
|
||||
entry["reward"] = reward
|
||||
if advantage is not None:
|
||||
entry["advantage"] = advantage
|
||||
|
||||
self.data.append(entry)
|
||||
|
||||
def log_to_swanlab(self, table_name: str = "completions") -> bool:
|
||||
"""Log buffered completions to SwanLab as a table.
|
||||
|
||||
Creates a SwanLab echarts Table with all buffered completions.
|
||||
Only logs if SwanLab is initialized and data is available.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table in SwanLab dashboard.
|
||||
Default: "completions"
|
||||
|
||||
Returns:
|
||||
True if logging succeeded, False otherwise
|
||||
"""
|
||||
if not self.data:
|
||||
LOG.debug("No completions to log to SwanLab")
|
||||
return False
|
||||
|
||||
try:
|
||||
import swanlab
|
||||
|
||||
if swanlab.get_run() is None:
|
||||
LOG.debug("SwanLab not initialized, skipping completion logging")
|
||||
return False
|
||||
|
||||
# Convert deque to list of dicts
|
||||
completions = list(self.data)
|
||||
|
||||
# Extract headers from first entry (all entries should have same structure)
|
||||
headers = list(completions[0].keys())
|
||||
|
||||
# Build rows: each completion becomes one row
|
||||
rows = []
|
||||
for completion in completions:
|
||||
row = [completion.get(header, "") for header in headers]
|
||||
rows.append(row)
|
||||
|
||||
# Log to SwanLab as echarts Table
|
||||
swanlab.log({table_name: swanlab.echarts.Table().add(headers, rows)})
|
||||
|
||||
LOG.info(f"Logged {len(rows)} completions to SwanLab table '{table_name}'")
|
||||
return True
|
||||
|
||||
except ImportError:
|
||||
LOG.warning(
|
||||
"SwanLab not installed, cannot log completions. "
|
||||
"Install with: pip install swanlab"
|
||||
)
|
||||
return False
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
LOG.exception("Failed to log completions to SwanLab: %s", err)
|
||||
return False
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all buffered completions."""
|
||||
self.data.clear()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return number of buffered completions."""
|
||||
return len(self.data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation showing buffer status."""
|
||||
return (
|
||||
f"CompletionLogger(maxlen={self.maxlen}, "
|
||||
f"buffered={len(self.data)}/{self.maxlen})"
|
||||
)
|
||||
@@ -1,554 +0,0 @@
|
||||
"""SwanLab Plugin for Axolotl"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class SwanLabPlugin(BasePlugin):
|
||||
"""
|
||||
SwanLab integration plugin for Axolotl.
|
||||
|
||||
Provides experiment tracking, visualization, and logging capabilities
|
||||
using SwanLab (https://swanlab.cn).
|
||||
|
||||
Usage in config.yaml:
|
||||
plugins:
|
||||
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||
|
||||
use_swanlab: true
|
||||
swanlab_project: my-project
|
||||
swanlab_experiment_name: my-experiment
|
||||
swanlab_mode: cloud # or 'local', 'offline', 'disabled'
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.swanlab_initialized = False
|
||||
LOG.info("SwanLab plugin initialized")
|
||||
|
||||
def get_input_args(self) -> str:
|
||||
"""Returns the configuration model for SwanLab integration."""
|
||||
return "axolotl.integrations.swanlab.SwanLabConfig"
|
||||
|
||||
def register(self, cfg: dict):
|
||||
"""Register SwanLab plugin with configuration and conflict detection."""
|
||||
LOG.info("Registering SwanLab plugin")
|
||||
|
||||
# === Conflict Detection: Required Fields ===
|
||||
|
||||
# Check if SwanLab is enabled
|
||||
if cfg.get("use_swanlab"):
|
||||
# 1. Validate project name is set
|
||||
if not cfg.get("swanlab_project"):
|
||||
raise ValueError(
|
||||
"SwanLab enabled but 'swanlab_project' is not set.\n\n"
|
||||
"Solutions:\n"
|
||||
" 1. Add 'swanlab_project: your-project-name' to your config\n"
|
||||
" 2. Set 'use_swanlab: false' to disable SwanLab\n\n"
|
||||
"See: src/axolotl/integrations/swanlab/README.md for examples"
|
||||
)
|
||||
|
||||
# 2. Validate swanlab_mode value
|
||||
valid_modes = ["cloud", "local", "offline", "disabled"]
|
||||
mode = cfg.get("swanlab_mode")
|
||||
if mode and mode not in valid_modes:
|
||||
raise ValueError(
|
||||
f"Invalid swanlab_mode: '{mode}'.\n\n"
|
||||
f"Valid options: {', '.join(valid_modes)}\n\n"
|
||||
f"Example:\n"
|
||||
f" swanlab_mode: cloud # Sync to SwanLab cloud\n"
|
||||
f" swanlab_mode: local # Local only, no cloud sync\n"
|
||||
)
|
||||
|
||||
# 3. Check API key for cloud mode
|
||||
import os
|
||||
|
||||
mode = cfg.get("swanlab_mode", "cloud") # Default is cloud
|
||||
if mode == "cloud":
|
||||
api_key = cfg.get("swanlab_api_key") or os.environ.get(
|
||||
"SWANLAB_API_KEY"
|
||||
)
|
||||
if not api_key:
|
||||
LOG.warning(
|
||||
"SwanLab cloud mode enabled but no API key found.\n"
|
||||
"SwanLab may fail to initialize during training.\n\n"
|
||||
"Solutions:\n"
|
||||
" 1. Set SWANLAB_API_KEY environment variable:\n"
|
||||
" export SWANLAB_API_KEY=your-api-key\n"
|
||||
" 2. Add 'swanlab_api_key: your-api-key' to config (less secure)\n"
|
||||
" 3. Run 'swanlab login' before training\n"
|
||||
" 4. Use 'swanlab_mode: local' for offline tracking\n"
|
||||
)
|
||||
|
||||
# === Conflict Detection: Multi-Logger Performance Warning ===
|
||||
|
||||
# Detect all active logging tools
|
||||
active_loggers = []
|
||||
if cfg.get("use_wandb"):
|
||||
active_loggers.append("WandB")
|
||||
if cfg.get("use_mlflow"):
|
||||
active_loggers.append("MLflow")
|
||||
if cfg.get("comet_api_key") or cfg.get("comet_project_name"):
|
||||
active_loggers.append("Comet")
|
||||
if cfg.get("use_swanlab"):
|
||||
active_loggers.append("SwanLab")
|
||||
|
||||
if len(active_loggers) > 1:
|
||||
LOG.warning(
|
||||
f"\n{'=' * 70}\n"
|
||||
f"Multiple logging tools enabled: {', '.join(active_loggers)}\n"
|
||||
f"{'=' * 70}\n"
|
||||
f"This may cause:\n"
|
||||
f" - Performance overhead (~1-2% per logger, cumulative)\n"
|
||||
f" - Increased memory usage\n"
|
||||
f" - Longer training time per step\n"
|
||||
f" - Potential config/callback conflicts\n\n"
|
||||
f"Recommendations:\n"
|
||||
f" - Choose ONE primary logging tool for production training\n"
|
||||
f" - Use multiple loggers only for:\n"
|
||||
f" * Migration period (transitioning between tools)\n"
|
||||
f" * Short comparison runs\n"
|
||||
f" * Debugging specific tool issues\n"
|
||||
f" - Monitor system resources (CPU, memory) during training\n"
|
||||
f"{'=' * 70}\n"
|
||||
)
|
||||
|
||||
if len(active_loggers) >= 3:
|
||||
LOG.error(
|
||||
f"\n{'!' * 70}\n"
|
||||
f"WARNING: {len(active_loggers)} logging tools enabled simultaneously!\n"
|
||||
f"{'!' * 70}\n"
|
||||
f"This is likely unintentional and WILL significantly impact performance.\n"
|
||||
f"Expected overhead: ~{len(active_loggers) * 1.5:.1f}% per training step.\n\n"
|
||||
f"STRONGLY RECOMMEND:\n"
|
||||
f" - Disable all but ONE logging tool\n"
|
||||
f" - Use config inheritance to manage multiple configs\n"
|
||||
f"{'!' * 70}\n"
|
||||
)
|
||||
|
||||
# === Auto-Enable Logic ===
|
||||
|
||||
# Enable SwanLab if project is specified
|
||||
if cfg.get("swanlab_project") and not cfg.get("use_swanlab"):
|
||||
cfg["use_swanlab"] = True
|
||||
LOG.info("Automatically enabled use_swanlab because swanlab_project is set")
|
||||
|
||||
def pre_model_load(self, cfg: DictDefault):
|
||||
"""Initialize SwanLab before model loading with runtime checks."""
|
||||
if not cfg.use_swanlab:
|
||||
return
|
||||
|
||||
# === Runtime Check: Import Availability ===
|
||||
try:
|
||||
import swanlab
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"SwanLab is not installed.\n\n"
|
||||
"Install with:\n"
|
||||
" pip install swanlab\n\n"
|
||||
"Or add to requirements:\n"
|
||||
" swanlab>=0.3.0\n\n"
|
||||
f"Original error: {err}"
|
||||
) from err
|
||||
|
||||
# Log SwanLab version
|
||||
try:
|
||||
swanlab_version = swanlab.__version__
|
||||
LOG.info(f"SwanLab version: {swanlab_version}")
|
||||
except AttributeError:
|
||||
LOG.warning("Could not determine SwanLab version")
|
||||
|
||||
# === Runtime Check: Distributed Training Setup ===
|
||||
from axolotl.utils.distributed import get_world_size, is_main_process
|
||||
|
||||
world_size = get_world_size()
|
||||
if world_size > 1:
|
||||
mode = getattr(cfg, "swanlab_mode", "cloud")
|
||||
LOG.info(
|
||||
f"\n{'=' * 70}\n"
|
||||
f"Distributed training detected (world_size={world_size})\n"
|
||||
f"SwanLab mode: {mode}\n"
|
||||
f"{'=' * 70}\n"
|
||||
f"Behavior:\n"
|
||||
f" - Only rank 0 will initialize SwanLab\n"
|
||||
f" - Other ranks will skip SwanLab to avoid conflicts\n"
|
||||
)
|
||||
|
||||
if mode == "cloud":
|
||||
LOG.info(
|
||||
f" - Only rank 0 will upload to SwanLab cloud\n"
|
||||
f" - Other ranks run without SwanLab overhead\n"
|
||||
f"{'=' * 70}\n"
|
||||
)
|
||||
|
||||
# Only initialize SwanLab on the main process (rank 0)
|
||||
# to avoid creating multiple runs in distributed training
|
||||
if not is_main_process():
|
||||
LOG.debug("Skipping SwanLab initialization on non-main process")
|
||||
return
|
||||
|
||||
# Initialize SwanLab run (passing all params directly to init)
|
||||
try:
|
||||
init_kwargs = self._get_swanlab_init_kwargs(cfg)
|
||||
swanlab.init(**init_kwargs)
|
||||
self.swanlab_initialized = True
|
||||
LOG.info(f"SwanLab initialized with project: {cfg.swanlab_project}")
|
||||
|
||||
# Register Lark notification callback (if configured)
|
||||
self._register_lark_callback(cfg)
|
||||
|
||||
# Log configuration (with error handling)
|
||||
try:
|
||||
config_dict = self._prepare_config_for_logging(cfg)
|
||||
swanlab.config.update(config_dict)
|
||||
LOG.debug("Successfully logged config to SwanLab")
|
||||
except Exception as config_err: # pylint: disable=broad-except
|
||||
LOG.warning(
|
||||
f"Failed to log config to SwanLab: {config_err}. Continuing anyway."
|
||||
)
|
||||
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
LOG.exception("Failed to initialize SwanLab: %s", err)
|
||||
self.swanlab_initialized = False
|
||||
|
||||
def add_callbacks_pre_trainer(self, cfg: DictDefault, model):
|
||||
"""Add SwanLab callbacks before trainer creation."""
|
||||
callbacks: list[TrainerCallback] = []
|
||||
|
||||
if not cfg.use_swanlab:
|
||||
return callbacks
|
||||
|
||||
if not self.swanlab_initialized:
|
||||
LOG.warning("SwanLab not initialized, skipping callback registration")
|
||||
return callbacks
|
||||
|
||||
try:
|
||||
from axolotl.utils.callbacks.swanlab import (
|
||||
CustomSwanLabCallback,
|
||||
SaveAxolotlConfigtoSwanLabCallback,
|
||||
)
|
||||
|
||||
# Add our custom lightweight SwanLabCallback
|
||||
# (avoids omegaconf/antlr4 version conflicts)
|
||||
swanlab_callback = CustomSwanLabCallback()
|
||||
callbacks.append(swanlab_callback)
|
||||
LOG.info("Added CustomSwanLabCallback for metrics logging")
|
||||
|
||||
# Add Axolotl config logging callback
|
||||
if cfg.axolotl_config_path:
|
||||
config_callback = SaveAxolotlConfigtoSwanLabCallback(
|
||||
cfg.axolotl_config_path
|
||||
)
|
||||
callbacks.append(config_callback)
|
||||
LOG.info("Added SaveAxolotlConfigtoSwanLabCallback")
|
||||
|
||||
except ImportError as err:
|
||||
LOG.exception("Failed to import SwanLab callbacks: %s", err)
|
||||
|
||||
return callbacks
|
||||
|
||||
def post_trainer_create(self, cfg: DictDefault, trainer):
|
||||
"""Post-trainer creation hook."""
|
||||
if cfg.use_swanlab and self.swanlab_initialized:
|
||||
try:
|
||||
import swanlab
|
||||
|
||||
# Log additional trainer information (with safe conversion)
|
||||
trainer_config = {
|
||||
"total_steps": int(trainer.state.max_steps)
|
||||
if trainer.state.max_steps
|
||||
else None,
|
||||
"num_train_epochs": float(trainer.args.num_train_epochs)
|
||||
if trainer.args.num_train_epochs
|
||||
else None,
|
||||
"train_batch_size": int(trainer.args.train_batch_size)
|
||||
if hasattr(trainer.args, "train_batch_size")
|
||||
else None,
|
||||
"gradient_accumulation_steps": int(
|
||||
trainer.args.gradient_accumulation_steps
|
||||
)
|
||||
if trainer.args.gradient_accumulation_steps
|
||||
else None,
|
||||
}
|
||||
# Remove None values
|
||||
trainer_config = {
|
||||
k: v for k, v in trainer_config.items() if v is not None
|
||||
}
|
||||
|
||||
if trainer_config:
|
||||
swanlab.config.update(trainer_config)
|
||||
LOG.info("Logged trainer configuration to SwanLab")
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
LOG.debug(f"Failed to log trainer config to SwanLab: {err}")
|
||||
|
||||
# Register RLHF completion logging callback if enabled
|
||||
self._register_completion_callback(cfg, trainer)
|
||||
|
||||
def _get_swanlab_init_kwargs(self, cfg: DictDefault) -> dict:
|
||||
"""Prepare kwargs for swanlab.init().
|
||||
|
||||
Passes all configuration parameters directly to swanlab.init()
|
||||
instead of using environment variables as an intermediate layer.
|
||||
|
||||
Returns:
|
||||
dict: Keyword arguments for swanlab.init()
|
||||
"""
|
||||
init_kwargs = {}
|
||||
|
||||
# Project name (required)
|
||||
if cfg.swanlab_project:
|
||||
init_kwargs["project"] = cfg.swanlab_project
|
||||
|
||||
# Experiment name
|
||||
if cfg.swanlab_experiment_name:
|
||||
init_kwargs["experiment_name"] = cfg.swanlab_experiment_name
|
||||
|
||||
# Description
|
||||
if cfg.swanlab_description:
|
||||
init_kwargs["description"] = cfg.swanlab_description
|
||||
|
||||
# Workspace (organization)
|
||||
if cfg.swanlab_workspace:
|
||||
init_kwargs["workspace"] = cfg.swanlab_workspace
|
||||
|
||||
# Mode: cloud, local, offline, disabled
|
||||
if cfg.swanlab_mode:
|
||||
init_kwargs["mode"] = cfg.swanlab_mode
|
||||
|
||||
# API key (pass directly instead of via env var)
|
||||
if cfg.swanlab_api_key:
|
||||
init_kwargs["api_key"] = cfg.swanlab_api_key
|
||||
|
||||
# Private deployment hosts (pass directly instead of via env var)
|
||||
if cfg.swanlab_web_host:
|
||||
init_kwargs["web_host"] = cfg.swanlab_web_host
|
||||
|
||||
if cfg.swanlab_api_host:
|
||||
init_kwargs["api_host"] = cfg.swanlab_api_host
|
||||
|
||||
# Log model checkpoints (coming soon in SwanLab)
|
||||
if cfg.swanlab_log_model:
|
||||
init_kwargs["log_model"] = cfg.swanlab_log_model
|
||||
|
||||
# Custom branding - adds Axolotl identifier to SwanLab UI
|
||||
# This helps identify runs from Axolotl vs other frameworks
|
||||
init_kwargs["config"] = {"UPPERFRAME": "🦎 Axolotl"}
|
||||
|
||||
return init_kwargs
|
||||
|
||||
def _prepare_config_for_logging(self, cfg: DictDefault) -> dict:
|
||||
"""Prepare configuration dict for logging to SwanLab."""
|
||||
|
||||
def safe_convert(value):
|
||||
"""Convert value to JSON-serializable type."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (int, float, bool)):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
# Convert everything else to string
|
||||
return str(value)
|
||||
|
||||
try:
|
||||
# Extract important training parameters with safe conversion
|
||||
config_dict = {
|
||||
"base_model": safe_convert(getattr(cfg, "base_model", "")),
|
||||
"model_type": safe_convert(getattr(cfg, "model_type", "")),
|
||||
"sequence_len": safe_convert(getattr(cfg, "sequence_len", None)),
|
||||
"micro_batch_size": safe_convert(
|
||||
getattr(cfg, "micro_batch_size", None)
|
||||
),
|
||||
"gradient_accumulation_steps": safe_convert(
|
||||
getattr(cfg, "gradient_accumulation_steps", None)
|
||||
),
|
||||
"num_epochs": safe_convert(getattr(cfg, "num_epochs", None)),
|
||||
"max_steps": safe_convert(getattr(cfg, "max_steps", None)),
|
||||
"learning_rate": safe_convert(getattr(cfg, "learning_rate", None)),
|
||||
"lr_scheduler": safe_convert(getattr(cfg, "lr_scheduler", "")),
|
||||
"optimizer": safe_convert(getattr(cfg, "optimizer", "")),
|
||||
"warmup_ratio": safe_convert(getattr(cfg, "warmup_ratio", None)),
|
||||
"weight_decay": safe_convert(getattr(cfg, "weight_decay", None)),
|
||||
"seed": safe_convert(getattr(cfg, "seed", None)),
|
||||
"bf16": safe_convert(getattr(cfg, "bf16", None)),
|
||||
"tf32": safe_convert(getattr(cfg, "tf32", None)),
|
||||
"flash_attention": safe_convert(getattr(cfg, "flash_attention", None)),
|
||||
"sample_packing": safe_convert(getattr(cfg, "sample_packing", None)),
|
||||
}
|
||||
|
||||
# Add FSDP/parallel config - only boolean flags
|
||||
if hasattr(cfg, "fsdp_config") and cfg.fsdp_config:
|
||||
config_dict["fsdp_enabled"] = True
|
||||
config_dict["fsdp_version"] = safe_convert(
|
||||
getattr(cfg, "fsdp_version", None)
|
||||
)
|
||||
|
||||
if hasattr(cfg, "deepspeed") and cfg.deepspeed:
|
||||
config_dict["deepspeed_enabled"] = True
|
||||
|
||||
# Add context parallel info
|
||||
if hasattr(cfg, "context_parallel_size"):
|
||||
config_dict["context_parallel_size"] = safe_convert(
|
||||
getattr(cfg, "context_parallel_size", None)
|
||||
)
|
||||
if hasattr(cfg, "tensor_parallel_size"):
|
||||
config_dict["tensor_parallel_size"] = safe_convert(
|
||||
getattr(cfg, "tensor_parallel_size", None)
|
||||
)
|
||||
if hasattr(cfg, "dp_shard_size"):
|
||||
config_dict["dp_shard_size"] = safe_convert(
|
||||
getattr(cfg, "dp_shard_size", None)
|
||||
)
|
||||
|
||||
# Remove None values and empty strings
|
||||
config_dict = {
|
||||
k: v
|
||||
for k, v in config_dict.items()
|
||||
if v is not None and v != "" and v != "None"
|
||||
}
|
||||
|
||||
return config_dict
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
LOG.warning(f"Failed to prepare config for logging: {err}")
|
||||
# Return minimal config
|
||||
try:
|
||||
lr = getattr(cfg, "learning_rate", None)
|
||||
lr_value = float(lr) if lr is not None else None
|
||||
except (TypeError, ValueError):
|
||||
lr_value = None
|
||||
return {
|
||||
"base_model": str(getattr(cfg, "base_model", "unknown")),
|
||||
"learning_rate": lr_value,
|
||||
}
|
||||
|
||||
def _register_lark_callback(self, cfg: DictDefault):
|
||||
"""Register Lark (Feishu) notification callback if configured.
|
||||
|
||||
Lark notifications enable sending training updates to team chat channels,
|
||||
useful for production monitoring and team collaboration.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object with Lark webhook settings
|
||||
"""
|
||||
# Check if Lark webhook URL is configured
|
||||
lark_webhook_url = getattr(cfg, "swanlab_lark_webhook_url", None)
|
||||
if not lark_webhook_url:
|
||||
return # Lark not configured, skip
|
||||
|
||||
try:
|
||||
import swanlab
|
||||
from swanlab.plugin.notification import LarkCallback
|
||||
|
||||
# Get optional secret for HMAC signature authentication
|
||||
lark_secret = getattr(cfg, "swanlab_lark_secret", None)
|
||||
|
||||
# Create Lark callback with webhook URL and optional secret
|
||||
lark_callback = LarkCallback(
|
||||
webhook_url=lark_webhook_url,
|
||||
secret=lark_secret,
|
||||
)
|
||||
|
||||
# Register callback with SwanLab
|
||||
swanlab.register_callbacks([lark_callback])
|
||||
|
||||
if lark_secret:
|
||||
LOG.info(
|
||||
"Registered Lark notification callback with HMAC authentication"
|
||||
)
|
||||
else:
|
||||
LOG.info("Registered Lark notification callback (no HMAC secret)")
|
||||
LOG.warning(
|
||||
"Lark webhook has no secret configured. "
|
||||
"For production use, set 'swanlab_lark_secret' to enable HMAC signature verification."
|
||||
)
|
||||
|
||||
except ImportError as err:
|
||||
LOG.warning(
|
||||
f"Failed to import SwanLab Lark plugin: {err}\n\n"
|
||||
"Lark notifications require SwanLab >= 0.3.0 with plugin support.\n"
|
||||
"Install with: pip install 'swanlab>=0.3.0'\n\n"
|
||||
"Continuing without Lark notifications..."
|
||||
)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
LOG.exception(
|
||||
"Failed to register Lark callback: %s\n\n"
|
||||
"Check your Lark webhook URL and secret configuration.\n"
|
||||
"Continuing without Lark notifications...",
|
||||
err,
|
||||
)
|
||||
|
||||
def _register_completion_callback(self, cfg: DictDefault, trainer):
|
||||
"""Register RLHF completion logging callback if enabled and applicable.
|
||||
|
||||
This callback logs model completions (prompts, chosen/rejected responses,
|
||||
rewards) to SwanLab during RLHF training for qualitative analysis.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object with completion logging settings
|
||||
trainer: The trainer instance to add callback to
|
||||
"""
|
||||
# Check if completion logging is enabled
|
||||
log_completions = getattr(cfg, "swanlab_log_completions", True)
|
||||
if not log_completions:
|
||||
LOG.debug("SwanLab completion logging disabled by config")
|
||||
return
|
||||
|
||||
# Check if trainer is an RLHF trainer
|
||||
trainer_name = trainer.__class__.__name__
|
||||
rlhf_trainers = ["DPO", "KTO", "ORPO", "GRPO", "CPO"]
|
||||
is_rlhf_trainer = any(name in trainer_name for name in rlhf_trainers)
|
||||
|
||||
if not is_rlhf_trainer:
|
||||
LOG.debug(
|
||||
f"Trainer {trainer_name} is not an RLHF trainer, "
|
||||
"skipping completion logging callback"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
from axolotl.integrations.swanlab.callbacks import (
|
||||
SwanLabRLHFCompletionCallback,
|
||||
)
|
||||
|
||||
# Get configuration parameters
|
||||
log_interval = getattr(cfg, "swanlab_completion_log_interval", 100)
|
||||
max_buffer = getattr(cfg, "swanlab_completion_max_buffer", 128)
|
||||
|
||||
# Create and register callback
|
||||
completion_callback = SwanLabRLHFCompletionCallback(
|
||||
log_interval=log_interval,
|
||||
max_completions=max_buffer,
|
||||
table_name="rlhf_completions",
|
||||
)
|
||||
|
||||
trainer.add_callback(completion_callback)
|
||||
|
||||
LOG.info(
|
||||
f"Registered SwanLab RLHF completion logging callback for {trainer_name} "
|
||||
f"(log_interval={log_interval}, max_buffer={max_buffer})"
|
||||
)
|
||||
|
||||
except ImportError as err:
|
||||
LOG.warning(
|
||||
f"Failed to import SwanLab completion callback: {err}\n\n"
|
||||
"This is a bug - the callback should be available.\n"
|
||||
"Please report this issue.\n\n"
|
||||
"Continuing without completion logging..."
|
||||
)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
LOG.exception(
|
||||
"Failed to register SwanLab completion callback: %s\n\n"
|
||||
"Continuing without completion logging...",
|
||||
err,
|
||||
)
|
||||
@@ -1,203 +0,0 @@
|
||||
"""SwanLab profiling utilities for Axolotl trainers.
|
||||
|
||||
This module provides decorators and context managers for profiling
|
||||
trainer methods and logging execution times to SwanLab.
|
||||
"""
|
||||
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Any, Callable
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def swanlab_profiling_context(trainer: Any, func_name: str):
|
||||
"""Context manager for profiling trainer methods.
|
||||
|
||||
Measures execution time and logs to SwanLab if enabled.
|
||||
|
||||
Example usage:
|
||||
>>> with swanlab_profiling_context(self, "training_step"):
|
||||
... result = do_expensive_computation()
|
||||
|
||||
Args:
|
||||
trainer: Trainer instance (must have cfg attribute with use_swanlab flag)
|
||||
func_name: Name of the function being profiled
|
||||
|
||||
Yields:
|
||||
None
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
duration = time.perf_counter() - start_time
|
||||
|
||||
# Check if SwanLab is enabled and initialized
|
||||
use_swanlab = getattr(getattr(trainer, "cfg", None), "use_swanlab", False)
|
||||
if use_swanlab:
|
||||
try:
|
||||
import swanlab
|
||||
|
||||
if swanlab.get_run() is not None:
|
||||
# Log profiling metric
|
||||
trainer_class = trainer.__class__.__name__
|
||||
metric_name = f"profiling/Time taken: {trainer_class}.{func_name}"
|
||||
|
||||
swanlab.log({metric_name: duration})
|
||||
|
||||
except ImportError:
|
||||
# SwanLab not installed, silently skip
|
||||
pass
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
# Log error but don't fail training
|
||||
LOG.debug(f"Failed to log profiling metric for {func_name}: {err}")
|
||||
|
||||
|
||||
def swanlab_profile(func: Callable) -> Callable:
|
||||
"""Decorator to profile and log function execution time to SwanLab.
|
||||
|
||||
Automatically measures execution time of trainer methods and logs
|
||||
to SwanLab as profiling metrics.
|
||||
|
||||
Example usage:
|
||||
>>> class MyTrainer:
|
||||
... @swanlab_profile
|
||||
... def training_step(self, model, inputs):
|
||||
... return super().training_step(model, inputs)
|
||||
|
||||
Args:
|
||||
func: Function to profile (must be a method of a trainer instance)
|
||||
|
||||
Returns:
|
||||
Wrapped function with profiling
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with swanlab_profiling_context(self, func.__name__):
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ProfilingConfig:
|
||||
"""Configuration for SwanLab profiling.
|
||||
|
||||
This class provides a centralized way to control profiling behavior.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether profiling is enabled globally
|
||||
min_duration_ms: Minimum duration (in ms) to log (filters out very fast ops)
|
||||
log_interval: Log every N function calls (to reduce overhead)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool = True,
|
||||
min_duration_ms: float = 0.1,
|
||||
log_interval: int = 1,
|
||||
):
|
||||
"""Initialize profiling configuration.
|
||||
|
||||
Args:
|
||||
enabled: Enable profiling. Default: True
|
||||
min_duration_ms: Minimum duration to log (ms). Default: 0.1
|
||||
log_interval: Log every N calls. Default: 1 (log all)
|
||||
"""
|
||||
self.enabled = enabled
|
||||
self.min_duration_ms = min_duration_ms
|
||||
self.log_interval = log_interval
|
||||
self._call_counts: dict[str, int] = {}
|
||||
|
||||
def should_log(self, func_name: str, duration_seconds: float) -> bool:
|
||||
"""Check if a profiling measurement should be logged.
|
||||
|
||||
Args:
|
||||
func_name: Name of the profiled function
|
||||
duration_seconds: Execution duration in seconds
|
||||
|
||||
Returns:
|
||||
True if should log, False otherwise
|
||||
"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
# Check minimum duration threshold
|
||||
duration_ms = duration_seconds * 1000
|
||||
if duration_ms < self.min_duration_ms:
|
||||
return False
|
||||
|
||||
# Check log interval
|
||||
self._call_counts.setdefault(func_name, 0)
|
||||
self._call_counts[func_name] += 1
|
||||
|
||||
# Always log on first call OR at intervals
|
||||
count = self._call_counts[func_name]
|
||||
if count == 1 or count % self.log_interval == 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Global profiling config (can be modified by users)
|
||||
DEFAULT_PROFILING_CONFIG = ProfilingConfig()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def swanlab_profiling_context_advanced(
|
||||
trainer: Any,
|
||||
func_name: str,
|
||||
config: ProfilingConfig | None = None,
|
||||
):
|
||||
"""Advanced profiling context with configurable behavior.
|
||||
|
||||
Similar to swanlab_profiling_context but with additional configuration
|
||||
options for filtering and throttling profiling logs.
|
||||
|
||||
Example usage:
|
||||
>>> config = ProfilingConfig(min_duration_ms=1.0, log_interval=10)
|
||||
>>> with swanlab_profiling_context_advanced(self, "forward", config):
|
||||
... output = model(inputs)
|
||||
|
||||
Args:
|
||||
trainer: Trainer instance
|
||||
func_name: Function name
|
||||
config: Profiling configuration. If None, uses DEFAULT_PROFILING_CONFIG
|
||||
|
||||
Yields:
|
||||
None
|
||||
"""
|
||||
if config is None:
|
||||
config = DEFAULT_PROFILING_CONFIG
|
||||
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
duration = time.perf_counter() - start_time
|
||||
|
||||
# Check if should log based on config
|
||||
if config.should_log(func_name, duration):
|
||||
# Check if SwanLab is enabled
|
||||
use_swanlab = getattr(getattr(trainer, "cfg", None), "use_swanlab", False)
|
||||
if use_swanlab:
|
||||
try:
|
||||
import swanlab
|
||||
|
||||
if swanlab.get_run() is not None:
|
||||
trainer_class = trainer.__class__.__name__
|
||||
metric_name = (
|
||||
f"profiling/Time taken: {trainer_class}.{func_name}"
|
||||
)
|
||||
|
||||
swanlab.log({metric_name: duration})
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
LOG.debug(f"Failed to log profiling metric for {func_name}: {err}")
|
||||
@@ -138,7 +138,6 @@ class PatchManager:
|
||||
self._apply_llama_flash_attn_patches(model)
|
||||
self._apply_unsloth_patches(model)
|
||||
self._apply_lora_kernel_patch(model)
|
||||
self._apply_scaling_softmax_patch(model)
|
||||
|
||||
def _apply_flash_attention_patches(self):
|
||||
"""Apply patches related to Flash Attention."""
|
||||
@@ -561,16 +560,3 @@ class PatchManager:
|
||||
)
|
||||
|
||||
patch_apertus_xielu_activation()
|
||||
|
||||
def _apply_scaling_softmax_patch(self, model: PreTrainedModel):
|
||||
"""Apply Scaling Softmax (SSMax) patch. Ref: https://arxiv.org/abs/2501.19399"""
|
||||
if self.cfg.scaling_softmax:
|
||||
from axolotl.monkeypatch.scaled_softmax_attn import (
|
||||
patch_scaled_softmax_attention,
|
||||
)
|
||||
|
||||
patch_scaled_softmax_attention(
|
||||
scaling_factor_init=self.cfg.scaling_softmax_factor or 0.43,
|
||||
bias=self.cfg.scaling_softmax_bias or 0.0,
|
||||
model=model,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Type
|
||||
|
||||
import addict
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -154,9 +153,6 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
|
||||
This function determines the appropriate model config source, loads it, applies any
|
||||
necessary overrides, and validates it for compatibility with the `axolotl` config.
|
||||
|
||||
If `cfg.cls_model_config` is set, a custom config class from transformers will be
|
||||
used instead of `AutoConfig` (e.g., 'LlamaConfig', 'MistralConfig').
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
@@ -178,13 +174,8 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
|
||||
if cfg.num_labels:
|
||||
# num_labels is used to initialize classifier models
|
||||
config_kwargs["num_labels"] = cfg.num_labels
|
||||
|
||||
config_cls = AutoConfig
|
||||
if cfg.cls_model_config:
|
||||
config_cls = getattr(transformers, cfg.cls_model_config)
|
||||
|
||||
try:
|
||||
model_config = config_cls.from_pretrained(
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_config_name,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**config_kwargs,
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
"""Dynamic Fine-Tuning (DFT) loss implementation"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def selective_log_softmax(logits, index):
|
||||
"""Memory-efficient log_softmax -> gather"""
|
||||
if logits.dtype in [torch.float32, torch.float64]:
|
||||
selected_logits = torch.gather(
|
||||
logits, dim=-1, index=index.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
||||
per_token_logps = selected_logits - logsumexp_values
|
||||
else:
|
||||
per_token_logps = []
|
||||
for row_logits, row_labels in zip(logits, index, strict=True):
|
||||
row_logps = F.log_softmax(row_logits, dim=-1)
|
||||
row_per_token_logps = row_logps.gather(
|
||||
dim=-1, index=row_labels.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
per_token_logps.append(row_per_token_logps)
|
||||
per_token_logps = torch.stack(per_token_logps)
|
||||
return per_token_logps
|
||||
|
||||
|
||||
def get_dft_loss(ignore_index: int = -100):
|
||||
"""Creates DFT loss function"""
|
||||
|
||||
def for_causal_lm_dft_loss(
|
||||
logits,
|
||||
labels,
|
||||
vocab_size: int = None,
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
ignore_index: int = -100,
|
||||
shift_labels: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""DFT loss: -exp(logprobs).detach() * logprobs"""
|
||||
if shift_labels is None:
|
||||
# Shift so that tokens < n predict n
|
||||
labels = F.pad(labels, (0, 1), value=ignore_index)
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
shift_labels = shift_labels.to(logits.device)
|
||||
|
||||
# Create loss mask
|
||||
loss_mask = shift_labels != ignore_index
|
||||
shift_labels_masked = shift_labels.clone()
|
||||
shift_labels_masked[~loss_mask] = 0
|
||||
|
||||
# Compute log probabilities
|
||||
logprobs = selective_log_softmax(logits, shift_labels_masked)
|
||||
|
||||
# DFT loss: -exp(logprobs).detach() * logprobs
|
||||
per_token_loss = -logprobs.exp().detach() * logprobs
|
||||
|
||||
# Sum over valid tokens and normalize
|
||||
if num_items_in_batch is None:
|
||||
num_items_in_batch = loss_mask.sum()
|
||||
|
||||
loss = (per_token_loss * loss_mask).sum() / num_items_in_batch
|
||||
return loss
|
||||
|
||||
return for_causal_lm_dft_loss
|
||||
|
||||
|
||||
def dft_loss(outputs, labels, num_items_in_batch=None):
|
||||
"""DFT loss compatible with Trainer.compute_loss_func signature.
|
||||
|
||||
This function is designed to be passed to Trainer's compute_loss_func parameter.
|
||||
"""
|
||||
ignore_index = -100
|
||||
|
||||
# Shift labels for causal LM
|
||||
labels = F.pad(labels, (0, 1), value=ignore_index)
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_labels = shift_labels.to(outputs.logits.device)
|
||||
|
||||
# Create loss mask
|
||||
loss_mask = shift_labels != ignore_index
|
||||
shift_labels_masked = shift_labels.clone()
|
||||
shift_labels_masked[~loss_mask] = 0
|
||||
|
||||
# Compute log probabilities
|
||||
logprobs = selective_log_softmax(outputs.logits, shift_labels_masked)
|
||||
|
||||
# DFT loss: -exp(logprobs).detach() * logprobs
|
||||
per_token_loss = -logprobs.exp().detach() * logprobs
|
||||
|
||||
# Sum over valid tokens and normalize
|
||||
if num_items_in_batch is None:
|
||||
num_items_in_batch = loss_mask.sum()
|
||||
|
||||
loss = (per_token_loss * loss_mask).sum() / num_items_in_batch
|
||||
return loss
|
||||
@@ -1,141 +0,0 @@
|
||||
"""
|
||||
Scaled Softmax (SSMax) attention patch using FlexAttention.
|
||||
SSMax: softmax(scores * s * log(n) + b) where n is the position index
|
||||
Ref: https://arxiv.org/abs/2501.19399
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
try:
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
from transformers.integrations.flex_attention import (
|
||||
compile_friendly_flex_attention,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
FLEX_ATTENTION_AVAILABLE = True
|
||||
except ImportError:
|
||||
FLEX_ATTENTION_AVAILABLE = False
|
||||
BlockMask = None
|
||||
|
||||
_ssmax_config = {}
|
||||
|
||||
|
||||
def patch_scaled_softmax_attention(
|
||||
scaling_factor_init: float = 0.43, bias: float = 0.0, model: PreTrainedModel = None
|
||||
):
|
||||
"""Patch attention to apply SSMax via FlexAttention score_mod."""
|
||||
global _ssmax_config
|
||||
|
||||
if not FLEX_ATTENTION_AVAILABLE:
|
||||
raise RuntimeError("SSMax requires FlexAttention.")
|
||||
|
||||
_ssmax_config["ssmax_s"] = scaling_factor_init
|
||||
_ssmax_config["ssmax_b"] = bias
|
||||
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
if "flex_attention" in ALL_ATTENTION_FUNCTIONS:
|
||||
_ssmax_config["original_flex_fn"] = ALL_ATTENTION_FUNCTIONS["flex_attention"]
|
||||
ALL_ATTENTION_FUNCTIONS["flex_attention"] = ssmax_flex_attention_forward
|
||||
LOG.info(
|
||||
f"Patched flex_attention with SSMax (s={scaling_factor_init}, b={bias})"
|
||||
)
|
||||
else:
|
||||
LOG.warning("flex_attention not found. Ensure flex_attention: true is set.")
|
||||
|
||||
|
||||
def ssmax_flex_attention_forward(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask,
|
||||
scaling: float | None = None,
|
||||
softcap: float | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""FlexAttention forward with SSMax: score * (s * log(n) + b)."""
|
||||
|
||||
if kwargs.get("dropout", 0.0) > 0:
|
||||
raise ValueError("flex_attention does not support dropout")
|
||||
|
||||
ssmax_s = _ssmax_config.get("ssmax_s", 0.43)
|
||||
ssmax_b = _ssmax_config.get("ssmax_b", 0.0)
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
position_ids_flat = position_ids.view(-1) if position_ids is not None else None
|
||||
|
||||
block_mask = attention_mask if isinstance(attention_mask, BlockMask) else None
|
||||
score_mask = None if block_mask else attention_mask
|
||||
|
||||
if score_mask is not None:
|
||||
score_mask = score_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
|
||||
"""
|
||||
Apply SSMax scaling: score * (s * log(n) + b)
|
||||
where n is the relative position within each packed sequence.
|
||||
"""
|
||||
if position_ids_flat is not None:
|
||||
relative_pos = position_ids_flat[q_idx]
|
||||
n = (relative_pos + 1).float()
|
||||
else:
|
||||
n = (q_idx + 1).float()
|
||||
|
||||
n = torch.clamp(n, min=2.0)
|
||||
|
||||
ssmax_scale = ssmax_s * torch.log(n) + ssmax_b
|
||||
score = score * ssmax_scale
|
||||
|
||||
if softcap is not None:
|
||||
score = softcap * torch.tanh(score / softcap)
|
||||
|
||||
if score_mask is not None:
|
||||
score = score + score_mask[batch_idx][0][q_idx][kv_idx]
|
||||
|
||||
return score
|
||||
|
||||
enable_gqa = True
|
||||
if (query.shape[1] & (query.shape[1] - 1)) != 0:
|
||||
key = repeat_kv(key, query.shape[1] // key.shape[1])
|
||||
value = repeat_kv(value, query.shape[1] // value.shape[1])
|
||||
enable_gqa = False
|
||||
|
||||
return_lse = query.device.type != "cpu"
|
||||
flex_output = compile_friendly_flex_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
enable_gqa=enable_gqa,
|
||||
scale=scaling,
|
||||
kernel_options=kwargs.get("kernel_options"),
|
||||
return_lse=return_lse,
|
||||
training=module.training,
|
||||
)
|
||||
|
||||
if return_lse:
|
||||
attention_output, lse = flex_output
|
||||
lse = lse.to(value.dtype)
|
||||
else:
|
||||
attention_output, lse = flex_output, None
|
||||
|
||||
return attention_output.transpose(1, 2).contiguous(), lse
|
||||
|
||||
|
||||
def unpatch_scaled_softmax_attention():
|
||||
"""Restore the original FlexAttention function."""
|
||||
global _ssmax_config
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
if "original_flex_fn" in _ssmax_config:
|
||||
ALL_ATTENTION_FUNCTIONS["flex_attention"] = _ssmax_config["original_flex_fn"]
|
||||
_ssmax_config.clear()
|
||||
LOG.info("Unpatched flex_attention, restored original")
|
||||
@@ -1,248 +0,0 @@
|
||||
"""Callbacks for SwanLab integration"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.training_args import AxolotlTrainingArguments
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class CustomSwanLabCallback(TrainerCallback):
|
||||
"""
|
||||
Lightweight SwanLab callback that directly logs metrics without using
|
||||
SwanLab's transformers integration (which requires omegaconf).
|
||||
|
||||
This avoids the antlr4 version conflict between omegaconf and axolotl.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._initialized = False
|
||||
self.swanlab = None
|
||||
|
||||
def setup(self):
|
||||
"""Lazy initialization of SwanLab"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
import swanlab
|
||||
|
||||
self.swanlab = swanlab
|
||||
|
||||
# Check if SwanLab run is initialized
|
||||
if swanlab.get_run() is None:
|
||||
LOG.warning("SwanLab run is not initialized")
|
||||
return
|
||||
|
||||
self._initialized = True
|
||||
LOG.info("CustomSwanLabCallback initialized successfully")
|
||||
except ImportError:
|
||||
LOG.error("SwanLab is not installed")
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Called at the beginning of training"""
|
||||
if not state.is_world_process_zero:
|
||||
return control
|
||||
|
||||
self.setup()
|
||||
|
||||
if not self._initialized:
|
||||
return control
|
||||
|
||||
# Log training configuration
|
||||
try:
|
||||
self.swanlab.config.update(
|
||||
{
|
||||
"train_batch_size": args.per_device_train_batch_size,
|
||||
"eval_batch_size": args.per_device_eval_batch_size,
|
||||
"learning_rate": args.learning_rate,
|
||||
"num_train_epochs": args.num_train_epochs,
|
||||
"max_steps": args.max_steps,
|
||||
"warmup_steps": args.warmup_steps,
|
||||
"logging_steps": args.logging_steps,
|
||||
"save_steps": args.save_steps,
|
||||
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
||||
}
|
||||
)
|
||||
LOG.debug("Training configuration logged to SwanLab")
|
||||
except Exception as err:
|
||||
LOG.warning(f"Failed to log training config: {err}")
|
||||
|
||||
return control
|
||||
|
||||
def on_log(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
logs=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Called when logging metrics"""
|
||||
if not state.is_world_process_zero:
|
||||
return control
|
||||
|
||||
if not self._initialized:
|
||||
self.setup()
|
||||
|
||||
if not self._initialized or logs is None:
|
||||
return control
|
||||
|
||||
# Log metrics to SwanLab
|
||||
try:
|
||||
# Filter out non-numeric values and prepare for logging
|
||||
metrics = {}
|
||||
for key, value in logs.items():
|
||||
if isinstance(value, (int, float)):
|
||||
# Use step from state
|
||||
metrics[key] = value
|
||||
|
||||
if metrics and state.global_step is not None:
|
||||
self.swanlab.log(metrics, step=state.global_step)
|
||||
except Exception as err:
|
||||
LOG.warning(f"Failed to log metrics to SwanLab: {err}")
|
||||
|
||||
return control
|
||||
|
||||
def on_train_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Called at the end of training"""
|
||||
if not state.is_world_process_zero:
|
||||
return control
|
||||
|
||||
if self._initialized:
|
||||
LOG.info("Training completed. SwanLab logs are available.")
|
||||
|
||||
return control
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoSwanLabCallback(TrainerCallback):
|
||||
"""Callback to save axolotl config to SwanLab"""
|
||||
|
||||
def __init__(self, axolotl_config_path):
|
||||
self.axolotl_config_path = axolotl_config_path
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: AxolotlTrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if state.is_world_process_zero:
|
||||
try:
|
||||
import swanlab
|
||||
|
||||
# Check if SwanLab is initialized
|
||||
if swanlab.get_run() is None:
|
||||
LOG.warning(
|
||||
"SwanLab run is not initialized. Please initialize SwanLab before training."
|
||||
)
|
||||
return control
|
||||
|
||||
# Log Axolotl config as artifact
|
||||
with NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||
) as temp_file:
|
||||
copyfile(self.axolotl_config_path, temp_file.name)
|
||||
|
||||
# Log config file to SwanLab
|
||||
with open(temp_file.name, "r", encoding="utf-8") as config_file:
|
||||
swanlab.log(
|
||||
{
|
||||
"axolotl_config": swanlab.Text(
|
||||
config_file.read(), caption="Axolotl Config"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
"The Axolotl config has been saved to the SwanLab run under logs."
|
||||
)
|
||||
|
||||
# Clean up temp file
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
except ImportError:
|
||||
LOG.warning(
|
||||
"SwanLab is not installed. Install it with: pip install swanlab"
|
||||
)
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to SwanLab: {err}")
|
||||
|
||||
# Log DeepSpeed config if available
|
||||
if args.deepspeed:
|
||||
try:
|
||||
import swanlab
|
||||
|
||||
with NamedTemporaryFile(
|
||||
mode="w",
|
||||
delete=False,
|
||||
suffix=".json",
|
||||
prefix="deepspeed_config_",
|
||||
) as temp_file:
|
||||
skip_upload = False
|
||||
if isinstance(args.deepspeed, dict):
|
||||
json.dump(args.deepspeed, temp_file, indent=4)
|
||||
elif isinstance(args.deepspeed, str) and os.path.exists(
|
||||
args.deepspeed
|
||||
):
|
||||
copyfile(args.deepspeed, temp_file.name)
|
||||
else:
|
||||
skip_upload = True
|
||||
|
||||
if not skip_upload:
|
||||
temp_file.flush()
|
||||
with open(
|
||||
temp_file.name, "r", encoding="utf-8"
|
||||
) as ds_config_file:
|
||||
swanlab.log(
|
||||
{
|
||||
"deepspeed_config": swanlab.Text(
|
||||
ds_config_file.read(),
|
||||
caption="DeepSpeed Config",
|
||||
)
|
||||
}
|
||||
)
|
||||
LOG.info(
|
||||
"The DeepSpeed config has been saved to the SwanLab run under logs."
|
||||
)
|
||||
|
||||
# Clean up temp file
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(
|
||||
f"Error while saving DeepSpeed config to SwanLab: {err}"
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return control
|
||||
@@ -101,3 +101,9 @@ class TokensPerSecondCallback(TrainerCallback):
|
||||
# Clear per-step tokens after logging
|
||||
if tokens and "trainable_tokens" in tokens:
|
||||
tokens["trainable_tokens"] = torch.zeros_like(tokens["trainable_tokens"])
|
||||
|
||||
if tokens and "total" in tokens:
|
||||
logs["tokens/total"] = tokens["total"].item()
|
||||
|
||||
if tokens and "trainable" in tokens:
|
||||
logs["tokens/trainable"] = tokens["trainable"].item()
|
||||
|
||||
@@ -9,6 +9,10 @@ from torchao.quantization import quantize_
|
||||
from torchao.quantization.qat import (
|
||||
QATConfig,
|
||||
)
|
||||
from torchao.quantization.qat import fake_quantizer
|
||||
from torchao.quantization.qat.fake_quantizer import (
|
||||
Int4WeightFakeQuantizer as AoInt4WeightFakeQuantizer,
|
||||
)
|
||||
from torchao.quantization.quant_api import (
|
||||
Float8DynamicActivationFloat8WeightConfig,
|
||||
Float8DynamicActivationInt4WeightConfig,
|
||||
@@ -17,6 +21,27 @@ from torchao.quantization.quant_api import (
|
||||
|
||||
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||
|
||||
|
||||
class Int4WeightFakeQuantizer(AoInt4WeightFakeQuantizer):
|
||||
"""
|
||||
Adds 'enabled' attribute to Int4WeightFakeQuantizer (removed in torchao 0.15).
|
||||
Allows toggling fake quantization on/off for fake_quant_after_n_steps.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.enabled = True
|
||||
|
||||
def forward(self, w: torch.Tensor) -> torch.Tensor:
|
||||
if not self.enabled:
|
||||
return w
|
||||
return super().forward(w)
|
||||
|
||||
|
||||
# Replace the original Int4WeightFakeQuantizer in the fake_quantizer module
|
||||
# so that torchao's quantize_() function will use our version
|
||||
fake_quantizer.Int4WeightFakeQuantizer = Int4WeightFakeQuantizer
|
||||
|
||||
quantization_config_to_str = {
|
||||
Int8DynamicActivationInt4WeightConfig: "int8int4",
|
||||
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
|
||||
|
||||
@@ -619,25 +619,6 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
scaling_softmax: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to use Scaled Softmax (SSMax) attention. Ref: https://arxiv.org/abs/2501.19399"
|
||||
},
|
||||
)
|
||||
scaling_softmax_factor: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Scaling factor for SSMax attention. Default is 0.43"
|
||||
},
|
||||
)
|
||||
scaling_softmax_bias: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Bias for SSMax attention. Default is 0.0. Note: The paper recommends bias=0 for better length generalization."
|
||||
},
|
||||
)
|
||||
|
||||
unsloth_cross_entropy_loss: bool | None = None
|
||||
unsloth_lora_mlp: bool | None = None
|
||||
unsloth_lora_qkv: bool | None = None
|
||||
@@ -676,10 +657,6 @@ class AxolotlInputConfig(
|
||||
"description": "Number of chunks to use for chunked cross entropy loss"
|
||||
},
|
||||
)
|
||||
use_dynamic_finetuning: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Enable Dynamic Fine-Tuning loss (DFT)"},
|
||||
)
|
||||
|
||||
tiled_mlp: bool | None = Field(
|
||||
default=None,
|
||||
|
||||
@@ -25,12 +25,7 @@ class ModelInputConfig(BaseModel):
|
||||
"description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model"
|
||||
},
|
||||
)
|
||||
cls_model_config: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "transformers config class (e.g., 'LlamaConfig', 'MistralConfig'). Defaults to AutoConfig."
|
||||
},
|
||||
)
|
||||
cls_model_config: str | None = None
|
||||
tokenizer_config: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -201,16 +201,6 @@ class AttentionValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_scaling_softmax_requires_flex(cls, data):
|
||||
if data.get("scaling_softmax") and not data.get("flex_attention"):
|
||||
raise ValueError(
|
||||
"scaling_softmax requires flex_attention: true\n"
|
||||
"Add 'flex_attention: true' to your config file.\n"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class TrainingValidationMixin:
|
||||
"""Validation methods related to training configuration."""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,92 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestLoRAConfigValidation:
|
||||
"""Test suite for LoRA/QLoRA configuration validation"""
|
||||
|
||||
def test_basic_configuration_validation(self):
|
||||
"""Test basic LoRA configuration validation"""
|
||||
|
||||
valid_config = DictDefault(
|
||||
{
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.1,
|
||||
"lora_target_modules": ["q_proj", "v_proj"],
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
|
||||
result = validate_config(valid_config)
|
||||
assert result["adapter"] == "lora"
|
||||
|
||||
with pytest.raises(ValueError, match="not compatible with DoRA"):
|
||||
invalid_config = DictDefault(
|
||||
{
|
||||
"adapter": "lora",
|
||||
"lora_mlp_kernel": True,
|
||||
"peft_use_dora": True,
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
validate_config(invalid_config)
|
||||
|
||||
def test_qlora_4bit_validation(self):
|
||||
"""Test QLoRA 4-bit configuration validation"""
|
||||
valid_config = DictDefault(
|
||||
{
|
||||
"adapter": "qlora",
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_compute_dtype": "float16",
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
result = validate_config(valid_config)
|
||||
assert result["adapter"] == "qlora"
|
||||
assert result["load_in_4bit"] is True
|
||||
|
||||
# Test QLoRA without 4-bit (should fail via PEFT validation)
|
||||
with pytest.raises(ValueError, match=r"Require cfg\.load_in_4bit"):
|
||||
invalid_config = DictDefault(
|
||||
{
|
||||
"adapter": "qlora",
|
||||
"load_in_4bit": False,
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
validate_config(invalid_config)
|
||||
|
||||
# Test QLoRA with 8-bit (incompatible)
|
||||
with pytest.raises(ValueError, match="Can't load qlora in 8bit"):
|
||||
invalid_config = DictDefault(
|
||||
{
|
||||
"adapter": "qlora",
|
||||
"load_in_8bit": True,
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
validate_config(invalid_config)
|
||||
@@ -1,261 +0,0 @@
|
||||
import importlib.util
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from axolotl.kernels.lora import get_lora_parameters
|
||||
|
||||
PEFT_AVAILABLE = importlib.util.find_spec("peft") is not None
|
||||
|
||||
|
||||
class TestLoRAParameterFreezing:
|
||||
"""Test suite for LoRA parameter freezing validation."""
|
||||
|
||||
def setup_method(self):
|
||||
self.dtype = torch.float32
|
||||
|
||||
def create_mock_lora_layer(
|
||||
self, has_adapters=True, adapters_disabled=False, merged=False
|
||||
):
|
||||
"""Create a mock LoRA layer for testing."""
|
||||
mock_layer = Mock()
|
||||
|
||||
base_layer = Mock()
|
||||
base_layer.weight = torch.randn(512, 256, dtype=self.dtype)
|
||||
base_layer.bias = torch.randn(512, dtype=self.dtype)
|
||||
|
||||
if has_adapters:
|
||||
mock_layer.base_layer = base_layer
|
||||
mock_layer.disable_adapters = adapters_disabled
|
||||
mock_layer.merged = merged
|
||||
|
||||
mock_layer.active_adapters = ["default"]
|
||||
mock_layer.lora_A = {"default": Mock()}
|
||||
mock_layer.lora_B = {"default": Mock()}
|
||||
mock_layer.scaling = {"default": 0.1}
|
||||
|
||||
mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype)
|
||||
mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype)
|
||||
else:
|
||||
mock_layer.weight = base_layer.weight
|
||||
mock_layer.bias = base_layer.bias
|
||||
|
||||
return mock_layer
|
||||
|
||||
def test_parameter_freezing_adapters_disabled(self):
|
||||
"""Test that LoRA parameters are None when adapters are disabled."""
|
||||
layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
# Base parameters should be returned
|
||||
assert W is not None
|
||||
assert b is not None
|
||||
# LoRA parameters should be None (frozen)
|
||||
assert A is None
|
||||
assert B is None
|
||||
assert s is None
|
||||
|
||||
def test_parameter_freezing_adapters_merged(self):
|
||||
"""Test that LoRA parameters are None when adapters are merged."""
|
||||
layer = self.create_mock_lora_layer(has_adapters=True, merged=True)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
# Base parameters should be returned
|
||||
assert W is not None
|
||||
assert b is not None
|
||||
|
||||
# LoRA parameters should be None (frozen)
|
||||
assert A is None
|
||||
assert B is None
|
||||
assert s is None
|
||||
|
||||
def test_parameter_freezing_no_adapters(self):
|
||||
"""Test parameter behavior when no adapters are present."""
|
||||
layer = self.create_mock_lora_layer(has_adapters=False)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
# Base parameters should be returned
|
||||
assert W is not None
|
||||
assert b is not None
|
||||
|
||||
# LoRA parameters should be None (frozen)
|
||||
assert A is None
|
||||
assert B is None
|
||||
assert s is None
|
||||
|
||||
def test_parameter_active_adapters_enabled(self):
|
||||
"""Test that LoRA parameters are returned when adapters are active."""
|
||||
layer = self.create_mock_lora_layer(
|
||||
has_adapters=True, adapters_disabled=False, merged=False
|
||||
)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
# All parameters should be returned
|
||||
assert W is not None
|
||||
assert b is not None
|
||||
assert A is not None
|
||||
assert B is not None
|
||||
assert s is not None
|
||||
assert s == 0.1
|
||||
|
||||
def test_parameter_shapes_consistency(self):
|
||||
"""Test that parameter shapes are consistent when active."""
|
||||
layer = self.create_mock_lora_layer(
|
||||
has_adapters=True, adapters_disabled=False, merged=False
|
||||
)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
# Check shape consistency
|
||||
assert W.shape == (512, 256)
|
||||
assert b.shape == (512,)
|
||||
assert A.shape == (16, 256)
|
||||
assert B.shape == (512, 16)
|
||||
|
||||
def test_parameter_dtypes_consistency(self):
|
||||
"""Test that parameter dtypes are consistent."""
|
||||
layer = self.create_mock_lora_layer(
|
||||
has_adapters=True, adapters_disabled=False, merged=False
|
||||
)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
assert W.dtype == self.dtype
|
||||
assert b.dtype == self.dtype
|
||||
assert A.dtype == self.dtype
|
||||
assert B.dtype == self.dtype
|
||||
|
||||
def test_quantization_state_handling(self):
|
||||
"""Test that quantization state is properly handled."""
|
||||
layer = self.create_mock_lora_layer(has_adapters=True)
|
||||
|
||||
quant_state_mock = Mock()
|
||||
layer.base_layer.weight.quant_state = quant_state_mock
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
assert quant_state == quant_state_mock
|
||||
|
||||
def test_multiple_adapters_active_adapter_selection(self):
|
||||
"""Test that the correct adapter is selected when multiple adapters exist."""
|
||||
layer = self.create_mock_lora_layer(
|
||||
has_adapters=True, adapters_disabled=False, merged=False
|
||||
)
|
||||
|
||||
layer.lora_A["adapter2"] = Mock()
|
||||
layer.lora_B["adapter2"] = Mock()
|
||||
layer.scaling["adapter2"] = 0.2
|
||||
|
||||
layer.lora_A["adapter2"].weight = torch.randn(16, 256, dtype=self.dtype)
|
||||
layer.lora_B["adapter2"].weight = torch.randn(512, 16, dtype=self.dtype)
|
||||
|
||||
layer.active_adapters = ["adapter2"]
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
|
||||
assert s == 0.2
|
||||
assert torch.equal(A, layer.lora_A["adapter2"].weight)
|
||||
assert torch.equal(B, layer.lora_B["adapter2"].weight)
|
||||
|
||||
|
||||
class TestLoRAParameterFreezingIntegration:
|
||||
"""Integration tests for parameter freezing with actual LoRA layers."""
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not PEFT_AVAILABLE, reason="PEFT not available for integration tests"
|
||||
)
|
||||
def test_parameter_freezing_with_real_lora_layer(self):
|
||||
"""Test parameter freezing with actual PEFT LoRA layer."""
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(256, 512)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
base_model = SimpleModel()
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["linear"],
|
||||
lora_dropout=0.1,
|
||||
)
|
||||
model = get_peft_model(base_model, lora_config)
|
||||
lora_layer = model.base_model.model.linear
|
||||
# Test with adapters enabled
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
|
||||
assert A is not None
|
||||
assert B is not None
|
||||
assert s is not None
|
||||
# Test with adapters disabled
|
||||
model.disable_adapter_layers()
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
|
||||
assert A is None
|
||||
assert B is None
|
||||
assert s is None
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not PEFT_AVAILABLE, reason="PEFT not available for integration tests"
|
||||
)
|
||||
def test_parameter_freezing_gradient_behavior(self):
|
||||
"""Test that frozen parameters don't receive gradients."""
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
class SimpleModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(256, 512)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
base_model = SimpleModel()
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["linear"],
|
||||
lora_dropout=0.1,
|
||||
)
|
||||
model = get_peft_model(base_model, lora_config)
|
||||
x = torch.randn(1, 256)
|
||||
target = torch.randn(1, 512)
|
||||
model.enable_adapter_layers()
|
||||
output = model(x)
|
||||
loss = nn.MSELoss()(output, target)
|
||||
loss.backward()
|
||||
lora_layer = model.base_model.model.linear
|
||||
has_lora_grads = any(
|
||||
param.grad is not None
|
||||
for name, param in lora_layer.named_parameters()
|
||||
if "lora_" in name
|
||||
)
|
||||
assert has_lora_grads, (
|
||||
"LoRA parameters should have gradients when adapters are enabled"
|
||||
)
|
||||
model.zero_grad()
|
||||
model.disable_adapter_layers()
|
||||
output = model(x)
|
||||
loss = nn.MSELoss()(output, target)
|
||||
any_requires_grad = any(param.requires_grad for param in model.parameters())
|
||||
if any_requires_grad:
|
||||
loss.backward()
|
||||
has_lora_grads_disabled = any(
|
||||
param.grad is not None
|
||||
for name, param in lora_layer.named_parameters()
|
||||
if "lora_" in name
|
||||
)
|
||||
assert not has_lora_grads_disabled, (
|
||||
"LoRA parameters should not have gradients when adapters are disabled"
|
||||
)
|
||||
model.zero_grad()
|
||||
del model, base_model, lora_layer, x, target, output, loss
|
||||
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||||
@@ -1,181 +0,0 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.cli.merge_lora import do_merge_lora
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestAdapterMergeUnmerge:
|
||||
"""Test suite for LoRA adapter merging/unmerging functionality"""
|
||||
|
||||
def setup_method(self):
|
||||
self.dtype = torch.float32
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
def create_mock_base_model(self, vocab_size=1000, hidden_size=256):
|
||||
"""Create a mock base model with linear layers"""
|
||||
mock_model = Mock()
|
||||
|
||||
mock_model.config = Mock()
|
||||
mock_model.config.vocab_size = vocab_size
|
||||
mock_model.config.hidden_size = hidden_size
|
||||
|
||||
mock_model.q_proj = Mock()
|
||||
mock_model.q_proj.weight = torch.randn(
|
||||
hidden_size, hidden_size, dtype=self.dtype
|
||||
)
|
||||
mock_model.q_proj.bias = torch.randn(hidden_size, dtype=self.dtype)
|
||||
|
||||
mock_model.v_proj = Mock()
|
||||
mock_model.v_proj.weight = torch.randn(
|
||||
hidden_size, hidden_size, dtype=self.dtype
|
||||
)
|
||||
mock_model.v_proj.bias = torch.randn(hidden_size, dtype=self.dtype)
|
||||
|
||||
return mock_model
|
||||
|
||||
def create_mock_lora_model(self, base_model, r=8, alpha=16):
|
||||
"""Create a mock LoRA model wrapping the base model"""
|
||||
mock_lora_model = Mock()
|
||||
mock_lora_model.base_model = base_model
|
||||
|
||||
mock_lora_model.merge_and_unload = None
|
||||
mock_lora_model.to = Mock(return_value=mock_lora_model)
|
||||
|
||||
mock_lora_model.generation_config = Mock()
|
||||
mock_lora_model.config = Mock()
|
||||
|
||||
self.original_q_weight = base_model.q_proj.weight.clone()
|
||||
self.original_v_weight = base_model.v_proj.weight.clone()
|
||||
|
||||
mock_lora_model.peft_config = {"default": Mock()}
|
||||
mock_lora_model.peft_config["default"].r = r
|
||||
mock_lora_model.peft_config["default"].lora_alpha = alpha
|
||||
|
||||
self.lora_A_q = torch.randn(
|
||||
r, base_model.q_proj.weight.shape[1], dtype=self.dtype
|
||||
)
|
||||
self.lora_B_q = torch.randn(
|
||||
base_model.q_proj.weight.shape[0], r, dtype=self.dtype
|
||||
)
|
||||
|
||||
self.lora_A_v = torch.randn(
|
||||
r, base_model.v_proj.weight.shape[1], dtype=self.dtype
|
||||
)
|
||||
self.lora_B_v = torch.randn(
|
||||
base_model.v_proj.weight.shape[0], r, dtype=self.dtype
|
||||
)
|
||||
|
||||
self.scaling = alpha / r
|
||||
|
||||
def mock_merge_and_unload(progressbar=False):
|
||||
"""Simulate the actual merge operation"""
|
||||
# Apply LoRA delta to base weights: W_new = W_base + (B @ A) * scaling
|
||||
delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
|
||||
delta_v = (self.lora_B_v @ self.lora_A_v) * self.scaling
|
||||
|
||||
base_model.q_proj.weight = self.original_q_weight + delta_q
|
||||
base_model.v_proj.weight = self.original_v_weight + delta_v
|
||||
|
||||
return base_model
|
||||
|
||||
mock_lora_model.merge_and_unload = mock_merge_and_unload
|
||||
return mock_lora_model
|
||||
|
||||
def test_basic_lora_merge_unmerge_cycle(self):
|
||||
"""Test: original_weights -> merge -> unmerge -> should equal original_weights"""
|
||||
|
||||
base_model = self.create_mock_base_model()
|
||||
lora_model = self.create_mock_lora_model(base_model)
|
||||
|
||||
original_q_weight = self.original_q_weight.clone()
|
||||
original_v_weight = self.original_v_weight.clone()
|
||||
|
||||
merged_model = lora_model.merge_and_unload()
|
||||
|
||||
assert not torch.equal(merged_model.q_proj.weight, original_q_weight)
|
||||
assert not torch.equal(merged_model.v_proj.weight, original_v_weight)
|
||||
|
||||
delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
|
||||
delta_v = (self.lora_B_v @ self.lora_A_v) * self.scaling
|
||||
|
||||
unmerged_q_weight = merged_model.q_proj.weight - delta_q
|
||||
unmerged_v_weight = merged_model.v_proj.weight - delta_v
|
||||
|
||||
assert torch.allclose(unmerged_q_weight, original_q_weight, atol=1e-6)
|
||||
assert torch.allclose(unmerged_v_weight, original_v_weight, atol=1e-6)
|
||||
|
||||
def test_merge_weight_calculation_accuracy(self):
|
||||
"""Test: merged_weight = base_weight + (lora_B @ lora_A * scaling)"""
|
||||
base_model = self.create_mock_base_model()
|
||||
lora_model = self.create_mock_lora_model(base_model, r=16, alpha=32)
|
||||
|
||||
expected_delta_q = (self.lora_B_q @ self.lora_A_q) * self.scaling
|
||||
expected_merged_q = self.original_q_weight + expected_delta_q
|
||||
merged_model = lora_model.merge_and_unload()
|
||||
|
||||
assert torch.allclose(merged_model.q_proj.weight, expected_merged_q, atol=1e-6)
|
||||
|
||||
@patch("axolotl.cli.merge_lora.load_model_and_tokenizer")
|
||||
def test_cli_do_merge_functionality(self, mock_load_model, tmp_path):
|
||||
base_model = self.create_mock_base_model()
|
||||
lora_model = self.create_mock_lora_model(base_model)
|
||||
tokenizer = Mock()
|
||||
processor = None
|
||||
|
||||
mock_load_model.return_value = (lora_model, tokenizer, processor)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"save_safetensors": True,
|
||||
"torch_dtype": torch.float32,
|
||||
"local_rank": 0,
|
||||
"output_dir": str(tmp_path),
|
||||
}
|
||||
)
|
||||
|
||||
with (
|
||||
patch("pathlib.Path.mkdir"),
|
||||
patch.object(base_model, "save_pretrained") as mock_save_model,
|
||||
patch.object(tokenizer, "save_pretrained") as mock_save_tokenizer,
|
||||
):
|
||||
do_merge_lora(cfg=cfg)
|
||||
|
||||
mock_save_model.assert_called_once()
|
||||
mock_save_tokenizer.assert_called_once()
|
||||
|
||||
def test_quantized_model_merge_compatibility(self):
|
||||
"""Test 4-bit/8-bit model merging scenarios"""
|
||||
base_model = self.create_mock_base_model()
|
||||
|
||||
# Mock quantized weights
|
||||
base_model.q_proj.weight.quant_state = Mock()
|
||||
base_model.q_proj.weight.quant_state.dtype = torch.uint8
|
||||
|
||||
lora_model = self.create_mock_lora_model(base_model)
|
||||
|
||||
merged_model = lora_model.merge_and_unload()
|
||||
assert merged_model is not None
|
||||
|
||||
@patch.dict("os.environ", {"CUDA_VISIBLE_DEVICES": ""})
|
||||
def test_memory_efficient_merge_with_cpu_offload(self, tmp_path):
|
||||
"""Test lora_on_cpu configuration during merge"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"lora_on_cpu": True,
|
||||
"save_safetensors": True,
|
||||
"output_dir": str(tmp_path),
|
||||
"local_rank": 0,
|
||||
}
|
||||
)
|
||||
|
||||
with patch("axolotl.cli.merge_lora.load_model_and_tokenizer") as mock_load:
|
||||
base_model = self.create_mock_base_model()
|
||||
lora_model = self.create_mock_lora_model(base_model)
|
||||
mock_load.return_value = (lora_model, Mock(), None)
|
||||
|
||||
with patch("pathlib.Path.mkdir"), patch("torch.save"):
|
||||
do_merge_lora(cfg=cfg)
|
||||
|
||||
assert mock_load.called
|
||||
Reference in New Issue
Block a user