diff --git a/src/axolotl/integrations/kernels/constants.py b/src/axolotl/integrations/kernels/constants.py index a7d513b5e..8002b3f79 100644 --- a/src/axolotl/integrations/kernels/constants.py +++ b/src/axolotl/integrations/kernels/constants.py @@ -15,6 +15,7 @@ SPARSE_MOE_BLOCK = { "qwen2_moe": "Qwen2MoeSparseMoeBlock", "qwen3_moe": "Qwen3MoeSparseMoeBlock", "qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock", + "qwen3_5_moe_text": "Qwen3_5MoeSparseMoeBlock", "qwen3_next": "Qwen3NextSparseMoeBlock", "qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock", # qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate) @@ -58,7 +59,16 @@ def resolve_moe_block_classes(model_type: str): cls_names = entry if isinstance(entry, list) else [entry] module_path = f"transformers.models.{model_type}.modeling_{model_type}" - module = importlib.import_module(module_path) + try: + module = importlib.import_module(module_path) + except ModuleNotFoundError: + # Text sub-model types (e.g. qwen3_5_moe_text) share the parent module + if model_type.endswith("_text"): + parent_type = model_type.removesuffix("_text") + module_path = f"transformers.models.{parent_type}.modeling_{parent_type}" + module = importlib.import_module(module_path) + else: + raise classes = [] for cls_name in cls_names: