feat: add Mistral Small 4 (#3502)
* feat: add mistral small 4 * fix: update mistral common * fix: deepcopy when passing in tokenizer * feat: add doc on reasoning and thinking section * fix: don't use custom tokenizer and quantize experts * chore: update docs and configs * chore: update doc to follow official name * feat: update cce to include mistral4 * chore: move * fix: naming * fix: test mock breaking get_text_config check * fix: enable CCE and add expert block targetting to configs * chore: docs * fix: use act checkpointing * chore: doc * chore: docs * chore: docs
This commit is contained in:
@@ -16,6 +16,7 @@ MOE_ARCH_BLOCK = {
|
||||
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||
"deepseek_v2": "DeepseekV2MoE",
|
||||
"deepseek_v3": "DeepseekV3MoE",
|
||||
"mistral4": "Mistral4MoE",
|
||||
"gpt_oss": "GptOssDecoderLayer",
|
||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||
"afmoe": "AfmoeMoE",
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -73,8 +73,10 @@ plugins:
|
||||
- ministral3
|
||||
- mistral
|
||||
- mistral3
|
||||
- mistral4
|
||||
- mixtral
|
||||
- mllama
|
||||
- nemotron_h
|
||||
- olmo
|
||||
- olmo2
|
||||
- olmo3
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,8 @@ SPARSE_MOE_BLOCK = {
|
||||
"olmoe": "OlmoeSparseMoeBlock",
|
||||
"mixtral": "MixtralSparseMoeBlock",
|
||||
"minimax": "MiniMaxSparseMoeBlock",
|
||||
# softmax -> topk routing (with group-based expert selection)
|
||||
"mistral4": "Mistral4MoE",
|
||||
# sigmoid -> topk routing (with group-based expert selection)
|
||||
"glm_moe_dsa": "GlmMoeDsaMoE",
|
||||
"deepseek_v3": "DeepseekV3MoE",
|
||||
|
||||
@@ -61,9 +61,11 @@ class KernelsPlugin(BasePlugin):
|
||||
return "axolotl.integrations.kernels.KernelsArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
|
||||
|
||||
if cfg.use_scattermoe:
|
||||
self._register_kernels()
|
||||
self._kernelize_model(cfg.model_config_type)
|
||||
self._kernelize_model(moe_model_type)
|
||||
elif cfg.use_sonicmoe:
|
||||
if not importlib.util.find_spec("sonicmoe"):
|
||||
raise RuntimeError(
|
||||
@@ -75,11 +77,9 @@ class KernelsPlugin(BasePlugin):
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
|
||||
|
||||
LOG.info(
|
||||
f"Applying SonicMoE patches for model type: {cfg.model_config_type}"
|
||||
)
|
||||
LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}")
|
||||
patch_sonicmoe(
|
||||
cfg.model_config_type,
|
||||
moe_model_type,
|
||||
torch_compile=bool(getattr(cfg, "torch_compile", False)),
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ Different MoE architectures use different routing strategies:
|
||||
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
|
||||
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)
|
||||
- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)
|
||||
- mistral4: softmax -> group selection -> topk (with renormalization and scaling)
|
||||
|
||||
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
||||
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
||||
@@ -45,6 +46,8 @@ def get_model_moe_config(model_type: str):
|
||||
"minimax",
|
||||
):
|
||||
return softmax_topk_routing, ActivationType.SWIGLU, "gate"
|
||||
elif model_type in ("mistral4",):
|
||||
return softmax_group_topk_routing, ActivationType.SWIGLU, "gate"
|
||||
elif model_type in (
|
||||
"glm_moe_dsa",
|
||||
"deepseek_v3",
|
||||
@@ -126,6 +129,62 @@ def softmax_topk_routing(
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def softmax_group_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale."""
|
||||
gate = moe_block.gate
|
||||
T, H = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
|
||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
scores_for_choice = router_probs
|
||||
|
||||
# Group selection: pick top groups, mask the rest
|
||||
if n_group > 1:
|
||||
group_scores = (
|
||||
scores_for_choice.view(-1, n_group, E // n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
)
|
||||
group_idx = torch.topk(
|
||||
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
|
||||
)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
group_mask.scatter_(1, group_idx, 1)
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
||||
)
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
|
||||
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||
topk_weights = router_probs.gather(1, topk_indices)
|
||||
|
||||
# Renormalization + scaling
|
||||
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||
if norm_topk_prob:
|
||||
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
||||
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
|
||||
# Flatten for moe_general_routing_inputs
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def sigmoid_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
|
||||
@@ -829,8 +829,9 @@ class ModelLoader:
|
||||
def _set_z3_leaf_modules(self):
|
||||
from deepspeed.utils import set_z3_leaf_modules
|
||||
|
||||
if self.cfg.model_config_type in MOE_ARCH_BLOCK:
|
||||
moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type]
|
||||
moe_type = self.cfg.model_config_type_text or self.cfg.model_config_type
|
||||
if moe_type in MOE_ARCH_BLOCK:
|
||||
moe_blocks = MOE_ARCH_BLOCK[moe_type]
|
||||
moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks
|
||||
set_z3_leaf_modules(
|
||||
self.model,
|
||||
|
||||
@@ -55,12 +55,12 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
)
|
||||
|
||||
processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False
|
||||
processor_kwargs["tokenizer"] = tokenizer
|
||||
|
||||
processor = processor_cls.from_pretrained(
|
||||
cfg.processor_config,
|
||||
**processor_kwargs,
|
||||
)
|
||||
processor.tokenizer = tokenizer
|
||||
|
||||
# Attempt to load image size from processor if available
|
||||
if (
|
||||
|
||||
@@ -57,6 +57,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"olmo3",
|
||||
"ministral",
|
||||
"ministral3",
|
||||
"mistral4",
|
||||
"afmoe",
|
||||
]
|
||||
|
||||
|
||||
@@ -195,6 +195,15 @@ def normalize_config(cfg):
|
||||
|
||||
cfg.model_config_type = model_config.model_type
|
||||
|
||||
# Resolve inner text backbone type for VLM wrappers (e.g. mistral3 -> mistral4)
|
||||
if callable(getattr(model_config, "get_text_config", None)):
|
||||
text_config = model_config.get_text_config()
|
||||
if (
|
||||
hasattr(text_config, "model_type")
|
||||
and text_config.model_type != model_config.model_type
|
||||
):
|
||||
cfg.model_config_type_text = text_config.model_type
|
||||
|
||||
# figure out if the model is llama
|
||||
cfg.is_llama_derived_model = (
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user