From ba3dba3e4f6fbe845b0249f517c3bff88d898e22 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 6 Aug 2025 09:47:55 -0400 Subject: [PATCH] add kernels for gpt oss models (#3020) * add kernels for gpt oss models * add support for gpt-oss * typo incorrect package * fix: layout for configs and added wandb/epochs * add gptoss example w offload and set moe leaf for z3 * add support for Mxfp4Config from yaml * update yaml to use official model * fix lora and don't allow triton to go above 3.3.1 * fix lr and tweak vram use * fix range for triton since pinned wasn't compatible with toch 2.6.0 * update cce with gpt oss patches --------- Co-authored-by: NanoCode012 --- .../colab-axolotl-example.ipynb | 2 +- examples/gpt-oss/README.md | 9 +++ .../gpt-oss-20b-fft-fsdp2-offload.yaml | 62 ++++++++++++++++++ examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml | 62 ++++++++++++++++++ .../gpt-oss-20b-sft-lora-singlegpu.yaml | 64 +++++++++++++++++++ requirements.txt | 4 +- scripts/cutcrossentropy_install.py | 2 +- src/axolotl/common/architectures.py | 1 + src/axolotl/core/trainers/base.py | 6 +- .../integrations/cut_cross_entropy/README.md | 2 +- .../cut_cross_entropy/__init__.py | 2 +- src/axolotl/loaders/model.py | 19 +++++- src/axolotl/utils/schemas/config.py | 7 ++ src/axolotl/utils/schemas/model.py | 16 +++++ src/axolotl/utils/schemas/validation.py | 10 +++ 15 files changed, 257 insertions(+), 11 deletions(-) create mode 100644 examples/gpt-oss/README.md create mode 100644 examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml create mode 100644 examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml create mode 100644 examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index 6c6e21f94..c283092be 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -40,7 +40,7 @@ "%%capture\n", "# This step can take ~5-10 minutes to install dependencies\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", - "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0\"" + "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169\"" ] }, { diff --git a/examples/gpt-oss/README.md b/examples/gpt-oss/README.md new file mode 100644 index 000000000..7157806af --- /dev/null +++ b/examples/gpt-oss/README.md @@ -0,0 +1,9 @@ +# OpenAI's GPT-OSS + +GPT-OSS is a 20 billion parameter MoE model trained by OpenAI, released in August 2025. + +- 20B Full Parameter SFT can be trained on 8x48GB GPUs (peak reserved memory @ ~36GiB/GPU) - [YAML](./gpt-oss-20b-fft-fsdp2.yaml) +- 20B LoRA SFT (all linear layers, and experts in last two layers) can be trained a single GPU (peak reserved memory @ ~47GiB) + - removing the experts from `lora_target_parameters` will allow the model to fit around ~44GiB of VRAM + - [YAML](./gpt-oss-20b-sft-lora-singlegpu.yaml) +- 20B Full Parameter SFT with FSDP2 offloading can be trained on 2x24GB GPUs (peak reserved memory @ ~21GiB/GPU) - [YAML](./gpt-oss-20b-fft-fsdp2-offload.yaml) diff --git a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml new file mode 100644 index 000000000..d55a272ba --- /dev/null +++ b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml @@ -0,0 +1,62 @@ +base_model: openai/gpt-oss-20b +use_kernels: true +model_quantization_config: Mxfp4Config +model_quantization_config_kwargs: + dequantize: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding + +datasets: + - path: winglian/pirate-ultrachat-10k + type: chat_template + split: train + +dataset_prepared_path: last_run_prepared +val_set_size: 0 +output_dir: ./outputs/gpt-oss-out/ + +sequence_len: 4096 +sample_packing: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 1 +num_epochs: 1 + +optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload +lr_scheduler: constant_with_warmup +learning_rate: 2e-5 + +bf16: true +tf32: true + +flash_attention: true +attn_implementation: kernels-community/vllm-flash-attn3 + +gradient_checkpointing: true +activation_offloading: true + +logging_steps: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 + +special_tokens: +eot_tokens: + - "<|end|>" + +fsdp_version: 2 +fsdp_config: + offload_params: true + state_dict_type: SHARDED_STATE_DICT + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: GptOssDecoderLayer + reshard_after_forward: true diff --git a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml new file mode 100644 index 000000000..f9f2c1dce --- /dev/null +++ b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml @@ -0,0 +1,62 @@ +base_model: openai/gpt-oss-20b +use_kernels: true +model_quantization_config: Mxfp4Config +model_quantization_config_kwargs: + dequantize: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding + +datasets: + - path: winglian/pirate-ultrachat-10k + type: chat_template + split: train + +dataset_prepared_path: last_run_prepared +val_set_size: 0 +output_dir: ./outputs/gpt-oss-out/ + +sequence_len: 4096 +sample_packing: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 1 +num_epochs: 1 + +optimizer: adamw_torch_8bit +lr_scheduler: constant_with_warmup +learning_rate: 2e-5 + +bf16: true +tf32: true + +flash_attention: true +attn_implementation: kernels-community/vllm-flash-attn3 + +gradient_checkpointing: true +activation_offloading: true + +logging_steps: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 + +special_tokens: +eot_tokens: + - "<|end|>" + +fsdp_version: 2 +fsdp_config: + offload_params: false + state_dict_type: SHARDED_STATE_DICT + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: GptOssDecoderLayer + reshard_after_forward: true diff --git a/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml b/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml new file mode 100644 index 000000000..f7c332dfe --- /dev/null +++ b/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml @@ -0,0 +1,64 @@ +base_model: openai/gpt-oss-20b +use_kernels: true +model_quantization_config: Mxfp4Config +model_quantization_config_kwargs: + dequantize: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding + +datasets: + - path: winglian/pirate-ultrachat-10k + type: chat_template + split: train + +dataset_prepared_path: last_run_prepared +val_set_size: 0 +output_dir: ./outputs/gpt-oss-out/ + +sequence_len: 4096 +sample_packing: true + +adapter: lora +lora_r: 8 +lora_alpha: 16 +lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters +lora_target_linear: true +lora_target_parameters: # target the experts in the last two layers + - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj" + - "22._checkpoint_wrapped_module.mlp.experts.down_proj" + - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj" + - "23._checkpoint_wrapped_module.mlp.experts.down_proj" + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 8 +micro_batch_size: 1 +num_epochs: 1 + +optimizer: adamw_torch_8bit +lr_scheduler: constant_with_warmup +learning_rate: 2e-4 + +bf16: true +tf32: true + +flash_attention: true +attn_implementation: kernels-community/vllm-flash-attn3 + +gradient_checkpointing: true +activation_offloading: true + +logging_steps: 1 +saves_per_epoch: 1 +warmup_ratio: 0.1 + +special_tokens: +eot_tokens: + - "<|end|>" diff --git a/requirements.txt b/requirements.txt index 244d1239c..0103ba919 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,8 @@ # START section of dependencies that don't install on Darwin/MacOS bitsandbytes==0.46.1 -triton>=3.0.0 +# triton 3.4.0 is not compatible with CCE +triton>=3.0.0,<3.4.0 mamba-ssm==1.2.0.post1 xformers>=0.0.23.post1 autoawq==0.2.7.post3 @@ -20,6 +21,7 @@ datasets==4.0.0 deepspeed>=0.17.0 trl==0.20.0 hf_xet==1.1.5 +kernels==0.9.0 optimum==1.16.2 hf_transfer diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index e76749493..cf9ced60c 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else "" print( UNINSTALL_PREFIX - + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"' + + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169"' ) diff --git a/src/axolotl/common/architectures.py b/src/axolotl/common/architectures.py index 2f77b613e..58d557e7e 100644 --- a/src/axolotl/common/architectures.py +++ b/src/axolotl/common/architectures.py @@ -13,4 +13,5 @@ MOE_ARCH_BLOCK = { "qwen2_moe": "Qwen2MoeSparseMoeBlock", "qwen3_moe": "Qwen3MoeSparseMoeBlock", "deepseek_v2": "DeepseekV2MoE", + "gpt_oss": "GptOssExperts", } diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 617506eb2..3540fb6a1 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -567,9 +567,9 @@ class AxolotlTrainer( # Add memory usage try: active, allocated, reserved = get_gpu_memory_usage() - logs["memory/max_memory_active(gib)"] = round(active, 2) - logs["memory/max_memory_allocated(gib)"] = round(allocated, 2) - logs["memory/device_memory_reserved(gib)"] = round(reserved, 2) + logs["memory/max_mem_active(gib)"] = round(active, 2) + logs["memory/max_mem_allocated(gib)"] = round(allocated, 2) + logs["memory/device_mem_reserved(gib)"] = round(reserved, 2) except (ValueError, TypeError, FileNotFoundError): pass diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index 048559789..e0ff14db8 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh - If you are installing from pip ```bash -pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0" +pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169" ``` ## Usage diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index d1419e27e..24cd7b6a7 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -34,7 +34,7 @@ LOG = get_logger(__name__) _CCE_INSTALL_MESSAGE = ( "Please install Axolotl's fork of cut_cross_entropy with transformers support using " - '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"`' + '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@48b5169"`' ) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 1b983f7d0..7c9e2d2bc 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -202,6 +202,8 @@ class ModelLoader: self._set_device_map_config() if self.cfg.revision_of_model: self.model_kwargs["revision"] = self.cfg.revision_of_model + if self.cfg.use_kernels: + self.model_kwargs["use_kernels"] = self.cfg.use_kernels self._set_quantization_config() self._set_attention_config() @@ -565,8 +567,17 @@ class ModelLoader: def _set_quantization_config(self): """Set up quantization config (bitsandbytes, awq, gptq, etc.)""" - self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit - self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit + + if self.cfg.model_quantization_config == "Mxfp4Config": + from transformers import Mxfp4Config + + mxfp4_kwargs = {} + if self.cfg.model_quantization_config_kwargs: + mxfp4_kwargs = self.cfg.model_quantization_config_kwargs + self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs) + else: + self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit + self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit if self.cfg.gptq: if not hasattr(self.model_config, "quantization_config"): @@ -648,7 +659,9 @@ class ModelLoader: def _set_attention_config(self): """Sample packing uses custom FA2 patch""" - if self.cfg.flex_attention: + if self.cfg.attn_implementation: + self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation + elif self.cfg.flex_attention: self.model_kwargs["attn_implementation"] = "flex_attention" self.model_config._attn_implementation = ( # pylint: disable=protected-access "flex_attention" diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index beaee57c9..e3de6e37b 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -544,6 +544,13 @@ class AxolotlInputConfig( eager_attention: bool | None = None + attn_implementation: str | None = Field( + default=None, + json_schema_extra={ + "description": "Specify a custom attention implementation, used mostly for kernels." + }, + ) + unsloth_cross_entropy_loss: bool | None = None unsloth_lora_mlp: bool | None = None unsloth_lora_qkv: bool | None = None diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py index eae8dacb6..eb751bfcc 100644 --- a/src/axolotl/utils/schemas/model.py +++ b/src/axolotl/utils/schemas/model.py @@ -1,5 +1,7 @@ """Pydantic models for model input / output, etc. configuration""" +from typing import Any, Literal + from pydantic import BaseModel, Field, field_validator from axolotl.utils.logging import get_logger @@ -70,6 +72,20 @@ class ModelInputConfig(BaseModel): }, ) + use_kernels: bool | None = Field( + default=None, + json_schema_extra={"description": "Use custom kernels, e.g. MegaBlocks."}, + ) + + model_quantization_config: Literal["Mxfp4Config"] | None = Field( + default=None, + json_schema_extra={"description": "Model loading quantization config"}, + ) + model_quantization_config_kwargs: dict[str, Any] | None = Field( + default=None, + json_schema_extra={"description": "kwargs for model quantization config"}, + ) + @field_validator("trust_remote_code") @classmethod def hint_trust_remote_code(cls, trust_remote_code): diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index e15adf077..ac3355f74 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -972,6 +972,16 @@ class SystemValidationMixin: raise ValueError("deepspeed and fsdp cannot be used together.") return data + @model_validator(mode="before") + @classmethod + def check_model_quantization_config_vs_bnb(cls, data): + if data.get("model_quantization_config"): + if data.get("load_in_8bit") or data.get("load_in_4bit"): + raise ValueError( + "model_quantization_config and load_in_8bit or load_in_4bit cannot be used together." + ) + return data + @model_validator(mode="before") @classmethod def check_npu_config(cls, data):