Compare commits
41 Commits
cp-sdpa
...
vendor-moe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dd85358543 | ||
|
|
55d98db0d0 | ||
|
|
d90ade3b1b | ||
|
|
824a641cee | ||
|
|
e003a05177 | ||
|
|
91393c4dc8 | ||
|
|
d578c53603 | ||
|
|
4db7a21ff7 | ||
|
|
3b2e05c563 | ||
|
|
1037ca3a97 | ||
|
|
6369dcd7b8 | ||
|
|
a81612305c | ||
|
|
d0da67eb17 | ||
|
|
8a1f5ae940 | ||
|
|
146ca48cba | ||
|
|
fd312f6058 | ||
|
|
ab8fa56b16 | ||
|
|
1640cd4006 | ||
|
|
3277d44d71 | ||
|
|
d3e1b0ef1a | ||
|
|
5b97633faa | ||
|
|
94cbc6d42d | ||
|
|
493616fc3d | ||
|
|
d2b25c7327 | ||
|
|
b670c45276 | ||
|
|
61faf4cbe4 | ||
|
|
8d8fa834a2 | ||
|
|
9d69c6fb3e | ||
|
|
92f2f6e73c | ||
|
|
e5d2aebe16 | ||
|
|
4ab9e3f58b | ||
|
|
5788832812 | ||
|
|
db782430f8 | ||
|
|
5c74edeefe | ||
|
|
18269ee6a9 | ||
|
|
6a45d804f9 | ||
|
|
95e607574a | ||
|
|
f9748c4dc5 | ||
|
|
33975ce4bc | ||
|
|
e8b962d47f | ||
|
|
856ff12171 |
@@ -267,6 +267,7 @@ website:
|
||||
- docs/dataset_loading.qmd
|
||||
- docs/qat.qmd
|
||||
- docs/quantize.qmd
|
||||
- docs/optimizations.qmd
|
||||
|
||||
- section: "Core Concepts"
|
||||
contents:
|
||||
|
||||
@@ -212,6 +212,14 @@ Instead of passing `tools` via the system prompt, an alternative method would be
|
||||
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
If you have tool arguments with same name but different dtypes (like `"time": string` and `"time": number`), please save `arguments: ` as JSON string to prevent `datasets` from having casting issues.
|
||||
|
||||
```
|
||||
"arguments": "{\"...\": \"...\"}"
|
||||
```
|
||||
:::
|
||||
|
||||
Example config for Llama4:
|
||||
```yaml
|
||||
chat_template: llama4
|
||||
|
||||
@@ -61,7 +61,7 @@ While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet
|
||||
|
||||
### Pre-training without streaming
|
||||
|
||||
On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
|
||||
In the case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
|
||||
|
||||
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.
|
||||
|
||||
|
||||
@@ -140,3 +140,7 @@ description: Frequently asked questions
|
||||
**Q: `ValueError("Backward pass should have cleared tracker of all tensors")`
|
||||
|
||||
> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML.
|
||||
|
||||
**Q: `Error parsing tool_calls arguments as JSON.`
|
||||
|
||||
> A: There is an error parsing string arguments to a dict. Please check your dataset and the error message for more details.
|
||||
|
||||
133
docs/optimizations.qmd
Normal file
133
docs/optimizations.qmd
Normal file
@@ -0,0 +1,133 @@
|
||||
---
|
||||
title: Optimizations Guide
|
||||
description: A guide to the performance and memory optimizations available in Axolotl.
|
||||
---
|
||||
|
||||
Axolotl includes numerous optimizations to speed up training, reduce memory usage, and handle large models.
|
||||
|
||||
This guide provides a high-level overview and directs you to the detailed documentation for each feature.
|
||||
|
||||
## Speed Optimizations
|
||||
|
||||
These optimizations focus on increasing training throughput and reducing total training time.
|
||||
|
||||
### Sample Packing
|
||||
|
||||
Improves GPU utilization by combining multiple short sequences into a single packed sequence for training. This requires enabling one of the [attention](#attention-implementations) implementations below.
|
||||
|
||||
- **Config:** `sample_packing: true`
|
||||
- **Learn more:** [Sample Packing](multipack.qmd)
|
||||
|
||||
### Attention Implementations
|
||||
|
||||
Using an optimized attention implementation is critical for training speed.
|
||||
|
||||
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
|
||||
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
|
||||
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
|
||||
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
|
||||
|
||||
*Note: You should only enable one attention backend.*
|
||||
|
||||
### LoRA Optimizations
|
||||
|
||||
Leverages optimized kernels to accelerate LoRA training and reduce memory usage.
|
||||
|
||||
- **Learn more:** [LoRA Optimizations Documentation](lora_optims.qmd)
|
||||
|
||||
## Memory Optimizations
|
||||
|
||||
These techniques help you fit larger models or use bigger batch sizes on your existing hardware.
|
||||
|
||||
### Parameter Efficient Finetuning (LoRA & QLoRA)
|
||||
|
||||
Drastically reduces memory by training a small set of "adapter" parameters instead of the full model. This is the most common and effective memory-saving technique.
|
||||
|
||||
- Examples: Find configs with `lora` or `qlora` in the [examples directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-3).
|
||||
- Config Reference: See `adapter`, `load_in_4bit`, and `load_in_8bit` in the [Configuration Reference](config-reference.qmd).
|
||||
|
||||
### Gradient Checkpointing & Activation Offloading
|
||||
|
||||
These techniques save VRAM by changing how activations are handled.
|
||||
|
||||
- Gradient Checkpointing: re-computes activations during the backward pass, trading compute time for VRAM.
|
||||
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
|
||||
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
|
||||
|
||||
### Cut Cross Entropy (CCE)
|
||||
|
||||
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
|
||||
|
||||
- **Learn more:** [Custom Integrations - CCE](custom_integrations.qmd#cut-cross-entropy)
|
||||
|
||||
### Liger Kernels
|
||||
|
||||
Provides efficient Triton kernels to improve training speed and reduce memory usage.
|
||||
|
||||
- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)
|
||||
|
||||
## Long Context Models
|
||||
|
||||
Techniques to train models on sequences longer than their original context window.
|
||||
|
||||
### RoPE Scaling
|
||||
|
||||
Extends a model's context window by interpolating its Rotary Position Embeddings.
|
||||
|
||||
- **Config:** Pass the `rope_scaling` config under the `overrides_of_model_config: `. To learn how to set RoPE, check the respective model config.
|
||||
|
||||
### Sequence Parallelism
|
||||
|
||||
Splits long sequences across multiple GPUs, enabling training with sequence lengths that would not fit on a single device.
|
||||
|
||||
- **Learn more:** [Sequence Parallelism Documentation](sequence_parallelism.qmd)
|
||||
|
||||
### Artic Long Sequence Training (ALST)
|
||||
|
||||
ALST is a recipe that combines several techniques to train long-context models efficiently. It typically involves:
|
||||
|
||||
- TiledMLP to reduce memory usage in MLP layers.
|
||||
- Tiled Loss functions (like [CCE](#cut-cross-entropy-(cce) or [Liger](#liger-kernels)).
|
||||
- Activation Offloading to CPU.
|
||||
|
||||
- Example: [ALST Example Configuration](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst)
|
||||
|
||||
## Large Models (Distributed Training)
|
||||
|
||||
To train models that don't fit on a single GPU, you'll need to use a distributed training strategy like FSDP or DeepSpeed. These frameworks shard the model weights, gradients, and optimizer states across multiple GPUs and nodes.
|
||||
|
||||
- **Learn more:** [Multi-GPU Guide](multi-gpu.qmd)
|
||||
- **Learn more:** [Multi-Node Guide](multi-node.qmd)
|
||||
|
||||
### N-D Parallelism (Beta)
|
||||
|
||||
For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once.
|
||||
|
||||
- **Learn more:** [N-D Parallelism Guide](nd_parallelism.qmd)
|
||||
|
||||
|
||||
## Quantization
|
||||
|
||||
Techniques to reduce the precision of model weights for memory savings.
|
||||
|
||||
### 4-bit Training (QLoRA)
|
||||
|
||||
The recommended approach for quantization-based training. It loads the base model in 4-bit using `bitsandbytes` and then trains QLoRA adapters. See [Adapter Finetuning](#adapter-finetuning-lora-qlora) for details.
|
||||
|
||||
### FP8 Training
|
||||
|
||||
Enables training with 8-bit floating point precision on supported hardware (e.g., NVIDIA Hopper series GPUs) for significant speed and memory gains.
|
||||
|
||||
- **Example:** [Llama 3 FP8 FSDP Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-3/3b-fp8-fsdp2.yaml)
|
||||
|
||||
### Quantization Aware Training (QAT)
|
||||
|
||||
Simulates quantization effects during training, helping the model adapt and potentially improving the final accuracy of the quantized model.
|
||||
|
||||
- **Learn more:** [QAT Documentation](qat.qmd)
|
||||
|
||||
### GPTQ
|
||||
|
||||
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
|
||||
|
||||
- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)
|
||||
@@ -30,6 +30,7 @@ qat:
|
||||
```
|
||||
|
||||
We support the following quantization schemas:
|
||||
|
||||
- `Int4WeightOnly` (requires the `fbgemm-gpu` extra when installing Axolotl)
|
||||
- `Int8DynamicActivationInt4Weight`
|
||||
- `Float8DynamicActivationFloat8Weight`
|
||||
|
||||
@@ -7,3 +7,24 @@ techniques. It is a combination of:
|
||||
- Activation Offloading: Offload activations to CPU RAM to reduce memory usage
|
||||
|
||||
For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996).
|
||||
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
tiled_mlp: true
|
||||
|
||||
# See Sequence Parallelism docs
|
||||
# https://docs.axolotl.ai/docs/sequence_parallelism.html
|
||||
context_parallel_size: int
|
||||
|
||||
plugins:
|
||||
# See Cut Cross Entropy docs
|
||||
# https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
# or Liger Kernel docs
|
||||
# https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
# ...
|
||||
|
||||
```
|
||||
|
||||
@@ -38,7 +38,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
||||
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 41.7 GiB VRAM.
|
||||
This config uses about 45.62 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
|
||||
@@ -27,6 +27,14 @@ lora_r: 16
|
||||
lora_alpha: 8
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- linear_attn.in_proj_ba
|
||||
- linear_attn.in_proj_qkvz
|
||||
- linear_attn.out_proj
|
||||
- shared_expert.up_proj
|
||||
- shared_expert.down_proj
|
||||
- shared_expert.gate_proj
|
||||
- shared_expert_gate
|
||||
- mlp.gate
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
|
||||
1
scripts/__init__.py
Normal file
1
scripts/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Utility scripts package."""
|
||||
5
scripts/benchmarks/__init__.py
Normal file
5
scripts/benchmarks/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Benchmark helpers."""
|
||||
|
||||
from .deepseek_v3_moe import ACCURACY_TOLERANCE, DTYPE_MAP, benchmark_deepseek_v3
|
||||
|
||||
__all__ = ["benchmark_deepseek_v3", "DTYPE_MAP", "ACCURACY_TOLERANCE"]
|
||||
100
scripts/benchmarks/build_deepseek_v3_8b.py
Executable file
100
scripts/benchmarks/build_deepseek_v3_8b.py
Executable file
@@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Instantiate a ~8.3B DeepSeek-V3 MoE model with random weights.
|
||||
|
||||
Run this on a GPU-equipped machine (e.g. 1× NVL H100) so the dense
|
||||
initialization completes quickly:
|
||||
|
||||
python scripts/benchmarks/build_deepseek_v3_8b.py --output deepseek-v3-8b-moe
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from transformers import DeepseekV3Config, DeepseekV3ForCausalLM
|
||||
|
||||
DTYPE_MAP = {
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float16": torch.float16,
|
||||
}
|
||||
|
||||
|
||||
def build_config() -> DeepseekV3Config:
|
||||
"""Return a DeepSeek V3 configuration totaling ~8.3B parameters."""
|
||||
|
||||
return DeepseekV3Config(
|
||||
vocab_size=32_000,
|
||||
hidden_size=3_072,
|
||||
intermediate_size=8_192,
|
||||
moe_intermediate_size=2_560,
|
||||
num_hidden_layers=20,
|
||||
num_attention_heads=24,
|
||||
num_key_value_heads=24,
|
||||
n_routed_experts=18,
|
||||
num_experts_per_tok=4,
|
||||
n_group=6,
|
||||
topk_group=4,
|
||||
kv_lora_rank=192,
|
||||
q_lora_rank=384,
|
||||
max_position_embeddings=2_048,
|
||||
rope_theta=10_000.0,
|
||||
rope_interleave=True,
|
||||
hidden_act="silu",
|
||||
initializer_range=0.02,
|
||||
attention_dropout=0.0,
|
||||
attention_bias=False,
|
||||
n_shared_experts=1,
|
||||
routed_scaling_factor=2.5,
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory to save the generated model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="bfloat16",
|
||||
choices=DTYPE_MAP.keys(),
|
||||
help="Storage dtype for the checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Torch RNG seed for reproducibility",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
output_dir = args.output
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
config = build_config()
|
||||
model = DeepseekV3ForCausalLM(config)
|
||||
|
||||
dtype = DTYPE_MAP[args.dtype]
|
||||
model.to(dtype=dtype)
|
||||
|
||||
param_count = sum(p.numel() for p in model.parameters())
|
||||
print(f"Initialized DeepSeek-V3 MoE with {param_count / 1e9:.3f}B parameters")
|
||||
|
||||
model.save_pretrained(output_dir, safe_serialization=True)
|
||||
config.save_pretrained(output_dir)
|
||||
print(f"Saved model and config to {output_dir.resolve()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
190
scripts/benchmarks/deepseek_v3_group_gemm_table.py
Normal file
190
scripts/benchmarks/deepseek_v3_group_gemm_table.py
Normal file
@@ -0,0 +1,190 @@
|
||||
#!/usr/bin/env python
|
||||
"""Reproduce TorchTitan CG GEMM timings for selected problem sizes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
CURRENT_DIR = Path(__file__).resolve().parent
|
||||
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||
repo_root = candidate / "axolotl"
|
||||
if repo_root.exists():
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
break
|
||||
else:
|
||||
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||
|
||||
from axolotl.kernels.moe import (
|
||||
cg_grouped_gemm_forward,
|
||||
cg_grouped_gemm_forward_dynamic,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Scenario:
|
||||
num_groups: int
|
||||
m: int
|
||||
n: int
|
||||
k: int
|
||||
|
||||
|
||||
SCENARIOS: tuple[Scenario, ...] = (
|
||||
Scenario(num_groups=4, m=8192, n=4096, k=7168),
|
||||
Scenario(num_groups=4, m=8192, n=7168, k=2048),
|
||||
Scenario(num_groups=8, m=4096, n=4096, k=7168),
|
||||
Scenario(num_groups=8, m=4096, n=7168, k=2048),
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--device", default="cuda", choices=["cuda"], help="Execution device"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="bf16",
|
||||
choices=["bf16", "fp16", "fp32"],
|
||||
help="Computation dtype",
|
||||
)
|
||||
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
|
||||
parser.add_argument("--iters", type=int, default=20, help="Benchmark iterations")
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
default=128,
|
||||
help="GROUP_SIZE_M expected by the kernel",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def pick_dtype(name: str) -> torch.dtype:
|
||||
return {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
}[name]
|
||||
|
||||
|
||||
def make_indices(
|
||||
num_groups: int, group_size: int, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
indices = torch.arange(num_groups, device=device, dtype=torch.int32)
|
||||
return indices.repeat_interleave(group_size)
|
||||
|
||||
|
||||
def timed_call(fn, *args, warmup: int, iters: int) -> float:
|
||||
for _ in range(warmup):
|
||||
fn(*args)
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
for _ in range(iters):
|
||||
fn(*args)
|
||||
torch.cuda.synchronize()
|
||||
return (time.perf_counter() - start) * 1000.0 / iters
|
||||
|
||||
|
||||
def run_scenario(
|
||||
scenario: Scenario,
|
||||
*,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
warmup: int,
|
||||
iters: int,
|
||||
group_size_m: int,
|
||||
) -> dict:
|
||||
if scenario.m % scenario.num_groups != 0:
|
||||
raise ValueError(
|
||||
f"M ({scenario.m}) not divisible by groups ({scenario.num_groups})"
|
||||
)
|
||||
group_size = scenario.m // scenario.num_groups
|
||||
if group_size % group_size_m != 0:
|
||||
raise ValueError(
|
||||
f"Group size {group_size} must be a multiple of GROUP_SIZE_M ({group_size_m}) for the Triton kernel"
|
||||
)
|
||||
|
||||
inputs = torch.randn(scenario.m, scenario.k, device=device, dtype=dtype)
|
||||
weights = torch.randn(
|
||||
scenario.num_groups, scenario.n, scenario.k, device=device, dtype=dtype
|
||||
)
|
||||
indices = make_indices(scenario.num_groups, group_size, device)
|
||||
|
||||
def persistent():
|
||||
return cg_grouped_gemm_forward(inputs, weights, indices, group_size_m)
|
||||
|
||||
def baseline():
|
||||
return cg_grouped_gemm_forward_dynamic(inputs, weights, indices, group_size_m)
|
||||
|
||||
persistent_ms = timed_call(persistent, warmup=warmup, iters=iters)
|
||||
baseline_ms = timed_call(baseline, warmup=warmup, iters=iters)
|
||||
|
||||
return {
|
||||
"scenario": scenario,
|
||||
"persistent_ms": persistent_ms,
|
||||
"baseline_ms": baseline_ms,
|
||||
"speedup": baseline_ms / persistent_ms if persistent_ms > 0 else float("nan"),
|
||||
}
|
||||
|
||||
|
||||
def main() -> None: # pragma: no cover - utility script
|
||||
args = parse_args()
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
if args.device != "cuda" or not torch.cuda.is_available():
|
||||
raise SystemExit("CUDA device required for this benchmark")
|
||||
|
||||
dtype = pick_dtype(args.dtype)
|
||||
device = torch.device(args.device)
|
||||
|
||||
print(
|
||||
f"device={device} dtype={dtype} warmup={args.warmup} iters={args.iters} group_size={args.group_size}"
|
||||
)
|
||||
print(
|
||||
f"{'groups':>7} {'m':>7} {'n':>7} {'k':>7} {'persistent':>12} {'baseline':>12} {'speedup':>8}"
|
||||
)
|
||||
for result in run_all(
|
||||
SCENARIOS,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
warmup=args.warmup,
|
||||
iters=args.iters,
|
||||
group_size_m=args.group_size,
|
||||
):
|
||||
scen = result["scenario"]
|
||||
print(
|
||||
f"{scen.num_groups:>7} {scen.m:>7} {scen.n:>7} {scen.k:>7}"
|
||||
f" {result['persistent_ms']:>11.3f} ms {result['baseline_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
|
||||
)
|
||||
|
||||
|
||||
def run_all(
|
||||
scenarios: Iterable[Scenario],
|
||||
*,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
warmup: int,
|
||||
iters: int,
|
||||
group_size_m: int,
|
||||
) -> Iterable[dict]:
|
||||
for scenario in scenarios:
|
||||
yield run_scenario(
|
||||
scenario,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
warmup=warmup,
|
||||
iters=iters,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
301
scripts/benchmarks/deepseek_v3_moe.py
Normal file
301
scripts/benchmarks/deepseek_v3_moe.py
Normal file
@@ -0,0 +1,301 @@
|
||||
#!/usr/bin/env python
|
||||
# mypy: ignore-errors
|
||||
"""Microbenchmark for DeepSeek V3 MoE block comparing baseline vs Triton CG kernels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from transformers.models.deepseek_v3.configuration_deepseek_v3 import (
|
||||
DeepseekV3Config,
|
||||
)
|
||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
except ImportError as exc: # pragma: no cover - utility script
|
||||
raise SystemExit(
|
||||
"Transformers with DeepSeek-V3 support must be available in PYTHONPATH"
|
||||
) from exc
|
||||
|
||||
CURRENT_DIR = Path(__file__).resolve().parent
|
||||
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||
repo_root = candidate / "axolotl"
|
||||
if repo_root.exists():
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
break
|
||||
else: # pragma: no cover - defensive guard
|
||||
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||
|
||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe # noqa: E402
|
||||
|
||||
ACCURACY_TOLERANCE = 5e-3
|
||||
|
||||
DTYPE_MAP = {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
}
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--batch", type=int, default=8, help="batch size")
|
||||
parser.add_argument("--seq-len", type=int, default=2048, help="sequence length")
|
||||
parser.add_argument("--hidden-size", type=int, default=4096, help="MoE hidden size")
|
||||
parser.add_argument(
|
||||
"--moe-intermediate-size",
|
||||
type=int,
|
||||
default=8192,
|
||||
help="MoE intermediate projection size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-experts",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Number of routed experts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of experts per token",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--groups",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Router groups (must divide n-experts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
choices=DTYPE_MAP.keys(),
|
||||
default="bf16",
|
||||
help="Computation dtype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
default="auto",
|
||||
choices=["auto", "cpu", "cuda"],
|
||||
help="Execution device",
|
||||
)
|
||||
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
|
||||
parser.add_argument("--iters", type=int, default=25, help="Benchmark iterations")
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--uniform-routing",
|
||||
action="store_true",
|
||||
help="Override router to distribute tokens evenly across experts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
default=128,
|
||||
help="GROUP_SIZE_M used by the Triton kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["cg", "mg"],
|
||||
default="mg",
|
||||
help="MoE kernel backend to benchmark",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def resolve_device(requested: str) -> torch.device:
|
||||
if requested == "auto":
|
||||
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return torch.device(requested)
|
||||
|
||||
|
||||
def build_module(args: argparse.Namespace) -> DeepseekV3MoE:
|
||||
config = DeepseekV3Config(
|
||||
hidden_size=args.hidden_size,
|
||||
intermediate_size=args.moe_intermediate_size,
|
||||
moe_intermediate_size=args.moe_intermediate_size,
|
||||
n_routed_experts=args.n_experts,
|
||||
num_experts_per_tok=args.top_k,
|
||||
n_group=args.groups,
|
||||
topk_group=max(1, min(args.groups, args.top_k)),
|
||||
n_shared_experts=1,
|
||||
)
|
||||
module = DeepseekV3MoE(config)
|
||||
module.eval()
|
||||
return module
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def timed_forward(
|
||||
module: DeepseekV3MoE, inputs: torch.Tensor, iters: int, warmup: int
|
||||
) -> float:
|
||||
for _ in range(warmup):
|
||||
module(inputs)
|
||||
if inputs.is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
for _ in range(iters):
|
||||
module(inputs)
|
||||
if inputs.is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start
|
||||
return (elapsed / iters) * 1000.0
|
||||
|
||||
|
||||
def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
device = resolve_device(args.device)
|
||||
dtype = DTYPE_MAP[args.dtype]
|
||||
|
||||
if args.n_experts % args.groups != 0:
|
||||
raise SystemExit("n-experts must be divisible by groups")
|
||||
if args.top_k > args.n_experts:
|
||||
raise SystemExit("top-k cannot exceed number of experts")
|
||||
|
||||
if device.type == "cuda" and not torch.cuda.is_available():
|
||||
raise SystemExit("CUDA requested but not available")
|
||||
|
||||
baseline_module = build_module(args)
|
||||
original_moe = getattr(
|
||||
DeepseekV3MoE,
|
||||
"_axolotl_triton_original_moe",
|
||||
DeepseekV3MoE.moe,
|
||||
)
|
||||
baseline_module.moe = MethodType(original_moe, baseline_module)
|
||||
state_dict = baseline_module.state_dict()
|
||||
|
||||
patch_deepseek_v3_moe(group_size_m=args.group_size, backend=args.backend)
|
||||
patched_module = build_module(args)
|
||||
patched_module.load_state_dict(state_dict)
|
||||
|
||||
baseline_module.to(device=device, dtype=dtype)
|
||||
patched_module.to(device=device, dtype=dtype)
|
||||
|
||||
tokens = args.batch * args.seq_len
|
||||
routed_tokens = tokens * args.top_k
|
||||
avg_tokens_per_expert = routed_tokens / args.n_experts
|
||||
|
||||
inputs = torch.randn(
|
||||
args.batch,
|
||||
args.seq_len,
|
||||
args.hidden_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
flat_inputs = inputs.view(-1, args.hidden_size)
|
||||
if args.uniform_routing:
|
||||
total_assignments = flat_inputs.size(0) * args.top_k
|
||||
base = total_assignments // args.n_experts
|
||||
remainder = total_assignments % args.n_experts
|
||||
counts = torch.full(
|
||||
(args.n_experts,),
|
||||
base,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
if remainder:
|
||||
counts[:remainder] += 1
|
||||
assignments = torch.repeat_interleave(
|
||||
torch.arange(args.n_experts, device=device), counts
|
||||
)
|
||||
assignments = assignments[torch.randperm(assignments.size(0))]
|
||||
topk_idx = assignments.view(flat_inputs.size(0), args.top_k)
|
||||
else:
|
||||
topk_idx, _ = patched_module.gate(flat_inputs)
|
||||
|
||||
tokens_per_expert = torch.bincount(
|
||||
topk_idx.reshape(-1), minlength=args.n_experts
|
||||
)
|
||||
min_tokens = int(tokens_per_expert.min().item())
|
||||
max_tokens = int(tokens_per_expert.max().item())
|
||||
|
||||
if args.uniform_routing:
|
||||
weights = torch.full(
|
||||
topk_idx.shape,
|
||||
1.0 / args.top_k,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
def _uniform_gate(self, hidden_states):
|
||||
flat = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
token_count = flat.shape[0]
|
||||
return topk_idx[:token_count], weights[:token_count]
|
||||
|
||||
patched_module.gate.forward = _uniform_gate.__get__(
|
||||
patched_module.gate, patched_module.gate.__class__
|
||||
)
|
||||
baseline_module.gate.forward = _uniform_gate.__get__(
|
||||
baseline_module.gate, baseline_module.gate.__class__
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_output = baseline_module(inputs)
|
||||
patched_output = patched_module(inputs)
|
||||
max_diff = (ref_output - patched_output).abs().max().item()
|
||||
|
||||
baseline_vram = patched_vram = None
|
||||
if device.type == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
baseline_ms = timed_forward(baseline_module, inputs, args.iters, args.warmup)
|
||||
if device.type == "cuda":
|
||||
baseline_vram = torch.cuda.max_memory_allocated(device)
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
patched_ms = timed_forward(patched_module, inputs, args.iters, args.warmup)
|
||||
if device.type == "cuda":
|
||||
patched_vram = torch.cuda.max_memory_allocated(device)
|
||||
|
||||
speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan")
|
||||
|
||||
return {
|
||||
"device": device,
|
||||
"backend": args.backend,
|
||||
"dtype": dtype,
|
||||
"baseline_ms": baseline_ms,
|
||||
"patched_ms": patched_ms,
|
||||
"speedup": speedup,
|
||||
"max_diff": max_diff,
|
||||
"routed_tokens": routed_tokens,
|
||||
"avg_tokens": avg_tokens_per_expert,
|
||||
"min_tokens": min_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
"baseline_vram": baseline_vram,
|
||||
"patched_vram": patched_vram,
|
||||
"accuracy_ok": max_diff <= ACCURACY_TOLERANCE,
|
||||
}
|
||||
|
||||
|
||||
def main() -> None: # pragma: no cover - CLI entrypoint
|
||||
args = parse_args()
|
||||
result = benchmark_deepseek_v3(args)
|
||||
|
||||
print(
|
||||
f"Device={result['device'].type} dtype={result['dtype']} backend={result['backend']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
|
||||
)
|
||||
print(
|
||||
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"
|
||||
)
|
||||
print(f"min/max tokens per expert: {result['min_tokens']}/{result['max_tokens']}")
|
||||
if result["baseline_vram"] is not None:
|
||||
print(
|
||||
f"VRAM baseline={result['baseline_vram'] / (1024**2):.1f} MiB | patched={result['patched_vram'] / (1024**2):.1f} MiB"
|
||||
)
|
||||
print(
|
||||
f"Baseline: {result['baseline_ms']:.3f} ms | Patched: {result['patched_ms']:.3f} ms | x{result['speedup']:.2f}"
|
||||
)
|
||||
print(f"Max |Δ| between outputs: {result['max_diff']:.2e}")
|
||||
if not result["accuracy_ok"]:
|
||||
raise RuntimeError(
|
||||
f"Accuracy check failed: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
275
scripts/benchmarks/deepseek_v3_moe_sweep.py
Normal file
275
scripts/benchmarks/deepseek_v3_moe_sweep.py
Normal file
@@ -0,0 +1,275 @@
|
||||
#!/usr/bin/env python
|
||||
# mypy: ignore-errors
|
||||
"""Sweep a set of DeepSeek V3 MoE benchmark configurations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
CURRENT_DIR = Path(__file__).resolve().parent
|
||||
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||
repo_root = candidate / "axolotl"
|
||||
if repo_root.exists():
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
break
|
||||
else: # pragma: no cover - defensive guard
|
||||
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||
|
||||
from scripts.benchmarks.deepseek_v3_moe import ( # noqa: E402
|
||||
ACCURACY_TOLERANCE,
|
||||
DTYPE_MAP,
|
||||
benchmark_deepseek_v3,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
choices=DTYPE_MAP.keys(),
|
||||
default="bf16",
|
||||
help="Computation dtype for all benchmarks",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
default="auto",
|
||||
choices=["auto", "cpu", "cuda"],
|
||||
help="Execution device",
|
||||
)
|
||||
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
|
||||
parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations")
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
help="Override GROUP_SIZE_M for every configuration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backends",
|
||||
default="mg",
|
||||
help="Comma separated list of backends to benchmark (subset of cg,mg)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-uniform-routing",
|
||||
action="store_true",
|
||||
help="Disable uniform routing for every configuration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-mixtral-long",
|
||||
action="store_true",
|
||||
help="Add an 8×8192 Mixtral-style run to the sweep",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
help="Optional CSV file to store results",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_namespace(
|
||||
base: dict, args: argparse.Namespace, backend: str
|
||||
) -> SimpleNamespace:
|
||||
combined = dict(base)
|
||||
combined.update(
|
||||
{
|
||||
"dtype": args.dtype,
|
||||
"device": args.device,
|
||||
"backend": backend,
|
||||
"warmup": args.warmup,
|
||||
"iters": args.iters,
|
||||
"seed": args.seed,
|
||||
"uniform_routing": not args.no_uniform_routing,
|
||||
}
|
||||
)
|
||||
if args.group_size is not None:
|
||||
combined["group_size"] = args.group_size
|
||||
return SimpleNamespace(**combined)
|
||||
|
||||
|
||||
ARCHETYPES = (
|
||||
(
|
||||
"mixtral",
|
||||
{
|
||||
"hidden_size": 4096,
|
||||
"moe_intermediate_size": 14336,
|
||||
"n_experts": 8,
|
||||
"top_k": 2,
|
||||
"groups": 1,
|
||||
"group_size": 128,
|
||||
},
|
||||
[(4, 2048), (8, 4096)],
|
||||
),
|
||||
(
|
||||
"qwen",
|
||||
{
|
||||
"hidden_size": 6144,
|
||||
"moe_intermediate_size": 24576,
|
||||
"n_experts": 16,
|
||||
"top_k": 4,
|
||||
"groups": 8,
|
||||
"group_size": 128,
|
||||
},
|
||||
[(4, 4096), (8, 8192)],
|
||||
),
|
||||
(
|
||||
"deepseek_v3",
|
||||
{
|
||||
"hidden_size": 12288,
|
||||
"moe_intermediate_size": 49152,
|
||||
"n_experts": 128,
|
||||
"top_k": 8,
|
||||
"groups": 16,
|
||||
"group_size": 128,
|
||||
},
|
||||
[(4, 4096), (8, 8192)],
|
||||
),
|
||||
)
|
||||
|
||||
MIXTRAL_LONG_SHAPES = [(8, 8192)]
|
||||
|
||||
|
||||
def main() -> None: # pragma: no cover - utility script
|
||||
args = parse_args()
|
||||
|
||||
grid = []
|
||||
for label, base_cfg, shapes in ARCHETYPES:
|
||||
for batch, seq_len in shapes:
|
||||
cfg = {
|
||||
"label": label,
|
||||
"batch": batch,
|
||||
"seq_len": seq_len,
|
||||
**base_cfg,
|
||||
}
|
||||
if cfg["n_experts"] % cfg["groups"] != 0 or cfg["top_k"] > cfg["n_experts"]:
|
||||
continue
|
||||
grid.append(cfg)
|
||||
|
||||
if args.include_mixtral_long:
|
||||
base_cfg = ARCHETYPES[0][1]
|
||||
for batch, seq_len in MIXTRAL_LONG_SHAPES:
|
||||
grid.append(
|
||||
{
|
||||
"label": "mixtral_long",
|
||||
"batch": batch,
|
||||
"seq_len": seq_len,
|
||||
**base_cfg,
|
||||
}
|
||||
)
|
||||
|
||||
if not grid:
|
||||
raise SystemExit("No valid parameter combinations produced")
|
||||
|
||||
header = (
|
||||
"model",
|
||||
"batch",
|
||||
"seq_len",
|
||||
"hidden_size",
|
||||
"moe_intermediate",
|
||||
"n_experts",
|
||||
"top_k",
|
||||
"groups",
|
||||
"backend",
|
||||
"baseline_ms",
|
||||
"patched_ms",
|
||||
"speedup",
|
||||
"baseline_vram_mib",
|
||||
"patched_vram_mib",
|
||||
"min_tokens",
|
||||
"max_tokens",
|
||||
"max_diff",
|
||||
"accuracy_ok",
|
||||
)
|
||||
rows = []
|
||||
|
||||
raw_backends = [
|
||||
token.strip() for token in args.backends.split(",") if token.strip()
|
||||
]
|
||||
if not raw_backends:
|
||||
raw_backends = ["mg"]
|
||||
valid_backends = []
|
||||
for backend in raw_backends:
|
||||
if backend not in {"cg", "mg"}:
|
||||
raise SystemExit(f"Unsupported backend '{backend}' requested")
|
||||
if backend not in valid_backends:
|
||||
valid_backends.append(backend)
|
||||
|
||||
uniform_flag = not args.no_uniform_routing
|
||||
print(
|
||||
f"Running sweep on device={args.device} dtype={args.dtype} backends={tuple(valid_backends)} uniform_routing={uniform_flag}"
|
||||
)
|
||||
print(
|
||||
f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
|
||||
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'diff':>10} {'acc':>5}"
|
||||
)
|
||||
|
||||
for cfg in grid:
|
||||
for backend in valid_backends:
|
||||
ns = make_namespace(cfg, args, backend)
|
||||
result = benchmark_deepseek_v3(ns)
|
||||
baseline_vram_mib = (
|
||||
result["baseline_vram"] / (1024**2)
|
||||
if result["baseline_vram"] is not None
|
||||
else float("nan")
|
||||
)
|
||||
patched_vram_mib = (
|
||||
result["patched_vram"] / (1024**2)
|
||||
if result["patched_vram"] is not None
|
||||
else float("nan")
|
||||
)
|
||||
rows.append(
|
||||
(
|
||||
cfg["label"],
|
||||
cfg["batch"],
|
||||
cfg["seq_len"],
|
||||
cfg["hidden_size"],
|
||||
cfg["moe_intermediate_size"],
|
||||
cfg["n_experts"],
|
||||
cfg["top_k"],
|
||||
cfg["groups"],
|
||||
backend,
|
||||
result["baseline_ms"],
|
||||
result["patched_ms"],
|
||||
result["speedup"],
|
||||
baseline_vram_mib,
|
||||
patched_vram_mib,
|
||||
result["min_tokens"],
|
||||
result["max_tokens"],
|
||||
result["max_diff"],
|
||||
result["accuracy_ok"],
|
||||
)
|
||||
)
|
||||
status = "OK" if result["accuracy_ok"] else "FAIL"
|
||||
print(
|
||||
f"{cfg['label']:>10} {cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {backend:>8}"
|
||||
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
|
||||
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {result['max_diff']:>10.3e} {status:>5}"
|
||||
)
|
||||
if not result["accuracy_ok"]:
|
||||
LOG.warning(
|
||||
"Accuracy tolerance exceeded for %s backend=%s: diff=%.3e (> %.1e)",
|
||||
cfg["label"],
|
||||
backend,
|
||||
result["max_diff"],
|
||||
ACCURACY_TOLERANCE,
|
||||
)
|
||||
|
||||
if args.output:
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
with args.output.open("w", newline="") as fp:
|
||||
writer = csv.writer(fp)
|
||||
writer.writerow(header)
|
||||
writer.writerows(rows)
|
||||
print(f"Results written to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
21
src/axolotl/kernels/moe/__init__.py
Normal file
21
src/axolotl/kernels/moe/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Mixture-of-Experts kernel implementations."""
|
||||
|
||||
from .indices import generate_permute_indices
|
||||
from .tt_cg_gemm import (
|
||||
ContiguousGroupedGEMM,
|
||||
ContiguousGroupedGEMMForwardOnly,
|
||||
cg_grouped_gemm,
|
||||
cg_grouped_gemm_forward,
|
||||
cg_grouped_gemm_forward_dynamic,
|
||||
)
|
||||
from .tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
|
||||
|
||||
__all__ = [
|
||||
"cg_grouped_gemm",
|
||||
"cg_grouped_gemm_forward",
|
||||
"cg_grouped_gemm_forward_dynamic",
|
||||
"ContiguousGroupedGEMM",
|
||||
"ContiguousGroupedGEMMForwardOnly",
|
||||
"generate_permute_indices",
|
||||
"mg_grouped_gemm",
|
||||
]
|
||||
5
src/axolotl/kernels/moe/indices/__init__.py
Normal file
5
src/axolotl/kernels/moe/indices/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Token permutation utilities for grouped MoE kernels."""
|
||||
|
||||
from .indices import generate_permute_indices
|
||||
|
||||
__all__ = ["generate_permute_indices"]
|
||||
144
src/axolotl/kernels/moe/indices/indices.py
Normal file
144
src/axolotl/kernels/moe/indices/indices.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Vendored token permutation kernels from TorchTitan."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
__all__ = ["generate_permute_indices"]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fill_indices_kernel(
|
||||
tokens_per_expert_group_ptr,
|
||||
start_index_values_ptr,
|
||||
write_offsets_ptr,
|
||||
output_ptr,
|
||||
experts_per_rank: tl.constexpr,
|
||||
num_ranks: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_programs = tl.num_programs(axis=0)
|
||||
|
||||
for expert_id in range(pid, experts_per_rank, num_programs):
|
||||
write_offset = tl.load(write_offsets_ptr + expert_id)
|
||||
|
||||
for r in range(num_ranks):
|
||||
idx = r * experts_per_rank + expert_id
|
||||
|
||||
start_index = tl.load(start_index_values_ptr + idx)
|
||||
length = tl.load(tokens_per_expert_group_ptr + idx)
|
||||
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
for chunk_start in range(0, length, BLOCK_SIZE):
|
||||
chunk_offsets = chunk_start + offsets
|
||||
mask = chunk_offsets < length
|
||||
values = start_index + chunk_offsets
|
||||
dest_indices = write_offset + chunk_offsets
|
||||
tl.store(output_ptr + dest_indices, values, mask=mask)
|
||||
|
||||
write_offset += length
|
||||
|
||||
|
||||
def fill_indices_wrapper(
|
||||
tokens_per_expert_group: torch.Tensor,
|
||||
start_index_values: torch.Tensor,
|
||||
write_offsets: torch.Tensor,
|
||||
experts_per_rank: int,
|
||||
num_ranks: int,
|
||||
max_len: int,
|
||||
block_size: int = 128,
|
||||
max_blocks: int = 1024,
|
||||
):
|
||||
permuted_indices = torch.full(
|
||||
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
|
||||
)
|
||||
num_blocks = min(experts_per_rank, max_blocks)
|
||||
grid = (num_blocks,)
|
||||
_fill_indices_kernel[grid](
|
||||
tokens_per_expert_group,
|
||||
start_index_values,
|
||||
write_offsets,
|
||||
permuted_indices,
|
||||
experts_per_rank,
|
||||
num_ranks,
|
||||
BLOCK_SIZE=block_size,
|
||||
)
|
||||
return permuted_indices
|
||||
|
||||
|
||||
def fill_indices_cpu(
|
||||
tokens_per_expert_group: torch.Tensor,
|
||||
start_index_values: torch.Tensor,
|
||||
write_offsets: torch.Tensor,
|
||||
experts_per_rank: int,
|
||||
num_ranks: int,
|
||||
max_len: int,
|
||||
):
|
||||
permuted_indices = torch.full((max_len,), -1, dtype=torch.int32)
|
||||
for expert_id in range(experts_per_rank):
|
||||
write_start = write_offsets[expert_id].item()
|
||||
for r in range(num_ranks):
|
||||
idx = r * experts_per_rank + expert_id
|
||||
start_index = start_index_values[idx].item()
|
||||
length = tokens_per_expert_group[idx].item()
|
||||
if length > 0:
|
||||
end_idx = min(write_start + length, max_len)
|
||||
permuted_indices[write_start:end_idx] = torch.arange(
|
||||
start_index,
|
||||
start_index + (end_idx - write_start),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
write_start += length
|
||||
return permuted_indices
|
||||
|
||||
|
||||
def generate_permute_indices(
|
||||
tokens_per_expert_group: torch.Tensor,
|
||||
experts_per_rank: int,
|
||||
num_ranks: int,
|
||||
max_len: int,
|
||||
alignment: int,
|
||||
use_cpu: bool = False,
|
||||
):
|
||||
start_index_values = (
|
||||
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
|
||||
)
|
||||
|
||||
total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
|
||||
total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment)
|
||||
|
||||
m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to(
|
||||
torch.int32
|
||||
)
|
||||
|
||||
m_offsets = torch.cumsum(m_sizes, 0)
|
||||
write_offsets = m_offsets - m_sizes
|
||||
|
||||
if use_cpu:
|
||||
permuted_indices = fill_indices_cpu(
|
||||
tokens_per_expert_group,
|
||||
start_index_values,
|
||||
write_offsets,
|
||||
experts_per_rank,
|
||||
num_ranks,
|
||||
max_len,
|
||||
)
|
||||
else:
|
||||
permuted_indices = fill_indices_wrapper(
|
||||
tokens_per_expert_group,
|
||||
start_index_values,
|
||||
write_offsets,
|
||||
experts_per_rank,
|
||||
num_ranks,
|
||||
max_len,
|
||||
)
|
||||
|
||||
return permuted_indices, m_sizes, m_offsets.to(torch.int32)
|
||||
17
src/axolotl/kernels/moe/tt_cg_gemm/__init__.py
Normal file
17
src/axolotl/kernels/moe/tt_cg_gemm/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Vendored Triton contiguous grouped GEMM kernels from TorchTitan."""
|
||||
|
||||
from .cg_backward import ContiguousGroupedGEMM
|
||||
from .cg_forward import (
|
||||
ContiguousGroupedGEMM as ContiguousGroupedGEMMForwardOnly,
|
||||
cg_grouped_gemm,
|
||||
cg_grouped_gemm_forward,
|
||||
cg_grouped_gemm_forward_dynamic,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"cg_grouped_gemm",
|
||||
"cg_grouped_gemm_forward",
|
||||
"cg_grouped_gemm_forward_dynamic",
|
||||
"ContiguousGroupedGEMM",
|
||||
"ContiguousGroupedGEMMForwardOnly",
|
||||
]
|
||||
290
src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py
Normal file
290
src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""Vendored backward pass for Triton contiguous grouped GEMM."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .cg_forward import cg_grouped_gemm_forward
|
||||
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
|
||||
|
||||
GROUP_SIZE_M = 128
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=STANDARD_CONFIGS,
|
||||
key=["M_TOTAL", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_cg_backward_dx(
|
||||
grad_output_ptr,
|
||||
b_ptr,
|
||||
grad_input_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
|
||||
):
|
||||
"""Compute gradients with respect to inputs."""
|
||||
|
||||
pid = tl.program_id(0)
|
||||
|
||||
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
|
||||
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
|
||||
tile_m = pid // num_k_tiles
|
||||
tile_k = pid % num_k_tiles
|
||||
|
||||
m_start = tile_m * BLOCK_SIZE_M
|
||||
k_start = tile_k * BLOCK_SIZE_K
|
||||
|
||||
if m_start < M_TOTAL:
|
||||
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_k = offs_k < K
|
||||
|
||||
group_idx = m_start // GROUP_SIZE_M
|
||||
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
|
||||
|
||||
grad_input = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_K], dtype=tl.float32)
|
||||
|
||||
for n in range(0, N, BLOCK_SIZE_N):
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N) + n
|
||||
mask_n = offs_n < N
|
||||
|
||||
mask_go = mask_m[:, None] & mask_n[None, :]
|
||||
mask_w = mask_n[:, None] & mask_k[None, :]
|
||||
|
||||
go_ptrs = grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
|
||||
|
||||
w_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
||||
w = tl.load(w_ptrs, mask=mask_w, other=0.0).to(tl.float32)
|
||||
|
||||
grad_input += tl.dot(go, w)
|
||||
|
||||
grad_input_ptrs = grad_input_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
mask_gi = mask_m[:, None] & mask_k[None, :]
|
||||
tl.store(grad_input_ptrs, grad_input, mask=mask_gi)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _kernel_cg_backward_dw(
|
||||
grad_output_ptr,
|
||||
inputs_ptr,
|
||||
grad_weights_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""Simplified kernel for expert weight gradients."""
|
||||
|
||||
pid = tl.program_id(0)
|
||||
|
||||
expert_id = pid // ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
|
||||
position_id = pid % ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
|
||||
|
||||
if expert_id < NUM_EXPERTS:
|
||||
n_tiles = K // BLOCK_SIZE_K
|
||||
tile_n = position_id // n_tiles
|
||||
tile_k = position_id % n_tiles
|
||||
|
||||
n_start = tile_n * BLOCK_SIZE_N
|
||||
k_start = tile_k * BLOCK_SIZE_K
|
||||
|
||||
if n_start < N and k_start < K:
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
|
||||
|
||||
mask_n = offs_n < N
|
||||
mask_k = offs_k < K
|
||||
|
||||
grad_weights = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_K], dtype=tl.float32)
|
||||
|
||||
for group_idx in range(0, M_TOTAL // GROUP_SIZE_M):
|
||||
group_start = group_idx * GROUP_SIZE_M
|
||||
group_expert = tl.load(indices_ptr + group_start)
|
||||
|
||||
if group_expert == expert_id:
|
||||
for m_offset in range(0, GROUP_SIZE_M, BLOCK_SIZE_M):
|
||||
m_start = group_start + m_offset
|
||||
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
|
||||
|
||||
mask_m = offs_m < min(group_start + GROUP_SIZE_M, M_TOTAL)
|
||||
|
||||
go_ptrs = (
|
||||
grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
)
|
||||
mask_go = mask_m[:, None] & mask_n[None, :]
|
||||
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
|
||||
|
||||
in_ptrs = inputs_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
mask_in = mask_m[:, None] & mask_k[None, :]
|
||||
inp = tl.load(in_ptrs, mask=mask_in, other=0.0).to(tl.float32)
|
||||
|
||||
go_t = tl.trans(go)
|
||||
grad_weights += tl.dot(go_t, inp)
|
||||
|
||||
grad_w_ptrs = (
|
||||
grad_weights_ptr
|
||||
+ expert_id * N * K
|
||||
+ offs_n[:, None] * K
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
mask_gw = mask_n[:, None] & mask_k[None, :]
|
||||
tl.store(grad_w_ptrs, grad_weights, mask=mask_gw)
|
||||
|
||||
|
||||
def cg_grouped_gemm_backward_weights(
|
||||
grad_output: torch.Tensor,
|
||||
inputs: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
num_experts: int,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Backward pass for expert weights."""
|
||||
|
||||
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
|
||||
assert inputs.is_contiguous(), "Inputs tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, N = grad_output.shape
|
||||
_, K = inputs.shape
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
grad_weights = torch.zeros(
|
||||
(num_experts, N, K), device=grad_output.device, dtype=grad_output.dtype
|
||||
)
|
||||
|
||||
block_size_n = min(128, N)
|
||||
block_size_k = min(32, K)
|
||||
block_size_m = min(32, group_size_m)
|
||||
|
||||
n_tiles = triton.cdiv(N, block_size_n)
|
||||
k_tiles = triton.cdiv(K, block_size_k)
|
||||
grid = (num_experts * n_tiles * k_tiles,)
|
||||
|
||||
_kernel_cg_backward_dw[grid](
|
||||
grad_output,
|
||||
inputs,
|
||||
grad_weights,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
BLOCK_SIZE_N=block_size_n,
|
||||
BLOCK_SIZE_K=block_size_k,
|
||||
BLOCK_SIZE_M=block_size_m,
|
||||
)
|
||||
|
||||
return grad_weights
|
||||
|
||||
|
||||
def cg_grouped_gemm_backward_inputs(
|
||||
grad_output: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Backward pass for inputs."""
|
||||
|
||||
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
|
||||
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, N = grad_output.shape
|
||||
num_experts, _, K = expert_weights.shape
|
||||
|
||||
assert M_total % group_size_m == 0, (
|
||||
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
|
||||
)
|
||||
|
||||
grad_inputs = torch.zeros(
|
||||
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
|
||||
)
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(K, meta["BLOCK_SIZE_K"]),
|
||||
)
|
||||
|
||||
_kernel_cg_backward_dx[grid](
|
||||
grad_output,
|
||||
expert_weights,
|
||||
grad_inputs,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
)
|
||||
|
||||
return grad_inputs
|
||||
|
||||
|
||||
class ContiguousGroupedGEMM(torch.autograd.Function):
|
||||
"""Autograd function with full backward support."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
|
||||
ctx.save_for_backward(inputs, expert_weights, expert_indices)
|
||||
ctx.group_size_m = group_size_m
|
||||
|
||||
return cg_grouped_gemm_forward(
|
||||
inputs=inputs,
|
||||
expert_weights=expert_weights,
|
||||
expert_indices=expert_indices,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
inputs, expert_weights, expert_indices = ctx.saved_tensors
|
||||
group_size_m = ctx.group_size_m
|
||||
|
||||
grad_output = grad_output.contiguous()
|
||||
num_experts = expert_weights.shape[0]
|
||||
|
||||
grad_inputs = cg_grouped_gemm_backward_inputs(
|
||||
grad_output=grad_output,
|
||||
expert_weights=expert_weights,
|
||||
expert_indices=expert_indices,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
grad_weights = cg_grouped_gemm_backward_weights(
|
||||
grad_output=grad_output,
|
||||
inputs=inputs,
|
||||
expert_indices=expert_indices,
|
||||
num_experts=num_experts,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
grad_indices = None
|
||||
grad_group_size_m = None
|
||||
|
||||
return grad_inputs, grad_weights, grad_indices, grad_group_size_m
|
||||
311
src/axolotl/kernels/moe/tt_cg_gemm/cg_forward.py
Normal file
311
src/axolotl/kernels/moe/tt_cg_gemm/cg_forward.py
Normal file
@@ -0,0 +1,311 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Vendored forward Triton contiguous grouped GEMM kernels."""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
|
||||
|
||||
GROUP_SIZE_M = 128
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, super_group_m):
|
||||
group_id = tile_id // num_pid_in_group
|
||||
first_pid_m = group_id * super_group_m
|
||||
group_size_m = min(num_pid_m - first_pid_m, super_group_m)
|
||||
pid_m = first_pid_m + (tile_id % group_size_m)
|
||||
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
||||
return pid_m, pid_n
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=STANDARD_CONFIGS,
|
||||
key=["M_TOTAL", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_cg_persistent_forward(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
NUM_SMS: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
|
||||
SUPER_GROUP_M: tl.constexpr = 32,
|
||||
):
|
||||
"""
|
||||
Contiguous Grouped GEMM kernel forward (persistent variant).
|
||||
"""
|
||||
|
||||
c_type = c_ptr.dtype.element_ty
|
||||
|
||||
start_pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_tiles = num_pid_m * num_pid_n
|
||||
tile_id_c = start_pid - NUM_SMS
|
||||
num_pid_in_group = SUPER_GROUP_M * num_pid_n
|
||||
|
||||
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS):
|
||||
tile_m_idx, tile_n_idx = _compute_pid(
|
||||
tile_id, num_pid_in_group, num_pid_m, SUPER_GROUP_M
|
||||
)
|
||||
|
||||
m_start = tile_m_idx * BLOCK_SIZE_M
|
||||
n_start = tile_n_idx * BLOCK_SIZE_N
|
||||
|
||||
if m_start < M_TOTAL:
|
||||
offs_m = m_start + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = n_start + tl.arange(0, BLOCK_SIZE_N)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for ki in range(k_tiles):
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_n = offs_n < N
|
||||
mask_k = offs_k < K
|
||||
|
||||
mask_a = mask_m[:, None] & mask_k[None, :]
|
||||
mask_b = mask_n[:, None] & mask_k[None, :]
|
||||
|
||||
group_idx = m_start // GROUP_SIZE_M
|
||||
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
|
||||
|
||||
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
|
||||
|
||||
b_ptrs = (
|
||||
b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
|
||||
|
||||
accumulator += tl.dot(a, b.T)
|
||||
|
||||
tile_id_c += NUM_SMS
|
||||
tile_m_idx, tile_n_idx = _compute_pid(
|
||||
tile_id_c, num_pid_in_group, num_pid_m, SUPER_GROUP_M
|
||||
)
|
||||
|
||||
offs_m = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_n = offs_n < N
|
||||
mask_c = mask_m[:, None] & mask_n[None, :]
|
||||
|
||||
c = accumulator.to(tl.float32)
|
||||
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
tl.store(c_ptrs, c.to(c_type), mask=mask_c)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=STANDARD_CONFIGS,
|
||||
key=["M_TOTAL", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_cg_forward_aligned(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
|
||||
):
|
||||
"""
|
||||
Contiguous Grouped GEMM kernel forward for aligned inputs.
|
||||
"""
|
||||
|
||||
pid = tl.program_id(0)
|
||||
|
||||
c_type = c_ptr.dtype.element_ty
|
||||
|
||||
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
|
||||
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
tile_m = pid // num_n_tiles
|
||||
tile_n = pid % num_n_tiles
|
||||
|
||||
m_start = tile_m * BLOCK_SIZE_M
|
||||
n_start = tile_n * BLOCK_SIZE_N
|
||||
|
||||
if m_start < M_TOTAL:
|
||||
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_n = offs_n < N
|
||||
|
||||
group_idx = m_start // GROUP_SIZE_M
|
||||
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
|
||||
|
||||
acc = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], dtype=tl.float32)
|
||||
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + k
|
||||
mask_k = offs_k < K
|
||||
|
||||
mask_a = mask_m[:, None] & mask_k[None, :]
|
||||
mask_b = mask_n[:, None] & mask_k[None, :]
|
||||
|
||||
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
|
||||
|
||||
b_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
||||
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
|
||||
|
||||
acc += tl.dot(a, b.T)
|
||||
|
||||
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
mask_c = mask_m[:, None] & mask_n[None, :]
|
||||
tl.store(c_ptrs, acc.to(c_type), mask=mask_c)
|
||||
|
||||
|
||||
def cg_grouped_gemm_forward(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Contiguous grouped GEMM forward pass for MoE."""
|
||||
|
||||
assert inputs.is_contiguous(), "Input tensor must be contiguous"
|
||||
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, K = inputs.shape
|
||||
assert M_total % group_size_m == 0, (
|
||||
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
|
||||
)
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
num_experts, N, K_weights = expert_weights.shape
|
||||
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
|
||||
assert expert_indices.shape[0] == M_total, (
|
||||
"Expert indices length must match M_total"
|
||||
)
|
||||
|
||||
output = torch.empty((M_total, N), device=inputs.device, dtype=torch.bfloat16)
|
||||
|
||||
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
|
||||
grid = (NUM_SMS, 1, 1)
|
||||
_kernel_cg_persistent_forward[grid](
|
||||
inputs,
|
||||
expert_weights,
|
||||
output,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
NUM_SMS=NUM_SMS,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def cg_grouped_gemm_forward_dynamic(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Contiguous grouped GEMM forward pass for MoE with autotuned launch."""
|
||||
|
||||
assert inputs.is_contiguous(), "Input tensor must be contiguous"
|
||||
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, K = inputs.shape
|
||||
assert M_total % group_size_m == 0, (
|
||||
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
|
||||
)
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
num_experts, N, K_weights = expert_weights.shape
|
||||
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
|
||||
assert expert_indices.shape[0] == M_total, (
|
||||
"Expert indices length must match M_total"
|
||||
)
|
||||
|
||||
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(N, meta["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
_kernel_cg_forward_aligned[grid](
|
||||
inputs,
|
||||
expert_weights,
|
||||
output,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ContiguousGroupedGEMM(torch.autograd.Function):
|
||||
"""Autograd function for contiguous grouped GEMM forward pass only."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
|
||||
return cg_grouped_gemm_forward(
|
||||
inputs=inputs,
|
||||
expert_weights=expert_weights,
|
||||
expert_indices=expert_indices,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output): # pragma: no cover - not implemented
|
||||
raise NotImplementedError("Backward pass not implemented")
|
||||
|
||||
|
||||
def cg_grouped_gemm(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Convenience wrapper for the forward-only autograd function."""
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
return ContiguousGroupedGEMM.apply(
|
||||
inputs, expert_weights, expert_indices, group_size_m
|
||||
)
|
||||
31
src/axolotl/kernels/moe/tt_cg_gemm/cg_reference.py
Normal file
31
src/axolotl/kernels/moe/tt_cg_gemm/cg_reference.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Reference implementation for contiguous grouped GEMM."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def pytorch_reference(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""Simple PyTorch implementation for verification."""
|
||||
|
||||
M_total, K = inputs.shape
|
||||
num_experts, N, _ = expert_weights.shape
|
||||
|
||||
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
|
||||
|
||||
for i in range(0, M_total, group_size_m):
|
||||
end_idx = min(i + group_size_m, M_total)
|
||||
expert_idx = expert_indices[i].item()
|
||||
expert_weight = expert_weights[expert_idx]
|
||||
output[i:end_idx] = torch.matmul(inputs[i:end_idx], expert_weight.T)
|
||||
|
||||
return output
|
||||
209
src/axolotl/kernels/moe/tt_cg_gemm/tma_cuda_autotune.py
Normal file
209
src/axolotl/kernels/moe/tt_cg_gemm/tma_cuda_autotune.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Autotuning utilities for Triton contiguous grouped GEMM kernels."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.runtime import driver
|
||||
|
||||
|
||||
class CudaUtils:
|
||||
"""Helper utilities for CUDA specific Triton features."""
|
||||
|
||||
@staticmethod
|
||||
def is_cuda() -> bool:
|
||||
return driver.active.get_current_target().backend == "cuda"
|
||||
|
||||
@staticmethod
|
||||
def verify_tma() -> bool:
|
||||
return (
|
||||
CudaUtils.is_cuda()
|
||||
and torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability()[0] >= 9
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_num_sms() -> int:
|
||||
if not CudaUtils.is_cuda():
|
||||
raise RuntimeError("Triton is not running on CUDA backend")
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is not available")
|
||||
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
|
||||
|
||||
class TmaDescriptorHelper:
|
||||
"""Helper class for managing TMA descriptors in Triton kernels."""
|
||||
|
||||
class KernelParamWrapper:
|
||||
def __init__(self, desc: torch.Tensor):
|
||||
self.desc = desc
|
||||
|
||||
def tma_desc_cpu_ptr(self) -> int:
|
||||
return self.desc.data_ptr()
|
||||
|
||||
def __init__(self, tma_size: int = 128):
|
||||
if not CudaUtils.verify_tma():
|
||||
raise RuntimeError(
|
||||
"TMA not supported on this device (requires Hopper or newer)"
|
||||
)
|
||||
if "nv_tma_desc_type" not in dir(tl):
|
||||
raise RuntimeError(
|
||||
"TMA grid constant descriptors not supported in your Triton version"
|
||||
)
|
||||
|
||||
self.tma_size = tma_size
|
||||
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
|
||||
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
|
||||
self.descriptors: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def init_tma_descriptor(self, name: str) -> None:
|
||||
self.descriptors[name] = torch.empty(
|
||||
self.tma_size, device="cpu", dtype=torch.int8
|
||||
)
|
||||
|
||||
def fill_1d_tma_descriptor(
|
||||
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
|
||||
) -> None:
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_1d_tma_descriptor_inner(
|
||||
ptr, dim, block_dim, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def fill_2d_tma_descriptor(
|
||||
self,
|
||||
name: str,
|
||||
ptr: int,
|
||||
dim1: int,
|
||||
dim0: int,
|
||||
block_dim1: int,
|
||||
block_dim0: int,
|
||||
element_size: int,
|
||||
) -> None:
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_2d_tma_descriptor_inner(
|
||||
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def get_tma_descriptor_kernel_param(
|
||||
self, name: str
|
||||
) -> "TmaDescriptorHelper.KernelParamWrapper":
|
||||
if name not in self.descriptors or self.descriptors[name] is None:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
return self.KernelParamWrapper(self.descriptors[name])
|
||||
|
||||
|
||||
HOPPER_CONFIGS = [
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=8,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
STANDARD_CONFIGS = [
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=8,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def early_config_prune(configs, args, **kwargs):
|
||||
"""Filter out configurations that would exceed shared memory capacity."""
|
||||
k = kwargs.get("K", 0)
|
||||
valid_configs = [
|
||||
config for config in configs if config.kwargs.get("BLOCK_SIZE_K", 0) <= k
|
||||
]
|
||||
if not valid_configs and configs:
|
||||
return [
|
||||
min(
|
||||
configs,
|
||||
key=lambda c: c.kwargs.get("BLOCK_SIZE_K", float("inf")),
|
||||
)
|
||||
]
|
||||
|
||||
return valid_configs
|
||||
13
src/axolotl/kernels/moe/tt_mg_gemm/__init__.py
Normal file
13
src/axolotl/kernels/moe/tt_mg_gemm/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .mg_grouped_gemm import grouped_gemm_forward
|
||||
from .tma_autotuning import ALIGN_SIZE_M
|
||||
|
||||
__all__ = [
|
||||
"grouped_gemm_forward",
|
||||
"ALIGN_SIZE_M",
|
||||
]
|
||||
761
src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py
Normal file
761
src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py
Normal file
@@ -0,0 +1,761 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# credit - flat index forward kernel is derived from FBGemm:
|
||||
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
||||
|
||||
# pyre-unsafe
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .tma_autotuning import (
|
||||
_NV_CONFIGS,
|
||||
CudaUtils,
|
||||
early_config_prune,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
_allocator_registered = False
|
||||
|
||||
|
||||
def _torch_allocator(size: int, alignment: int, stream) -> torch.Tensor:
|
||||
return torch.empty(size, device="cuda", dtype=torch.int8)
|
||||
|
||||
|
||||
def _ensure_triton_allocator() -> None:
|
||||
global _allocator_registered
|
||||
if not _allocator_registered:
|
||||
triton.set_allocator(_torch_allocator)
|
||||
_allocator_registered = True
|
||||
|
||||
|
||||
# ============== Start Triton Kernels ===============
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=_NV_CONFIGS,
|
||||
key=["G", "M_BUCKET", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_mg_forward_hopper(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m_sizes,
|
||||
M_TOTAL,
|
||||
# problem sizes
|
||||
G: tl.constexpr,
|
||||
M_BUCKET: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
# config
|
||||
NUM_SMS: tl.constexpr,
|
||||
USE_EPILOGUE_SUBTILING: tl.constexpr,
|
||||
# tiles
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
) -> None:
|
||||
"""Flat index style forward kernel for Hopper using tensor descriptors."""
|
||||
tbidx = tl.program_id(0)
|
||||
|
||||
c_dtype = c_ptr.dtype.element_ty
|
||||
n_size = N // G
|
||||
|
||||
a_desc = tl.make_tensor_descriptor(
|
||||
a_ptr,
|
||||
shape=[M_TOTAL, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
||||
)
|
||||
b_desc = tl.make_tensor_descriptor(
|
||||
b_ptr,
|
||||
shape=[N, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
||||
)
|
||||
|
||||
M_end = tl.full([], 0, dtype=tl.int32)
|
||||
processed_tiles = 0
|
||||
|
||||
for g in range(G):
|
||||
M_start = M_end
|
||||
m_size = tl.load(m_sizes + g)
|
||||
M_end = M_start + m_size
|
||||
|
||||
if m_size > 0:
|
||||
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
||||
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
||||
group_num_tiles = num_m_tiles * num_n_tiles
|
||||
|
||||
while (
|
||||
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
|
||||
):
|
||||
group_index = tbidx - processed_tiles
|
||||
|
||||
tile_m_index = group_index % num_m_tiles
|
||||
tile_n_index = group_index // num_m_tiles
|
||||
|
||||
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
|
||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||
|
||||
cols_remaining = n_size - tile_n_index * BLOCK_SIZE_N
|
||||
col_mask = tl.arange(0, BLOCK_SIZE_N) < cols_remaining
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
||||
global_n_offset = (g * n_size + n_offset).to(tl.int32)
|
||||
|
||||
for k_offset in range(0, K, BLOCK_SIZE_K):
|
||||
k_remaining = K - k_offset
|
||||
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
|
||||
|
||||
a = a_desc.load([m_offset, k_offset])
|
||||
a_mask = row_mask[:, None] & k_mask[None, :]
|
||||
a = tl.where(a_mask, a, tl.zeros_like(a))
|
||||
|
||||
b = b_desc.load([global_n_offset, k_offset])
|
||||
b_mask = col_mask[:, None] & k_mask[None, :]
|
||||
b = tl.where(b_mask, b, tl.zeros_like(b))
|
||||
|
||||
accumulator += tl.dot(a, b.T)
|
||||
|
||||
local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||
|
||||
local_row_offsets = local_m_offset + tl.arange(0, BLOCK_SIZE_M)
|
||||
row_store_mask = local_row_offsets < m_size
|
||||
global_row = (M_start + local_row_offsets).to(tl.int32)
|
||||
|
||||
local_col_offsets = tile_n_index * BLOCK_SIZE_N + tl.arange(
|
||||
0, BLOCK_SIZE_N
|
||||
)
|
||||
col_store_mask = local_col_offsets < n_size
|
||||
|
||||
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
|
||||
|
||||
if USE_EPILOGUE_SUBTILING:
|
||||
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
|
||||
acc = tl.permute(acc, (0, 2, 1))
|
||||
acc0, acc1 = tl.split(acc)
|
||||
|
||||
col_offsets0 = local_col_offsets[: BLOCK_SIZE_N // 2]
|
||||
col_mask0 = col_store_mask[: BLOCK_SIZE_N // 2]
|
||||
ptr0 = c_ptr + global_row[:, None] * n_size + col_offsets0[None, :]
|
||||
tl.store(
|
||||
ptr0,
|
||||
acc0.to(c_dtype),
|
||||
mask=row_store_mask[:, None] & col_mask0[None, :],
|
||||
)
|
||||
|
||||
col_offsets1 = local_col_offsets[BLOCK_SIZE_N // 2 :]
|
||||
col_mask1 = col_store_mask[BLOCK_SIZE_N // 2 :]
|
||||
ptr1 = c_ptr + global_row[:, None] * n_size + col_offsets1[None, :]
|
||||
tl.store(
|
||||
ptr1,
|
||||
acc1.to(c_dtype),
|
||||
mask=row_store_mask[:, None] & col_mask1[None, :],
|
||||
)
|
||||
else:
|
||||
ptr = (
|
||||
c_ptr
|
||||
+ global_row[:, None] * n_size
|
||||
+ local_col_offsets[None, :]
|
||||
)
|
||||
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
|
||||
|
||||
tbidx += NUM_SMS
|
||||
|
||||
processed_tiles += group_num_tiles
|
||||
|
||||
|
||||
"""
|
||||
Backward pass for grouped GEMM with Triton, where grouping is M*G
|
||||
We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
|
||||
"""
|
||||
|
||||
|
||||
# ---- dx flat linear indexed ----
|
||||
@triton.autotune(
|
||||
configs=_NV_CONFIGS,
|
||||
key=["G", "M_BUCKET", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_mg_dx_tma(
|
||||
grad_output_ptr,
|
||||
w_ptr,
|
||||
grad_input_ptr,
|
||||
m_sizes,
|
||||
M_TOTAL,
|
||||
# problem sizes
|
||||
G: tl.constexpr,
|
||||
M_BUCKET: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
# config
|
||||
NUM_SMS: tl.constexpr,
|
||||
# tiles
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
) -> None:
|
||||
"""Compute grad_input = grad_output @ w using tensor descriptors."""
|
||||
tbidx = tl.program_id(0)
|
||||
|
||||
c_dtype = grad_input_ptr.dtype.element_ty
|
||||
|
||||
grad_output_desc = tl.make_tensor_descriptor(
|
||||
grad_output_ptr,
|
||||
shape=[M_TOTAL, N],
|
||||
strides=[N, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||
)
|
||||
w_desc = tl.make_tensor_descriptor(
|
||||
w_ptr,
|
||||
shape=[N, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
||||
)
|
||||
|
||||
M_end = tl.full([], 0, dtype=tl.int32)
|
||||
processed_tiles = 0
|
||||
|
||||
for g in range(G):
|
||||
M_start = M_end
|
||||
m_size = tl.load(m_sizes + g)
|
||||
M_end = M_start + m_size
|
||||
|
||||
if m_size > 0:
|
||||
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
||||
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
group_num_tiles = num_m_tiles * num_k_tiles
|
||||
|
||||
while (
|
||||
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
|
||||
):
|
||||
group_index = tbidx - processed_tiles
|
||||
|
||||
tile_m_index = group_index % num_m_tiles
|
||||
tile_k_index = group_index // num_m_tiles
|
||||
|
||||
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
|
||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||
|
||||
k_offset = tile_k_index * BLOCK_SIZE_K
|
||||
k_remaining_total = K - k_offset
|
||||
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining_total
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
||||
|
||||
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||
|
||||
for n_offset in range(0, N, BLOCK_SIZE_N):
|
||||
n_remaining = N - n_offset
|
||||
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
|
||||
|
||||
grad_y = grad_output_desc.load([m_offset, n_offset])
|
||||
grad_y_mask = row_mask[:, None] & n_mask[None, :]
|
||||
grad_y = tl.where(grad_y_mask, grad_y, tl.zeros_like(grad_y))
|
||||
|
||||
w_tile = w_desc.load([n_offset, k_offset])
|
||||
w_mask = n_mask[:, None] & k_mask[None, :]
|
||||
w_tile = tl.where(w_mask, w_tile, tl.zeros_like(w_tile))
|
||||
|
||||
accumulator += tl.dot(grad_y, w_tile)
|
||||
|
||||
local_row_offsets = tile_m_index * BLOCK_SIZE_M + tl.arange(
|
||||
0, BLOCK_SIZE_M
|
||||
)
|
||||
row_store_mask = local_row_offsets < m_size
|
||||
global_row = (M_start + local_row_offsets).to(tl.int32)
|
||||
|
||||
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
|
||||
col_store_mask = col_offsets < K
|
||||
|
||||
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
|
||||
|
||||
ptr = grad_input_ptr + global_row[:, None] * K + col_offsets[None, :]
|
||||
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
|
||||
|
||||
tbidx += NUM_SMS
|
||||
|
||||
processed_tiles += group_num_tiles
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=_NV_CONFIGS,
|
||||
key=["G", "M_BUCKET", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_mg_dw_tma(
|
||||
x_ptr,
|
||||
grad_output_ptr,
|
||||
grad_weight_ptr,
|
||||
m_sizes,
|
||||
M_TOTAL,
|
||||
# problem sizes
|
||||
G: tl.constexpr,
|
||||
M_BUCKET: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
# config
|
||||
NUM_SMS: tl.constexpr,
|
||||
# tiles
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
) -> None:
|
||||
"""Compute grad_weight = grad_output.T @ x using tensor descriptors."""
|
||||
tbidx = tl.program_id(0)
|
||||
|
||||
c_dtype = grad_weight_ptr.dtype.element_ty
|
||||
|
||||
x_desc = tl.make_tensor_descriptor(
|
||||
x_ptr,
|
||||
shape=[M_TOTAL, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
||||
)
|
||||
grad_output_desc = tl.make_tensor_descriptor(
|
||||
grad_output_ptr,
|
||||
shape=[M_TOTAL, N],
|
||||
strides=[N, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||
)
|
||||
|
||||
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
total_tiles = num_n_tiles * num_k_tiles
|
||||
|
||||
for tile_idx in range(tbidx, total_tiles, NUM_SMS):
|
||||
tile_n_idx = tile_idx % num_n_tiles
|
||||
tile_k_idx = tile_idx // num_n_tiles
|
||||
|
||||
n_offset = tile_n_idx * BLOCK_SIZE_N
|
||||
n_remaining = N - n_offset
|
||||
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
|
||||
|
||||
k_offset = tile_k_idx * BLOCK_SIZE_K
|
||||
k_remaining = K - k_offset
|
||||
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
|
||||
|
||||
M_end = tl.full([], 0, dtype=tl.int32)
|
||||
for g in range(G):
|
||||
M_start = M_end
|
||||
m_size = tl.load(m_sizes + g)
|
||||
M_end = M_start + m_size
|
||||
|
||||
if m_size > 0:
|
||||
for m_offset_local in range(0, m_size, BLOCK_SIZE_M):
|
||||
rows_remaining = m_size - m_offset_local
|
||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||
|
||||
m_offset = (M_start + m_offset_local).to(tl.int32)
|
||||
|
||||
x_block = x_desc.load([m_offset, k_offset])
|
||||
x_mask = row_mask[:, None] & k_mask[None, :]
|
||||
x_block = tl.where(x_mask, x_block, tl.zeros_like(x_block))
|
||||
|
||||
grad_block = grad_output_desc.load([m_offset, n_offset])
|
||||
grad_mask = row_mask[:, None] & n_mask[None, :]
|
||||
grad_block = tl.where(
|
||||
grad_mask, grad_block, tl.zeros_like(grad_block)
|
||||
)
|
||||
|
||||
contribution = tl.dot(
|
||||
grad_block.to(tl.float32).T,
|
||||
x_block.to(tl.float32),
|
||||
)
|
||||
accumulator += contribution
|
||||
|
||||
row_offsets = n_offset + tl.arange(0, BLOCK_SIZE_N)
|
||||
row_store_mask = row_offsets < N
|
||||
|
||||
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
|
||||
col_store_mask = col_offsets < K
|
||||
|
||||
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
|
||||
|
||||
ptr = grad_weight_ptr + row_offsets[:, None] * K + col_offsets[None, :]
|
||||
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
|
||||
|
||||
|
||||
# ======== End Triton kernels ========
|
||||
# ======== End Triton kernels ========
|
||||
|
||||
# ======== Triton wrapper functions ========
|
||||
|
||||
# ----- main forward pass wrapper -----
|
||||
|
||||
|
||||
def grouped_gemm_forward(
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
tma_size: int = 128,
|
||||
using_fp8: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Grouped GEMM forward using Hopper TMA kernels."""
|
||||
_ensure_triton_allocator()
|
||||
if not CudaUtils.verify_tma():
|
||||
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
|
||||
if using_fp8:
|
||||
raise NotImplementedError(
|
||||
"FP8 path not implemented with the new Triton API yet"
|
||||
)
|
||||
|
||||
G = m_sizes.shape[0]
|
||||
|
||||
assert x.is_contiguous()
|
||||
assert w.is_contiguous()
|
||||
assert m_sizes.is_contiguous()
|
||||
|
||||
M_total, K = x.shape
|
||||
N = w.shape[0]
|
||||
assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
|
||||
|
||||
y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
|
||||
if M_total == 0:
|
||||
return y
|
||||
|
||||
NUM_SMS = CudaUtils.get_num_sms()
|
||||
USE_EPILOGUE_SUBTILING = False
|
||||
|
||||
def grid(_meta):
|
||||
return (NUM_SMS,)
|
||||
|
||||
M_BUCKET = triton.next_power_of_2(M_total)
|
||||
_kernel_mg_forward_hopper[grid](
|
||||
x,
|
||||
w,
|
||||
y,
|
||||
m_sizes,
|
||||
M_total,
|
||||
G,
|
||||
M_BUCKET,
|
||||
N,
|
||||
K,
|
||||
NUM_SMS,
|
||||
USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
# ======== Improved Backward =============
|
||||
def grouped_gemm_backward(
|
||||
grad_output: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
use_tma: bool = True,
|
||||
tma_size: int = 128,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Unified backward pass for grouped GeMM with M*G grouping.
|
||||
Uses optimized TMA-based implementations for both dx and dw when available.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient of output, shape [M_total, N]
|
||||
x: Input tensor from forward pass, shape [M_total, K]
|
||||
w: Weight tensor from forward pass, shape [N, K]
|
||||
m_sizes: Group sizes tensor, shape [G]
|
||||
use_tma: Whether to try using TMA acceleration (if available)
|
||||
tma_size: Size of TMA descriptor in bytes
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of gradients with respect to x and w: (grad_x, grad_w)
|
||||
"""
|
||||
logging.info("Starting unified grouped_gemm_backward")
|
||||
|
||||
# do this once, seems expensive
|
||||
NUM_SMS = CudaUtils.get_num_sms()
|
||||
|
||||
# Basic validation
|
||||
M_total, K_x = x.shape
|
||||
M_grad, N = grad_output.shape
|
||||
N_w, K_w = w.shape
|
||||
|
||||
# Check dimensions
|
||||
if K_x != K_w:
|
||||
raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
|
||||
if M_total != M_grad:
|
||||
raise ValueError(
|
||||
f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
|
||||
)
|
||||
|
||||
# Check total M matches sum of group sizes
|
||||
sum_m_sizes = m_sizes.sum().item()
|
||||
if M_total != sum_m_sizes:
|
||||
raise ValueError(
|
||||
f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
||||
)
|
||||
|
||||
# Make sure inputs are contiguous
|
||||
grad_output = grad_output.contiguous()
|
||||
x = x.contiguous()
|
||||
w = w.contiguous()
|
||||
m_sizes = m_sizes.contiguous()
|
||||
|
||||
# Check TMA support
|
||||
if use_tma and not CudaUtils.verify_tma():
|
||||
logging.info("TMA requested but not supported on this device")
|
||||
use_tma = False
|
||||
|
||||
# Compute grad_x using flat linear implementation
|
||||
try:
|
||||
logging.info("Computing grad_x with flat linear kernel")
|
||||
|
||||
# Use TMA-optimized implementation
|
||||
grad_x = grouped_gemm_dx_tma(
|
||||
grad_output=grad_output,
|
||||
w=w,
|
||||
m_sizes=m_sizes,
|
||||
num_sms=NUM_SMS,
|
||||
tma_size=tma_size,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in grad_x computation: {e}")
|
||||
raise
|
||||
|
||||
# Compute grad_w using flat linear style implementation
|
||||
try:
|
||||
logging.info("Computing grad_w with flat linear kernel")
|
||||
|
||||
grad_w = grouped_gemm_dw_tma(
|
||||
x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in grad_w computation: {e}")
|
||||
raise
|
||||
|
||||
return grad_x, grad_w
|
||||
|
||||
|
||||
# ----- dx backward pass wrapper -----
|
||||
|
||||
|
||||
def grouped_gemm_dx_tma(
|
||||
grad_output: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
num_sms: int = 132,
|
||||
tma_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""Compute grad_x using the Hopper grouped GEMM kernel."""
|
||||
_ensure_triton_allocator()
|
||||
if not CudaUtils.verify_tma():
|
||||
raise NotImplementedError("Optimized dx computation requires TMA support")
|
||||
|
||||
grad_output = grad_output.contiguous()
|
||||
w = w.contiguous()
|
||||
m_sizes = m_sizes.contiguous()
|
||||
|
||||
M_total, N = grad_output.shape
|
||||
N_w, K = w.shape
|
||||
if N != N_w:
|
||||
raise ValueError(f"Grad_output N ({N}) must match weight N ({N_w})")
|
||||
|
||||
if m_sizes.sum().item() != M_total:
|
||||
raise ValueError("Sum of m_sizes must equal the number of rows in grad_output")
|
||||
|
||||
grad_x = torch.empty(
|
||||
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
|
||||
)
|
||||
|
||||
NUM_SMS = num_sms
|
||||
|
||||
def grid(_meta):
|
||||
return (NUM_SMS,)
|
||||
|
||||
M_BUCKET = triton.next_power_of_2(M_total)
|
||||
_kernel_mg_dx_tma[grid](
|
||||
grad_output,
|
||||
w,
|
||||
grad_x,
|
||||
m_sizes,
|
||||
M_total,
|
||||
m_sizes.shape[0],
|
||||
M_BUCKET,
|
||||
N,
|
||||
K,
|
||||
NUM_SMS,
|
||||
)
|
||||
return grad_x
|
||||
|
||||
|
||||
# ======== dw wrapper function ==========
|
||||
|
||||
|
||||
def grouped_gemm_dw_tma(
|
||||
x: torch.Tensor,
|
||||
grad_output: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
num_sms: int = 132,
|
||||
tma_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""Compute grad_w using the Hopper grouped GEMM kernel."""
|
||||
_ensure_triton_allocator()
|
||||
if not CudaUtils.verify_tma():
|
||||
raise RuntimeError("TMA grouped GEMM requested on a device without TMA support")
|
||||
|
||||
x = x.contiguous()
|
||||
grad_output = grad_output.contiguous()
|
||||
m_sizes = m_sizes.contiguous()
|
||||
|
||||
M_total, K = x.shape
|
||||
M_grad, N = grad_output.shape
|
||||
if M_total != M_grad:
|
||||
raise ValueError("x and grad_output must have matching batch dimension")
|
||||
if m_sizes.sum().item() != M_total:
|
||||
raise ValueError("Sum of m_sizes must equal the number of rows in the inputs")
|
||||
|
||||
grad_w = torch.zeros((N, K), device=x.device, dtype=x.dtype)
|
||||
|
||||
NUM_SMS = num_sms
|
||||
|
||||
def grid(_meta):
|
||||
return (NUM_SMS,)
|
||||
|
||||
M_BUCKET = triton.next_power_of_2(M_total)
|
||||
_kernel_mg_dw_tma[grid](
|
||||
x,
|
||||
grad_output,
|
||||
grad_w,
|
||||
m_sizes,
|
||||
M_total,
|
||||
m_sizes.shape[0],
|
||||
M_BUCKET,
|
||||
N,
|
||||
K,
|
||||
NUM_SMS,
|
||||
)
|
||||
return grad_w
|
||||
|
||||
|
||||
# ======== End Backwards Wrapper Functions =============
|
||||
|
||||
# ======== PyTorch wrapper functions ========
|
||||
|
||||
|
||||
class GroupedGemmMg(torch.autograd.Function):
|
||||
"""
|
||||
Autograd function for GroupedGEMM with M*G grouping.
|
||||
Supports both standard and FP8 quantized operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128, using_fp8=False):
|
||||
"""
|
||||
Forward pass of GroupedGEMM.
|
||||
|
||||
Args:
|
||||
x: Input tensor, shape [M_total, K]
|
||||
w: Weight tensor, shape [N, K]
|
||||
m_sizes: Tensor of shape [G] containing the size of each group
|
||||
use_tma: Whether to try using TMA acceleration (if available)
|
||||
tma_size: Size of TMA descriptor in bytes
|
||||
using_fp8: Whether to use FP8 quantization
|
||||
|
||||
Returns:
|
||||
Output tensor, shape [M_total, N]
|
||||
"""
|
||||
|
||||
# Use regular forward without quantization
|
||||
output = grouped_gemm_forward(
|
||||
x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
|
||||
)
|
||||
|
||||
# Save inputs and parameters for backward pass
|
||||
ctx.save_for_backward(x, w, m_sizes)
|
||||
ctx.use_tma = use_tma
|
||||
ctx.tma_size = tma_size
|
||||
|
||||
ctx.save_for_backward(x, w, m_sizes)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
"""
|
||||
Backward pass of M*G GroupedGEMM.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient of output, shape [M_total, N]
|
||||
|
||||
Returns:
|
||||
Tuple of gradients:
|
||||
- grad_x: Gradient with respect to x, shape [M_total, K]
|
||||
- grad_w: Gradient with respect to w, shape [N, K]
|
||||
- None: Gradient with respect to m_sizes (not differentiable)
|
||||
- None: Gradient with respect to use_tma (not differentiable)
|
||||
- None: Gradient with respect to tma_size (not differentiable)
|
||||
|
||||
"""
|
||||
# Retrieve saved tensors and parameters
|
||||
|
||||
x, w, m_sizes = ctx.saved_tensors
|
||||
|
||||
use_tma = ctx.use_tma
|
||||
tma_size = ctx.tma_size
|
||||
|
||||
# Compute gradients using the unified implementation
|
||||
grad_x, grad_w = grouped_gemm_backward(
|
||||
grad_output=grad_output,
|
||||
x=x,
|
||||
w=w,
|
||||
m_sizes=m_sizes,
|
||||
use_tma=use_tma,
|
||||
tma_size=tma_size,
|
||||
)
|
||||
|
||||
# Return gradients for all inputs (None for non-differentiable parameters)
|
||||
return grad_x, grad_w, None, None, None, None
|
||||
|
||||
|
||||
def mg_grouped_gemm(
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
use_tma: bool = True,
|
||||
tma_size: int = 128,
|
||||
using_fp8: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Unified differentiable grouped GEMM operation for M*G grouped GEMM.
|
||||
Supports both standard precision and FP8 quantized operations.
|
||||
|
||||
Args:
|
||||
x: Input tensor, shape [M_total, K]
|
||||
w: Weight tensor, shape [N, K]
|
||||
m_sizes: Tensor of shape [G] containing the size of each group
|
||||
use_tma: Whether to try using TMA acceleration (if available)
|
||||
tma_size: Size of TMA descriptor in bytes
|
||||
using_fp8: Whether to use FP8 quantization
|
||||
|
||||
Returns:
|
||||
Output tensor, shape [M_total, N]
|
||||
"""
|
||||
return GroupedGemmMg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)
|
||||
232
src/axolotl/kernels/moe/tt_mg_gemm/tma_autotuning.py
Normal file
232
src/axolotl/kernels/moe/tt_mg_gemm/tma_autotuning.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# credit - TMAHelper class, AutoTuning are derived from FBGemm:
|
||||
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from triton.runtime import driver # @manual
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
# ===== Supporting utils, CUDA and TMA =====
|
||||
|
||||
|
||||
class CudaUtils:
|
||||
@staticmethod
|
||||
def is_cuda() -> bool:
|
||||
"""Check if Triton is running on CUDA backend."""
|
||||
return driver.active.get_current_target().backend == "cuda"
|
||||
|
||||
@staticmethod
|
||||
def verify_tma() -> bool:
|
||||
"""Check if TMA is supported on the current device."""
|
||||
return (
|
||||
CudaUtils.is_cuda()
|
||||
and torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability()[0] >= 9
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_num_sms() -> int:
|
||||
"""Get the number of streaming multiprocessors on the current device."""
|
||||
if not CudaUtils.is_cuda():
|
||||
raise RuntimeError("Triton is not running on CUDA backend")
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is not available")
|
||||
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
|
||||
|
||||
class TmaDescriptorHelper:
|
||||
"""Helper class for managing TMA descriptors in Triton kernels.
|
||||
|
||||
Args:
|
||||
tma_size: Size of the TMA descriptor in bytes
|
||||
"""
|
||||
|
||||
class KernelParamWrapper:
|
||||
"""Wrapper to implement the TmaDescKernelParam interface."""
|
||||
|
||||
def __init__(self, desc: torch.Tensor):
|
||||
self.desc = desc
|
||||
|
||||
def tma_desc_cpu_ptr(self) -> int:
|
||||
"""Return the CPU pointer to the TMA descriptor."""
|
||||
return self.desc.data_ptr()
|
||||
|
||||
def __init__(self, tma_size: int = 128):
|
||||
if not CudaUtils.verify_tma():
|
||||
raise RuntimeError(
|
||||
"TMA not supported on this device (requires Hopper or newer)"
|
||||
)
|
||||
|
||||
self.tma_size = tma_size
|
||||
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor
|
||||
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor
|
||||
self.descriptors: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def init_tma_descriptor(self, name: str) -> None:
|
||||
"""Initialize a TMA descriptor with the given name.
|
||||
|
||||
Call this method outside of the lambda function for grid size.
|
||||
"""
|
||||
self.descriptors[name] = torch.empty(
|
||||
self.tma_size, device="cpu", dtype=torch.int8
|
||||
)
|
||||
|
||||
def fill_1d_tma_descriptor(
|
||||
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
|
||||
) -> None:
|
||||
"""Fill a 1D TMA descriptor.
|
||||
|
||||
Call this method inside the lambda function for grid size.
|
||||
"""
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_1d_tma_descriptor_inner(
|
||||
ptr, dim, block_dim, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def fill_2d_tma_descriptor(
|
||||
self,
|
||||
name: str,
|
||||
ptr: int,
|
||||
dim1: int,
|
||||
dim0: int,
|
||||
block_dim1: int,
|
||||
block_dim0: int,
|
||||
element_size: int,
|
||||
) -> None:
|
||||
"""Fill a 2D TMA descriptor.
|
||||
|
||||
Call this method inside the lambda function for grid size.
|
||||
"""
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_2d_tma_descriptor_inner(
|
||||
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
|
||||
"""Get the TMA descriptor kernel parameter for the given name."""
|
||||
if name not in self.descriptors or self.descriptors[name] is None:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
return self.KernelParamWrapper(self.descriptors[name])
|
||||
|
||||
|
||||
# ====== Autotuning utilities ======
|
||||
ALIGN_SIZE_M = 128
|
||||
|
||||
_NV_CONFIGS = [
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_size_m,
|
||||
"BLOCK_SIZE_N": block_size_n,
|
||||
"BLOCK_SIZE_K": block_size_k,
|
||||
},
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
num_ctas=num_ctas,
|
||||
)
|
||||
for block_size_m in [
|
||||
ALIGN_SIZE_M,
|
||||
]
|
||||
for block_size_n in [64, 128, 256]
|
||||
for block_size_k in [64, 128, 256]
|
||||
for num_stages in [3, 4]
|
||||
for num_warps in [4, 8]
|
||||
for num_ctas in [1]
|
||||
]
|
||||
|
||||
|
||||
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
|
||||
device = torch.cuda.current_device()
|
||||
# Check for all possible pointer parameter names
|
||||
if "grad_input_ptr" in named_args:
|
||||
ptr_name = "grad_input_ptr"
|
||||
elif "c_ptr" in named_args:
|
||||
ptr_name = "c_ptr"
|
||||
elif "grad_weight_ptr" in named_args:
|
||||
ptr_name = "grad_weight_ptr"
|
||||
else:
|
||||
raise KeyError("No recognized pointer parameter found in kernel arguments")
|
||||
|
||||
if dtsize is None:
|
||||
dtsize = named_args[ptr_name].element_size()
|
||||
if dtype is None:
|
||||
dtype = named_args[ptr_name].dtype
|
||||
|
||||
pruned_configs = []
|
||||
for config in configs:
|
||||
kw = config.kwargs
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
|
||||
kw["BLOCK_SIZE_M"],
|
||||
kw["BLOCK_SIZE_N"],
|
||||
kw["BLOCK_SIZE_K"],
|
||||
config.num_stages,
|
||||
)
|
||||
G, M, N, K = (
|
||||
named_args["G"],
|
||||
named_args["M_BUCKET"],
|
||||
named_args["N"],
|
||||
named_args["K"],
|
||||
)
|
||||
|
||||
# 1. make sure we have enough smem
|
||||
max_shared_memory = driver.active.utils.get_device_properties(device)[
|
||||
"max_shared_mem"
|
||||
]
|
||||
|
||||
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
||||
if required_shared_memory > max_shared_memory:
|
||||
continue
|
||||
|
||||
M_PER_GROUP = M // G
|
||||
MIN_M_TILES = 64
|
||||
# 2. make sure we don't load M tiles that are too big
|
||||
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
|
||||
continue
|
||||
# 3. make sure we don't load N tiles that are too small
|
||||
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
|
||||
continue
|
||||
|
||||
num_sm = driver.active.utils.get_device_properties(device)[
|
||||
"multiprocessor_count"
|
||||
]
|
||||
N_TILES = N // BLOCK_N
|
||||
MIN_N_TILES = 64
|
||||
# 4. make sure we don't load N tiles that are too big
|
||||
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
|
||||
continue
|
||||
# 5. make sure we don't load N tiles that are too small
|
||||
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
|
||||
continue
|
||||
# 6. make sure K can be evenly divided
|
||||
if K % BLOCK_K != 0:
|
||||
continue
|
||||
|
||||
pruned_configs.append(config)
|
||||
|
||||
return pruned_configs
|
||||
|
||||
|
||||
# ======== End Autotuning utilities ========
|
||||
@@ -84,9 +84,7 @@ class PatchManager:
|
||||
patch_evaluation_loop()
|
||||
patch_maybe_log_save_evaluate()
|
||||
|
||||
if self.cfg.context_parallel_size > 1 and getattr(
|
||||
self.cfg, "flash_attention", False
|
||||
):
|
||||
if self.cfg.context_parallel_size > 1:
|
||||
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
||||
patch_prepare_context_parallel_inputs,
|
||||
)
|
||||
@@ -192,6 +190,15 @@ class PatchManager:
|
||||
|
||||
apply_mistral_tokenizer_image_patch()
|
||||
|
||||
if self.cfg.moe_kernels and self.cfg.model_config_type == "deepseek_v3":
|
||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
|
||||
|
||||
patch_deepseek_v3_moe(backend=self.cfg.moe_kernel_backend)
|
||||
elif self.cfg.model_config_type == "deepseek_v3" and not self.cfg.moe_kernels:
|
||||
LOG.info(
|
||||
"Skipping DeepSeek V3 Triton MoE kernels; enable with `moe_kernels: true`"
|
||||
)
|
||||
|
||||
def _apply_fp8_patches(self):
|
||||
"""Apply patches for FP8 support."""
|
||||
if self.cfg.fp8:
|
||||
|
||||
401
src/axolotl/monkeypatch/deepseek_v3/__init__.py
Normal file
401
src/axolotl/monkeypatch/deepseek_v3/__init__.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""Monkeypatches for DeepSeek V3 MoE to use Triton contiguous grouped GEMM kernels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.kernels.moe import ContiguousGroupedGEMM
|
||||
from axolotl.kernels.moe.indices import generate_permute_indices
|
||||
from axolotl.kernels.moe.tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
_GROUP_SIZE_M = 128
|
||||
_COMBINED_SUBMODULES = ("gate_proj", "up_proj", "down_proj")
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
|
||||
if not hidden_states.is_cuda or hidden_states.shape[0] == 0:
|
||||
return False
|
||||
major, _ = torch.cuda.get_device_capability(hidden_states.device)
|
||||
if major < 9:
|
||||
LOG.debug(
|
||||
"Skipping Triton MoE kernels: requires compute capability >= 90, found %s",
|
||||
major,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _ensure_combined_expert_weights(
|
||||
module, dtype: torch.dtype, device: torch.device, backend: str
|
||||
) -> None:
|
||||
if not hasattr(module, "_axolotl_original_specs"):
|
||||
module._axolotl_original_specs = {}
|
||||
if not hasattr(module, "_axolotl_mg_shapes"):
|
||||
module._axolotl_mg_shapes = {}
|
||||
|
||||
prev_backend = getattr(module, "_axolotl_combined_backend", None)
|
||||
if getattr(module, "_axolotl_combined_weights", False):
|
||||
if prev_backend != backend:
|
||||
_restore_expert_weights(module)
|
||||
else:
|
||||
for name in _COMBINED_SUBMODULES:
|
||||
param_name = f"{name}_weight"
|
||||
param = module.get_parameter(param_name)
|
||||
if param.device != device or param.dtype != dtype:
|
||||
module._parameters[param_name] = torch.nn.Parameter(
|
||||
param.to(device=device, dtype=dtype).contiguous()
|
||||
)
|
||||
module._axolotl_combined_dtype = dtype
|
||||
module._axolotl_combined_device = device
|
||||
module._axolotl_combined_backend = backend
|
||||
return
|
||||
|
||||
module._axolotl_mg_shapes = {}
|
||||
for name in _COMBINED_SUBMODULES:
|
||||
weights = []
|
||||
orig_device = None
|
||||
orig_dtype = None
|
||||
orig_shape = None
|
||||
for expert in module.experts:
|
||||
lin = expert.get_submodule(name)
|
||||
weight_param = lin._parameters.get("weight")
|
||||
if weight_param is None:
|
||||
raise RuntimeError("Expected expert linear layers to have weights")
|
||||
if orig_device is None:
|
||||
orig_device = weight_param.device
|
||||
orig_dtype = weight_param.dtype
|
||||
orig_shape = tuple(weight_param.shape)
|
||||
weights.append(weight_param.detach().to(device=device, dtype=dtype))
|
||||
if "weight" in lin._parameters:
|
||||
del lin._parameters["weight"]
|
||||
if "bias" in lin._parameters:
|
||||
del lin._parameters["bias"]
|
||||
if backend == "cg":
|
||||
combined_weight = torch.stack(weights, dim=0).contiguous()
|
||||
else:
|
||||
combined_weight = torch.cat(weights, dim=0).contiguous()
|
||||
module._axolotl_mg_shapes[name] = orig_shape
|
||||
module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined_weight))
|
||||
module._axolotl_original_specs[name] = (orig_device, orig_dtype, orig_shape)
|
||||
|
||||
module._axolotl_combined_weights = True
|
||||
module._axolotl_combined_dtype = dtype
|
||||
module._axolotl_combined_device = device
|
||||
module._axolotl_combined_backend = backend
|
||||
|
||||
|
||||
def _restore_expert_weights(module) -> None:
|
||||
if not getattr(module, "_axolotl_combined_weights", False):
|
||||
return
|
||||
|
||||
for name in _COMBINED_SUBMODULES:
|
||||
param_name = f"{name}_weight"
|
||||
combined = module._parameters.pop(param_name)
|
||||
orig_device, orig_dtype, orig_shape = module._axolotl_original_specs.get(
|
||||
name, (combined.device, combined.dtype, None)
|
||||
)
|
||||
rows_per = orig_shape[0] if orig_shape else None
|
||||
for idx, expert in enumerate(module.experts):
|
||||
lin = expert.get_submodule(name)
|
||||
if combined.dim() == 3:
|
||||
slice_tensor = combined[idx]
|
||||
elif rows_per is not None:
|
||||
start = idx * rows_per
|
||||
end = start + rows_per
|
||||
slice_tensor = combined[start:end]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unable to recover expert weight shape during restore"
|
||||
)
|
||||
lin._parameters["weight"] = torch.nn.Parameter(
|
||||
slice_tensor.detach().clone().to(orig_device, dtype=orig_dtype)
|
||||
)
|
||||
|
||||
module._axolotl_combined_weights = False
|
||||
module._axolotl_combined_dtype = None
|
||||
module._axolotl_combined_device = None
|
||||
module._axolotl_combined_backend = None
|
||||
module._axolotl_original_specs = {}
|
||||
module._axolotl_mg_shapes = {}
|
||||
|
||||
|
||||
def _run_cg_grouped_gemm(
|
||||
module,
|
||||
grouped_hidden: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
num_experts: int,
|
||||
group_size_m: int,
|
||||
hidden_dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
_ensure_combined_expert_weights(module, hidden_dtype, device, backend="cg")
|
||||
|
||||
expert_index_tensor = torch.repeat_interleave(
|
||||
torch.arange(num_experts, device=device, dtype=torch.int32),
|
||||
m_sizes.to(torch.int64),
|
||||
)
|
||||
|
||||
gate_weights = module.get_parameter("gate_proj_weight")
|
||||
if gate_weights.dim() == 2:
|
||||
out_dim = gate_weights.shape[0] // num_experts
|
||||
gate_weights = gate_weights.view(num_experts, out_dim, gate_weights.shape[1])
|
||||
|
||||
up_weights = module.get_parameter("up_proj_weight")
|
||||
if up_weights.dim() == 2:
|
||||
out_dim = up_weights.shape[0] // num_experts
|
||||
up_weights = up_weights.view(num_experts, out_dim, up_weights.shape[1])
|
||||
|
||||
down_weights = module.get_parameter("down_proj_weight")
|
||||
if down_weights.dim() == 2:
|
||||
out_dim = down_weights.shape[0] // num_experts
|
||||
down_weights = down_weights.view(num_experts, out_dim, down_weights.shape[1])
|
||||
|
||||
gate_out = ContiguousGroupedGEMM.apply(
|
||||
grouped_hidden,
|
||||
gate_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
up_out = ContiguousGroupedGEMM.apply(
|
||||
grouped_hidden,
|
||||
up_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
return (
|
||||
gate_out.to(hidden_dtype),
|
||||
up_out.to(hidden_dtype),
|
||||
down_weights,
|
||||
expert_index_tensor,
|
||||
)
|
||||
|
||||
gate_out = mg_grouped_gemm(
|
||||
grouped_hidden,
|
||||
module.get_parameter("gate_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
)
|
||||
up_out = mg_grouped_gemm(
|
||||
grouped_hidden,
|
||||
module.get_parameter("up_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
)
|
||||
down_out = mg_grouped_gemm(
|
||||
hidden_grouped,
|
||||
module.get_parameter("down_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
)
|
||||
|
||||
return (
|
||||
gate_out.to(hidden_dtype),
|
||||
up_out.to(hidden_dtype),
|
||||
down_out.to(hidden_dtype),
|
||||
)
|
||||
|
||||
|
||||
def _moe_triton_forward(
|
||||
module,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
group_size_m: int,
|
||||
backend: str,
|
||||
fallback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if not _is_triton_eligible(hidden_states):
|
||||
return fallback(hidden_states, topk_indices, topk_weights)
|
||||
|
||||
device = hidden_states.device
|
||||
hidden_dtype = hidden_states.dtype
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
top_k = topk_indices.size(-1)
|
||||
|
||||
expanded_hidden = hidden_states.repeat_interleave(top_k, dim=0)
|
||||
expert_assignments = topk_indices.reshape(-1)
|
||||
if expanded_hidden.numel() == 0:
|
||||
return hidden_states.new_zeros_like(hidden_states)
|
||||
|
||||
sort_perm = torch.argsort(expert_assignments)
|
||||
sorted_hidden = expanded_hidden.index_select(0, sort_perm)
|
||||
sorted_assignments = expert_assignments.index_select(0, sort_perm)
|
||||
|
||||
num_experts = len(module.experts)
|
||||
counts = torch.bincount(sorted_assignments, minlength=num_experts)
|
||||
total_actual = int(counts.sum().item())
|
||||
if total_actual == 0:
|
||||
return hidden_states.new_zeros_like(hidden_states)
|
||||
|
||||
if not getattr(module, "_axolotl_triton_logged", False):
|
||||
min_tokens = int(counts.min().item())
|
||||
max_tokens = int(counts.max().item())
|
||||
LOG.info(
|
||||
"DeepseekV3MoE Triton: tokens per expert (min=%s, max=%s, avg=%.1f) with group_size=%s",
|
||||
min_tokens,
|
||||
max_tokens,
|
||||
total_actual / max(1, num_experts),
|
||||
group_size_m,
|
||||
)
|
||||
module._axolotl_triton_logged = True
|
||||
|
||||
counts_int = counts.to(torch.int32)
|
||||
aligned_counts = (
|
||||
(torch.clamp_min(counts_int, group_size_m) + group_size_m - 1) // group_size_m
|
||||
) * group_size_m
|
||||
max_len = int(aligned_counts.sum().item())
|
||||
|
||||
permuted_indices, m_sizes, _ = generate_permute_indices(
|
||||
counts_int.to(device),
|
||||
experts_per_rank=num_experts,
|
||||
num_ranks=1,
|
||||
max_len=max_len,
|
||||
alignment=group_size_m,
|
||||
use_cpu=not hidden_states.is_cuda,
|
||||
)
|
||||
|
||||
permuted_indices = permuted_indices.to(device)
|
||||
m_sizes = m_sizes.to(device)
|
||||
|
||||
permuted_indices_long = permuted_indices.to(torch.int64)
|
||||
valid_mask = permuted_indices_long >= 0
|
||||
valid_positions = torch.nonzero(valid_mask, as_tuple=False).squeeze(-1)
|
||||
source_indices = permuted_indices_long[valid_mask]
|
||||
padded_positions = torch.nonzero(~valid_mask, as_tuple=False).squeeze(-1)
|
||||
|
||||
grouped_hidden = hidden_states.new_empty((max_len, hidden_dim))
|
||||
if valid_positions.numel() > 0:
|
||||
grouped_hidden.index_copy_(
|
||||
0,
|
||||
valid_positions,
|
||||
sorted_hidden.index_select(0, source_indices),
|
||||
)
|
||||
if valid_positions.numel() < max_len:
|
||||
grouped_hidden.index_fill_(0, padded_positions, 0)
|
||||
|
||||
m_sizes_tensor = m_sizes.to(device=device, dtype=torch.int32)
|
||||
|
||||
if backend == "mg":
|
||||
_ensure_combined_expert_weights(module, hidden_dtype, device, backend)
|
||||
gate_out = mg_grouped_gemm(
|
||||
grouped_hidden,
|
||||
module.get_parameter("gate_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
).to(hidden_dtype)
|
||||
up_out = mg_grouped_gemm(
|
||||
grouped_hidden,
|
||||
module.get_parameter("up_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
).to(hidden_dtype)
|
||||
else:
|
||||
gate_out, up_out, down_weights, expert_index_tensor = _run_cg_grouped_gemm(
|
||||
module,
|
||||
grouped_hidden,
|
||||
m_sizes,
|
||||
num_experts,
|
||||
group_size_m,
|
||||
hidden_dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
|
||||
if valid_positions.numel() > 0:
|
||||
gate_valid = gate_out.index_select(0, valid_positions)
|
||||
up_valid = up_out.index_select(0, valid_positions)
|
||||
hidden_concat = act_fn(gate_valid) * up_valid
|
||||
else:
|
||||
hidden_concat = torch.empty(
|
||||
(0, gate_out.shape[-1]), device=device, dtype=hidden_dtype
|
||||
)
|
||||
|
||||
intermediate_dim = hidden_concat.shape[-1]
|
||||
hidden_grouped = hidden_states.new_empty((max_len, intermediate_dim))
|
||||
if valid_positions.numel() > 0:
|
||||
hidden_grouped.index_copy_(0, valid_positions, hidden_concat)
|
||||
if valid_positions.numel() < max_len:
|
||||
hidden_grouped.index_fill_(0, padded_positions, 0)
|
||||
|
||||
if backend == "mg":
|
||||
down_out = mg_grouped_gemm(
|
||||
hidden_grouped,
|
||||
module.get_parameter("down_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
).to(hidden_dtype)
|
||||
else:
|
||||
down_out = ContiguousGroupedGEMM.apply(
|
||||
hidden_grouped,
|
||||
down_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
).to(hidden_dtype)
|
||||
|
||||
if valid_positions.numel() > 0:
|
||||
down_valid = down_out.index_select(0, valid_positions)
|
||||
else:
|
||||
down_valid = torch.empty(
|
||||
(0, down_out.shape[-1]), device=device, dtype=hidden_dtype
|
||||
)
|
||||
|
||||
sorted_outputs = hidden_states.new_zeros((total_actual, hidden_dim))
|
||||
if down_valid.numel() > 0:
|
||||
sorted_outputs.index_copy_(0, source_indices, down_valid)
|
||||
|
||||
expanded_output = expanded_hidden.new_empty(expanded_hidden.shape)
|
||||
expanded_output.index_copy_(0, sort_perm, sorted_outputs)
|
||||
expert_outputs = expanded_output.view(num_tokens, top_k, hidden_dim)
|
||||
|
||||
weighted = expert_outputs * topk_weights.unsqueeze(-1).to(hidden_dtype)
|
||||
return weighted.sum(dim=1)
|
||||
|
||||
|
||||
def patch_deepseek_v3_moe(
|
||||
group_size_m: int = _GROUP_SIZE_M, backend: str = "mg"
|
||||
) -> None:
|
||||
"""Patch HuggingFace DeepseekV3MoE to use Triton contiguous group GEMM kernels."""
|
||||
|
||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
|
||||
if backend not in {"cg", "mg"}:
|
||||
raise ValueError(f"Unsupported MoE kernel backend: {backend}")
|
||||
|
||||
# Record the unpatched implementation so callers can access a true baseline even
|
||||
# after the Triton patch has been applied (e.g. repeated microbenchmarks).
|
||||
if not hasattr(DeepseekV3MoE, "_axolotl_triton_original_moe"):
|
||||
DeepseekV3MoE._axolotl_triton_original_moe = DeepseekV3MoE.moe
|
||||
|
||||
if getattr(DeepseekV3MoE, "_axolotl_triton_patch", False):
|
||||
return
|
||||
|
||||
original_moe = DeepseekV3MoE._axolotl_triton_original_moe
|
||||
DeepseekV3MoE._axolotl_triton_backend = backend
|
||||
DeepseekV3MoE._axolotl_group_size_m = group_size_m
|
||||
|
||||
def patched_moe(self, hidden_states, topk_indices, topk_weights):
|
||||
backend_sel = getattr(self, "_axolotl_triton_backend", backend)
|
||||
group_size_sel = getattr(self, "_axolotl_group_size_m", group_size_m)
|
||||
if backend_sel == "cg" and group_size_sel != _GROUP_SIZE_M:
|
||||
LOG.debug(
|
||||
"Adjusting group_size_m=%s to %s for CG backend",
|
||||
group_size_sel,
|
||||
_GROUP_SIZE_M,
|
||||
)
|
||||
group_size_sel = _GROUP_SIZE_M
|
||||
try:
|
||||
return _moe_triton_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
topk_indices,
|
||||
topk_weights,
|
||||
group_size_sel,
|
||||
backend_sel,
|
||||
original_moe,
|
||||
)
|
||||
except Exception as err: # surface Triton failures explicitly
|
||||
_restore_expert_weights(self)
|
||||
LOG.error("DeepseekV3MoE Triton path failed: %s", err)
|
||||
raise
|
||||
|
||||
DeepseekV3MoE.moe = patched_moe
|
||||
DeepseekV3MoE._axolotl_triton_patch = True
|
||||
@@ -13,10 +13,21 @@ from typing import Callable
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
import transformers.modeling_flash_attention_utils as flash_utils
|
||||
import transformers.modeling_flash_attention_utils
|
||||
from ring_flash_attn import ring_flash_attn_func
|
||||
from ring_flash_attn.adapters.hf_adapter import check_params
|
||||
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
|
||||
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
||||
except ImportError:
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size as _flash_supports_window,
|
||||
)
|
||||
except ImportError:
|
||||
_flash_supports_window = True
|
||||
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
@@ -107,7 +118,7 @@ def create_flash_attn_forward_varlen_llama3(
|
||||
|
||||
# Handle sliding window
|
||||
use_sliding_windows = (
|
||||
_flash_windows_supported()
|
||||
_flash_supports_window
|
||||
and sliding_window is not None
|
||||
and key_states.shape[1] > sliding_window
|
||||
)
|
||||
@@ -183,18 +194,3 @@ def substitute_hf_flash_attn(
|
||||
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
|
||||
|
||||
|
||||
def _flash_windows_supported() -> bool:
|
||||
"""Return whether current transformers build advertises sliding-window support."""
|
||||
support = getattr(flash_utils, "_flash_supports_window", None)
|
||||
if support is None:
|
||||
support = getattr(flash_utils, "_flash_supports_window_size", None)
|
||||
|
||||
if support is None:
|
||||
return True
|
||||
|
||||
if callable(support):
|
||||
return True
|
||||
|
||||
return bool(support)
|
||||
|
||||
@@ -13,9 +13,18 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers.modeling_flash_attention_utils as flash_utils
|
||||
from torch.distributed import DeviceMesh
|
||||
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
||||
except ImportError:
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size as _flash_supports_window,
|
||||
)
|
||||
except ImportError:
|
||||
_flash_supports_window = True
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
@@ -74,7 +83,7 @@ def create_ring_flash_attention_forward(
|
||||
|
||||
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
||||
use_sliding_windows = (
|
||||
_flash_windows_supported()
|
||||
_flash_supports_window
|
||||
and sliding_window is not None
|
||||
and key_states.shape[1] > sliding_window
|
||||
)
|
||||
@@ -216,19 +225,3 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
|
||||
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
|
||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||
|
||||
|
||||
def _flash_windows_supported() -> bool:
|
||||
"""Best-effort check for FlashAttention sliding-window support."""
|
||||
support = getattr(flash_utils, "_flash_supports_window", None)
|
||||
if support is None:
|
||||
support = getattr(flash_utils, "_flash_supports_window_size", None)
|
||||
|
||||
if support is None:
|
||||
return True
|
||||
|
||||
if callable(support):
|
||||
# Signature differs across versions; assume support when callable.
|
||||
return True
|
||||
|
||||
return bool(support)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
HF Chat Templates prompt strategy
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Set, Union
|
||||
|
||||
@@ -794,6 +795,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
if val is not None:
|
||||
transformed_message[key] = val
|
||||
|
||||
if "tool_calls" in transformed_message and transformed_message["tool_calls"]:
|
||||
for tool_call in transformed_message["tool_calls"]:
|
||||
if "function" in tool_call and "arguments" in tool_call["function"]:
|
||||
args = tool_call["function"]["arguments"]
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
tool_call["function"]["arguments"] = json.loads(args)
|
||||
except json.JSONDecodeError as e:
|
||||
LOG.error(
|
||||
f"Error parsing tool_calls arguments as JSON. "
|
||||
f"Function: {tool_call.get('function', {}).get('name', 'unknown')}, "
|
||||
f"Arguments string: {args!r}, "
|
||||
f"Error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
return transformed_message
|
||||
|
||||
def _get_images(self, prompt):
|
||||
|
||||
@@ -179,11 +179,7 @@ def execute_training(
|
||||
)
|
||||
)
|
||||
|
||||
use_flash_cp = cfg.context_parallel_size > 1 and bool(
|
||||
getattr(cfg, "flash_attention", False)
|
||||
)
|
||||
|
||||
if use_flash_cp:
|
||||
if cfg.context_parallel_size > 1:
|
||||
models = [trainer.model]
|
||||
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
||||
models.append(trainer.ref_model)
|
||||
|
||||
@@ -113,6 +113,19 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
moe_kernels: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Enable Axolotl's vendored MoE kernels when supported (e.g., DeepSeek V3)"
|
||||
},
|
||||
)
|
||||
moe_kernel_backend: Literal["cg", "mg"] | None = Field(
|
||||
default="mg",
|
||||
json_schema_extra={
|
||||
"description": "Grouped GEMM backend to use when `moe_kernels` is enabled. `mg` selects the Hopper TMA kernel; `cg` selects the contiguous kernel."
|
||||
},
|
||||
)
|
||||
|
||||
trainer_cls: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Module with validation methods for config pydantic model."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
@@ -1313,40 +1314,50 @@ class ComplexValidationMixin:
|
||||
if not self.context_parallel_size:
|
||||
self.context_parallel_size = 1
|
||||
elif self.context_parallel_size > 1:
|
||||
use_flash_attention = getattr(self, "flash_attention", False)
|
||||
use_sdp_attention = getattr(self, "sdp_attention", False)
|
||||
|
||||
if not (use_flash_attention or use_sdp_attention):
|
||||
if not self.flash_attention:
|
||||
raise ValueError(
|
||||
"context_parallel_size > 1 requires either flash_attention: true "
|
||||
"or sdp_attention: true"
|
||||
"flash_attention: true must be set with context_parallel_size > 1"
|
||||
)
|
||||
|
||||
if use_flash_attention:
|
||||
if self.sample_packing and self.micro_batch_size > 1:
|
||||
raise ValueError(
|
||||
"micro_batch_size must be set to 1 when sample_packing is enabled "
|
||||
"due to a `ring-flash-attn` requirement"
|
||||
)
|
||||
|
||||
try:
|
||||
import ring_flash_attn # noqa: F401 # Required after monkey-patching
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"context_parallel_size > 1 but ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
|
||||
LOG.warning(
|
||||
"Sequence parallelism (SP) is enabled with "
|
||||
f"context_parallel_size={self.context_parallel_size}. "
|
||||
"Please note that logged losses may differ slightly to the non-SP "
|
||||
"losses due to transformers Trainer implementation details. "
|
||||
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||
"for more details."
|
||||
if self.sample_packing and self.micro_batch_size > 1:
|
||||
raise ValueError(
|
||||
"micro_batch_size must be set to 1 when sample_packing is enabled "
|
||||
"due to a `ring-flash-attn` requirement"
|
||||
)
|
||||
|
||||
try:
|
||||
import transformers.modeling_flash_attention_utils
|
||||
from transformers.utils import is_flash_attn_greater_or_equal
|
||||
|
||||
transformers.modeling_flash_attention_utils._flash_supports_window = (
|
||||
True
|
||||
)
|
||||
sys.modules[
|
||||
"transformers.modeling_flash_attention_utils"
|
||||
]._flash_supports_window = True
|
||||
sys.modules[
|
||||
"transformers.modeling_flash_attention_utils"
|
||||
]._flash_supports_window_size = True
|
||||
sys.modules[
|
||||
"transformers.modeling_flash_attention_utils"
|
||||
].is_flash_attn_greater_or_equal = is_flash_attn_greater_or_equal
|
||||
import ring_flash_attn # noqa: F401 # Required after monkey-patching
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"context_parallel_size > 1 but ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
|
||||
LOG.warning(
|
||||
"Sequence parallelism (SP) is enabled with "
|
||||
f"context_parallel_size={self.context_parallel_size}. "
|
||||
"Please note that logged losses may differ slightly to the non-SP "
|
||||
"losses due to transformers Trainer implementation details. "
|
||||
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||
"for more details."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
@@ -23,8 +23,6 @@ class TestSequenceParallelism:
|
||||
pad_to_sequence_len=True,
|
||||
ring_attn_func=None,
|
||||
threshold=2.0,
|
||||
flash_attention=True,
|
||||
sdp_attention=False,
|
||||
):
|
||||
"""Helper method to run sequence parallel tests with different configurations"""
|
||||
cfg = DictDefault(
|
||||
@@ -60,8 +58,7 @@ class TestSequenceParallelism:
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": flash_attention,
|
||||
"sdp_attention": sdp_attention,
|
||||
"flash_attention": True,
|
||||
"loss_watchdog_threshold": 5.0,
|
||||
"loss_watchdog_patience": 3,
|
||||
"bf16": "auto",
|
||||
@@ -135,16 +132,3 @@ class TestSequenceParallelism:
|
||||
ring_attn_func=ring_attn_func,
|
||||
threshold=threshold,
|
||||
)
|
||||
|
||||
def test_sequence_parallel_training_sdpa(self, temp_dir):
|
||||
"""Smoke test for SDPA-based context parallelism."""
|
||||
self._run_sequence_parallel_test(
|
||||
temp_dir,
|
||||
sample_packing=False,
|
||||
micro_batch_size=1,
|
||||
pad_to_sequence_len=True,
|
||||
ring_attn_func=None,
|
||||
threshold=3.0,
|
||||
flash_attention=False,
|
||||
sdp_attention=True,
|
||||
)
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
"""Tests for PatchManager context parallel patch selection."""
|
||||
|
||||
import addict
|
||||
|
||||
from axolotl.loaders.patch_manager import PatchManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def _stub_transformers_patches(monkeypatch):
|
||||
"""Replace trainer loss patchers with no-ops for isolation."""
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.transformers.trainer_loss_calc.patch_evaluation_loop",
|
||||
lambda: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.transformers.trainer_loss_calc.patch_maybe_log_save_evaluate",
|
||||
lambda: None,
|
||||
)
|
||||
|
||||
|
||||
def test_patch_manager_applies_flash_cp_patch(monkeypatch):
|
||||
"""When flash attention is enabled, we patch Trainer for CP."""
|
||||
_stub_transformers_patches(monkeypatch)
|
||||
|
||||
patch_calls = {"count": 0}
|
||||
|
||||
def stub_patch():
|
||||
patch_calls["count"] += 1
|
||||
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.transformers.trainer_context_parallel.patch_prepare_context_parallel_inputs",
|
||||
stub_patch,
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"context_parallel_size": 2,
|
||||
"flash_attention": True,
|
||||
"sdp_attention": False,
|
||||
}
|
||||
)
|
||||
|
||||
manager = PatchManager(cfg, addict.Dict())
|
||||
manager._apply_transformers_patches()
|
||||
|
||||
assert patch_calls["count"] == 1
|
||||
|
||||
|
||||
def test_patch_manager_skips_flash_patch_for_sdpa(monkeypatch):
|
||||
"""When only SDPA is requested, we should not patch Trainer."""
|
||||
_stub_transformers_patches(monkeypatch)
|
||||
|
||||
patch_calls = {"count": 0}
|
||||
|
||||
def stub_patch():
|
||||
patch_calls["count"] += 1
|
||||
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.transformers.trainer_context_parallel.patch_prepare_context_parallel_inputs",
|
||||
stub_patch,
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"context_parallel_size": 2,
|
||||
"flash_attention": False,
|
||||
"sdp_attention": True,
|
||||
}
|
||||
)
|
||||
|
||||
manager = PatchManager(cfg, addict.Dict())
|
||||
manager._apply_transformers_patches()
|
||||
|
||||
assert patch_calls["count"] == 0
|
||||
@@ -177,6 +177,15 @@ def fixture_devstral_1_1_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_tokenizer")
|
||||
def qwen3_tokenizer_fixture(
|
||||
download_qwen3_half_billion_model,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja")
|
||||
def fixture_mistralv03_chat_template_jinja_w_system() -> str:
|
||||
return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n'
|
||||
|
||||
@@ -6,7 +6,6 @@ import json
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import StrategyLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -23,15 +22,6 @@ def fixture_messages_w_tools():
|
||||
return Dataset.from_list(rows)
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_tokenizer")
|
||||
def qwen3_tokenizer_fixture(
|
||||
download_qwen3_half_billion_model,
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_prompt_strategy")
|
||||
def qwen3_chat_template_strategy(qwen3_tokenizer):
|
||||
cfg = DictDefault(
|
||||
|
||||
@@ -4,7 +4,6 @@ Tests for splitting reasoning/thinking from content into separate field
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import (
|
||||
load,
|
||||
@@ -56,15 +55,6 @@ def messages_w_reasoning_fixture():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_tokenizer")
|
||||
def qwen3_tokenizer_fixture(
|
||||
download_qwen3_half_billion_model,
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestSplitThinking:
|
||||
"""
|
||||
test class to make sure datasets with reasoning content conforms to the chat_template strategy
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Tests for handling json tool content
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import (
|
||||
load,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_instruct_prompt_strategy")
|
||||
def qwen3_instruct_chat_template_strategy(qwen3_tokenizer):
|
||||
strategy = load(
|
||||
qwen3_tokenizer,
|
||||
DictDefault(
|
||||
{
|
||||
"train_on_inputs": False,
|
||||
"sequence_len": 512,
|
||||
}
|
||||
),
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "qwen3",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"message_property_mappings": {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
"roles": {
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
"system": ["system"],
|
||||
},
|
||||
"field_messages": "messages",
|
||||
}
|
||||
),
|
||||
)
|
||||
return strategy
|
||||
|
||||
|
||||
class TestQwen3IdenticalConversationArgs:
|
||||
"""
|
||||
Test Qwen3 tools is identical between JSON and dict
|
||||
"""
|
||||
|
||||
@pytest.fixture(name="conversation_dict_args_dataset")
|
||||
def fixture_conversation_dict_args_dataset(self):
|
||||
"""
|
||||
Provides a dataset with conversation where arguments is a dict.
|
||||
"""
|
||||
user_content = "What is the weather in Boston?"
|
||||
function_name = "get_current_weather"
|
||||
arguments_dict = {"location": "Boston, MA", "unit": "celsius"}
|
||||
|
||||
data = [
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": user_content},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": arguments_dict, # dict格式
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
return Dataset.from_list(data)
|
||||
|
||||
@pytest.fixture(name="conversation_str_args_dataset")
|
||||
def fixture_conversation_str_args_dataset(self):
|
||||
"""
|
||||
Provides a dataset with conversation where arguments is a JSON string.
|
||||
"""
|
||||
user_content = "What is the weather in Boston?"
|
||||
function_name = "get_current_weather"
|
||||
arguments_dict = {"location": "Boston, MA", "unit": "celsius"}
|
||||
arguments_str = json.dumps(arguments_dict)
|
||||
|
||||
data = [
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": user_content},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": arguments_str, # str格式
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
return Dataset.from_list(data)
|
||||
|
||||
@pytest.fixture(name="conversation_mixed_time_types_dataset")
|
||||
def fixture_conversation_mixed_time_types_dataset(self):
|
||||
"""
|
||||
Provides a dataset where 'time' field has different types in different tool calls.
|
||||
"""
|
||||
data = [
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Get weather information at different times",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "func1",
|
||||
"arguments": json.dumps(
|
||||
{"time": "2025-08-01"}
|
||||
), # string type
|
||||
}
|
||||
},
|
||||
{
|
||||
"function": {
|
||||
"name": "func2",
|
||||
"arguments": json.dumps(
|
||||
{"time": 1690876800}
|
||||
), # number type
|
||||
}
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
return Dataset.from_list(data)
|
||||
|
||||
def test_dict_and_str_args_produce_identical_output(
|
||||
self,
|
||||
conversation_dict_args_dataset,
|
||||
conversation_str_args_dataset,
|
||||
qwen3_instruct_prompt_strategy,
|
||||
qwen3_tokenizer,
|
||||
):
|
||||
"""
|
||||
Tests that after tokenization and decoding, the outputs for both
|
||||
dict and string `arguments` are exactly the same.
|
||||
"""
|
||||
processed_dict_args = conversation_dict_args_dataset.map(
|
||||
qwen3_instruct_prompt_strategy.tokenize_prompt,
|
||||
batched=True,
|
||||
remove_columns=["messages"],
|
||||
)
|
||||
|
||||
processed_str_args = conversation_str_args_dataset.map(
|
||||
qwen3_instruct_prompt_strategy.tokenize_prompt,
|
||||
batched=True,
|
||||
remove_columns=["messages"],
|
||||
)
|
||||
|
||||
decoded_prompt_from_dict = qwen3_tokenizer.decode(
|
||||
processed_dict_args[0]["input_ids"]
|
||||
)
|
||||
|
||||
decoded_prompt_from_str = qwen3_tokenizer.decode(
|
||||
processed_str_args[0]["input_ids"]
|
||||
)
|
||||
|
||||
assert decoded_prompt_from_dict == decoded_prompt_from_str, (
|
||||
f"Dict format output:\n{decoded_prompt_from_dict}\n"
|
||||
f"String format output:\n{decoded_prompt_from_str}"
|
||||
)
|
||||
|
||||
assert (
|
||||
processed_dict_args[0]["input_ids"] == processed_str_args[0]["input_ids"]
|
||||
), "The tokenized input_ids should be identical for dict and str arguments"
|
||||
|
||||
def test_str_args_with_mixed_time_types_no_error(
|
||||
self,
|
||||
conversation_mixed_time_types_dataset,
|
||||
qwen3_instruct_prompt_strategy,
|
||||
qwen3_tokenizer,
|
||||
):
|
||||
"""
|
||||
Tests that when 'time' field has different types (string vs number)
|
||||
in different tool calls, str format arguments don't cause errors.
|
||||
"""
|
||||
processed = conversation_mixed_time_types_dataset.map(
|
||||
qwen3_instruct_prompt_strategy.tokenize_prompt,
|
||||
batched=True,
|
||||
remove_columns=["messages"],
|
||||
)
|
||||
|
||||
assert len(processed) == 1
|
||||
assert "input_ids" in processed[0]
|
||||
assert len(processed[0]["input_ids"]) > 0
|
||||
|
||||
decoded = qwen3_tokenizer.decode(processed[0]["input_ids"])
|
||||
assert "2025-08-01" in decoded, "String time value should be present"
|
||||
assert "1690876800" in decoded, "Number time value should be present"
|
||||
@@ -1,111 +0,0 @@
|
||||
"""Unit tests for choosing the correct context parallel implementation."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from axolotl.train import execute_training
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class DummyTrainer:
|
||||
"""Minimal trainer stub to exercise execute_training."""
|
||||
|
||||
def __init__(self):
|
||||
self.model = object()
|
||||
self.ref_model = None
|
||||
self.accelerator = SimpleNamespace(torch_device_mesh=None)
|
||||
self.train_called = False
|
||||
|
||||
def train(self, resume_from_checkpoint=None): # pylint: disable=unused-argument
|
||||
self.train_called = True
|
||||
|
||||
|
||||
class DummyPluginManager:
|
||||
"""Minimal plugin manager stub."""
|
||||
|
||||
@staticmethod
|
||||
def post_train(cfg, model): # pylint: disable=unused-argument
|
||||
return None
|
||||
|
||||
|
||||
class DummyContext:
|
||||
"""Test context manager that records entries/exits."""
|
||||
|
||||
def __init__(self, recorder, **kwargs):
|
||||
recorder.append({"kwargs": kwargs})
|
||||
self.recorder = recorder
|
||||
|
||||
def __enter__(self):
|
||||
self.recorder[-1]["entered"] = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb): # pylint: disable=unused-argument
|
||||
self.recorder[-1]["exited"] = True
|
||||
return False
|
||||
|
||||
|
||||
def _base_cfg(**overrides):
|
||||
base = {
|
||||
"context_parallel_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"ring_attn_func": None,
|
||||
"heads_k_stride": None,
|
||||
"rl": None,
|
||||
"flash_optimum": False,
|
||||
}
|
||||
base.update(overrides)
|
||||
return DictDefault(base)
|
||||
|
||||
|
||||
def test_execute_training_uses_ring_when_flash(monkeypatch):
|
||||
"""FlashAttention CP should engage the custom ring context manager."""
|
||||
recorder: list[dict] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"axolotl.train.SequenceParallelContextManager",
|
||||
lambda **kwargs: DummyContext(recorder, **kwargs),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"axolotl.train.PluginManager.get_instance",
|
||||
lambda: DummyPluginManager(),
|
||||
)
|
||||
|
||||
cfg = _base_cfg(flash_attention=True, sdp_attention=False)
|
||||
trainer = DummyTrainer()
|
||||
|
||||
execute_training(cfg, trainer, resume_from_checkpoint=None)
|
||||
|
||||
assert trainer.train_called
|
||||
assert len(recorder) == 1
|
||||
assert recorder[0]["kwargs"]["context_parallel_size"] == 2
|
||||
assert recorder[0].get("entered") is True
|
||||
assert recorder[0].get("exited") is True
|
||||
|
||||
|
||||
def test_execute_training_uses_transformers_cp_for_sdpa(monkeypatch):
|
||||
"""SDPA CP should bypass the ring context manager."""
|
||||
invoked = {"count": 0}
|
||||
|
||||
class NoOpContext:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb): # pylint: disable=unused-argument
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"axolotl.train.SequenceParallelContextManager",
|
||||
lambda **kwargs: invoked.__setitem__("count", invoked["count"] + 1)
|
||||
or NoOpContext(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"axolotl.train.PluginManager.get_instance",
|
||||
lambda: DummyPluginManager(),
|
||||
)
|
||||
|
||||
cfg = _base_cfg(flash_attention=False, sdp_attention=True)
|
||||
trainer = DummyTrainer()
|
||||
|
||||
execute_training(cfg, trainer, resume_from_checkpoint=None)
|
||||
|
||||
assert trainer.train_called
|
||||
assert invoked["count"] == 0
|
||||
Reference in New Issue
Block a user