add kernels for gpt oss models (#3020)

* add kernels for gpt oss models

* add support for gpt-oss

* typo incorrect package

* fix: layout for configs and added wandb/epochs

* add gptoss example w offload and set moe leaf for z3

* add support for Mxfp4Config from yaml

* update yaml to use official model

* fix lora and don't allow triton to go above 3.3.1

* fix lr and tweak vram use

* fix range for triton since pinned wasn't compatible with toch 2.6.0

* update cce with gpt oss patches

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Wing Lian
2025-08-06 09:47:55 -04:00
committed by GitHub
parent 97e86c6d47
commit ba3dba3e4f
15 changed files with 257 additions and 11 deletions

View File

@@ -13,4 +13,5 @@ MOE_ARCH_BLOCK = {
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
"gpt_oss": "GptOssExperts",
}

View File

@@ -567,9 +567,9 @@ class AxolotlTrainer(
# Add memory usage
try:
active, allocated, reserved = get_gpu_memory_usage()
logs["memory/max_memory_active(gib)"] = round(active, 2)
logs["memory/max_memory_allocated(gib)"] = round(allocated, 2)
logs["memory/device_memory_reserved(gib)"] = round(reserved, 2)
logs["memory/max_mem_active(gib)"] = round(active, 2)
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
except (ValueError, TypeError, FileNotFoundError):
pass

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169"
```
## Usage

View File

@@ -34,7 +34,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169"`'
)

View File

@@ -202,6 +202,8 @@ class ModelLoader:
self._set_device_map_config()
if self.cfg.revision_of_model:
self.model_kwargs["revision"] = self.cfg.revision_of_model
if self.cfg.use_kernels:
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
self._set_quantization_config()
self._set_attention_config()
@@ -565,8 +567,17 @@ class ModelLoader:
def _set_quantization_config(self):
"""Set up quantization config (bitsandbytes, awq, gptq, etc.)"""
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
if self.cfg.model_quantization_config == "Mxfp4Config":
from transformers import Mxfp4Config
mxfp4_kwargs = {}
if self.cfg.model_quantization_config_kwargs:
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
else:
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
if self.cfg.gptq:
if not hasattr(self.model_config, "quantization_config"):
@@ -648,7 +659,9 @@ class ModelLoader:
def _set_attention_config(self):
"""Sample packing uses custom FA2 patch"""
if self.cfg.flex_attention:
if self.cfg.attn_implementation:
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
elif self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"

View File

@@ -544,6 +544,13 @@ class AxolotlInputConfig(
eager_attention: bool | None = None
attn_implementation: str | None = Field(
default=None,
json_schema_extra={
"description": "Specify a custom attention implementation, used mostly for kernels."
},
)
unsloth_cross_entropy_loss: bool | None = None
unsloth_lora_mlp: bool | None = None
unsloth_lora_qkv: bool | None = None

View File

@@ -1,5 +1,7 @@
"""Pydantic models for model input / output, etc. configuration"""
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from axolotl.utils.logging import get_logger
@@ -70,6 +72,20 @@ class ModelInputConfig(BaseModel):
},
)
use_kernels: bool | None = Field(
default=None,
json_schema_extra={"description": "Use custom kernels, e.g. MegaBlocks."},
)
model_quantization_config: Literal["Mxfp4Config"] | None = Field(
default=None,
json_schema_extra={"description": "Model loading quantization config"},
)
model_quantization_config_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={"description": "kwargs for model quantization config"},
)
@field_validator("trust_remote_code")
@classmethod
def hint_trust_remote_code(cls, trust_remote_code):

View File

@@ -972,6 +972,16 @@ class SystemValidationMixin:
raise ValueError("deepspeed and fsdp cannot be used together.")
return data
@model_validator(mode="before")
@classmethod
def check_model_quantization_config_vs_bnb(cls, data):
if data.get("model_quantization_config"):
if data.get("load_in_8bit") or data.get("load_in_4bit"):
raise ValueError(
"model_quantization_config and load_in_8bit or load_in_4bit cannot be used together."
)
return data
@model_validator(mode="before")
@classmethod
def check_npu_config(cls, data):