add kernels for gpt oss models (#3020)

* add kernels for gpt oss models

* add support for gpt-oss

* typo incorrect package

* fix: layout for configs and added wandb/epochs

* add gptoss example w offload and set moe leaf for z3

* add support for Mxfp4Config from yaml

* update yaml to use official model

* fix lora and don't allow triton to go above 3.3.1

* fix lr and tweak vram use

* fix range for triton since pinned wasn't compatible with toch 2.6.0

* update cce with gpt oss patches

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Wing Lian
2025-08-06 09:47:55 -04:00
committed by GitHub
parent 97e86c6d47
commit ba3dba3e4f
15 changed files with 257 additions and 11 deletions

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169\""
]
},
{

View File

@@ -0,0 +1,9 @@
# OpenAI's GPT-OSS
GPT-OSS is a 20 billion parameter MoE model trained by OpenAI, released in August 2025.
- 20B Full Parameter SFT can be trained on 8x48GB GPUs (peak reserved memory @ ~36GiB/GPU) - [YAML](./gpt-oss-20b-fft-fsdp2.yaml)
- 20B LoRA SFT (all linear layers, and experts in last two layers) can be trained a single GPU (peak reserved memory @ ~47GiB)
- removing the experts from `lora_target_parameters` will allow the model to fit around ~44GiB of VRAM
- [YAML](./gpt-oss-20b-sft-lora-singlegpu.yaml)
- 20B Full Parameter SFT with FSDP2 offloading can be trained on 2x24GB GPUs (peak reserved memory @ ~21GiB/GPU) - [YAML](./gpt-oss-20b-fft-fsdp2-offload.yaml)

View File

@@ -0,0 +1,62 @@
base_model: openai/gpt-oss-20b
use_kernels: true
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: winglian/pirate-ultrachat-10k
type: chat_template
split: train
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1
special_tokens:
eot_tokens:
- "<|end|>"
fsdp_version: 2
fsdp_config:
offload_params: true
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: GptOssDecoderLayer
reshard_after_forward: true

View File

@@ -0,0 +1,62 @@
base_model: openai/gpt-oss-20b
use_kernels: true
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: winglian/pirate-ultrachat-10k
type: chat_template
split: train
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1
special_tokens:
eot_tokens:
- "<|end|>"
fsdp_version: 2
fsdp_config:
offload_params: false
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: GptOssDecoderLayer
reshard_after_forward: true

View File

@@ -0,0 +1,64 @@
base_model: openai/gpt-oss-20b
use_kernels: true
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
datasets:
- path: winglian/pirate-ultrachat-10k
type: chat_template
split: train
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
adapter: lora
lora_r: 8
lora_alpha: 16
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
lora_target_linear: true
lora_target_parameters: # target the experts in the last two layers
- "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
- "22._checkpoint_wrapped_module.mlp.experts.down_proj"
- "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
- "23._checkpoint_wrapped_module.mlp.experts.down_proj"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: constant_with_warmup
learning_rate: 2e-4
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1
special_tokens:
eot_tokens:
- "<|end|>"

View File

@@ -2,7 +2,8 @@
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.46.1
triton>=3.0.0
# triton 3.4.0 is not compatible with CCE
triton>=3.0.0,<3.4.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
@@ -20,6 +21,7 @@ datasets==4.0.0
deepspeed>=0.17.0
trl==0.20.0
hf_xet==1.1.5
kernels==0.9.0
optimum==1.16.2
hf_transfer

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169"'
)

View File

@@ -13,4 +13,5 @@ MOE_ARCH_BLOCK = {
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
"gpt_oss": "GptOssExperts",
}

View File

@@ -567,9 +567,9 @@ class AxolotlTrainer(
# Add memory usage
try:
active, allocated, reserved = get_gpu_memory_usage()
logs["memory/max_memory_active(gib)"] = round(active, 2)
logs["memory/max_memory_allocated(gib)"] = round(allocated, 2)
logs["memory/device_memory_reserved(gib)"] = round(reserved, 2)
logs["memory/max_mem_active(gib)"] = round(active, 2)
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
except (ValueError, TypeError, FileNotFoundError):
pass

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169"
```
## Usage

View File

@@ -34,7 +34,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169"`'
)

View File

@@ -202,6 +202,8 @@ class ModelLoader:
self._set_device_map_config()
if self.cfg.revision_of_model:
self.model_kwargs["revision"] = self.cfg.revision_of_model
if self.cfg.use_kernels:
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
self._set_quantization_config()
self._set_attention_config()
@@ -565,8 +567,17 @@ class ModelLoader:
def _set_quantization_config(self):
"""Set up quantization config (bitsandbytes, awq, gptq, etc.)"""
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
if self.cfg.model_quantization_config == "Mxfp4Config":
from transformers import Mxfp4Config
mxfp4_kwargs = {}
if self.cfg.model_quantization_config_kwargs:
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
else:
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
if self.cfg.gptq:
if not hasattr(self.model_config, "quantization_config"):
@@ -648,7 +659,9 @@ class ModelLoader:
def _set_attention_config(self):
"""Sample packing uses custom FA2 patch"""
if self.cfg.flex_attention:
if self.cfg.attn_implementation:
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
elif self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"

View File

@@ -544,6 +544,13 @@ class AxolotlInputConfig(
eager_attention: bool | None = None
attn_implementation: str | None = Field(
default=None,
json_schema_extra={
"description": "Specify a custom attention implementation, used mostly for kernels."
},
)
unsloth_cross_entropy_loss: bool | None = None
unsloth_lora_mlp: bool | None = None
unsloth_lora_qkv: bool | None = None

View File

@@ -1,5 +1,7 @@
"""Pydantic models for model input / output, etc. configuration"""
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from axolotl.utils.logging import get_logger
@@ -70,6 +72,20 @@ class ModelInputConfig(BaseModel):
},
)
use_kernels: bool | None = Field(
default=None,
json_schema_extra={"description": "Use custom kernels, e.g. MegaBlocks."},
)
model_quantization_config: Literal["Mxfp4Config"] | None = Field(
default=None,
json_schema_extra={"description": "Model loading quantization config"},
)
model_quantization_config_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={"description": "kwargs for model quantization config"},
)
@field_validator("trust_remote_code")
@classmethod
def hint_trust_remote_code(cls, trust_remote_code):

View File

@@ -972,6 +972,16 @@ class SystemValidationMixin:
raise ValueError("deepspeed and fsdp cannot be used together.")
return data
@model_validator(mode="before")
@classmethod
def check_model_quantization_config_vs_bnb(cls, data):
if data.get("model_quantization_config"):
if data.get("load_in_8bit") or data.get("load_in_4bit"):
raise ValueError(
"model_quantization_config and load_in_8bit or load_in_4bit cannot be used together."
)
return data
@model_validator(mode="before")
@classmethod
def check_npu_config(cls, data):