diff --git a/examples/deepseek-v3/full-ft.yaml b/examples/deepseek-v3/full-ft.yaml new file mode 100644 index 000000000..dec8ab7c2 --- /dev/null +++ b/examples/deepseek-v3/full-ft.yaml @@ -0,0 +1,66 @@ +# Example full fine-tuning config for a DeepSeek-V3 MoE model using Axolotl's +# vendored Triton contiguous grouped GEMM kernels. +# Replace `your-org/deepseek-v3-model` with the name of the model you uploaded to HF. + +base_model: axolotl-ai-co/deepseek-v3-8b +model_config_type: deepseek_v3 +trust_remote_code: true +moe_kernels: true + +# --- Data ------------------------------------------------------------------ +datasets: + - path: tatsu-lab/alpaca + type: alpaca + +val_set_size: 0.0 +output_dir: ./outputs/deepseek-v3/full-ft + +sequence_len: 4096 +sample_packing: true + +# --- Optimisation ---------------------------------------------------------- +num_epochs: 1 +micro_batch_size: 1 +gradient_accumulation_steps: 8 +optimizer: adamw_torch_fused +learning_rate: 2e-5 +lr_scheduler: cosine +warmup_ratio: 0.1 +weight_decay: 0.01 + +# --- Precision & Performance ----------------------------------------------- +bf16: auto +flash_attention: true + +# enable GC to keep activation memory manageable for the MoE blocks +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false + +# Axolotl automatically applies the DeepSeek-V3 MoE monkeypatch when +# model_config_type is set to `deepseek_v3`, routing matmuls through the +# vendored Triton kernels. + +# --- Logging & Saving ------------------------------------------------------ +logging_steps: 1 +evals_per_epoch: 2 +saves_per_epoch: 1 + +# Uncomment the section below for multi-GPU training with FSDP +# fsdp: +# - full_shard +# - auto_wrap +# fsdp_config: +# fsdp_limit_all_gathers: true +# fsdp_sync_module_states: true +# fsdp_offload_params: true +# fsdp_use_orig_params: false +# fsdp_cpu_ram_efficient_loading: true +# fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP +# fsdp_transformer_layer_cls_to_wrap: DeepseekV3MoE +# fsdp_state_dict_type: FULL_STATE_DICT +# fsdp_sharding_strategy: FULL_SHARD + +# wandb_project: +# wandb_entity: +# wandb_name: diff --git a/scripts/benchmarks/deepseek_v3_group_gemm_table.py b/scripts/benchmarks/deepseek_v3_group_gemm_table.py index d029ad21d..e0aae9cda 100644 --- a/scripts/benchmarks/deepseek_v3_group_gemm_table.py +++ b/scripts/benchmarks/deepseek_v3_group_gemm_table.py @@ -34,8 +34,15 @@ SCENARIOS: tuple[Scenario, ...] = ( def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--device", default="cuda", choices=["cuda"], help="Execution device") - parser.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"], help="Computation dtype") + parser.add_argument( + "--device", default="cuda", choices=["cuda"], help="Execution device" + ) + parser.add_argument( + "--dtype", + default="bf16", + choices=["bf16", "fp16", "fp32"], + help="Computation dtype", + ) parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations") parser.add_argument("--iters", type=int, default=20, help="Benchmark iterations") parser.add_argument("--seed", type=int, default=0, help="Random seed") @@ -56,7 +63,9 @@ def pick_dtype(name: str) -> torch.dtype: }[name] -def make_indices(num_groups: int, group_size: int, device: torch.device) -> torch.Tensor: +def make_indices( + num_groups: int, group_size: int, device: torch.device +) -> torch.Tensor: indices = torch.arange(num_groups, device=device, dtype=torch.int32) return indices.repeat_interleave(group_size) @@ -82,7 +91,9 @@ def run_scenario( group_size_m: int, ) -> dict: if scenario.m % scenario.num_groups != 0: - raise ValueError(f"M ({scenario.m}) not divisible by groups ({scenario.num_groups})") + raise ValueError( + f"M ({scenario.m}) not divisible by groups ({scenario.num_groups})" + ) group_size = scenario.m // scenario.num_groups if group_size % group_size_m != 0: raise ValueError( @@ -90,7 +101,9 @@ def run_scenario( ) inputs = torch.randn(scenario.m, scenario.k, device=device, dtype=dtype) - weights = torch.randn(scenario.num_groups, scenario.n, scenario.k, device=device, dtype=dtype) + weights = torch.randn( + scenario.num_groups, scenario.n, scenario.k, device=device, dtype=dtype + ) indices = make_indices(scenario.num_groups, group_size, device) def persistent(): diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index dafa8a28c..446c73640 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -190,10 +190,14 @@ class PatchManager: apply_mistral_tokenizer_image_patch() - if self.cfg.model_config_type == "deepseek_v3": + if self.cfg.moe_kernels and self.cfg.model_config_type == "deepseek_v3": from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe patch_deepseek_v3_moe() + elif self.cfg.model_config_type == "deepseek_v3" and not self.cfg.moe_kernels: + LOG.info( + "Skipping DeepSeek V3 Triton MoE kernels; enable with `moe_kernels: true`" + ) def _apply_fp8_patches(self): """Apply patches for FP8 support.""" diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 0177b19f6..20698d920 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -106,6 +106,12 @@ class AxolotlInputConfig( "description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs" }, ) + moe_kernels: bool = Field( + default=False, + json_schema_extra={ + "description": "Enable Axolotl's vendored MoE kernels when supported (e.g., DeepSeek V3)" + }, + ) reinit_weights: bool | None = Field( default=None, json_schema_extra={