move configs from global config to plugin specific args
This commit is contained in:
@@ -1,2 +1,9 @@
|
|||||||
"""Aux-loss-free (AFB) MoE router integration package."""
|
"""Aux-loss-free (AFB) MoE router integration package."""
|
||||||
|
|
||||||
|
from .args import AuxFreeRouterArgs
|
||||||
|
from .plugin import AuxFreeMoEPlugin
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AuxFreeRouterArgs",
|
||||||
|
"AuxFreeMoEPlugin",
|
||||||
|
]
|
||||||
|
|||||||
72
src/axolotl/integrations/aux_free_router/args.py
Normal file
72
src/axolotl/integrations/aux_free_router/args.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Plugin args for the Aux-Loss-Free MoE router integration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class AuxFreeRouterArgs(BaseModel):
|
||||||
|
"""
|
||||||
|
Input args for Aux-Loss-Free MoE routing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
moe_balance_type: Literal["gshard", "noaux_tc"] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "MoE load balancing strategy: 'gshard' for auxiliary loss, "
|
||||||
|
"'noaux_tc' for aux-loss-free bias updates affecting top-k selection only. "
|
||||||
|
"Defaults to model's native behavior when unset."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
moe_update_rate: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Per-step bias update rate (gamma). Recommended: 0.005–0.05. "
|
||||||
|
"If unset, plugin default is 0.01."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
moe_update_momentum: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "EMA momentum for expert load smoothing (0–1). "
|
||||||
|
"If unset, plugin default is 0.9."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
moe_bias_cap: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Absolute clamp for expert bias magnitude. "
|
||||||
|
"If unset, plugin default is 2.0."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
moe_afb_warmup_steps: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of initial steps to delay aux-free bias updates, "
|
||||||
|
"allowing routing to stabilize. If unset, plugin default is 0."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
moe_bias_sync_group: Literal["world", "ep"] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Reduction group for expert load counts: 'world' (DP) or "
|
||||||
|
"'ep' (expert-parallel group if available). Defaults to 'world' when unset."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@@ -134,6 +134,9 @@ class AuxFreeMoEPlugin(BasePlugin):
|
|||||||
self._shim: Optional[AuxFreeShim] = None
|
self._shim: Optional[AuxFreeShim] = None
|
||||||
self._ep_group_cache: dict[tuple[int, ...], dist.ProcessGroup] = {}
|
self._ep_group_cache: dict[tuple[int, ...], dist.ProcessGroup] = {}
|
||||||
|
|
||||||
|
def get_input_args(self):
|
||||||
|
return "axolotl.integrations.aux_free_router.AuxFreeRouterArgs"
|
||||||
|
|
||||||
def post_model_build(self, cfg, model):
|
def post_model_build(self, cfg, model):
|
||||||
# Enable only when explicitly requested
|
# Enable only when explicitly requested
|
||||||
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
|
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
|
||||||
|
|||||||
@@ -299,6 +299,7 @@ def validate_config(
|
|||||||
AxolotlInputConfig = AxolotlInputConfigBase
|
AxolotlInputConfig = AxolotlInputConfigBase
|
||||||
|
|
||||||
if cfg.plugins:
|
if cfg.plugins:
|
||||||
|
prepare_plugins(cfg)
|
||||||
(
|
(
|
||||||
AxolotlConfigWCapabilities,
|
AxolotlConfigWCapabilities,
|
||||||
AxolotlInputConfig,
|
AxolotlInputConfig,
|
||||||
|
|||||||
@@ -758,44 +758,6 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
llama4_linearized_experts: bool | None = None
|
llama4_linearized_experts: bool | None = None
|
||||||
|
|
||||||
# MoE aux-loss-free (AFB) toggles
|
|
||||||
moe_balance_type: Literal["gshard", "noaux_tc"] | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "MoE load balancing strategy: 'gshard' for auxiliary loss, 'noaux_tc' for aux-loss-free bias updates affecting top-k selection only. Defaults to model's native behavior when unset.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
moe_update_rate: float | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Per-step bias update rate (gamma). Recommended: 0.005–0.05. If unset, plugin default is 0.01.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
moe_update_momentum: float | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "EMA momentum for expert load smoothing (0–1). If unset, plugin default is 0.9.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
moe_bias_cap: float | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Absolute clamp for expert bias magnitude. If unset, plugin default is 2.0.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
moe_afb_warmup_steps: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Number of initial steps to delay aux-free bias updates, allowing routing to stabilize. If unset, plugin default is 0.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
moe_bias_sync_group: Literal["world", "ep"] | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Reduction group for expert load counts: 'world' (DP) or 'ep' (expert-parallel group if available). Defaults to 'world' when unset.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
deepspeed: str | dict[str, Any] | None = Field(
|
deepspeed: str | dict[str, Any] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
Reference in New Issue
Block a user