Compare commits

...

38 Commits

Author SHA1 Message Date
NanoCode012
83ff8bfa1a fix: change docker miniconda install to workspace 2025-11-06 18:54:56 +07:00
salman
c37decb073 update pre-commit cadence (#3245) 2025-11-04 13:43:40 +00:00
NanoCode012
01a346d86a feat(example): add gpt-oss-safeguard docs (#3243)
* feat(example): add gpt-oss-safeguard docs

* fix: add doc on reasoning_effort
2025-11-04 07:39:21 +07:00
NanoCode012
26f05b6008 fix(example): set model_type to load for gemma3 text (#3242)
* fix: set model_type to load for gemma3 text

* chore: simplify

* chore: unify
2025-11-04 07:35:07 +07:00
github-actions[bot]
ed58fa8a75 chore: update pre-commit hooks (#3244) 2025-11-03 15:55:40 +00:00
Wing Lian
633afffacb add torch 2.9.0 to ci (#3223) 2025-10-30 18:50:26 -04:00
Wing Lian
4b1b4fa6d8 upgrade numpy (#3236)
* upgrade numpy to 2.3.4

* bump contribs for numpy

* fix vllm versions

* bump numba

* make sure psutil is installed

* add psutil to cicd dockerfile jinja

* lower dep versions of numba + numpy for vllm

* bump datasets version

* resolve pydantic conflict too
2025-10-30 10:03:24 -04:00
github-actions[bot]
0f7c886b7b chore: update pre-commit hooks (#3222) [skip ci]
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-10-29 18:09:46 -04:00
Wing Lian
a4b921135b build cuda 13.0.0 base image with 2.9.0 (#3229)
* build cuda 13.0.0 base image with 2.9.0

* upgrade causal-conv1d

* 1.5.4 not in pypi yet

* pin to 1.3.0

* use github release instead of pypi

* split the logic for incompatible packages

* fix bash in dockerfile
2025-10-29 18:07:29 -04:00
Wing Lian
98333e639a upgrade trl to 0.24.0 and liger to 0.6.3 (#3230)
* upgrade trl to 0.24.0

* fix reward collator init

* use newer DataCollatorForPreference instead

* DataCollatorForPreference doesn't use padding kwarg

* fix input id labels

* fix fbgemm-gpu version for pytorch versions

* tweak pinned deps

* transformers doesn't support hub 1.0 yet

* upgrade liger dep to 0.6.3

* set TORCH_CUDA_ARCH_LIST correctly
2025-10-29 18:02:16 -04:00
Dan Saunders
9d4d39e939 Diffusion trainer fix: shift logits to align with input tokens (#3191)
* shift logits for diffusion generate

* delete unused

* diffusion trainer: token shift
2025-10-27 14:42:01 +07:00
Wing Lian
bb33fda44d install flash attention in 2.9.0 base images (#3224) 2025-10-22 21:24:52 -07:00
VED
4dc018992d Feat/opentelemetry (#3215) 2025-10-22 19:16:55 -07:00
NanoCode012
243620394a fix: force train split for json,csv,txt for test_datasets and misc doc changes (#3226)
* fix: force train split for json,csv,txt for test_datasets

* feat(doc): add info on mixing datasets for VLM

* feat(doc): max memory

* fix(doc): clarify lr groups

* fix: add info on vision not being dropped

* feat: add qwen3-vl to multimodal docs

* fix: add moe blocks to arch list

* feat(doc): improve mistral docs

* chore: add helpful link [skip-e2e]

* fix: add vram usage for mistral small

* Update link in docs/faq.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-10-22 15:23:20 -07:00
Qingyang Wu
3750fdcf79 Fix trainer dataloader slow loading issue (#3219)
* Fix trainer dataloader handling in src/axolotl/core/trainers/base.py

* update comment to reflect torch version

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-10-22 21:22:14 +07:00
Matthew Hambrecht
613bcf90e5 fix: enable_sleep_mode -> vllm_enable_sleep_mode (#3225)
Co-authored-by: Matthew Hambrecht <matthew.hambrecht@patapsco.ai>
2025-10-22 06:55:26 -07:00
Wing Lian
383f220cfd build torch 2.9.0 base images (#3221) 2025-10-20 08:53:49 -04:00
NanoCode012
8bb871b5cf fix: deepspeed with context parallel (#3220) 2025-10-20 14:06:58 +07:00
Leonard
87565ecc05 Add chat_template.argilla_chat support for DPO datasets (#3202)
* Add chat_template.argilla_chat support for DPO datasets

  Creates a new chat_template.argilla_chat prompt strategy for handling
  DPO datasets where chosen/rejected fields contain full conversations
  (messages + final response), following the pattern of chatml.argilla_chat
  and llama3.argilla_chat.

  - Add argilla_chat() function to chat_template.py
  - Add chat_template.argilla_chat to RLHF documentation
  - Add test coverage for argilla_chat with multiple tokenizers

  Dataset format:
  {
    "chosen": [
      {"role": "user", "content": "..."},
      {"role": "assistant", "content": "..."}
    ],
    "rejected": [
      {"role": "user", "content": "..."},
      {"role": "assistant", "content": "..."}
    ]
  }

* Fix chat_template.argilla_chat return value contract and add docstring

- Return (transform_fn, dataset_kwargs) tuple instead of bare transform_fn
- Add remove_columns specification for field_chosen and field_rejected
- Add comprehensive docstring with Args/Returns sections
- Update tests to unpack tuple return value

Addresses PR feedback to maintain consistency with chat_template.default()
and properly specify columns to remove after dataset transformation.

* Update tests/prompt_strategies/test_dpo_chat_templates.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-10-17 17:00:26 +07:00
NanoCode012
93ba57396f fix: qwen3_vl attention config (#3216) 2025-10-17 10:35:03 +07:00
NanoCode012
aa1240acd8 fix: transformers deprecate load_in_Xbit in model_kwargs (#3205)
* fix: transformers deprecate load_in_Xbit in model_kwargs

* fix: test to read from quantization_config kwarg

* fix: test

* fix: access

* fix: test weirdly entering incorrect config
2025-10-16 16:07:27 +07:00
Wing Lian
4cdfdfebb5 upgrade transformers==4.57.1 and peft==0.23.1 (#3214) 2025-10-14 15:54:05 -04:00
github-actions[bot]
6e2f5ccf9f chore: update pre-commit hooks (#3211) [skip ci]
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-10-14 10:21:49 -04:00
NanoCode012
8c7f63cf97 fix: unpack cce imported incorrectly (#3212) [skip ci] 2025-10-13 17:19:15 +07:00
VED
cd856b45b1 feat:add support dataset_num_processes (#3129) [skip ci]
* feat:add support dataset_num_processes

* chore

* required changes

* requested chnages

* required chnages

* required changes

* required changes

* elif get_default_process_count()

* add:del data

* Update cicd/Dockerfile.jinja

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update cicd/single_gpu.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2025-10-13 17:18:12 +07:00
salman
143dea4753 FSDPConfig (#3170) 2025-10-10 14:44:25 +01:00
Hitesh Sagtani
bc2ffb8204 fix: Enable KD plugin support for PEFT/LoRA adapters (#3207)
- Fix _loss_function attribute not found on base model with PEFT
- Fix mismatched attribute name (loss_function vs _loss_function)
- Set _loss_function on unwrapped base model for PEFT
- Enable previously skipped test_llama_lora_kd test
- Add test config fixes for LoRA kernel compatibility

Fixes https://github.com/axolotl-ai-cloud/axolotl/issues/3206
2025-10-10 08:57:00 -04:00
NanoCode012
153edcfe79 fix(doc): add act checkpointing migration to fsdp2 docs (#3193) [skip ci] 2025-10-10 10:57:50 +07:00
Wing Lian
08b8fa62cc only calculate packed ds length once if using a large world size (#3210) 2025-10-09 14:18:46 -04:00
Wing Lian
3a5c97e6e5 use can_device_access_peer for P2P checks (#3209) [skip ci]
* use can_device_access_peer for P2P checks

* also log warn when automatically setting NCCL_P2P_DISABLE=1
2025-10-09 14:17:31 -04:00
VED
37f78c8592 add chat_template_jinja to wandb (#3192) [skip ci]
* add chat_template_jinja to wandb

* temp_ct_file.flush()

* Update src/axolotl/utils/callbacks/__init__.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Update src/axolotl/utils/callbacks/__init__.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Apply suggestion from @winglian

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-10-09 12:05:54 -04:00
NanoCode012
ab63b92c38 feat: add lfm2 family and latest moe model (#3208)
* feat: add lfm2 family and latest moe model

* fix: use ml-cross-entropy for lfm2 examples
2025-10-09 10:47:41 -04:00
Manh Nguyen
6f8ce024d1 Remove check_torch_compile_deepspeed (#3195) [skip ci]
Signed-off-by: nguyen599 <pnvmanh2123@gmail.com>
2025-10-08 11:27:01 -04:00
Wing Lian
d0e9c3c1c5 When using Ray use prepare for dataloader fixes (#3198)
* make sure to use ray prepare for dataloader fixes

* ray tests use 2.7.0+

* don't call init_distributed w ray and deepspeed

* handle dict deepspeed config

* better handling of dict deepspeed config

* use json.dumps

* guard to_dict

* wrap import for optional ray
2025-10-08 10:43:41 -04:00
github-actions[bot]
4c3488cc9f chore: update pre-commit hooks (#3160) [skip ci]
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-10-08 08:58:02 -04:00
Wing Lian
130637a3fa upgrade transformers to 4.57.0 (#3201)
* upgrade transformers to 4.57.0

* remove deprecated autoawq and use latest peft

* remove autoawq from setuptools script

* fix imports

* make sure torchvision is installed

* remove support for BetterTransformer

* skip fsdp_qlora_prequant test

* more robust error reporting
2025-10-08 08:43:46 -04:00
VED
377c510e95 sleep model support (#3135)
Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-10-08 12:39:21 +01:00
Wing Lian
409cfb8a87 deprecate torch 2.6.0 support (#3197) [skip ci] 2025-10-07 11:23:41 -04:00
93 changed files with 1604 additions and 359 deletions

View File

@@ -25,20 +25,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
@@ -67,6 +53,20 @@ jobs:
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
@@ -122,13 +122,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
@@ -150,6 +143,20 @@ jobs:
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -15,11 +15,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -88,11 +83,6 @@ jobs:
strategy:
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -162,11 +152,6 @@ jobs:
strategy:
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"

View File

@@ -26,13 +26,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -47,6 +40,13 @@ jobs:
axolotl_extras: fbgemm-gpu
num_gpus: 2
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras: fbgemm-gpu
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:

View File

@@ -12,16 +12,16 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -65,16 +65,16 @@ jobs:
strategy:
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -2,7 +2,7 @@ name: Pre-commit auto-update
on:
schedule:
- cron: '0 0 * * 0' # Run weekly
- cron: '0 0 1 * *' # Run monthly
workflow_dispatch: # Manual kickoff
jobs:

View File

@@ -26,7 +26,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0"]
pytorch_version: ["2.7.1", "2.8.0"]
timeout-minutes: 20
steps:
@@ -102,14 +102,14 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.8.0
num_gpus: 1
axolotl_extras:
nightly_build: "true"

View File

@@ -55,7 +55,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
timeout-minutes: 20
steps:
@@ -130,7 +130,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
timeout-minutes: 20
steps:
@@ -152,7 +152,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel psutil
- name: Install PyTorch
run: |
@@ -231,16 +231,10 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.8.0
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
@@ -289,15 +283,15 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
# - cuda: 128
# cuda_version: 12.8.1
# python_version: "3.11"
# pytorch: 2.7.1
# num_gpus: 1
# axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
@@ -305,6 +299,12 @@ jobs:
num_gpus: 1
gpu_type: "B200"
axolotl_extras: fbgemm-gpu
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -11,13 +11,13 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.12
rev: v0.14.3
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.17.1
rev: v1.18.2
hooks:
- id: mypy
additional_dependencies:

View File

@@ -73,7 +73,7 @@ Features:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- PyTorch ≥2.6.0
- PyTorch ≥2.7.1
### Google Colab

View File

@@ -32,6 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi
RUN uv pip install packaging==23.2 setuptools==75.8.0
RUN uv pip install torchvision
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -1,6 +1,6 @@
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
@@ -9,7 +9,7 @@ ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_PROCESSES="8"
ENV AXOLOTL_DATASET_NUM_PROC="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
@@ -32,7 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN pip install packaging==23.2 setuptools==75.8.0
RUN pip install packaging==23.2 setuptools==75.8.0 psutil
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -65,8 +65,13 @@ def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
sp_env = os.environ.copy()
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
exit(exit_code)
try:
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
if exit_code:
print(f"Command '{cmd}' failed with exit code {exit_code}")
return exit_code
except Exception as e: # pylint: disable=broad-except
print(f"Command '{cmd}' failed with exception {e}")

View File

@@ -13,7 +13,7 @@ datasets:
val_set_size: 0
output_dir: temp_debug/axolotl_outputs/model
dataset_prepared_path: temp_debug/axolotl_outputs/data
dataset_processes: 1
dataset_num_proc: 1
sequence_len: 4096
sample_packing: false

View File

@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ENV PATH="/workspace/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.10"
ARG PYTORCH_VERSION="2.1.2"
@@ -24,29 +24,35 @@ RUN apt-get update \
&& rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& mkdir -p /workspace/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel psutil && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge
RUN if [ "$CUDA" != "130" ] ; then \
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@v1.5.4"; \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
python3 -m pip cache purge; \
fi
RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
fi

View File

@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ENV PATH="/workspace/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="next"
@@ -19,12 +19,12 @@ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& mkdir -p /workspace/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace

View File

@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ENV PATH="/workspace/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="nightly"
@@ -19,14 +19,14 @@ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& mkdir -p /workspace/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace

View File

@@ -30,7 +30,13 @@ RUN uv venv --no-project --relocatable axolotl-venv
ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install torch==${PYTORCH_VERSION} \
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
fi

View File

@@ -29,7 +29,7 @@ While debugging it's helpful to simplify your test scenario as much as possible.
1. **Make sure you are using the latest version of axolotl**: This project changes often and bugs get fixed fast. Check your git branch and make sure you have pulled the latest changes from `main`.
1. **Eliminate concurrency**: Restrict the number of processes to 1 for both training and data preprocessing:
- Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`.
- Set `dataset_processes: 1` in your axolotl config or run the training command with `--dataset_processes=1`.
- Set `dataset_num_proc: 1` in your axolotl config or run the training command with `--dataset_num_proc=1`.
2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config):
```yaml
@@ -101,7 +101,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
"-m", "axolotl.cli.train", "dev_chat_template.yml",
// The flags below simplify debugging by overriding the axolotl config
// with the debugging tips above. Modify as needed.
"--dataset_processes=1", // limits data preprocessing to one process
"--dataset_num_proc=1", // limits data preprocessing to one process
"--max_steps=1", // limits training to just one step
"--batch_size=1", // minimizes batch size
"--micro_batch_size=1", // minimizes batch size

View File

@@ -63,6 +63,14 @@ description: Frequently asked questions
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
**Q: Can we mix text and text+image datasets for VLM training?**
> A: Yes, you can for newer VLM arch. The ones that would not work are LLaVA / Pixtral arch. If you notice one not working, please let us know!
**Q: Why is `memory/max_*` different from `nvidia-smi`?**
> A: We use `torch` APIs to retrieve this information. You can see https://docs.pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management for more information.
### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**

View File

@@ -27,3 +27,9 @@ learning_rate: 2e-5
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
self attention `q_proj` module.
::: {.callout-note}
We currently only support varying `lr` for now. If you're interested in adding support for others (`weight_decay`), we welcome PRs. See https://github.com/axolotl-ai-cloud/axolotl/blob/613bcf90e58f3ab81d3827e7fc572319908db9fb/src/axolotl/core/trainers/mixins/optimizer.py#L17
:::

View File

@@ -88,6 +88,7 @@ fsdp_sync_module_states | **REMOVED**
fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading
fsdp_state_dict_type | state_dict_type
fsdp_use_orig_params | **REMOVED**
fsdp_activation_checkpointing | activation_checkpointing
For more details, please see the migration guide in the [torchtitan repo](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md). In Axolotl,
if you were using the following FSDP1 config:

View File

@@ -56,10 +56,14 @@ image_resize_algorithm: bilinear
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
::: {.callout-warning}
::: {.callout-tip}
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
:::
::: {.callout-note}
As of now, we do not truncate nor drop samples based on `sequence_len` as each arch has different ways to process non-text tokens. We are looking for help on this.
:::
### Mllama {#sec-mllama}
```yaml
@@ -168,6 +172,14 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### Qwen3-VL {#sec-qwen3-vl}
```yaml
base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip}

View File

@@ -219,6 +219,21 @@ DPO supports the following types with the following dataset format:
}
```
#### chat_template.argilla_chat
```json
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
#### chat_template.default
```yaml

View File

@@ -6,6 +6,8 @@ LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
Thanks to the team at LiquidAI for giving us early access to prepare for these releases.
## Getting Started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
@@ -31,6 +33,14 @@ This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
```
**LFM2-MoE**
```bash
pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
# LoRA SFT (1x48GB @ 16.2GiB)
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
```
### TIPS
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
@@ -45,14 +55,13 @@ This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [Optimizations Guide](https://docs.axolotl.ai/docs/optimizations.html)
## Related Resources
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
- [LFM2-MoE Blog](https://www.liquid.ai/blog/lfm2-8b-a1b-an-efficient-on-device-mixture-of-experts)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,6 +1,7 @@
base_model: LiquidAI/LFM2-350M
chunked_cross_entropy: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
eot_tokens:
- "<|im_end|>"

View File

@@ -0,0 +1,59 @@
base_model: LiquidAI/LFM2-8B-A1B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
eot_tokens:
- "<|im_end|>"
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
bf16: true
tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -3,6 +3,9 @@ trust_remote_code: true
model_type: AutoModelForImageTextToText
processor_type: AutoProcessor
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec\""
]
},
{

View File

@@ -1,7 +1,7 @@
base_model: google/gemma-3-1b-it
# optionally might have model_type or tokenizer_type
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
model_type: Gemma3ForCausalLM
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

View File

@@ -1,7 +1,7 @@
base_model: google/gemma-3-270m-it
# optionally might have model_type or tokenizer_type
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
model_type: Gemma3ForCausalLM
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

View File

@@ -1,5 +1,8 @@
base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM
load_in_4bit: true
# gemma3 doesn't seem to play nice with ddp

View File

@@ -2,6 +2,8 @@
[GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.
In October 2025, OpenAI released safeguard models built upon GPT-OSS called [GPT-OSS-Safeguard](https://huggingface.co/collections/openai/gpt-oss-safeguard). They use the same architecture, so the same examples below can be re-used.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
@@ -64,6 +66,16 @@ axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offlo
mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
```
### How to set reasoning_effort in template?
The harmony template has a feature to set the `reasoning_effort` during prompt building. The default is `medium`. If you would like to adjust this, you can add the following to your config:
```yaml
chat_template_kwargs:
reasoning_effort: "high" # low | medium | high
```
Currently, this applies globally. There is no method to apply per sample yet. If you are interested in adding this, please feel free to create an Issue to discuss.
### Inferencing your fine-tuned model

View File

@@ -0,0 +1,67 @@
base_model: openai/gpt-oss-safeguard-20b
use_kernels: true
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-safeguard-out/
sequence_len: 4096
sample_packing: true
adapter: lora
lora_r: 8
lora_alpha: 16
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
lora_target_linear: true
# TODO: not supported for now, see peft#2710
#lora_target_parameters: # target the experts in the last two layers
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: constant_with_warmup
learning_rate: 2e-4
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1
special_tokens:
eot_tokens:
- "<|end|>"

View File

@@ -29,7 +29,7 @@ flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
save_strategy: no
torch_compile: true
wandb_project:

View File

@@ -0,0 +1,50 @@
base_model: NousResearch/Llama-3.2-1B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_4bit: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
output_dir: ./outputs/opentelemetry-example
adapter: qlora
sequence_len: 512
sample_packing: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
# OpenTelemetry Configuration
use_otel_metrics: true
otel_metrics_host: "localhost"
otel_metrics_port: 8000
# Disable WandB
use_wandb: false
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
logging_steps: 1
flash_attention: false
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -12,7 +12,7 @@ Before starting, ensure you have:
Run the thinking model fine-tuning:
```bash
axolotl train magistral-small-think-qlora.yaml
axolotl train examples/magistral/think/magistral-small-think-qlora.yaml
```
This config uses about 19.1 GiB VRAM.

View File

@@ -21,7 +21,7 @@ Before starting, ensure you have:
3. Run the fine-tuning:
```bash
axolotl train magistral-small-vision-24B-qlora.yml
axolotl train examples/magistral/vision/magistral-small-vision-24B-qlora.yml
```
This config uses about 17GiB VRAM.

View File

@@ -0,0 +1,51 @@
# Mistral Small 3.1/3.2 Fine-tuning
This guide covers fine-tuning [Mistral Small 3.1](mistralai/Mistral-Small-3.1-24B-Instruct-2503) and [Mistral Small 3.2](mistralai/Mistral-Small-3.2-24B-Instruct-2506) with vision capabilities using Axolotl.
## Prerequisites
Before starting, ensure you have:
- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))
## Getting Started
1. Install the required vision lib:
```bash
pip install 'mistral-common[opencv]==1.8.5'
```
2. Download the example dataset image:
```bash
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
```
3. Run the fine-tuning:
```bash
axolotl train examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml
```
This config uses about 29.4 GiB VRAM.
## Dataset Format
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
Example:
```json
{
"messages": [
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
{"role": "user", "content": [
{ "type": "text", "text": "What's in this image?"},
{"type": "image", "path": "path/to/image.jpg" }
]},
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
],
}
```
## Limitations
- Sample Packing is not supported for multi-modality training currently.

View File

@@ -39,7 +39,7 @@ wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine

View File

@@ -5,31 +5,30 @@ bitsandbytes==0.47.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.6.1
liger-kernel==0.6.3
# END section
packaging==23.2
huggingface_hub>=0.33.0
peft>=0.17.0
transformers==4.56.1
huggingface_hub>=0.36.0
peft>=0.17.1
tokenizers>=0.21.1
transformers==4.57.1
accelerate==1.10.1
datasets==4.0.0
datasets==4.3.0
deepspeed>=0.17.0
trl==0.23.0
hf_xet==1.1.5
kernels==0.9.0
trl==0.24.0
hf_xet==1.2.0
kernels>=0.9.0
trackio
optimum==1.16.2
hf_transfer
sentencepiece
gradio==5.41.1
gradio==5.49.1
modal==1.0.2
pydantic==2.10.6
pydantic>=2.10.6
addict
fire
PyYAML>=6.0
@@ -37,8 +36,8 @@ requests
wandb
einops
colorama
numba
numpy>=1.24.4,<=2.0.1
numba>=0.61.2
numpy>=2.2.6
# qlora things
evaluate==0.4.1
@@ -51,7 +50,7 @@ python-dotenv==1.0.1
# remote filesystems
s3fs>=2024.5.0
gcsfs>=2024.5.0
gcsfs>=2025.3.0
adlfs>=2024.5.0
ocifs==1.3.2
@@ -67,7 +66,7 @@ antlr4-python3-runtime==4.13.2
torchao==0.13.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-lgpl==0.0.7
axolotl-contribs-mit==0.0.5
mistral-common==1.8.5

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"'
)

View File

@@ -26,7 +26,6 @@ def parse_requirements(extras_require_map):
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# skip packages not compatible with OSX
skip_packages = [
@@ -34,7 +33,6 @@ def parse_requirements(extras_require_map):
"triton",
"mamba-ssm",
"xformers",
"autoawq",
"liger-kernel",
]
_install_requires = [
@@ -51,7 +49,7 @@ def parse_requirements(extras_require_map):
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.6.0" # default to torch 2.6
torch_version = "2.8.0" # default to torch 2.8.0
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
@@ -64,8 +62,15 @@ def parse_requirements(extras_require_map):
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 8):
pass
if (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
extras_require_map["vllm"] = ["vllm==0.11.1"]
_install_requires.pop(_install_requires.index(xformers_version))
elif (major, minor) >= (2, 8):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
extras_require_map["vllm"] = ["vllm==0.11.0"]
elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
@@ -74,7 +79,7 @@ def parse_requirements(extras_require_map):
extras_require_map.pop("vllm")
else:
_install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm>=0.10.0"]
extras_require_map["vllm"] = ["vllm==0.10.1"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post3")
@@ -87,7 +92,6 @@ def parse_requirements(extras_require_map):
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers>=0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 4):
extras_require_map.pop("vllm")
@@ -161,7 +165,13 @@ extras_require = {
"llmcompressor": [
"llmcompressor==0.5.1",
],
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
"fbgemm-gpu": ["fbgemm-gpu-genai==1.3.0"],
"opentelemetry": [
"opentelemetry-api",
"opentelemetry-sdk",
"opentelemetry-exporter-prometheus",
"prometheus-client",
],
}
install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require

View File

@@ -99,7 +99,7 @@ def ray_train_func(kwargs: dict):
resolve_dtype(cfg)
# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
if cfg.deepspeed:
if cfg.deepspeed and hasattr(cfg.deepspeed, "to_dict"):
cfg.deepspeed = cfg.deepspeed.to_dict()
# initialize accelerator before model instantiation

View File

@@ -12,6 +12,9 @@ MOE_ARCH_BLOCK = {
"mixtral": "MixtralSparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
"deepseek_v3": "DeepseekV3MoE",
"gpt_oss": "GptOssDecoderLayer",
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
}

View File

@@ -29,7 +29,11 @@ from transformers.trainer_pt_utils import AcceleratorConfig
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils import (
is_comet_available,
is_mlflow_available,
is_opentelemetry_available,
)
from axolotl.utils.callbacks import (
GCCallback,
SaveAxolotlConfigtoWandBCallback,
@@ -134,6 +138,12 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_otel_metrics and is_opentelemetry_available():
from axolotl.utils.callbacks.opentelemetry import (
OpenTelemetryMetricsCallback,
)
callbacks.append(OpenTelemetryMetricsCallback(self.cfg))
if self.cfg.save_first_step:
callbacks.append(SaveModelOnFirstStepCallback())
@@ -491,6 +501,7 @@ class TrainerBuilderBase(abc.ABC):
"dion_momentum",
"dion_rank_fraction",
"dion_rank_multiple_of",
"dataset_num_proc",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
@@ -514,9 +525,6 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
# max_length is not used in CausalTrainer
if self.cfg.reward_model or self.cfg.rl:
training_args_kwargs["max_length"] = self.cfg.sequence_len

View File

@@ -12,7 +12,7 @@ from transformers import (
EarlyStoppingCallback,
Trainer,
)
from trl.trainer.utils import RewardDataCollatorWithPadding
from trl.trainer.reward_trainer import DataCollatorForPreference
from axolotl.core.builders.base import TrainerBuilderBase
from axolotl.core.trainers import (
@@ -28,7 +28,6 @@ from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
@@ -63,12 +62,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.relora:
callbacks.append(ReLoRACallback(self.cfg))
if (
hasattr(self.model, "use_bettertransformer")
and self.model.use_bettertransformer is True
):
callbacks.append(SaveBetterTransformerModelCallback())
# TODO: check if can move to base class
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
@@ -460,7 +453,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
DataCollatorWithFlattening,
RewardDataCollatorWithPadding,
DataCollatorForPreference,
]
]
collator_args = [self.tokenizer]
@@ -477,7 +470,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if kwargs and isinstance(kwargs, dict):
kwargs.update(collator_cls_and_kwargs[1])
elif self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
collator = DataCollatorForPreference
tokenizer = collator_args.pop(0)
kwargs["pad_token_id"] = tokenizer.pad_token_id
kwargs.pop("padding")
elif use_batch_sampler_collator:
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
# supported multipack models, or non-flash-attention llama

View File

@@ -225,17 +225,6 @@ class AxolotlTrainer(
data_collator = self.data_collator if is_training else self.eval_data_collator
if dataset.column_names and "length" in dataset.column_names:
dataset = dataset.remove_columns(["length"])
if (
dataset.column_names
and "position_ids" in dataset.column_names
and "attention_mask" in dataset.column_names
and self.args.sample_packing
and self.args.sample_packing_drop_attention_mask
):
dataset = dataset.remove_columns(["attention_mask"])
if isinstance(dataset, datasets.Dataset):
if is_training:
if not self.args.sample_packing or self.args.pretraining:
@@ -294,6 +283,18 @@ class AxolotlTrainer(
):
self.accelerator.even_batches = False
if dataset.column_names and "length" in dataset.column_names:
dataset = dataset.remove_columns(["length"])
if (
dataset.column_names
and "position_ids" in dataset.column_names
and "attention_mask" in dataset.column_names
and self.args.sample_packing
and self.args.sample_packing_drop_attention_mask
):
dataset = dataset.remove_columns(["attention_mask"])
dataloader = DataLoader(dataset, **dataloader_params)
# Accelerator.free_memory() will destroy the references, so
@@ -560,13 +561,6 @@ class AxolotlTrainer(
super().create_accelerator_and_postprocess()
if self.is_fsdp_enabled:
if (
"limit_all_gathers" in self.args.fsdp_config
and self.args.fsdp_config["limit_all_gathers"]
):
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
def additional_accelerator_args(
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
) -> dict[str, Any]:

View File

@@ -52,6 +52,7 @@ class GRPOStrategy:
if trl.vllm_mode:
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
if trl.vllm_mode == "colocate":
grpo_args_kwargs["vllm_enable_sleep_mode"] = trl.vllm_enable_sleep_mode # type: ignore[attr-defined]
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
vllm_cfg.gpu_memory_utilization
)

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"
```
## Usage
@@ -54,9 +54,13 @@ plugins:
- granitemoehybrid
- hunyuan_v1_dense
- hunyuan_v1_moe
- lfm2
- lfm2_moe
- lfm2_vl
- llama
- llama4
- llama4_text
- llava
- mistral
- mistral3
- mixtral

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@147ea28"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"`'
)

View File

@@ -7,7 +7,7 @@ import torch
from axolotl.utils.logging import get_logger
from .utils import create_bidirectional_attention_mask
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
LOG = get_logger(__name__)
@@ -360,7 +360,7 @@ def _diffusion_step(
# Forward pass
outputs = model(input_ids=sequence, attention_mask=attention_mask)
logits = outputs.logits
logits = shift_logits_to_input_positions(outputs.logits)
# Only sample at currently masked positions
if current_mask.any():

View File

@@ -11,7 +11,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .callbacks import DiffusionGenerationCallback
from .utils import create_bidirectional_attention_mask
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
LOG = get_logger(__name__)
@@ -207,7 +207,7 @@ class DiffusionTrainer(AxolotlTrainer):
input_ids=noisy_batch.long(),
attention_mask=bidirectional_mask,
)
logits = outputs.logits
logits = shift_logits_to_input_positions(outputs.logits)
if masked_indices.sum() > 0:
valid_indices = torch.where(masked_indices)

View File

@@ -157,3 +157,10 @@ def create_bidirectional_attention_mask(
# Add head dimension: [batch_size, 1, seq_len, seq_len]
return bidirectional_mask.unsqueeze(1)
def shift_logits_to_input_positions(logits: torch.Tensor) -> torch.Tensor:
"""Align next-token logits with their input token positions for diffusion."""
if logits.size(1) <= 1:
return logits
return torch.cat([logits[:, :1], logits[:, :-1]], dim=1)

View File

@@ -72,9 +72,9 @@ def kldiv_forward_llama_like(
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
# self._loss_function should be LigerFusedLinearKLTopKLogprobLoss
loss = self.loss_function(
loss = self._loss_function(
self.lm_head.weight,
hidden_states,
target_token_ids,

View File

@@ -29,7 +29,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_accepts_loss_kwargs = True
self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
loss_fn = LigerFusedLinearKLTopKLogprobLoss(
self.args.kd_ce_alpha, # hard label loss
self.args.kd_alpha, # kd loss
self.args.kd_temperature,
@@ -37,6 +38,14 @@ class AxolotlKDTrainer(AxolotlTrainer):
compute_ce_loss=bool(self.args.kd_ce_alpha),
normalize_topk=self.args.kd_normalize_topk,
)
target = self.model
# Unwrap PEFT wrapper
if hasattr(target, "get_base_model"):
target = target.get_base_model()
# Set on the actual model instance
target._loss_function = loss_fn
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()

View File

@@ -515,9 +515,6 @@ class ModelLoader:
if self.cfg.model_quantization_config_kwargs:
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
else:
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
if self.cfg.gptq:
if not hasattr(self.model_config, "quantization_config"):
@@ -552,9 +549,7 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and self.model_kwargs.get(
"load_in_4bit", False
):
elif self.cfg.adapter == "qlora" and self.cfg.load_in_4bit:
bnb_config = {
"load_in_4bit": True,
"llm_int8_threshold": 6.0,
@@ -580,9 +575,7 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
elif self.cfg.adapter == "lora" and self.model_kwargs.get(
"load_in_8bit", False
):
elif self.cfg.adapter == "lora" and self.cfg.load_in_8bit:
bnb_config = {
"load_in_8bit": True,
}
@@ -596,11 +589,6 @@ class ModelLoader:
**bnb_config,
)
# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
self.model_kwargs.pop("load_in_8bit", None)
self.model_kwargs.pop("load_in_4bit", None)
def _set_attention_config(self):
"""Sample packing uses custom FA2 patch"""
if self.cfg.attn_implementation:

View File

@@ -134,6 +134,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
return Qwen2Attention
if model_type == "qwen3_vl":
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextAttention
return Qwen3VLTextAttention
if model_type == "mllama":
from transformers.models.mllama.modeling_mllama import MllamaTextSelfAttention

View File

@@ -45,6 +45,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"gpt_oss",
"arcee",
"seed_oss",
"lfm2",
"lfm2_moe",
]

View File

@@ -13,9 +13,7 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
PATCHED_GUARD = (
'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):'
)
PATCHED_GUARD = 'if (attn_impl := (getattr(model.config, "_attn_implementation", None) or getattr(model.model.config, "_attn_implementation", None))) and attn_impl not in ("sdpa", "flash_attention_2"):'
def patch_prepare_context_parallel_inputs() -> None:

View File

@@ -6,8 +6,10 @@ from typing import Optional
from PIL import Image, ImageOps
from PIL.Image import Resampling
from torch import Tensor, zeros_like
from transformers import ProcessorMixin, SmolVLMProcessor, VoxtralProcessor
from transformers import ProcessorMixin
from transformers.image_utils import load_image
from transformers.models.smolvlm import SmolVLMProcessor
from transformers.models.voxtral import VoxtralProcessor
from axolotl.utils.dict import remove_none_values
from axolotl.utils.logging import get_logger

View File

@@ -71,10 +71,10 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
]
return {
"input_ids_chosen": chosen_tokenized["input_ids"],
"chosen_input_ids": chosen_tokenized["input_ids"],
"attention_mask_chosen": chosen_tokenized["attention_mask"],
"labels_chosen": 1.0,
"input_ids_rejected": rejected_tokenized["input_ids"],
"rejected_input_ids": rejected_tokenized["input_ids"],
"attention_mask_rejected": rejected_tokenized["attention_mask"],
"labels_rejected": 0.0,
}

View File

@@ -120,3 +120,123 @@ def default(cfg, dataset_idx=0, **kwargs):
return result
return transform_fn, {"remove_columns": [field_messages]}
def argilla_chat(cfg, dataset_idx=0, **kwargs):
"""
DPO chat template strategy for argilla-style datasets.
For argilla-style datasets where chosen/rejected contain full conversations
instead of single response messages. Extracts the conversation history from
the chosen field and formats both chosen/rejected responses using the
configured chat template.
Args:
cfg: Configuration object containing chat_template and dataset settings
dataset_idx: Index of the dataset in the config (default: 0)
**kwargs: Additional keyword arguments (unused)
Returns:
tuple: (transform_fn, dataset_kwargs) where:
- transform_fn: Function to transform dataset samples
- dataset_kwargs: Dict with 'remove_columns' specifying columns to drop
Dataset format:
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
"""
ds_cfg = cfg["datasets"][dataset_idx]
ds_cfg = handle_legacy_message_fields_logic(ds_cfg)
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg=cfg, ds_cfg=ds_cfg
)
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
message_property_mappings = ds_cfg.get(
"message_property_mappings",
{
"role": "role",
"content": "content",
},
)
role_map_inv = ds_cfg.get(
"roles",
{
"user": ["user"],
"assistant": ["assistant"],
"system": ["system"],
},
)
role_map = {}
for target, sources in role_map_inv.items():
for source in sources:
role_map[source] = target
def transform_fn(sample, tokenizer=None):
chat_template_string = get_chat_template(
user_choice=chat_template_choice,
jinja_template=chat_template_jinja,
tokenizer=tokenizer,
)
chosen_raw = sample[field_chosen]
rejected_raw = sample[field_rejected]
# Extract messages (all but last) and responses (last message)
chosen_messages = [
{
"role": role_map[m[message_property_mappings["role"]]],
"content": m[message_property_mappings["content"]],
}
for m in chosen_raw[:-1]
]
chosen_response = {
"role": role_map[chosen_raw[-1][message_property_mappings["role"]]],
"content": chosen_raw[-1][message_property_mappings["content"]],
}
rejected_response = {
"role": role_map[rejected_raw[-1][message_property_mappings["role"]]],
"content": rejected_raw[-1][message_property_mappings["content"]],
}
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
result = {}
result["prompt"] = tokenizer.apply_chat_template(
chosen_messages,
add_generation_prompt=True,
chat_template=chat_template_string,
tokenize=False,
)
result["chosen"] = tokenizer.apply_chat_template(
[dummy_user_message, chosen_response],
add_generation_prompt=False,
chat_template=chat_template_string,
tokenize=False,
)
chosen_strip_index = result["chosen"].find(chosen_response["content"])
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template(
[dummy_user_message, rejected_response],
add_generation_prompt=False,
chat_template=chat_template_string,
tokenize=False,
)
rejected_strip_index = result["rejected"].find(rejected_response["content"])
result["rejected"] = result["rejected"][rejected_strip_index:].rstrip()
return result
return transform_fn, {"remove_columns": [field_chosen, field_rejected]}

View File

@@ -40,11 +40,6 @@ from axolotl.utils.schemas.enums import RLType
from axolotl.utils.train import determine_last_checkpoint
from axolotl.utils.trainer import setup_trainer
try:
from optimum.bettertransformer import BetterTransformer
except ImportError:
BetterTransformer = None
if typing.TYPE_CHECKING:
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
@@ -141,8 +136,6 @@ def setup_signal_handler(
def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
_model = model_weakref()
if cfg.flash_optimum and BetterTransformer:
_model = BetterTransformer.reverse(_model)
_model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
@@ -321,9 +314,6 @@ def save_trained_model(
except FileNotFoundError:
pass
elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
@@ -535,6 +525,17 @@ def setup_model_and_trainer(
plugin_manager = PluginManager.get_instance()
plugin_manager.post_trainer_create(cfg, trainer)
if cfg.use_ray:
try:
import ray.train.huggingface.transformers
trainer = ray.train.huggingface.transformers.prepare_trainer(trainer)
except ImportError:
LOG.warning(
"The Ray integration with Hugging Face Transformers is not available. "
"To use Ray, install the 'ray[train]' package."
)
return (
trainer,
model,

View File

@@ -17,6 +17,13 @@ def is_comet_available():
return importlib.util.find_spec("comet_ml") is not None
def is_opentelemetry_available():
return (
importlib.util.find_spec("opentelemetry") is not None
and importlib.util.find_spec("prometheus_client") is not None
)
def get_pytorch_version() -> tuple[int, int, int]:
"""
Get Pytorch version as a tuple of (major, minor, patch).

View File

@@ -16,8 +16,8 @@ import pandas as pd
import torch
import torch.distributed as dist
import wandb
import yaml
from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
from transformers import (
GenerationConfig,
@@ -28,8 +28,6 @@ from transformers import (
TrainingArguments,
)
from transformers.trainer_utils import (
PREFIX_CHECKPOINT_DIR,
IntervalStrategy,
SaveStrategy,
)
from trl.models import unwrap_model_for_generation
@@ -56,40 +54,6 @@ IGNORE_INDEX = -100
LOG = get_logger(__name__)
class SaveBetterTransformerModelCallback(TrainerCallback):
"""Callback to save the BetterTransformer wrapped model"""
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> TrainerControl:
# Save
if (
args.save_strategy == IntervalStrategy.STEPS
and args.save_steps > 0
and state.global_step % args.save_steps == 0
):
control.should_save = True
if control.should_save:
checkpoint_folder = os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
model = BetterTransformer.reverse(kwargs["model"])
model.save_pretrained(checkpoint_folder)
# FIXME - need to cleanup old checkpoints
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False
return control
class LossWatchDogCallback(TrainerCallback):
"""Callback to track loss and stop training if loss is too high"""
@@ -796,6 +760,37 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
try:
with open(self.axolotl_config_path, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
chat_tpl = cfg.get("chat_template_jinja")
if chat_tpl:
with NamedTemporaryFile(
mode="w", delete=True, suffix=".jinja", prefix="chat_template_"
) as temp_ct_file:
if (
isinstance(chat_tpl, str)
and os.path.exists(chat_tpl)
and os.path.isfile(chat_tpl)
):
copyfile(chat_tpl, temp_ct_file.name)
else:
temp_ct_file.write(str(chat_tpl))
temp_ct_file.flush()
artifact = wandb.Artifact(
f"chat-template-{wandb.run.id}", type="jinja-template"
)
artifact.add_file(temp_ct_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_ct_file.name)
LOG.info(
"The chat_template_jinja has been saved to the WandB run under files."
)
except (FileNotFoundError, ConnectionError, yaml.YAMLError) as err:
LOG.warning(f"Error while saving chat_template_jinja to WandB: {err}")
if args.deepspeed:
try:
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.

View File

@@ -0,0 +1,238 @@
"""OpenTelemetry metrics callback for Axolotl training"""
import threading
from typing import Dict, Optional
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
try:
from opentelemetry import metrics
from opentelemetry.exporter.prometheus import PrometheusMetricReader
from opentelemetry.metrics import set_meter_provider
from opentelemetry.sdk.metrics import MeterProvider as SDKMeterProvider
from prometheus_client import start_http_server
OPENTELEMETRY_AVAILABLE = True
except ImportError:
LOG.warning("OpenTelemetry not available. pip install [opentelemetry]")
OPENTELEMETRY_AVAILABLE = False
class OpenTelemetryMetricsCallback(TrainerCallback):
"""
TrainerCallback that exports training metrics to OpenTelemetry/Prometheus.
This callback automatically tracks key training metrics including:
- Training loss
- Evaluation loss
- Learning rate
- Epoch progress
- Global step count
- Gradient norm
Metrics are exposed via HTTP endpoint for Prometheus scraping.
"""
def __init__(self, cfg):
if not OPENTELEMETRY_AVAILABLE:
LOG.warning("OpenTelemetry not available, metrics will not be collected")
self.metrics_enabled = False
return
self.cfg = cfg
self.metrics_host = getattr(cfg, "otel_metrics_host", "localhost")
self.metrics_port = getattr(cfg, "otel_metrics_port", 8000)
self.metrics_enabled = True
self.server_started = False
self.metrics_lock = threading.Lock()
try:
# Create Prometheus metrics reader
prometheus_reader = PrometheusMetricReader()
# Create meter provider with Prometheus exporter
provider = SDKMeterProvider(metric_readers=[prometheus_reader])
set_meter_provider(provider)
# Get meter for creating metrics
self.meter = metrics.get_meter("axolotl.training")
# Create metrics
self._create_metrics()
except Exception as e:
LOG.warning(f"Failed to initialize OpenTelemetry metrics: {e}")
self.metrics_enabled = False
def _create_metrics(self):
"""Create all metrics that will be tracked"""
self.train_loss_gauge = self.meter.create_gauge(
name="axolotl_train_loss",
description="Current training loss",
unit="1",
)
self.eval_loss_gauge = self.meter.create_gauge(
name="axolotl_eval_loss",
description="Current evaluation loss",
unit="1",
)
self.learning_rate_gauge = self.meter.create_gauge(
name="axolotl_learning_rate",
description="Current learning rate",
unit="1",
)
self.epoch_gauge = self.meter.create_gauge(
name="axolotl_epoch",
description="Current training epoch",
unit="1",
)
self.global_step_counter = self.meter.create_counter(
name="axolotl_global_steps",
description="Total training steps completed",
unit="1",
)
self.grad_norm_gauge = self.meter.create_gauge(
name="axolotl_gradient_norm",
description="Gradient norm",
unit="1",
)
self.memory_usage_gauge = self.meter.create_gauge(
name="axolotl_memory_usage",
description="Current memory usage in MB",
unit="MB",
)
def _start_metrics_server(self):
"""Start the HTTP server for metrics exposure"""
if self.server_started:
return
try:
start_http_server(self.metrics_port, addr=self.metrics_host)
self.server_started = True
LOG.info(
f"OpenTelemetry metrics server started on http://{self.metrics_host}:{self.metrics_port}/metrics"
)
except Exception as e:
LOG.error(f"Failed to start OpenTelemetry metrics server: {e}")
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Called at the beginning of training"""
if not self.metrics_enabled:
return
self._start_metrics_server()
LOG.info("OpenTelemetry metrics collection started")
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs: Optional[Dict[str, float]] = None,
**kwargs,
):
"""Called when logging occurs"""
if not self.metrics_enabled or not logs:
return
if "loss" in logs:
self.train_loss_gauge.set(logs["loss"])
if "eval_loss" in logs:
self.eval_loss_gauge.set(logs["eval_loss"])
if "learning_rate" in logs:
self.learning_rate_gauge.set(logs["learning_rate"])
if "epoch" in logs:
self.epoch_gauge.set(logs["epoch"])
if "grad_norm" in logs:
self.grad_norm_gauge.set(logs["grad_norm"])
if "memory_usage" in logs:
self.memory_usage_gauge.set(logs["memory_usage"])
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Called at the end of each training step"""
if not self.metrics_enabled:
return
# Update step counter and epoch
self.global_step_counter.add(1)
if state.epoch is not None:
self.epoch_gauge.set(state.epoch)
def on_evaluate(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
metrics: Optional[Dict[str, float]] = None,
**kwargs,
):
"""Called after evaluation"""
if not self.metrics_enabled or not metrics:
return
if "eval_loss" in metrics:
self.eval_loss_gauge.set(metrics["eval_loss"])
# Record any other eval metrics as gauges
for key, value in metrics.items():
if key.startswith("eval_") and isinstance(value, (int, float)):
# Create gauge for this metric if it doesn't exist
gauge_name = f"axolotl_{key}"
try:
gauge = self.meter.create_gauge(
name=gauge_name,
description=f"Evaluation metric: {key}",
unit="1",
)
gauge.set(value)
except Exception as e:
LOG.warning(f"Failed to create/update metric {gauge_name}: {e}")
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Called at the end of training"""
if not self.metrics_enabled:
return
LOG.info("Training completed. OpenTelemetry metrics collection finished.")
LOG.info(
f"Metrics are still available at http://{self.metrics_host}:{self.metrics_port}/metrics"
)

View File

@@ -113,7 +113,7 @@ def _map_dataset(
dataset = dataset.map(
ds_transform_fn,
num_proc=cfg.dataset_processes,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Mapping RL Dataset",
**map_kwargs,
@@ -234,7 +234,7 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)

View File

@@ -239,6 +239,11 @@ def _load_from_local_path(
return load_dataset(dataset_config.path, **load_dataset_kwargs)
elif local_path.is_file():
dataset_type = get_dataset_type(dataset_config)
# For single file datasets, HF always creates only a "train" split
if dataset_type in ("json", "csv", "text"):
load_dataset_kwargs["split"] = "train"
return load_dataset(
dataset_type,
data_files=dataset_config.path,
@@ -409,7 +414,7 @@ def save_preprocessed_dataset(
) -> None:
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
num_workers = cfg.dataset_processes or get_default_process_count()
num_workers = cfg.dataset_num_proc or get_default_process_count()
if isinstance(dataset, IterableDataset):
ds_from_iter = Dataset.from_generator(
functools.partial(_generate_from_iterable_dataset, dataset),

View File

@@ -223,7 +223,7 @@ def handle_long_seq_in_dataset(
filter_map_kwargs = {}
if not isinstance(dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {}

View File

@@ -80,7 +80,7 @@ def get_dataset_wrapper(
"""
# Common parameters for dataset wrapping
dataset_kwargs: dict[str, Any] = {
"process_count": cfg.dataset_processes,
"process_count": cfg.dataset_num_proc,
"keep_in_memory": cfg.dataset_keep_in_memory is True,
}

View File

@@ -4,6 +4,8 @@ import os
def get_default_process_count():
if axolotl_dataset_num_proc := os.environ.get("AXOLOTL_DATASET_NUM_PROC"):
return int(axolotl_dataset_num_proc)
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
return int(axolotl_dataset_processes)
if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):

View File

@@ -3,66 +3,46 @@ utils to get GPU info for the current environment
"""
import os
import subprocess # nosec B404
from importlib.metadata import version
import torch
from accelerate.utils.environment import (
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
get_gpu_info,
)
from packaging.version import Version, parse
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def check_cuda_p2p_ib_support():
if not accelerate_check_cuda_p2p_ib_support():
return False
if not check_runpod_p2p_support():
if not check_cuda_p2p_support():
return False
unsupported_devices = {"RTX 6000 Ada", "L40S"}
try:
device_names, device_count = get_gpu_info()
if 1 < device_count < 8:
if any(
unsupported_device in device_name
for device_name in device_names
for unsupported_device in unsupported_devices
):
return False
except Exception: # nosec B110
pass
return True
def check_runpod_p2p_support() -> bool:
if "RUNPOD_GPU_COUNT" not in os.environ:
return True
def check_cuda_p2p_support() -> bool:
try:
gpu_count = int(os.environ.get("RUNPOD_GPU_COUNT", "1"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
except ValueError:
return True
if gpu_count >= 2:
# run `nvidia-smi topo -p2p n` and inspect the GPU0 row
if world_size > 1:
node_world_size = int(os.environ.get("NODE_WORLD_SIZE", "8"))
local_other_rank = (local_rank // node_world_size) * node_world_size
local_other_rank += 1 if (local_rank % node_world_size) == 0 else 0
try:
result = subprocess.run( # nosec B603 B607
["nvidia-smi", "topo", "-p2p", "n"],
check=True,
capture_output=True,
text=True,
timeout=5,
)
except (
subprocess.CalledProcessError,
FileNotFoundError,
subprocess.TimeoutExpired,
):
return True # fail-open if detection fails
output_lines = result.stdout.strip().split("\n")
# filter rows that start with "GPU0" (avoid header row)
gpu0_rows = [line for line in output_lines if line.lstrip().startswith("GPU0")]
if not gpu0_rows:
can_p2p = torch.cuda.can_device_access_peer(local_rank, local_other_rank)
except AssertionError as exc:
# some sort of logic error in indexing processes, assume p2p is fine for now
LOG.warning(exc)
return True
# consider P2P supported if any OK is present in the GPU0 row
return "OK" in gpu0_rows[-1]
return can_p2p
return True

View File

@@ -5,6 +5,7 @@ into fixed-capacity batches to optimize memory usage and training throughput.
import gc
import math
import os
import time
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count, get_context
@@ -291,7 +292,10 @@ class MultipackBatchSampler(BatchSampler):
self.total_token_slots = 0
# The number of times to calculate batches to determine minimum packed dataset length
self.num_count_samples = num_count_samples
world_size = int(os.environ.get("WORLD_SIZE", "1"))
self.num_count_samples = (
1 if world_size >= num_count_samples else num_count_samples
)
if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warning(

View File

@@ -24,11 +24,13 @@ from axolotl.utils.schemas.datasets import (
)
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
from axolotl.utils.schemas.fsdp import FSDPConfig
from axolotl.utils.schemas.integrations import (
CometConfig,
GradioConfig,
LISAConfig,
MLFlowConfig,
OpenTelemetryConfig,
RayConfig,
WandbConfig,
)
@@ -59,6 +61,7 @@ class AxolotlInputConfig(
WandbConfig,
MLFlowConfig,
CometConfig,
OpenTelemetryConfig,
LISAConfig,
GradioConfig,
RayConfig,
@@ -233,6 +236,7 @@ class AxolotlInputConfig(
)
dataset_processes: int | None = Field(
default=None,
deprecated="Use `dataset_num_proc` instead. This parameter will be removed in a future version.",
json_schema_extra={
"description": (
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
@@ -240,6 +244,16 @@ class AxolotlInputConfig(
)
},
)
dataset_num_proc: int | None = Field(
default=None,
json_schema_extra={
"description": (
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
"For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT."
)
},
)
dataset_exact_deduplication: bool | None = Field(
default=None,
json_schema_extra={
@@ -667,8 +681,7 @@ class AxolotlInputConfig(
json_schema_extra={"description": "FSDP configuration"},
deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ",
)
# TODO @SalmanMohammadi strongly type this as its own schema
fsdp_config: dict[str, Any] | None = Field(
fsdp_config: FSDPConfig | None = Field(
default=None, json_schema_extra={"description": "FSDP configuration options"}
)
fsdp_version: int | None = Field(
@@ -1314,10 +1327,22 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
@model_validator(mode="before")
@classmethod
def default_dataset_processes(cls, data):
if data.get("dataset_processes") is None:
data["dataset_processes"] = get_default_process_count()
def default_dataset_num_proc(cls, data):
if data.get("dataset_processes") is not None:
if data.get("dataset_num_proc") is None:
data["dataset_num_proc"] = data["dataset_processes"]
LOG.warning(
"dataset_processes is deprecated and will be removed in a future version. "
"Please use dataset_num_proc instead."
)
else:
LOG.warning(
"Both dataset_processes and dataset_num_proc are set. "
"Using dataset_num_proc and ignoring dataset_processes."
)
del data["dataset_processes"]
elif data.get("dataset_num_proc") is None:
data["dataset_num_proc"] = get_default_process_count()
return data
@model_validator(mode="before")

View File

@@ -0,0 +1,71 @@
"""
FSDP Configuration Schema
"""
from typing import Literal
from pydantic import BaseModel, Field
class FSDPConfig(BaseModel):
"""
FSDP Configuration Schema
"""
activation_checkpointing: bool | None = Field(
default=None,
description="Enable activation checkpointing to reduce memory usage during forward passes",
)
offload_params: bool | None = Field(
default=None,
description="Offload parameters to CPU to reduce GPU memory usage",
)
sync_module_states: bool | None = Field(
default=None,
description="Synchronize module states across all processes",
)
cpu_ram_efficient_loading: bool | None = Field(
default=None,
description="Enable CPU RAM efficient loading to reduce memory usage during model loading",
)
cpu_offload_pin_memory: bool | None = Field(
default=None,
description="Disabling this enables swap memory usage for resource-constrained setups when offload_params is enabled.",
)
use_orig_params: bool | None = Field(
default=None,
description="Use original parameters instead of flattened parameters",
)
state_dict_type: (
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
) = Field(
default=None,
description="Type of state dict to use for saving/loading checkpoints",
)
final_state_dict_type: (
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
) = Field(
default=None,
description="Final state dict type to use after training completion",
)
auto_wrap_policy: Literal["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP"] | None = (
Field(
default=None,
description="Policy for automatically wrapping modules with FSDP",
)
)
transformer_layer_cls_to_wrap: str | None = Field(
default=None,
description="Class name of transformer layers to wrap (e.g., 'LlamaDecoderLayer')",
)
reshard_after_forward: bool | None = Field(
default=None,
description="Reshard parameters after forward pass to save memory",
)
mixed_precision_policy: str | None = Field(
default=None,
description="Mixed precision policy for FSDP (e.g., 'fp16', 'bf16')",
)

View File

@@ -176,3 +176,27 @@ class RayConfig(BaseModel):
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
},
)
class OpenTelemetryConfig(BaseModel):
"""OpenTelemetry configuration subset"""
use_otel_metrics: bool | None = Field(
default=False,
json_schema_extra={
"description": "Enable OpenTelemetry metrics collection and Prometheus export"
},
)
otel_metrics_host: str | None = Field(
default="localhost",
json_schema_extra={
"title": "OpenTelemetry Metrics Host",
"description": "Host to bind the OpenTelemetry metrics server to",
},
)
otel_metrics_port: int | None = Field(
default=8000,
json_schema_extra={
"description": "Port for the Prometheus metrics HTTP server"
},
)

View File

@@ -167,3 +167,9 @@ class TRLConfig(BaseModel):
"description": "Whether to exclude truncated completions from loss calculation."
},
)
vllm_enable_sleep_mode: bool | None = Field(
default=None,
json_schema_extra={
"description": "Enable sleep mode for vLLM to offload VRAM when idle"
},
)

View File

@@ -783,15 +783,6 @@ class OptimizationValidationMixin:
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_deepspeed(cls, data):
if data.get("deepspeed") and data.get("torch_compile"):
raise ValueError(
"torch_compile should be set within your deepspeed config file"
)
return data
@model_validator(mode="before")
@classmethod
def check_xentropy_patch_conflicts(cls, data):
@@ -890,7 +881,7 @@ class OptimizationValidationMixin:
and self.fsdp_config
and self.optimizer
and "8bit" in self.optimizer.value
and self.fsdp_config["offload_params"]
and self.fsdp_config.offload_params
and str(self.fsdp_version) != "2"
):
raise ValueError(

View File

@@ -6,6 +6,7 @@ import os
import random
from contextlib import contextmanager
from functools import partial
from tempfile import NamedTemporaryFile
from typing import List, Optional
import numpy as np
@@ -15,6 +16,7 @@ from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.logging import get_logger
@@ -276,7 +278,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
prior_len = None
filter_map_kwargs = {}
if not isinstance(train_dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {}
@@ -316,7 +318,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if cfg.group_by_length:
train_dataset = train_dataset.map(
add_length,
num_proc=cfg.dataset_processes,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Group By Length",
)
@@ -333,7 +335,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
)
train_dataset = train_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
@@ -342,7 +344,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset:
eval_dataset = eval_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
@@ -467,7 +469,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
bin_size=cfg.sample_packing_bin_size,
sequential=cfg.sample_packing_sequentially,
drop_last=True,
num_processes=cfg.dataset_processes,
num_processes=cfg.dataset_prcoesses,
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
)
@@ -540,6 +542,13 @@ def setup_deepspeed_env(cfg, stage=None):
)
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
if isinstance(cfg.deepspeed, DictDefault):
with NamedTemporaryFile(
mode="w", delete=False, suffix=".json", prefix="deepspeed_config_"
) as temp_file:
temp_file.write(json.dumps(cfg.deepspeed.to_dict(), indent=4))
temp_file.close()
cfg.deepspeed = str(temp_file.name)
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
os.environ["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str(
cfg.gradient_accumulation_steps
@@ -562,6 +571,7 @@ def setup_deepspeed_env(cfg, stage=None):
if (
int(os.environ.get("WORLD_SIZE", "1")) == 1
and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1"
and cfg.use_ray is not True
):
os.environ["WORLD_SIZE"] = "1" # force it in case not set
os.environ["LOCAL_RANK"] = "0" # force it in case not set
@@ -631,6 +641,7 @@ def setup_parallelism_envs(cfg):
def prepare_optim_env(cfg):
if not check_cuda_p2p_ib_support():
if os.getenv("NCCL_P2P_DISABLE") is None:
LOG.warning("P2P support not detected, setting `NCCL_P2P_DISABLE=1`")
os.environ["NCCL_P2P_DISABLE"] = "1"
# TODO @SalmanMohammadi remove the cfg.fsdp check in 0.12
if cfg.fsdp or cfg.fsdp_config:
@@ -638,11 +649,15 @@ def prepare_optim_env(cfg):
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
stage = None
deepspeed_config = None
# check if the cfg.deepspeed is a file
if os.path.isfile(cfg.deepspeed):
if isinstance(cfg.deepspeed, DictDefault):
deepspeed_config = cfg.deepspeed
elif os.path.isfile(cfg.deepspeed):
# parse with json
with open(cfg.deepspeed, "r", encoding="utf-8") as fin:
deepspeed_config = json.load(fin)
if deepspeed_config:
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
setup_deepspeed_env(cfg, stage=stage)

View File

@@ -33,7 +33,6 @@ def parse_requirements():
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
@@ -63,7 +62,6 @@ def parse_requirements():
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers==0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))

View File

@@ -440,7 +440,7 @@ def rand_reward_func(prompts, completions) -> list[float]:
]
else:
raise ValueError(f"Unhandled cfg_string: {cfg_string}")
cfg["dataset_processes"] = 4
cfg["dataset_num_proc"] = 4
if cfg_string == "grpo_cfg":
rewards_dir = tmp_path / "rewards_test"

View File

@@ -104,7 +104,6 @@ class TestKnowledgeDistillation:
temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high"
)
@pytest.mark.skip(reason="Chunked KD loss doesn't support PEFT/LoRA")
@pytest.mark.parametrize(
"load_in_8bit",
[True, False],
@@ -120,6 +119,10 @@ class TestKnowledgeDistillation:
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.0,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"lora_mlp_kernel": False,
"lora_qkv_kernel": False,
"lora_o_kernel": False,
}
| kd_min_cfg
)

View File

@@ -353,7 +353,6 @@ class TestMultiGPULlama:
"auto_wrap",
],
"fsdp_config": {
"fsdp_limit_all_gathers": True,
"fsdp_offload_params": False,
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
@@ -431,7 +430,6 @@ class TestMultiGPULlama:
"auto_wrap",
],
"fsdp_config": {
"fsdp_limit_all_gathers": True,
"fsdp_offload_params": False,
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
@@ -594,7 +592,6 @@ class TestMultiGPULlama:
"auto_wrap",
],
"fsdp_config": {
"fsdp_limit_all_gathers": True,
"fsdp_offload_params": False,
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,

View File

@@ -13,7 +13,6 @@ from axolotl.utils.dict import DictDefault
from tests.e2e.utils import (
check_tensorboard,
require_torch_2_7_0,
require_torch_lt_2_6_0,
)
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@@ -24,7 +23,7 @@ class TestMultiGPURay:
Test cases for AnyScale Ray post training
"""
@require_torch_lt_2_6_0
@require_torch_2_7_0
def test_lora_ddp(self, temp_dir):
cfg = DictDefault(
{
@@ -83,7 +82,7 @@ class TestMultiGPURay:
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@require_torch_lt_2_6_0
@require_torch_2_7_0
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],

View File

@@ -69,7 +69,7 @@ class TestActivationCheckpointing:
"save_safetensors": True,
"gradient_checkpointing": gradient_checkpointing,
"save_first_step": False,
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)

View File

@@ -29,7 +29,7 @@ class TestPretrainLlama:
"sequence_len": 1024,
"sample_packing": sample_packing,
"pretrain_multipack_attn": pretrain_multipack_attn,
"dataset_processes": 1,
"dataset_num_proc": 1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},

View File

@@ -8,7 +8,7 @@ import pytest
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.prompt_strategies.dpo.chat_template import default
from axolotl.prompt_strategies.dpo.chat_template import argilla_chat, default
from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
@@ -78,6 +78,36 @@ def fixture_custom_assistant_dataset():
)
@pytest.fixture(name="argilla_chat_dataset")
def fixture_argilla_chat_dataset():
return Dataset.from_list(
[
{
"chosen": [
{
"role": "user",
"content": "hello",
},
{
"role": "assistant",
"content": "goodbye",
},
],
"rejected": [
{
"role": "user",
"content": "hello",
},
{
"role": "assistant",
"content": "party on",
},
],
}
]
)
@pytest.fixture(name="phi3_tokenizer")
@enable_hf_offline
def fixture_phi3_tokenizer():
@@ -216,5 +246,51 @@ class TestAssistantDPOChatTemplateGemma:
assert result["rejected"] == "party on<end_of_turn>"
class TestArgillaChatDPOChatTemplate:
"""
Test class for argilla_chat style datasets (chosen/rejected contain full conversations).
"""
def test_llama3_argilla_chat(self, llama3_tokenizer, argilla_chat_dataset):
transform_fn, _ = argilla_chat(
DictDefault(
{
"chat_template": "llama3",
"datasets": [
{
"type": "chat_template.argilla_chat",
}
],
}
)
)
result = transform_fn(argilla_chat_dataset[0], tokenizer=llama3_tokenizer)
assert result["prompt"] == (
"<|begin_of_text|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert result["chosen"] == "goodbye<|eot_id|>"
assert result["rejected"] == "party on<|eot_id|>"
def test_phi3_argilla_chat(self, phi3_tokenizer, argilla_chat_dataset):
transform_fn, _ = argilla_chat(
DictDefault(
{
"chat_template": "tokenizer_default",
"datasets": [
{
"type": "chat_template.argilla_chat",
}
],
}
)
)
result = transform_fn(argilla_chat_dataset[0], tokenizer=phi3_tokenizer)
assert result["prompt"] == "<|user|>\nhello<|end|>\n" + "<|assistant|>\n"
assert result["chosen"] == "goodbye<|end|>"
assert result["rejected"] == "party on<|end|>"
if __name__ == "__main__":
unittest.main()

View File

@@ -141,7 +141,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)
@@ -180,7 +180,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)
@@ -219,7 +219,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)
@@ -252,7 +252,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)
@@ -285,7 +285,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)
@@ -370,7 +370,7 @@ class TestDatasetPreparation:
"rl": "dpo",
"chat_template": "llama3",
"datasets": [ALPACA_MESSAGES_CONFIG_REVISION],
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)
@@ -471,7 +471,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)

View File

@@ -210,7 +210,7 @@ class TestDeduplicateRLDataset:
ALPACA_MESSAGES_CONFIG_REVISION,
ALPACA_MESSAGES_CONFIG_REVISION,
],
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)
yield fixture

View File

@@ -80,16 +80,26 @@ class TestModelsUtils:
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
)
elif load_in_8bit and self.cfg.adapter is not None:
assert self.model_loader.model_kwargs["load_in_8bit"]
elif load_in_4bit and self.cfg.adapter is not None:
assert self.model_loader.model_kwargs["load_in_4bit"]
if (self.cfg.adapter == "qlora" and load_in_4bit) or (
self.cfg.adapter == "lora" and load_in_8bit
):
assert self.model_loader.model_kwargs.get(
"quantization_config", BitsAndBytesConfig
if self.cfg.adapter == "qlora" and load_in_4bit:
assert isinstance(
self.model_loader.model_kwargs.get("quantization_config"),
BitsAndBytesConfig,
)
assert (
self.model_loader.model_kwargs["quantization_config"]._load_in_4bit
is True
)
if self.cfg.adapter == "lora" and load_in_8bit:
assert isinstance(
self.model_loader.model_kwargs.get("quantization_config"),
BitsAndBytesConfig,
)
assert (
self.model_loader.model_kwargs["quantization_config"]._load_in_8bit
is True
)
def test_message_property_mapping(self):

View File

@@ -111,7 +111,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_offload_params": False,
"fsdp_cpu_ram_efficient_loading": True,
"regular_param": "value",
}
}
)
@@ -124,7 +123,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
)
self.assertEqual(cfg_with_version.fsdp_config.offload_params, False)
self.assertEqual(cfg_with_version.fsdp_config.cpu_ram_efficient_loading, True)
self.assertEqual(cfg_with_version.fsdp_config.regular_param, "value")
self.assertNotIn("fsdp_auto_wrap_policy", cfg_with_version.fsdp_config)
self.assertNotIn("fsdp_offload_params", cfg_with_version.fsdp_config)
@@ -137,7 +135,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
"fsdp_config": {
"fsdp_auto_wrap_policy": "SIZE_BASED_WRAP",
"fsdp_offload_params": True,
"regular_param": "value",
}
}
)
@@ -149,7 +146,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
cfg_without_version.fsdp_config.auto_wrap_policy, "SIZE_BASED_WRAP"
)
self.assertEqual(cfg_without_version.fsdp_config.offload_params, True)
self.assertEqual(cfg_without_version.fsdp_config.regular_param, "value")
self.assertNotIn("fsdp_auto_wrap_policy", cfg_without_version.fsdp_config)
self.assertNotIn("fsdp_offload_params", cfg_without_version.fsdp_config)

View File

@@ -0,0 +1,349 @@
"""Tests for OpenTelemetry metrics callback functionality."""
import time
import pytest
from axolotl.utils.dict import DictDefault
@pytest.fixture
def mock_otel_config():
"""Mock configuration for OpenTelemetry callback."""
return DictDefault(
{
"use_otel_metrics": True,
"otel_metrics_host": "localhost",
"otel_metrics_port": 8003, # Use unique port for tests
}
)
@pytest.fixture
def mock_trainer_state():
"""Mock trainer state for callback testing."""
from transformers import TrainerState
state = TrainerState()
state.epoch = 1.0
state.global_step = 100
return state
@pytest.fixture
def mock_training_args():
"""Mock training arguments for callback testing."""
from transformers import TrainingArguments
return TrainingArguments(output_dir="/tmp/test")
@pytest.fixture
def mock_trainer_control():
"""Mock trainer control for callback testing."""
from transformers.trainer_callback import TrainerControl
return TrainerControl()
class TestOpenTelemetryConfig:
"""Test OpenTelemetry configuration schema."""
def test_config_schema_valid(self):
"""Test OpenTelemetry configuration schema validation."""
from axolotl.utils.schemas.integrations import OpenTelemetryConfig
# Test valid config
valid_config = {
"use_otel_metrics": True,
"otel_metrics_host": "localhost",
"otel_metrics_port": 8000,
}
otel_config = OpenTelemetryConfig(**valid_config)
assert otel_config.use_otel_metrics is True
assert otel_config.otel_metrics_host == "localhost"
assert otel_config.otel_metrics_port == 8000
def test_config_defaults(self):
"""Test OpenTelemetry configuration default values."""
from axolotl.utils.schemas.integrations import OpenTelemetryConfig
# Test minimal config with defaults
minimal_config = {"use_otel_metrics": True}
otel_config = OpenTelemetryConfig(**minimal_config)
assert otel_config.use_otel_metrics is True
assert otel_config.otel_metrics_host == "localhost" # default
assert otel_config.otel_metrics_port == 8000 # default
def test_config_disabled_by_default(self):
"""Test that OpenTelemetry is disabled by default."""
from axolotl.utils.schemas.integrations import OpenTelemetryConfig
# Test default config
default_config = OpenTelemetryConfig()
assert default_config.use_otel_metrics is False
class TestOpenTelemetryCallback:
"""Test OpenTelemetry callback functionality."""
def test_callback_import(self):
"""Test that OpenTelemetry callback can be imported."""
from axolotl.utils.callbacks.opentelemetry import OpenTelemetryMetricsCallback
assert OpenTelemetryMetricsCallback is not None
def test_callback_graceful_fallback(self, mock_otel_config):
"""Test callback gracefully handles missing dependencies."""
from axolotl.utils.callbacks.opentelemetry import OpenTelemetryMetricsCallback
# This should not raise an exception even if dependencies are missing
callback = OpenTelemetryMetricsCallback(mock_otel_config)
# Callback should exist but may have metrics disabled
assert callback is not None
assert hasattr(callback, "metrics_enabled")
def test_callback_initialization_enabled(self, mock_otel_config):
"""Test callback initialization when OpenTelemetry is available."""
from axolotl.utils.callbacks.opentelemetry import (
OPENTELEMETRY_AVAILABLE,
OpenTelemetryMetricsCallback,
)
callback = OpenTelemetryMetricsCallback(mock_otel_config)
if OPENTELEMETRY_AVAILABLE:
assert callback.metrics_enabled is True
assert callback.cfg == mock_otel_config
assert callback.metrics_host == "localhost"
assert callback.metrics_port == 8003
else:
assert callback.metrics_enabled is False
def test_metrics_server_lifecycle(
self,
mock_otel_config,
mock_trainer_state,
mock_training_args,
mock_trainer_control,
):
"""Test metrics server starts and stops correctly."""
from axolotl.utils.callbacks.opentelemetry import (
OPENTELEMETRY_AVAILABLE,
OpenTelemetryMetricsCallback,
)
if not OPENTELEMETRY_AVAILABLE:
pytest.skip("OpenTelemetry dependencies not available")
callback = OpenTelemetryMetricsCallback(mock_otel_config)
# Start server
callback.on_train_begin(
mock_training_args, mock_trainer_state, mock_trainer_control
)
assert callback.server_started is True
# End training
callback.on_train_end(
mock_training_args, mock_trainer_state, mock_trainer_control
)
def test_metrics_recording(
self,
mock_otel_config,
mock_trainer_state,
mock_training_args,
mock_trainer_control,
):
"""Test that metrics are recorded during training."""
from axolotl.utils.callbacks.opentelemetry import (
OPENTELEMETRY_AVAILABLE,
OpenTelemetryMetricsCallback,
)
if not OPENTELEMETRY_AVAILABLE:
pytest.skip("OpenTelemetry dependencies not available")
callback = OpenTelemetryMetricsCallback(mock_otel_config)
callback.on_train_begin(
mock_training_args, mock_trainer_state, mock_trainer_control
)
# Test logging metrics
test_logs = {
"loss": 0.5,
"learning_rate": 1e-4,
"grad_norm": 0.8,
}
# This should not raise an exception
callback.on_log(
mock_training_args, mock_trainer_state, mock_trainer_control, logs=test_logs
)
assert callback.metrics_enabled is True
def test_evaluation_metrics(
self,
mock_otel_config,
mock_trainer_state,
mock_training_args,
mock_trainer_control,
):
"""Test evaluation metrics recording."""
from axolotl.utils.callbacks.opentelemetry import (
OPENTELEMETRY_AVAILABLE,
OpenTelemetryMetricsCallback,
)
if not OPENTELEMETRY_AVAILABLE:
pytest.skip("OpenTelemetry dependencies not available")
callback = OpenTelemetryMetricsCallback(mock_otel_config)
callback.on_train_begin(
mock_training_args, mock_trainer_state, mock_trainer_control
)
# Test evaluation metrics
eval_logs = {
"eval_loss": 0.3,
"eval_accuracy": 0.95,
}
# This should not raise an exception
callback.on_evaluate(
mock_training_args, mock_trainer_state, mock_trainer_control, eval_logs
)
assert callback.metrics_enabled is True
def test_thread_safety(self, mock_otel_config):
"""Test that callback has thread safety mechanisms."""
from axolotl.utils.callbacks.opentelemetry import (
OPENTELEMETRY_AVAILABLE,
OpenTelemetryMetricsCallback,
)
if not OPENTELEMETRY_AVAILABLE:
pytest.skip("OpenTelemetry dependencies not available")
callback = OpenTelemetryMetricsCallback(mock_otel_config)
assert hasattr(callback, "metrics_lock")
# Check it's a lock-like object
assert hasattr(callback.metrics_lock, "__enter__")
assert hasattr(callback.metrics_lock, "__exit__")
class TestOpenTelemetryIntegration:
"""Integration tests for OpenTelemetry."""
def test_availability_check(self):
"""Test availability check function."""
from axolotl.utils import is_opentelemetry_available
result = is_opentelemetry_available()
assert isinstance(result, bool)
def test_prometheus_endpoint_basic(
self,
mock_otel_config,
mock_trainer_state,
mock_training_args,
mock_trainer_control,
):
"""Test basic Prometheus endpoint functionality."""
from axolotl.utils.callbacks.opentelemetry import (
OPENTELEMETRY_AVAILABLE,
OpenTelemetryMetricsCallback,
)
if not OPENTELEMETRY_AVAILABLE:
pytest.skip("OpenTelemetry dependencies not available")
try:
import requests
except ImportError:
pytest.skip("requests library not available")
callback = OpenTelemetryMetricsCallback(mock_otel_config)
callback.on_train_begin(
mock_training_args, mock_trainer_state, mock_trainer_control
)
if not callback.server_started:
pytest.skip("Metrics server failed to start")
# Give server time to start
time.sleep(1)
# Try to access metrics endpoint
try:
response = requests.get(
f"http://{callback.metrics_host}:{callback.metrics_port}/metrics",
timeout=2,
)
assert response.status_code == 200
# Check for Prometheus format
assert "# TYPE" in response.text or "# HELP" in response.text
except requests.exceptions.RequestException:
pytest.skip(
"Could not connect to metrics endpoint - this is expected in some environments"
)
class TestOpenTelemetryCallbackMethods:
"""Test specific callback methods."""
def test_step_end_callback(
self,
mock_otel_config,
mock_trainer_state,
mock_training_args,
mock_trainer_control,
):
"""Test step end callback method."""
from axolotl.utils.callbacks.opentelemetry import (
OPENTELEMETRY_AVAILABLE,
OpenTelemetryMetricsCallback,
)
if not OPENTELEMETRY_AVAILABLE:
pytest.skip("OpenTelemetry dependencies not available")
callback = OpenTelemetryMetricsCallback(mock_otel_config)
callback.on_train_begin(
mock_training_args, mock_trainer_state, mock_trainer_control
)
# Should not raise an exception
callback.on_step_end(
mock_training_args, mock_trainer_state, mock_trainer_control
)
def test_epoch_end_callback(
self,
mock_otel_config,
mock_trainer_state,
mock_training_args,
mock_trainer_control,
):
"""Test epoch end callback method."""
from axolotl.utils.callbacks.opentelemetry import (
OPENTELEMETRY_AVAILABLE,
OpenTelemetryMetricsCallback,
)
if not OPENTELEMETRY_AVAILABLE:
pytest.skip("OpenTelemetry dependencies not available")
callback = OpenTelemetryMetricsCallback(mock_otel_config)
callback.on_train_begin(
mock_training_args, mock_trainer_state, mock_trainer_control
)
# Should not raise an exception
callback.on_epoch_end(
mock_training_args, mock_trainer_state, mock_trainer_control
)

View File

@@ -55,7 +55,7 @@ class TestPacking(unittest.TestCase):
"type": "alpaca",
},
],
"dataset_processes": 4,
"dataset_num_proc": 4,
"num_epochs": 1,
"max_steps": 20,
"save_steps": 10,