diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 1e9819c56..5ac66260a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -817,11 +817,13 @@ def load_model( ) if cfg.model_config_type in MOE_ARCH_BLOCK: + moe_blocks = MOE_ARCH_BLOCK[cfg.model_config_type] + moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks set_z3_leaf_modules( model, [ get_module_class_from_name(model, module_name) - for module_name in MOE_ARCH_BLOCK[cfg.model_config_type] + for module_name in moe_blocks ], )