feat: add Mistral Small 4 (#3502)

* feat: add mistral small 4

* fix: update mistral common

* fix: deepcopy when passing in tokenizer

* feat: add doc on reasoning and thinking section

* fix: don't use custom tokenizer and quantize experts

* chore: update docs and configs

* chore: update doc to follow official name

* feat: update cce to include mistral4

* chore: move

* fix: naming

* fix: test mock breaking get_text_config check

* fix: enable CCE and add expert block targetting to configs

* chore: docs

* fix: use act checkpointing

* chore: doc

* chore: docs

* chore: docs
This commit is contained in:
NanoCode012
2026-03-17 09:39:05 +07:00
committed by GitHub
parent 7da5f94379
commit a098df527b
20 changed files with 417 additions and 14 deletions

View File

@@ -16,6 +16,7 @@ MOE_ARCH_BLOCK = {
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
"deepseek_v3": "DeepseekV3MoE",
"mistral4": "Mistral4MoE",
"gpt_oss": "GptOssDecoderLayer",
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
"afmoe": "AfmoeMoE",

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"
```
## Usage
@@ -73,8 +73,10 @@ plugins:
- ministral3
- mistral
- mistral3
- mistral4
- mixtral
- mllama
- nemotron_h
- olmo
- olmo2
- olmo3

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"`'
)

View File

@@ -25,6 +25,8 @@ SPARSE_MOE_BLOCK = {
"olmoe": "OlmoeSparseMoeBlock",
"mixtral": "MixtralSparseMoeBlock",
"minimax": "MiniMaxSparseMoeBlock",
# softmax -> topk routing (with group-based expert selection)
"mistral4": "Mistral4MoE",
# sigmoid -> topk routing (with group-based expert selection)
"glm_moe_dsa": "GlmMoeDsaMoE",
"deepseek_v3": "DeepseekV3MoE",

View File

@@ -61,9 +61,11 @@ class KernelsPlugin(BasePlugin):
return "axolotl.integrations.kernels.KernelsArgs"
def pre_model_load(self, cfg):
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
if cfg.use_scattermoe:
self._register_kernels()
self._kernelize_model(cfg.model_config_type)
self._kernelize_model(moe_model_type)
elif cfg.use_sonicmoe:
if not importlib.util.find_spec("sonicmoe"):
raise RuntimeError(
@@ -75,11 +77,9 @@ class KernelsPlugin(BasePlugin):
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
LOG.info(
f"Applying SonicMoE patches for model type: {cfg.model_config_type}"
)
LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}")
patch_sonicmoe(
cfg.model_config_type,
moe_model_type,
torch_compile=bool(getattr(cfg, "torch_compile", False)),
)

View File

@@ -5,6 +5,7 @@ Different MoE architectures use different routing strategies:
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)
- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)
- mistral4: softmax -> group selection -> topk (with renormalization and scaling)
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
@@ -45,6 +46,8 @@ def get_model_moe_config(model_type: str):
"minimax",
):
return softmax_topk_routing, ActivationType.SWIGLU, "gate"
elif model_type in ("mistral4",):
return softmax_group_topk_routing, ActivationType.SWIGLU, "gate"
elif model_type in (
"glm_moe_dsa",
"deepseek_v3",
@@ -126,6 +129,62 @@ def softmax_topk_routing(
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def softmax_group_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale."""
gate = moe_block.gate
T, H = hidden_states.shape
K = moe_block.top_k
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
n_group = getattr(moe_block, "n_group", 1)
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
scores_for_choice = router_probs
# Group selection: pick top groups, mask the rest
if n_group > 1:
group_scores = (
scores_for_choice.view(-1, n_group, E // n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
)
group_idx = torch.topk(
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
topk_weights = router_probs.gather(1, topk_indices)
# Renormalization + scaling
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
if norm_topk_prob:
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
topk_weights = topk_weights * routed_scaling_factor
# Flatten for moe_general_routing_inputs
token_indices = (
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
.unsqueeze(1)
.expand(T, K)
)
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
flat_token_idx = token_indices.reshape(-1) # [T*K]
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def sigmoid_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

View File

@@ -829,8 +829,9 @@ class ModelLoader:
def _set_z3_leaf_modules(self):
from deepspeed.utils import set_z3_leaf_modules
if self.cfg.model_config_type in MOE_ARCH_BLOCK:
moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type]
moe_type = self.cfg.model_config_type_text or self.cfg.model_config_type
if moe_type in MOE_ARCH_BLOCK:
moe_blocks = MOE_ARCH_BLOCK[moe_type]
moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks
set_z3_leaf_modules(
self.model,

View File

@@ -55,12 +55,12 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
)
processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False
processor_kwargs["tokenizer"] = tokenizer
processor = processor_cls.from_pretrained(
cfg.processor_config,
**processor_kwargs,
)
processor.tokenizer = tokenizer
# Attempt to load image size from processor if available
if (

View File

@@ -57,6 +57,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"olmo3",
"ministral",
"ministral3",
"mistral4",
"afmoe",
]

View File

@@ -195,6 +195,15 @@ def normalize_config(cfg):
cfg.model_config_type = model_config.model_type
# Resolve inner text backbone type for VLM wrappers (e.g. mistral3 -> mistral4)
if callable(getattr(model_config, "get_text_config", None)):
text_config = model_config.get_text_config()
if (
hasattr(text_config, "model_type")
and text_config.model_type != model_config.model_type
):
cfg.model_config_type_text = text_config.model_type
# figure out if the model is llama
cfg.is_llama_derived_model = (
(