From 66b2ab8414e3ed96d8ff020347d7f4b955db16e3 Mon Sep 17 00:00:00 2001 From: lhl Date: Thu, 13 Nov 2025 04:06:27 +0900 Subject: [PATCH] move configs from global config to plugin specific args --- .../integrations/aux_free_router/__init__.py | 7 ++ .../integrations/aux_free_router/args.py | 72 +++++++++++++++++++ .../integrations/aux_free_router/plugin.py | 3 + src/axolotl/utils/config/__init__.py | 1 + src/axolotl/utils/schemas/config.py | 38 ---------- 5 files changed, 83 insertions(+), 38 deletions(-) create mode 100644 src/axolotl/integrations/aux_free_router/args.py diff --git a/src/axolotl/integrations/aux_free_router/__init__.py b/src/axolotl/integrations/aux_free_router/__init__.py index b3f78049b..8eac77224 100644 --- a/src/axolotl/integrations/aux_free_router/__init__.py +++ b/src/axolotl/integrations/aux_free_router/__init__.py @@ -1,2 +1,9 @@ """Aux-loss-free (AFB) MoE router integration package.""" +from .args import AuxFreeRouterArgs +from .plugin import AuxFreeMoEPlugin + +__all__ = [ + "AuxFreeRouterArgs", + "AuxFreeMoEPlugin", +] diff --git a/src/axolotl/integrations/aux_free_router/args.py b/src/axolotl/integrations/aux_free_router/args.py new file mode 100644 index 000000000..d284d4d66 --- /dev/null +++ b/src/axolotl/integrations/aux_free_router/args.py @@ -0,0 +1,72 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Plugin args for the Aux-Loss-Free MoE router integration. +""" + +from typing import Literal + +from pydantic import BaseModel, Field + + +class AuxFreeRouterArgs(BaseModel): + """ + Input args for Aux-Loss-Free MoE routing. + """ + + moe_balance_type: Literal["gshard", "noaux_tc"] | None = Field( + default=None, + json_schema_extra={ + "description": "MoE load balancing strategy: 'gshard' for auxiliary loss, " + "'noaux_tc' for aux-loss-free bias updates affecting top-k selection only. " + "Defaults to model's native behavior when unset." + }, + ) + moe_update_rate: float | None = Field( + default=None, + json_schema_extra={ + "description": "Per-step bias update rate (gamma). Recommended: 0.005–0.05. " + "If unset, plugin default is 0.01." + }, + ) + moe_update_momentum: float | None = Field( + default=None, + json_schema_extra={ + "description": "EMA momentum for expert load smoothing (0–1). " + "If unset, plugin default is 0.9." + }, + ) + moe_bias_cap: float | None = Field( + default=None, + json_schema_extra={ + "description": "Absolute clamp for expert bias magnitude. " + "If unset, plugin default is 2.0." + }, + ) + moe_afb_warmup_steps: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of initial steps to delay aux-free bias updates, " + "allowing routing to stabilize. If unset, plugin default is 0." + }, + ) + moe_bias_sync_group: Literal["world", "ep"] | None = Field( + default=None, + json_schema_extra={ + "description": "Reduction group for expert load counts: 'world' (DP) or " + "'ep' (expert-parallel group if available). Defaults to 'world' when unset." + }, + ) + diff --git a/src/axolotl/integrations/aux_free_router/plugin.py b/src/axolotl/integrations/aux_free_router/plugin.py index 4f026ed1c..31893e281 100644 --- a/src/axolotl/integrations/aux_free_router/plugin.py +++ b/src/axolotl/integrations/aux_free_router/plugin.py @@ -134,6 +134,9 @@ class AuxFreeMoEPlugin(BasePlugin): self._shim: Optional[AuxFreeShim] = None self._ep_group_cache: dict[tuple[int, ...], dist.ProcessGroup] = {} + def get_input_args(self): + return "axolotl.integrations.aux_free_router.AuxFreeRouterArgs" + def post_model_build(self, cfg, model): # Enable only when explicitly requested if getattr(cfg, "moe_balance_type", None) != "noaux_tc": diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index c5bad62de..e6405a45f 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -299,6 +299,7 @@ def validate_config( AxolotlInputConfig = AxolotlInputConfigBase if cfg.plugins: + prepare_plugins(cfg) ( AxolotlConfigWCapabilities, AxolotlInputConfig, diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 1ba657b5c..1ef83cc90 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -758,44 +758,6 @@ class AxolotlInputConfig( llama4_linearized_experts: bool | None = None - # MoE aux-loss-free (AFB) toggles - moe_balance_type: Literal["gshard", "noaux_tc"] | None = Field( - default=None, - json_schema_extra={ - "description": "MoE load balancing strategy: 'gshard' for auxiliary loss, 'noaux_tc' for aux-loss-free bias updates affecting top-k selection only. Defaults to model's native behavior when unset.", - }, - ) - moe_update_rate: float | None = Field( - default=None, - json_schema_extra={ - "description": "Per-step bias update rate (gamma). Recommended: 0.005–0.05. If unset, plugin default is 0.01.", - }, - ) - moe_update_momentum: float | None = Field( - default=None, - json_schema_extra={ - "description": "EMA momentum for expert load smoothing (0–1). If unset, plugin default is 0.9.", - }, - ) - moe_bias_cap: float | None = Field( - default=None, - json_schema_extra={ - "description": "Absolute clamp for expert bias magnitude. If unset, plugin default is 2.0.", - }, - ) - moe_afb_warmup_steps: int | None = Field( - default=None, - json_schema_extra={ - "description": "Number of initial steps to delay aux-free bias updates, allowing routing to stabilize. If unset, plugin default is 0.", - }, - ) - moe_bias_sync_group: Literal["world", "ep"] | None = Field( - default=None, - json_schema_extra={ - "description": "Reduction group for expert load counts: 'world' (DP) or 'ep' (expert-parallel group if available). Defaults to 'world' when unset.", - }, - ) - deepspeed: str | dict[str, Any] | None = Field( default=None, json_schema_extra={