add kernels for gpt oss models (#3020)
* add kernels for gpt oss models * add support for gpt-oss * typo incorrect package * fix: layout for configs and added wandb/epochs * add gptoss example w offload and set moe leaf for z3 * add support for Mxfp4Config from yaml * update yaml to use official model * fix lora and don't allow triton to go above 3.3.1 * fix lr and tweak vram use * fix range for triton since pinned wasn't compatible with toch 2.6.0 * update cce with gpt oss patches --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
9
examples/gpt-oss/README.md
Normal file
9
examples/gpt-oss/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# OpenAI's GPT-OSS
|
||||
|
||||
GPT-OSS is a 20 billion parameter MoE model trained by OpenAI, released in August 2025.
|
||||
|
||||
- 20B Full Parameter SFT can be trained on 8x48GB GPUs (peak reserved memory @ ~36GiB/GPU) - [YAML](./gpt-oss-20b-fft-fsdp2.yaml)
|
||||
- 20B LoRA SFT (all linear layers, and experts in last two layers) can be trained a single GPU (peak reserved memory @ ~47GiB)
|
||||
- removing the experts from `lora_target_parameters` will allow the model to fit around ~44GiB of VRAM
|
||||
- [YAML](./gpt-oss-20b-sft-lora-singlegpu.yaml)
|
||||
- 20B Full Parameter SFT with FSDP2 offloading can be trained on 2x24GB GPUs (peak reserved memory @ ~21GiB/GPU) - [YAML](./gpt-oss-20b-fft-fsdp2-offload.yaml)
|
||||
62
examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
Normal file
62
examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
Normal file
@@ -0,0 +1,62 @@
|
||||
base_model: openai/gpt-oss-20b
|
||||
use_kernels: true
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: winglian/pirate-ultrachat-10k
|
||||
type: chat_template
|
||||
split: train
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: true
|
||||
state_dict_type: SHARDED_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||
reshard_after_forward: true
|
||||
62
examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
Normal file
62
examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
Normal file
@@ -0,0 +1,62 @@
|
||||
base_model: openai/gpt-oss-20b
|
||||
use_kernels: true
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: winglian/pirate-ultrachat-10k
|
||||
type: chat_template
|
||||
split: train
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
state_dict_type: SHARDED_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||
reshard_after_forward: true
|
||||
64
examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
Normal file
64
examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
base_model: openai/gpt-oss-20b
|
||||
use_kernels: true
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: winglian/pirate-ultrachat-10k
|
||||
type: chat_template
|
||||
split: train
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
|
||||
lora_target_linear: true
|
||||
lora_target_parameters: # target the experts in the last two layers
|
||||
- "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
- "22._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
- "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
- "23._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-4
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
warmup_ratio: 0.1
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.46.1
|
||||
triton>=3.0.0
|
||||
# triton 3.4.0 is not compatible with CCE
|
||||
triton>=3.0.0,<3.4.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
@@ -20,6 +21,7 @@ datasets==4.0.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.20.0
|
||||
hf_xet==1.1.5
|
||||
kernels==0.9.0
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169"'
|
||||
)
|
||||
|
||||
@@ -13,4 +13,5 @@ MOE_ARCH_BLOCK = {
|
||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||
"deepseek_v2": "DeepseekV2MoE",
|
||||
"gpt_oss": "GptOssExperts",
|
||||
}
|
||||
|
||||
@@ -567,9 +567,9 @@ class AxolotlTrainer(
|
||||
# Add memory usage
|
||||
try:
|
||||
active, allocated, reserved = get_gpu_memory_usage()
|
||||
logs["memory/max_memory_active(gib)"] = round(active, 2)
|
||||
logs["memory/max_memory_allocated(gib)"] = round(allocated, 2)
|
||||
logs["memory/device_memory_reserved(gib)"] = round(reserved, 2)
|
||||
logs["memory/max_mem_active(gib)"] = round(active, 2)
|
||||
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
|
||||
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
|
||||
except (ValueError, TypeError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -34,7 +34,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -202,6 +202,8 @@ class ModelLoader:
|
||||
self._set_device_map_config()
|
||||
if self.cfg.revision_of_model:
|
||||
self.model_kwargs["revision"] = self.cfg.revision_of_model
|
||||
if self.cfg.use_kernels:
|
||||
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
|
||||
self._set_quantization_config()
|
||||
self._set_attention_config()
|
||||
|
||||
@@ -565,8 +567,17 @@ class ModelLoader:
|
||||
|
||||
def _set_quantization_config(self):
|
||||
"""Set up quantization config (bitsandbytes, awq, gptq, etc.)"""
|
||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
||||
|
||||
if self.cfg.model_quantization_config == "Mxfp4Config":
|
||||
from transformers import Mxfp4Config
|
||||
|
||||
mxfp4_kwargs = {}
|
||||
if self.cfg.model_quantization_config_kwargs:
|
||||
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
|
||||
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
|
||||
else:
|
||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
||||
|
||||
if self.cfg.gptq:
|
||||
if not hasattr(self.model_config, "quantization_config"):
|
||||
@@ -648,7 +659,9 @@ class ModelLoader:
|
||||
|
||||
def _set_attention_config(self):
|
||||
"""Sample packing uses custom FA2 patch"""
|
||||
if self.cfg.flex_attention:
|
||||
if self.cfg.attn_implementation:
|
||||
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
|
||||
elif self.cfg.flex_attention:
|
||||
self.model_kwargs["attn_implementation"] = "flex_attention"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flex_attention"
|
||||
|
||||
@@ -544,6 +544,13 @@ class AxolotlInputConfig(
|
||||
|
||||
eager_attention: bool | None = None
|
||||
|
||||
attn_implementation: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Specify a custom attention implementation, used mostly for kernels."
|
||||
},
|
||||
)
|
||||
|
||||
unsloth_cross_entropy_loss: bool | None = None
|
||||
unsloth_lora_mlp: bool | None = None
|
||||
unsloth_lora_qkv: bool | None = None
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Pydantic models for model input / output, etc. configuration"""
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -70,6 +72,20 @@ class ModelInputConfig(BaseModel):
|
||||
},
|
||||
)
|
||||
|
||||
use_kernels: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Use custom kernels, e.g. MegaBlocks."},
|
||||
)
|
||||
|
||||
model_quantization_config: Literal["Mxfp4Config"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Model loading quantization config"},
|
||||
)
|
||||
model_quantization_config_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "kwargs for model quantization config"},
|
||||
)
|
||||
|
||||
@field_validator("trust_remote_code")
|
||||
@classmethod
|
||||
def hint_trust_remote_code(cls, trust_remote_code):
|
||||
|
||||
@@ -972,6 +972,16 @@ class SystemValidationMixin:
|
||||
raise ValueError("deepspeed and fsdp cannot be used together.")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_model_quantization_config_vs_bnb(cls, data):
|
||||
if data.get("model_quantization_config"):
|
||||
if data.get("load_in_8bit") or data.get("load_in_4bit"):
|
||||
raise ValueError(
|
||||
"model_quantization_config and load_in_8bit or load_in_4bit cannot be used together."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_npu_config(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user