move configs from global config to plugin specific args

This commit is contained in:
lhl
2025-11-13 04:06:27 +09:00
committed by Wing Lian
parent 676d5e855d
commit 66b2ab8414
5 changed files with 83 additions and 38 deletions

View File

@@ -1,2 +1,9 @@
"""Aux-loss-free (AFB) MoE router integration package."""
from .args import AuxFreeRouterArgs
from .plugin import AuxFreeMoEPlugin
__all__ = [
"AuxFreeRouterArgs",
"AuxFreeMoEPlugin",
]

View File

@@ -0,0 +1,72 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Plugin args for the Aux-Loss-Free MoE router integration.
"""
from typing import Literal
from pydantic import BaseModel, Field
class AuxFreeRouterArgs(BaseModel):
"""
Input args for Aux-Loss-Free MoE routing.
"""
moe_balance_type: Literal["gshard", "noaux_tc"] | None = Field(
default=None,
json_schema_extra={
"description": "MoE load balancing strategy: 'gshard' for auxiliary loss, "
"'noaux_tc' for aux-loss-free bias updates affecting top-k selection only. "
"Defaults to model's native behavior when unset."
},
)
moe_update_rate: float | None = Field(
default=None,
json_schema_extra={
"description": "Per-step bias update rate (gamma). Recommended: 0.0050.05. "
"If unset, plugin default is 0.01."
},
)
moe_update_momentum: float | None = Field(
default=None,
json_schema_extra={
"description": "EMA momentum for expert load smoothing (01). "
"If unset, plugin default is 0.9."
},
)
moe_bias_cap: float | None = Field(
default=None,
json_schema_extra={
"description": "Absolute clamp for expert bias magnitude. "
"If unset, plugin default is 2.0."
},
)
moe_afb_warmup_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of initial steps to delay aux-free bias updates, "
"allowing routing to stabilize. If unset, plugin default is 0."
},
)
moe_bias_sync_group: Literal["world", "ep"] | None = Field(
default=None,
json_schema_extra={
"description": "Reduction group for expert load counts: 'world' (DP) or "
"'ep' (expert-parallel group if available). Defaults to 'world' when unset."
},
)

View File

@@ -134,6 +134,9 @@ class AuxFreeMoEPlugin(BasePlugin):
self._shim: Optional[AuxFreeShim] = None
self._ep_group_cache: dict[tuple[int, ...], dist.ProcessGroup] = {}
def get_input_args(self):
return "axolotl.integrations.aux_free_router.AuxFreeRouterArgs"
def post_model_build(self, cfg, model):
# Enable only when explicitly requested
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":

View File

@@ -299,6 +299,7 @@ def validate_config(
AxolotlInputConfig = AxolotlInputConfigBase
if cfg.plugins:
prepare_plugins(cfg)
(
AxolotlConfigWCapabilities,
AxolotlInputConfig,

View File

@@ -758,44 +758,6 @@ class AxolotlInputConfig(
llama4_linearized_experts: bool | None = None
# MoE aux-loss-free (AFB) toggles
moe_balance_type: Literal["gshard", "noaux_tc"] | None = Field(
default=None,
json_schema_extra={
"description": "MoE load balancing strategy: 'gshard' for auxiliary loss, 'noaux_tc' for aux-loss-free bias updates affecting top-k selection only. Defaults to model's native behavior when unset.",
},
)
moe_update_rate: float | None = Field(
default=None,
json_schema_extra={
"description": "Per-step bias update rate (gamma). Recommended: 0.0050.05. If unset, plugin default is 0.01.",
},
)
moe_update_momentum: float | None = Field(
default=None,
json_schema_extra={
"description": "EMA momentum for expert load smoothing (01). If unset, plugin default is 0.9.",
},
)
moe_bias_cap: float | None = Field(
default=None,
json_schema_extra={
"description": "Absolute clamp for expert bias magnitude. If unset, plugin default is 2.0.",
},
)
moe_afb_warmup_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of initial steps to delay aux-free bias updates, allowing routing to stabilize. If unset, plugin default is 0.",
},
)
moe_bias_sync_group: Literal["world", "ep"] | None = Field(
default=None,
json_schema_extra={
"description": "Reduction group for expert load counts: 'world' (DP) or 'ep' (expert-parallel group if available). Defaults to 'world' when unset.",
},
)
deepspeed: str | dict[str, Any] | None = Field(
default=None,
json_schema_extra={