address PR code review
This commit is contained in:
@@ -4,6 +4,6 @@ from .args import AuxFreeRouterArgs
|
|||||||
from .plugin import AuxFreeMoEPlugin
|
from .plugin import AuxFreeMoEPlugin
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AuxFreeRouterArgs",
|
|
||||||
"AuxFreeMoEPlugin",
|
"AuxFreeMoEPlugin",
|
||||||
|
"AuxFreeRouterArgs",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -88,12 +88,22 @@ class BaseMoEAdapter:
|
|||||||
try:
|
try:
|
||||||
model_or_layer.router_aux_loss_coef = 0.0
|
model_or_layer.router_aux_loss_coef = 0.0
|
||||||
except Exception: # pragma: no cover - non-critical
|
except Exception: # pragma: no cover - non-critical
|
||||||
pass
|
LOG.debug(
|
||||||
|
"disable_aux_loss: failed to set router_aux_loss_coef on %s",
|
||||||
|
type(model_or_layer).__name__,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
def _register_aux_buffers(
|
def _register_aux_buffers(
|
||||||
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||||
) -> None:
|
) -> None:
|
||||||
device = next(moe_layer.parameters(), torch.tensor(0)).device
|
p = next(moe_layer.parameters(), None)
|
||||||
|
b = next(moe_layer.buffers(), None)
|
||||||
|
device = (
|
||||||
|
p.device
|
||||||
|
if p is not None
|
||||||
|
else (b.device if b is not None else torch.device("cpu"))
|
||||||
|
)
|
||||||
if not hasattr(moe_layer, "_afb_bias"):
|
if not hasattr(moe_layer, "_afb_bias"):
|
||||||
moe_layer.register_buffer(
|
moe_layer.register_buffer(
|
||||||
"_afb_bias", torch.zeros(handle.num_experts, device=device)
|
"_afb_bias", torch.zeros(handle.num_experts, device=device)
|
||||||
@@ -275,7 +285,7 @@ class BailingAdapter(BaseMoEAdapter):
|
|||||||
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
|
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
|
||||||
bias = moe_layer._afb_bias
|
bias = moe_layer._afb_bias
|
||||||
biased_scores = scores_unbiased + bias
|
biased_scores = scores_unbiased + bias
|
||||||
topk_vals, topk_idx = self.group_limited_topk(biased_scores)
|
_, topk_idx = self.group_limited_topk(biased_scores)
|
||||||
weights = torch.gather(scores_unbiased, 1, topk_idx)
|
weights = torch.gather(scores_unbiased, 1, topk_idx)
|
||||||
if self.top_k > 1:
|
if self.top_k > 1:
|
||||||
denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20)
|
denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20)
|
||||||
@@ -376,6 +386,8 @@ def discover_and_prepare_layers(
|
|||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing"
|
"AuxFreeMoE: prepared %d %s layers for aux-free routing",
|
||||||
|
len(handles),
|
||||||
|
adapter.family,
|
||||||
)
|
)
|
||||||
return handles
|
return handles
|
||||||
|
|||||||
@@ -37,14 +37,14 @@ class AuxFreeRouterArgs(BaseModel):
|
|||||||
moe_update_rate: float | None = Field(
|
moe_update_rate: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Per-step bias update rate (gamma). Recommended: 0.005–0.05. "
|
"description": "Per-step bias update rate (gamma). Recommended: 0.005-0.05. "
|
||||||
"If unset, plugin default is 0.01."
|
"If unset, plugin default is 0.01."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
moe_update_momentum: float | None = Field(
|
moe_update_momentum: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "EMA momentum for expert load smoothing (0–1). "
|
"description": "EMA momentum for expert load smoothing (0-1). "
|
||||||
"If unset, plugin default is 0.9."
|
"If unset, plugin default is 0.9."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ class AuxFreeShim:
|
|||||||
else:
|
else:
|
||||||
b = self.state.bias[layer_idx]
|
b = self.state.bias[layer_idx]
|
||||||
biased = logits + b # bias is a buffer
|
biased = logits + b # bias is a buffer
|
||||||
topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1)
|
_topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1)
|
||||||
chosen_logits = torch.gather(logits, -1, topk_idx)
|
chosen_logits = torch.gather(logits, -1, topk_idx)
|
||||||
weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype)
|
weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype)
|
||||||
return topk_idx, weights
|
return topk_idx, weights
|
||||||
|
|||||||
@@ -239,11 +239,17 @@ class AuxFreeMoEPlugin(BasePlugin):
|
|||||||
return dist.group.WORLD
|
return dist.group.WORLD
|
||||||
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
group_start = (rank // ep_size) * ep_size
|
# All ranks must collectively create all EP subgroups in the same order
|
||||||
ranks = tuple(range(group_start, group_start + ep_size))
|
# to avoid deadlocks (dist.new_group is a collective operation).
|
||||||
if ranks not in self._ep_group_cache:
|
world_size = world
|
||||||
self._ep_group_cache[ranks] = dist.new_group(ranks)
|
my_group = None
|
||||||
return self._ep_group_cache[ranks]
|
for group_start in range(0, world_size, ep_size):
|
||||||
|
ranks = tuple(range(group_start, group_start + ep_size))
|
||||||
|
if ranks not in self._ep_group_cache:
|
||||||
|
self._ep_group_cache[ranks] = dist.new_group(ranks)
|
||||||
|
if rank in ranks:
|
||||||
|
my_group = self._ep_group_cache[ranks]
|
||||||
|
return my_group
|
||||||
|
|
||||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||||
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
|
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
|
||||||
|
|||||||
@@ -177,7 +177,9 @@ def softmax_group_topk_routing(
|
|||||||
score_mask = (
|
score_mask = (
|
||||||
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
||||||
)
|
)
|
||||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
scores_for_choice = scores_for_choice.masked_fill(
|
||||||
|
~score_mask.bool(), -float("inf")
|
||||||
|
)
|
||||||
|
|
||||||
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||||
topk_weights = router_probs.gather(1, topk_indices)
|
topk_weights = router_probs.gather(1, topk_indices)
|
||||||
@@ -275,7 +277,9 @@ def sigmoid_topk_routing(
|
|||||||
score_mask = (
|
score_mask = (
|
||||||
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
||||||
)
|
)
|
||||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
scores_for_choice = scores_for_choice.masked_fill(
|
||||||
|
~score_mask.bool(), -float("inf")
|
||||||
|
)
|
||||||
|
|
||||||
# Final topk from (possibly masked) scores
|
# Final topk from (possibly masked) scores
|
||||||
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||||
@@ -316,6 +320,8 @@ def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
|
|||||||
``aux_free_router`` plugin). The counts are later consumed by the
|
``aux_free_router`` plugin). The counts are later consumed by the
|
||||||
``MoeAuxFreeBiasUpdateCallback`` at each training step.
|
``MoeAuxFreeBiasUpdateCallback`` at each training step.
|
||||||
"""
|
"""
|
||||||
|
if hasattr(moe_block, "training") and not moe_block.training:
|
||||||
|
return
|
||||||
afb_counts = getattr(moe_block, "_afb_counts", None)
|
afb_counts = getattr(moe_block, "_afb_counts", None)
|
||||||
if afb_counts is None:
|
if afb_counts is None:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class TestLlama4MoeAuxFree(unittest.TestCase):
|
|||||||
{
|
{
|
||||||
"base_model": "yujiepan/llama-4-tiny-random",
|
"base_model": "yujiepan/llama-4-tiny-random",
|
||||||
"tokenizer_config": "yujiepan/llama-4-tiny-random",
|
"tokenizer_config": "yujiepan/llama-4-tiny-random",
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": False,
|
||||||
"flash_attention": False,
|
"flash_attention": False,
|
||||||
"sequence_len": 512,
|
"sequence_len": 512,
|
||||||
"bf16": False,
|
"bf16": False,
|
||||||
|
|||||||
@@ -3,8 +3,11 @@ Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-t
|
|||||||
Checks that aux-free training loss does not degrade beyond a small tolerance.
|
Checks that aux-free training loss does not degrade beyond a small tolerance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
@@ -61,6 +64,11 @@ class TestMoeAuxParity(unittest.TestCase):
|
|||||||
loss0 = _last_logged_loss(trainer0)
|
loss0 = _last_logged_loss(trainer0)
|
||||||
assert loss0 is not None
|
assert loss0 is not None
|
||||||
|
|
||||||
|
# Release baseline resources before starting aux-free run
|
||||||
|
del model0, trainer0, dataset_meta0
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# Aux-free: plugin + noaux_tc
|
# Aux-free: plugin + noaux_tc
|
||||||
cfg1 = DictDefault(dict(base_cfg))
|
cfg1 = DictDefault(dict(base_cfg))
|
||||||
cfg1.output_dir = f"{temp_dir}/auxfree"
|
cfg1.output_dir = f"{temp_dir}/auxfree"
|
||||||
|
|||||||
@@ -390,7 +390,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
|||||||
|
|
||||||
def test_ep_group_resolution_deferred_until_dist_ready(self):
|
def test_ep_group_resolution_deferred_until_dist_ready(self):
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
dist.destroy_process_group()
|
self.skipTest(
|
||||||
|
"Cannot safely test deferred EP group resolution when a process group is already initialized"
|
||||||
|
)
|
||||||
|
|
||||||
model, block = _build_bailing_model()
|
model, block = _build_bailing_model()
|
||||||
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)
|
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)
|
||||||
|
|||||||
Reference in New Issue
Block a user