Files
axolotl/examples/gpt-oss
Wing Lian e4032fc90f Refactor separate attention flags with attn_implementation and capability/concerns feature flags (#3602)
* upgrade to torchao 0.17.0

* chore: lint

* refactor attention handling

* replace legacy attention boolean flags with capability properties

Replace checks with capability-based properties derived from attn_implementation

This separates three concerns that were conflated under flash_attention:
1. Backend selection -> attn_implementation enum
2. Packing capability -> attn_supports_packing property
3. Flash-attn library dependency -> attn_uses_flash_lib property

* compute attn capability flags in normalizer instead of properties

* make attn_implementation the single source of truth

* move attention-dependent validators to mode=after

* migrate remaining consumers to canonical attn_implementation

* expand attention tests + rewrite docs

* migrate example configs to canonical attn_implementation

* update doc snippets + reject gemma4-hybrid with non-FA2 backend

* remove dead gemma4 branch in _set_attention_config

* fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests

* drop "Phase 2" naming from attn-implementation tests

* regroup attn_implementation tests by feature concern

* clean up verbose comments and remove MD

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x

In transformers 5.x, ProcessorMixin.apply_chat_template gained its own
`return_dict` parameter (defaulting to False).  When return_dict=False
and tokenize=True the method returns out["input_ids"] directly — a 2-D
tensor — rather than the full BatchFeature dict.

The old code placed `return_dict=True` inside processor_kwargs.  In
transformers 5.x those kwargs are forwarded to the underlying processor
call self(...) where _merge_kwargs silently ignores any key not present
in MllamaProcessorKwargs (emitting a warning).  The outer return_dict
therefore stayed False, apply_chat_template returned the raw input_ids
tensor, and the subsequent `batch["input_ids"]` attempted to index a
2-D tensor with the 9-character string "input_ids", producing:

  IndexError: too many indices for tensor of dimension 2

The fix is to pass return_dict=True as a top-level keyword argument to
apply_chat_template (where it is actually consumed) and remove it from
processor_kwargs (where it was silently dropped).  No version guard is
needed: transformers is pinned to ==5.5.4 in pyproject.toml.

Adds a unit-level regression test (tests/test_mm_chat_collator.py) that
mocks the processor to return a raw tensor when apply_chat_template is
called without top-level return_dict=True, verifying the four invariants:
process_rows returns a dict, input_ids is 2-D, labels is 2-D, and
apply_chat_template receives return_dict=True as a top-level kwarg.

Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset
Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset
Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): process_rows returns dict (BatchFeature) shape

Two related changes for the multimodal chat collator under transformers 5.x:

1. Wrap apply_chat_template result in dict(...) so process_rows returns
   a plain dict rather than a BatchFeature instance. BatchFeature is a
   Mapping but not a dict; downstream code that did
     batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"])
   would index on a tensor when the result wasn't dict-shaped, raising
     IndexError: too many indices for tensor of dimension 2

2. Soften the regression test's contract from `dict` to `Mapping` so it
   exercises the actual semantic guarantee (key/value access) rather
   than the implementation detail (dict vs BatchFeature). Test guards
   against the original transformers 5.x breakage where apply_chat_template's
   return_dict default went from True to False.

Includes regression test under tests/test_mm_chat_collator.py.

Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against
attn-implementation-refactor; squash-merged from agent commits 4de886fd
+ dc9fcf4f.

Signed-off-by: Wing Lian <wing@axolotl.ai>

---------

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
2026-05-05 10:15:18 -04:00
..
2026-04-24 14:23:09 +07:00

Finetune OpenAI's GPT-OSS with Axolotl

GPT-OSS are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.

In October 2025, OpenAI released safeguard models built upon GPT-OSS called GPT-OSS-Safeguard. They use the same architecture, so the same examples below can be re-used.

This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.

Getting started

  1. Install Axolotl following the installation guide.

    Here is an example of how to install from pip:

# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
uv pip install --no-build-isolation 'axolotl>=0.16.1'
  1. Choose one of the following configs below for training the 20B model. (for 120B, see below)
# LoRA SFT linear layers (1x48GB @ ~44GiB)
axolotl train examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml

# FFT SFT with offloading (2x24GB @ ~21GiB/GPU)
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml

# FFT SFT (8x48GB @ ~36GiB/GPU or 4x80GB @ ~46GiB/GPU)
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml

Note: Memory usage taken from device_mem_reserved(gib) from logs.

Training 120B

On 8xH100s, make sure you have ~3TB of free disk space. With each checkpoint clocking in at ~720GB, along with the base model, and final model output, you may need at least 3TB of free disk space to keep at least 2 checkpoints.

# FFT SFT with offloading (8x80GB @ ~49GiB/GPU)
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml

To simplify fine-tuning across 2 nodes × 8x H100 (80GB) GPUs, we've partnered with Baseten to showcase multi-node training of the 120B model using Baseten Truss. You can read more about this recipe on Baseten's blog. The recipe can be found on their GitHub.

ERRATA: Transformers saves the model Architecture prefixed with FSDP which needs to be manually renamed in config.json. See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.

sed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json

When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your configured output_dir. However, if that step fails due to a disk space error, you can take an additional step to merge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded weights to {output_dir}/merged.

axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/

How to set reasoning_effort in template?

The harmony template has a feature to set the reasoning_effort during prompt building. The default is medium. If you would like to adjust this, you can add the following to your config:

chat_template_kwargs:
  reasoning_effort: "high"  # low | medium | high

Currently, this applies globally. There is no method to apply per sample yet. If you are interested in adding this, please feel free to create an Issue to discuss.

Inferencing your fine-tuned model

vLLM

GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425 for more information about using a special vllm-openai docker image for inferencing with vLLM.

Optionally, vLLM can be installed from nightly:

uv pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly

and the vLLM server can be started with the following command (modify --tensor-parallel-size 8 to match your environment):

vllm serve ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-20b --host 0.0.0.0 --port 8888  --tensor-parallel-size 8

SGLang

SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:

python3 -m sglang.launch_server --model ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-120b --host 0.0.0.0 --port 8888 --tp 8

Tool use

GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning.

Here is an example dataset config:

datasets:
  - path: Nanobit/text-tools-2k-test
    type: chat_template

See Nanobit/text-tools-2k-test for the sample dataset.

Refer to our docs for more info.

Thinking and chat_template masking conflict

OpenAIs Harmony template hides thinking in all non-final turns, which conflicts with Axolotls chat_template masking.

If your dataset has thinking content mid-turn, there are two paths we recommend:

  • Train only on the last turn. This can be accomplished via chat_template's train on last doc.

  • Adjust your dataset to only have thinking content in the last turn.

TIPS

  • Read more on how to load your own dataset at docs.
  • The dataset format follows the OpenAI Messages format as seen here.

Optimization Guides