Compare commits
38 Commits
uv-first
...
fix/hpc-ro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
83ff8bfa1a | ||
|
|
c37decb073 | ||
|
|
01a346d86a | ||
|
|
26f05b6008 | ||
|
|
ed58fa8a75 | ||
|
|
633afffacb | ||
|
|
4b1b4fa6d8 | ||
|
|
0f7c886b7b | ||
|
|
a4b921135b | ||
|
|
98333e639a | ||
|
|
9d4d39e939 | ||
|
|
bb33fda44d | ||
|
|
4dc018992d | ||
|
|
243620394a | ||
|
|
3750fdcf79 | ||
|
|
613bcf90e5 | ||
|
|
383f220cfd | ||
|
|
8bb871b5cf | ||
|
|
87565ecc05 | ||
|
|
93ba57396f | ||
|
|
aa1240acd8 | ||
|
|
4cdfdfebb5 | ||
|
|
6e2f5ccf9f | ||
|
|
8c7f63cf97 | ||
|
|
cd856b45b1 | ||
|
|
143dea4753 | ||
|
|
bc2ffb8204 | ||
|
|
153edcfe79 | ||
|
|
08b8fa62cc | ||
|
|
3a5c97e6e5 | ||
|
|
37f78c8592 | ||
|
|
ab63b92c38 | ||
|
|
6f8ce024d1 | ||
|
|
d0e9c3c1c5 | ||
|
|
4c3488cc9f | ||
|
|
130637a3fa | ||
|
|
377c510e95 | ||
|
|
409cfb8a87 |
49
.github/workflows/base.yml
vendored
49
.github/workflows/base.yml
vendored
@@ -25,20 +25,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
@@ -67,6 +53,20 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
# - cuda: "128"
|
||||
# cuda_version: 12.8.1
|
||||
# cudnn_version: ""
|
||||
@@ -122,13 +122,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
@@ -150,6 +143,20 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
15
.github/workflows/main.yml
vendored
15
.github/workflows/main.yml
vendored
@@ -15,11 +15,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
@@ -88,11 +83,6 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
@@ -162,11 +152,6 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
|
||||
14
.github/workflows/multi-gpu-e2e.yml
vendored
14
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -26,13 +26,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
@@ -47,6 +40,13 @@ jobs:
|
||||
axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
steps:
|
||||
|
||||
20
.github/workflows/nightlies.yml
vendored
20
.github/workflows/nightlies.yml
vendored
@@ -12,16 +12,16 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -65,16 +65,16 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
2
.github/workflows/precommit-autoupdate.yml
vendored
2
.github/workflows/precommit-autoupdate.yml
vendored
@@ -2,7 +2,7 @@ name: Pre-commit auto-update
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * 0' # Run weekly
|
||||
- cron: '0 0 1 * *' # Run monthly
|
||||
workflow_dispatch: # Manual kickoff
|
||||
|
||||
jobs:
|
||||
|
||||
10
.github/workflows/tests-nightly.yml
vendored
10
.github/workflows/tests-nightly.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.6.0", "2.7.0"]
|
||||
pytorch_version: ["2.7.1", "2.8.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -102,14 +102,14 @@ jobs:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.8.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
|
||||
36
.github/workflows/tests.yml
vendored
36
.github/workflows/tests.yml
vendored
@@ -55,7 +55,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
|
||||
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -130,7 +130,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
|
||||
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -152,7 +152,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel psutil
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
@@ -231,16 +231,10 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.8.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
@@ -289,15 +283,15 @@ jobs:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
# - cuda: 128
|
||||
# cuda_version: 12.8.1
|
||||
# python_version: "3.11"
|
||||
# pytorch: 2.7.1
|
||||
# num_gpus: 1
|
||||
# axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -305,6 +299,12 @@ jobs:
|
||||
num_gpus: 1
|
||||
gpu_type: "B200"
|
||||
axolotl_extras: fbgemm-gpu
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -11,13 +11,13 @@ repos:
|
||||
- id: no-commit-to-branch
|
||||
args: ['--branch', 'main']
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.12
|
||||
rev: v0.14.3
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.17.1
|
||||
rev: v1.18.2
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
|
||||
@@ -73,7 +73,7 @@ Features:
|
||||
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python 3.11
|
||||
- PyTorch ≥2.6.0
|
||||
- PyTorch ≥2.7.1
|
||||
|
||||
### Google Colab
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
fi
|
||||
|
||||
RUN uv pip install packaging==23.2 setuptools==75.8.0
|
||||
RUN uv pip install torchvision
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
||||
ENV CUDA="{{ CUDA }}"
|
||||
@@ -9,7 +9,7 @@ ENV GITHUB_REF="{{ GITHUB_REF }}"
|
||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
ENV AXOLOTL_DATASET_PROCESSES="8"
|
||||
ENV AXOLOTL_DATASET_NUM_PROC="8"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
@@ -32,7 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||
RUN pip install packaging==23.2 setuptools==75.8.0 psutil
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -65,8 +65,13 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
import subprocess # nosec
|
||||
|
||||
sp_env = os.environ.copy()
|
||||
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
|
||||
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
|
||||
exit(exit_code)
|
||||
try:
|
||||
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
|
||||
if exit_code:
|
||||
print(f"Command '{cmd}' failed with exit code {exit_code}")
|
||||
return exit_code
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"Command '{cmd}' failed with exception {e}")
|
||||
|
||||
@@ -13,7 +13,7 @@ datasets:
|
||||
val_set_size: 0
|
||||
output_dir: temp_debug/axolotl_outputs/model
|
||||
dataset_prepared_path: temp_debug/axolotl_outputs/data
|
||||
dataset_processes: 1
|
||||
dataset_num_proc: 1
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: false
|
||||
|
||||
@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.10"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
@@ -24,29 +24,35 @@ RUN apt-get update \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& mkdir -p /workspace/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel psutil && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
|
||||
python3 -m pip cache purge
|
||||
|
||||
RUN if [ "$CUDA" != "130" ] ; then \
|
||||
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@v1.5.4"; \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
|
||||
python3 -m pip cache purge; \
|
||||
fi
|
||||
|
||||
RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||
pip3 cache purge
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
|
||||
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
fi
|
||||
|
||||
@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="next"
|
||||
@@ -19,12 +19,12 @@ RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& mkdir -p /workspace/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="nightly"
|
||||
@@ -19,14 +19,14 @@ RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& mkdir -p /workspace/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
||||
@@ -30,7 +30,13 @@ RUN uv venv --no-project --relocatable axolotl-venv
|
||||
ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
|
||||
|
||||
RUN uv pip install packaging setuptools wheel psutil \
|
||||
&& uv pip install torch==${PYTORCH_VERSION} \
|
||||
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
|
||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||
&& uv pip install awscli pydantic
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
fi
|
||||
|
||||
@@ -29,7 +29,7 @@ While debugging it's helpful to simplify your test scenario as much as possible.
|
||||
1. **Make sure you are using the latest version of axolotl**: This project changes often and bugs get fixed fast. Check your git branch and make sure you have pulled the latest changes from `main`.
|
||||
1. **Eliminate concurrency**: Restrict the number of processes to 1 for both training and data preprocessing:
|
||||
- Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`.
|
||||
- Set `dataset_processes: 1` in your axolotl config or run the training command with `--dataset_processes=1`.
|
||||
- Set `dataset_num_proc: 1` in your axolotl config or run the training command with `--dataset_num_proc=1`.
|
||||
2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config):
|
||||
|
||||
```yaml
|
||||
@@ -101,7 +101,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
|
||||
"-m", "axolotl.cli.train", "dev_chat_template.yml",
|
||||
// The flags below simplify debugging by overriding the axolotl config
|
||||
// with the debugging tips above. Modify as needed.
|
||||
"--dataset_processes=1", // limits data preprocessing to one process
|
||||
"--dataset_num_proc=1", // limits data preprocessing to one process
|
||||
"--max_steps=1", // limits training to just one step
|
||||
"--batch_size=1", // minimizes batch size
|
||||
"--micro_batch_size=1", // minimizes batch size
|
||||
|
||||
@@ -63,6 +63,14 @@ description: Frequently asked questions
|
||||
|
||||
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
|
||||
|
||||
**Q: Can we mix text and text+image datasets for VLM training?**
|
||||
|
||||
> A: Yes, you can for newer VLM arch. The ones that would not work are LLaVA / Pixtral arch. If you notice one not working, please let us know!
|
||||
|
||||
**Q: Why is `memory/max_*` different from `nvidia-smi`?**
|
||||
|
||||
> A: We use `torch` APIs to retrieve this information. You can see https://docs.pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management for more information.
|
||||
|
||||
### Chat templates
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
|
||||
@@ -27,3 +27,9 @@ learning_rate: 2e-5
|
||||
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
|
||||
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
|
||||
self attention `q_proj` module.
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We currently only support varying `lr` for now. If you're interested in adding support for others (`weight_decay`), we welcome PRs. See https://github.com/axolotl-ai-cloud/axolotl/blob/613bcf90e58f3ab81d3827e7fc572319908db9fb/src/axolotl/core/trainers/mixins/optimizer.py#L17
|
||||
|
||||
:::
|
||||
|
||||
@@ -88,6 +88,7 @@ fsdp_sync_module_states | **REMOVED**
|
||||
fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading
|
||||
fsdp_state_dict_type | state_dict_type
|
||||
fsdp_use_orig_params | **REMOVED**
|
||||
fsdp_activation_checkpointing | activation_checkpointing
|
||||
|
||||
For more details, please see the migration guide in the [torchtitan repo](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md). In Axolotl,
|
||||
if you were using the following FSDP1 config:
|
||||
|
||||
@@ -56,10 +56,14 @@ image_resize_algorithm: bilinear
|
||||
|
||||
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
|
||||
|
||||
::: {.callout-warning}
|
||||
::: {.callout-tip}
|
||||
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
|
||||
:::
|
||||
|
||||
::: {.callout-note}
|
||||
As of now, we do not truncate nor drop samples based on `sequence_len` as each arch has different ways to process non-text tokens. We are looking for help on this.
|
||||
:::
|
||||
|
||||
### Mllama {#sec-mllama}
|
||||
|
||||
```yaml
|
||||
@@ -168,6 +172,14 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
### Qwen3-VL {#sec-qwen3-vl}
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen3-VL-4B-Instruct
|
||||
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
### SmolVLM2 {#sec-smolvlm2}
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
@@ -219,6 +219,21 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### chat_template.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### chat_template.default
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -6,6 +6,8 @@ LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-
|
||||
|
||||
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
|
||||
|
||||
Thanks to the team at LiquidAI for giving us early access to prepare for these releases.
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
@@ -31,6 +33,14 @@ This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
|
||||
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
|
||||
```
|
||||
|
||||
**LFM2-MoE**
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
|
||||
|
||||
# LoRA SFT (1x48GB @ 16.2GiB)
|
||||
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
|
||||
```
|
||||
|
||||
### TIPS
|
||||
|
||||
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
|
||||
@@ -45,14 +55,13 @@ This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [Optimizations Guide](https://docs.axolotl.ai/docs/optimizations.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
|
||||
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
|
||||
- [LFM2-MoE Blog](https://www.liquid.ai/blog/lfm2-8b-a1b-an-efficient-on-device-mixture-of-experts)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
base_model: LiquidAI/LFM2-350M
|
||||
|
||||
chunked_cross_entropy: true
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
eot_tokens:
|
||||
- "<|im_end|>"
|
||||
|
||||
59
examples/LiquidAI/lfm2-8b-a1b-lora.yaml
Normal file
59
examples/LiquidAI/lfm2-8b-a1b-lora.yaml
Normal file
@@ -0,0 +1,59 @@
|
||||
base_model: LiquidAI/LFM2-8B-A1B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: true
|
||||
|
||||
eot_tokens:
|
||||
- "<|im_end|>"
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 4
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
|
||||
weight_decay: 0.0
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -3,6 +3,9 @@ trust_remote_code: true
|
||||
model_type: AutoModelForImageTextToText
|
||||
processor_type: AutoProcessor
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
base_model: google/gemma-3-1b-it
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
base_model: google/gemma-3-270m-it
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
base_model: google/gemma-3-4b-it
|
||||
|
||||
# Need to set else transformers tries to load vision too
|
||||
model_type: Gemma3ForCausalLM
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
[GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.
|
||||
|
||||
In October 2025, OpenAI released safeguard models built upon GPT-OSS called [GPT-OSS-Safeguard](https://huggingface.co/collections/openai/gpt-oss-safeguard). They use the same architecture, so the same examples below can be re-used.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
@@ -64,6 +66,16 @@ axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offlo
|
||||
mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
|
||||
```
|
||||
|
||||
### How to set reasoning_effort in template?
|
||||
|
||||
The harmony template has a feature to set the `reasoning_effort` during prompt building. The default is `medium`. If you would like to adjust this, you can add the following to your config:
|
||||
|
||||
```yaml
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: "high" # low | medium | high
|
||||
```
|
||||
|
||||
Currently, this applies globally. There is no method to apply per sample yet. If you are interested in adding this, please feel free to create an Issue to discuss.
|
||||
|
||||
### Inferencing your fine-tuned model
|
||||
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
base_model: openai/gpt-oss-safeguard-20b
|
||||
use_kernels: true
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-safeguard-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
|
||||
lora_target_linear: true
|
||||
|
||||
# TODO: not supported for now, see peft#2710
|
||||
#lora_target_parameters: # target the experts in the last two layers
|
||||
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-4
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
warmup_ratio: 0.1
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
@@ -29,7 +29,7 @@ flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
|
||||
save_strategy: no
|
||||
torch_compile: true
|
||||
|
||||
wandb_project:
|
||||
|
||||
50
examples/llama-3/opentelemetry-qlora.yml
Normal file
50
examples/llama-3/opentelemetry-qlora.yml
Normal file
@@ -0,0 +1,50 @@
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/opentelemetry-example
|
||||
|
||||
adapter: qlora
|
||||
sequence_len: 512
|
||||
sample_packing: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
# OpenTelemetry Configuration
|
||||
use_otel_metrics: true
|
||||
otel_metrics_host: "localhost"
|
||||
otel_metrics_port: 8000
|
||||
|
||||
# Disable WandB
|
||||
use_wandb: false
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: false
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
@@ -12,7 +12,7 @@ Before starting, ensure you have:
|
||||
Run the thinking model fine-tuning:
|
||||
|
||||
```bash
|
||||
axolotl train magistral-small-think-qlora.yaml
|
||||
axolotl train examples/magistral/think/magistral-small-think-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 19.1 GiB VRAM.
|
||||
|
||||
@@ -21,7 +21,7 @@ Before starting, ensure you have:
|
||||
|
||||
3. Run the fine-tuning:
|
||||
```bash
|
||||
axolotl train magistral-small-vision-24B-qlora.yml
|
||||
axolotl train examples/magistral/vision/magistral-small-vision-24B-qlora.yml
|
||||
```
|
||||
|
||||
This config uses about 17GiB VRAM.
|
||||
|
||||
51
examples/mistral/mistral-small/README.md
Normal file
51
examples/mistral/mistral-small/README.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# Mistral Small 3.1/3.2 Fine-tuning
|
||||
|
||||
This guide covers fine-tuning [Mistral Small 3.1](mistralai/Mistral-Small-3.1-24B-Instruct-2503) and [Mistral Small 3.2](mistralai/Mistral-Small-3.2-24B-Instruct-2506) with vision capabilities using Axolotl.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Install the required vision lib:
|
||||
```bash
|
||||
pip install 'mistral-common[opencv]==1.8.5'
|
||||
```
|
||||
|
||||
2. Download the example dataset image:
|
||||
```bash
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
```
|
||||
|
||||
3. Run the fine-tuning:
|
||||
```bash
|
||||
axolotl train examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml
|
||||
```
|
||||
|
||||
This config uses about 29.4 GiB VRAM.
|
||||
|
||||
## Dataset Format
|
||||
|
||||
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
||||
{"role": "user", "content": [
|
||||
{ "type": "text", "text": "What's in this image?"},
|
||||
{"type": "image", "path": "path/to/image.jpg" }
|
||||
]},
|
||||
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- Sample Packing is not supported for multi-modality training currently.
|
||||
@@ -39,7 +39,7 @@ wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
|
||||
@@ -5,31 +5,30 @@ bitsandbytes==0.47.0
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.6.1
|
||||
liger-kernel==0.6.3
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub>=0.33.0
|
||||
peft>=0.17.0
|
||||
transformers==4.56.1
|
||||
huggingface_hub>=0.36.0
|
||||
peft>=0.17.1
|
||||
tokenizers>=0.21.1
|
||||
transformers==4.57.1
|
||||
accelerate==1.10.1
|
||||
datasets==4.0.0
|
||||
datasets==4.3.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.23.0
|
||||
hf_xet==1.1.5
|
||||
kernels==0.9.0
|
||||
trl==0.24.0
|
||||
hf_xet==1.2.0
|
||||
kernels>=0.9.0
|
||||
trackio
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
sentencepiece
|
||||
gradio==5.41.1
|
||||
gradio==5.49.1
|
||||
|
||||
modal==1.0.2
|
||||
pydantic==2.10.6
|
||||
pydantic>=2.10.6
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
@@ -37,8 +36,8 @@ requests
|
||||
wandb
|
||||
einops
|
||||
colorama
|
||||
numba
|
||||
numpy>=1.24.4,<=2.0.1
|
||||
numba>=0.61.2
|
||||
numpy>=2.2.6
|
||||
|
||||
# qlora things
|
||||
evaluate==0.4.1
|
||||
@@ -51,7 +50,7 @@ python-dotenv==1.0.1
|
||||
|
||||
# remote filesystems
|
||||
s3fs>=2024.5.0
|
||||
gcsfs>=2024.5.0
|
||||
gcsfs>=2025.3.0
|
||||
adlfs>=2024.5.0
|
||||
ocifs==1.3.2
|
||||
|
||||
@@ -67,7 +66,7 @@ antlr4-python3-runtime==4.13.2
|
||||
torchao==0.13.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-lgpl==0.0.7
|
||||
axolotl-contribs-mit==0.0.5
|
||||
|
||||
mistral-common==1.8.5
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"'
|
||||
)
|
||||
|
||||
26
setup.py
26
setup.py
@@ -26,7 +26,6 @@ def parse_requirements(extras_require_map):
|
||||
_install_requires.append(line)
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
||||
if "Darwin" in platform.system():
|
||||
# skip packages not compatible with OSX
|
||||
skip_packages = [
|
||||
@@ -34,7 +33,6 @@ def parse_requirements(extras_require_map):
|
||||
"triton",
|
||||
"mamba-ssm",
|
||||
"xformers",
|
||||
"autoawq",
|
||||
"liger-kernel",
|
||||
]
|
||||
_install_requires = [
|
||||
@@ -51,7 +49,7 @@ def parse_requirements(extras_require_map):
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
except PackageNotFoundError:
|
||||
torch_version = "2.6.0" # default to torch 2.6
|
||||
torch_version = "2.8.0" # default to torch 2.8.0
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
@@ -64,8 +62,15 @@ def parse_requirements(extras_require_map):
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 8):
|
||||
pass
|
||||
if (major, minor) >= (2, 9):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
|
||||
extras_require_map["vllm"] = ["vllm==0.11.1"]
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
elif (major, minor) >= (2, 8):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
|
||||
extras_require_map["vllm"] = ["vllm==0.11.0"]
|
||||
elif (major, minor) >= (2, 7):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
@@ -74,7 +79,7 @@ def parse_requirements(extras_require_map):
|
||||
extras_require_map.pop("vllm")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.31")
|
||||
extras_require_map["vllm"] = ["vllm>=0.10.0"]
|
||||
extras_require_map["vllm"] = ["vllm==0.10.1"]
|
||||
elif (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.29.post3")
|
||||
@@ -87,7 +92,6 @@ def parse_requirements(extras_require_map):
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers>=0.0.28.post3")
|
||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||
extras_require_map.pop("vllm")
|
||||
elif (major, minor) >= (2, 4):
|
||||
extras_require_map.pop("vllm")
|
||||
@@ -161,7 +165,13 @@ extras_require = {
|
||||
"llmcompressor": [
|
||||
"llmcompressor==0.5.1",
|
||||
],
|
||||
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
|
||||
"fbgemm-gpu": ["fbgemm-gpu-genai==1.3.0"],
|
||||
"opentelemetry": [
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-prometheus",
|
||||
"prometheus-client",
|
||||
],
|
||||
}
|
||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||
extras_require
|
||||
|
||||
@@ -99,7 +99,7 @@ def ray_train_func(kwargs: dict):
|
||||
resolve_dtype(cfg)
|
||||
|
||||
# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
|
||||
if cfg.deepspeed:
|
||||
if cfg.deepspeed and hasattr(cfg.deepspeed, "to_dict"):
|
||||
cfg.deepspeed = cfg.deepspeed.to_dict()
|
||||
|
||||
# initialize accelerator before model instantiation
|
||||
|
||||
@@ -12,6 +12,9 @@ MOE_ARCH_BLOCK = {
|
||||
"mixtral": "MixtralSparseMoeBlock",
|
||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||
"deepseek_v2": "DeepseekV2MoE",
|
||||
"deepseek_v3": "DeepseekV3MoE",
|
||||
"gpt_oss": "GptOssDecoderLayer",
|
||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||
}
|
||||
|
||||
@@ -29,7 +29,11 @@ from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils import (
|
||||
is_comet_available,
|
||||
is_mlflow_available,
|
||||
is_opentelemetry_available,
|
||||
)
|
||||
from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
@@ -134,6 +138,12 @@ class TrainerBuilderBase(abc.ABC):
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_otel_metrics and is_opentelemetry_available():
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
callbacks.append(OpenTelemetryMetricsCallback(self.cfg))
|
||||
if self.cfg.save_first_step:
|
||||
callbacks.append(SaveModelOnFirstStepCallback())
|
||||
|
||||
@@ -491,6 +501,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
"dion_momentum",
|
||||
"dion_rank_fraction",
|
||||
"dion_rank_multiple_of",
|
||||
"dataset_num_proc",
|
||||
]:
|
||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||
@@ -514,9 +525,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
# max_length is not used in CausalTrainer
|
||||
if self.cfg.reward_model or self.cfg.rl:
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers import (
|
||||
EarlyStoppingCallback,
|
||||
Trainer,
|
||||
)
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
from trl.trainer.reward_trainer import DataCollatorForPreference
|
||||
|
||||
from axolotl.core.builders.base import TrainerBuilderBase
|
||||
from axolotl.core.trainers import (
|
||||
@@ -28,7 +28,6 @@ from axolotl.processing_strategies import get_processing_strategy
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
LossWatchDogCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
causal_lm_bench_eval_callback_factory,
|
||||
colab_inference_post_train_callback,
|
||||
@@ -63,12 +62,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.relora:
|
||||
callbacks.append(ReLoRACallback(self.cfg))
|
||||
|
||||
if (
|
||||
hasattr(self.model, "use_bettertransformer")
|
||||
and self.model.use_bettertransformer is True
|
||||
):
|
||||
callbacks.append(SaveBetterTransformerModelCallback())
|
||||
|
||||
# TODO: check if can move to base class
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
@@ -460,7 +453,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
DataCollatorWithFlattening,
|
||||
RewardDataCollatorWithPadding,
|
||||
DataCollatorForPreference,
|
||||
]
|
||||
]
|
||||
collator_args = [self.tokenizer]
|
||||
@@ -477,7 +470,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if kwargs and isinstance(kwargs, dict):
|
||||
kwargs.update(collator_cls_and_kwargs[1])
|
||||
elif self.cfg.reward_model:
|
||||
collator = RewardDataCollatorWithPadding
|
||||
collator = DataCollatorForPreference
|
||||
tokenizer = collator_args.pop(0)
|
||||
kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||
kwargs.pop("padding")
|
||||
elif use_batch_sampler_collator:
|
||||
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
||||
# supported multipack models, or non-flash-attention llama
|
||||
|
||||
@@ -225,17 +225,6 @@ class AxolotlTrainer(
|
||||
|
||||
data_collator = self.data_collator if is_training else self.eval_data_collator
|
||||
|
||||
if dataset.column_names and "length" in dataset.column_names:
|
||||
dataset = dataset.remove_columns(["length"])
|
||||
if (
|
||||
dataset.column_names
|
||||
and "position_ids" in dataset.column_names
|
||||
and "attention_mask" in dataset.column_names
|
||||
and self.args.sample_packing
|
||||
and self.args.sample_packing_drop_attention_mask
|
||||
):
|
||||
dataset = dataset.remove_columns(["attention_mask"])
|
||||
|
||||
if isinstance(dataset, datasets.Dataset):
|
||||
if is_training:
|
||||
if not self.args.sample_packing or self.args.pretraining:
|
||||
@@ -294,6 +283,18 @@ class AxolotlTrainer(
|
||||
):
|
||||
self.accelerator.even_batches = False
|
||||
|
||||
if dataset.column_names and "length" in dataset.column_names:
|
||||
dataset = dataset.remove_columns(["length"])
|
||||
|
||||
if (
|
||||
dataset.column_names
|
||||
and "position_ids" in dataset.column_names
|
||||
and "attention_mask" in dataset.column_names
|
||||
and self.args.sample_packing
|
||||
and self.args.sample_packing_drop_attention_mask
|
||||
):
|
||||
dataset = dataset.remove_columns(["attention_mask"])
|
||||
|
||||
dataloader = DataLoader(dataset, **dataloader_params)
|
||||
|
||||
# Accelerator.free_memory() will destroy the references, so
|
||||
@@ -560,13 +561,6 @@ class AxolotlTrainer(
|
||||
|
||||
super().create_accelerator_and_postprocess()
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
if (
|
||||
"limit_all_gathers" in self.args.fsdp_config
|
||||
and self.args.fsdp_config["limit_all_gathers"]
|
||||
):
|
||||
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
||||
|
||||
def additional_accelerator_args(
|
||||
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
|
||||
@@ -52,6 +52,7 @@ class GRPOStrategy:
|
||||
if trl.vllm_mode:
|
||||
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
|
||||
if trl.vllm_mode == "colocate":
|
||||
grpo_args_kwargs["vllm_enable_sleep_mode"] = trl.vllm_enable_sleep_mode # type: ignore[attr-defined]
|
||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||
vllm_cfg.gpu_memory_utilization
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -54,9 +54,13 @@ plugins:
|
||||
- granitemoehybrid
|
||||
- hunyuan_v1_dense
|
||||
- hunyuan_v1_moe
|
||||
- lfm2
|
||||
- lfm2_moe
|
||||
- lfm2_vl
|
||||
- llama
|
||||
- llama4
|
||||
- llama4_text
|
||||
- llava
|
||||
- mistral
|
||||
- mistral3
|
||||
- mixtral
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .utils import create_bidirectional_attention_mask
|
||||
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -360,7 +360,7 @@ def _diffusion_step(
|
||||
|
||||
# Forward pass
|
||||
outputs = model(input_ids=sequence, attention_mask=attention_mask)
|
||||
logits = outputs.logits
|
||||
logits = shift_logits_to_input_positions(outputs.logits)
|
||||
|
||||
# Only sample at currently masked positions
|
||||
if current_mask.any():
|
||||
|
||||
@@ -11,7 +11,7 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .callbacks import DiffusionGenerationCallback
|
||||
from .utils import create_bidirectional_attention_mask
|
||||
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -207,7 +207,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
input_ids=noisy_batch.long(),
|
||||
attention_mask=bidirectional_mask,
|
||||
)
|
||||
logits = outputs.logits
|
||||
logits = shift_logits_to_input_positions(outputs.logits)
|
||||
|
||||
if masked_indices.sum() > 0:
|
||||
valid_indices = torch.where(masked_indices)
|
||||
|
||||
@@ -157,3 +157,10 @@ def create_bidirectional_attention_mask(
|
||||
|
||||
# Add head dimension: [batch_size, 1, seq_len, seq_len]
|
||||
return bidirectional_mask.unsqueeze(1)
|
||||
|
||||
|
||||
def shift_logits_to_input_positions(logits: torch.Tensor) -> torch.Tensor:
|
||||
"""Align next-token logits with their input token positions for diffusion."""
|
||||
if logits.size(1) <= 1:
|
||||
return logits
|
||||
return torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
||||
|
||||
@@ -72,9 +72,9 @@ def kldiv_forward_llama_like(
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
|
||||
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
|
||||
# self._loss_function should be LigerFusedLinearKLTopKLogprobLoss
|
||||
|
||||
loss = self.loss_function(
|
||||
loss = self._loss_function(
|
||||
self.lm_head.weight,
|
||||
hidden_states,
|
||||
target_token_ids,
|
||||
|
||||
@@ -29,7 +29,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_accepts_loss_kwargs = True
|
||||
self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
|
||||
|
||||
loss_fn = LigerFusedLinearKLTopKLogprobLoss(
|
||||
self.args.kd_ce_alpha, # hard label loss
|
||||
self.args.kd_alpha, # kd loss
|
||||
self.args.kd_temperature,
|
||||
@@ -37,6 +38,14 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
compute_ce_loss=bool(self.args.kd_ce_alpha),
|
||||
normalize_topk=self.args.kd_normalize_topk,
|
||||
)
|
||||
target = self.model
|
||||
|
||||
# Unwrap PEFT wrapper
|
||||
if hasattr(target, "get_base_model"):
|
||||
target = target.get_base_model()
|
||||
|
||||
# Set on the actual model instance
|
||||
target._loss_function = loss_fn
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
super()._set_signature_columns_if_needed()
|
||||
|
||||
@@ -515,9 +515,6 @@ class ModelLoader:
|
||||
if self.cfg.model_quantization_config_kwargs:
|
||||
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
|
||||
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
|
||||
else:
|
||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
||||
|
||||
if self.cfg.gptq:
|
||||
if not hasattr(self.model_config, "quantization_config"):
|
||||
@@ -552,9 +549,7 @@ class ModelLoader:
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**self.model_config.quantization_config
|
||||
)
|
||||
elif self.cfg.adapter == "qlora" and self.model_kwargs.get(
|
||||
"load_in_4bit", False
|
||||
):
|
||||
elif self.cfg.adapter == "qlora" and self.cfg.load_in_4bit:
|
||||
bnb_config = {
|
||||
"load_in_4bit": True,
|
||||
"llm_int8_threshold": 6.0,
|
||||
@@ -580,9 +575,7 @@ class ModelLoader:
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**bnb_config,
|
||||
)
|
||||
elif self.cfg.adapter == "lora" and self.model_kwargs.get(
|
||||
"load_in_8bit", False
|
||||
):
|
||||
elif self.cfg.adapter == "lora" and self.cfg.load_in_8bit:
|
||||
bnb_config = {
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
@@ -596,11 +589,6 @@ class ModelLoader:
|
||||
**bnb_config,
|
||||
)
|
||||
|
||||
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
||||
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
|
||||
self.model_kwargs.pop("load_in_8bit", None)
|
||||
self.model_kwargs.pop("load_in_4bit", None)
|
||||
|
||||
def _set_attention_config(self):
|
||||
"""Sample packing uses custom FA2 patch"""
|
||||
if self.cfg.attn_implementation:
|
||||
|
||||
@@ -134,6 +134,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
|
||||
return Qwen2Attention
|
||||
|
||||
if model_type == "qwen3_vl":
|
||||
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextAttention
|
||||
|
||||
return Qwen3VLTextAttention
|
||||
|
||||
if model_type == "mllama":
|
||||
from transformers.models.mllama.modeling_mllama import MllamaTextSelfAttention
|
||||
|
||||
|
||||
@@ -45,6 +45,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"gpt_oss",
|
||||
"arcee",
|
||||
"seed_oss",
|
||||
"lfm2",
|
||||
"lfm2_moe",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -13,9 +13,7 @@ from axolotl.utils.logging import get_logger
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
|
||||
PATCHED_GUARD = (
|
||||
'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):'
|
||||
)
|
||||
PATCHED_GUARD = 'if (attn_impl := (getattr(model.config, "_attn_implementation", None) or getattr(model.model.config, "_attn_implementation", None))) and attn_impl not in ("sdpa", "flash_attention_2"):'
|
||||
|
||||
|
||||
def patch_prepare_context_parallel_inputs() -> None:
|
||||
|
||||
@@ -6,8 +6,10 @@ from typing import Optional
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.Image import Resampling
|
||||
from torch import Tensor, zeros_like
|
||||
from transformers import ProcessorMixin, SmolVLMProcessor, VoxtralProcessor
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.image_utils import load_image
|
||||
from transformers.models.smolvlm import SmolVLMProcessor
|
||||
from transformers.models.voxtral import VoxtralProcessor
|
||||
|
||||
from axolotl.utils.dict import remove_none_values
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -71,10 +71,10 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
]
|
||||
|
||||
return {
|
||||
"input_ids_chosen": chosen_tokenized["input_ids"],
|
||||
"chosen_input_ids": chosen_tokenized["input_ids"],
|
||||
"attention_mask_chosen": chosen_tokenized["attention_mask"],
|
||||
"labels_chosen": 1.0,
|
||||
"input_ids_rejected": rejected_tokenized["input_ids"],
|
||||
"rejected_input_ids": rejected_tokenized["input_ids"],
|
||||
"attention_mask_rejected": rejected_tokenized["attention_mask"],
|
||||
"labels_rejected": 0.0,
|
||||
}
|
||||
|
||||
@@ -120,3 +120,123 @@ def default(cfg, dataset_idx=0, **kwargs):
|
||||
return result
|
||||
|
||||
return transform_fn, {"remove_columns": [field_messages]}
|
||||
|
||||
|
||||
def argilla_chat(cfg, dataset_idx=0, **kwargs):
|
||||
"""
|
||||
DPO chat template strategy for argilla-style datasets.
|
||||
|
||||
For argilla-style datasets where chosen/rejected contain full conversations
|
||||
instead of single response messages. Extracts the conversation history from
|
||||
the chosen field and formats both chosen/rejected responses using the
|
||||
configured chat template.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing chat_template and dataset settings
|
||||
dataset_idx: Index of the dataset in the config (default: 0)
|
||||
**kwargs: Additional keyword arguments (unused)
|
||||
|
||||
Returns:
|
||||
tuple: (transform_fn, dataset_kwargs) where:
|
||||
- transform_fn: Function to transform dataset samples
|
||||
- dataset_kwargs: Dict with 'remove_columns' specifying columns to drop
|
||||
|
||||
Dataset format:
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
"""
|
||||
ds_cfg = cfg["datasets"][dataset_idx]
|
||||
ds_cfg = handle_legacy_message_fields_logic(ds_cfg)
|
||||
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg=cfg, ds_cfg=ds_cfg
|
||||
)
|
||||
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||
message_property_mappings = ds_cfg.get(
|
||||
"message_property_mappings",
|
||||
{
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
)
|
||||
role_map_inv = ds_cfg.get(
|
||||
"roles",
|
||||
{
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
"system": ["system"],
|
||||
},
|
||||
)
|
||||
role_map = {}
|
||||
for target, sources in role_map_inv.items():
|
||||
for source in sources:
|
||||
role_map[source] = target
|
||||
|
||||
def transform_fn(sample, tokenizer=None):
|
||||
chat_template_string = get_chat_template(
|
||||
user_choice=chat_template_choice,
|
||||
jinja_template=chat_template_jinja,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
chosen_raw = sample[field_chosen]
|
||||
rejected_raw = sample[field_rejected]
|
||||
|
||||
# Extract messages (all but last) and responses (last message)
|
||||
chosen_messages = [
|
||||
{
|
||||
"role": role_map[m[message_property_mappings["role"]]],
|
||||
"content": m[message_property_mappings["content"]],
|
||||
}
|
||||
for m in chosen_raw[:-1]
|
||||
]
|
||||
chosen_response = {
|
||||
"role": role_map[chosen_raw[-1][message_property_mappings["role"]]],
|
||||
"content": chosen_raw[-1][message_property_mappings["content"]],
|
||||
}
|
||||
|
||||
rejected_response = {
|
||||
"role": role_map[rejected_raw[-1][message_property_mappings["role"]]],
|
||||
"content": rejected_raw[-1][message_property_mappings["content"]],
|
||||
}
|
||||
|
||||
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
|
||||
|
||||
result = {}
|
||||
result["prompt"] = tokenizer.apply_chat_template(
|
||||
chosen_messages,
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
result["chosen"] = tokenizer.apply_chat_template(
|
||||
[dummy_user_message, chosen_response],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)
|
||||
chosen_strip_index = result["chosen"].find(chosen_response["content"])
|
||||
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
|
||||
|
||||
result["rejected"] = tokenizer.apply_chat_template(
|
||||
[dummy_user_message, rejected_response],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_string,
|
||||
tokenize=False,
|
||||
)
|
||||
rejected_strip_index = result["rejected"].find(rejected_response["content"])
|
||||
result["rejected"] = result["rejected"][rejected_strip_index:].rstrip()
|
||||
|
||||
return result
|
||||
|
||||
return transform_fn, {"remove_columns": [field_chosen, field_rejected]}
|
||||
|
||||
@@ -40,11 +40,6 @@ from axolotl.utils.schemas.enums import RLType
|
||||
from axolotl.utils.train import determine_last_checkpoint
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
try:
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
except ImportError:
|
||||
BetterTransformer = None
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
|
||||
@@ -141,8 +136,6 @@ def setup_signal_handler(
|
||||
def terminate_handler(_, __, model_weakref):
|
||||
if model_weakref() is not None:
|
||||
_model = model_weakref()
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
_model = BetterTransformer.reverse(_model)
|
||||
_model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
@@ -321,9 +314,6 @@ def save_trained_model(
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
trainer.model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
@@ -535,6 +525,17 @@ def setup_model_and_trainer(
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_trainer_create(cfg, trainer)
|
||||
|
||||
if cfg.use_ray:
|
||||
try:
|
||||
import ray.train.huggingface.transformers
|
||||
|
||||
trainer = ray.train.huggingface.transformers.prepare_trainer(trainer)
|
||||
except ImportError:
|
||||
LOG.warning(
|
||||
"The Ray integration with Hugging Face Transformers is not available. "
|
||||
"To use Ray, install the 'ray[train]' package."
|
||||
)
|
||||
|
||||
return (
|
||||
trainer,
|
||||
model,
|
||||
|
||||
@@ -17,6 +17,13 @@ def is_comet_available():
|
||||
return importlib.util.find_spec("comet_ml") is not None
|
||||
|
||||
|
||||
def is_opentelemetry_available():
|
||||
return (
|
||||
importlib.util.find_spec("opentelemetry") is not None
|
||||
and importlib.util.find_spec("prometheus_client") is not None
|
||||
)
|
||||
|
||||
|
||||
def get_pytorch_version() -> tuple[int, int, int]:
|
||||
"""
|
||||
Get Pytorch version as a tuple of (major, minor, patch).
|
||||
|
||||
@@ -16,8 +16,8 @@ import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
import yaml
|
||||
from datasets import load_dataset
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from tqdm import tqdm
|
||||
from transformers import (
|
||||
GenerationConfig,
|
||||
@@ -28,8 +28,6 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import (
|
||||
PREFIX_CHECKPOINT_DIR,
|
||||
IntervalStrategy,
|
||||
SaveStrategy,
|
||||
)
|
||||
from trl.models import unwrap_model_for_generation
|
||||
@@ -56,40 +54,6 @@ IGNORE_INDEX = -100
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class SaveBetterTransformerModelCallback(TrainerCallback):
|
||||
"""Callback to save the BetterTransformer wrapped model"""
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
) -> TrainerControl:
|
||||
# Save
|
||||
if (
|
||||
args.save_strategy == IntervalStrategy.STEPS
|
||||
and args.save_steps > 0
|
||||
and state.global_step % args.save_steps == 0
|
||||
):
|
||||
control.should_save = True
|
||||
|
||||
if control.should_save:
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
)
|
||||
|
||||
model = BetterTransformer.reverse(kwargs["model"])
|
||||
model.save_pretrained(checkpoint_folder)
|
||||
# FIXME - need to cleanup old checkpoints
|
||||
|
||||
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
|
||||
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
||||
control.should_save = False
|
||||
return control
|
||||
|
||||
|
||||
class LossWatchDogCallback(TrainerCallback):
|
||||
"""Callback to track loss and stop training if loss is too high"""
|
||||
|
||||
@@ -796,6 +760,37 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||
|
||||
try:
|
||||
with open(self.axolotl_config_path, "r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
|
||||
chat_tpl = cfg.get("chat_template_jinja")
|
||||
if chat_tpl:
|
||||
with NamedTemporaryFile(
|
||||
mode="w", delete=True, suffix=".jinja", prefix="chat_template_"
|
||||
) as temp_ct_file:
|
||||
if (
|
||||
isinstance(chat_tpl, str)
|
||||
and os.path.exists(chat_tpl)
|
||||
and os.path.isfile(chat_tpl)
|
||||
):
|
||||
copyfile(chat_tpl, temp_ct_file.name)
|
||||
else:
|
||||
temp_ct_file.write(str(chat_tpl))
|
||||
temp_ct_file.flush()
|
||||
|
||||
artifact = wandb.Artifact(
|
||||
f"chat-template-{wandb.run.id}", type="jinja-template"
|
||||
)
|
||||
artifact.add_file(temp_ct_file.name)
|
||||
wandb.log_artifact(artifact)
|
||||
wandb.save(temp_ct_file.name)
|
||||
LOG.info(
|
||||
"The chat_template_jinja has been saved to the WandB run under files."
|
||||
)
|
||||
except (FileNotFoundError, ConnectionError, yaml.YAMLError) as err:
|
||||
LOG.warning(f"Error while saving chat_template_jinja to WandB: {err}")
|
||||
|
||||
if args.deepspeed:
|
||||
try:
|
||||
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
|
||||
|
||||
238
src/axolotl/utils/callbacks/opentelemetry.py
Normal file
238
src/axolotl/utils/callbacks/opentelemetry.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""OpenTelemetry metrics callback for Axolotl training"""
|
||||
|
||||
import threading
|
||||
from typing import Dict, Optional
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
try:
|
||||
from opentelemetry import metrics
|
||||
from opentelemetry.exporter.prometheus import PrometheusMetricReader
|
||||
from opentelemetry.metrics import set_meter_provider
|
||||
from opentelemetry.sdk.metrics import MeterProvider as SDKMeterProvider
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
OPENTELEMETRY_AVAILABLE = True
|
||||
except ImportError:
|
||||
LOG.warning("OpenTelemetry not available. pip install [opentelemetry]")
|
||||
OPENTELEMETRY_AVAILABLE = False
|
||||
|
||||
|
||||
class OpenTelemetryMetricsCallback(TrainerCallback):
|
||||
"""
|
||||
TrainerCallback that exports training metrics to OpenTelemetry/Prometheus.
|
||||
|
||||
This callback automatically tracks key training metrics including:
|
||||
- Training loss
|
||||
- Evaluation loss
|
||||
- Learning rate
|
||||
- Epoch progress
|
||||
- Global step count
|
||||
- Gradient norm
|
||||
|
||||
Metrics are exposed via HTTP endpoint for Prometheus scraping.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
if not OPENTELEMETRY_AVAILABLE:
|
||||
LOG.warning("OpenTelemetry not available, metrics will not be collected")
|
||||
self.metrics_enabled = False
|
||||
return
|
||||
|
||||
self.cfg = cfg
|
||||
self.metrics_host = getattr(cfg, "otel_metrics_host", "localhost")
|
||||
self.metrics_port = getattr(cfg, "otel_metrics_port", 8000)
|
||||
self.metrics_enabled = True
|
||||
self.server_started = False
|
||||
self.metrics_lock = threading.Lock()
|
||||
|
||||
try:
|
||||
# Create Prometheus metrics reader
|
||||
prometheus_reader = PrometheusMetricReader()
|
||||
|
||||
# Create meter provider with Prometheus exporter
|
||||
provider = SDKMeterProvider(metric_readers=[prometheus_reader])
|
||||
set_meter_provider(provider)
|
||||
|
||||
# Get meter for creating metrics
|
||||
self.meter = metrics.get_meter("axolotl.training")
|
||||
|
||||
# Create metrics
|
||||
self._create_metrics()
|
||||
|
||||
except Exception as e:
|
||||
LOG.warning(f"Failed to initialize OpenTelemetry metrics: {e}")
|
||||
self.metrics_enabled = False
|
||||
|
||||
def _create_metrics(self):
|
||||
"""Create all metrics that will be tracked"""
|
||||
self.train_loss_gauge = self.meter.create_gauge(
|
||||
name="axolotl_train_loss",
|
||||
description="Current training loss",
|
||||
unit="1",
|
||||
)
|
||||
|
||||
self.eval_loss_gauge = self.meter.create_gauge(
|
||||
name="axolotl_eval_loss",
|
||||
description="Current evaluation loss",
|
||||
unit="1",
|
||||
)
|
||||
|
||||
self.learning_rate_gauge = self.meter.create_gauge(
|
||||
name="axolotl_learning_rate",
|
||||
description="Current learning rate",
|
||||
unit="1",
|
||||
)
|
||||
|
||||
self.epoch_gauge = self.meter.create_gauge(
|
||||
name="axolotl_epoch",
|
||||
description="Current training epoch",
|
||||
unit="1",
|
||||
)
|
||||
|
||||
self.global_step_counter = self.meter.create_counter(
|
||||
name="axolotl_global_steps",
|
||||
description="Total training steps completed",
|
||||
unit="1",
|
||||
)
|
||||
|
||||
self.grad_norm_gauge = self.meter.create_gauge(
|
||||
name="axolotl_gradient_norm",
|
||||
description="Gradient norm",
|
||||
unit="1",
|
||||
)
|
||||
|
||||
self.memory_usage_gauge = self.meter.create_gauge(
|
||||
name="axolotl_memory_usage",
|
||||
description="Current memory usage in MB",
|
||||
unit="MB",
|
||||
)
|
||||
|
||||
def _start_metrics_server(self):
|
||||
"""Start the HTTP server for metrics exposure"""
|
||||
if self.server_started:
|
||||
return
|
||||
|
||||
try:
|
||||
start_http_server(self.metrics_port, addr=self.metrics_host)
|
||||
self.server_started = True
|
||||
LOG.info(
|
||||
f"OpenTelemetry metrics server started on http://{self.metrics_host}:{self.metrics_port}/metrics"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(f"Failed to start OpenTelemetry metrics server: {e}")
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Called at the beginning of training"""
|
||||
if not self.metrics_enabled:
|
||||
return
|
||||
|
||||
self._start_metrics_server()
|
||||
LOG.info("OpenTelemetry metrics collection started")
|
||||
|
||||
def on_log(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
logs: Optional[Dict[str, float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Called when logging occurs"""
|
||||
if not self.metrics_enabled or not logs:
|
||||
return
|
||||
|
||||
if "loss" in logs:
|
||||
self.train_loss_gauge.set(logs["loss"])
|
||||
|
||||
if "eval_loss" in logs:
|
||||
self.eval_loss_gauge.set(logs["eval_loss"])
|
||||
|
||||
if "learning_rate" in logs:
|
||||
self.learning_rate_gauge.set(logs["learning_rate"])
|
||||
|
||||
if "epoch" in logs:
|
||||
self.epoch_gauge.set(logs["epoch"])
|
||||
|
||||
if "grad_norm" in logs:
|
||||
self.grad_norm_gauge.set(logs["grad_norm"])
|
||||
if "memory_usage" in logs:
|
||||
self.memory_usage_gauge.set(logs["memory_usage"])
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Called at the end of each training step"""
|
||||
if not self.metrics_enabled:
|
||||
return
|
||||
|
||||
# Update step counter and epoch
|
||||
self.global_step_counter.add(1)
|
||||
if state.epoch is not None:
|
||||
self.epoch_gauge.set(state.epoch)
|
||||
|
||||
def on_evaluate(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
metrics: Optional[Dict[str, float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Called after evaluation"""
|
||||
if not self.metrics_enabled or not metrics:
|
||||
return
|
||||
|
||||
if "eval_loss" in metrics:
|
||||
self.eval_loss_gauge.set(metrics["eval_loss"])
|
||||
|
||||
# Record any other eval metrics as gauges
|
||||
for key, value in metrics.items():
|
||||
if key.startswith("eval_") and isinstance(value, (int, float)):
|
||||
# Create gauge for this metric if it doesn't exist
|
||||
gauge_name = f"axolotl_{key}"
|
||||
try:
|
||||
gauge = self.meter.create_gauge(
|
||||
name=gauge_name,
|
||||
description=f"Evaluation metric: {key}",
|
||||
unit="1",
|
||||
)
|
||||
gauge.set(value)
|
||||
except Exception as e:
|
||||
LOG.warning(f"Failed to create/update metric {gauge_name}: {e}")
|
||||
|
||||
def on_train_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Called at the end of training"""
|
||||
if not self.metrics_enabled:
|
||||
return
|
||||
|
||||
LOG.info("Training completed. OpenTelemetry metrics collection finished.")
|
||||
LOG.info(
|
||||
f"Metrics are still available at http://{self.metrics_host}:{self.metrics_port}/metrics"
|
||||
)
|
||||
@@ -113,7 +113,7 @@ def _map_dataset(
|
||||
|
||||
dataset = dataset.map(
|
||||
ds_transform_fn,
|
||||
num_proc=cfg.dataset_processes,
|
||||
num_proc=cfg.dataset_num_proc,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Mapping RL Dataset",
|
||||
**map_kwargs,
|
||||
@@ -234,7 +234,7 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_processes,
|
||||
num_proc=cfg.dataset_num_proc,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
|
||||
@@ -239,6 +239,11 @@ def _load_from_local_path(
|
||||
return load_dataset(dataset_config.path, **load_dataset_kwargs)
|
||||
elif local_path.is_file():
|
||||
dataset_type = get_dataset_type(dataset_config)
|
||||
|
||||
# For single file datasets, HF always creates only a "train" split
|
||||
if dataset_type in ("json", "csv", "text"):
|
||||
load_dataset_kwargs["split"] = "train"
|
||||
|
||||
return load_dataset(
|
||||
dataset_type,
|
||||
data_files=dataset_config.path,
|
||||
@@ -409,7 +414,7 @@ def save_preprocessed_dataset(
|
||||
) -> None:
|
||||
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
|
||||
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
|
||||
num_workers = cfg.dataset_processes or get_default_process_count()
|
||||
num_workers = cfg.dataset_num_proc or get_default_process_count()
|
||||
if isinstance(dataset, IterableDataset):
|
||||
ds_from_iter = Dataset.from_generator(
|
||||
functools.partial(_generate_from_iterable_dataset, dataset),
|
||||
|
||||
@@ -223,7 +223,7 @@ def handle_long_seq_in_dataset(
|
||||
|
||||
filter_map_kwargs = {}
|
||||
if not isinstance(dataset, IterableDataset):
|
||||
filter_map_kwargs["num_proc"] = cfg.dataset_processes
|
||||
filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
|
||||
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
||||
|
||||
drop_long_kwargs = {}
|
||||
|
||||
@@ -80,7 +80,7 @@ def get_dataset_wrapper(
|
||||
"""
|
||||
# Common parameters for dataset wrapping
|
||||
dataset_kwargs: dict[str, Any] = {
|
||||
"process_count": cfg.dataset_processes,
|
||||
"process_count": cfg.dataset_num_proc,
|
||||
"keep_in_memory": cfg.dataset_keep_in_memory is True,
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ import os
|
||||
|
||||
|
||||
def get_default_process_count():
|
||||
if axolotl_dataset_num_proc := os.environ.get("AXOLOTL_DATASET_NUM_PROC"):
|
||||
return int(axolotl_dataset_num_proc)
|
||||
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
|
||||
return int(axolotl_dataset_processes)
|
||||
if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
|
||||
|
||||
@@ -3,66 +3,46 @@ utils to get GPU info for the current environment
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess # nosec B404
|
||||
from importlib.metadata import version
|
||||
|
||||
import torch
|
||||
from accelerate.utils.environment import (
|
||||
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
|
||||
get_gpu_info,
|
||||
)
|
||||
from packaging.version import Version, parse
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def check_cuda_p2p_ib_support():
|
||||
if not accelerate_check_cuda_p2p_ib_support():
|
||||
return False
|
||||
if not check_runpod_p2p_support():
|
||||
if not check_cuda_p2p_support():
|
||||
return False
|
||||
unsupported_devices = {"RTX 6000 Ada", "L40S"}
|
||||
try:
|
||||
device_names, device_count = get_gpu_info()
|
||||
if 1 < device_count < 8:
|
||||
if any(
|
||||
unsupported_device in device_name
|
||||
for device_name in device_names
|
||||
for unsupported_device in unsupported_devices
|
||||
):
|
||||
return False
|
||||
except Exception: # nosec B110
|
||||
pass
|
||||
return True
|
||||
|
||||
|
||||
def check_runpod_p2p_support() -> bool:
|
||||
if "RUNPOD_GPU_COUNT" not in os.environ:
|
||||
return True
|
||||
def check_cuda_p2p_support() -> bool:
|
||||
try:
|
||||
gpu_count = int(os.environ.get("RUNPOD_GPU_COUNT", "1"))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
except ValueError:
|
||||
return True
|
||||
if gpu_count >= 2:
|
||||
# run `nvidia-smi topo -p2p n` and inspect the GPU0 row
|
||||
|
||||
if world_size > 1:
|
||||
node_world_size = int(os.environ.get("NODE_WORLD_SIZE", "8"))
|
||||
local_other_rank = (local_rank // node_world_size) * node_world_size
|
||||
local_other_rank += 1 if (local_rank % node_world_size) == 0 else 0
|
||||
try:
|
||||
result = subprocess.run( # nosec B603 B607
|
||||
["nvidia-smi", "topo", "-p2p", "n"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
except (
|
||||
subprocess.CalledProcessError,
|
||||
FileNotFoundError,
|
||||
subprocess.TimeoutExpired,
|
||||
):
|
||||
return True # fail-open if detection fails
|
||||
output_lines = result.stdout.strip().split("\n")
|
||||
# filter rows that start with "GPU0" (avoid header row)
|
||||
gpu0_rows = [line for line in output_lines if line.lstrip().startswith("GPU0")]
|
||||
if not gpu0_rows:
|
||||
can_p2p = torch.cuda.can_device_access_peer(local_rank, local_other_rank)
|
||||
except AssertionError as exc:
|
||||
# some sort of logic error in indexing processes, assume p2p is fine for now
|
||||
LOG.warning(exc)
|
||||
return True
|
||||
# consider P2P supported if any OK is present in the GPU0 row
|
||||
return "OK" in gpu0_rows[-1]
|
||||
return can_p2p
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ into fixed-capacity batches to optimize memory usage and training throughput.
|
||||
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from multiprocessing import cpu_count, get_context
|
||||
@@ -291,7 +292,10 @@ class MultipackBatchSampler(BatchSampler):
|
||||
self.total_token_slots = 0
|
||||
|
||||
# The number of times to calculate batches to determine minimum packed dataset length
|
||||
self.num_count_samples = num_count_samples
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
self.num_count_samples = (
|
||||
1 if world_size >= num_count_samples else num_count_samples
|
||||
)
|
||||
|
||||
if self.sequential and not isinstance(sampler, SequentialSampler):
|
||||
LOG.warning(
|
||||
|
||||
@@ -24,11 +24,13 @@ from axolotl.utils.schemas.datasets import (
|
||||
)
|
||||
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
||||
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
|
||||
from axolotl.utils.schemas.fsdp import FSDPConfig
|
||||
from axolotl.utils.schemas.integrations import (
|
||||
CometConfig,
|
||||
GradioConfig,
|
||||
LISAConfig,
|
||||
MLFlowConfig,
|
||||
OpenTelemetryConfig,
|
||||
RayConfig,
|
||||
WandbConfig,
|
||||
)
|
||||
@@ -59,6 +61,7 @@ class AxolotlInputConfig(
|
||||
WandbConfig,
|
||||
MLFlowConfig,
|
||||
CometConfig,
|
||||
OpenTelemetryConfig,
|
||||
LISAConfig,
|
||||
GradioConfig,
|
||||
RayConfig,
|
||||
@@ -233,6 +236,7 @@ class AxolotlInputConfig(
|
||||
)
|
||||
dataset_processes: int | None = Field(
|
||||
default=None,
|
||||
deprecated="Use `dataset_num_proc` instead. This parameter will be removed in a future version.",
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
|
||||
@@ -240,6 +244,16 @@ class AxolotlInputConfig(
|
||||
)
|
||||
},
|
||||
)
|
||||
dataset_num_proc: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
|
||||
"For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
dataset_exact_deduplication: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -667,8 +681,7 @@ class AxolotlInputConfig(
|
||||
json_schema_extra={"description": "FSDP configuration"},
|
||||
deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ",
|
||||
)
|
||||
# TODO @SalmanMohammadi strongly type this as its own schema
|
||||
fsdp_config: dict[str, Any] | None = Field(
|
||||
fsdp_config: FSDPConfig | None = Field(
|
||||
default=None, json_schema_extra={"description": "FSDP configuration options"}
|
||||
)
|
||||
fsdp_version: int | None = Field(
|
||||
@@ -1314,10 +1327,22 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def default_dataset_processes(cls, data):
|
||||
if data.get("dataset_processes") is None:
|
||||
data["dataset_processes"] = get_default_process_count()
|
||||
|
||||
def default_dataset_num_proc(cls, data):
|
||||
if data.get("dataset_processes") is not None:
|
||||
if data.get("dataset_num_proc") is None:
|
||||
data["dataset_num_proc"] = data["dataset_processes"]
|
||||
LOG.warning(
|
||||
"dataset_processes is deprecated and will be removed in a future version. "
|
||||
"Please use dataset_num_proc instead."
|
||||
)
|
||||
else:
|
||||
LOG.warning(
|
||||
"Both dataset_processes and dataset_num_proc are set. "
|
||||
"Using dataset_num_proc and ignoring dataset_processes."
|
||||
)
|
||||
del data["dataset_processes"]
|
||||
elif data.get("dataset_num_proc") is None:
|
||||
data["dataset_num_proc"] = get_default_process_count()
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
71
src/axolotl/utils/schemas/fsdp.py
Normal file
71
src/axolotl/utils/schemas/fsdp.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
FSDP Configuration Schema
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FSDPConfig(BaseModel):
|
||||
"""
|
||||
FSDP Configuration Schema
|
||||
"""
|
||||
|
||||
activation_checkpointing: bool | None = Field(
|
||||
default=None,
|
||||
description="Enable activation checkpointing to reduce memory usage during forward passes",
|
||||
)
|
||||
offload_params: bool | None = Field(
|
||||
default=None,
|
||||
description="Offload parameters to CPU to reduce GPU memory usage",
|
||||
)
|
||||
sync_module_states: bool | None = Field(
|
||||
default=None,
|
||||
description="Synchronize module states across all processes",
|
||||
)
|
||||
cpu_ram_efficient_loading: bool | None = Field(
|
||||
default=None,
|
||||
description="Enable CPU RAM efficient loading to reduce memory usage during model loading",
|
||||
)
|
||||
cpu_offload_pin_memory: bool | None = Field(
|
||||
default=None,
|
||||
description="Disabling this enables swap memory usage for resource-constrained setups when offload_params is enabled.",
|
||||
)
|
||||
use_orig_params: bool | None = Field(
|
||||
default=None,
|
||||
description="Use original parameters instead of flattened parameters",
|
||||
)
|
||||
|
||||
state_dict_type: (
|
||||
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
|
||||
) = Field(
|
||||
default=None,
|
||||
description="Type of state dict to use for saving/loading checkpoints",
|
||||
)
|
||||
final_state_dict_type: (
|
||||
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
|
||||
) = Field(
|
||||
default=None,
|
||||
description="Final state dict type to use after training completion",
|
||||
)
|
||||
|
||||
auto_wrap_policy: Literal["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP"] | None = (
|
||||
Field(
|
||||
default=None,
|
||||
description="Policy for automatically wrapping modules with FSDP",
|
||||
)
|
||||
)
|
||||
transformer_layer_cls_to_wrap: str | None = Field(
|
||||
default=None,
|
||||
description="Class name of transformer layers to wrap (e.g., 'LlamaDecoderLayer')",
|
||||
)
|
||||
|
||||
reshard_after_forward: bool | None = Field(
|
||||
default=None,
|
||||
description="Reshard parameters after forward pass to save memory",
|
||||
)
|
||||
mixed_precision_policy: str | None = Field(
|
||||
default=None,
|
||||
description="Mixed precision policy for FSDP (e.g., 'fp16', 'bf16')",
|
||||
)
|
||||
@@ -176,3 +176,27 @@ class RayConfig(BaseModel):
|
||||
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class OpenTelemetryConfig(BaseModel):
|
||||
"""OpenTelemetry configuration subset"""
|
||||
|
||||
use_otel_metrics: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "Enable OpenTelemetry metrics collection and Prometheus export"
|
||||
},
|
||||
)
|
||||
otel_metrics_host: str | None = Field(
|
||||
default="localhost",
|
||||
json_schema_extra={
|
||||
"title": "OpenTelemetry Metrics Host",
|
||||
"description": "Host to bind the OpenTelemetry metrics server to",
|
||||
},
|
||||
)
|
||||
otel_metrics_port: int | None = Field(
|
||||
default=8000,
|
||||
json_schema_extra={
|
||||
"description": "Port for the Prometheus metrics HTTP server"
|
||||
},
|
||||
)
|
||||
|
||||
@@ -167,3 +167,9 @@ class TRLConfig(BaseModel):
|
||||
"description": "Whether to exclude truncated completions from loss calculation."
|
||||
},
|
||||
)
|
||||
vllm_enable_sleep_mode: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Enable sleep mode for vLLM to offload VRAM when idle"
|
||||
},
|
||||
)
|
||||
|
||||
@@ -783,15 +783,6 @@ class OptimizationValidationMixin:
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_torch_compile_deepspeed(cls, data):
|
||||
if data.get("deepspeed") and data.get("torch_compile"):
|
||||
raise ValueError(
|
||||
"torch_compile should be set within your deepspeed config file"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_xentropy_patch_conflicts(cls, data):
|
||||
@@ -890,7 +881,7 @@ class OptimizationValidationMixin:
|
||||
and self.fsdp_config
|
||||
and self.optimizer
|
||||
and "8bit" in self.optimizer.value
|
||||
and self.fsdp_config["offload_params"]
|
||||
and self.fsdp_config.offload_params
|
||||
and str(self.fsdp_version) != "2"
|
||||
):
|
||||
raise ValueError(
|
||||
|
||||
@@ -6,6 +6,7 @@ import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -15,6 +16,7 @@ from datasets import IterableDataset, disable_caching, enable_caching
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
|
||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -276,7 +278,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
prior_len = None
|
||||
filter_map_kwargs = {}
|
||||
if not isinstance(train_dataset, IterableDataset):
|
||||
filter_map_kwargs["num_proc"] = cfg.dataset_processes
|
||||
filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
|
||||
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
||||
|
||||
drop_long_kwargs = {}
|
||||
@@ -316,7 +318,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
if cfg.group_by_length:
|
||||
train_dataset = train_dataset.map(
|
||||
add_length,
|
||||
num_proc=cfg.dataset_processes,
|
||||
num_proc=cfg.dataset_num_proc,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Group By Length",
|
||||
)
|
||||
@@ -333,7 +335,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
)
|
||||
train_dataset = train_dataset.map(
|
||||
pose_fn,
|
||||
num_proc=cfg.dataset_processes,
|
||||
num_proc=cfg.dataset_num_proc,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Add position_id column (PoSE)",
|
||||
)
|
||||
@@ -342,7 +344,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.map(
|
||||
pose_fn,
|
||||
num_proc=cfg.dataset_processes,
|
||||
num_proc=cfg.dataset_num_proc,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Add position_id column (PoSE)",
|
||||
)
|
||||
@@ -467,7 +469,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
bin_size=cfg.sample_packing_bin_size,
|
||||
sequential=cfg.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
num_processes=cfg.dataset_processes,
|
||||
num_processes=cfg.dataset_prcoesses,
|
||||
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
|
||||
)
|
||||
|
||||
@@ -540,6 +542,13 @@ def setup_deepspeed_env(cfg, stage=None):
|
||||
)
|
||||
|
||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
if isinstance(cfg.deepspeed, DictDefault):
|
||||
with NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".json", prefix="deepspeed_config_"
|
||||
) as temp_file:
|
||||
temp_file.write(json.dumps(cfg.deepspeed.to_dict(), indent=4))
|
||||
temp_file.close()
|
||||
cfg.deepspeed = str(temp_file.name)
|
||||
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||
os.environ["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str(
|
||||
cfg.gradient_accumulation_steps
|
||||
@@ -562,6 +571,7 @@ def setup_deepspeed_env(cfg, stage=None):
|
||||
if (
|
||||
int(os.environ.get("WORLD_SIZE", "1")) == 1
|
||||
and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1"
|
||||
and cfg.use_ray is not True
|
||||
):
|
||||
os.environ["WORLD_SIZE"] = "1" # force it in case not set
|
||||
os.environ["LOCAL_RANK"] = "0" # force it in case not set
|
||||
@@ -631,6 +641,7 @@ def setup_parallelism_envs(cfg):
|
||||
def prepare_optim_env(cfg):
|
||||
if not check_cuda_p2p_ib_support():
|
||||
if os.getenv("NCCL_P2P_DISABLE") is None:
|
||||
LOG.warning("P2P support not detected, setting `NCCL_P2P_DISABLE=1`")
|
||||
os.environ["NCCL_P2P_DISABLE"] = "1"
|
||||
# TODO @SalmanMohammadi remove the cfg.fsdp check in 0.12
|
||||
if cfg.fsdp or cfg.fsdp_config:
|
||||
@@ -638,11 +649,15 @@ def prepare_optim_env(cfg):
|
||||
setup_fsdp_envs(cfg)
|
||||
elif cfg.deepspeed:
|
||||
stage = None
|
||||
deepspeed_config = None
|
||||
# check if the cfg.deepspeed is a file
|
||||
if os.path.isfile(cfg.deepspeed):
|
||||
if isinstance(cfg.deepspeed, DictDefault):
|
||||
deepspeed_config = cfg.deepspeed
|
||||
elif os.path.isfile(cfg.deepspeed):
|
||||
# parse with json
|
||||
with open(cfg.deepspeed, "r", encoding="utf-8") as fin:
|
||||
deepspeed_config = json.load(fin)
|
||||
if deepspeed_config:
|
||||
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
|
||||
setup_deepspeed_env(cfg, stage=stage)
|
||||
|
||||
|
||||
@@ -33,7 +33,6 @@ def parse_requirements():
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
||||
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
||||
|
||||
if "Darwin" in platform.system():
|
||||
# don't install xformers on MacOS
|
||||
@@ -63,7 +62,6 @@ def parse_requirements():
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.28.post3")
|
||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||
elif (major, minor) >= (2, 4):
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
|
||||
@@ -440,7 +440,7 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unhandled cfg_string: {cfg_string}")
|
||||
cfg["dataset_processes"] = 4
|
||||
cfg["dataset_num_proc"] = 4
|
||||
|
||||
if cfg_string == "grpo_cfg":
|
||||
rewards_dir = tmp_path / "rewards_test"
|
||||
|
||||
@@ -104,7 +104,6 @@ class TestKnowledgeDistillation:
|
||||
temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.skip(reason="Chunked KD loss doesn't support PEFT/LoRA")
|
||||
@pytest.mark.parametrize(
|
||||
"load_in_8bit",
|
||||
[True, False],
|
||||
@@ -120,6 +119,10 @@ class TestKnowledgeDistillation:
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||
"lora_mlp_kernel": False,
|
||||
"lora_qkv_kernel": False,
|
||||
"lora_o_kernel": False,
|
||||
}
|
||||
| kd_min_cfg
|
||||
)
|
||||
|
||||
@@ -353,7 +353,6 @@ class TestMultiGPULlama:
|
||||
"auto_wrap",
|
||||
],
|
||||
"fsdp_config": {
|
||||
"fsdp_limit_all_gathers": True,
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
@@ -431,7 +430,6 @@ class TestMultiGPULlama:
|
||||
"auto_wrap",
|
||||
],
|
||||
"fsdp_config": {
|
||||
"fsdp_limit_all_gathers": True,
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
@@ -594,7 +592,6 @@ class TestMultiGPULlama:
|
||||
"auto_wrap",
|
||||
],
|
||||
"fsdp_config": {
|
||||
"fsdp_limit_all_gathers": True,
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
|
||||
@@ -13,7 +13,6 @@ from axolotl.utils.dict import DictDefault
|
||||
from tests.e2e.utils import (
|
||||
check_tensorboard,
|
||||
require_torch_2_7_0,
|
||||
require_torch_lt_2_6_0,
|
||||
)
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
@@ -24,7 +23,7 @@ class TestMultiGPURay:
|
||||
Test cases for AnyScale Ray post training
|
||||
"""
|
||||
|
||||
@require_torch_lt_2_6_0
|
||||
@require_torch_2_7_0
|
||||
def test_lora_ddp(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -83,7 +82,7 @@ class TestMultiGPURay:
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@require_torch_lt_2_6_0
|
||||
@require_torch_2_7_0
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 2],
|
||||
|
||||
@@ -69,7 +69,7 @@ class TestActivationCheckpointing:
|
||||
"save_safetensors": True,
|
||||
"gradient_checkpointing": gradient_checkpointing,
|
||||
"save_first_step": False,
|
||||
"dataset_processes": 4,
|
||||
"dataset_num_proc": 4,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestPretrainLlama:
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": sample_packing,
|
||||
"pretrain_multipack_attn": pretrain_multipack_attn,
|
||||
"dataset_processes": 1,
|
||||
"dataset_num_proc": 1,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
|
||||
@@ -8,7 +8,7 @@ import pytest
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.dpo.chat_template import default
|
||||
from axolotl.prompt_strategies.dpo.chat_template import argilla_chat, default
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
@@ -78,6 +78,36 @@ def fixture_custom_assistant_dataset():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="argilla_chat_dataset")
|
||||
def fixture_argilla_chat_dataset():
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"chosen": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "goodbye",
|
||||
},
|
||||
],
|
||||
"rejected": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "party on",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="phi3_tokenizer")
|
||||
@enable_hf_offline
|
||||
def fixture_phi3_tokenizer():
|
||||
@@ -216,5 +246,51 @@ class TestAssistantDPOChatTemplateGemma:
|
||||
assert result["rejected"] == "party on<end_of_turn>"
|
||||
|
||||
|
||||
class TestArgillaChatDPOChatTemplate:
|
||||
"""
|
||||
Test class for argilla_chat style datasets (chosen/rejected contain full conversations).
|
||||
"""
|
||||
|
||||
def test_llama3_argilla_chat(self, llama3_tokenizer, argilla_chat_dataset):
|
||||
transform_fn, _ = argilla_chat(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
"datasets": [
|
||||
{
|
||||
"type": "chat_template.argilla_chat",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
result = transform_fn(argilla_chat_dataset[0], tokenizer=llama3_tokenizer)
|
||||
assert result["prompt"] == (
|
||||
"<|begin_of_text|>"
|
||||
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
|
||||
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
assert result["chosen"] == "goodbye<|eot_id|>"
|
||||
assert result["rejected"] == "party on<|eot_id|>"
|
||||
|
||||
def test_phi3_argilla_chat(self, phi3_tokenizer, argilla_chat_dataset):
|
||||
transform_fn, _ = argilla_chat(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "tokenizer_default",
|
||||
"datasets": [
|
||||
{
|
||||
"type": "chat_template.argilla_chat",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
result = transform_fn(argilla_chat_dataset[0], tokenizer=phi3_tokenizer)
|
||||
assert result["prompt"] == "<|user|>\nhello<|end|>\n" + "<|assistant|>\n"
|
||||
assert result["chosen"] == "goodbye<|end|>"
|
||||
assert result["rejected"] == "party on<|end|>"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -141,7 +141,7 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
"dataset_num_proc": 4,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -180,7 +180,7 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
"dataset_num_proc": 4,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -219,7 +219,7 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
"dataset_num_proc": 4,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -252,7 +252,7 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
"dataset_num_proc": 4,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -285,7 +285,7 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
"dataset_num_proc": 4,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -370,7 +370,7 @@ class TestDatasetPreparation:
|
||||
"rl": "dpo",
|
||||
"chat_template": "llama3",
|
||||
"datasets": [ALPACA_MESSAGES_CONFIG_REVISION],
|
||||
"dataset_processes": 4,
|
||||
"dataset_num_proc": 4,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -471,7 +471,7 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
"dataset_num_proc": 4,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -210,7 +210,7 @@ class TestDeduplicateRLDataset:
|
||||
ALPACA_MESSAGES_CONFIG_REVISION,
|
||||
ALPACA_MESSAGES_CONFIG_REVISION,
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
"dataset_num_proc": 4,
|
||||
}
|
||||
)
|
||||
yield fixture
|
||||
|
||||
@@ -80,16 +80,26 @@ class TestModelsUtils:
|
||||
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
|
||||
and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
|
||||
)
|
||||
elif load_in_8bit and self.cfg.adapter is not None:
|
||||
assert self.model_loader.model_kwargs["load_in_8bit"]
|
||||
elif load_in_4bit and self.cfg.adapter is not None:
|
||||
assert self.model_loader.model_kwargs["load_in_4bit"]
|
||||
|
||||
if (self.cfg.adapter == "qlora" and load_in_4bit) or (
|
||||
self.cfg.adapter == "lora" and load_in_8bit
|
||||
):
|
||||
assert self.model_loader.model_kwargs.get(
|
||||
"quantization_config", BitsAndBytesConfig
|
||||
if self.cfg.adapter == "qlora" and load_in_4bit:
|
||||
assert isinstance(
|
||||
self.model_loader.model_kwargs.get("quantization_config"),
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
assert (
|
||||
self.model_loader.model_kwargs["quantization_config"]._load_in_4bit
|
||||
is True
|
||||
)
|
||||
if self.cfg.adapter == "lora" and load_in_8bit:
|
||||
assert isinstance(
|
||||
self.model_loader.model_kwargs.get("quantization_config"),
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
assert (
|
||||
self.model_loader.model_kwargs["quantization_config"]._load_in_8bit
|
||||
is True
|
||||
)
|
||||
|
||||
def test_message_property_mapping(self):
|
||||
|
||||
@@ -111,7 +111,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": True,
|
||||
"regular_param": "value",
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -124,7 +123,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(cfg_with_version.fsdp_config.offload_params, False)
|
||||
self.assertEqual(cfg_with_version.fsdp_config.cpu_ram_efficient_loading, True)
|
||||
self.assertEqual(cfg_with_version.fsdp_config.regular_param, "value")
|
||||
|
||||
self.assertNotIn("fsdp_auto_wrap_policy", cfg_with_version.fsdp_config)
|
||||
self.assertNotIn("fsdp_offload_params", cfg_with_version.fsdp_config)
|
||||
@@ -137,7 +135,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
"fsdp_config": {
|
||||
"fsdp_auto_wrap_policy": "SIZE_BASED_WRAP",
|
||||
"fsdp_offload_params": True,
|
||||
"regular_param": "value",
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -149,7 +146,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
cfg_without_version.fsdp_config.auto_wrap_policy, "SIZE_BASED_WRAP"
|
||||
)
|
||||
self.assertEqual(cfg_without_version.fsdp_config.offload_params, True)
|
||||
self.assertEqual(cfg_without_version.fsdp_config.regular_param, "value")
|
||||
|
||||
self.assertNotIn("fsdp_auto_wrap_policy", cfg_without_version.fsdp_config)
|
||||
self.assertNotIn("fsdp_offload_params", cfg_without_version.fsdp_config)
|
||||
|
||||
349
tests/test_opentelemetry_callback.py
Normal file
349
tests/test_opentelemetry_callback.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""Tests for OpenTelemetry metrics callback functionality."""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_otel_config():
|
||||
"""Mock configuration for OpenTelemetry callback."""
|
||||
return DictDefault(
|
||||
{
|
||||
"use_otel_metrics": True,
|
||||
"otel_metrics_host": "localhost",
|
||||
"otel_metrics_port": 8003, # Use unique port for tests
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trainer_state():
|
||||
"""Mock trainer state for callback testing."""
|
||||
from transformers import TrainerState
|
||||
|
||||
state = TrainerState()
|
||||
state.epoch = 1.0
|
||||
state.global_step = 100
|
||||
return state
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_training_args():
|
||||
"""Mock training arguments for callback testing."""
|
||||
from transformers import TrainingArguments
|
||||
|
||||
return TrainingArguments(output_dir="/tmp/test")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trainer_control():
|
||||
"""Mock trainer control for callback testing."""
|
||||
from transformers.trainer_callback import TrainerControl
|
||||
|
||||
return TrainerControl()
|
||||
|
||||
|
||||
class TestOpenTelemetryConfig:
|
||||
"""Test OpenTelemetry configuration schema."""
|
||||
|
||||
def test_config_schema_valid(self):
|
||||
"""Test OpenTelemetry configuration schema validation."""
|
||||
from axolotl.utils.schemas.integrations import OpenTelemetryConfig
|
||||
|
||||
# Test valid config
|
||||
valid_config = {
|
||||
"use_otel_metrics": True,
|
||||
"otel_metrics_host": "localhost",
|
||||
"otel_metrics_port": 8000,
|
||||
}
|
||||
|
||||
otel_config = OpenTelemetryConfig(**valid_config)
|
||||
assert otel_config.use_otel_metrics is True
|
||||
assert otel_config.otel_metrics_host == "localhost"
|
||||
assert otel_config.otel_metrics_port == 8000
|
||||
|
||||
def test_config_defaults(self):
|
||||
"""Test OpenTelemetry configuration default values."""
|
||||
from axolotl.utils.schemas.integrations import OpenTelemetryConfig
|
||||
|
||||
# Test minimal config with defaults
|
||||
minimal_config = {"use_otel_metrics": True}
|
||||
|
||||
otel_config = OpenTelemetryConfig(**minimal_config)
|
||||
assert otel_config.use_otel_metrics is True
|
||||
assert otel_config.otel_metrics_host == "localhost" # default
|
||||
assert otel_config.otel_metrics_port == 8000 # default
|
||||
|
||||
def test_config_disabled_by_default(self):
|
||||
"""Test that OpenTelemetry is disabled by default."""
|
||||
from axolotl.utils.schemas.integrations import OpenTelemetryConfig
|
||||
|
||||
# Test default config
|
||||
default_config = OpenTelemetryConfig()
|
||||
assert default_config.use_otel_metrics is False
|
||||
|
||||
|
||||
class TestOpenTelemetryCallback:
|
||||
"""Test OpenTelemetry callback functionality."""
|
||||
|
||||
def test_callback_import(self):
|
||||
"""Test that OpenTelemetry callback can be imported."""
|
||||
from axolotl.utils.callbacks.opentelemetry import OpenTelemetryMetricsCallback
|
||||
|
||||
assert OpenTelemetryMetricsCallback is not None
|
||||
|
||||
def test_callback_graceful_fallback(self, mock_otel_config):
|
||||
"""Test callback gracefully handles missing dependencies."""
|
||||
from axolotl.utils.callbacks.opentelemetry import OpenTelemetryMetricsCallback
|
||||
|
||||
# This should not raise an exception even if dependencies are missing
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
|
||||
# Callback should exist but may have metrics disabled
|
||||
assert callback is not None
|
||||
assert hasattr(callback, "metrics_enabled")
|
||||
|
||||
def test_callback_initialization_enabled(self, mock_otel_config):
|
||||
"""Test callback initialization when OpenTelemetry is available."""
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OPENTELEMETRY_AVAILABLE,
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
|
||||
if OPENTELEMETRY_AVAILABLE:
|
||||
assert callback.metrics_enabled is True
|
||||
assert callback.cfg == mock_otel_config
|
||||
assert callback.metrics_host == "localhost"
|
||||
assert callback.metrics_port == 8003
|
||||
else:
|
||||
assert callback.metrics_enabled is False
|
||||
|
||||
def test_metrics_server_lifecycle(
|
||||
self,
|
||||
mock_otel_config,
|
||||
mock_trainer_state,
|
||||
mock_training_args,
|
||||
mock_trainer_control,
|
||||
):
|
||||
"""Test metrics server starts and stops correctly."""
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OPENTELEMETRY_AVAILABLE,
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
if not OPENTELEMETRY_AVAILABLE:
|
||||
pytest.skip("OpenTelemetry dependencies not available")
|
||||
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
|
||||
# Start server
|
||||
callback.on_train_begin(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
assert callback.server_started is True
|
||||
|
||||
# End training
|
||||
callback.on_train_end(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
|
||||
def test_metrics_recording(
|
||||
self,
|
||||
mock_otel_config,
|
||||
mock_trainer_state,
|
||||
mock_training_args,
|
||||
mock_trainer_control,
|
||||
):
|
||||
"""Test that metrics are recorded during training."""
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OPENTELEMETRY_AVAILABLE,
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
if not OPENTELEMETRY_AVAILABLE:
|
||||
pytest.skip("OpenTelemetry dependencies not available")
|
||||
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
callback.on_train_begin(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
|
||||
# Test logging metrics
|
||||
test_logs = {
|
||||
"loss": 0.5,
|
||||
"learning_rate": 1e-4,
|
||||
"grad_norm": 0.8,
|
||||
}
|
||||
|
||||
# This should not raise an exception
|
||||
callback.on_log(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control, logs=test_logs
|
||||
)
|
||||
assert callback.metrics_enabled is True
|
||||
|
||||
def test_evaluation_metrics(
|
||||
self,
|
||||
mock_otel_config,
|
||||
mock_trainer_state,
|
||||
mock_training_args,
|
||||
mock_trainer_control,
|
||||
):
|
||||
"""Test evaluation metrics recording."""
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OPENTELEMETRY_AVAILABLE,
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
if not OPENTELEMETRY_AVAILABLE:
|
||||
pytest.skip("OpenTelemetry dependencies not available")
|
||||
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
callback.on_train_begin(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
|
||||
# Test evaluation metrics
|
||||
eval_logs = {
|
||||
"eval_loss": 0.3,
|
||||
"eval_accuracy": 0.95,
|
||||
}
|
||||
|
||||
# This should not raise an exception
|
||||
callback.on_evaluate(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control, eval_logs
|
||||
)
|
||||
assert callback.metrics_enabled is True
|
||||
|
||||
def test_thread_safety(self, mock_otel_config):
|
||||
"""Test that callback has thread safety mechanisms."""
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OPENTELEMETRY_AVAILABLE,
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
if not OPENTELEMETRY_AVAILABLE:
|
||||
pytest.skip("OpenTelemetry dependencies not available")
|
||||
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
assert hasattr(callback, "metrics_lock")
|
||||
# Check it's a lock-like object
|
||||
assert hasattr(callback.metrics_lock, "__enter__")
|
||||
assert hasattr(callback.metrics_lock, "__exit__")
|
||||
|
||||
|
||||
class TestOpenTelemetryIntegration:
|
||||
"""Integration tests for OpenTelemetry."""
|
||||
|
||||
def test_availability_check(self):
|
||||
"""Test availability check function."""
|
||||
from axolotl.utils import is_opentelemetry_available
|
||||
|
||||
result = is_opentelemetry_available()
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_prometheus_endpoint_basic(
|
||||
self,
|
||||
mock_otel_config,
|
||||
mock_trainer_state,
|
||||
mock_training_args,
|
||||
mock_trainer_control,
|
||||
):
|
||||
"""Test basic Prometheus endpoint functionality."""
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OPENTELEMETRY_AVAILABLE,
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
if not OPENTELEMETRY_AVAILABLE:
|
||||
pytest.skip("OpenTelemetry dependencies not available")
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
pytest.skip("requests library not available")
|
||||
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
callback.on_train_begin(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
|
||||
if not callback.server_started:
|
||||
pytest.skip("Metrics server failed to start")
|
||||
|
||||
# Give server time to start
|
||||
time.sleep(1)
|
||||
|
||||
# Try to access metrics endpoint
|
||||
try:
|
||||
response = requests.get(
|
||||
f"http://{callback.metrics_host}:{callback.metrics_port}/metrics",
|
||||
timeout=2,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
# Check for Prometheus format
|
||||
assert "# TYPE" in response.text or "# HELP" in response.text
|
||||
except requests.exceptions.RequestException:
|
||||
pytest.skip(
|
||||
"Could not connect to metrics endpoint - this is expected in some environments"
|
||||
)
|
||||
|
||||
|
||||
class TestOpenTelemetryCallbackMethods:
|
||||
"""Test specific callback methods."""
|
||||
|
||||
def test_step_end_callback(
|
||||
self,
|
||||
mock_otel_config,
|
||||
mock_trainer_state,
|
||||
mock_training_args,
|
||||
mock_trainer_control,
|
||||
):
|
||||
"""Test step end callback method."""
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OPENTELEMETRY_AVAILABLE,
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
if not OPENTELEMETRY_AVAILABLE:
|
||||
pytest.skip("OpenTelemetry dependencies not available")
|
||||
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
callback.on_train_begin(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
|
||||
# Should not raise an exception
|
||||
callback.on_step_end(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
|
||||
def test_epoch_end_callback(
|
||||
self,
|
||||
mock_otel_config,
|
||||
mock_trainer_state,
|
||||
mock_training_args,
|
||||
mock_trainer_control,
|
||||
):
|
||||
"""Test epoch end callback method."""
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OPENTELEMETRY_AVAILABLE,
|
||||
OpenTelemetryMetricsCallback,
|
||||
)
|
||||
|
||||
if not OPENTELEMETRY_AVAILABLE:
|
||||
pytest.skip("OpenTelemetry dependencies not available")
|
||||
|
||||
callback = OpenTelemetryMetricsCallback(mock_otel_config)
|
||||
callback.on_train_begin(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
|
||||
# Should not raise an exception
|
||||
callback.on_epoch_end(
|
||||
mock_training_args, mock_trainer_state, mock_trainer_control
|
||||
)
|
||||
@@ -55,7 +55,7 @@ class TestPacking(unittest.TestCase):
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
"dataset_num_proc": 4,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
|
||||
Reference in New Issue
Block a user