add kernels for gpt oss models (#3020)
* add kernels for gpt oss models * add support for gpt-oss * typo incorrect package * fix: layout for configs and added wandb/epochs * add gptoss example w offload and set moe leaf for z3 * add support for Mxfp4Config from yaml * update yaml to use official model * fix lora and don't allow triton to go above 3.3.1 * fix lr and tweak vram use * fix range for triton since pinned wasn't compatible with toch 2.6.0 * update cce with gpt oss patches --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
9
examples/gpt-oss/README.md
Normal file
9
examples/gpt-oss/README.md
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# OpenAI's GPT-OSS
|
||||||
|
|
||||||
|
GPT-OSS is a 20 billion parameter MoE model trained by OpenAI, released in August 2025.
|
||||||
|
|
||||||
|
- 20B Full Parameter SFT can be trained on 8x48GB GPUs (peak reserved memory @ ~36GiB/GPU) - [YAML](./gpt-oss-20b-fft-fsdp2.yaml)
|
||||||
|
- 20B LoRA SFT (all linear layers, and experts in last two layers) can be trained a single GPU (peak reserved memory @ ~47GiB)
|
||||||
|
- removing the experts from `lora_target_parameters` will allow the model to fit around ~44GiB of VRAM
|
||||||
|
- [YAML](./gpt-oss-20b-sft-lora-singlegpu.yaml)
|
||||||
|
- 20B Full Parameter SFT with FSDP2 offloading can be trained on 2x24GB GPUs (peak reserved memory @ ~21GiB/GPU) - [YAML](./gpt-oss-20b-fft-fsdp2-offload.yaml)
|
||||||
62
examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
Normal file
62
examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
base_model: openai/gpt-oss-20b
|
||||||
|
use_kernels: true
|
||||||
|
model_quantization_config: Mxfp4Config
|
||||||
|
model_quantization_config_kwargs:
|
||||||
|
dequantize: true
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: winglian/pirate-ultrachat-10k
|
||||||
|
type: chat_template
|
||||||
|
split: train
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./outputs/gpt-oss-out/
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
|
||||||
|
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
|
||||||
|
lr_scheduler: constant_with_warmup
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
flash_attention: true
|
||||||
|
attn_implementation: kernels-community/vllm-flash-attn3
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
activation_offloading: true
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
eot_tokens:
|
||||||
|
- "<|end|>"
|
||||||
|
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: true
|
||||||
|
state_dict_type: SHARDED_STATE_DICT
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||||
|
reshard_after_forward: true
|
||||||
62
examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
Normal file
62
examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
base_model: openai/gpt-oss-20b
|
||||||
|
use_kernels: true
|
||||||
|
model_quantization_config: Mxfp4Config
|
||||||
|
model_quantization_config_kwargs:
|
||||||
|
dequantize: true
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: winglian/pirate-ultrachat-10k
|
||||||
|
type: chat_template
|
||||||
|
split: train
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./outputs/gpt-oss-out/
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: constant_with_warmup
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
flash_attention: true
|
||||||
|
attn_implementation: kernels-community/vllm-flash-attn3
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
activation_offloading: true
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
eot_tokens:
|
||||||
|
- "<|end|>"
|
||||||
|
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
state_dict_type: SHARDED_STATE_DICT
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||||
|
reshard_after_forward: true
|
||||||
64
examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
Normal file
64
examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
base_model: openai/gpt-oss-20b
|
||||||
|
use_kernels: true
|
||||||
|
model_quantization_config: Mxfp4Config
|
||||||
|
model_quantization_config_kwargs:
|
||||||
|
dequantize: true
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: winglian/pirate-ultrachat-10k
|
||||||
|
type: chat_template
|
||||||
|
split: train
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./outputs/gpt-oss-out/
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_r: 8
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_target_parameters: # target the experts in the last two layers
|
||||||
|
- "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||||
|
- "22._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||||
|
- "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||||
|
- "23._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: constant_with_warmup
|
||||||
|
learning_rate: 2e-4
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
flash_attention: true
|
||||||
|
attn_implementation: kernels-community/vllm-flash-attn3
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
activation_offloading: true
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
eot_tokens:
|
||||||
|
- "<|end|>"
|
||||||
@@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
# START section of dependencies that don't install on Darwin/MacOS
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
bitsandbytes==0.46.1
|
bitsandbytes==0.46.1
|
||||||
triton>=3.0.0
|
# triton 3.4.0 is not compatible with CCE
|
||||||
|
triton>=3.0.0,<3.4.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
autoawq==0.2.7.post3
|
autoawq==0.2.7.post3
|
||||||
@@ -20,6 +21,7 @@ datasets==4.0.0
|
|||||||
deepspeed>=0.17.0
|
deepspeed>=0.17.0
|
||||||
trl==0.20.0
|
trl==0.20.0
|
||||||
hf_xet==1.1.5
|
hf_xet==1.1.5
|
||||||
|
kernels==0.9.0
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,4 +13,5 @@ MOE_ARCH_BLOCK = {
|
|||||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||||
"deepseek_v2": "DeepseekV2MoE",
|
"deepseek_v2": "DeepseekV2MoE",
|
||||||
|
"gpt_oss": "GptOssExperts",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -567,9 +567,9 @@ class AxolotlTrainer(
|
|||||||
# Add memory usage
|
# Add memory usage
|
||||||
try:
|
try:
|
||||||
active, allocated, reserved = get_gpu_memory_usage()
|
active, allocated, reserved = get_gpu_memory_usage()
|
||||||
logs["memory/max_memory_active(gib)"] = round(active, 2)
|
logs["memory/max_mem_active(gib)"] = round(active, 2)
|
||||||
logs["memory/max_memory_allocated(gib)"] = round(allocated, 2)
|
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
|
||||||
logs["memory/device_memory_reserved(gib)"] = round(reserved, 2)
|
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
|
||||||
except (ValueError, TypeError, FileNotFoundError):
|
except (ValueError, TypeError, FileNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -202,6 +202,8 @@ class ModelLoader:
|
|||||||
self._set_device_map_config()
|
self._set_device_map_config()
|
||||||
if self.cfg.revision_of_model:
|
if self.cfg.revision_of_model:
|
||||||
self.model_kwargs["revision"] = self.cfg.revision_of_model
|
self.model_kwargs["revision"] = self.cfg.revision_of_model
|
||||||
|
if self.cfg.use_kernels:
|
||||||
|
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
|
||||||
self._set_quantization_config()
|
self._set_quantization_config()
|
||||||
self._set_attention_config()
|
self._set_attention_config()
|
||||||
|
|
||||||
@@ -565,8 +567,17 @@ class ModelLoader:
|
|||||||
|
|
||||||
def _set_quantization_config(self):
|
def _set_quantization_config(self):
|
||||||
"""Set up quantization config (bitsandbytes, awq, gptq, etc.)"""
|
"""Set up quantization config (bitsandbytes, awq, gptq, etc.)"""
|
||||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
|
||||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
if self.cfg.model_quantization_config == "Mxfp4Config":
|
||||||
|
from transformers import Mxfp4Config
|
||||||
|
|
||||||
|
mxfp4_kwargs = {}
|
||||||
|
if self.cfg.model_quantization_config_kwargs:
|
||||||
|
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
|
||||||
|
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
|
||||||
|
else:
|
||||||
|
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||||
|
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
||||||
|
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
if not hasattr(self.model_config, "quantization_config"):
|
if not hasattr(self.model_config, "quantization_config"):
|
||||||
@@ -648,7 +659,9 @@ class ModelLoader:
|
|||||||
|
|
||||||
def _set_attention_config(self):
|
def _set_attention_config(self):
|
||||||
"""Sample packing uses custom FA2 patch"""
|
"""Sample packing uses custom FA2 patch"""
|
||||||
if self.cfg.flex_attention:
|
if self.cfg.attn_implementation:
|
||||||
|
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
|
||||||
|
elif self.cfg.flex_attention:
|
||||||
self.model_kwargs["attn_implementation"] = "flex_attention"
|
self.model_kwargs["attn_implementation"] = "flex_attention"
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"flex_attention"
|
"flex_attention"
|
||||||
|
|||||||
@@ -544,6 +544,13 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
eager_attention: bool | None = None
|
eager_attention: bool | None = None
|
||||||
|
|
||||||
|
attn_implementation: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Specify a custom attention implementation, used mostly for kernels."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
unsloth_cross_entropy_loss: bool | None = None
|
unsloth_cross_entropy_loss: bool | None = None
|
||||||
unsloth_lora_mlp: bool | None = None
|
unsloth_lora_mlp: bool | None = None
|
||||||
unsloth_lora_qkv: bool | None = None
|
unsloth_lora_qkv: bool | None = None
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Pydantic models for model input / output, etc. configuration"""
|
"""Pydantic models for model input / output, etc. configuration"""
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -70,6 +72,20 @@ class ModelInputConfig(BaseModel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
use_kernels: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Use custom kernels, e.g. MegaBlocks."},
|
||||||
|
)
|
||||||
|
|
||||||
|
model_quantization_config: Literal["Mxfp4Config"] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Model loading quantization config"},
|
||||||
|
)
|
||||||
|
model_quantization_config_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "kwargs for model quantization config"},
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("trust_remote_code")
|
@field_validator("trust_remote_code")
|
||||||
@classmethod
|
@classmethod
|
||||||
def hint_trust_remote_code(cls, trust_remote_code):
|
def hint_trust_remote_code(cls, trust_remote_code):
|
||||||
|
|||||||
@@ -972,6 +972,16 @@ class SystemValidationMixin:
|
|||||||
raise ValueError("deepspeed and fsdp cannot be used together.")
|
raise ValueError("deepspeed and fsdp cannot be used together.")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_model_quantization_config_vs_bnb(cls, data):
|
||||||
|
if data.get("model_quantization_config"):
|
||||||
|
if data.get("load_in_8bit") or data.get("load_in_4bit"):
|
||||||
|
raise ValueError(
|
||||||
|
"model_quantization_config and load_in_8bit or load_in_4bit cannot be used together."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_npu_config(cls, data):
|
def check_npu_config(cls, data):
|
||||||
|
|||||||
Reference in New Issue
Block a user