Compare commits
48 Commits
moekernels
...
vendor-moe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dd85358543 | ||
|
|
55d98db0d0 | ||
|
|
d90ade3b1b | ||
|
|
824a641cee | ||
|
|
e003a05177 | ||
|
|
91393c4dc8 | ||
|
|
d578c53603 | ||
|
|
4db7a21ff7 | ||
|
|
3b2e05c563 | ||
|
|
1037ca3a97 | ||
|
|
6369dcd7b8 | ||
|
|
a81612305c | ||
|
|
d0da67eb17 | ||
|
|
8a1f5ae940 | ||
|
|
146ca48cba | ||
|
|
fd312f6058 | ||
|
|
ab8fa56b16 | ||
|
|
1640cd4006 | ||
|
|
3277d44d71 | ||
|
|
d3e1b0ef1a | ||
|
|
5b97633faa | ||
|
|
94cbc6d42d | ||
|
|
493616fc3d | ||
|
|
d2b25c7327 | ||
|
|
b670c45276 | ||
|
|
61faf4cbe4 | ||
|
|
8d8fa834a2 | ||
|
|
9d69c6fb3e | ||
|
|
92f2f6e73c | ||
|
|
e5d2aebe16 | ||
|
|
4ab9e3f58b | ||
|
|
5788832812 | ||
|
|
db782430f8 | ||
|
|
5c74edeefe | ||
|
|
18269ee6a9 | ||
|
|
6a45d804f9 | ||
|
|
95e607574a | ||
|
|
f9748c4dc5 | ||
|
|
33975ce4bc | ||
|
|
e8b962d47f | ||
|
|
856ff12171 | ||
|
|
6bc959342b | ||
|
|
b3b92687c4 | ||
|
|
55d1be2ae6 | ||
|
|
08d831c3d5 | ||
|
|
7be8740c5c | ||
|
|
c51d6b06c3 | ||
|
|
09959fac70 |
@@ -267,6 +267,7 @@ website:
|
||||
- docs/dataset_loading.qmd
|
||||
- docs/qat.qmd
|
||||
- docs/quantize.qmd
|
||||
- docs/optimizations.qmd
|
||||
|
||||
- section: "Core Concepts"
|
||||
contents:
|
||||
@@ -285,7 +286,6 @@ website:
|
||||
- docs/custom_integrations.qmd
|
||||
- docs/sequence_parallelism.qmd
|
||||
- docs/gradient_checkpointing.qmd
|
||||
- docs/moe_backends.md
|
||||
- docs/nd_parallelism.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
|
||||
@@ -212,6 +212,14 @@ Instead of passing `tools` via the system prompt, an alternative method would be
|
||||
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
If you have tool arguments with same name but different dtypes (like `"time": string` and `"time": number`), please save `arguments: ` as JSON string to prevent `datasets` from having casting issues.
|
||||
|
||||
```
|
||||
"arguments": "{\"...\": \"...\"}"
|
||||
```
|
||||
:::
|
||||
|
||||
Example config for Llama4:
|
||||
```yaml
|
||||
chat_template: llama4
|
||||
|
||||
@@ -61,7 +61,7 @@ While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet
|
||||
|
||||
### Pre-training without streaming
|
||||
|
||||
On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
|
||||
In the case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
|
||||
|
||||
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.
|
||||
|
||||
|
||||
@@ -140,3 +140,7 @@ description: Frequently asked questions
|
||||
**Q: `ValueError("Backward pass should have cleared tracker of all tensors")`
|
||||
|
||||
> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML.
|
||||
|
||||
**Q: `Error parsing tool_calls arguments as JSON.`
|
||||
|
||||
> A: There is an error parsing string arguments to a dict. Please check your dataset and the error message for more details.
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
MoE Backends in Axolotl
|
||||
|
||||
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via the training config (YAML):
|
||||
|
||||
- Set `moe_backend: auto|torch_grouped|naive`
|
||||
|
||||
Behavior
|
||||
- auto (default): prefers PyTorch 2.8+ grouped GEMM; otherwise naive.
|
||||
- torch_grouped: targets PyTorch 2.8+ grouped GEMM (H100/SM90+ recommended).
|
||||
- naive: keeps the reference per-expert loop.
|
||||
|
||||
Notes
|
||||
- Current implementation wires the backend selector and routes Mixtral MoE through it. Torch grouped uses cuBLASLt grouped GEMM when available; otherwise, the code falls back to the naive per-expert loop.
|
||||
- No changes to training scripts are required; selection happens inside the model forward.
|
||||
|
||||
Example
|
||||
moe_backend: torch_grouped
|
||||
accelerate launch -m axolotl.cli.train path/to/config.yaml
|
||||
@@ -13,6 +13,7 @@ format:
|
||||
- [Pixtral](#sec-pixtral)
|
||||
- [Llava-1.5](#sec-llava-15)
|
||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||
- [Magistral-Small-2509](#sec-magistral-small-2509)
|
||||
- [Voxtral](#sec-voxtral)
|
||||
- [Gemma-3](#sec-gemma-3)
|
||||
- [Gemma-3n](#sec-gemma-3n)
|
||||
@@ -41,7 +42,6 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
# (optional) if doing lora, only finetune the Language model,
|
||||
# leave the vision model and vision tower frozen
|
||||
@@ -94,10 +94,22 @@ chat_template: llava
|
||||
|
||||
### Mistral-Small-3.1 {#sec-mistral-small-31}
|
||||
|
||||
::: {.callout-tip}
|
||||
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
```
|
||||
|
||||
chat_template: mistral_v7_tekken
|
||||
### Magistral-Small-2509 {#sec-magistral-small-2509}
|
||||
|
||||
::: {.callout-tip}
|
||||
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Magistral-Small-2509
|
||||
```
|
||||
|
||||
### Voxtral {#sec-voxtral}
|
||||
|
||||
133
docs/optimizations.qmd
Normal file
133
docs/optimizations.qmd
Normal file
@@ -0,0 +1,133 @@
|
||||
---
|
||||
title: Optimizations Guide
|
||||
description: A guide to the performance and memory optimizations available in Axolotl.
|
||||
---
|
||||
|
||||
Axolotl includes numerous optimizations to speed up training, reduce memory usage, and handle large models.
|
||||
|
||||
This guide provides a high-level overview and directs you to the detailed documentation for each feature.
|
||||
|
||||
## Speed Optimizations
|
||||
|
||||
These optimizations focus on increasing training throughput and reducing total training time.
|
||||
|
||||
### Sample Packing
|
||||
|
||||
Improves GPU utilization by combining multiple short sequences into a single packed sequence for training. This requires enabling one of the [attention](#attention-implementations) implementations below.
|
||||
|
||||
- **Config:** `sample_packing: true`
|
||||
- **Learn more:** [Sample Packing](multipack.qmd)
|
||||
|
||||
### Attention Implementations
|
||||
|
||||
Using an optimized attention implementation is critical for training speed.
|
||||
|
||||
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
|
||||
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
|
||||
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
|
||||
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
|
||||
|
||||
*Note: You should only enable one attention backend.*
|
||||
|
||||
### LoRA Optimizations
|
||||
|
||||
Leverages optimized kernels to accelerate LoRA training and reduce memory usage.
|
||||
|
||||
- **Learn more:** [LoRA Optimizations Documentation](lora_optims.qmd)
|
||||
|
||||
## Memory Optimizations
|
||||
|
||||
These techniques help you fit larger models or use bigger batch sizes on your existing hardware.
|
||||
|
||||
### Parameter Efficient Finetuning (LoRA & QLoRA)
|
||||
|
||||
Drastically reduces memory by training a small set of "adapter" parameters instead of the full model. This is the most common and effective memory-saving technique.
|
||||
|
||||
- Examples: Find configs with `lora` or `qlora` in the [examples directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-3).
|
||||
- Config Reference: See `adapter`, `load_in_4bit`, and `load_in_8bit` in the [Configuration Reference](config-reference.qmd).
|
||||
|
||||
### Gradient Checkpointing & Activation Offloading
|
||||
|
||||
These techniques save VRAM by changing how activations are handled.
|
||||
|
||||
- Gradient Checkpointing: re-computes activations during the backward pass, trading compute time for VRAM.
|
||||
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
|
||||
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
|
||||
|
||||
### Cut Cross Entropy (CCE)
|
||||
|
||||
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
|
||||
|
||||
- **Learn more:** [Custom Integrations - CCE](custom_integrations.qmd#cut-cross-entropy)
|
||||
|
||||
### Liger Kernels
|
||||
|
||||
Provides efficient Triton kernels to improve training speed and reduce memory usage.
|
||||
|
||||
- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)
|
||||
|
||||
## Long Context Models
|
||||
|
||||
Techniques to train models on sequences longer than their original context window.
|
||||
|
||||
### RoPE Scaling
|
||||
|
||||
Extends a model's context window by interpolating its Rotary Position Embeddings.
|
||||
|
||||
- **Config:** Pass the `rope_scaling` config under the `overrides_of_model_config: `. To learn how to set RoPE, check the respective model config.
|
||||
|
||||
### Sequence Parallelism
|
||||
|
||||
Splits long sequences across multiple GPUs, enabling training with sequence lengths that would not fit on a single device.
|
||||
|
||||
- **Learn more:** [Sequence Parallelism Documentation](sequence_parallelism.qmd)
|
||||
|
||||
### Artic Long Sequence Training (ALST)
|
||||
|
||||
ALST is a recipe that combines several techniques to train long-context models efficiently. It typically involves:
|
||||
|
||||
- TiledMLP to reduce memory usage in MLP layers.
|
||||
- Tiled Loss functions (like [CCE](#cut-cross-entropy-(cce) or [Liger](#liger-kernels)).
|
||||
- Activation Offloading to CPU.
|
||||
|
||||
- Example: [ALST Example Configuration](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst)
|
||||
|
||||
## Large Models (Distributed Training)
|
||||
|
||||
To train models that don't fit on a single GPU, you'll need to use a distributed training strategy like FSDP or DeepSpeed. These frameworks shard the model weights, gradients, and optimizer states across multiple GPUs and nodes.
|
||||
|
||||
- **Learn more:** [Multi-GPU Guide](multi-gpu.qmd)
|
||||
- **Learn more:** [Multi-Node Guide](multi-node.qmd)
|
||||
|
||||
### N-D Parallelism (Beta)
|
||||
|
||||
For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once.
|
||||
|
||||
- **Learn more:** [N-D Parallelism Guide](nd_parallelism.qmd)
|
||||
|
||||
|
||||
## Quantization
|
||||
|
||||
Techniques to reduce the precision of model weights for memory savings.
|
||||
|
||||
### 4-bit Training (QLoRA)
|
||||
|
||||
The recommended approach for quantization-based training. It loads the base model in 4-bit using `bitsandbytes` and then trains QLoRA adapters. See [Adapter Finetuning](#adapter-finetuning-lora-qlora) for details.
|
||||
|
||||
### FP8 Training
|
||||
|
||||
Enables training with 8-bit floating point precision on supported hardware (e.g., NVIDIA Hopper series GPUs) for significant speed and memory gains.
|
||||
|
||||
- **Example:** [Llama 3 FP8 FSDP Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-3/3b-fp8-fsdp2.yaml)
|
||||
|
||||
### Quantization Aware Training (QAT)
|
||||
|
||||
Simulates quantization effects during training, helping the model adapt and potentially improving the final accuracy of the quantized model.
|
||||
|
||||
- **Learn more:** [QAT Documentation](qat.qmd)
|
||||
|
||||
### GPTQ
|
||||
|
||||
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
|
||||
|
||||
- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)
|
||||
@@ -30,6 +30,7 @@ qat:
|
||||
```
|
||||
|
||||
We support the following quantization schemas:
|
||||
|
||||
- `Int4WeightOnly` (requires the `fbgemm-gpu` extra when installing Axolotl)
|
||||
- `Int8DynamicActivationInt4Weight`
|
||||
- `Float8DynamicActivationFloat8Weight`
|
||||
|
||||
@@ -7,3 +7,24 @@ techniques. It is a combination of:
|
||||
- Activation Offloading: Offload activations to CPU RAM to reduce memory usage
|
||||
|
||||
For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996).
|
||||
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
tiled_mlp: true
|
||||
|
||||
# See Sequence Parallelism docs
|
||||
# https://docs.axolotl.ai/docs/sequence_parallelism.html
|
||||
context_parallel_size: int
|
||||
|
||||
plugins:
|
||||
# See Cut Cross Entropy docs
|
||||
# https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
# or Liger Kernel docs
|
||||
# https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
# ...
|
||||
|
||||
```
|
||||
|
||||
110
examples/apertus/README.md
Normal file
110
examples/apertus/README.md
Normal file
@@ -0,0 +1,110 @@
|
||||
# Finetune Swiss-AI's Apertus with Axolotl
|
||||
|
||||
[Apertus](https://huggingface.co/collections/swiss-ai/apertus-llm-68b699e65415c231ace3b059) is a family of opensource models trained by Swiss-ai.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Apertus is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. (Optional, highly recommended) Install XIELU CUDA
|
||||
|
||||
```bash
|
||||
## Recommended for reduced VRAM and faster speeds
|
||||
|
||||
# Point to CUDA toolkit directory
|
||||
# For those using our Docker image, use the below path.
|
||||
export CUDA_HOME=/usr/local/cuda
|
||||
|
||||
pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/apertus/apertus-8b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 8.7 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Tips
|
||||
|
||||
- For inference, the official Apertus team recommends `top_p=0.9` and `temperature=0.8`.
|
||||
- You can instead use full paremter fine-tuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
### XIELU Installation Issues
|
||||
|
||||
#### `ModuleNotFoundError: No module named 'torch'`
|
||||
|
||||
Please check these one by one:
|
||||
- Running in correct environment
|
||||
- Env has PyTorch installed
|
||||
- CUDA toolkit is at `CUDA_HOME`
|
||||
|
||||
If those didn't help, please try the below solutions:
|
||||
|
||||
1. Pass env for CMAKE and try install again:
|
||||
|
||||
```bash
|
||||
Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
2. Git clone the repo and manually hardcode python path:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/nickjbrowning/XIELU
|
||||
cd xielu
|
||||
git checkout 59d6031
|
||||
|
||||
cd xielu
|
||||
nano CMakeLists.txt # or vi depending on your preference
|
||||
```
|
||||
|
||||
```diff
|
||||
execute_process(
|
||||
- COMMAND ${Python_EXECUTABLE} -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
|
||||
+ COMMAND /root/miniconda3/envs/py3.11/bin/python -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
|
||||
RESULT_VARIABLE TORCH_CMAKE_PATH_RESULT
|
||||
OUTPUT_VARIABLE TORCH_CMAKE_PATH_OUTPUT
|
||||
ERROR_VARIABLE TORCH_CMAKE_PATH_ERROR
|
||||
)
|
||||
```
|
||||
|
||||
```bash
|
||||
pip3 install . --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Apertus Tech Report](https://github.com/swiss-ai/apertus-tech-report/blob/main/Apertus_Tech_Report.pdf)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
64
examples/apertus/apertus-8b-qlora.yaml
Normal file
64
examples/apertus/apertus-8b-qlora.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
base_model: swiss-ai/Apertus-8B-Instruct-2509
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -19,6 +19,9 @@ cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
@@ -9,10 +9,6 @@ strict: false
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -9,10 +9,6 @@ strict: false
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -9,10 +9,6 @@ strict: false
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -18,7 +18,7 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
@@ -23,7 +23,15 @@ pip3 install timm==1.0.17
|
||||
pip3 install librosa==0.11.0
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
3. Download sample dataset files
|
||||
|
||||
```bash
|
||||
# for text + vision + audio only
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# text only
|
||||
|
||||
@@ -12,15 +12,6 @@ chat_template: llama3
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -46,7 +46,6 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -45,7 +45,6 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# Finetune Magistral Small with Axolotl
|
||||
|
||||
Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506) and [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506), [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)), and [2509](https://huggingface.co/mistralai/Magistral-Small-2509) (see [Vision](#vision)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
MistralAI has also released a proprietary medium-sized version called Magistral Medium.
|
||||
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for this release.
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for these releases.
|
||||
|
||||
## Getting started
|
||||
|
||||
@@ -36,29 +36,17 @@ Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Thinking
|
||||
|
||||
MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities. The model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
|
||||
MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps.
|
||||
|
||||
Example format:
|
||||
📚 **[See the Thinking fine-tuning guide →](./think/README.md)**
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
||||
{"role": "user", "content": [{ "type": "text", "text": "..."}]},
|
||||
{"role": "assistant", "content": [{ "type": "thinking", "thinking": "..."}, { "type": "text", "text": "..." }]},
|
||||
],
|
||||
}
|
||||
```
|
||||
### Vision
|
||||
|
||||
Example config: `./magistral-small-think-qlora.yaml`.
|
||||
MistralAI has released their [2509](https://huggingface.co/mistralai/Magistral-Small-2509) model with vision capabilities.
|
||||
|
||||
The `thinking` section also supports an optional arg `closed: bool` (`True` default) which controls adding the closing `[/THINK]` tag.
|
||||
📚 **[See the Vision fine-tuning guide →](./vision/README.md)**
|
||||
|
||||
Limitations:
|
||||
- You cannot mix `content: str` with `content: list[dict]` as the `dataset.load_dataset` may complain about different types for `content` key.
|
||||
- This mode does not work with custom `train_detail` and `training` at the moment.
|
||||
|
||||
### TIPS
|
||||
### Tips
|
||||
|
||||
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
|
||||
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
||||
@@ -89,5 +77,5 @@ In addition, we do not support overriding tokens yet.
|
||||
|
||||
## Future Work
|
||||
|
||||
- Add parity to Preference Tuning, RL, Multi-modal, etc.
|
||||
- Add parity to Preference Tuning, RL, etc.
|
||||
- Add parity to other tokenizer configs like overriding tokens.
|
||||
|
||||
73
examples/magistral/think/README.md
Normal file
73
examples/magistral/think/README.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# Magistral Small Thinking Fine-tuning
|
||||
|
||||
This guide covers fine-tuning [Magistral Small 2507](https://huggingface.co/mistralai/Magistral-Small-2507) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
- Installed Axolotl (see [main README](../README.md))
|
||||
|
||||
## Getting Started
|
||||
|
||||
Run the thinking model fine-tuning:
|
||||
|
||||
```bash
|
||||
axolotl train magistral-small-think-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 19.1 GiB VRAM.
|
||||
|
||||
### Tips
|
||||
|
||||
- Dataset uses multi-content format with `type: thinking` support. See [Dataset Format](#dataset-format) below.
|
||||
- You cannot mix `content: str` and `content: list[dict]`, otherwise, dataset loading will fail. Keep it consistent.
|
||||
|
||||
## Dataset Format
|
||||
|
||||
The thinking model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
|
||||
|
||||
Example format:
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{ "type": "text", "text": "{SYSTEM_PROMPT}"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{ "type": "text", "text": "Solve this step by step: What is 15% of 240?"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36."
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "To find 15% of 240, I'll multiply 240 by 0.15:\n\n240 × 0.15 = 36\n\nTherefore, 15% of 240 is 36."
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Options
|
||||
|
||||
The `thinking` section supports an optional `closed` parameter:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "Internal reasoning here...",
|
||||
"closed": true // Default: true, controls adding the closing [/THINK] tag
|
||||
}
|
||||
```
|
||||
60
examples/magistral/vision/README.md
Normal file
60
examples/magistral/vision/README.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# Magistral Small Vision Fine-tuning
|
||||
|
||||
This guide covers fine-tuning [Magistral Small 2509](https://huggingface.co/mistralai/Magistral-Small-2509) with vision capabilities using Axolotl.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
- Installed Axolotl from source (see [main README](../README.md#getting-started))
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install the required vision lib:
|
||||
```bash
|
||||
pip install 'mistral-common[opencv]==1.8.5'
|
||||
```
|
||||
|
||||
2. Download the example dataset image:
|
||||
```bash
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
```
|
||||
|
||||
3. Run the fine-tuning:
|
||||
```bash
|
||||
axolotl train magistral-small-vision-24B-qlora.yml
|
||||
```
|
||||
|
||||
This config uses about 17GiB VRAM.
|
||||
|
||||
WARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.
|
||||
|
||||
### Tips
|
||||
|
||||
Key differences from text-only model:
|
||||
- `max_tokens: 131072` for inference
|
||||
- Multi-modal dataset format required
|
||||
- Sample packing not supported
|
||||
|
||||
## Dataset Format
|
||||
|
||||
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
||||
{"role": "user", "content": [
|
||||
{ "type": "text", "text": "What's in this image?"},
|
||||
{"type": "image", "path": "path/to/image.jpg" }
|
||||
]},
|
||||
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- Sample Packing is not supported for multi-modality training currently.
|
||||
@@ -0,0 +1,64 @@
|
||||
base_model: mistralai/Magistral-Small-2509
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
# sample dataset below requires downloading image in advance
|
||||
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
datasets:
|
||||
- path: Nanobit/text-vision-2k-test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,6 +1,9 @@
|
||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
load_in_8bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
@@ -8,12 +11,12 @@ skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: mistral_v7_tekken
|
||||
# sample dataset below requires downloading image in advance
|
||||
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
- path: Nanobit/text-vision-2k-test
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
@@ -48,8 +51,7 @@ tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
# flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
|
||||
sdp_attention: true
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
@@ -1,53 +0,0 @@
|
||||
base_model: Qwen/Qwen1.5-MoE-A2.7B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code: true
|
||||
|
||||
# Keep VRAM low
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/qwen2-moe-qlora-10gb
|
||||
|
||||
# Train small to fit 10GB
|
||||
sequence_len: 512
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 5
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.03
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
model_config:
|
||||
output_router_logits: true
|
||||
|
||||
special_tokens:
|
||||
@@ -12,15 +12,6 @@ chat_template: phi_3
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -45,8 +45,7 @@ tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
# flash_attention: # PixtralVisionModel does not support Flash Attention 2.0 yet
|
||||
sdp_attention: true
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
@@ -11,7 +11,7 @@ datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
64
examples/qwen3-next/README.md
Normal file
64
examples/qwen3-next/README.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# Finetune Qwen3-Next with Axolotl
|
||||
|
||||
[Qwen3-Next](https://huggingface.co/collections/Qwen/qwen3-next-68c25fd6838e585db8eeea9d) represents the next-generation foundation models optimized for extreme context length and large-scale parameter efficiency. The series introduces architectural innovations including Hybrid Attention (Gated DeltaNet + Gated Attention), High-Sparsity MoE with 1:50 activation ratio, and Multi-Token Prediction for enhanced performance and inference acceleration.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. Install Qwen3-Next transformers commit
|
||||
```bash
|
||||
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
|
||||
```
|
||||
|
||||
3. Install FLA for improved performance
|
||||
```bash
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 45.62 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. See [Multi-GPU](#optimization-guides) section below.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Qwen3-Next Blog](https://qwenlm.github.io/blog/qwen3_next/)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
68
examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
Normal file
68
examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: Qwen/Qwen3-Next-80B-A3B-Instruct
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 8
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- linear_attn.in_proj_ba
|
||||
- linear_attn.in_proj_qkvz
|
||||
- linear_attn.out_proj
|
||||
- shared_expert.up_proj
|
||||
- shared_expert.down_proj
|
||||
- shared_expert.gate_proj
|
||||
- shared_expert_gate
|
||||
- mlp.gate
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -27,7 +27,14 @@ pip3 install 'mistral_common[audio]==1.8.3'
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
3. Download sample dataset files
|
||||
|
||||
```bash
|
||||
# for text + audio only
|
||||
wget https://huggingface.co/datasets/Nanobit/text-audio-2k-test/resolve/main/En-us-African_elephant.oga
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# text only
|
||||
|
||||
@@ -70,4 +70,4 @@ schedulefree==1.4.1
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.5
|
||||
|
||||
mistral-common==1.8.3
|
||||
mistral-common==1.8.5
|
||||
|
||||
1
scripts/__init__.py
Normal file
1
scripts/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Utility scripts package."""
|
||||
@@ -1,209 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""Benchmark Hugging Face Qwen2 MoE block with and without grouped_mm."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
|
||||
try:
|
||||
from axolotl.kernels.moe import torch_grouped as tg
|
||||
except Exception: # pragma: no cover
|
||||
tg = None
|
||||
|
||||
|
||||
def bench(run, *, iters: int, warmup: int, sync: bool = True) -> float:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
for _ in range(warmup):
|
||||
run()
|
||||
if sync and device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
times = []
|
||||
for _ in range(iters):
|
||||
if sync and device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
run()
|
||||
if sync and device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
times.append((time.perf_counter() - start) * 1000.0)
|
||||
return sum(times) / len(times)
|
||||
|
||||
|
||||
def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
|
||||
return 6.0 * tokens * top_k * hidden * inter
|
||||
|
||||
|
||||
def load_hf_block(
|
||||
hidden: int,
|
||||
inter: int,
|
||||
experts: int,
|
||||
top_k: int,
|
||||
*,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
transformers_src = project_root / "transformers" / "src"
|
||||
if transformers_src.exists() and str(transformers_src) not in sys.path:
|
||||
sys.path.append(str(transformers_src))
|
||||
|
||||
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
|
||||
cfg = Qwen2MoeConfig(
|
||||
hidden_size=hidden,
|
||||
moe_intermediate_size=inter,
|
||||
shared_expert_intermediate_size=inter,
|
||||
num_experts=experts,
|
||||
num_experts_per_tok=top_k,
|
||||
norm_topk_prob=True,
|
||||
qkv_bias=True,
|
||||
)
|
||||
|
||||
block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
|
||||
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
|
||||
block_grouped.load_state_dict(block.state_dict())
|
||||
return block, block_grouped
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description="Qwen2 MoE grouped_mm benchmark")
|
||||
p.add_argument("--bsz", type=int, default=8)
|
||||
p.add_argument("--seq", type=int, default=1024)
|
||||
p.add_argument("--hidden", type=int, default=4096)
|
||||
p.add_argument("--inter", type=int, default=14336)
|
||||
p.add_argument("--experts", type=int, default=32)
|
||||
p.add_argument("--top_k", type=int, default=4)
|
||||
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
|
||||
p.add_argument("--iters", type=int, default=50)
|
||||
p.add_argument("--warmup", type=int, default=10)
|
||||
p.add_argument("--profile", action="store_true")
|
||||
p.add_argument(
|
||||
"--compile",
|
||||
action="store_true",
|
||||
help="Torch.compile both paths before benchmarking",
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
}[args.dtype]
|
||||
|
||||
torch.manual_seed(0)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
block_naive, block_grouped = load_hf_block(
|
||||
args.hidden,
|
||||
args.inter,
|
||||
args.experts,
|
||||
args.top_k,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
tokens = args.bsz * args.seq
|
||||
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
|
||||
print(
|
||||
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
|
||||
f"experts={args.experts} top_k={args.top_k}"
|
||||
)
|
||||
|
||||
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
|
||||
|
||||
# Optional torch.compile
|
||||
run_grouped_impl = None
|
||||
if args.compile:
|
||||
dynamo.config.capture_scalar_outputs = True
|
||||
dynamo.config.allow_unspec_int_on_nn_module = True
|
||||
try:
|
||||
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
|
||||
except Exception as exc: # pragma: no cover
|
||||
print(f"torch.compile naive failed ({exc}); using eager")
|
||||
else:
|
||||
|
||||
def grouped_forward(inp, *, block=block_grouped):
|
||||
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
|
||||
y, _ = tg.moe_ffn_forward_grouped(
|
||||
inp, block.gate, block.experts, block.top_k
|
||||
)
|
||||
return y
|
||||
|
||||
try:
|
||||
run_grouped_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
|
||||
except Exception as exc: # pragma: no cover
|
||||
print(f"torch.compile grouped failed ({exc}); using eager")
|
||||
run_grouped_impl = None
|
||||
|
||||
def run_naive(block=block_naive, data=x):
|
||||
y, _ = block(data)
|
||||
return y
|
||||
|
||||
def run_grouped(block=block_grouped, data=x, impl=run_grouped_impl):
|
||||
if impl is not None:
|
||||
return impl(data)
|
||||
if tg is None or not tg.available():
|
||||
return torch.empty(0)
|
||||
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
|
||||
y, _ = tg.moe_ffn_forward_grouped(data, block.gate, block.experts, block.top_k)
|
||||
return y if y is not None else torch.empty(0)
|
||||
|
||||
t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup)
|
||||
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
|
||||
print(
|
||||
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
y_ref = run_naive()
|
||||
|
||||
if tg is None or not tg.available():
|
||||
print("torch_grouped\tN/A (unavailable)")
|
||||
return
|
||||
|
||||
y_grouped = run_grouped()
|
||||
if y_grouped.numel() == 0:
|
||||
print("torch_grouped\tN/A (op not callable)")
|
||||
return
|
||||
|
||||
t_grouped = bench(run_grouped, iters=args.iters, warmup=args.warmup)
|
||||
tflops_grouped = flops_total / ((t_grouped / 1000.0) * 1e12)
|
||||
speedup = t_naive / t_grouped
|
||||
print(
|
||||
f"torch_grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000.0):.1f} tok/s\t"
|
||||
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
|
||||
)
|
||||
|
||||
diff = (y_ref.float() - y_grouped.float()).abs()
|
||||
print(
|
||||
"torch_grouped_check: "
|
||||
f"max_abs={diff.max().item():.3e} mean_abs={diff.mean().item():.3e} "
|
||||
f"rel_l2={(diff.pow(2).sum() / (y_ref.float().pow(2).sum() + 1e-12)).sqrt().item():.3e}"
|
||||
)
|
||||
|
||||
if args.profile:
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
|
||||
) as prof:
|
||||
run_naive()
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
|
||||
) as prof:
|
||||
run_grouped()
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,311 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""Sweep grouped_mm vs naive performance for Qwen2 MoE block."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import sys
|
||||
import time
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
|
||||
try:
|
||||
from axolotl.kernels.moe import torch_grouped as tg
|
||||
except Exception: # pragma: no cover
|
||||
tg = None
|
||||
|
||||
|
||||
def _parse_list(arg: str) -> List[int]:
|
||||
return [int(v) for v in arg.split(",") if v]
|
||||
|
||||
|
||||
def _bench(run, *, iters: int, warmup: int, device: torch.device) -> float:
|
||||
for _ in range(warmup):
|
||||
run()
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
times: List[float] = []
|
||||
for _ in range(iters):
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
run()
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
times.append((time.perf_counter() - start) * 1000.0)
|
||||
return sum(times) / len(times)
|
||||
|
||||
|
||||
def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
|
||||
return 6.0 * tokens * top_k * hidden * inter
|
||||
|
||||
|
||||
def _load_block(
|
||||
hidden: int,
|
||||
inter: int,
|
||||
experts: int,
|
||||
top_k: int,
|
||||
*,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
transformers_src = project_root / "transformers" / "src"
|
||||
if transformers_src.exists() and str(transformers_src) not in sys.path:
|
||||
sys.path.append(str(transformers_src))
|
||||
|
||||
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
|
||||
cfg = Qwen2MoeConfig(
|
||||
hidden_size=hidden,
|
||||
moe_intermediate_size=inter,
|
||||
shared_expert_intermediate_size=inter,
|
||||
num_experts=experts,
|
||||
num_experts_per_tok=top_k,
|
||||
norm_topk_prob=True,
|
||||
qkv_bias=True,
|
||||
)
|
||||
|
||||
block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
|
||||
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
|
||||
block_grouped.load_state_dict(block.state_dict())
|
||||
return block, block_grouped
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
bsz: int
|
||||
seq: int
|
||||
hidden: int
|
||||
inter: int
|
||||
experts: int
|
||||
top_k: int
|
||||
dtype: str
|
||||
naive_ms: float
|
||||
grouped_ms: float
|
||||
speedup: float
|
||||
naive_tflops: float
|
||||
grouped_tflops: float
|
||||
max_abs: float
|
||||
mean_abs: float
|
||||
rel_l2: float
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description="Grouped MoE sweep")
|
||||
p.add_argument("--batch-sizes", default="4,8,16")
|
||||
p.add_argument("--seq-lens", default="512,1024,2048")
|
||||
p.add_argument("--hidden", default="2048,4096")
|
||||
p.add_argument("--inter", default="5632,8192,14336")
|
||||
p.add_argument("--experts", default="8,16,32")
|
||||
p.add_argument("--top-k", default="1,2,4")
|
||||
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
|
||||
p.add_argument("--iters", type=int, default=25)
|
||||
p.add_argument("--warmup", type=int, default=5)
|
||||
p.add_argument("--csv", type=Path, default=None)
|
||||
p.add_argument("--compile", action="store_true")
|
||||
args = p.parse_args()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
}[args.dtype]
|
||||
|
||||
if tg is None or not tg.available():
|
||||
print("torch_grouped unavailable; sweep aborted")
|
||||
return
|
||||
|
||||
bs_list = _parse_list(args.batch_sizes)
|
||||
seq_list = _parse_list(args.seq_lens)
|
||||
hidden_list = _parse_list(args.hidden)
|
||||
inter_list = _parse_list(args.inter)
|
||||
expert_list = _parse_list(args.experts)
|
||||
topk_list = _parse_list(args.top_k)
|
||||
|
||||
results: List[Result] = []
|
||||
|
||||
print(
|
||||
"bsz\tseq\thidden\tinter\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
|
||||
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
|
||||
)
|
||||
|
||||
for bsz in bs_list:
|
||||
for seq in seq_list:
|
||||
tokens = bsz * seq
|
||||
for hidden in hidden_list:
|
||||
for inter in inter_list:
|
||||
for experts in expert_list:
|
||||
for top_k in topk_list:
|
||||
torch.manual_seed(0)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
block_naive, block_grouped = _load_block(
|
||||
hidden,
|
||||
inter,
|
||||
experts,
|
||||
top_k,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
x = torch.randn(
|
||||
bsz, seq, hidden, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
compiled_impl = None
|
||||
if args.compile:
|
||||
dynamo.config.capture_scalar_outputs = True
|
||||
dynamo.config.allow_unspec_int_on_nn_module = True
|
||||
try:
|
||||
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
|
||||
except Exception as exc:
|
||||
print(
|
||||
f"torch.compile naive failed ({exc}); using eager"
|
||||
)
|
||||
else:
|
||||
|
||||
def grouped_forward(inp, *, block=block_grouped):
|
||||
block.experts._ax_parent_block_ref = (
|
||||
weakref.ref(block)
|
||||
) # type: ignore[attr-defined]
|
||||
y, _ = tg.moe_ffn_forward_grouped(
|
||||
inp,
|
||||
block.gate,
|
||||
block.experts,
|
||||
block.top_k,
|
||||
)
|
||||
return y
|
||||
|
||||
try:
|
||||
compiled_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
|
||||
except Exception as exc:
|
||||
print(
|
||||
f"torch.compile grouped failed ({exc}); using eager"
|
||||
)
|
||||
compiled_impl = None
|
||||
|
||||
def run_naive(block=block_naive, data=x):
|
||||
y, _ = block(data)
|
||||
return y
|
||||
|
||||
def run_grouped(
|
||||
block=block_grouped, data=x, impl=compiled_impl
|
||||
):
|
||||
if impl is not None:
|
||||
return impl(data)
|
||||
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
|
||||
y, _ = tg.moe_ffn_forward_grouped(
|
||||
data,
|
||||
block.gate,
|
||||
block.experts,
|
||||
block.top_k,
|
||||
)
|
||||
return y
|
||||
|
||||
naive_ms = _bench(
|
||||
run_naive,
|
||||
iters=args.iters,
|
||||
warmup=args.warmup,
|
||||
device=device,
|
||||
)
|
||||
y_naive = run_naive()
|
||||
|
||||
grouped_ms = _bench(
|
||||
run_grouped,
|
||||
iters=args.iters,
|
||||
warmup=args.warmup,
|
||||
device=device,
|
||||
)
|
||||
y_grouped = run_grouped()
|
||||
|
||||
diff = (y_naive.float() - y_grouped.float()).abs()
|
||||
res = Result(
|
||||
bsz,
|
||||
seq,
|
||||
hidden,
|
||||
inter,
|
||||
experts,
|
||||
top_k,
|
||||
args.dtype,
|
||||
naive_ms,
|
||||
grouped_ms,
|
||||
naive_ms / grouped_ms,
|
||||
_estimate_flops(tokens, hidden, inter, top_k)
|
||||
/ ((naive_ms / 1000.0) * 1e12),
|
||||
_estimate_flops(tokens, hidden, inter, top_k)
|
||||
/ ((grouped_ms / 1000.0) * 1e12),
|
||||
diff.max().item(),
|
||||
diff.mean().item(),
|
||||
(
|
||||
(
|
||||
diff.pow(2).sum()
|
||||
/ (y_naive.float().pow(2).sum() + 1e-12)
|
||||
)
|
||||
.sqrt()
|
||||
.item()
|
||||
),
|
||||
)
|
||||
results.append(res)
|
||||
print(
|
||||
f"{bsz}\t{seq}\t{hidden}\t{inter}\t{experts}\t{top_k}\t{res.naive_ms:.2f}\t"
|
||||
f"{res.grouped_ms:.2f}\t{res.speedup:.2f}\t{res.naive_tflops:.2f}\t"
|
||||
f"{res.grouped_tflops:.2f}\t{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}"
|
||||
)
|
||||
|
||||
if args.csv:
|
||||
fieldnames = [
|
||||
"bsz",
|
||||
"seq",
|
||||
"hidden",
|
||||
"inter",
|
||||
"experts",
|
||||
"top_k",
|
||||
"dtype",
|
||||
"naive_ms",
|
||||
"grouped_ms",
|
||||
"speedup",
|
||||
"naive_tflops",
|
||||
"grouped_tflops",
|
||||
"max_abs",
|
||||
"mean_abs",
|
||||
"rel_l2",
|
||||
]
|
||||
with args.csv.open("w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
for r in results:
|
||||
writer.writerow(
|
||||
{
|
||||
"bsz": r.bsz,
|
||||
"seq": r.seq,
|
||||
"hidden": r.hidden,
|
||||
"inter": r.inter,
|
||||
"experts": r.experts,
|
||||
"top_k": r.top_k,
|
||||
"dtype": r.dtype,
|
||||
"naive_ms": f"{r.naive_ms:.4f}",
|
||||
"grouped_ms": f"{r.grouped_ms:.4f}",
|
||||
"speedup": f"{r.speedup:.4f}",
|
||||
"naive_tflops": f"{r.naive_tflops:.4f}",
|
||||
"grouped_tflops": f"{r.grouped_tflops:.4f}",
|
||||
"max_abs": f"{r.max_abs:.6e}",
|
||||
"mean_abs": f"{r.mean_abs:.6e}",
|
||||
"rel_l2": f"{r.rel_l2:.6e}",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import weakref
|
||||
|
||||
main()
|
||||
@@ -1,205 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""Benchmark Torchtitan MoE grouped vs naive expert execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
# Ensure torchtitan is importable when running from the axolotl tree
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
_TITAN_PATH = _PROJECT_ROOT / "torchtitan"
|
||||
if str(_TITAN_PATH) not in sys.path:
|
||||
sys.path.append(str(_TITAN_PATH))
|
||||
|
||||
from torchtitan.models.moe import MoE, MoEArgs
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="Torchtitan MoE microbenchmark")
|
||||
p.add_argument("--bsz", type=int, default=8)
|
||||
p.add_argument("--seq", type=int, default=1024)
|
||||
p.add_argument("--hidden", type=int, default=4096)
|
||||
p.add_argument("--inter", type=int, default=14336)
|
||||
p.add_argument("--experts", type=int, default=8)
|
||||
p.add_argument("--top_k", type=int, default=2)
|
||||
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
|
||||
p.add_argument("--iters", type=int, default=50)
|
||||
p.add_argument("--warmup", type=int, default=10)
|
||||
p.add_argument("--init-std", type=float, default=0.02)
|
||||
p.add_argument(
|
||||
"--score-before",
|
||||
action="store_true",
|
||||
help="Apply routing scores before expert computation (default: after)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--score-func",
|
||||
choices=["softmax", "sigmoid"],
|
||||
default="softmax",
|
||||
)
|
||||
p.add_argument(
|
||||
"--route-norm",
|
||||
action="store_true",
|
||||
help="Enable Torchtitan router normalization when using sigmoid scores.",
|
||||
)
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def _map_dtype(arg: str) -> torch.dtype:
|
||||
return {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
}[arg]
|
||||
|
||||
|
||||
def _estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
|
||||
# Two up projections + one down projection per expert/token combination.
|
||||
return 6.0 * tokens * top_k * hidden * inter
|
||||
|
||||
|
||||
def _prepare_module(
|
||||
moe: MoE,
|
||||
*,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> MoE:
|
||||
moe = moe.to(device=device)
|
||||
for param in moe.parameters():
|
||||
param.data = param.data.to(dtype)
|
||||
if param.grad is not None:
|
||||
param.grad = None
|
||||
|
||||
buffers = dict(moe.named_buffers())
|
||||
for name, buf in buffers.items():
|
||||
if name == "tokens_per_expert":
|
||||
moe._buffers[name] = torch.zeros_like(
|
||||
buf, dtype=torch.float32, device=device
|
||||
)
|
||||
elif name == "expert_bias" and buf is not None:
|
||||
moe._buffers[name] = torch.zeros_like(
|
||||
buf, dtype=torch.float32, device=device
|
||||
)
|
||||
else:
|
||||
moe._buffers[name] = buf.to(device=device, dtype=dtype)
|
||||
moe.eval()
|
||||
return moe
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _forward_fn(module: MoE, x: torch.Tensor) -> torch.Tensor:
|
||||
return module(x)
|
||||
|
||||
|
||||
def _bench(fn, *, iters: int, warmup: int, sync: bool = True) -> float:
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
for _ in range(warmup):
|
||||
fn()
|
||||
if sync and device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
times = []
|
||||
for _ in range(iters):
|
||||
if sync and device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
fn()
|
||||
if sync and device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
times.append((time.perf_counter() - start) * 1000.0)
|
||||
return sum(times) / len(times)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = _parse_args()
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
dtype = _map_dtype(args.dtype)
|
||||
|
||||
torch.manual_seed(0)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
moe_args_grouped = MoEArgs(
|
||||
num_experts=args.experts,
|
||||
num_shared_experts=0,
|
||||
score_func=args.score_func,
|
||||
route_norm=args.route_norm,
|
||||
top_k=args.top_k,
|
||||
use_grouped_mm=True,
|
||||
score_before_experts=args.score_before,
|
||||
load_balance_coeff=None,
|
||||
)
|
||||
moe_grouped = MoE(moe_args_grouped, dim=args.hidden, hidden_dim=args.inter)
|
||||
moe_grouped.init_weights(args.init_std, buffer_device=device)
|
||||
|
||||
moe_args_naive = MoEArgs(
|
||||
num_experts=args.experts,
|
||||
num_shared_experts=0,
|
||||
score_func=args.score_func,
|
||||
route_norm=args.route_norm,
|
||||
top_k=args.top_k,
|
||||
use_grouped_mm=False,
|
||||
score_before_experts=args.score_before,
|
||||
load_balance_coeff=None,
|
||||
)
|
||||
moe_naive = MoE(moe_args_naive, dim=args.hidden, hidden_dim=args.inter)
|
||||
moe_naive.load_state_dict(moe_grouped.state_dict(), strict=True)
|
||||
|
||||
moe_grouped = _prepare_module(moe_grouped, device=device, dtype=dtype)
|
||||
moe_naive = _prepare_module(moe_naive, device=device, dtype=dtype)
|
||||
|
||||
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
|
||||
|
||||
tokens = args.bsz * args.seq
|
||||
print(
|
||||
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} "
|
||||
f"inter={args.inter} experts={args.experts} top_k={args.top_k}"
|
||||
)
|
||||
|
||||
def run_naive():
|
||||
return _forward_fn(moe_naive, x)
|
||||
|
||||
def run_grouped():
|
||||
return _forward_fn(moe_grouped, x)
|
||||
|
||||
if hasattr(moe_naive, "tokens_per_expert"):
|
||||
moe_naive.tokens_per_expert.zero_()
|
||||
if hasattr(moe_grouped, "tokens_per_expert"):
|
||||
moe_grouped.tokens_per_expert.zero_()
|
||||
|
||||
t_naive = _bench(run_naive, iters=args.iters, warmup=args.warmup)
|
||||
flops = _estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
|
||||
tflops_naive = flops / ((t_naive / 1000.0) * 1e12)
|
||||
print(
|
||||
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t"
|
||||
f"{tflops_naive:.2f} TFLOP/s"
|
||||
)
|
||||
|
||||
y_naive = run_naive()
|
||||
|
||||
if hasattr(moe_grouped, "tokens_per_expert"):
|
||||
moe_grouped.tokens_per_expert.zero_()
|
||||
|
||||
t_grouped = _bench(run_grouped, iters=args.iters, warmup=args.warmup)
|
||||
tflops_grouped = flops / ((t_grouped / 1000.0) * 1e12)
|
||||
speedup = t_naive / t_grouped if t_grouped > 0 else float("nan")
|
||||
print(
|
||||
f"grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000.0):.1f} tok/s\t"
|
||||
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
|
||||
)
|
||||
|
||||
y_grouped = run_grouped()
|
||||
diff = (y_naive.float() - y_grouped.float()).abs()
|
||||
max_abs = diff.max().item()
|
||||
mean_abs = diff.mean().item()
|
||||
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
|
||||
print(
|
||||
f"grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,328 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""Sweep Torchtitan MoE grouped vs naive configurations and report performance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List
|
||||
|
||||
import torch
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
_TITAN_PATH = _PROJECT_ROOT / "torchtitan"
|
||||
if str(_TITAN_PATH) not in sys.path:
|
||||
sys.path.append(str(_TITAN_PATH))
|
||||
|
||||
from torchtitan.models.moe import MoE, MoEArgs
|
||||
|
||||
|
||||
def _parse_int_list(value: str) -> List[int]:
|
||||
return [int(v) for v in value.split(",") if v]
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="Torchtitan MoE grouped vs naive sweep")
|
||||
p.add_argument(
|
||||
"--batch-sizes", default="4,8,16", help="Comma separated batch sizes"
|
||||
)
|
||||
p.add_argument(
|
||||
"--seq-lens", default="1024,2048", help="Comma separated sequence lengths"
|
||||
)
|
||||
p.add_argument(
|
||||
"--experts", default="8,16,32,64", help="Comma separated expert counts"
|
||||
)
|
||||
p.add_argument("--top-ks", default="1,2,4", help="Comma separated top_k choices")
|
||||
p.add_argument("--hidden", type=int, default=4096)
|
||||
p.add_argument("--inter", type=int, default=14336)
|
||||
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
|
||||
p.add_argument("--iters", type=int, default=25)
|
||||
p.add_argument("--warmup", type=int, default=5)
|
||||
p.add_argument("--init-std", type=float, default=0.02)
|
||||
p.add_argument("--score-before", action="store_true")
|
||||
p.add_argument("--score-func", choices=["softmax", "sigmoid"], default="softmax")
|
||||
p.add_argument("--route-norm", action="store_true")
|
||||
p.add_argument("--csv", type=Path, default=None, help="Optional CSV output path")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def _map_dtype(arg: str) -> torch.dtype:
|
||||
return {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
}[arg]
|
||||
|
||||
|
||||
def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
|
||||
return 6.0 * tokens * top_k * hidden * inter
|
||||
|
||||
|
||||
def _prepare_module(module: MoE, *, device: torch.device, dtype: torch.dtype) -> MoE:
|
||||
module = module.to(device=device)
|
||||
for param in module.parameters():
|
||||
param.data = param.data.to(dtype)
|
||||
if param.grad is not None:
|
||||
param.grad = None
|
||||
for name, buf in module.named_buffers():
|
||||
if name == "tokens_per_expert":
|
||||
module._buffers[name] = torch.zeros_like(
|
||||
buf, dtype=torch.float32, device=device
|
||||
)
|
||||
elif name == "expert_bias" and buf is not None:
|
||||
module._buffers[name] = torch.zeros_like(
|
||||
buf, dtype=torch.float32, device=device
|
||||
)
|
||||
else:
|
||||
module._buffers[name] = buf.to(device=device, dtype=dtype)
|
||||
module.eval()
|
||||
return module
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _forward(module: MoE, x: torch.Tensor) -> torch.Tensor:
|
||||
return module(x)
|
||||
|
||||
|
||||
def _bench(callable_, *, iters: int, warmup: int, device: torch.device) -> float:
|
||||
for _ in range(warmup):
|
||||
callable_()
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
timings: List[float] = []
|
||||
for _ in range(iters):
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
callable_()
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
timings.append((time.perf_counter() - start) * 1000.0)
|
||||
return sum(timings) / len(timings)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepResult:
|
||||
bsz: int
|
||||
seq: int
|
||||
experts: int
|
||||
top_k: int
|
||||
dtype: str
|
||||
naive_ms: float
|
||||
grouped_ms: float
|
||||
speedup: float
|
||||
naive_tflops: float
|
||||
grouped_tflops: float
|
||||
max_abs: float
|
||||
mean_abs: float
|
||||
rel_l2: float
|
||||
|
||||
|
||||
def _run_case(
|
||||
*,
|
||||
bsz: int,
|
||||
seq: int,
|
||||
experts: int,
|
||||
top_k: int,
|
||||
hidden: int,
|
||||
inter: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
iters: int,
|
||||
warmup: int,
|
||||
init_std: float,
|
||||
score_before: bool,
|
||||
score_func: str,
|
||||
route_norm: bool,
|
||||
) -> SweepResult:
|
||||
torch.manual_seed(0)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
moe_args_grouped = MoEArgs(
|
||||
num_experts=experts,
|
||||
num_shared_experts=0,
|
||||
score_func=score_func,
|
||||
route_norm=route_norm,
|
||||
top_k=top_k,
|
||||
use_grouped_mm=True,
|
||||
score_before_experts=score_before,
|
||||
load_balance_coeff=None,
|
||||
)
|
||||
moe_grouped = MoE(moe_args_grouped, dim=hidden, hidden_dim=inter)
|
||||
moe_grouped.init_weights(init_std, buffer_device=device)
|
||||
|
||||
moe_args_naive = MoEArgs(
|
||||
num_experts=experts,
|
||||
num_shared_experts=0,
|
||||
score_func=score_func,
|
||||
route_norm=route_norm,
|
||||
top_k=top_k,
|
||||
use_grouped_mm=False,
|
||||
score_before_experts=score_before,
|
||||
load_balance_coeff=None,
|
||||
)
|
||||
moe_naive = MoE(moe_args_naive, dim=hidden, hidden_dim=inter)
|
||||
moe_naive.load_state_dict(moe_grouped.state_dict(), strict=True)
|
||||
|
||||
moe_grouped = _prepare_module(moe_grouped, device=device, dtype=dtype)
|
||||
moe_naive = _prepare_module(moe_naive, device=device, dtype=dtype)
|
||||
|
||||
x = torch.randn(bsz, seq, hidden, device=device, dtype=dtype)
|
||||
|
||||
def run_naive():
|
||||
if hasattr(moe_naive, "tokens_per_expert"):
|
||||
moe_naive.tokens_per_expert.zero_()
|
||||
return _forward(moe_naive, x)
|
||||
|
||||
def run_grouped():
|
||||
if hasattr(moe_grouped, "tokens_per_expert"):
|
||||
moe_grouped.tokens_per_expert.zero_()
|
||||
return _forward(moe_grouped, x)
|
||||
|
||||
naive_ms = _bench(run_naive, iters=iters, warmup=warmup, device=device)
|
||||
y_naive = run_naive()
|
||||
|
||||
grouped_ms = _bench(run_grouped, iters=iters, warmup=warmup, device=device)
|
||||
y_grouped = run_grouped()
|
||||
|
||||
diff = (y_naive.float() - y_grouped.float()).abs()
|
||||
max_abs = diff.max().item()
|
||||
mean_abs = diff.mean().item()
|
||||
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
|
||||
|
||||
tokens = bsz * seq
|
||||
flops = _estimate_flops(tokens, hidden, inter, top_k)
|
||||
naive_tflops = flops / ((naive_ms / 1000.0) * 1e12)
|
||||
grouped_tflops = flops / ((grouped_ms / 1000.0) * 1e12)
|
||||
speedup = naive_ms / grouped_ms if grouped_ms > 0 else float("nan")
|
||||
|
||||
return SweepResult(
|
||||
bsz=bsz,
|
||||
seq=seq,
|
||||
experts=experts,
|
||||
top_k=top_k,
|
||||
dtype=str(dtype),
|
||||
naive_ms=naive_ms,
|
||||
grouped_ms=grouped_ms,
|
||||
speedup=speedup,
|
||||
naive_tflops=naive_tflops,
|
||||
grouped_tflops=grouped_tflops,
|
||||
max_abs=max_abs,
|
||||
mean_abs=mean_abs,
|
||||
rel_l2=rel_l2,
|
||||
)
|
||||
|
||||
|
||||
def _print_header(
|
||||
hidden: int, inter: int, dtype: torch.dtype, device: torch.device
|
||||
) -> None:
|
||||
print(f"Device={device} dtype={dtype} hidden={hidden} inter={inter}")
|
||||
print(
|
||||
"bsz\tseq\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
|
||||
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
|
||||
)
|
||||
|
||||
|
||||
def _print_result(res: SweepResult) -> None:
|
||||
print(
|
||||
f"{res.bsz}\t{res.seq}\t{res.experts}\t{res.top_k}\t"
|
||||
f"{res.naive_ms:.2f}\t{res.grouped_ms:.2f}\t{res.speedup:.2f}\t"
|
||||
f"{res.naive_tflops:.2f}\t{res.grouped_tflops:.2f}\t"
|
||||
f"{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}"
|
||||
)
|
||||
|
||||
|
||||
def _write_csv(path: Path, results: Iterable[SweepResult]) -> None:
|
||||
fieldnames = [
|
||||
"batch_size",
|
||||
"seq_len",
|
||||
"experts",
|
||||
"top_k",
|
||||
"dtype",
|
||||
"naive_ms",
|
||||
"grouped_ms",
|
||||
"speedup",
|
||||
"naive_tflops",
|
||||
"grouped_tflops",
|
||||
"max_abs",
|
||||
"mean_abs",
|
||||
"rel_l2",
|
||||
]
|
||||
with path.open("w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
for r in results:
|
||||
writer.writerow(
|
||||
{
|
||||
"batch_size": r.bsz,
|
||||
"seq_len": r.seq,
|
||||
"experts": r.experts,
|
||||
"top_k": r.top_k,
|
||||
"dtype": r.dtype,
|
||||
"naive_ms": f"{r.naive_ms:.4f}",
|
||||
"grouped_ms": f"{r.grouped_ms:.4f}",
|
||||
"speedup": f"{r.speedup:.4f}",
|
||||
"naive_tflops": f"{r.naive_tflops:.4f}",
|
||||
"grouped_tflops": f"{r.grouped_tflops:.4f}",
|
||||
"max_abs": f"{r.max_abs:.6e}",
|
||||
"mean_abs": f"{r.mean_abs:.6e}",
|
||||
"rel_l2": f"{r.rel_l2:.6e}",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = _parse_args()
|
||||
dtype = _map_dtype(args.dtype)
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
batch_sizes = _parse_int_list(args.batch_sizes)
|
||||
seq_lens = _parse_int_list(args.seq_lens)
|
||||
experts_list = _parse_int_list(args.experts)
|
||||
top_ks = _parse_int_list(args.top_ks)
|
||||
|
||||
results: List[SweepResult] = []
|
||||
_print_header(args.hidden, args.inter, dtype, device)
|
||||
|
||||
for bsz in batch_sizes:
|
||||
for seq in seq_lens:
|
||||
for experts in experts_list:
|
||||
for top_k in top_ks:
|
||||
try:
|
||||
res = _run_case(
|
||||
bsz=bsz,
|
||||
seq=seq,
|
||||
experts=experts,
|
||||
top_k=top_k,
|
||||
hidden=args.hidden,
|
||||
inter=args.inter,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
iters=args.iters,
|
||||
warmup=args.warmup,
|
||||
init_std=args.init_std,
|
||||
score_before=args.score_before,
|
||||
score_func=args.score_func,
|
||||
route_norm=args.route_norm,
|
||||
)
|
||||
except RuntimeError as err:
|
||||
print(
|
||||
f"{bsz}\t{seq}\t{experts}\t{top_k}\tERROR: {err}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
results.append(res)
|
||||
_print_result(res)
|
||||
|
||||
if args.csv and results:
|
||||
_write_csv(args.csv, results)
|
||||
print(f"Wrote {len(results)} rows to {args.csv}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
scripts/benchmarks/__init__.py
Normal file
5
scripts/benchmarks/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Benchmark helpers."""
|
||||
|
||||
from .deepseek_v3_moe import ACCURACY_TOLERANCE, DTYPE_MAP, benchmark_deepseek_v3
|
||||
|
||||
__all__ = ["benchmark_deepseek_v3", "DTYPE_MAP", "ACCURACY_TOLERANCE"]
|
||||
100
scripts/benchmarks/build_deepseek_v3_8b.py
Executable file
100
scripts/benchmarks/build_deepseek_v3_8b.py
Executable file
@@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Instantiate a ~8.3B DeepSeek-V3 MoE model with random weights.
|
||||
|
||||
Run this on a GPU-equipped machine (e.g. 1× NVL H100) so the dense
|
||||
initialization completes quickly:
|
||||
|
||||
python scripts/benchmarks/build_deepseek_v3_8b.py --output deepseek-v3-8b-moe
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from transformers import DeepseekV3Config, DeepseekV3ForCausalLM
|
||||
|
||||
DTYPE_MAP = {
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float16": torch.float16,
|
||||
}
|
||||
|
||||
|
||||
def build_config() -> DeepseekV3Config:
|
||||
"""Return a DeepSeek V3 configuration totaling ~8.3B parameters."""
|
||||
|
||||
return DeepseekV3Config(
|
||||
vocab_size=32_000,
|
||||
hidden_size=3_072,
|
||||
intermediate_size=8_192,
|
||||
moe_intermediate_size=2_560,
|
||||
num_hidden_layers=20,
|
||||
num_attention_heads=24,
|
||||
num_key_value_heads=24,
|
||||
n_routed_experts=18,
|
||||
num_experts_per_tok=4,
|
||||
n_group=6,
|
||||
topk_group=4,
|
||||
kv_lora_rank=192,
|
||||
q_lora_rank=384,
|
||||
max_position_embeddings=2_048,
|
||||
rope_theta=10_000.0,
|
||||
rope_interleave=True,
|
||||
hidden_act="silu",
|
||||
initializer_range=0.02,
|
||||
attention_dropout=0.0,
|
||||
attention_bias=False,
|
||||
n_shared_experts=1,
|
||||
routed_scaling_factor=2.5,
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory to save the generated model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="bfloat16",
|
||||
choices=DTYPE_MAP.keys(),
|
||||
help="Storage dtype for the checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Torch RNG seed for reproducibility",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
output_dir = args.output
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
config = build_config()
|
||||
model = DeepseekV3ForCausalLM(config)
|
||||
|
||||
dtype = DTYPE_MAP[args.dtype]
|
||||
model.to(dtype=dtype)
|
||||
|
||||
param_count = sum(p.numel() for p in model.parameters())
|
||||
print(f"Initialized DeepSeek-V3 MoE with {param_count / 1e9:.3f}B parameters")
|
||||
|
||||
model.save_pretrained(output_dir, safe_serialization=True)
|
||||
config.save_pretrained(output_dir)
|
||||
print(f"Saved model and config to {output_dir.resolve()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
190
scripts/benchmarks/deepseek_v3_group_gemm_table.py
Normal file
190
scripts/benchmarks/deepseek_v3_group_gemm_table.py
Normal file
@@ -0,0 +1,190 @@
|
||||
#!/usr/bin/env python
|
||||
"""Reproduce TorchTitan CG GEMM timings for selected problem sizes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
CURRENT_DIR = Path(__file__).resolve().parent
|
||||
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||
repo_root = candidate / "axolotl"
|
||||
if repo_root.exists():
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
break
|
||||
else:
|
||||
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||
|
||||
from axolotl.kernels.moe import (
|
||||
cg_grouped_gemm_forward,
|
||||
cg_grouped_gemm_forward_dynamic,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Scenario:
|
||||
num_groups: int
|
||||
m: int
|
||||
n: int
|
||||
k: int
|
||||
|
||||
|
||||
SCENARIOS: tuple[Scenario, ...] = (
|
||||
Scenario(num_groups=4, m=8192, n=4096, k=7168),
|
||||
Scenario(num_groups=4, m=8192, n=7168, k=2048),
|
||||
Scenario(num_groups=8, m=4096, n=4096, k=7168),
|
||||
Scenario(num_groups=8, m=4096, n=7168, k=2048),
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--device", default="cuda", choices=["cuda"], help="Execution device"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="bf16",
|
||||
choices=["bf16", "fp16", "fp32"],
|
||||
help="Computation dtype",
|
||||
)
|
||||
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
|
||||
parser.add_argument("--iters", type=int, default=20, help="Benchmark iterations")
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
default=128,
|
||||
help="GROUP_SIZE_M expected by the kernel",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def pick_dtype(name: str) -> torch.dtype:
|
||||
return {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
}[name]
|
||||
|
||||
|
||||
def make_indices(
|
||||
num_groups: int, group_size: int, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
indices = torch.arange(num_groups, device=device, dtype=torch.int32)
|
||||
return indices.repeat_interleave(group_size)
|
||||
|
||||
|
||||
def timed_call(fn, *args, warmup: int, iters: int) -> float:
|
||||
for _ in range(warmup):
|
||||
fn(*args)
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
for _ in range(iters):
|
||||
fn(*args)
|
||||
torch.cuda.synchronize()
|
||||
return (time.perf_counter() - start) * 1000.0 / iters
|
||||
|
||||
|
||||
def run_scenario(
|
||||
scenario: Scenario,
|
||||
*,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
warmup: int,
|
||||
iters: int,
|
||||
group_size_m: int,
|
||||
) -> dict:
|
||||
if scenario.m % scenario.num_groups != 0:
|
||||
raise ValueError(
|
||||
f"M ({scenario.m}) not divisible by groups ({scenario.num_groups})"
|
||||
)
|
||||
group_size = scenario.m // scenario.num_groups
|
||||
if group_size % group_size_m != 0:
|
||||
raise ValueError(
|
||||
f"Group size {group_size} must be a multiple of GROUP_SIZE_M ({group_size_m}) for the Triton kernel"
|
||||
)
|
||||
|
||||
inputs = torch.randn(scenario.m, scenario.k, device=device, dtype=dtype)
|
||||
weights = torch.randn(
|
||||
scenario.num_groups, scenario.n, scenario.k, device=device, dtype=dtype
|
||||
)
|
||||
indices = make_indices(scenario.num_groups, group_size, device)
|
||||
|
||||
def persistent():
|
||||
return cg_grouped_gemm_forward(inputs, weights, indices, group_size_m)
|
||||
|
||||
def baseline():
|
||||
return cg_grouped_gemm_forward_dynamic(inputs, weights, indices, group_size_m)
|
||||
|
||||
persistent_ms = timed_call(persistent, warmup=warmup, iters=iters)
|
||||
baseline_ms = timed_call(baseline, warmup=warmup, iters=iters)
|
||||
|
||||
return {
|
||||
"scenario": scenario,
|
||||
"persistent_ms": persistent_ms,
|
||||
"baseline_ms": baseline_ms,
|
||||
"speedup": baseline_ms / persistent_ms if persistent_ms > 0 else float("nan"),
|
||||
}
|
||||
|
||||
|
||||
def main() -> None: # pragma: no cover - utility script
|
||||
args = parse_args()
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
if args.device != "cuda" or not torch.cuda.is_available():
|
||||
raise SystemExit("CUDA device required for this benchmark")
|
||||
|
||||
dtype = pick_dtype(args.dtype)
|
||||
device = torch.device(args.device)
|
||||
|
||||
print(
|
||||
f"device={device} dtype={dtype} warmup={args.warmup} iters={args.iters} group_size={args.group_size}"
|
||||
)
|
||||
print(
|
||||
f"{'groups':>7} {'m':>7} {'n':>7} {'k':>7} {'persistent':>12} {'baseline':>12} {'speedup':>8}"
|
||||
)
|
||||
for result in run_all(
|
||||
SCENARIOS,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
warmup=args.warmup,
|
||||
iters=args.iters,
|
||||
group_size_m=args.group_size,
|
||||
):
|
||||
scen = result["scenario"]
|
||||
print(
|
||||
f"{scen.num_groups:>7} {scen.m:>7} {scen.n:>7} {scen.k:>7}"
|
||||
f" {result['persistent_ms']:>11.3f} ms {result['baseline_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
|
||||
)
|
||||
|
||||
|
||||
def run_all(
|
||||
scenarios: Iterable[Scenario],
|
||||
*,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
warmup: int,
|
||||
iters: int,
|
||||
group_size_m: int,
|
||||
) -> Iterable[dict]:
|
||||
for scenario in scenarios:
|
||||
yield run_scenario(
|
||||
scenario,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
warmup=warmup,
|
||||
iters=iters,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
301
scripts/benchmarks/deepseek_v3_moe.py
Normal file
301
scripts/benchmarks/deepseek_v3_moe.py
Normal file
@@ -0,0 +1,301 @@
|
||||
#!/usr/bin/env python
|
||||
# mypy: ignore-errors
|
||||
"""Microbenchmark for DeepSeek V3 MoE block comparing baseline vs Triton CG kernels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from transformers.models.deepseek_v3.configuration_deepseek_v3 import (
|
||||
DeepseekV3Config,
|
||||
)
|
||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
except ImportError as exc: # pragma: no cover - utility script
|
||||
raise SystemExit(
|
||||
"Transformers with DeepSeek-V3 support must be available in PYTHONPATH"
|
||||
) from exc
|
||||
|
||||
CURRENT_DIR = Path(__file__).resolve().parent
|
||||
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||
repo_root = candidate / "axolotl"
|
||||
if repo_root.exists():
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
break
|
||||
else: # pragma: no cover - defensive guard
|
||||
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||
|
||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe # noqa: E402
|
||||
|
||||
ACCURACY_TOLERANCE = 5e-3
|
||||
|
||||
DTYPE_MAP = {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
}
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--batch", type=int, default=8, help="batch size")
|
||||
parser.add_argument("--seq-len", type=int, default=2048, help="sequence length")
|
||||
parser.add_argument("--hidden-size", type=int, default=4096, help="MoE hidden size")
|
||||
parser.add_argument(
|
||||
"--moe-intermediate-size",
|
||||
type=int,
|
||||
default=8192,
|
||||
help="MoE intermediate projection size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-experts",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Number of routed experts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of experts per token",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--groups",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Router groups (must divide n-experts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
choices=DTYPE_MAP.keys(),
|
||||
default="bf16",
|
||||
help="Computation dtype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
default="auto",
|
||||
choices=["auto", "cpu", "cuda"],
|
||||
help="Execution device",
|
||||
)
|
||||
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
|
||||
parser.add_argument("--iters", type=int, default=25, help="Benchmark iterations")
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--uniform-routing",
|
||||
action="store_true",
|
||||
help="Override router to distribute tokens evenly across experts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
default=128,
|
||||
help="GROUP_SIZE_M used by the Triton kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["cg", "mg"],
|
||||
default="mg",
|
||||
help="MoE kernel backend to benchmark",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def resolve_device(requested: str) -> torch.device:
|
||||
if requested == "auto":
|
||||
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return torch.device(requested)
|
||||
|
||||
|
||||
def build_module(args: argparse.Namespace) -> DeepseekV3MoE:
|
||||
config = DeepseekV3Config(
|
||||
hidden_size=args.hidden_size,
|
||||
intermediate_size=args.moe_intermediate_size,
|
||||
moe_intermediate_size=args.moe_intermediate_size,
|
||||
n_routed_experts=args.n_experts,
|
||||
num_experts_per_tok=args.top_k,
|
||||
n_group=args.groups,
|
||||
topk_group=max(1, min(args.groups, args.top_k)),
|
||||
n_shared_experts=1,
|
||||
)
|
||||
module = DeepseekV3MoE(config)
|
||||
module.eval()
|
||||
return module
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def timed_forward(
|
||||
module: DeepseekV3MoE, inputs: torch.Tensor, iters: int, warmup: int
|
||||
) -> float:
|
||||
for _ in range(warmup):
|
||||
module(inputs)
|
||||
if inputs.is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
for _ in range(iters):
|
||||
module(inputs)
|
||||
if inputs.is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start
|
||||
return (elapsed / iters) * 1000.0
|
||||
|
||||
|
||||
def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
device = resolve_device(args.device)
|
||||
dtype = DTYPE_MAP[args.dtype]
|
||||
|
||||
if args.n_experts % args.groups != 0:
|
||||
raise SystemExit("n-experts must be divisible by groups")
|
||||
if args.top_k > args.n_experts:
|
||||
raise SystemExit("top-k cannot exceed number of experts")
|
||||
|
||||
if device.type == "cuda" and not torch.cuda.is_available():
|
||||
raise SystemExit("CUDA requested but not available")
|
||||
|
||||
baseline_module = build_module(args)
|
||||
original_moe = getattr(
|
||||
DeepseekV3MoE,
|
||||
"_axolotl_triton_original_moe",
|
||||
DeepseekV3MoE.moe,
|
||||
)
|
||||
baseline_module.moe = MethodType(original_moe, baseline_module)
|
||||
state_dict = baseline_module.state_dict()
|
||||
|
||||
patch_deepseek_v3_moe(group_size_m=args.group_size, backend=args.backend)
|
||||
patched_module = build_module(args)
|
||||
patched_module.load_state_dict(state_dict)
|
||||
|
||||
baseline_module.to(device=device, dtype=dtype)
|
||||
patched_module.to(device=device, dtype=dtype)
|
||||
|
||||
tokens = args.batch * args.seq_len
|
||||
routed_tokens = tokens * args.top_k
|
||||
avg_tokens_per_expert = routed_tokens / args.n_experts
|
||||
|
||||
inputs = torch.randn(
|
||||
args.batch,
|
||||
args.seq_len,
|
||||
args.hidden_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
flat_inputs = inputs.view(-1, args.hidden_size)
|
||||
if args.uniform_routing:
|
||||
total_assignments = flat_inputs.size(0) * args.top_k
|
||||
base = total_assignments // args.n_experts
|
||||
remainder = total_assignments % args.n_experts
|
||||
counts = torch.full(
|
||||
(args.n_experts,),
|
||||
base,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
if remainder:
|
||||
counts[:remainder] += 1
|
||||
assignments = torch.repeat_interleave(
|
||||
torch.arange(args.n_experts, device=device), counts
|
||||
)
|
||||
assignments = assignments[torch.randperm(assignments.size(0))]
|
||||
topk_idx = assignments.view(flat_inputs.size(0), args.top_k)
|
||||
else:
|
||||
topk_idx, _ = patched_module.gate(flat_inputs)
|
||||
|
||||
tokens_per_expert = torch.bincount(
|
||||
topk_idx.reshape(-1), minlength=args.n_experts
|
||||
)
|
||||
min_tokens = int(tokens_per_expert.min().item())
|
||||
max_tokens = int(tokens_per_expert.max().item())
|
||||
|
||||
if args.uniform_routing:
|
||||
weights = torch.full(
|
||||
topk_idx.shape,
|
||||
1.0 / args.top_k,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
def _uniform_gate(self, hidden_states):
|
||||
flat = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
token_count = flat.shape[0]
|
||||
return topk_idx[:token_count], weights[:token_count]
|
||||
|
||||
patched_module.gate.forward = _uniform_gate.__get__(
|
||||
patched_module.gate, patched_module.gate.__class__
|
||||
)
|
||||
baseline_module.gate.forward = _uniform_gate.__get__(
|
||||
baseline_module.gate, baseline_module.gate.__class__
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_output = baseline_module(inputs)
|
||||
patched_output = patched_module(inputs)
|
||||
max_diff = (ref_output - patched_output).abs().max().item()
|
||||
|
||||
baseline_vram = patched_vram = None
|
||||
if device.type == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
baseline_ms = timed_forward(baseline_module, inputs, args.iters, args.warmup)
|
||||
if device.type == "cuda":
|
||||
baseline_vram = torch.cuda.max_memory_allocated(device)
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
patched_ms = timed_forward(patched_module, inputs, args.iters, args.warmup)
|
||||
if device.type == "cuda":
|
||||
patched_vram = torch.cuda.max_memory_allocated(device)
|
||||
|
||||
speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan")
|
||||
|
||||
return {
|
||||
"device": device,
|
||||
"backend": args.backend,
|
||||
"dtype": dtype,
|
||||
"baseline_ms": baseline_ms,
|
||||
"patched_ms": patched_ms,
|
||||
"speedup": speedup,
|
||||
"max_diff": max_diff,
|
||||
"routed_tokens": routed_tokens,
|
||||
"avg_tokens": avg_tokens_per_expert,
|
||||
"min_tokens": min_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
"baseline_vram": baseline_vram,
|
||||
"patched_vram": patched_vram,
|
||||
"accuracy_ok": max_diff <= ACCURACY_TOLERANCE,
|
||||
}
|
||||
|
||||
|
||||
def main() -> None: # pragma: no cover - CLI entrypoint
|
||||
args = parse_args()
|
||||
result = benchmark_deepseek_v3(args)
|
||||
|
||||
print(
|
||||
f"Device={result['device'].type} dtype={result['dtype']} backend={result['backend']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
|
||||
)
|
||||
print(
|
||||
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"
|
||||
)
|
||||
print(f"min/max tokens per expert: {result['min_tokens']}/{result['max_tokens']}")
|
||||
if result["baseline_vram"] is not None:
|
||||
print(
|
||||
f"VRAM baseline={result['baseline_vram'] / (1024**2):.1f} MiB | patched={result['patched_vram'] / (1024**2):.1f} MiB"
|
||||
)
|
||||
print(
|
||||
f"Baseline: {result['baseline_ms']:.3f} ms | Patched: {result['patched_ms']:.3f} ms | x{result['speedup']:.2f}"
|
||||
)
|
||||
print(f"Max |Δ| between outputs: {result['max_diff']:.2e}")
|
||||
if not result["accuracy_ok"]:
|
||||
raise RuntimeError(
|
||||
f"Accuracy check failed: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
275
scripts/benchmarks/deepseek_v3_moe_sweep.py
Normal file
275
scripts/benchmarks/deepseek_v3_moe_sweep.py
Normal file
@@ -0,0 +1,275 @@
|
||||
#!/usr/bin/env python
|
||||
# mypy: ignore-errors
|
||||
"""Sweep a set of DeepSeek V3 MoE benchmark configurations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
CURRENT_DIR = Path(__file__).resolve().parent
|
||||
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||
repo_root = candidate / "axolotl"
|
||||
if repo_root.exists():
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
break
|
||||
else: # pragma: no cover - defensive guard
|
||||
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||
|
||||
from scripts.benchmarks.deepseek_v3_moe import ( # noqa: E402
|
||||
ACCURACY_TOLERANCE,
|
||||
DTYPE_MAP,
|
||||
benchmark_deepseek_v3,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
choices=DTYPE_MAP.keys(),
|
||||
default="bf16",
|
||||
help="Computation dtype for all benchmarks",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
default="auto",
|
||||
choices=["auto", "cpu", "cuda"],
|
||||
help="Execution device",
|
||||
)
|
||||
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
|
||||
parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations")
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
help="Override GROUP_SIZE_M for every configuration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backends",
|
||||
default="mg",
|
||||
help="Comma separated list of backends to benchmark (subset of cg,mg)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-uniform-routing",
|
||||
action="store_true",
|
||||
help="Disable uniform routing for every configuration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-mixtral-long",
|
||||
action="store_true",
|
||||
help="Add an 8×8192 Mixtral-style run to the sweep",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
help="Optional CSV file to store results",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_namespace(
|
||||
base: dict, args: argparse.Namespace, backend: str
|
||||
) -> SimpleNamespace:
|
||||
combined = dict(base)
|
||||
combined.update(
|
||||
{
|
||||
"dtype": args.dtype,
|
||||
"device": args.device,
|
||||
"backend": backend,
|
||||
"warmup": args.warmup,
|
||||
"iters": args.iters,
|
||||
"seed": args.seed,
|
||||
"uniform_routing": not args.no_uniform_routing,
|
||||
}
|
||||
)
|
||||
if args.group_size is not None:
|
||||
combined["group_size"] = args.group_size
|
||||
return SimpleNamespace(**combined)
|
||||
|
||||
|
||||
ARCHETYPES = (
|
||||
(
|
||||
"mixtral",
|
||||
{
|
||||
"hidden_size": 4096,
|
||||
"moe_intermediate_size": 14336,
|
||||
"n_experts": 8,
|
||||
"top_k": 2,
|
||||
"groups": 1,
|
||||
"group_size": 128,
|
||||
},
|
||||
[(4, 2048), (8, 4096)],
|
||||
),
|
||||
(
|
||||
"qwen",
|
||||
{
|
||||
"hidden_size": 6144,
|
||||
"moe_intermediate_size": 24576,
|
||||
"n_experts": 16,
|
||||
"top_k": 4,
|
||||
"groups": 8,
|
||||
"group_size": 128,
|
||||
},
|
||||
[(4, 4096), (8, 8192)],
|
||||
),
|
||||
(
|
||||
"deepseek_v3",
|
||||
{
|
||||
"hidden_size": 12288,
|
||||
"moe_intermediate_size": 49152,
|
||||
"n_experts": 128,
|
||||
"top_k": 8,
|
||||
"groups": 16,
|
||||
"group_size": 128,
|
||||
},
|
||||
[(4, 4096), (8, 8192)],
|
||||
),
|
||||
)
|
||||
|
||||
MIXTRAL_LONG_SHAPES = [(8, 8192)]
|
||||
|
||||
|
||||
def main() -> None: # pragma: no cover - utility script
|
||||
args = parse_args()
|
||||
|
||||
grid = []
|
||||
for label, base_cfg, shapes in ARCHETYPES:
|
||||
for batch, seq_len in shapes:
|
||||
cfg = {
|
||||
"label": label,
|
||||
"batch": batch,
|
||||
"seq_len": seq_len,
|
||||
**base_cfg,
|
||||
}
|
||||
if cfg["n_experts"] % cfg["groups"] != 0 or cfg["top_k"] > cfg["n_experts"]:
|
||||
continue
|
||||
grid.append(cfg)
|
||||
|
||||
if args.include_mixtral_long:
|
||||
base_cfg = ARCHETYPES[0][1]
|
||||
for batch, seq_len in MIXTRAL_LONG_SHAPES:
|
||||
grid.append(
|
||||
{
|
||||
"label": "mixtral_long",
|
||||
"batch": batch,
|
||||
"seq_len": seq_len,
|
||||
**base_cfg,
|
||||
}
|
||||
)
|
||||
|
||||
if not grid:
|
||||
raise SystemExit("No valid parameter combinations produced")
|
||||
|
||||
header = (
|
||||
"model",
|
||||
"batch",
|
||||
"seq_len",
|
||||
"hidden_size",
|
||||
"moe_intermediate",
|
||||
"n_experts",
|
||||
"top_k",
|
||||
"groups",
|
||||
"backend",
|
||||
"baseline_ms",
|
||||
"patched_ms",
|
||||
"speedup",
|
||||
"baseline_vram_mib",
|
||||
"patched_vram_mib",
|
||||
"min_tokens",
|
||||
"max_tokens",
|
||||
"max_diff",
|
||||
"accuracy_ok",
|
||||
)
|
||||
rows = []
|
||||
|
||||
raw_backends = [
|
||||
token.strip() for token in args.backends.split(",") if token.strip()
|
||||
]
|
||||
if not raw_backends:
|
||||
raw_backends = ["mg"]
|
||||
valid_backends = []
|
||||
for backend in raw_backends:
|
||||
if backend not in {"cg", "mg"}:
|
||||
raise SystemExit(f"Unsupported backend '{backend}' requested")
|
||||
if backend not in valid_backends:
|
||||
valid_backends.append(backend)
|
||||
|
||||
uniform_flag = not args.no_uniform_routing
|
||||
print(
|
||||
f"Running sweep on device={args.device} dtype={args.dtype} backends={tuple(valid_backends)} uniform_routing={uniform_flag}"
|
||||
)
|
||||
print(
|
||||
f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
|
||||
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'diff':>10} {'acc':>5}"
|
||||
)
|
||||
|
||||
for cfg in grid:
|
||||
for backend in valid_backends:
|
||||
ns = make_namespace(cfg, args, backend)
|
||||
result = benchmark_deepseek_v3(ns)
|
||||
baseline_vram_mib = (
|
||||
result["baseline_vram"] / (1024**2)
|
||||
if result["baseline_vram"] is not None
|
||||
else float("nan")
|
||||
)
|
||||
patched_vram_mib = (
|
||||
result["patched_vram"] / (1024**2)
|
||||
if result["patched_vram"] is not None
|
||||
else float("nan")
|
||||
)
|
||||
rows.append(
|
||||
(
|
||||
cfg["label"],
|
||||
cfg["batch"],
|
||||
cfg["seq_len"],
|
||||
cfg["hidden_size"],
|
||||
cfg["moe_intermediate_size"],
|
||||
cfg["n_experts"],
|
||||
cfg["top_k"],
|
||||
cfg["groups"],
|
||||
backend,
|
||||
result["baseline_ms"],
|
||||
result["patched_ms"],
|
||||
result["speedup"],
|
||||
baseline_vram_mib,
|
||||
patched_vram_mib,
|
||||
result["min_tokens"],
|
||||
result["max_tokens"],
|
||||
result["max_diff"],
|
||||
result["accuracy_ok"],
|
||||
)
|
||||
)
|
||||
status = "OK" if result["accuracy_ok"] else "FAIL"
|
||||
print(
|
||||
f"{cfg['label']:>10} {cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {backend:>8}"
|
||||
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
|
||||
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {result['max_diff']:>10.3e} {status:>5}"
|
||||
)
|
||||
if not result["accuracy_ok"]:
|
||||
LOG.warning(
|
||||
"Accuracy tolerance exceeded for %s backend=%s: diff=%.3e (> %.1e)",
|
||||
cfg["label"],
|
||||
backend,
|
||||
result["max_diff"],
|
||||
ACCURACY_TOLERANCE,
|
||||
)
|
||||
|
||||
if args.output:
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
with args.output.open("w", newline="") as fp:
|
||||
writer = csv.writer(fp)
|
||||
writer.writerow(header)
|
||||
writer.writerows(rows)
|
||||
print(f"Results written to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"'
|
||||
)
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""Inspect Qwen2 MoE expert implementations for grouped-mm debugging."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
sys.path.extend(
|
||||
[
|
||||
str(ROOT / "transformers" / "src"),
|
||||
str(ROOT / "src"),
|
||||
]
|
||||
)
|
||||
|
||||
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
|
||||
from axolotl.kernels.moe.torch_grouped import _iter_expert_impls
|
||||
|
||||
|
||||
def main() -> None:
|
||||
cfg = Qwen2MoeConfig(
|
||||
hidden_size=4096,
|
||||
moe_intermediate_size=14336,
|
||||
shared_expert_intermediate_size=14336,
|
||||
num_experts=32,
|
||||
num_experts_per_tok=4,
|
||||
)
|
||||
|
||||
block = Qwen2MoeSparseMoeBlock(cfg).to("cuda", dtype=torch.bfloat16)
|
||||
experts = block.experts
|
||||
experts._ax_parent_block = block
|
||||
|
||||
impls = _iter_expert_impls(experts)
|
||||
print(f"impl count: {len(impls)}")
|
||||
for idx, impl in enumerate(impls[:8]):
|
||||
has_gate = hasattr(impl, "gate_proj")
|
||||
has_up = hasattr(impl, "up_proj")
|
||||
print(
|
||||
f"impl[{idx}] type={impl.__class__.__name__} has_gate={has_gate} has_up={has_up}"
|
||||
)
|
||||
if has_gate:
|
||||
print(f" gate shape {tuple(impl.gate_proj.weight.shape)}")
|
||||
print(f" up shape {tuple(impl.up_proj.weight.shape)}")
|
||||
print(f" down shape {tuple(impl.down_proj.weight.shape)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,47 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Probe PyTorch for grouped GEMM operator names and namespaces.
|
||||
Run: python scripts/probe_torch_grouped_ops.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
import torch
|
||||
except Exception as e:
|
||||
print("Failed to import torch:", e)
|
||||
sys.exit(1)
|
||||
|
||||
print("torch version:", torch.__version__)
|
||||
namespaces = [n for n in dir(torch.ops) if not n.startswith("_")]
|
||||
print("ops namespaces:", namespaces)
|
||||
|
||||
found_any = False
|
||||
for ns in namespaces:
|
||||
obj = getattr(torch.ops, ns, None)
|
||||
ops = []
|
||||
if obj is not None:
|
||||
try:
|
||||
ops = dir(obj)
|
||||
except Exception as e:
|
||||
print(f"warning: failed to list ops for namespace {ns}: {e}")
|
||||
cands = [
|
||||
o
|
||||
for o in ops
|
||||
if ("group" in o.lower())
|
||||
or ("mm_grouped" in o.lower())
|
||||
or ("matmul_grouped" in o.lower())
|
||||
or ("grouped" in o.lower())
|
||||
]
|
||||
if cands:
|
||||
found_any = True
|
||||
print(f"namespace {ns} candidates:", cands)
|
||||
|
||||
if not found_any:
|
||||
print("No grouped GEMM candidates found. PyTorch >= 2.8 is recommended.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
setup.py
1
setup.py
@@ -124,7 +124,6 @@ extras_require = {
|
||||
"ring-flash-attn": [
|
||||
"flash-attn==2.8.3",
|
||||
"ring-flash-attn>=0.1.7",
|
||||
"yunchang==0.6.0",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.17.5",
|
||||
|
||||
@@ -120,6 +120,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.use_wandb:
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
else:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
|
||||
training_args_cls = None
|
||||
blocklist_args_kwargs = []
|
||||
if self.cfg.rl is RLType.SIMPO:
|
||||
@@ -129,10 +134,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.cpo_alpha is not None:
|
||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||
|
||||
# Handle when max_prompt_length == max_length from defaults
|
||||
# CPOTrainer requires strictly less than
|
||||
if (
|
||||
training_args_kwargs["max_prompt_length"]
|
||||
== training_args_kwargs["max_length"]
|
||||
):
|
||||
training_args_kwargs["max_prompt_length"] -= 1
|
||||
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
training_args_cls = AxolotlORPOConfig
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
training_args_cls = AxolotlKTOConfig
|
||||
@@ -144,9 +155,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kto_undesirable_weight or 1.0
|
||||
)
|
||||
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl is RLType.GRPO:
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Any, Mapping
|
||||
|
||||
def chat_message_transform_builder(
|
||||
train_on_inputs=False,
|
||||
conversations_field: str = "conversations",
|
||||
conversations_field: str = "messages",
|
||||
message_field_role: str | list[str] | None = None, # commonly "role"
|
||||
message_field_content: str | list[str] | None = None, # commonly "content"
|
||||
message_field_training: str | list[str] | None = None, # commonly "weight"
|
||||
@@ -20,13 +20,13 @@ def chat_message_transform_builder(
|
||||
If True, the transform will train on the inputs. If False, the transform will train on the targets.
|
||||
Defaults to False.
|
||||
conversations_field (str, optional):
|
||||
The field name of the conversations. Defaults to "conversations".
|
||||
The field name of the conversations. Defaults to "messages".
|
||||
message_field_role (str | list[str], optional):
|
||||
The field name of the role. Defaults to "role".
|
||||
The field name of the role.
|
||||
message_field_content (str | list[str], optional):
|
||||
The field name of the message content. Defaults to "content".
|
||||
The field name of the message content.
|
||||
message_field_training (str | list[str], optional):
|
||||
The field name of the train/weight. Defaults to "weight".
|
||||
The field name of the train/weight.
|
||||
|
||||
Returns:
|
||||
Callable:
|
||||
|
||||
@@ -27,7 +27,6 @@ class DPOStrategy:
|
||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
||||
if cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -65,6 +65,7 @@ plugins:
|
||||
- qwen2_5_vl
|
||||
- qwen3
|
||||
- qwen3_moe
|
||||
- qwen3_next
|
||||
- smollm3
|
||||
- seed_oss
|
||||
- voxtral
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
from .backends import MOEBackend, get_moe_backend_name
|
||||
"""Mixture-of-Experts kernel implementations."""
|
||||
|
||||
__all__ = ["get_moe_backend_name", "MOEBackend"]
|
||||
from .indices import generate_permute_indices
|
||||
from .tt_cg_gemm import (
|
||||
ContiguousGroupedGEMM,
|
||||
ContiguousGroupedGEMMForwardOnly,
|
||||
cg_grouped_gemm,
|
||||
cg_grouped_gemm_forward,
|
||||
cg_grouped_gemm_forward_dynamic,
|
||||
)
|
||||
from .tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
|
||||
|
||||
__all__ = [
|
||||
"cg_grouped_gemm",
|
||||
"cg_grouped_gemm_forward",
|
||||
"cg_grouped_gemm_forward_dynamic",
|
||||
"ContiguousGroupedGEMM",
|
||||
"ContiguousGroupedGEMMForwardOnly",
|
||||
"generate_permute_indices",
|
||||
"mg_grouped_gemm",
|
||||
]
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
import warnings
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MOEBackend(str, Enum):
|
||||
AUTO = "auto"
|
||||
TORCH_GROUPED = "torch_grouped"
|
||||
NAIVE = "naive"
|
||||
|
||||
|
||||
def _probe_torch_grouped() -> bool:
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
|
||||
# Prefer a simple version check; exact APIs may vary across 2.8+.
|
||||
ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
|
||||
return ver >= (2, 8)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
|
||||
"""
|
||||
Resolve the desired MoE backend using, in order of precedence:
|
||||
- explicit preferred argument (e.g., from config)
|
||||
- auto detection
|
||||
"""
|
||||
choice = (preferred or "auto").lower()
|
||||
try:
|
||||
selected = MOEBackend(choice)
|
||||
except ValueError:
|
||||
warnings.warn(
|
||||
f"Unknown moe backend '{choice}', falling back to auto", stacklevel=2
|
||||
)
|
||||
selected = MOEBackend.AUTO
|
||||
|
||||
if selected == MOEBackend.AUTO:
|
||||
if _probe_torch_grouped():
|
||||
return MOEBackend.TORCH_GROUPED
|
||||
return MOEBackend.NAIVE
|
||||
if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped():
|
||||
warnings.warn(
|
||||
"torch_grouped requested but torch>=2.8 not detected; falling back to naive",
|
||||
stacklevel=2,
|
||||
)
|
||||
return MOEBackend.NAIVE
|
||||
return selected
|
||||
5
src/axolotl/kernels/moe/indices/__init__.py
Normal file
5
src/axolotl/kernels/moe/indices/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Token permutation utilities for grouped MoE kernels."""
|
||||
|
||||
from .indices import generate_permute_indices
|
||||
|
||||
__all__ = ["generate_permute_indices"]
|
||||
144
src/axolotl/kernels/moe/indices/indices.py
Normal file
144
src/axolotl/kernels/moe/indices/indices.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Vendored token permutation kernels from TorchTitan."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
__all__ = ["generate_permute_indices"]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fill_indices_kernel(
|
||||
tokens_per_expert_group_ptr,
|
||||
start_index_values_ptr,
|
||||
write_offsets_ptr,
|
||||
output_ptr,
|
||||
experts_per_rank: tl.constexpr,
|
||||
num_ranks: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_programs = tl.num_programs(axis=0)
|
||||
|
||||
for expert_id in range(pid, experts_per_rank, num_programs):
|
||||
write_offset = tl.load(write_offsets_ptr + expert_id)
|
||||
|
||||
for r in range(num_ranks):
|
||||
idx = r * experts_per_rank + expert_id
|
||||
|
||||
start_index = tl.load(start_index_values_ptr + idx)
|
||||
length = tl.load(tokens_per_expert_group_ptr + idx)
|
||||
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
for chunk_start in range(0, length, BLOCK_SIZE):
|
||||
chunk_offsets = chunk_start + offsets
|
||||
mask = chunk_offsets < length
|
||||
values = start_index + chunk_offsets
|
||||
dest_indices = write_offset + chunk_offsets
|
||||
tl.store(output_ptr + dest_indices, values, mask=mask)
|
||||
|
||||
write_offset += length
|
||||
|
||||
|
||||
def fill_indices_wrapper(
|
||||
tokens_per_expert_group: torch.Tensor,
|
||||
start_index_values: torch.Tensor,
|
||||
write_offsets: torch.Tensor,
|
||||
experts_per_rank: int,
|
||||
num_ranks: int,
|
||||
max_len: int,
|
||||
block_size: int = 128,
|
||||
max_blocks: int = 1024,
|
||||
):
|
||||
permuted_indices = torch.full(
|
||||
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
|
||||
)
|
||||
num_blocks = min(experts_per_rank, max_blocks)
|
||||
grid = (num_blocks,)
|
||||
_fill_indices_kernel[grid](
|
||||
tokens_per_expert_group,
|
||||
start_index_values,
|
||||
write_offsets,
|
||||
permuted_indices,
|
||||
experts_per_rank,
|
||||
num_ranks,
|
||||
BLOCK_SIZE=block_size,
|
||||
)
|
||||
return permuted_indices
|
||||
|
||||
|
||||
def fill_indices_cpu(
|
||||
tokens_per_expert_group: torch.Tensor,
|
||||
start_index_values: torch.Tensor,
|
||||
write_offsets: torch.Tensor,
|
||||
experts_per_rank: int,
|
||||
num_ranks: int,
|
||||
max_len: int,
|
||||
):
|
||||
permuted_indices = torch.full((max_len,), -1, dtype=torch.int32)
|
||||
for expert_id in range(experts_per_rank):
|
||||
write_start = write_offsets[expert_id].item()
|
||||
for r in range(num_ranks):
|
||||
idx = r * experts_per_rank + expert_id
|
||||
start_index = start_index_values[idx].item()
|
||||
length = tokens_per_expert_group[idx].item()
|
||||
if length > 0:
|
||||
end_idx = min(write_start + length, max_len)
|
||||
permuted_indices[write_start:end_idx] = torch.arange(
|
||||
start_index,
|
||||
start_index + (end_idx - write_start),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
write_start += length
|
||||
return permuted_indices
|
||||
|
||||
|
||||
def generate_permute_indices(
|
||||
tokens_per_expert_group: torch.Tensor,
|
||||
experts_per_rank: int,
|
||||
num_ranks: int,
|
||||
max_len: int,
|
||||
alignment: int,
|
||||
use_cpu: bool = False,
|
||||
):
|
||||
start_index_values = (
|
||||
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
|
||||
)
|
||||
|
||||
total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
|
||||
total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment)
|
||||
|
||||
m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to(
|
||||
torch.int32
|
||||
)
|
||||
|
||||
m_offsets = torch.cumsum(m_sizes, 0)
|
||||
write_offsets = m_offsets - m_sizes
|
||||
|
||||
if use_cpu:
|
||||
permuted_indices = fill_indices_cpu(
|
||||
tokens_per_expert_group,
|
||||
start_index_values,
|
||||
write_offsets,
|
||||
experts_per_rank,
|
||||
num_ranks,
|
||||
max_len,
|
||||
)
|
||||
else:
|
||||
permuted_indices = fill_indices_wrapper(
|
||||
tokens_per_expert_group,
|
||||
start_index_values,
|
||||
write_offsets,
|
||||
experts_per_rank,
|
||||
num_ranks,
|
||||
max_len,
|
||||
)
|
||||
|
||||
return permuted_indices, m_sizes, m_offsets.to(torch.int32)
|
||||
@@ -1,371 +0,0 @@
|
||||
"""Minimal grouped GEMM fast path for MoE experts using PyTorch _grouped_mm."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
_LOGGER = logging.getLogger("axolotl.moe.grouped")
|
||||
|
||||
|
||||
def available() -> bool:
|
||||
try:
|
||||
major, minor = map(int, torch.__version__.split("+")[0].split(".")[:2])
|
||||
if (major, minor) < (2, 8):
|
||||
return False
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
sm, _ = torch.cuda.get_device_capability()
|
||||
if sm < 9:
|
||||
return False
|
||||
return hasattr(torch.ops, "_grouped_mm")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _iter_expert_impls(
|
||||
experts_module, visited: Optional[set[int]] = None
|
||||
) -> List[torch.nn.Module]:
|
||||
if visited is None:
|
||||
visited = set()
|
||||
module_id = id(experts_module)
|
||||
if module_id in visited:
|
||||
return []
|
||||
visited.add(module_id)
|
||||
|
||||
impls: List[torch.nn.Module] = []
|
||||
for exp in experts_module:
|
||||
candidate = getattr(exp, "mlp", getattr(exp, "ffn", exp))
|
||||
if hasattr(candidate, "gate_proj") and hasattr(candidate, "up_proj"):
|
||||
impls.append(candidate)
|
||||
continue
|
||||
nested = getattr(candidate, "experts", None)
|
||||
if nested is not None:
|
||||
impls.extend(_iter_expert_impls(nested, visited))
|
||||
continue
|
||||
raise RuntimeError(
|
||||
"torch_grouped: unable to resolve expert implementation for module"
|
||||
)
|
||||
return impls
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GroupedWeightStorage:
|
||||
pattern: str
|
||||
gate: torch.Tensor
|
||||
up: torch.Tensor
|
||||
down: torch.Tensor
|
||||
fused_gate_up: torch.Tensor
|
||||
dtype: torch.dtype
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _allocate_fused_gate_up(
|
||||
num_experts: int,
|
||||
gate_shape: torch.Size,
|
||||
up_shape: torch.Size,
|
||||
*,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if gate_shape[1] != up_shape[1]:
|
||||
raise RuntimeError(
|
||||
"torch_grouped: gate and up projections must share the hidden dimension"
|
||||
)
|
||||
|
||||
fused = torch.empty(
|
||||
(num_experts, gate_shape[0] + up_shape[0], gate_shape[1]),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
gate_view = fused[:, : gate_shape[0]]
|
||||
up_view = fused[:, gate_shape[0] : gate_shape[0] + up_shape[0]]
|
||||
return fused, gate_view, up_view
|
||||
|
||||
|
||||
def _ensure_grouped_weights(
|
||||
experts_module, expert_impls: List[torch.nn.Module], sample_mod: torch.nn.Module
|
||||
) -> _GroupedWeightStorage:
|
||||
storage: Optional[_GroupedWeightStorage] = getattr(
|
||||
experts_module, "_ax_grouped_storage", None
|
||||
)
|
||||
|
||||
def _store(new_storage: _GroupedWeightStorage) -> _GroupedWeightStorage:
|
||||
experts_module._ax_grouped_storage = new_storage
|
||||
return new_storage
|
||||
|
||||
# Identify expert parameter layout
|
||||
if (
|
||||
hasattr(sample_mod, "w1")
|
||||
and hasattr(sample_mod, "w3")
|
||||
and hasattr(sample_mod, "w2")
|
||||
):
|
||||
pattern = "swi_glu"
|
||||
num_experts = len(expert_impls)
|
||||
w1_shape = sample_mod.w1.weight.shape
|
||||
w3_shape = sample_mod.w3.weight.shape
|
||||
w2_shape = sample_mod.w2.weight.shape
|
||||
if (
|
||||
storage is not None
|
||||
and storage.pattern == pattern
|
||||
and storage.dtype == sample_mod.w1.weight.dtype
|
||||
and storage.device == sample_mod.w1.weight.device
|
||||
and storage.gate.shape[1:] == w1_shape
|
||||
):
|
||||
return storage
|
||||
|
||||
fused, gate, up = _allocate_fused_gate_up(
|
||||
num_experts,
|
||||
w1_shape,
|
||||
w3_shape,
|
||||
device=sample_mod.w1.weight.device,
|
||||
dtype=sample_mod.w1.weight.dtype,
|
||||
)
|
||||
down = torch.empty(
|
||||
(num_experts, *w2_shape),
|
||||
device=sample_mod.w2.weight.device,
|
||||
dtype=sample_mod.w2.weight.dtype,
|
||||
)
|
||||
with torch.no_grad():
|
||||
for idx, mod in enumerate(expert_impls):
|
||||
gate[idx].copy_(mod.w1.weight.detach())
|
||||
up[idx].copy_(mod.w3.weight.detach())
|
||||
down[idx].copy_(mod.w2.weight.detach())
|
||||
mod.w1.weight.detach_()
|
||||
mod.w1.weight.set_(gate[idx])
|
||||
mod.w3.weight.detach_()
|
||||
mod.w3.weight.set_(up[idx])
|
||||
mod.w2.weight.detach_()
|
||||
mod.w2.weight.set_(down[idx])
|
||||
|
||||
return _store(
|
||||
_GroupedWeightStorage(
|
||||
pattern=pattern,
|
||||
gate=gate,
|
||||
up=up,
|
||||
down=down,
|
||||
fused_gate_up=fused,
|
||||
dtype=gate.dtype,
|
||||
device=gate.device,
|
||||
)
|
||||
)
|
||||
|
||||
if hasattr(sample_mod, "gate_up_proj") and hasattr(sample_mod, "down_proj"):
|
||||
pattern = "fused_gate_up"
|
||||
gate_weight = sample_mod.gate_up_proj.weight
|
||||
down_weight = sample_mod.down_proj.weight
|
||||
if (
|
||||
storage is not None
|
||||
and storage.pattern == pattern
|
||||
and storage.dtype == gate_weight.dtype
|
||||
and storage.device == gate_weight.device
|
||||
and storage.gate.shape[1:]
|
||||
== (gate_weight.shape[0] // 2, gate_weight.shape[1])
|
||||
):
|
||||
return storage
|
||||
|
||||
num_experts = len(expert_impls)
|
||||
gate_full = torch.empty(
|
||||
(num_experts, *gate_weight.shape),
|
||||
device=gate_weight.device,
|
||||
dtype=gate_weight.dtype,
|
||||
)
|
||||
down = torch.empty(
|
||||
(num_experts, *down_weight.shape),
|
||||
device=down_weight.device,
|
||||
dtype=down_weight.dtype,
|
||||
)
|
||||
with torch.no_grad():
|
||||
for idx, mod in enumerate(expert_impls):
|
||||
gate_full[idx].copy_(mod.gate_up_proj.weight.detach())
|
||||
down[idx].copy_(mod.down_proj.weight.detach())
|
||||
mod.gate_up_proj.weight.detach_()
|
||||
mod.gate_up_proj.weight.set_(gate_full[idx])
|
||||
mod.down_proj.weight.detach_()
|
||||
mod.down_proj.weight.set_(down[idx])
|
||||
|
||||
inter = gate_weight.shape[0] // 2
|
||||
gate = gate_full[:, :inter]
|
||||
up = gate_full[:, inter:]
|
||||
return _store(
|
||||
_GroupedWeightStorage(
|
||||
pattern=pattern,
|
||||
gate=gate,
|
||||
up=up,
|
||||
down=down,
|
||||
fused_gate_up=gate_full,
|
||||
dtype=gate.dtype,
|
||||
device=gate.device,
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(sample_mod, "up_proj")
|
||||
and hasattr(sample_mod, "gate_proj")
|
||||
and hasattr(sample_mod, "down_proj")
|
||||
):
|
||||
pattern = "dual_proj"
|
||||
up_weight = sample_mod.up_proj.weight
|
||||
gate_weight = sample_mod.gate_proj.weight
|
||||
down_weight = sample_mod.down_proj.weight
|
||||
if (
|
||||
storage is not None
|
||||
and storage.pattern == pattern
|
||||
and storage.dtype == sample_mod.up_proj.weight.dtype
|
||||
and storage.device == sample_mod.up_proj.weight.device
|
||||
and storage.gate.shape[1:] == gate_weight.shape
|
||||
):
|
||||
return storage
|
||||
|
||||
num_experts = len(expert_impls)
|
||||
fused, gate, up = _allocate_fused_gate_up(
|
||||
num_experts,
|
||||
gate_weight.shape,
|
||||
up_weight.shape,
|
||||
device=gate_weight.device,
|
||||
dtype=gate_weight.dtype,
|
||||
)
|
||||
down = torch.empty(
|
||||
(num_experts, *down_weight.shape),
|
||||
device=down_weight.device,
|
||||
dtype=down_weight.dtype,
|
||||
)
|
||||
with torch.no_grad():
|
||||
for idx, mod in enumerate(expert_impls):
|
||||
gate[idx].copy_(mod.gate_proj.weight.detach())
|
||||
up[idx].copy_(mod.up_proj.weight.detach())
|
||||
down[idx].copy_(mod.down_proj.weight.detach())
|
||||
mod.up_proj.weight.detach_()
|
||||
mod.up_proj.weight.set_(up[idx])
|
||||
mod.gate_proj.weight.detach_()
|
||||
mod.gate_proj.weight.set_(gate[idx])
|
||||
mod.down_proj.weight.detach_()
|
||||
mod.down_proj.weight.set_(down[idx])
|
||||
|
||||
return _store(
|
||||
_GroupedWeightStorage(
|
||||
pattern=pattern,
|
||||
gate=gate,
|
||||
up=up,
|
||||
down=down,
|
||||
fused_gate_up=fused,
|
||||
dtype=gate.dtype,
|
||||
device=gate.device,
|
||||
)
|
||||
)
|
||||
|
||||
raise RuntimeError(
|
||||
"torch_grouped: unsupported expert module layout for grouped weights"
|
||||
)
|
||||
|
||||
|
||||
def moe_ffn_forward_grouped(
|
||||
hidden_states: torch.Tensor,
|
||||
gate_linear: torch.nn.Linear,
|
||||
experts_module,
|
||||
top_k: int,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
if not available():
|
||||
return None, None
|
||||
|
||||
bsz, seqlen, hdim = hidden_states.shape
|
||||
tokens = bsz * seqlen
|
||||
device = hidden_states.device
|
||||
|
||||
routing_dtype = gate_linear.weight.dtype
|
||||
expert_dtype = hidden_states.dtype
|
||||
|
||||
if expert_dtype not in (torch.bfloat16, torch.float16):
|
||||
_LOGGER.debug(
|
||||
"torch_grouped: unsupported expert dtype %s; falling back to naive",
|
||||
expert_dtype,
|
||||
)
|
||||
return None, None
|
||||
|
||||
parent_block = None
|
||||
parent_ref = getattr(experts_module, "_ax_parent_block_ref", None)
|
||||
if parent_ref is not None:
|
||||
try:
|
||||
parent_block = parent_ref()
|
||||
except TypeError:
|
||||
parent_block = None
|
||||
|
||||
expert_container = getattr(experts_module, "experts", experts_module)
|
||||
|
||||
expert_impls = _iter_expert_impls(expert_container)
|
||||
sample_mod = expert_impls[0]
|
||||
storage = _ensure_grouped_weights(expert_container, expert_impls, sample_mod)
|
||||
w_gate = storage.gate
|
||||
w_up = storage.up
|
||||
w2 = storage.down
|
||||
|
||||
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
|
||||
router_logits = gate_linear(x_flat.to(routing_dtype))
|
||||
|
||||
shared_out_flat: Optional[torch.Tensor] = None
|
||||
shared_owner = parent_block if parent_block is not None else experts_module
|
||||
if hasattr(shared_owner, "shared_expert"):
|
||||
shared_expert = shared_owner.shared_expert
|
||||
shared_out_flat = shared_expert(x_flat)
|
||||
shared_out_flat = shared_out_flat.to(expert_dtype)
|
||||
shared_gate = getattr(shared_owner, "shared_expert_gate", None)
|
||||
if shared_gate is not None:
|
||||
gate_input = shared_gate(x_flat.to(shared_gate.weight.dtype))
|
||||
gate_vals = torch.sigmoid(gate_input)
|
||||
shared_out_flat.mul_(gate_vals.to(expert_dtype))
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
|
||||
|
||||
flat_idx = topk_idx.view(-1)
|
||||
num_experts = len(expert_impls)
|
||||
if flat_idx.numel() == 0:
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
sorted_experts, perm = torch.sort(flat_idx)
|
||||
assignments = torch.bincount(sorted_experts, minlength=num_experts)
|
||||
if assignments.sum() == 0:
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
token_indices_sorted = torch.div(perm, top_k, rounding_mode="floor").contiguous()
|
||||
scores_sorted = topk_weight.reshape(-1).index_select(0, perm)
|
||||
|
||||
gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
|
||||
routed_input = torch.gather(x_flat, 0, gather_index)
|
||||
|
||||
counts_i32 = assignments.to(device=device, dtype=torch.int32)
|
||||
offsets = torch.cumsum(counts_i32, dim=0).to(dtype=torch.int32)
|
||||
mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype
|
||||
routed_in = routed_input.to(mm_dtype)
|
||||
w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype)
|
||||
w_up_t = w_up.transpose(-2, -1).to(mm_dtype)
|
||||
w2_t = w2.transpose(-2, -1).to(mm_dtype)
|
||||
|
||||
routed_in = routed_in.contiguous()
|
||||
w_gate_t = w_gate_t.contiguous()
|
||||
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
|
||||
torch.ops.aten.silu_(gate_out)
|
||||
w_up_t = w_up_t.contiguous()
|
||||
up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets)
|
||||
gate_out.mul_(up_out)
|
||||
gate_out = gate_out.contiguous()
|
||||
w2_t = w2_t.contiguous()
|
||||
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets).to(expert_dtype)
|
||||
|
||||
weights = scores_sorted.unsqueeze(-1).to(expert_dtype)
|
||||
down_out.mul_(weights)
|
||||
|
||||
combined = torch.zeros_like(x_flat)
|
||||
combined.scatter_add_(0, gather_index, down_out)
|
||||
|
||||
output = combined.view(bsz, seqlen, hdim)
|
||||
if shared_out_flat is not None:
|
||||
output = output + shared_out_flat.view(bsz, seqlen, hdim)
|
||||
return output, router_logits
|
||||
17
src/axolotl/kernels/moe/tt_cg_gemm/__init__.py
Normal file
17
src/axolotl/kernels/moe/tt_cg_gemm/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Vendored Triton contiguous grouped GEMM kernels from TorchTitan."""
|
||||
|
||||
from .cg_backward import ContiguousGroupedGEMM
|
||||
from .cg_forward import (
|
||||
ContiguousGroupedGEMM as ContiguousGroupedGEMMForwardOnly,
|
||||
cg_grouped_gemm,
|
||||
cg_grouped_gemm_forward,
|
||||
cg_grouped_gemm_forward_dynamic,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"cg_grouped_gemm",
|
||||
"cg_grouped_gemm_forward",
|
||||
"cg_grouped_gemm_forward_dynamic",
|
||||
"ContiguousGroupedGEMM",
|
||||
"ContiguousGroupedGEMMForwardOnly",
|
||||
]
|
||||
290
src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py
Normal file
290
src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""Vendored backward pass for Triton contiguous grouped GEMM."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .cg_forward import cg_grouped_gemm_forward
|
||||
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
|
||||
|
||||
GROUP_SIZE_M = 128
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=STANDARD_CONFIGS,
|
||||
key=["M_TOTAL", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_cg_backward_dx(
|
||||
grad_output_ptr,
|
||||
b_ptr,
|
||||
grad_input_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
|
||||
):
|
||||
"""Compute gradients with respect to inputs."""
|
||||
|
||||
pid = tl.program_id(0)
|
||||
|
||||
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
|
||||
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
|
||||
tile_m = pid // num_k_tiles
|
||||
tile_k = pid % num_k_tiles
|
||||
|
||||
m_start = tile_m * BLOCK_SIZE_M
|
||||
k_start = tile_k * BLOCK_SIZE_K
|
||||
|
||||
if m_start < M_TOTAL:
|
||||
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_k = offs_k < K
|
||||
|
||||
group_idx = m_start // GROUP_SIZE_M
|
||||
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
|
||||
|
||||
grad_input = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_K], dtype=tl.float32)
|
||||
|
||||
for n in range(0, N, BLOCK_SIZE_N):
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N) + n
|
||||
mask_n = offs_n < N
|
||||
|
||||
mask_go = mask_m[:, None] & mask_n[None, :]
|
||||
mask_w = mask_n[:, None] & mask_k[None, :]
|
||||
|
||||
go_ptrs = grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
|
||||
|
||||
w_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
||||
w = tl.load(w_ptrs, mask=mask_w, other=0.0).to(tl.float32)
|
||||
|
||||
grad_input += tl.dot(go, w)
|
||||
|
||||
grad_input_ptrs = grad_input_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
mask_gi = mask_m[:, None] & mask_k[None, :]
|
||||
tl.store(grad_input_ptrs, grad_input, mask=mask_gi)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _kernel_cg_backward_dw(
|
||||
grad_output_ptr,
|
||||
inputs_ptr,
|
||||
grad_weights_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""Simplified kernel for expert weight gradients."""
|
||||
|
||||
pid = tl.program_id(0)
|
||||
|
||||
expert_id = pid // ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
|
||||
position_id = pid % ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
|
||||
|
||||
if expert_id < NUM_EXPERTS:
|
||||
n_tiles = K // BLOCK_SIZE_K
|
||||
tile_n = position_id // n_tiles
|
||||
tile_k = position_id % n_tiles
|
||||
|
||||
n_start = tile_n * BLOCK_SIZE_N
|
||||
k_start = tile_k * BLOCK_SIZE_K
|
||||
|
||||
if n_start < N and k_start < K:
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
|
||||
|
||||
mask_n = offs_n < N
|
||||
mask_k = offs_k < K
|
||||
|
||||
grad_weights = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_K], dtype=tl.float32)
|
||||
|
||||
for group_idx in range(0, M_TOTAL // GROUP_SIZE_M):
|
||||
group_start = group_idx * GROUP_SIZE_M
|
||||
group_expert = tl.load(indices_ptr + group_start)
|
||||
|
||||
if group_expert == expert_id:
|
||||
for m_offset in range(0, GROUP_SIZE_M, BLOCK_SIZE_M):
|
||||
m_start = group_start + m_offset
|
||||
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
|
||||
|
||||
mask_m = offs_m < min(group_start + GROUP_SIZE_M, M_TOTAL)
|
||||
|
||||
go_ptrs = (
|
||||
grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
)
|
||||
mask_go = mask_m[:, None] & mask_n[None, :]
|
||||
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
|
||||
|
||||
in_ptrs = inputs_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
mask_in = mask_m[:, None] & mask_k[None, :]
|
||||
inp = tl.load(in_ptrs, mask=mask_in, other=0.0).to(tl.float32)
|
||||
|
||||
go_t = tl.trans(go)
|
||||
grad_weights += tl.dot(go_t, inp)
|
||||
|
||||
grad_w_ptrs = (
|
||||
grad_weights_ptr
|
||||
+ expert_id * N * K
|
||||
+ offs_n[:, None] * K
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
mask_gw = mask_n[:, None] & mask_k[None, :]
|
||||
tl.store(grad_w_ptrs, grad_weights, mask=mask_gw)
|
||||
|
||||
|
||||
def cg_grouped_gemm_backward_weights(
|
||||
grad_output: torch.Tensor,
|
||||
inputs: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
num_experts: int,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Backward pass for expert weights."""
|
||||
|
||||
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
|
||||
assert inputs.is_contiguous(), "Inputs tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, N = grad_output.shape
|
||||
_, K = inputs.shape
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
grad_weights = torch.zeros(
|
||||
(num_experts, N, K), device=grad_output.device, dtype=grad_output.dtype
|
||||
)
|
||||
|
||||
block_size_n = min(128, N)
|
||||
block_size_k = min(32, K)
|
||||
block_size_m = min(32, group_size_m)
|
||||
|
||||
n_tiles = triton.cdiv(N, block_size_n)
|
||||
k_tiles = triton.cdiv(K, block_size_k)
|
||||
grid = (num_experts * n_tiles * k_tiles,)
|
||||
|
||||
_kernel_cg_backward_dw[grid](
|
||||
grad_output,
|
||||
inputs,
|
||||
grad_weights,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
BLOCK_SIZE_N=block_size_n,
|
||||
BLOCK_SIZE_K=block_size_k,
|
||||
BLOCK_SIZE_M=block_size_m,
|
||||
)
|
||||
|
||||
return grad_weights
|
||||
|
||||
|
||||
def cg_grouped_gemm_backward_inputs(
|
||||
grad_output: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Backward pass for inputs."""
|
||||
|
||||
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
|
||||
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, N = grad_output.shape
|
||||
num_experts, _, K = expert_weights.shape
|
||||
|
||||
assert M_total % group_size_m == 0, (
|
||||
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
|
||||
)
|
||||
|
||||
grad_inputs = torch.zeros(
|
||||
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
|
||||
)
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(K, meta["BLOCK_SIZE_K"]),
|
||||
)
|
||||
|
||||
_kernel_cg_backward_dx[grid](
|
||||
grad_output,
|
||||
expert_weights,
|
||||
grad_inputs,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
)
|
||||
|
||||
return grad_inputs
|
||||
|
||||
|
||||
class ContiguousGroupedGEMM(torch.autograd.Function):
|
||||
"""Autograd function with full backward support."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
|
||||
ctx.save_for_backward(inputs, expert_weights, expert_indices)
|
||||
ctx.group_size_m = group_size_m
|
||||
|
||||
return cg_grouped_gemm_forward(
|
||||
inputs=inputs,
|
||||
expert_weights=expert_weights,
|
||||
expert_indices=expert_indices,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
inputs, expert_weights, expert_indices = ctx.saved_tensors
|
||||
group_size_m = ctx.group_size_m
|
||||
|
||||
grad_output = grad_output.contiguous()
|
||||
num_experts = expert_weights.shape[0]
|
||||
|
||||
grad_inputs = cg_grouped_gemm_backward_inputs(
|
||||
grad_output=grad_output,
|
||||
expert_weights=expert_weights,
|
||||
expert_indices=expert_indices,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
grad_weights = cg_grouped_gemm_backward_weights(
|
||||
grad_output=grad_output,
|
||||
inputs=inputs,
|
||||
expert_indices=expert_indices,
|
||||
num_experts=num_experts,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
grad_indices = None
|
||||
grad_group_size_m = None
|
||||
|
||||
return grad_inputs, grad_weights, grad_indices, grad_group_size_m
|
||||
311
src/axolotl/kernels/moe/tt_cg_gemm/cg_forward.py
Normal file
311
src/axolotl/kernels/moe/tt_cg_gemm/cg_forward.py
Normal file
@@ -0,0 +1,311 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Vendored forward Triton contiguous grouped GEMM kernels."""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
|
||||
|
||||
GROUP_SIZE_M = 128
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, super_group_m):
|
||||
group_id = tile_id // num_pid_in_group
|
||||
first_pid_m = group_id * super_group_m
|
||||
group_size_m = min(num_pid_m - first_pid_m, super_group_m)
|
||||
pid_m = first_pid_m + (tile_id % group_size_m)
|
||||
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
||||
return pid_m, pid_n
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=STANDARD_CONFIGS,
|
||||
key=["M_TOTAL", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_cg_persistent_forward(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
NUM_SMS: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
|
||||
SUPER_GROUP_M: tl.constexpr = 32,
|
||||
):
|
||||
"""
|
||||
Contiguous Grouped GEMM kernel forward (persistent variant).
|
||||
"""
|
||||
|
||||
c_type = c_ptr.dtype.element_ty
|
||||
|
||||
start_pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_tiles = num_pid_m * num_pid_n
|
||||
tile_id_c = start_pid - NUM_SMS
|
||||
num_pid_in_group = SUPER_GROUP_M * num_pid_n
|
||||
|
||||
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS):
|
||||
tile_m_idx, tile_n_idx = _compute_pid(
|
||||
tile_id, num_pid_in_group, num_pid_m, SUPER_GROUP_M
|
||||
)
|
||||
|
||||
m_start = tile_m_idx * BLOCK_SIZE_M
|
||||
n_start = tile_n_idx * BLOCK_SIZE_N
|
||||
|
||||
if m_start < M_TOTAL:
|
||||
offs_m = m_start + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = n_start + tl.arange(0, BLOCK_SIZE_N)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for ki in range(k_tiles):
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_n = offs_n < N
|
||||
mask_k = offs_k < K
|
||||
|
||||
mask_a = mask_m[:, None] & mask_k[None, :]
|
||||
mask_b = mask_n[:, None] & mask_k[None, :]
|
||||
|
||||
group_idx = m_start // GROUP_SIZE_M
|
||||
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
|
||||
|
||||
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
|
||||
|
||||
b_ptrs = (
|
||||
b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
|
||||
|
||||
accumulator += tl.dot(a, b.T)
|
||||
|
||||
tile_id_c += NUM_SMS
|
||||
tile_m_idx, tile_n_idx = _compute_pid(
|
||||
tile_id_c, num_pid_in_group, num_pid_m, SUPER_GROUP_M
|
||||
)
|
||||
|
||||
offs_m = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_n = offs_n < N
|
||||
mask_c = mask_m[:, None] & mask_n[None, :]
|
||||
|
||||
c = accumulator.to(tl.float32)
|
||||
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
tl.store(c_ptrs, c.to(c_type), mask=mask_c)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=STANDARD_CONFIGS,
|
||||
key=["M_TOTAL", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_cg_forward_aligned(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
|
||||
):
|
||||
"""
|
||||
Contiguous Grouped GEMM kernel forward for aligned inputs.
|
||||
"""
|
||||
|
||||
pid = tl.program_id(0)
|
||||
|
||||
c_type = c_ptr.dtype.element_ty
|
||||
|
||||
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
|
||||
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
tile_m = pid // num_n_tiles
|
||||
tile_n = pid % num_n_tiles
|
||||
|
||||
m_start = tile_m * BLOCK_SIZE_M
|
||||
n_start = tile_n * BLOCK_SIZE_N
|
||||
|
||||
if m_start < M_TOTAL:
|
||||
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_n = offs_n < N
|
||||
|
||||
group_idx = m_start // GROUP_SIZE_M
|
||||
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
|
||||
|
||||
acc = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], dtype=tl.float32)
|
||||
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + k
|
||||
mask_k = offs_k < K
|
||||
|
||||
mask_a = mask_m[:, None] & mask_k[None, :]
|
||||
mask_b = mask_n[:, None] & mask_k[None, :]
|
||||
|
||||
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
|
||||
|
||||
b_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
||||
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
|
||||
|
||||
acc += tl.dot(a, b.T)
|
||||
|
||||
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
mask_c = mask_m[:, None] & mask_n[None, :]
|
||||
tl.store(c_ptrs, acc.to(c_type), mask=mask_c)
|
||||
|
||||
|
||||
def cg_grouped_gemm_forward(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Contiguous grouped GEMM forward pass for MoE."""
|
||||
|
||||
assert inputs.is_contiguous(), "Input tensor must be contiguous"
|
||||
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, K = inputs.shape
|
||||
assert M_total % group_size_m == 0, (
|
||||
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
|
||||
)
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
num_experts, N, K_weights = expert_weights.shape
|
||||
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
|
||||
assert expert_indices.shape[0] == M_total, (
|
||||
"Expert indices length must match M_total"
|
||||
)
|
||||
|
||||
output = torch.empty((M_total, N), device=inputs.device, dtype=torch.bfloat16)
|
||||
|
||||
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
|
||||
grid = (NUM_SMS, 1, 1)
|
||||
_kernel_cg_persistent_forward[grid](
|
||||
inputs,
|
||||
expert_weights,
|
||||
output,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
NUM_SMS=NUM_SMS,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def cg_grouped_gemm_forward_dynamic(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Contiguous grouped GEMM forward pass for MoE with autotuned launch."""
|
||||
|
||||
assert inputs.is_contiguous(), "Input tensor must be contiguous"
|
||||
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, K = inputs.shape
|
||||
assert M_total % group_size_m == 0, (
|
||||
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
|
||||
)
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
num_experts, N, K_weights = expert_weights.shape
|
||||
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
|
||||
assert expert_indices.shape[0] == M_total, (
|
||||
"Expert indices length must match M_total"
|
||||
)
|
||||
|
||||
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(N, meta["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
_kernel_cg_forward_aligned[grid](
|
||||
inputs,
|
||||
expert_weights,
|
||||
output,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ContiguousGroupedGEMM(torch.autograd.Function):
|
||||
"""Autograd function for contiguous grouped GEMM forward pass only."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
|
||||
return cg_grouped_gemm_forward(
|
||||
inputs=inputs,
|
||||
expert_weights=expert_weights,
|
||||
expert_indices=expert_indices,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output): # pragma: no cover - not implemented
|
||||
raise NotImplementedError("Backward pass not implemented")
|
||||
|
||||
|
||||
def cg_grouped_gemm(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Convenience wrapper for the forward-only autograd function."""
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
return ContiguousGroupedGEMM.apply(
|
||||
inputs, expert_weights, expert_indices, group_size_m
|
||||
)
|
||||
31
src/axolotl/kernels/moe/tt_cg_gemm/cg_reference.py
Normal file
31
src/axolotl/kernels/moe/tt_cg_gemm/cg_reference.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Reference implementation for contiguous grouped GEMM."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def pytorch_reference(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""Simple PyTorch implementation for verification."""
|
||||
|
||||
M_total, K = inputs.shape
|
||||
num_experts, N, _ = expert_weights.shape
|
||||
|
||||
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
|
||||
|
||||
for i in range(0, M_total, group_size_m):
|
||||
end_idx = min(i + group_size_m, M_total)
|
||||
expert_idx = expert_indices[i].item()
|
||||
expert_weight = expert_weights[expert_idx]
|
||||
output[i:end_idx] = torch.matmul(inputs[i:end_idx], expert_weight.T)
|
||||
|
||||
return output
|
||||
209
src/axolotl/kernels/moe/tt_cg_gemm/tma_cuda_autotune.py
Normal file
209
src/axolotl/kernels/moe/tt_cg_gemm/tma_cuda_autotune.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Autotuning utilities for Triton contiguous grouped GEMM kernels."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.runtime import driver
|
||||
|
||||
|
||||
class CudaUtils:
|
||||
"""Helper utilities for CUDA specific Triton features."""
|
||||
|
||||
@staticmethod
|
||||
def is_cuda() -> bool:
|
||||
return driver.active.get_current_target().backend == "cuda"
|
||||
|
||||
@staticmethod
|
||||
def verify_tma() -> bool:
|
||||
return (
|
||||
CudaUtils.is_cuda()
|
||||
and torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability()[0] >= 9
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_num_sms() -> int:
|
||||
if not CudaUtils.is_cuda():
|
||||
raise RuntimeError("Triton is not running on CUDA backend")
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is not available")
|
||||
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
|
||||
|
||||
class TmaDescriptorHelper:
|
||||
"""Helper class for managing TMA descriptors in Triton kernels."""
|
||||
|
||||
class KernelParamWrapper:
|
||||
def __init__(self, desc: torch.Tensor):
|
||||
self.desc = desc
|
||||
|
||||
def tma_desc_cpu_ptr(self) -> int:
|
||||
return self.desc.data_ptr()
|
||||
|
||||
def __init__(self, tma_size: int = 128):
|
||||
if not CudaUtils.verify_tma():
|
||||
raise RuntimeError(
|
||||
"TMA not supported on this device (requires Hopper or newer)"
|
||||
)
|
||||
if "nv_tma_desc_type" not in dir(tl):
|
||||
raise RuntimeError(
|
||||
"TMA grid constant descriptors not supported in your Triton version"
|
||||
)
|
||||
|
||||
self.tma_size = tma_size
|
||||
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
|
||||
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
|
||||
self.descriptors: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def init_tma_descriptor(self, name: str) -> None:
|
||||
self.descriptors[name] = torch.empty(
|
||||
self.tma_size, device="cpu", dtype=torch.int8
|
||||
)
|
||||
|
||||
def fill_1d_tma_descriptor(
|
||||
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
|
||||
) -> None:
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_1d_tma_descriptor_inner(
|
||||
ptr, dim, block_dim, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def fill_2d_tma_descriptor(
|
||||
self,
|
||||
name: str,
|
||||
ptr: int,
|
||||
dim1: int,
|
||||
dim0: int,
|
||||
block_dim1: int,
|
||||
block_dim0: int,
|
||||
element_size: int,
|
||||
) -> None:
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_2d_tma_descriptor_inner(
|
||||
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def get_tma_descriptor_kernel_param(
|
||||
self, name: str
|
||||
) -> "TmaDescriptorHelper.KernelParamWrapper":
|
||||
if name not in self.descriptors or self.descriptors[name] is None:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
return self.KernelParamWrapper(self.descriptors[name])
|
||||
|
||||
|
||||
HOPPER_CONFIGS = [
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=8,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
STANDARD_CONFIGS = [
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=8,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def early_config_prune(configs, args, **kwargs):
|
||||
"""Filter out configurations that would exceed shared memory capacity."""
|
||||
k = kwargs.get("K", 0)
|
||||
valid_configs = [
|
||||
config for config in configs if config.kwargs.get("BLOCK_SIZE_K", 0) <= k
|
||||
]
|
||||
if not valid_configs and configs:
|
||||
return [
|
||||
min(
|
||||
configs,
|
||||
key=lambda c: c.kwargs.get("BLOCK_SIZE_K", float("inf")),
|
||||
)
|
||||
]
|
||||
|
||||
return valid_configs
|
||||
13
src/axolotl/kernels/moe/tt_mg_gemm/__init__.py
Normal file
13
src/axolotl/kernels/moe/tt_mg_gemm/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .mg_grouped_gemm import grouped_gemm_forward
|
||||
from .tma_autotuning import ALIGN_SIZE_M
|
||||
|
||||
__all__ = [
|
||||
"grouped_gemm_forward",
|
||||
"ALIGN_SIZE_M",
|
||||
]
|
||||
761
src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py
Normal file
761
src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py
Normal file
@@ -0,0 +1,761 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# credit - flat index forward kernel is derived from FBGemm:
|
||||
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
||||
|
||||
# pyre-unsafe
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .tma_autotuning import (
|
||||
_NV_CONFIGS,
|
||||
CudaUtils,
|
||||
early_config_prune,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
_allocator_registered = False
|
||||
|
||||
|
||||
def _torch_allocator(size: int, alignment: int, stream) -> torch.Tensor:
|
||||
return torch.empty(size, device="cuda", dtype=torch.int8)
|
||||
|
||||
|
||||
def _ensure_triton_allocator() -> None:
|
||||
global _allocator_registered
|
||||
if not _allocator_registered:
|
||||
triton.set_allocator(_torch_allocator)
|
||||
_allocator_registered = True
|
||||
|
||||
|
||||
# ============== Start Triton Kernels ===============
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=_NV_CONFIGS,
|
||||
key=["G", "M_BUCKET", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_mg_forward_hopper(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m_sizes,
|
||||
M_TOTAL,
|
||||
# problem sizes
|
||||
G: tl.constexpr,
|
||||
M_BUCKET: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
# config
|
||||
NUM_SMS: tl.constexpr,
|
||||
USE_EPILOGUE_SUBTILING: tl.constexpr,
|
||||
# tiles
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
) -> None:
|
||||
"""Flat index style forward kernel for Hopper using tensor descriptors."""
|
||||
tbidx = tl.program_id(0)
|
||||
|
||||
c_dtype = c_ptr.dtype.element_ty
|
||||
n_size = N // G
|
||||
|
||||
a_desc = tl.make_tensor_descriptor(
|
||||
a_ptr,
|
||||
shape=[M_TOTAL, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
||||
)
|
||||
b_desc = tl.make_tensor_descriptor(
|
||||
b_ptr,
|
||||
shape=[N, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
||||
)
|
||||
|
||||
M_end = tl.full([], 0, dtype=tl.int32)
|
||||
processed_tiles = 0
|
||||
|
||||
for g in range(G):
|
||||
M_start = M_end
|
||||
m_size = tl.load(m_sizes + g)
|
||||
M_end = M_start + m_size
|
||||
|
||||
if m_size > 0:
|
||||
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
||||
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
|
||||
group_num_tiles = num_m_tiles * num_n_tiles
|
||||
|
||||
while (
|
||||
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
|
||||
):
|
||||
group_index = tbidx - processed_tiles
|
||||
|
||||
tile_m_index = group_index % num_m_tiles
|
||||
tile_n_index = group_index // num_m_tiles
|
||||
|
||||
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
|
||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||
|
||||
cols_remaining = n_size - tile_n_index * BLOCK_SIZE_N
|
||||
col_mask = tl.arange(0, BLOCK_SIZE_N) < cols_remaining
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
|
||||
global_n_offset = (g * n_size + n_offset).to(tl.int32)
|
||||
|
||||
for k_offset in range(0, K, BLOCK_SIZE_K):
|
||||
k_remaining = K - k_offset
|
||||
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
|
||||
|
||||
a = a_desc.load([m_offset, k_offset])
|
||||
a_mask = row_mask[:, None] & k_mask[None, :]
|
||||
a = tl.where(a_mask, a, tl.zeros_like(a))
|
||||
|
||||
b = b_desc.load([global_n_offset, k_offset])
|
||||
b_mask = col_mask[:, None] & k_mask[None, :]
|
||||
b = tl.where(b_mask, b, tl.zeros_like(b))
|
||||
|
||||
accumulator += tl.dot(a, b.T)
|
||||
|
||||
local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||
|
||||
local_row_offsets = local_m_offset + tl.arange(0, BLOCK_SIZE_M)
|
||||
row_store_mask = local_row_offsets < m_size
|
||||
global_row = (M_start + local_row_offsets).to(tl.int32)
|
||||
|
||||
local_col_offsets = tile_n_index * BLOCK_SIZE_N + tl.arange(
|
||||
0, BLOCK_SIZE_N
|
||||
)
|
||||
col_store_mask = local_col_offsets < n_size
|
||||
|
||||
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
|
||||
|
||||
if USE_EPILOGUE_SUBTILING:
|
||||
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
|
||||
acc = tl.permute(acc, (0, 2, 1))
|
||||
acc0, acc1 = tl.split(acc)
|
||||
|
||||
col_offsets0 = local_col_offsets[: BLOCK_SIZE_N // 2]
|
||||
col_mask0 = col_store_mask[: BLOCK_SIZE_N // 2]
|
||||
ptr0 = c_ptr + global_row[:, None] * n_size + col_offsets0[None, :]
|
||||
tl.store(
|
||||
ptr0,
|
||||
acc0.to(c_dtype),
|
||||
mask=row_store_mask[:, None] & col_mask0[None, :],
|
||||
)
|
||||
|
||||
col_offsets1 = local_col_offsets[BLOCK_SIZE_N // 2 :]
|
||||
col_mask1 = col_store_mask[BLOCK_SIZE_N // 2 :]
|
||||
ptr1 = c_ptr + global_row[:, None] * n_size + col_offsets1[None, :]
|
||||
tl.store(
|
||||
ptr1,
|
||||
acc1.to(c_dtype),
|
||||
mask=row_store_mask[:, None] & col_mask1[None, :],
|
||||
)
|
||||
else:
|
||||
ptr = (
|
||||
c_ptr
|
||||
+ global_row[:, None] * n_size
|
||||
+ local_col_offsets[None, :]
|
||||
)
|
||||
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
|
||||
|
||||
tbidx += NUM_SMS
|
||||
|
||||
processed_tiles += group_num_tiles
|
||||
|
||||
|
||||
"""
|
||||
Backward pass for grouped GEMM with Triton, where grouping is M*G
|
||||
We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
|
||||
"""
|
||||
|
||||
|
||||
# ---- dx flat linear indexed ----
|
||||
@triton.autotune(
|
||||
configs=_NV_CONFIGS,
|
||||
key=["G", "M_BUCKET", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_mg_dx_tma(
|
||||
grad_output_ptr,
|
||||
w_ptr,
|
||||
grad_input_ptr,
|
||||
m_sizes,
|
||||
M_TOTAL,
|
||||
# problem sizes
|
||||
G: tl.constexpr,
|
||||
M_BUCKET: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
# config
|
||||
NUM_SMS: tl.constexpr,
|
||||
# tiles
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
) -> None:
|
||||
"""Compute grad_input = grad_output @ w using tensor descriptors."""
|
||||
tbidx = tl.program_id(0)
|
||||
|
||||
c_dtype = grad_input_ptr.dtype.element_ty
|
||||
|
||||
grad_output_desc = tl.make_tensor_descriptor(
|
||||
grad_output_ptr,
|
||||
shape=[M_TOTAL, N],
|
||||
strides=[N, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||
)
|
||||
w_desc = tl.make_tensor_descriptor(
|
||||
w_ptr,
|
||||
shape=[N, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
|
||||
)
|
||||
|
||||
M_end = tl.full([], 0, dtype=tl.int32)
|
||||
processed_tiles = 0
|
||||
|
||||
for g in range(G):
|
||||
M_start = M_end
|
||||
m_size = tl.load(m_sizes + g)
|
||||
M_end = M_start + m_size
|
||||
|
||||
if m_size > 0:
|
||||
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
|
||||
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
group_num_tiles = num_m_tiles * num_k_tiles
|
||||
|
||||
while (
|
||||
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
|
||||
):
|
||||
group_index = tbidx - processed_tiles
|
||||
|
||||
tile_m_index = group_index % num_m_tiles
|
||||
tile_k_index = group_index // num_m_tiles
|
||||
|
||||
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
|
||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||
|
||||
k_offset = tile_k_index * BLOCK_SIZE_K
|
||||
k_remaining_total = K - k_offset
|
||||
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining_total
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
|
||||
|
||||
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
|
||||
|
||||
for n_offset in range(0, N, BLOCK_SIZE_N):
|
||||
n_remaining = N - n_offset
|
||||
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
|
||||
|
||||
grad_y = grad_output_desc.load([m_offset, n_offset])
|
||||
grad_y_mask = row_mask[:, None] & n_mask[None, :]
|
||||
grad_y = tl.where(grad_y_mask, grad_y, tl.zeros_like(grad_y))
|
||||
|
||||
w_tile = w_desc.load([n_offset, k_offset])
|
||||
w_mask = n_mask[:, None] & k_mask[None, :]
|
||||
w_tile = tl.where(w_mask, w_tile, tl.zeros_like(w_tile))
|
||||
|
||||
accumulator += tl.dot(grad_y, w_tile)
|
||||
|
||||
local_row_offsets = tile_m_index * BLOCK_SIZE_M + tl.arange(
|
||||
0, BLOCK_SIZE_M
|
||||
)
|
||||
row_store_mask = local_row_offsets < m_size
|
||||
global_row = (M_start + local_row_offsets).to(tl.int32)
|
||||
|
||||
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
|
||||
col_store_mask = col_offsets < K
|
||||
|
||||
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
|
||||
|
||||
ptr = grad_input_ptr + global_row[:, None] * K + col_offsets[None, :]
|
||||
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
|
||||
|
||||
tbidx += NUM_SMS
|
||||
|
||||
processed_tiles += group_num_tiles
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=_NV_CONFIGS,
|
||||
key=["G", "M_BUCKET", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_mg_dw_tma(
|
||||
x_ptr,
|
||||
grad_output_ptr,
|
||||
grad_weight_ptr,
|
||||
m_sizes,
|
||||
M_TOTAL,
|
||||
# problem sizes
|
||||
G: tl.constexpr,
|
||||
M_BUCKET: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
# config
|
||||
NUM_SMS: tl.constexpr,
|
||||
# tiles
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
) -> None:
|
||||
"""Compute grad_weight = grad_output.T @ x using tensor descriptors."""
|
||||
tbidx = tl.program_id(0)
|
||||
|
||||
c_dtype = grad_weight_ptr.dtype.element_ty
|
||||
|
||||
x_desc = tl.make_tensor_descriptor(
|
||||
x_ptr,
|
||||
shape=[M_TOTAL, K],
|
||||
strides=[K, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
|
||||
)
|
||||
grad_output_desc = tl.make_tensor_descriptor(
|
||||
grad_output_ptr,
|
||||
shape=[M_TOTAL, N],
|
||||
strides=[N, 1],
|
||||
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
|
||||
)
|
||||
|
||||
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
total_tiles = num_n_tiles * num_k_tiles
|
||||
|
||||
for tile_idx in range(tbidx, total_tiles, NUM_SMS):
|
||||
tile_n_idx = tile_idx % num_n_tiles
|
||||
tile_k_idx = tile_idx // num_n_tiles
|
||||
|
||||
n_offset = tile_n_idx * BLOCK_SIZE_N
|
||||
n_remaining = N - n_offset
|
||||
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
|
||||
|
||||
k_offset = tile_k_idx * BLOCK_SIZE_K
|
||||
k_remaining = K - k_offset
|
||||
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
|
||||
|
||||
M_end = tl.full([], 0, dtype=tl.int32)
|
||||
for g in range(G):
|
||||
M_start = M_end
|
||||
m_size = tl.load(m_sizes + g)
|
||||
M_end = M_start + m_size
|
||||
|
||||
if m_size > 0:
|
||||
for m_offset_local in range(0, m_size, BLOCK_SIZE_M):
|
||||
rows_remaining = m_size - m_offset_local
|
||||
rows_remaining = tl.maximum(rows_remaining, 0)
|
||||
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
|
||||
|
||||
m_offset = (M_start + m_offset_local).to(tl.int32)
|
||||
|
||||
x_block = x_desc.load([m_offset, k_offset])
|
||||
x_mask = row_mask[:, None] & k_mask[None, :]
|
||||
x_block = tl.where(x_mask, x_block, tl.zeros_like(x_block))
|
||||
|
||||
grad_block = grad_output_desc.load([m_offset, n_offset])
|
||||
grad_mask = row_mask[:, None] & n_mask[None, :]
|
||||
grad_block = tl.where(
|
||||
grad_mask, grad_block, tl.zeros_like(grad_block)
|
||||
)
|
||||
|
||||
contribution = tl.dot(
|
||||
grad_block.to(tl.float32).T,
|
||||
x_block.to(tl.float32),
|
||||
)
|
||||
accumulator += contribution
|
||||
|
||||
row_offsets = n_offset + tl.arange(0, BLOCK_SIZE_N)
|
||||
row_store_mask = row_offsets < N
|
||||
|
||||
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
|
||||
col_store_mask = col_offsets < K
|
||||
|
||||
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
|
||||
|
||||
ptr = grad_weight_ptr + row_offsets[:, None] * K + col_offsets[None, :]
|
||||
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
|
||||
|
||||
|
||||
# ======== End Triton kernels ========
|
||||
# ======== End Triton kernels ========
|
||||
|
||||
# ======== Triton wrapper functions ========
|
||||
|
||||
# ----- main forward pass wrapper -----
|
||||
|
||||
|
||||
def grouped_gemm_forward(
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
tma_size: int = 128,
|
||||
using_fp8: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Grouped GEMM forward using Hopper TMA kernels."""
|
||||
_ensure_triton_allocator()
|
||||
if not CudaUtils.verify_tma():
|
||||
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
|
||||
if using_fp8:
|
||||
raise NotImplementedError(
|
||||
"FP8 path not implemented with the new Triton API yet"
|
||||
)
|
||||
|
||||
G = m_sizes.shape[0]
|
||||
|
||||
assert x.is_contiguous()
|
||||
assert w.is_contiguous()
|
||||
assert m_sizes.is_contiguous()
|
||||
|
||||
M_total, K = x.shape
|
||||
N = w.shape[0]
|
||||
assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
|
||||
|
||||
y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
|
||||
if M_total == 0:
|
||||
return y
|
||||
|
||||
NUM_SMS = CudaUtils.get_num_sms()
|
||||
USE_EPILOGUE_SUBTILING = False
|
||||
|
||||
def grid(_meta):
|
||||
return (NUM_SMS,)
|
||||
|
||||
M_BUCKET = triton.next_power_of_2(M_total)
|
||||
_kernel_mg_forward_hopper[grid](
|
||||
x,
|
||||
w,
|
||||
y,
|
||||
m_sizes,
|
||||
M_total,
|
||||
G,
|
||||
M_BUCKET,
|
||||
N,
|
||||
K,
|
||||
NUM_SMS,
|
||||
USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
# ======== Improved Backward =============
|
||||
def grouped_gemm_backward(
|
||||
grad_output: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
use_tma: bool = True,
|
||||
tma_size: int = 128,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Unified backward pass for grouped GeMM with M*G grouping.
|
||||
Uses optimized TMA-based implementations for both dx and dw when available.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient of output, shape [M_total, N]
|
||||
x: Input tensor from forward pass, shape [M_total, K]
|
||||
w: Weight tensor from forward pass, shape [N, K]
|
||||
m_sizes: Group sizes tensor, shape [G]
|
||||
use_tma: Whether to try using TMA acceleration (if available)
|
||||
tma_size: Size of TMA descriptor in bytes
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of gradients with respect to x and w: (grad_x, grad_w)
|
||||
"""
|
||||
logging.info("Starting unified grouped_gemm_backward")
|
||||
|
||||
# do this once, seems expensive
|
||||
NUM_SMS = CudaUtils.get_num_sms()
|
||||
|
||||
# Basic validation
|
||||
M_total, K_x = x.shape
|
||||
M_grad, N = grad_output.shape
|
||||
N_w, K_w = w.shape
|
||||
|
||||
# Check dimensions
|
||||
if K_x != K_w:
|
||||
raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
|
||||
if M_total != M_grad:
|
||||
raise ValueError(
|
||||
f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
|
||||
)
|
||||
|
||||
# Check total M matches sum of group sizes
|
||||
sum_m_sizes = m_sizes.sum().item()
|
||||
if M_total != sum_m_sizes:
|
||||
raise ValueError(
|
||||
f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
|
||||
)
|
||||
|
||||
# Make sure inputs are contiguous
|
||||
grad_output = grad_output.contiguous()
|
||||
x = x.contiguous()
|
||||
w = w.contiguous()
|
||||
m_sizes = m_sizes.contiguous()
|
||||
|
||||
# Check TMA support
|
||||
if use_tma and not CudaUtils.verify_tma():
|
||||
logging.info("TMA requested but not supported on this device")
|
||||
use_tma = False
|
||||
|
||||
# Compute grad_x using flat linear implementation
|
||||
try:
|
||||
logging.info("Computing grad_x with flat linear kernel")
|
||||
|
||||
# Use TMA-optimized implementation
|
||||
grad_x = grouped_gemm_dx_tma(
|
||||
grad_output=grad_output,
|
||||
w=w,
|
||||
m_sizes=m_sizes,
|
||||
num_sms=NUM_SMS,
|
||||
tma_size=tma_size,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in grad_x computation: {e}")
|
||||
raise
|
||||
|
||||
# Compute grad_w using flat linear style implementation
|
||||
try:
|
||||
logging.info("Computing grad_w with flat linear kernel")
|
||||
|
||||
grad_w = grouped_gemm_dw_tma(
|
||||
x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in grad_w computation: {e}")
|
||||
raise
|
||||
|
||||
return grad_x, grad_w
|
||||
|
||||
|
||||
# ----- dx backward pass wrapper -----
|
||||
|
||||
|
||||
def grouped_gemm_dx_tma(
|
||||
grad_output: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
num_sms: int = 132,
|
||||
tma_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""Compute grad_x using the Hopper grouped GEMM kernel."""
|
||||
_ensure_triton_allocator()
|
||||
if not CudaUtils.verify_tma():
|
||||
raise NotImplementedError("Optimized dx computation requires TMA support")
|
||||
|
||||
grad_output = grad_output.contiguous()
|
||||
w = w.contiguous()
|
||||
m_sizes = m_sizes.contiguous()
|
||||
|
||||
M_total, N = grad_output.shape
|
||||
N_w, K = w.shape
|
||||
if N != N_w:
|
||||
raise ValueError(f"Grad_output N ({N}) must match weight N ({N_w})")
|
||||
|
||||
if m_sizes.sum().item() != M_total:
|
||||
raise ValueError("Sum of m_sizes must equal the number of rows in grad_output")
|
||||
|
||||
grad_x = torch.empty(
|
||||
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
|
||||
)
|
||||
|
||||
NUM_SMS = num_sms
|
||||
|
||||
def grid(_meta):
|
||||
return (NUM_SMS,)
|
||||
|
||||
M_BUCKET = triton.next_power_of_2(M_total)
|
||||
_kernel_mg_dx_tma[grid](
|
||||
grad_output,
|
||||
w,
|
||||
grad_x,
|
||||
m_sizes,
|
||||
M_total,
|
||||
m_sizes.shape[0],
|
||||
M_BUCKET,
|
||||
N,
|
||||
K,
|
||||
NUM_SMS,
|
||||
)
|
||||
return grad_x
|
||||
|
||||
|
||||
# ======== dw wrapper function ==========
|
||||
|
||||
|
||||
def grouped_gemm_dw_tma(
|
||||
x: torch.Tensor,
|
||||
grad_output: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
num_sms: int = 132,
|
||||
tma_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""Compute grad_w using the Hopper grouped GEMM kernel."""
|
||||
_ensure_triton_allocator()
|
||||
if not CudaUtils.verify_tma():
|
||||
raise RuntimeError("TMA grouped GEMM requested on a device without TMA support")
|
||||
|
||||
x = x.contiguous()
|
||||
grad_output = grad_output.contiguous()
|
||||
m_sizes = m_sizes.contiguous()
|
||||
|
||||
M_total, K = x.shape
|
||||
M_grad, N = grad_output.shape
|
||||
if M_total != M_grad:
|
||||
raise ValueError("x and grad_output must have matching batch dimension")
|
||||
if m_sizes.sum().item() != M_total:
|
||||
raise ValueError("Sum of m_sizes must equal the number of rows in the inputs")
|
||||
|
||||
grad_w = torch.zeros((N, K), device=x.device, dtype=x.dtype)
|
||||
|
||||
NUM_SMS = num_sms
|
||||
|
||||
def grid(_meta):
|
||||
return (NUM_SMS,)
|
||||
|
||||
M_BUCKET = triton.next_power_of_2(M_total)
|
||||
_kernel_mg_dw_tma[grid](
|
||||
x,
|
||||
grad_output,
|
||||
grad_w,
|
||||
m_sizes,
|
||||
M_total,
|
||||
m_sizes.shape[0],
|
||||
M_BUCKET,
|
||||
N,
|
||||
K,
|
||||
NUM_SMS,
|
||||
)
|
||||
return grad_w
|
||||
|
||||
|
||||
# ======== End Backwards Wrapper Functions =============
|
||||
|
||||
# ======== PyTorch wrapper functions ========
|
||||
|
||||
|
||||
class GroupedGemmMg(torch.autograd.Function):
|
||||
"""
|
||||
Autograd function for GroupedGEMM with M*G grouping.
|
||||
Supports both standard and FP8 quantized operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128, using_fp8=False):
|
||||
"""
|
||||
Forward pass of GroupedGEMM.
|
||||
|
||||
Args:
|
||||
x: Input tensor, shape [M_total, K]
|
||||
w: Weight tensor, shape [N, K]
|
||||
m_sizes: Tensor of shape [G] containing the size of each group
|
||||
use_tma: Whether to try using TMA acceleration (if available)
|
||||
tma_size: Size of TMA descriptor in bytes
|
||||
using_fp8: Whether to use FP8 quantization
|
||||
|
||||
Returns:
|
||||
Output tensor, shape [M_total, N]
|
||||
"""
|
||||
|
||||
# Use regular forward without quantization
|
||||
output = grouped_gemm_forward(
|
||||
x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
|
||||
)
|
||||
|
||||
# Save inputs and parameters for backward pass
|
||||
ctx.save_for_backward(x, w, m_sizes)
|
||||
ctx.use_tma = use_tma
|
||||
ctx.tma_size = tma_size
|
||||
|
||||
ctx.save_for_backward(x, w, m_sizes)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
"""
|
||||
Backward pass of M*G GroupedGEMM.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient of output, shape [M_total, N]
|
||||
|
||||
Returns:
|
||||
Tuple of gradients:
|
||||
- grad_x: Gradient with respect to x, shape [M_total, K]
|
||||
- grad_w: Gradient with respect to w, shape [N, K]
|
||||
- None: Gradient with respect to m_sizes (not differentiable)
|
||||
- None: Gradient with respect to use_tma (not differentiable)
|
||||
- None: Gradient with respect to tma_size (not differentiable)
|
||||
|
||||
"""
|
||||
# Retrieve saved tensors and parameters
|
||||
|
||||
x, w, m_sizes = ctx.saved_tensors
|
||||
|
||||
use_tma = ctx.use_tma
|
||||
tma_size = ctx.tma_size
|
||||
|
||||
# Compute gradients using the unified implementation
|
||||
grad_x, grad_w = grouped_gemm_backward(
|
||||
grad_output=grad_output,
|
||||
x=x,
|
||||
w=w,
|
||||
m_sizes=m_sizes,
|
||||
use_tma=use_tma,
|
||||
tma_size=tma_size,
|
||||
)
|
||||
|
||||
# Return gradients for all inputs (None for non-differentiable parameters)
|
||||
return grad_x, grad_w, None, None, None, None
|
||||
|
||||
|
||||
def mg_grouped_gemm(
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
use_tma: bool = True,
|
||||
tma_size: int = 128,
|
||||
using_fp8: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Unified differentiable grouped GEMM operation for M*G grouped GEMM.
|
||||
Supports both standard precision and FP8 quantized operations.
|
||||
|
||||
Args:
|
||||
x: Input tensor, shape [M_total, K]
|
||||
w: Weight tensor, shape [N, K]
|
||||
m_sizes: Tensor of shape [G] containing the size of each group
|
||||
use_tma: Whether to try using TMA acceleration (if available)
|
||||
tma_size: Size of TMA descriptor in bytes
|
||||
using_fp8: Whether to use FP8 quantization
|
||||
|
||||
Returns:
|
||||
Output tensor, shape [M_total, N]
|
||||
"""
|
||||
return GroupedGemmMg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)
|
||||
232
src/axolotl/kernels/moe/tt_mg_gemm/tma_autotuning.py
Normal file
232
src/axolotl/kernels/moe/tt_mg_gemm/tma_autotuning.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# credit - TMAHelper class, AutoTuning are derived from FBGemm:
|
||||
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from triton.runtime import driver # @manual
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
# ===== Supporting utils, CUDA and TMA =====
|
||||
|
||||
|
||||
class CudaUtils:
|
||||
@staticmethod
|
||||
def is_cuda() -> bool:
|
||||
"""Check if Triton is running on CUDA backend."""
|
||||
return driver.active.get_current_target().backend == "cuda"
|
||||
|
||||
@staticmethod
|
||||
def verify_tma() -> bool:
|
||||
"""Check if TMA is supported on the current device."""
|
||||
return (
|
||||
CudaUtils.is_cuda()
|
||||
and torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability()[0] >= 9
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_num_sms() -> int:
|
||||
"""Get the number of streaming multiprocessors on the current device."""
|
||||
if not CudaUtils.is_cuda():
|
||||
raise RuntimeError("Triton is not running on CUDA backend")
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is not available")
|
||||
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
|
||||
|
||||
class TmaDescriptorHelper:
|
||||
"""Helper class for managing TMA descriptors in Triton kernels.
|
||||
|
||||
Args:
|
||||
tma_size: Size of the TMA descriptor in bytes
|
||||
"""
|
||||
|
||||
class KernelParamWrapper:
|
||||
"""Wrapper to implement the TmaDescKernelParam interface."""
|
||||
|
||||
def __init__(self, desc: torch.Tensor):
|
||||
self.desc = desc
|
||||
|
||||
def tma_desc_cpu_ptr(self) -> int:
|
||||
"""Return the CPU pointer to the TMA descriptor."""
|
||||
return self.desc.data_ptr()
|
||||
|
||||
def __init__(self, tma_size: int = 128):
|
||||
if not CudaUtils.verify_tma():
|
||||
raise RuntimeError(
|
||||
"TMA not supported on this device (requires Hopper or newer)"
|
||||
)
|
||||
|
||||
self.tma_size = tma_size
|
||||
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor
|
||||
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor
|
||||
self.descriptors: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def init_tma_descriptor(self, name: str) -> None:
|
||||
"""Initialize a TMA descriptor with the given name.
|
||||
|
||||
Call this method outside of the lambda function for grid size.
|
||||
"""
|
||||
self.descriptors[name] = torch.empty(
|
||||
self.tma_size, device="cpu", dtype=torch.int8
|
||||
)
|
||||
|
||||
def fill_1d_tma_descriptor(
|
||||
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
|
||||
) -> None:
|
||||
"""Fill a 1D TMA descriptor.
|
||||
|
||||
Call this method inside the lambda function for grid size.
|
||||
"""
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_1d_tma_descriptor_inner(
|
||||
ptr, dim, block_dim, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def fill_2d_tma_descriptor(
|
||||
self,
|
||||
name: str,
|
||||
ptr: int,
|
||||
dim1: int,
|
||||
dim0: int,
|
||||
block_dim1: int,
|
||||
block_dim0: int,
|
||||
element_size: int,
|
||||
) -> None:
|
||||
"""Fill a 2D TMA descriptor.
|
||||
|
||||
Call this method inside the lambda function for grid size.
|
||||
"""
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_2d_tma_descriptor_inner(
|
||||
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
|
||||
"""Get the TMA descriptor kernel parameter for the given name."""
|
||||
if name not in self.descriptors or self.descriptors[name] is None:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
return self.KernelParamWrapper(self.descriptors[name])
|
||||
|
||||
|
||||
# ====== Autotuning utilities ======
|
||||
ALIGN_SIZE_M = 128
|
||||
|
||||
_NV_CONFIGS = [
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_size_m,
|
||||
"BLOCK_SIZE_N": block_size_n,
|
||||
"BLOCK_SIZE_K": block_size_k,
|
||||
},
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
num_ctas=num_ctas,
|
||||
)
|
||||
for block_size_m in [
|
||||
ALIGN_SIZE_M,
|
||||
]
|
||||
for block_size_n in [64, 128, 256]
|
||||
for block_size_k in [64, 128, 256]
|
||||
for num_stages in [3, 4]
|
||||
for num_warps in [4, 8]
|
||||
for num_ctas in [1]
|
||||
]
|
||||
|
||||
|
||||
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
|
||||
device = torch.cuda.current_device()
|
||||
# Check for all possible pointer parameter names
|
||||
if "grad_input_ptr" in named_args:
|
||||
ptr_name = "grad_input_ptr"
|
||||
elif "c_ptr" in named_args:
|
||||
ptr_name = "c_ptr"
|
||||
elif "grad_weight_ptr" in named_args:
|
||||
ptr_name = "grad_weight_ptr"
|
||||
else:
|
||||
raise KeyError("No recognized pointer parameter found in kernel arguments")
|
||||
|
||||
if dtsize is None:
|
||||
dtsize = named_args[ptr_name].element_size()
|
||||
if dtype is None:
|
||||
dtype = named_args[ptr_name].dtype
|
||||
|
||||
pruned_configs = []
|
||||
for config in configs:
|
||||
kw = config.kwargs
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
|
||||
kw["BLOCK_SIZE_M"],
|
||||
kw["BLOCK_SIZE_N"],
|
||||
kw["BLOCK_SIZE_K"],
|
||||
config.num_stages,
|
||||
)
|
||||
G, M, N, K = (
|
||||
named_args["G"],
|
||||
named_args["M_BUCKET"],
|
||||
named_args["N"],
|
||||
named_args["K"],
|
||||
)
|
||||
|
||||
# 1. make sure we have enough smem
|
||||
max_shared_memory = driver.active.utils.get_device_properties(device)[
|
||||
"max_shared_mem"
|
||||
]
|
||||
|
||||
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
||||
if required_shared_memory > max_shared_memory:
|
||||
continue
|
||||
|
||||
M_PER_GROUP = M // G
|
||||
MIN_M_TILES = 64
|
||||
# 2. make sure we don't load M tiles that are too big
|
||||
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
|
||||
continue
|
||||
# 3. make sure we don't load N tiles that are too small
|
||||
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
|
||||
continue
|
||||
|
||||
num_sm = driver.active.utils.get_device_properties(device)[
|
||||
"multiprocessor_count"
|
||||
]
|
||||
N_TILES = N // BLOCK_N
|
||||
MIN_N_TILES = 64
|
||||
# 4. make sure we don't load N tiles that are too big
|
||||
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
|
||||
continue
|
||||
# 5. make sure we don't load N tiles that are too small
|
||||
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
|
||||
continue
|
||||
# 6. make sure K can be evenly divided
|
||||
if K % BLOCK_K != 0:
|
||||
continue
|
||||
|
||||
pruned_configs.append(config)
|
||||
|
||||
return pruned_configs
|
||||
|
||||
|
||||
# ======== End Autotuning utilities ========
|
||||
@@ -12,7 +12,6 @@ import transformers
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.moe_grouped import apply_grouped_to_moe_blocks
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||
patch_for_multipack,
|
||||
@@ -58,8 +57,6 @@ class PatchManager:
|
||||
self._apply_fsdp_patches()
|
||||
self._apply_adapter_patches()
|
||||
self._apply_model_specific_patches()
|
||||
# Apply MoE grouped GEMM patches (cfg.moe_backend)
|
||||
apply_grouped_to_moe_blocks(self.cfg)
|
||||
self._apply_fp8_patches()
|
||||
self._apply_flash_attention_peft_patches()
|
||||
self._apply_gradient_checkpointing_patches()
|
||||
@@ -71,11 +68,12 @@ class PatchManager:
|
||||
self._apply_self_attention_lora_patch()
|
||||
self._apply_fsdp2_bnb_patches()
|
||||
self._apply_patch_deepspeed_zero3()
|
||||
self._apply_voxtral_patches()
|
||||
self._apply_apertus_patches()
|
||||
|
||||
def apply_post_plugin_pre_model_load_patches(self):
|
||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||
self._apply_voxtral_patches()
|
||||
|
||||
def _apply_transformers_patches(self):
|
||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||
@@ -86,6 +84,13 @@ class PatchManager:
|
||||
patch_evaluation_loop()
|
||||
patch_maybe_log_save_evaluate()
|
||||
|
||||
if self.cfg.context_parallel_size > 1:
|
||||
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
||||
patch_prepare_context_parallel_inputs,
|
||||
)
|
||||
|
||||
patch_prepare_context_parallel_inputs()
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches that require the model instance."""
|
||||
self._apply_llama_flash_attn_patches(model)
|
||||
@@ -171,6 +176,29 @@ class PatchManager:
|
||||
|
||||
patch_llama4_linearized_modeling()
|
||||
|
||||
if self.cfg.model_config_type == "qwen3_next" and self.cfg.sample_packing:
|
||||
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
||||
patch_qwen3_next_modeling_packing,
|
||||
)
|
||||
|
||||
patch_qwen3_next_modeling_packing()
|
||||
|
||||
if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type:
|
||||
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
|
||||
apply_mistral_tokenizer_image_patch,
|
||||
)
|
||||
|
||||
apply_mistral_tokenizer_image_patch()
|
||||
|
||||
if self.cfg.moe_kernels and self.cfg.model_config_type == "deepseek_v3":
|
||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
|
||||
|
||||
patch_deepseek_v3_moe(backend=self.cfg.moe_kernel_backend)
|
||||
elif self.cfg.model_config_type == "deepseek_v3" and not self.cfg.moe_kernels:
|
||||
LOG.info(
|
||||
"Skipping DeepSeek V3 Triton MoE kernels; enable with `moe_kernels: true`"
|
||||
)
|
||||
|
||||
def _apply_fp8_patches(self):
|
||||
"""Apply patches for FP8 support."""
|
||||
if self.cfg.fp8:
|
||||
@@ -272,7 +300,6 @@ class PatchManager:
|
||||
self.cfg.model_config_type,
|
||||
model_name=self.cfg.base_model,
|
||||
has_remote_code=has_remote_code,
|
||||
cfg=self.cfg,
|
||||
)
|
||||
|
||||
if self.cfg.sample_packing:
|
||||
@@ -338,6 +365,13 @@ class PatchManager:
|
||||
|
||||
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
|
||||
|
||||
if self.model_config.model_type in ("mistral3", "llava"):
|
||||
from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import (
|
||||
apply_patch_is_packed_sequence,
|
||||
)
|
||||
|
||||
apply_patch_is_packed_sequence()
|
||||
|
||||
def _patch_loss_llama(self):
|
||||
"""Patch loss functions and other optimizations for LLaMA models."""
|
||||
if not self.cfg.is_llama_derived_model:
|
||||
@@ -483,3 +517,12 @@ class PatchManager:
|
||||
apply_deepspeed_patches()
|
||||
except ImportError as e:
|
||||
LOG.warning(f"DeepSpeed patches not applied: {e}")
|
||||
|
||||
def _apply_apertus_patches(self):
|
||||
"""Apply patches for Apertus model."""
|
||||
if self.cfg.model_config_type == "apertus":
|
||||
from axolotl.monkeypatch.models.apertus.activation import (
|
||||
patch_apertus_xielu_activation,
|
||||
)
|
||||
|
||||
patch_apertus_xielu_activation()
|
||||
|
||||
@@ -21,6 +21,13 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
if cfg.processor_type:
|
||||
processor_cls = getattr(transformers, cfg.processor_type)
|
||||
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
from axolotl.utils.mistral import Mistral3Processor
|
||||
|
||||
return Mistral3Processor(
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
processor = processor_cls.from_pretrained(
|
||||
cfg.processor_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
|
||||
@@ -124,13 +124,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
|
||||
def _load_mistral_common_tokenizer(cfg: DictDefault):
|
||||
"""Load mistral-common tokenizer"""
|
||||
from transformers import tokenization_mistral_common
|
||||
|
||||
from axolotl.utils.mistral import HFMistralTokenizer
|
||||
|
||||
# patch
|
||||
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
|
||||
|
||||
# Load the HF-compatible wrapper around MistralTokenizer
|
||||
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
|
||||
|
||||
|
||||
401
src/axolotl/monkeypatch/deepseek_v3/__init__.py
Normal file
401
src/axolotl/monkeypatch/deepseek_v3/__init__.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""Monkeypatches for DeepSeek V3 MoE to use Triton contiguous grouped GEMM kernels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.kernels.moe import ContiguousGroupedGEMM
|
||||
from axolotl.kernels.moe.indices import generate_permute_indices
|
||||
from axolotl.kernels.moe.tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
_GROUP_SIZE_M = 128
|
||||
_COMBINED_SUBMODULES = ("gate_proj", "up_proj", "down_proj")
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
|
||||
if not hidden_states.is_cuda or hidden_states.shape[0] == 0:
|
||||
return False
|
||||
major, _ = torch.cuda.get_device_capability(hidden_states.device)
|
||||
if major < 9:
|
||||
LOG.debug(
|
||||
"Skipping Triton MoE kernels: requires compute capability >= 90, found %s",
|
||||
major,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _ensure_combined_expert_weights(
|
||||
module, dtype: torch.dtype, device: torch.device, backend: str
|
||||
) -> None:
|
||||
if not hasattr(module, "_axolotl_original_specs"):
|
||||
module._axolotl_original_specs = {}
|
||||
if not hasattr(module, "_axolotl_mg_shapes"):
|
||||
module._axolotl_mg_shapes = {}
|
||||
|
||||
prev_backend = getattr(module, "_axolotl_combined_backend", None)
|
||||
if getattr(module, "_axolotl_combined_weights", False):
|
||||
if prev_backend != backend:
|
||||
_restore_expert_weights(module)
|
||||
else:
|
||||
for name in _COMBINED_SUBMODULES:
|
||||
param_name = f"{name}_weight"
|
||||
param = module.get_parameter(param_name)
|
||||
if param.device != device or param.dtype != dtype:
|
||||
module._parameters[param_name] = torch.nn.Parameter(
|
||||
param.to(device=device, dtype=dtype).contiguous()
|
||||
)
|
||||
module._axolotl_combined_dtype = dtype
|
||||
module._axolotl_combined_device = device
|
||||
module._axolotl_combined_backend = backend
|
||||
return
|
||||
|
||||
module._axolotl_mg_shapes = {}
|
||||
for name in _COMBINED_SUBMODULES:
|
||||
weights = []
|
||||
orig_device = None
|
||||
orig_dtype = None
|
||||
orig_shape = None
|
||||
for expert in module.experts:
|
||||
lin = expert.get_submodule(name)
|
||||
weight_param = lin._parameters.get("weight")
|
||||
if weight_param is None:
|
||||
raise RuntimeError("Expected expert linear layers to have weights")
|
||||
if orig_device is None:
|
||||
orig_device = weight_param.device
|
||||
orig_dtype = weight_param.dtype
|
||||
orig_shape = tuple(weight_param.shape)
|
||||
weights.append(weight_param.detach().to(device=device, dtype=dtype))
|
||||
if "weight" in lin._parameters:
|
||||
del lin._parameters["weight"]
|
||||
if "bias" in lin._parameters:
|
||||
del lin._parameters["bias"]
|
||||
if backend == "cg":
|
||||
combined_weight = torch.stack(weights, dim=0).contiguous()
|
||||
else:
|
||||
combined_weight = torch.cat(weights, dim=0).contiguous()
|
||||
module._axolotl_mg_shapes[name] = orig_shape
|
||||
module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined_weight))
|
||||
module._axolotl_original_specs[name] = (orig_device, orig_dtype, orig_shape)
|
||||
|
||||
module._axolotl_combined_weights = True
|
||||
module._axolotl_combined_dtype = dtype
|
||||
module._axolotl_combined_device = device
|
||||
module._axolotl_combined_backend = backend
|
||||
|
||||
|
||||
def _restore_expert_weights(module) -> None:
|
||||
if not getattr(module, "_axolotl_combined_weights", False):
|
||||
return
|
||||
|
||||
for name in _COMBINED_SUBMODULES:
|
||||
param_name = f"{name}_weight"
|
||||
combined = module._parameters.pop(param_name)
|
||||
orig_device, orig_dtype, orig_shape = module._axolotl_original_specs.get(
|
||||
name, (combined.device, combined.dtype, None)
|
||||
)
|
||||
rows_per = orig_shape[0] if orig_shape else None
|
||||
for idx, expert in enumerate(module.experts):
|
||||
lin = expert.get_submodule(name)
|
||||
if combined.dim() == 3:
|
||||
slice_tensor = combined[idx]
|
||||
elif rows_per is not None:
|
||||
start = idx * rows_per
|
||||
end = start + rows_per
|
||||
slice_tensor = combined[start:end]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unable to recover expert weight shape during restore"
|
||||
)
|
||||
lin._parameters["weight"] = torch.nn.Parameter(
|
||||
slice_tensor.detach().clone().to(orig_device, dtype=orig_dtype)
|
||||
)
|
||||
|
||||
module._axolotl_combined_weights = False
|
||||
module._axolotl_combined_dtype = None
|
||||
module._axolotl_combined_device = None
|
||||
module._axolotl_combined_backend = None
|
||||
module._axolotl_original_specs = {}
|
||||
module._axolotl_mg_shapes = {}
|
||||
|
||||
|
||||
def _run_cg_grouped_gemm(
|
||||
module,
|
||||
grouped_hidden: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
num_experts: int,
|
||||
group_size_m: int,
|
||||
hidden_dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
_ensure_combined_expert_weights(module, hidden_dtype, device, backend="cg")
|
||||
|
||||
expert_index_tensor = torch.repeat_interleave(
|
||||
torch.arange(num_experts, device=device, dtype=torch.int32),
|
||||
m_sizes.to(torch.int64),
|
||||
)
|
||||
|
||||
gate_weights = module.get_parameter("gate_proj_weight")
|
||||
if gate_weights.dim() == 2:
|
||||
out_dim = gate_weights.shape[0] // num_experts
|
||||
gate_weights = gate_weights.view(num_experts, out_dim, gate_weights.shape[1])
|
||||
|
||||
up_weights = module.get_parameter("up_proj_weight")
|
||||
if up_weights.dim() == 2:
|
||||
out_dim = up_weights.shape[0] // num_experts
|
||||
up_weights = up_weights.view(num_experts, out_dim, up_weights.shape[1])
|
||||
|
||||
down_weights = module.get_parameter("down_proj_weight")
|
||||
if down_weights.dim() == 2:
|
||||
out_dim = down_weights.shape[0] // num_experts
|
||||
down_weights = down_weights.view(num_experts, out_dim, down_weights.shape[1])
|
||||
|
||||
gate_out = ContiguousGroupedGEMM.apply(
|
||||
grouped_hidden,
|
||||
gate_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
up_out = ContiguousGroupedGEMM.apply(
|
||||
grouped_hidden,
|
||||
up_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
return (
|
||||
gate_out.to(hidden_dtype),
|
||||
up_out.to(hidden_dtype),
|
||||
down_weights,
|
||||
expert_index_tensor,
|
||||
)
|
||||
|
||||
gate_out = mg_grouped_gemm(
|
||||
grouped_hidden,
|
||||
module.get_parameter("gate_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
)
|
||||
up_out = mg_grouped_gemm(
|
||||
grouped_hidden,
|
||||
module.get_parameter("up_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
)
|
||||
down_out = mg_grouped_gemm(
|
||||
hidden_grouped,
|
||||
module.get_parameter("down_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
)
|
||||
|
||||
return (
|
||||
gate_out.to(hidden_dtype),
|
||||
up_out.to(hidden_dtype),
|
||||
down_out.to(hidden_dtype),
|
||||
)
|
||||
|
||||
|
||||
def _moe_triton_forward(
|
||||
module,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
group_size_m: int,
|
||||
backend: str,
|
||||
fallback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if not _is_triton_eligible(hidden_states):
|
||||
return fallback(hidden_states, topk_indices, topk_weights)
|
||||
|
||||
device = hidden_states.device
|
||||
hidden_dtype = hidden_states.dtype
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
top_k = topk_indices.size(-1)
|
||||
|
||||
expanded_hidden = hidden_states.repeat_interleave(top_k, dim=0)
|
||||
expert_assignments = topk_indices.reshape(-1)
|
||||
if expanded_hidden.numel() == 0:
|
||||
return hidden_states.new_zeros_like(hidden_states)
|
||||
|
||||
sort_perm = torch.argsort(expert_assignments)
|
||||
sorted_hidden = expanded_hidden.index_select(0, sort_perm)
|
||||
sorted_assignments = expert_assignments.index_select(0, sort_perm)
|
||||
|
||||
num_experts = len(module.experts)
|
||||
counts = torch.bincount(sorted_assignments, minlength=num_experts)
|
||||
total_actual = int(counts.sum().item())
|
||||
if total_actual == 0:
|
||||
return hidden_states.new_zeros_like(hidden_states)
|
||||
|
||||
if not getattr(module, "_axolotl_triton_logged", False):
|
||||
min_tokens = int(counts.min().item())
|
||||
max_tokens = int(counts.max().item())
|
||||
LOG.info(
|
||||
"DeepseekV3MoE Triton: tokens per expert (min=%s, max=%s, avg=%.1f) with group_size=%s",
|
||||
min_tokens,
|
||||
max_tokens,
|
||||
total_actual / max(1, num_experts),
|
||||
group_size_m,
|
||||
)
|
||||
module._axolotl_triton_logged = True
|
||||
|
||||
counts_int = counts.to(torch.int32)
|
||||
aligned_counts = (
|
||||
(torch.clamp_min(counts_int, group_size_m) + group_size_m - 1) // group_size_m
|
||||
) * group_size_m
|
||||
max_len = int(aligned_counts.sum().item())
|
||||
|
||||
permuted_indices, m_sizes, _ = generate_permute_indices(
|
||||
counts_int.to(device),
|
||||
experts_per_rank=num_experts,
|
||||
num_ranks=1,
|
||||
max_len=max_len,
|
||||
alignment=group_size_m,
|
||||
use_cpu=not hidden_states.is_cuda,
|
||||
)
|
||||
|
||||
permuted_indices = permuted_indices.to(device)
|
||||
m_sizes = m_sizes.to(device)
|
||||
|
||||
permuted_indices_long = permuted_indices.to(torch.int64)
|
||||
valid_mask = permuted_indices_long >= 0
|
||||
valid_positions = torch.nonzero(valid_mask, as_tuple=False).squeeze(-1)
|
||||
source_indices = permuted_indices_long[valid_mask]
|
||||
padded_positions = torch.nonzero(~valid_mask, as_tuple=False).squeeze(-1)
|
||||
|
||||
grouped_hidden = hidden_states.new_empty((max_len, hidden_dim))
|
||||
if valid_positions.numel() > 0:
|
||||
grouped_hidden.index_copy_(
|
||||
0,
|
||||
valid_positions,
|
||||
sorted_hidden.index_select(0, source_indices),
|
||||
)
|
||||
if valid_positions.numel() < max_len:
|
||||
grouped_hidden.index_fill_(0, padded_positions, 0)
|
||||
|
||||
m_sizes_tensor = m_sizes.to(device=device, dtype=torch.int32)
|
||||
|
||||
if backend == "mg":
|
||||
_ensure_combined_expert_weights(module, hidden_dtype, device, backend)
|
||||
gate_out = mg_grouped_gemm(
|
||||
grouped_hidden,
|
||||
module.get_parameter("gate_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
).to(hidden_dtype)
|
||||
up_out = mg_grouped_gemm(
|
||||
grouped_hidden,
|
||||
module.get_parameter("up_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
).to(hidden_dtype)
|
||||
else:
|
||||
gate_out, up_out, down_weights, expert_index_tensor = _run_cg_grouped_gemm(
|
||||
module,
|
||||
grouped_hidden,
|
||||
m_sizes,
|
||||
num_experts,
|
||||
group_size_m,
|
||||
hidden_dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
|
||||
if valid_positions.numel() > 0:
|
||||
gate_valid = gate_out.index_select(0, valid_positions)
|
||||
up_valid = up_out.index_select(0, valid_positions)
|
||||
hidden_concat = act_fn(gate_valid) * up_valid
|
||||
else:
|
||||
hidden_concat = torch.empty(
|
||||
(0, gate_out.shape[-1]), device=device, dtype=hidden_dtype
|
||||
)
|
||||
|
||||
intermediate_dim = hidden_concat.shape[-1]
|
||||
hidden_grouped = hidden_states.new_empty((max_len, intermediate_dim))
|
||||
if valid_positions.numel() > 0:
|
||||
hidden_grouped.index_copy_(0, valid_positions, hidden_concat)
|
||||
if valid_positions.numel() < max_len:
|
||||
hidden_grouped.index_fill_(0, padded_positions, 0)
|
||||
|
||||
if backend == "mg":
|
||||
down_out = mg_grouped_gemm(
|
||||
hidden_grouped,
|
||||
module.get_parameter("down_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
).to(hidden_dtype)
|
||||
else:
|
||||
down_out = ContiguousGroupedGEMM.apply(
|
||||
hidden_grouped,
|
||||
down_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
).to(hidden_dtype)
|
||||
|
||||
if valid_positions.numel() > 0:
|
||||
down_valid = down_out.index_select(0, valid_positions)
|
||||
else:
|
||||
down_valid = torch.empty(
|
||||
(0, down_out.shape[-1]), device=device, dtype=hidden_dtype
|
||||
)
|
||||
|
||||
sorted_outputs = hidden_states.new_zeros((total_actual, hidden_dim))
|
||||
if down_valid.numel() > 0:
|
||||
sorted_outputs.index_copy_(0, source_indices, down_valid)
|
||||
|
||||
expanded_output = expanded_hidden.new_empty(expanded_hidden.shape)
|
||||
expanded_output.index_copy_(0, sort_perm, sorted_outputs)
|
||||
expert_outputs = expanded_output.view(num_tokens, top_k, hidden_dim)
|
||||
|
||||
weighted = expert_outputs * topk_weights.unsqueeze(-1).to(hidden_dtype)
|
||||
return weighted.sum(dim=1)
|
||||
|
||||
|
||||
def patch_deepseek_v3_moe(
|
||||
group_size_m: int = _GROUP_SIZE_M, backend: str = "mg"
|
||||
) -> None:
|
||||
"""Patch HuggingFace DeepseekV3MoE to use Triton contiguous group GEMM kernels."""
|
||||
|
||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
|
||||
if backend not in {"cg", "mg"}:
|
||||
raise ValueError(f"Unsupported MoE kernel backend: {backend}")
|
||||
|
||||
# Record the unpatched implementation so callers can access a true baseline even
|
||||
# after the Triton patch has been applied (e.g. repeated microbenchmarks).
|
||||
if not hasattr(DeepseekV3MoE, "_axolotl_triton_original_moe"):
|
||||
DeepseekV3MoE._axolotl_triton_original_moe = DeepseekV3MoE.moe
|
||||
|
||||
if getattr(DeepseekV3MoE, "_axolotl_triton_patch", False):
|
||||
return
|
||||
|
||||
original_moe = DeepseekV3MoE._axolotl_triton_original_moe
|
||||
DeepseekV3MoE._axolotl_triton_backend = backend
|
||||
DeepseekV3MoE._axolotl_group_size_m = group_size_m
|
||||
|
||||
def patched_moe(self, hidden_states, topk_indices, topk_weights):
|
||||
backend_sel = getattr(self, "_axolotl_triton_backend", backend)
|
||||
group_size_sel = getattr(self, "_axolotl_group_size_m", group_size_m)
|
||||
if backend_sel == "cg" and group_size_sel != _GROUP_SIZE_M:
|
||||
LOG.debug(
|
||||
"Adjusting group_size_m=%s to %s for CG backend",
|
||||
group_size_sel,
|
||||
_GROUP_SIZE_M,
|
||||
)
|
||||
group_size_sel = _GROUP_SIZE_M
|
||||
try:
|
||||
return _moe_triton_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
topk_indices,
|
||||
topk_weights,
|
||||
group_size_sel,
|
||||
backend_sel,
|
||||
original_moe,
|
||||
)
|
||||
except Exception as err: # surface Triton failures explicitly
|
||||
_restore_expert_weights(self)
|
||||
LOG.error("DeepseekV3MoE Triton path failed: %s", err)
|
||||
raise
|
||||
|
||||
DeepseekV3MoE.moe = patched_moe
|
||||
DeepseekV3MoE._axolotl_triton_patch = True
|
||||
@@ -5,14 +5,9 @@ Patches to support multipack for mixtral
|
||||
import torch
|
||||
|
||||
|
||||
def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
|
||||
import warnings
|
||||
|
||||
def patch_mixtral_moe_forward_zero3() -> None:
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.kernels.moe import backends as _moe_backends
|
||||
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
|
||||
|
||||
def mlp_forward(self, hidden_states):
|
||||
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
|
||||
hidden_states
|
||||
@@ -26,32 +21,21 @@ def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
|
||||
backend = get_moe_backend_name(preferred)
|
||||
if (
|
||||
backend == MOEBackend.TORCH_GROUPED
|
||||
and not _moe_backends._probe_torch_grouped()
|
||||
):
|
||||
warnings.warn(
|
||||
"torch_grouped selected but not available; falling back to naive",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
topk_weight, topk_idx = torch.topk(
|
||||
routing_weights, self.top_k, dim=-1, sorted=False
|
||||
)
|
||||
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
topk_weight = topk_weight.to(hidden_states.dtype)
|
||||
|
||||
hidden_states_rep = hidden_states.repeat_interleave(self.top_k, dim=0)
|
||||
y = torch.empty_like(hidden_states_rep)
|
||||
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
||||
y = torch.empty_like(hidden_states)
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
for i in range(self.num_experts):
|
||||
expert = self.experts[i]
|
||||
sel = flat_topk_idx == i
|
||||
if sel.any():
|
||||
y[sel] = expert(hidden_states_rep[sel])
|
||||
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, router_logits
|
||||
@@ -62,23 +46,4 @@ def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
|
||||
)
|
||||
|
||||
MixtralBlockSparseTop2MLP.forward = mlp_forward
|
||||
# Wrap forward to support optional torch_grouped backend via config
|
||||
from axolotl.kernels.moe import torch_grouped as _tg
|
||||
|
||||
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
|
||||
backend = get_moe_backend_name(preferred)
|
||||
|
||||
if backend == MOEBackend.TORCH_GROUPED and _tg.available():
|
||||
|
||||
def moe_forward_grouped(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
bsz, seqlen, hdim = hidden_states.shape
|
||||
y, router_logits = _tg.moe_ffn_forward_grouped(
|
||||
hidden_states, self.gate, self.experts, self.top_k
|
||||
)
|
||||
if y is None:
|
||||
return moe_forward(self, hidden_states)
|
||||
return y, router_logits
|
||||
|
||||
MixtralSparseMoeBlock.forward = moe_forward_grouped
|
||||
else:
|
||||
MixtralSparseMoeBlock.forward = moe_forward
|
||||
MixtralSparseMoeBlock.forward = moe_forward
|
||||
|
||||
0
src/axolotl/monkeypatch/models/apertus/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/apertus/__init__.py
Normal file
52
src/axolotl/monkeypatch/models/apertus/activation.py
Normal file
52
src/axolotl/monkeypatch/models/apertus/activation.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Monkeypatch for Apertus to dtype mismatch in XIELU act"""
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def patch_apertus_xielu_activation():
|
||||
try:
|
||||
from transformers.activations import XIELUActivation
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Cannot import XIELUActivation. "
|
||||
"Please make sure to update your transformers version >= 4.56.1."
|
||||
) from err
|
||||
|
||||
from transformers.activations import logger
|
||||
|
||||
# Store the original method
|
||||
old_fn = XIELUActivation._xielu_cuda
|
||||
|
||||
def _xielu_cuda_fixed(self, x: Tensor) -> Tensor:
|
||||
"""Firewall function to prevent torch.compile from seeing .item() calls"""
|
||||
original_shape = x.shape
|
||||
# CUDA kernel expects 3D tensors, reshape if needed
|
||||
while x.dim() < 3:
|
||||
x = x.unsqueeze(0)
|
||||
if x.dim() > 3:
|
||||
x = x.view(-1, 1, x.size(-1))
|
||||
if original_shape != x.shape:
|
||||
logger.warning_once(
|
||||
"Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).",
|
||||
original_shape,
|
||||
x.shape,
|
||||
)
|
||||
result = self._xielu_cuda_obj.forward(
|
||||
x,
|
||||
self.alpha_p.to(x.dtype),
|
||||
self.alpha_n.to(x.dtype),
|
||||
# Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
|
||||
self._beta_scalar,
|
||||
self._eps_scalar,
|
||||
self.with_vector_loads,
|
||||
)
|
||||
return result.view(original_shape)
|
||||
|
||||
# Apply the patch
|
||||
XIELUActivation._xielu_cuda = _xielu_cuda_fixed
|
||||
|
||||
def unpatch():
|
||||
"""Restore the original method"""
|
||||
XIELUActivation._xielu_cuda = old_fn
|
||||
|
||||
return unpatch
|
||||
0
src/axolotl/monkeypatch/models/mistral3/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/mistral3/__init__.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Monkeypatch to fix inefficient tensor conversion in MistralCommonTokenizer.apply_chat_template
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def apply_mistral_tokenizer_image_patch():
|
||||
"""Apply patch to MistralCommonTokenizer.apply_chat_template to fix image tensor conversion."""
|
||||
from transformers.tokenization_mistral_common import MistralCommonTokenizer
|
||||
|
||||
# Get original source
|
||||
original_source = inspect.getsource(MistralCommonTokenizer.apply_chat_template)
|
||||
original_source, _ = detab_code(original_source)
|
||||
|
||||
# Define the replacement
|
||||
original_tensor_conversion = (
|
||||
" pixel_values = torch.tensor(images)"
|
||||
)
|
||||
|
||||
patched_tensor_conversion = """ if isinstance(images, list) and len(images) > 0 and isinstance(images[0], np.ndarray):
|
||||
pixel_values = torch.tensor(np.array(images))
|
||||
else:
|
||||
pixel_values = torch.tensor(images)"""
|
||||
|
||||
# Apply the replacement
|
||||
if original_tensor_conversion in original_source:
|
||||
patched_source = original_source.replace(
|
||||
original_tensor_conversion, patched_tensor_conversion
|
||||
)
|
||||
patched_source = patched_source.replace(
|
||||
"def apply_chat_template(",
|
||||
"def patched_apply_chat_template(",
|
||||
1,
|
||||
)
|
||||
|
||||
# Load necessary imports from the module
|
||||
module_name = MistralCommonTokenizer.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Detect what needs to be imported
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in patched_source and not item.startswith("_"):
|
||||
items_to_import.append(item)
|
||||
|
||||
# Execute imports in global scope
|
||||
if items_to_import:
|
||||
exec( # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
|
||||
# Also need standard imports that might be used
|
||||
exec("import numpy as np", globals()) # nosec B102
|
||||
exec("import torch", globals()) # nosec B102
|
||||
exec("from typing import Union, Optional, List, Dict, Any, Callable", globals()) # nosec B102
|
||||
exec("from pathlib import Path", globals()) # nosec B102
|
||||
|
||||
# Import other dependencies that might be needed
|
||||
try:
|
||||
exec("from transformers.utils import is_torch_available", globals()) # nosec B102
|
||||
exec(
|
||||
"from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TensorType",
|
||||
globals(),
|
||||
) # nosec B102
|
||||
exec("from transformers.utils import logging", globals()) # nosec B102
|
||||
exec("logger = logging.get_logger(__name__)", globals()) # nosec B102
|
||||
except ImportError as e:
|
||||
LOG.warning(f"Could not import some dependencies: {e}")
|
||||
|
||||
# Execute the patched source
|
||||
exec(patched_source, globals()) # nosec B102
|
||||
|
||||
# Replace the method
|
||||
MistralCommonTokenizer.apply_chat_template = patched_apply_chat_template
|
||||
LOG.info("Successfully applied MistralCommonTokenizer tensor conversion patch")
|
||||
else:
|
||||
LOG.warning("Could not find target code for MistralCommonTokenizer patching")
|
||||
0
src/axolotl/monkeypatch/models/pixtral/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/pixtral/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Monkeypatch for FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def apply_patch_is_packed_sequence():
|
||||
"""Apply patch to FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid"""
|
||||
from transformers import modeling_flash_attention_utils
|
||||
|
||||
def fixed_is_packed_sequence(position_ids, batch_size):
|
||||
"""
|
||||
Check the position ids whether packed sequences are indicated or not
|
||||
1. Position ids exist
|
||||
2. Flattened sequences only are supported
|
||||
3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences
|
||||
"""
|
||||
if position_ids is None:
|
||||
return False
|
||||
|
||||
if position_ids.ndim == 1:
|
||||
position_ids = position_ids.unsqueeze(0) # [N] -> [1, N]
|
||||
|
||||
increasing_position_sequences = (
|
||||
torch.arange(position_ids.shape[1], device=position_ids.device)
|
||||
+ position_ids.min()
|
||||
)
|
||||
return (
|
||||
batch_size == 1
|
||||
and (increasing_position_sequences - position_ids).abs().sum().bool().item()
|
||||
)
|
||||
|
||||
# Store original method
|
||||
old_fn = modeling_flash_attention_utils._is_packed_sequence
|
||||
|
||||
# Apply the patch
|
||||
modeling_flash_attention_utils._is_packed_sequence = fixed_is_packed_sequence
|
||||
|
||||
def unpatch():
|
||||
"""Restore the original method"""
|
||||
modeling_flash_attention_utils._is_packed_sequence = old_fn
|
||||
|
||||
return unpatch
|
||||
1
src/axolotl/monkeypatch/models/qwen3_next/__init__.py
Normal file
1
src/axolotl/monkeypatch/models/qwen3_next/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Qwen3_Next model monkeypatches."""
|
||||
317
src/axolotl/monkeypatch/models/qwen3_next/modeling.py
Normal file
317
src/axolotl/monkeypatch/models/qwen3_next/modeling.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""Monkeypatch for Qwen3_Next model to pass position_ids to linear attention."""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def get_cu_seqlens(position_ids):
|
||||
"""
|
||||
Adapted from transformers.modeling_flash_attention_utils.prepare_fa_kwargs_from_position_ids.
|
||||
|
||||
https://github.com/huggingface/transformers/blob/0f1b128d3359a26bd18be99c26d7f04fb3cba914/src/transformers/modeling_flash_attention_utils.py#L316
|
||||
"""
|
||||
tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
|
||||
|
||||
position_ids = position_ids.view(-1)
|
||||
indices_q = (position_ids == 0).nonzero().view(-1)
|
||||
|
||||
cu_seq_lens_q = torch.cat(
|
||||
(
|
||||
indices_q.to(**tensor_kwargs),
|
||||
torch.tensor(position_ids.size(), **tensor_kwargs),
|
||||
)
|
||||
)
|
||||
|
||||
return cu_seq_lens_q
|
||||
|
||||
|
||||
def patch_qwen3_next_decoder_layer():
|
||||
"""Patch Qwen3NextDecoderLayer to pass position_ids to linear attention."""
|
||||
try:
|
||||
from transformers.models.qwen3_next.modeling_qwen3_next import (
|
||||
Qwen3NextDecoderLayer,
|
||||
)
|
||||
except ImportError:
|
||||
LOG.warning("Qwen3Next model not found, skipping patch")
|
||||
return
|
||||
|
||||
# Store original forward method
|
||||
original_decoder_forward = Qwen3NextDecoderLayer.forward
|
||||
|
||||
def patched_decoder_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[torch.Tensor]] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Token Mixer
|
||||
if self.layer_type == "linear_attention":
|
||||
hidden_states = self.linear_attn(
|
||||
hidden_states=hidden_states,
|
||||
cache_params=past_key_values,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
elif self.layer_type == "full_attention":
|
||||
# Self Attention
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
# For the MoE layers, we need to unpack
|
||||
if isinstance(hidden_states, Tuple):
|
||||
hidden_states, _ = hidden_states
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
# Apply the patches
|
||||
Qwen3NextDecoderLayer.forward = patched_decoder_forward
|
||||
|
||||
def unpatch():
|
||||
"""Restore the original forward method"""
|
||||
Qwen3NextDecoderLayer.forward = original_decoder_forward
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
def patch_qwen3_next_gateddelta_layer():
|
||||
"""Patch Qwen3NextGatedDeltaNet to parse cu_seqlens and pass to chunk_gated_delta_rule"""
|
||||
try:
|
||||
from transformers.models.qwen3_next.modeling_qwen3_next import (
|
||||
Qwen3NextDynamicCache,
|
||||
Qwen3NextGatedDeltaNet,
|
||||
apply_mask_to_padding_states,
|
||||
)
|
||||
except ImportError:
|
||||
LOG.warning("Qwen3Next model not found, skipping patch")
|
||||
return
|
||||
|
||||
# Store original forward method
|
||||
original_gated_delta_net_forward = Qwen3NextGatedDeltaNet.forward
|
||||
|
||||
def patched_gated_delta_net_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cache_params: Optional[Qwen3NextDynamicCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
||||
|
||||
# Set up dimensions for reshapes later
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
use_precomputed_states = (
|
||||
cache_params is not None
|
||||
and cache_params.has_previous_state
|
||||
and seq_len == 1
|
||||
and cache_position is not None
|
||||
)
|
||||
|
||||
# getting projected states from cache if it exists
|
||||
if cache_params is not None:
|
||||
conv_state = cache_params.conv_states[self.layer_idx]
|
||||
recurrent_state = cache_params.recurrent_states[self.layer_idx]
|
||||
|
||||
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
|
||||
projected_states_ba = self.in_proj_ba(hidden_states)
|
||||
query, key, value, z, b, a = self.fix_query_key_value_ordering(
|
||||
projected_states_qkvz, projected_states_ba
|
||||
)
|
||||
query, key, value = (
|
||||
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
|
||||
)
|
||||
|
||||
mixed_qkv = torch.cat((query, key, value), dim=-1)
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
|
||||
if use_precomputed_states:
|
||||
# 2. Convolution sequence transformation
|
||||
# NOTE: the conv state is updated in `causal_conv1d_update`
|
||||
mixed_qkv = self.causal_conv1d_update(
|
||||
mixed_qkv,
|
||||
conv_state,
|
||||
self.conv1d.weight.squeeze(1),
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
else:
|
||||
if cache_params is not None:
|
||||
conv_state = F.pad(
|
||||
mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)
|
||||
)
|
||||
cache_params.conv_states[self.layer_idx] = conv_state
|
||||
if self.causal_conv1d_fn is not None:
|
||||
mixed_qkv = self.causal_conv1d_fn(
|
||||
x=mixed_qkv,
|
||||
weight=self.conv1d.weight.squeeze(1),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
seq_idx=None,
|
||||
)
|
||||
else:
|
||||
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
||||
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
query, key, value = torch.split(
|
||||
mixed_qkv,
|
||||
[
|
||||
self.key_dim,
|
||||
self.key_dim,
|
||||
self.value_dim,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim)
|
||||
key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim)
|
||||
value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim)
|
||||
|
||||
beta = b.sigmoid()
|
||||
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
||||
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
|
||||
if self.num_v_heads // self.num_k_heads > 1:
|
||||
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||
|
||||
if not use_precomputed_states:
|
||||
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
|
||||
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
g=g,
|
||||
beta=beta,
|
||||
initial_state=None,
|
||||
output_final_state=cache_params is not None,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
else:
|
||||
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
g=g,
|
||||
beta=beta,
|
||||
initial_state=recurrent_state,
|
||||
output_final_state=cache_params is not None,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
|
||||
# Update cache
|
||||
if cache_params is not None:
|
||||
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
|
||||
|
||||
z_shape_og = z.shape
|
||||
# reshape input data into 2D tensor
|
||||
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
core_attn_out = self.norm(core_attn_out, z)
|
||||
core_attn_out = core_attn_out.reshape(z_shape_og)
|
||||
core_attn_out = core_attn_out.reshape(
|
||||
core_attn_out.shape[0], core_attn_out.shape[1], -1
|
||||
)
|
||||
|
||||
output = self.out_proj(core_attn_out)
|
||||
return output
|
||||
|
||||
# Apply the patches
|
||||
Qwen3NextGatedDeltaNet.forward = patched_gated_delta_net_forward
|
||||
|
||||
def unpatch():
|
||||
"""Restore the original forward method"""
|
||||
Qwen3NextGatedDeltaNet.forward = original_gated_delta_net_forward
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
def patch_qwen3_next_imports():
|
||||
"""Patch Qwen3Next imports to use try/except instead of is_flash_linear_attention_available."""
|
||||
try:
|
||||
import transformers.models.qwen3_next.modeling_qwen3_next as qwen3_modeling
|
||||
except ImportError:
|
||||
LOG.warning("Qwen3Next model not found, skipping import patch")
|
||||
return
|
||||
|
||||
# Save original values for unpatch
|
||||
original_FusedRMSNormGated = getattr(qwen3_modeling, "FusedRMSNormGated", None)
|
||||
original_chunk_gated_delta_rule = getattr(
|
||||
qwen3_modeling, "chunk_gated_delta_rule", None
|
||||
)
|
||||
original_fused_recurrent_gated_delta_rule = getattr(
|
||||
qwen3_modeling, "fused_recurrent_gated_delta_rule", None
|
||||
)
|
||||
original_is_fast_path_available = getattr(
|
||||
qwen3_modeling, "is_fast_path_available", False
|
||||
)
|
||||
|
||||
try:
|
||||
from fla.modules import FusedRMSNormGated
|
||||
from fla.ops.gated_delta_rule import (
|
||||
chunk_gated_delta_rule,
|
||||
fused_recurrent_gated_delta_rule,
|
||||
)
|
||||
|
||||
qwen3_modeling.FusedRMSNormGated = FusedRMSNormGated
|
||||
qwen3_modeling.chunk_gated_delta_rule = chunk_gated_delta_rule
|
||||
qwen3_modeling.fused_recurrent_gated_delta_rule = (
|
||||
fused_recurrent_gated_delta_rule
|
||||
)
|
||||
|
||||
# Force is_fast_path_available to be True
|
||||
# fla has triton kernels for causal_conv1d
|
||||
qwen3_modeling.is_fast_path_available = True
|
||||
except ImportError:
|
||||
qwen3_modeling.chunk_gated_delta_rule = None
|
||||
qwen3_modeling.fused_recurrent_gated_delta_rule = None
|
||||
qwen3_modeling.FusedRMSNormGated = None
|
||||
|
||||
def unpatch():
|
||||
"""Restore the original import values"""
|
||||
qwen3_modeling.FusedRMSNormGated = original_FusedRMSNormGated
|
||||
qwen3_modeling.chunk_gated_delta_rule = original_chunk_gated_delta_rule
|
||||
qwen3_modeling.fused_recurrent_gated_delta_rule = (
|
||||
original_fused_recurrent_gated_delta_rule
|
||||
)
|
||||
qwen3_modeling.is_fast_path_available = original_is_fast_path_available
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
def patch_qwen3_next_modeling_packing():
|
||||
"""Apply all Qwen3Next model patches."""
|
||||
patch_qwen3_next_imports()
|
||||
patch_qwen3_next_decoder_layer()
|
||||
patch_qwen3_next_gateddelta_layer()
|
||||
|
||||
LOG.info("Applied Qwen3Next patch for packing")
|
||||
@@ -1,133 +0,0 @@
|
||||
import logging
|
||||
import weakref
|
||||
from functools import wraps
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
|
||||
|
||||
_LOG = logging.getLogger("axolotl.moe.patch")
|
||||
|
||||
|
||||
def _patch_block_forward(block_cls, grouped_fn):
|
||||
"""Replace block_cls.forward with grouped_fn preserving signature."""
|
||||
block_cls.forward = grouped_fn
|
||||
|
||||
|
||||
def apply_grouped_to_moe_blocks(cfg=None) -> None:
|
||||
"""
|
||||
Attempt to patch all known MoE block classes to use the torch_grouped backend
|
||||
when cfg.moe_backend resolves to 'torch_grouped' and the op is available.
|
||||
Falls back to original forwards otherwise.
|
||||
"""
|
||||
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
|
||||
backend = get_moe_backend_name(preferred)
|
||||
if backend != MOEBackend.TORCH_GROUPED:
|
||||
_LOG.info(
|
||||
f"moe_backend is '{backend}', not 'torch_grouped'; skipping grouped patches"
|
||||
)
|
||||
return
|
||||
try:
|
||||
from axolotl.kernels.moe import torch_grouped as _tg
|
||||
except Exception:
|
||||
_LOG.warning("torch_grouped backend import failed; skipping grouped patches")
|
||||
return
|
||||
if not _tg.available():
|
||||
_LOG.warning(
|
||||
"torch_grouped requested but unavailable (op smoke test failed); skipping grouped patches"
|
||||
)
|
||||
return
|
||||
|
||||
# Map of architecture key to (modeling module path, class name or list of class names)
|
||||
model_mods = {
|
||||
"mixtral": (
|
||||
"transformers.models.mixtral.modeling_mixtral",
|
||||
MOE_ARCH_BLOCK.get("mixtral"),
|
||||
),
|
||||
"qwen2_moe": (
|
||||
"transformers.models.qwen2_moe.modeling_qwen2_moe",
|
||||
MOE_ARCH_BLOCK.get("qwen2_moe"),
|
||||
),
|
||||
"qwen3_moe": (
|
||||
"transformers.models.qwen3_moe.modeling_qwen3_moe",
|
||||
MOE_ARCH_BLOCK.get("qwen3_moe"),
|
||||
),
|
||||
"jamba": (
|
||||
"transformers.models.jamba.modeling_jamba",
|
||||
MOE_ARCH_BLOCK.get("jamba"),
|
||||
),
|
||||
"deepseek_v2": (
|
||||
"transformers.models.deepseek_v2.modeling_deepseek_v2",
|
||||
MOE_ARCH_BLOCK.get("deepseek_v2"),
|
||||
),
|
||||
# Others may not follow standard paths; best-effort import
|
||||
"dbrx": ("transformers.models.dbrx.modeling_dbrx", MOE_ARCH_BLOCK.get("dbrx")),
|
||||
"jetmoe": (
|
||||
"transformers.models.jetmoe.modeling_jetmoe",
|
||||
MOE_ARCH_BLOCK.get("jetmoe"),
|
||||
),
|
||||
"gpt_oss": (
|
||||
"transformers.models.gpt_oss.modeling_gpt_oss",
|
||||
MOE_ARCH_BLOCK.get("gpt_oss"),
|
||||
),
|
||||
}
|
||||
|
||||
def make_grouped_forward(orig_forward):
|
||||
@wraps(orig_forward)
|
||||
def _grouped_forward(self, hidden_states: torch.Tensor, *args, **kwargs):
|
||||
bsz, seqlen, hdim = hidden_states.shape
|
||||
# expose parent block so grouped backend can access shared expert context
|
||||
try:
|
||||
self.experts._ax_parent_block_ref = weakref.ref(self)
|
||||
except Exception:
|
||||
pass
|
||||
y, router_logits = _tg.moe_ffn_forward_grouped(
|
||||
hidden_states, self.gate, self.experts, self.top_k
|
||||
)
|
||||
# One-time log per block instance indicating whether grouped engaged or fallback occurred
|
||||
if not getattr(self, "_ax_grouped_wrapper_logged", False):
|
||||
if y is None:
|
||||
_LOG.warning(
|
||||
"Grouped wrapper active but fell back to naive for %s",
|
||||
self.__class__.__name__,
|
||||
)
|
||||
else:
|
||||
_LOG.info(
|
||||
f"Grouped wrapper engaged for {self.__class__.__name__} (top_k={self.top_k})"
|
||||
)
|
||||
self._ax_grouped_wrapper_logged = True
|
||||
if y is None:
|
||||
return orig_forward(self, hidden_states, *args, **kwargs)
|
||||
return y, router_logits
|
||||
|
||||
return _grouped_forward
|
||||
|
||||
patched = 0
|
||||
for key, (mod_path, cls_names) in model_mods.items():
|
||||
if not cls_names:
|
||||
continue
|
||||
try:
|
||||
import importlib
|
||||
|
||||
modeling = importlib.import_module(mod_path)
|
||||
names = cls_names if isinstance(cls_names, list) else [cls_names]
|
||||
for name in names:
|
||||
if not hasattr(modeling, name):
|
||||
continue
|
||||
block_cls = getattr(modeling, name)
|
||||
orig_forward = getattr(block_cls, "forward", None)
|
||||
if orig_forward is None:
|
||||
continue
|
||||
_patch_block_forward(block_cls, make_grouped_forward(orig_forward))
|
||||
patched += 1
|
||||
_LOG.info(f"Patched MoE block for grouped GEMM: {mod_path}.{name}")
|
||||
except Exception as e:
|
||||
# Best effort; log and skip this entry
|
||||
_LOG.warning(f"Skipping MoE patch for arch '{key}' ({mod_path}): {e}")
|
||||
if patched == 0:
|
||||
_LOG.warning(
|
||||
"No MoE blocks patched for grouped GEMM; model may not use known MoE classes"
|
||||
)
|
||||
else:
|
||||
_LOG.info(f"Grouped GEMM patches applied to {patched} MoE block class(es)")
|
||||
@@ -11,6 +11,7 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"apertus",
|
||||
"mllama_text_model",
|
||||
"llama",
|
||||
"llama4",
|
||||
@@ -20,6 +21,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"qwen2_moe",
|
||||
"qwen3",
|
||||
"qwen3_moe",
|
||||
"qwen3_next",
|
||||
"falcon",
|
||||
"phi",
|
||||
"phi3",
|
||||
@@ -46,7 +48,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
]
|
||||
|
||||
|
||||
def patch_for_multipack(model_type, model_name=None, has_remote_code=False, cfg=None):
|
||||
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
||||
if has_remote_code:
|
||||
patch_remote(model_name)
|
||||
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||
@@ -57,7 +59,7 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False, cfg=
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||
|
||||
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3(cfg)
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
|
||||
|
||||
def patch_remote(model_name):
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Monkey patch to allow context parallelism with FlashAttention in HF Trainer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
|
||||
PATCHED_GUARD = (
|
||||
'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):'
|
||||
)
|
||||
|
||||
|
||||
def patch_prepare_context_parallel_inputs() -> None:
|
||||
"""Relax the SDPA-only guard when running context parallelism with FlashAttention."""
|
||||
if getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False):
|
||||
LOG.debug("Trainer._prepare_context_parallel_inputs already patched")
|
||||
return
|
||||
|
||||
try:
|
||||
original_source = inspect.getsource(Trainer._prepare_context_parallel_inputs)
|
||||
except OSError as exc: # pragma: no cover - occurs when source is unavailable
|
||||
LOG.warning("Unable to patch Trainer._prepare_context_parallel_inputs: %s", exc)
|
||||
return
|
||||
|
||||
if GUARD_PATTERN not in original_source:
|
||||
LOG.warning(
|
||||
"Expected guard not found in Trainer._prepare_context_parallel_inputs; \n"
|
||||
"skipping FlashAttention context parallelism patch"
|
||||
)
|
||||
return
|
||||
|
||||
patched_source = original_source.replace(GUARD_PATTERN, PATCHED_GUARD)
|
||||
patched_source, _ = detab_code(patched_source)
|
||||
patched_source = patched_source.replace(
|
||||
"def _prepare_context_parallel_inputs(",
|
||||
"def axolotl_prepare_context_parallel_inputs(",
|
||||
1,
|
||||
)
|
||||
|
||||
module_name = Trainer.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# import symbols referenced in the method so exec can succeed
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in patched_source:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec(f"from {module_name} import ({', '.join(items_to_import)})", globals())
|
||||
exec(patched_source, globals())
|
||||
|
||||
Trainer._original_prepare_context_parallel_inputs = (
|
||||
Trainer._prepare_context_parallel_inputs
|
||||
)
|
||||
Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs
|
||||
Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source
|
||||
Trainer._axolotl_prepare_context_parallel_inputs_patched = True
|
||||
LOG.debug(
|
||||
"Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP"
|
||||
)
|
||||
@@ -11,6 +11,7 @@ from transformers.image_utils import load_image
|
||||
|
||||
from axolotl.utils.dict import remove_none_values
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -421,6 +422,36 @@ class SmolVLM2ProcessingStrategy(ProcessingStrategy):
|
||||
]
|
||||
|
||||
|
||||
class Mistral3ProcessingStrategy(ProcessingStrategy):
|
||||
"""Processing Strategy class for Mistral3"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor: Mistral3Processor,
|
||||
chat_template: Optional[str] = None,
|
||||
image_size: int | tuple[int, int] | None = None,
|
||||
image_resize_algorithm: Resampling | None = None,
|
||||
):
|
||||
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||
special_ids = (
|
||||
processor.tokenizer.tokenizer.instruct_tokenizer.image_encoder.special_ids
|
||||
)
|
||||
|
||||
self.image_token = special_ids.img
|
||||
self.image_break_token = special_ids.img_break
|
||||
self.image_end_token = special_ids.img_end
|
||||
|
||||
def process_labels(self, input_ids):
|
||||
labels = input_ids.clone()
|
||||
|
||||
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == self.image_token] = -100
|
||||
labels[labels == self.image_break_token] = -100
|
||||
labels[labels == self.image_end_token] = -100
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def get_processing_strategy(
|
||||
processor: ProcessorMixin,
|
||||
chat_template,
|
||||
@@ -463,6 +494,11 @@ def get_processing_strategy(
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(processor, Mistral3Processor):
|
||||
return Mistral3ProcessingStrategy(
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
# llama3_2_vision, llama4, llava
|
||||
# mistral_v7_tekken, pixtral, lfm2vl
|
||||
return ProcessingStrategy(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
HF Chat Templates prompt strategy
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Set, Union
|
||||
|
||||
@@ -794,6 +795,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
if val is not None:
|
||||
transformed_message[key] = val
|
||||
|
||||
if "tool_calls" in transformed_message and transformed_message["tool_calls"]:
|
||||
for tool_call in transformed_message["tool_calls"]:
|
||||
if "function" in tool_call and "arguments" in tool_call["function"]:
|
||||
args = tool_call["function"]["arguments"]
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
tool_call["function"]["arguments"] = json.loads(args)
|
||||
except json.JSONDecodeError as e:
|
||||
LOG.error(
|
||||
f"Error parsing tool_calls arguments as JSON. "
|
||||
f"Function: {tool_call.get('function', {}).get('name', 'unknown')}, "
|
||||
f"Arguments string: {args!r}, "
|
||||
f"Error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
return transformed_message
|
||||
|
||||
def _get_images(self, prompt):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Init for `axolotl.utils.mistral` module."""
|
||||
|
||||
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
|
||||
from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer
|
||||
|
||||
__all__ = ["HFMistralTokenizer"]
|
||||
__all__ = ["HFMistralTokenizer", "Mistral3Processor"]
|
||||
|
||||
169
src/axolotl/utils/mistral/mistral3_processor.py
Normal file
169
src/axolotl/utils/mistral/mistral3_processor.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Processor for Mistral3 multimodal models with image support"""
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.processing_utils import ProcessingKwargs
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
|
||||
from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer
|
||||
|
||||
|
||||
class Mistral3ProcessorKwargs(ProcessingKwargs):
|
||||
_defaults: Dict[str, Dict[str, Any]] = {
|
||||
"text_kwargs": {
|
||||
"padding": True,
|
||||
},
|
||||
"common_kwargs": {
|
||||
"return_tensors": "pt",
|
||||
"return_dict": True,
|
||||
"tokenize": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Mistral3Processor(ProcessorMixin):
|
||||
"""
|
||||
Processor for Mistral3 multimodal models that handles text and images.
|
||||
Wraps HFMistralTokenizer and adds image processing capabilities.
|
||||
"""
|
||||
|
||||
attributes = ["tokenizer"]
|
||||
tokenizer_class = "HFMistralTokenizer"
|
||||
|
||||
def __init__(self, tokenizer: HFMistralTokenizer):
|
||||
# Don't call super().__init__ to avoid the class validation issue
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@property
|
||||
def chat_template(self) -> None:
|
||||
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def audio_tokenizer(self) -> None:
|
||||
"""Audio tokenizer is not supported. Dummy method to satisfy HuggingFace API."""
|
||||
return None
|
||||
|
||||
def _merge_kwargs(
|
||||
self, processor_kwargs_class: Any, **kwargs: Any
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""Merge kwargs with defaults similar to ProcessorMixin"""
|
||||
defaults = processor_kwargs_class._defaults
|
||||
output_kwargs: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
for kwarg_type, default_values in defaults.items():
|
||||
output_kwargs[kwarg_type] = {**default_values}
|
||||
|
||||
# Update with provided kwargs
|
||||
for key, value in kwargs.items():
|
||||
# Try to match key to appropriate kwarg type
|
||||
if key in ["padding", "truncation", "max_length"]:
|
||||
output_kwargs.setdefault("text_kwargs", {}).update({key: value})
|
||||
elif key in ["return_tensors", "return_dict", "tokenize"]:
|
||||
output_kwargs.setdefault("common_kwargs", {}).update({key: value})
|
||||
else:
|
||||
# Add to text_kwargs by default
|
||||
output_kwargs.setdefault("text_kwargs", {}).update({key: value})
|
||||
|
||||
return output_kwargs
|
||||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
|
||||
**kwargs: Any,
|
||||
) -> Union[BatchFeature, str, list[str]]:
|
||||
"""
|
||||
Apply chat template with image support for Mistral3.
|
||||
|
||||
Similar to VoxtralProcessor, this method extracts images from the conversation,
|
||||
calls the tokenizer's apply_chat_template, then adds pixel_values and image_sizes
|
||||
to the result.
|
||||
"""
|
||||
output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs)
|
||||
text_kwargs = output_kwargs["text_kwargs"]
|
||||
common_kwargs = output_kwargs["common_kwargs"]
|
||||
|
||||
return_tensors = common_kwargs.pop("return_tensors", "pt")
|
||||
if return_tensors != "pt":
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} only supports `return_tensors='pt'`."
|
||||
)
|
||||
|
||||
return_dict = common_kwargs.pop("return_dict", False)
|
||||
tokenize = common_kwargs.pop("tokenize", False)
|
||||
|
||||
# Determine if batched
|
||||
if isinstance(conversation, (list, tuple)) and (
|
||||
isinstance(conversation[0], (list, tuple))
|
||||
or hasattr(conversation[0], "content")
|
||||
):
|
||||
is_batched = True
|
||||
conversations = conversation
|
||||
else:
|
||||
is_batched = False
|
||||
conversations = [conversation] # type: ignore
|
||||
|
||||
# Call tokenizer's apply_chat_template
|
||||
tokenizer_kwargs = {**text_kwargs, **common_kwargs}
|
||||
tokenizer_kwargs["return_tensors"] = return_tensors
|
||||
tokenizer_kwargs["tokenize"] = tokenize
|
||||
tokenizer_kwargs["return_dict"] = return_dict
|
||||
|
||||
encoded_instruct_inputs = self.tokenizer.apply_chat_template(
|
||||
conversations,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
if tokenize:
|
||||
if return_dict:
|
||||
# The tokenizer already handles pixel_values, we just need to add image_sizes
|
||||
if hasattr(encoded_instruct_inputs, "items"):
|
||||
data: Dict[str, Any] = dict(encoded_instruct_inputs) # type: ignore
|
||||
elif hasattr(encoded_instruct_inputs, "data"):
|
||||
data = encoded_instruct_inputs.data # type: ignore
|
||||
else:
|
||||
raise ValueError("Unknown data type")
|
||||
|
||||
if "pixel_values" in data:
|
||||
pixel_values = data["pixel_values"]
|
||||
|
||||
# MistralTokenizer returns a Double, so we convert to fp32
|
||||
data["pixel_values"] = pixel_values.to(dtype=torch.float32)
|
||||
|
||||
# Always batched: [B, C, H, W] -> image_sizes: [B, 2]
|
||||
# Since tensor is homogeneous, all images have same H, W
|
||||
batch_size = pixel_values.shape[0]
|
||||
image_sizes = torch.tensor([pixel_values.shape[-2:]] * batch_size)
|
||||
data["image_sizes"] = image_sizes
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
if not is_batched:
|
||||
return encoded_instruct_inputs[0]
|
||||
|
||||
return encoded_instruct_inputs
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[
|
||||
Union[
|
||||
TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]
|
||||
]
|
||||
],
|
||||
**kwargs: Any,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Forward text processing to the tokenizer.
|
||||
This method does not support images - use apply_chat_template instead.
|
||||
"""
|
||||
output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs)
|
||||
text_kwargs = output_kwargs["text_kwargs"]
|
||||
common_kwargs = output_kwargs["common_kwargs"]
|
||||
|
||||
out = self.tokenizer(text, **text_kwargs)
|
||||
return BatchFeature(
|
||||
data=out, tensor_type=common_kwargs.pop("return_tensors", None)
|
||||
)
|
||||
@@ -113,6 +113,19 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
moe_kernels: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Enable Axolotl's vendored MoE kernels when supported (e.g., DeepSeek V3)"
|
||||
},
|
||||
)
|
||||
moe_kernel_backend: Literal["cg", "mg"] | None = Field(
|
||||
default="mg",
|
||||
json_schema_extra={
|
||||
"description": "Grouped GEMM backend to use when `moe_kernels` is enabled. `mg` selects the Hopper TMA kernel; `cg` selects the contiguous kernel."
|
||||
},
|
||||
)
|
||||
|
||||
trainer_cls: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -132,14 +145,6 @@ class AxolotlInputConfig(
|
||||
vllm: VllmConfig | None = Field(
|
||||
default_factory=lambda: VllmConfig(),
|
||||
)
|
||||
moe_backend: Literal["auto", "torch_grouped", "naive"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Mixture-of-Experts backend to use: 'auto', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.",
|
||||
},
|
||||
)
|
||||
|
||||
# Value is constrained by the Literal type; no normalization needed.
|
||||
qat: QATConfig | None = None
|
||||
quantization: PTQConfig | None = None
|
||||
reward_model: bool | None = Field(
|
||||
@@ -444,8 +449,8 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
min_sample_len: int | None = None
|
||||
max_prompt_len: int = Field(
|
||||
default=512,
|
||||
max_prompt_len: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "maximum prompt length for RL training"},
|
||||
)
|
||||
sample_packing: bool | None = Field(
|
||||
|
||||
35
tests/monkeypatch/test_mistral_tokenizer_patch.py
Normal file
35
tests/monkeypatch/test_mistral_tokenizer_patch.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Integration tests for MistralCommonTokenizer patches."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestMistralTokenizerPatchIntegration:
|
||||
"""Test MistralCommonTokenizer patch integration."""
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_mistral_tokenizer_image_patch(self):
|
||||
"""Test that MistralCommonTokenizer image patch can be applied."""
|
||||
try:
|
||||
from transformers.tokenization_mistral_common import MistralCommonTokenizer
|
||||
except ImportError:
|
||||
pytest.skip("MistralCommonTokenizer not available")
|
||||
|
||||
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
|
||||
apply_mistral_tokenizer_image_patch,
|
||||
)
|
||||
|
||||
# Store original method
|
||||
original_apply_chat_template = MistralCommonTokenizer.apply_chat_template
|
||||
|
||||
# Apply patch
|
||||
apply_mistral_tokenizer_image_patch()
|
||||
|
||||
# Verify patch was applied
|
||||
assert (
|
||||
MistralCommonTokenizer.apply_chat_template != original_apply_chat_template
|
||||
), "apply_chat_template was not patched"
|
||||
|
||||
# Verify the method is still callable
|
||||
assert callable(MistralCommonTokenizer.apply_chat_template), (
|
||||
"Patched method is not callable"
|
||||
)
|
||||
@@ -1,258 +0,0 @@
|
||||
import sys
|
||||
import types
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from axolotl.kernels.moe import (
|
||||
backends as moe_backends,
|
||||
torch_grouped as torch_grouped_module,
|
||||
)
|
||||
from axolotl.monkeypatch import moe_grouped
|
||||
|
||||
|
||||
class DummyExperts(nn.Module):
|
||||
def __init__(self, layers):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.num_experts = len(layers)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.layers[idx]
|
||||
|
||||
|
||||
class DummyQwenMLP(nn.Module):
|
||||
def __init__(self, idx: int, hidden: int, intermediate: int):
|
||||
super().__init__()
|
||||
self.gate_up_proj = nn.Linear(hidden, 2 * intermediate, bias=False)
|
||||
self.down_proj = nn.Linear(intermediate, hidden, bias=False)
|
||||
nn.init.constant_(self.gate_up_proj.weight, float(idx + 1))
|
||||
nn.init.constant_(self.down_proj.weight, float((idx + 1) * 10))
|
||||
|
||||
|
||||
class DummyQwenExpert(nn.Module):
|
||||
def __init__(self, idx: int, hidden: int, intermediate: int):
|
||||
super().__init__()
|
||||
self.mlp = DummyQwenMLP(idx, hidden, intermediate)
|
||||
|
||||
|
||||
def _make_transformers_stub(monkeypatch, block_cls):
|
||||
# ensure we start from the original forward for each test
|
||||
if block_cls is DummyMixtralBlock:
|
||||
DummyMixtralBlock.forward = _DUMMY_MIXTRAL_ORIG_FORWARD
|
||||
|
||||
transformers_mod = types.ModuleType("transformers")
|
||||
models_mod = types.ModuleType("transformers.models")
|
||||
mixtral_mod = types.ModuleType("transformers.models.mixtral")
|
||||
modeling_mixtral = types.ModuleType("transformers.models.mixtral.modeling_mixtral")
|
||||
modeling_mixtral.MixtralSparseMoeBlock = block_cls
|
||||
|
||||
transformers_mod.models = models_mod
|
||||
models_mod.mixtral = mixtral_mod
|
||||
mixtral_mod.modeling_mixtral = modeling_mixtral
|
||||
|
||||
monkeypatch.setitem(sys.modules, "transformers", transformers_mod)
|
||||
monkeypatch.setitem(sys.modules, "transformers.models", models_mod)
|
||||
monkeypatch.setitem(sys.modules, "transformers.models.mixtral", mixtral_mod)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"transformers.models.mixtral.modeling_mixtral",
|
||||
modeling_mixtral,
|
||||
)
|
||||
|
||||
|
||||
def test_grouped_uses_per_expert_nested_modules(monkeypatch):
|
||||
hidden = 4
|
||||
intermediate = 2
|
||||
num_experts = 2
|
||||
|
||||
experts = DummyExperts(
|
||||
[DummyQwenExpert(i, hidden, intermediate) for i in range(num_experts)]
|
||||
)
|
||||
|
||||
gate = nn.Linear(hidden, num_experts, bias=False)
|
||||
nn.init.zeros_(gate.weight)
|
||||
|
||||
captured = []
|
||||
|
||||
def fake_grouped_mm(As, Bs, dtype):
|
||||
captured.append([b.detach().clone() for b in Bs])
|
||||
return [
|
||||
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
|
||||
for a, b in zip(As, Bs, strict=False)
|
||||
]
|
||||
|
||||
monkeypatch.setattr(torch_grouped_module, "_call_grouped_mm", fake_grouped_mm)
|
||||
|
||||
hidden_states = torch.randn(1, 2, hidden)
|
||||
y, router_logits = torch_grouped_module.moe_ffn_forward_grouped(
|
||||
hidden_states, gate, experts, top_k=2
|
||||
)
|
||||
|
||||
assert y is not None
|
||||
assert router_logits is not None
|
||||
assert captured, "Grouped GEMM path should have been invoked"
|
||||
first_call = captured[0]
|
||||
expected0 = experts[0].mlp.gate_up_proj.weight.t()
|
||||
expected1 = experts[1].mlp.gate_up_proj.weight.t()
|
||||
assert torch.equal(first_call[0], expected0)
|
||||
assert torch.equal(first_call[1], expected1)
|
||||
assert not torch.equal(first_call[0], first_call[1])
|
||||
|
||||
|
||||
def test_grouped_accepts_module_list_experts(monkeypatch):
|
||||
hidden = 4
|
||||
intermediate = 2
|
||||
experts = nn.ModuleList(
|
||||
[DummyQwenExpert(i, hidden, intermediate) for i in range(2)]
|
||||
)
|
||||
gate = nn.Linear(hidden, len(experts), bias=False)
|
||||
nn.init.zeros_(gate.weight)
|
||||
|
||||
calls = {"count": 0}
|
||||
|
||||
def fake_grouped_mm(As, Bs, dtype):
|
||||
calls["count"] += 1
|
||||
return [
|
||||
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
|
||||
for a, b in zip(As, Bs, strict=False)
|
||||
]
|
||||
|
||||
monkeypatch.setattr(torch_grouped_module, "_call_grouped_mm", fake_grouped_mm)
|
||||
|
||||
hidden_states = torch.randn(1, 2, hidden)
|
||||
y, router_logits = torch_grouped_module.moe_ffn_forward_grouped(
|
||||
hidden_states, gate, experts, top_k=2
|
||||
)
|
||||
|
||||
assert y is not None
|
||||
assert router_logits is not None
|
||||
assert calls["count"] > 0
|
||||
|
||||
|
||||
class _DummyCfg:
|
||||
moe_backend = "torch_grouped"
|
||||
|
||||
|
||||
class DummyMixtralBlock(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.top_k = 1
|
||||
self.gate = lambda x: x
|
||||
self.experts = object()
|
||||
self._calls = []
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, attention_mask=None):
|
||||
self._calls.append((hidden_states, attention_mask))
|
||||
tokens = hidden_states.shape[0] * hidden_states.shape[1]
|
||||
router = torch.ones(
|
||||
tokens, 2, device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)
|
||||
return hidden_states + 5, router
|
||||
|
||||
|
||||
_DUMMY_MIXTRAL_ORIG_FORWARD = DummyMixtralBlock.forward
|
||||
|
||||
|
||||
def test_apply_grouped_forward_handles_args(monkeypatch):
|
||||
_make_transformers_stub(monkeypatch, DummyMixtralBlock)
|
||||
import axolotl.common.architectures as arch
|
||||
|
||||
original_map = arch.MOE_ARCH_BLOCK.copy()
|
||||
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, "mixtral", "MixtralSparseMoeBlock")
|
||||
for key in list(original_map.keys()):
|
||||
if key != "mixtral":
|
||||
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, key, None)
|
||||
|
||||
monkeypatch.setattr(
|
||||
moe_grouped,
|
||||
"get_moe_backend_name",
|
||||
lambda preferred=None: moe_backends.MOEBackend.TORCH_GROUPED,
|
||||
)
|
||||
|
||||
results = {}
|
||||
|
||||
def fake_grouped_forward(hidden_states, gate, experts, top_k):
|
||||
results["called"] = True
|
||||
router = torch.zeros(
|
||||
hidden_states.shape[0] * hidden_states.shape[1],
|
||||
2,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
return hidden_states + 1, router
|
||||
|
||||
monkeypatch.setattr(torch_grouped_module, "available", lambda: True)
|
||||
monkeypatch.setattr(
|
||||
torch_grouped_module,
|
||||
"moe_ffn_forward_grouped",
|
||||
fake_grouped_forward,
|
||||
)
|
||||
|
||||
cfg = _DummyCfg()
|
||||
moe_grouped.apply_grouped_to_moe_blocks(cfg)
|
||||
|
||||
block = DummyMixtralBlock()
|
||||
hidden_states = torch.ones(1, 2, 3)
|
||||
mask = torch.zeros(1, 2)
|
||||
out, router = block.forward(hidden_states, attention_mask=mask)
|
||||
|
||||
assert results.get("called") is True
|
||||
assert torch.equal(out, hidden_states + 1)
|
||||
assert router.shape[0] == hidden_states.shape[0] * hidden_states.shape[1]
|
||||
|
||||
|
||||
def test_apply_grouped_forward_fallback(monkeypatch):
|
||||
_make_transformers_stub(monkeypatch, DummyMixtralBlock)
|
||||
import axolotl.common.architectures as arch
|
||||
|
||||
original_map = arch.MOE_ARCH_BLOCK.copy()
|
||||
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, "mixtral", "MixtralSparseMoeBlock")
|
||||
for key in list(original_map.keys()):
|
||||
if key != "mixtral":
|
||||
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, key, None)
|
||||
|
||||
monkeypatch.setattr(
|
||||
moe_grouped,
|
||||
"get_moe_backend_name",
|
||||
lambda preferred=None: moe_backends.MOEBackend.TORCH_GROUPED,
|
||||
)
|
||||
monkeypatch.setattr(torch_grouped_module, "available", lambda: True)
|
||||
monkeypatch.setattr(
|
||||
torch_grouped_module,
|
||||
"moe_ffn_forward_grouped",
|
||||
lambda *args, **kwargs: (None, None),
|
||||
)
|
||||
|
||||
cfg = _DummyCfg()
|
||||
moe_grouped.apply_grouped_to_moe_blocks(cfg)
|
||||
|
||||
block = DummyMixtralBlock()
|
||||
hidden_states = torch.ones(1, 2, 3)
|
||||
mask = torch.zeros(1, 2)
|
||||
out, router = block.forward(hidden_states, attention_mask=mask)
|
||||
|
||||
assert torch.equal(out, hidden_states + 5)
|
||||
assert router.shape[0] == hidden_states.shape[0] * hidden_states.shape[1]
|
||||
assert block._calls, "Original forward should have been invoked"
|
||||
call_hidden, call_mask = block._calls[-1]
|
||||
assert torch.equal(call_hidden, hidden_states)
|
||||
assert torch.equal(call_mask, mask)
|
||||
|
||||
|
||||
def test_get_moe_backend_name_prefers_probe(monkeypatch):
|
||||
monkeypatch.setattr(moe_backends, "_probe_torch_grouped", lambda: True)
|
||||
assert moe_backends.get_moe_backend_name() == moe_backends.MOEBackend.TORCH_GROUPED
|
||||
|
||||
|
||||
def test_get_moe_backend_name_falls_back(monkeypatch):
|
||||
warnings_captured = []
|
||||
|
||||
def fake_warn(msg, *, stacklevel=None): # noqa: ARG001
|
||||
warnings_captured.append(msg)
|
||||
|
||||
monkeypatch.setattr(moe_backends, "_probe_torch_grouped", lambda: False)
|
||||
monkeypatch.setattr(moe_backends.warnings, "warn", fake_warn)
|
||||
backend = moe_backends.get_moe_backend_name("torch_grouped")
|
||||
assert backend == moe_backends.MOEBackend.NAIVE
|
||||
assert warnings_captured, "Expected warning when torch_grouped unavailable"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user