cfg value
This commit is contained in:
66
examples/deepseek-v3/full-ft.yaml
Normal file
66
examples/deepseek-v3/full-ft.yaml
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
# Example full fine-tuning config for a DeepSeek-V3 MoE model using Axolotl's
|
||||||
|
# vendored Triton contiguous grouped GEMM kernels.
|
||||||
|
# Replace `your-org/deepseek-v3-model` with the name of the model you uploaded to HF.
|
||||||
|
|
||||||
|
base_model: axolotl-ai-co/deepseek-v3-8b
|
||||||
|
model_config_type: deepseek_v3
|
||||||
|
trust_remote_code: true
|
||||||
|
moe_kernels: true
|
||||||
|
|
||||||
|
# --- Data ------------------------------------------------------------------
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/deepseek-v3/full-ft
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
# --- Optimisation ----------------------------------------------------------
|
||||||
|
num_epochs: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
learning_rate: 2e-5
|
||||||
|
lr_scheduler: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.01
|
||||||
|
|
||||||
|
# --- Precision & Performance -----------------------------------------------
|
||||||
|
bf16: auto
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
# enable GC to keep activation memory manageable for the MoE blocks
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
|
||||||
|
# Axolotl automatically applies the DeepSeek-V3 MoE monkeypatch when
|
||||||
|
# model_config_type is set to `deepseek_v3`, routing matmuls through the
|
||||||
|
# vendored Triton kernels.
|
||||||
|
|
||||||
|
# --- Logging & Saving ------------------------------------------------------
|
||||||
|
logging_steps: 1
|
||||||
|
evals_per_epoch: 2
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# Uncomment the section below for multi-GPU training with FSDP
|
||||||
|
# fsdp:
|
||||||
|
# - full_shard
|
||||||
|
# - auto_wrap
|
||||||
|
# fsdp_config:
|
||||||
|
# fsdp_limit_all_gathers: true
|
||||||
|
# fsdp_sync_module_states: true
|
||||||
|
# fsdp_offload_params: true
|
||||||
|
# fsdp_use_orig_params: false
|
||||||
|
# fsdp_cpu_ram_efficient_loading: true
|
||||||
|
# fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
# fsdp_transformer_layer_cls_to_wrap: DeepseekV3MoE
|
||||||
|
# fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
# fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
|
||||||
|
# wandb_project:
|
||||||
|
# wandb_entity:
|
||||||
|
# wandb_name:
|
||||||
@@ -34,8 +34,15 @@ SCENARIOS: tuple[Scenario, ...] = (
|
|||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
parser.add_argument("--device", default="cuda", choices=["cuda"], help="Execution device")
|
parser.add_argument(
|
||||||
parser.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"], help="Computation dtype")
|
"--device", default="cuda", choices=["cuda"], help="Execution device"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
default="bf16",
|
||||||
|
choices=["bf16", "fp16", "fp32"],
|
||||||
|
help="Computation dtype",
|
||||||
|
)
|
||||||
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
|
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
|
||||||
parser.add_argument("--iters", type=int, default=20, help="Benchmark iterations")
|
parser.add_argument("--iters", type=int, default=20, help="Benchmark iterations")
|
||||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||||
@@ -56,7 +63,9 @@ def pick_dtype(name: str) -> torch.dtype:
|
|||||||
}[name]
|
}[name]
|
||||||
|
|
||||||
|
|
||||||
def make_indices(num_groups: int, group_size: int, device: torch.device) -> torch.Tensor:
|
def make_indices(
|
||||||
|
num_groups: int, group_size: int, device: torch.device
|
||||||
|
) -> torch.Tensor:
|
||||||
indices = torch.arange(num_groups, device=device, dtype=torch.int32)
|
indices = torch.arange(num_groups, device=device, dtype=torch.int32)
|
||||||
return indices.repeat_interleave(group_size)
|
return indices.repeat_interleave(group_size)
|
||||||
|
|
||||||
@@ -82,7 +91,9 @@ def run_scenario(
|
|||||||
group_size_m: int,
|
group_size_m: int,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if scenario.m % scenario.num_groups != 0:
|
if scenario.m % scenario.num_groups != 0:
|
||||||
raise ValueError(f"M ({scenario.m}) not divisible by groups ({scenario.num_groups})")
|
raise ValueError(
|
||||||
|
f"M ({scenario.m}) not divisible by groups ({scenario.num_groups})"
|
||||||
|
)
|
||||||
group_size = scenario.m // scenario.num_groups
|
group_size = scenario.m // scenario.num_groups
|
||||||
if group_size % group_size_m != 0:
|
if group_size % group_size_m != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -90,7 +101,9 @@ def run_scenario(
|
|||||||
)
|
)
|
||||||
|
|
||||||
inputs = torch.randn(scenario.m, scenario.k, device=device, dtype=dtype)
|
inputs = torch.randn(scenario.m, scenario.k, device=device, dtype=dtype)
|
||||||
weights = torch.randn(scenario.num_groups, scenario.n, scenario.k, device=device, dtype=dtype)
|
weights = torch.randn(
|
||||||
|
scenario.num_groups, scenario.n, scenario.k, device=device, dtype=dtype
|
||||||
|
)
|
||||||
indices = make_indices(scenario.num_groups, group_size, device)
|
indices = make_indices(scenario.num_groups, group_size, device)
|
||||||
|
|
||||||
def persistent():
|
def persistent():
|
||||||
|
|||||||
@@ -190,10 +190,14 @@ class PatchManager:
|
|||||||
|
|
||||||
apply_mistral_tokenizer_image_patch()
|
apply_mistral_tokenizer_image_patch()
|
||||||
|
|
||||||
if self.cfg.model_config_type == "deepseek_v3":
|
if self.cfg.moe_kernels and self.cfg.model_config_type == "deepseek_v3":
|
||||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
|
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
|
||||||
|
|
||||||
patch_deepseek_v3_moe()
|
patch_deepseek_v3_moe()
|
||||||
|
elif self.cfg.model_config_type == "deepseek_v3" and not self.cfg.moe_kernels:
|
||||||
|
LOG.info(
|
||||||
|
"Skipping DeepSeek V3 Triton MoE kernels; enable with `moe_kernels: true`"
|
||||||
|
)
|
||||||
|
|
||||||
def _apply_fp8_patches(self):
|
def _apply_fp8_patches(self):
|
||||||
"""Apply patches for FP8 support."""
|
"""Apply patches for FP8 support."""
|
||||||
|
|||||||
@@ -106,6 +106,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs"
|
"description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
moe_kernels: bool = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Enable Axolotl's vendored MoE kernels when supported (e.g., DeepSeek V3)"
|
||||||
|
},
|
||||||
|
)
|
||||||
reinit_weights: bool | None = Field(
|
reinit_weights: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
Reference in New Issue
Block a user