Compare commits
14 Commits
textui
...
6636e5de7e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6636e5de7e | ||
|
|
0a566d7a15 | ||
|
|
5acb1b0ade | ||
|
|
4009a2ba5f | ||
|
|
66b2ab8414 | ||
|
|
676d5e855d | ||
|
|
966a4555db | ||
|
|
ad0c825bcb | ||
|
|
46d677876e | ||
|
|
6eac9ac372 | ||
|
|
949cdf01eb | ||
|
|
a0019021dd | ||
|
|
2af7475fdf | ||
|
|
3e4688289c |
@@ -3,8 +3,7 @@ set -e
|
|||||||
|
|
||||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||||
|
|
||||||
set -o pipefail
|
curl --silent -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||||
curl --silent --show-error --fail --retry 3 --retry-delay 5 -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
|
||||||
# hf download "NousResearch/Meta-Llama-3-8B"
|
# hf download "NousResearch/Meta-Llama-3-8B"
|
||||||
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||||
# hf download "microsoft/Phi-4-reasoning"
|
# hf download "microsoft/Phi-4-reasoning"
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ coverage:
|
|||||||
only_pulls: false
|
only_pulls: false
|
||||||
flags: null
|
flags: null
|
||||||
paths: null
|
paths: null
|
||||||
informational: true
|
|
||||||
|
|
||||||
parsers:
|
parsers:
|
||||||
gcov:
|
gcov:
|
||||||
|
|||||||
@@ -91,7 +91,6 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
|
|||||||
type=click.Path(exists=True, path_type=str),
|
type=click.Path(exists=True, path_type=str),
|
||||||
help="YAML config for sweeping hyperparameters",
|
help="YAML config for sweeping hyperparameters",
|
||||||
)
|
)
|
||||||
@click.option("--tui", is_flag=True, default=False, help="Enable TUI dashboard")
|
|
||||||
@add_options_from_dataclass(TrainerCliArgs)
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
@filter_none_kwargs
|
@filter_none_kwargs
|
||||||
@@ -102,7 +101,6 @@ def train(
|
|||||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||||
cloud: str | None = None,
|
cloud: str | None = None,
|
||||||
sweep: str | None = None,
|
sweep: str | None = None,
|
||||||
tui: bool = False,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -120,10 +118,6 @@ def train(
|
|||||||
# Extract launcher args from extra args (after --)
|
# Extract launcher args from extra args (after --)
|
||||||
launcher_args = ctx.args if ctx.args else []
|
launcher_args = ctx.args if ctx.args else []
|
||||||
|
|
||||||
# Handle --tui flag: set env var so subprocess workers pick it up
|
|
||||||
if tui:
|
|
||||||
os.environ["AXOLOTL_TUI"] = "1"
|
|
||||||
|
|
||||||
# Handle Ray launcher override
|
# Handle Ray launcher override
|
||||||
_launcher = None if kwargs.get("use_ray") else launcher
|
_launcher = None if kwargs.get("use_ray") else launcher
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
import queue
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -35,101 +34,22 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|
||||||
# Start TUI early (before data loading) so it captures preprocessing events
|
plugin_manager = PluginManager.get_instance()
|
||||||
tui_renderer = None
|
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
|
||||||
tui_queue: queue.Queue | None = None
|
if not dataset_meta:
|
||||||
is_rank_0 = int(os.getenv("LOCAL_RANK", "0")) == 0
|
if cfg.rl:
|
||||||
if is_rank_0:
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
from axolotl.train import _is_tui_enabled
|
else:
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
if _is_tui_enabled(cfg):
|
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
import queue as _queue
|
|
||||||
|
|
||||||
from axolotl.train import _get_tui_config
|
del model, tokenizer, trainer
|
||||||
from axolotl.tui.config import TUIConfig
|
|
||||||
from axolotl.tui.renderer import TUIRenderer
|
|
||||||
|
|
||||||
tui_config_dict = _get_tui_config(cfg)
|
gc.collect()
|
||||||
tui_config = (
|
|
||||||
TUIConfig(**tui_config_dict)
|
|
||||||
if isinstance(tui_config_dict, dict)
|
|
||||||
else tui_config_dict
|
|
||||||
)
|
|
||||||
tui_queue = _queue.Queue(maxsize=4096)
|
|
||||||
tui_renderer = TUIRenderer(config=tui_config, metric_queue=tui_queue)
|
|
||||||
|
|
||||||
# Send initial run info
|
plugin_manager = PluginManager.get_instance()
|
||||||
model_name = cfg.base_model or ""
|
plugin_manager.post_train_unload(cfg)
|
||||||
training_mode = str(cfg.rl) if cfg.rl else "sft"
|
|
||||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
||||||
try:
|
|
||||||
tui_queue.put_nowait(
|
|
||||||
{
|
|
||||||
"type": "run_info",
|
|
||||||
"model_name": model_name,
|
|
||||||
"training_mode": training_mode,
|
|
||||||
"world_size": world_size,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except _queue.Full:
|
|
||||||
pass
|
|
||||||
|
|
||||||
tui_renderer.start()
|
|
||||||
|
|
||||||
# Attach logging handler early
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from axolotl.tui.callback import _TUILogHandler
|
|
||||||
|
|
||||||
_early_log_handler = _TUILogHandler(
|
|
||||||
tui_queue, min_level=tui_config.log_level
|
|
||||||
)
|
|
||||||
_early_log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
|
|
||||||
# Attach to BOTH root and axolotl loggers because axolotl logger
|
|
||||||
# has propagate=False so root handler never sees axolotl.* messages
|
|
||||||
root_logger = logging.getLogger()
|
|
||||||
root_logger.addHandler(_early_log_handler)
|
|
||||||
axolotl_logger = logging.getLogger("axolotl")
|
|
||||||
axolotl_logger.addHandler(_early_log_handler)
|
|
||||||
|
|
||||||
# Stash refs on cfg so train() can reuse the renderer
|
|
||||||
cfg._tui_renderer = tui_renderer
|
|
||||||
cfg._tui_queue = tui_queue
|
|
||||||
cfg._tui_early_log_handler = _early_log_handler
|
|
||||||
|
|
||||||
try:
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
|
||||||
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
|
|
||||||
if not dataset_meta:
|
|
||||||
if cfg.rl:
|
|
||||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
else:
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
|
|
||||||
del model, tokenizer, trainer
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
|
||||||
plugin_manager.post_train_unload(cfg)
|
|
||||||
finally:
|
|
||||||
# If the TUI renderer started early but train() didn't get to stop it
|
|
||||||
# (e.g., error during data loading), clean up here
|
|
||||||
if tui_renderer is not None and not tui_renderer._stop_event.is_set():
|
|
||||||
try:
|
|
||||||
if tui_queue is not None:
|
|
||||||
tui_queue.put_nowait({"type": "done"})
|
|
||||||
except queue.Full:
|
|
||||||
pass
|
|
||||||
tui_renderer.stop()
|
|
||||||
# Remove early log handler from both root and axolotl loggers
|
|
||||||
if hasattr(cfg, "_tui_early_log_handler"):
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.getLogger().removeHandler(cfg._tui_early_log_handler)
|
|
||||||
logging.getLogger("axolotl").removeHandler(cfg._tui_early_log_handler)
|
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
|
|||||||
50
src/axolotl/integrations/aux_free_router/README.md
Normal file
50
src/axolotl/integrations/aux_free_router/README.md
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# Aux-Loss-Free MoE Router Plugin
|
||||||
|
|
||||||
|
This integration adds an aux-loss-free (AFB) gating option to compatible MoE architectures without forking model code.
|
||||||
|
|
||||||
|
Summary
|
||||||
|
- Bias only affects expert selection (top-k); mixture weights come from unbiased logits.
|
||||||
|
- Per-expert token loads are accumulated on device and reduced across DP or EP groups.
|
||||||
|
- Bias is updated post-optimizer step outside autograd using EMA-smoothed loads.
|
||||||
|
- Existing aux loss is disabled when aux-free is enabled to avoid double signals.
|
||||||
|
|
||||||
|
Enable
|
||||||
|
- Add the plugin to your YAML, then set the aux-free toggle:
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin
|
||||||
|
|
||||||
|
moe_balance_type: noaux_tc
|
||||||
|
moe_update_rate: 0.01 # default if unset
|
||||||
|
moe_update_momentum: 0.9 # default if unset
|
||||||
|
moe_bias_cap: 2.0 # default if unset
|
||||||
|
moe_afb_warmup_steps: 100 # optional
|
||||||
|
moe_bias_sync_group: world # or 'ep' if expert_parallel_size > 1
|
||||||
|
expert_parallel_size: 1 # set to your EP width when using moe_bias_sync_group: ep
|
||||||
|
|
||||||
|
Config keys
|
||||||
|
- moe_balance_type: gshard (auxiliary loss) | noaux_tc (aux-free). Default: model native.
|
||||||
|
- moe_update_rate: bias update rate (gamma). Default: 0.01.
|
||||||
|
- moe_update_momentum: EMA momentum for load smoothing. Default: 0.9.
|
||||||
|
- moe_bias_cap: absolute clamp for bias. Default: 2.0.
|
||||||
|
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
|
||||||
|
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
|
||||||
|
- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world).
|
||||||
|
|
||||||
|
Compatibility
|
||||||
|
- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers.
|
||||||
|
- Pass-through: Models with native aux-free routing (e.g., DeepSeek-V3) are left unmodified; only telemetry may be added in future.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
- If you also enable Liger’s aux-loss paths, the plugin neutralizes aux loss when aux-free is on.
|
||||||
|
- Telemetry: logs per-layer min/mean/max token loads, `|bias| max`, and bias sign flip fraction using the Trainer’s `logging_steps` cadence.
|
||||||
|
- Sample packing: packed batches are compatible with aux-free routing. Because load counts are accumulated on-device per expert before reduction, packing tends to smooth token histograms and reduce bias oscillation. Keep `pad_to_sequence_len: true` when packing to preserve the target token budget per expert.
|
||||||
|
|
||||||
|
Telemetry metrics
|
||||||
|
- `moe_afb/l{idx}_load_min|mean|max`: token frequency per expert after reduction (0–1 range, sums to 1).
|
||||||
|
- `moe_afb/l{idx}_bias_abs_max`: absolute maximum of the learned bias for the layer.
|
||||||
|
- `moe_afb/l{idx}_bias_sign_flip_frac`: fraction of experts whose bias sign changed since the previous step (simple oscillation indicator).
|
||||||
|
|
||||||
|
Usage tips
|
||||||
|
- Increase `logging_steps` if router telemetry becomes noisy for large jobs—the plugin follows the Trainer’s logging cadence.
|
||||||
|
- Compare aux-free vs. aux-loss load metrics by plotting the `load_*` series; aux-free typically tightens min/max spread without the auxiliary loss term.
|
||||||
9
src/axolotl/integrations/aux_free_router/__init__.py
Normal file
9
src/axolotl/integrations/aux_free_router/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""Aux-loss-free (AFB) MoE router integration package."""
|
||||||
|
|
||||||
|
from .args import AuxFreeRouterArgs
|
||||||
|
from .plugin import AuxFreeMoEPlugin
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AuxFreeMoEPlugin",
|
||||||
|
"AuxFreeRouterArgs",
|
||||||
|
]
|
||||||
393
src/axolotl/integrations/aux_free_router/adapters.py
Normal file
393
src/axolotl/integrations/aux_free_router/adapters.py
Normal file
@@ -0,0 +1,393 @@
|
|||||||
|
"""Architecture-specific adapters for aux-loss-free MoE routing.
|
||||||
|
|
||||||
|
Each adapter discovers MoE layers for a model family and patches only the
|
||||||
|
router/gate to inject per-expert bias into expert selection while keeping
|
||||||
|
mixture weights from unbiased logits. Expert dispatch is left untouched so
|
||||||
|
the patching composes with any expert backend (eager, ScatterMoE, SonicMoE).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
from .core import AuxFreeShim
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LayerHandle:
|
||||||
|
layer: nn.Module
|
||||||
|
layer_idx: int
|
||||||
|
num_experts: int
|
||||||
|
top_k: int
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMoEAdapter:
|
||||||
|
"""Base adapter that discovers MoE layers and patches their routing.
|
||||||
|
|
||||||
|
Concrete adapters implement discovery, attribute extraction, and
|
||||||
|
architecture-specific router patching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
family: str = "generic"
|
||||||
|
|
||||||
|
def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim
|
||||||
|
return False
|
||||||
|
|
||||||
|
def find_moe_layers(
|
||||||
|
self, model: nn.Module
|
||||||
|
) -> Iterable[nn.Module]: # pragma: no cover
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_top_k(self, moe_layer: nn.Module) -> int:
|
||||||
|
"""Resolve top_k from the MoE layer, checking common attribute paths."""
|
||||||
|
for attr_path in [
|
||||||
|
("top_k",),
|
||||||
|
("num_experts_per_tok",),
|
||||||
|
("gate", "top_k"),
|
||||||
|
("router", "top_k"),
|
||||||
|
]:
|
||||||
|
obj: object = moe_layer
|
||||||
|
for attr in attr_path:
|
||||||
|
obj = getattr(obj, attr, None)
|
||||||
|
if obj is None:
|
||||||
|
break
|
||||||
|
if isinstance(obj, int):
|
||||||
|
return obj
|
||||||
|
return 2
|
||||||
|
|
||||||
|
def get_num_experts(self, moe_layer: nn.Module) -> int:
|
||||||
|
"""Resolve num_experts from the MoE layer, checking common attribute paths."""
|
||||||
|
for attr_path in [
|
||||||
|
("num_experts",),
|
||||||
|
("num_local_experts",),
|
||||||
|
("gate", "num_experts"),
|
||||||
|
("router", "num_experts"),
|
||||||
|
("experts", "num_experts"),
|
||||||
|
]:
|
||||||
|
obj: object = moe_layer
|
||||||
|
for attr in attr_path:
|
||||||
|
obj = getattr(obj, attr, None)
|
||||||
|
if obj is None:
|
||||||
|
break
|
||||||
|
if isinstance(obj, int):
|
||||||
|
return obj
|
||||||
|
raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}")
|
||||||
|
|
||||||
|
def disable_aux_loss(self, model_or_layer: nn.Module) -> None:
|
||||||
|
# Best-effort: zero router aux loss coef if present
|
||||||
|
if hasattr(model_or_layer, "router_aux_loss_coef"):
|
||||||
|
try:
|
||||||
|
model_or_layer.router_aux_loss_coef = 0.0
|
||||||
|
except Exception: # pragma: no cover - non-critical
|
||||||
|
LOG.debug(
|
||||||
|
"disable_aux_loss: failed to set router_aux_loss_coef on %s",
|
||||||
|
type(model_or_layer).__name__,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _register_aux_buffers(
|
||||||
|
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||||
|
) -> None:
|
||||||
|
p = next(moe_layer.parameters(), None)
|
||||||
|
b = next(moe_layer.buffers(), None)
|
||||||
|
device = (
|
||||||
|
p.device
|
||||||
|
if p is not None
|
||||||
|
else (b.device if b is not None else torch.device("cpu"))
|
||||||
|
)
|
||||||
|
if not hasattr(moe_layer, "_afb_bias"):
|
||||||
|
moe_layer.register_buffer(
|
||||||
|
"_afb_bias", torch.zeros(handle.num_experts, device=device)
|
||||||
|
)
|
||||||
|
if not hasattr(moe_layer, "_afb_counts"):
|
||||||
|
moe_layer.register_buffer(
|
||||||
|
"_afb_counts", torch.zeros(handle.num_experts, device=device)
|
||||||
|
)
|
||||||
|
if not hasattr(moe_layer, "_afb_ema"):
|
||||||
|
moe_layer.register_buffer(
|
||||||
|
"_afb_ema", torch.zeros(handle.num_experts, device=device)
|
||||||
|
)
|
||||||
|
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
|
||||||
|
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
|
||||||
|
shim.register_layer_buffers(handle.layer_idx, moe_layer)
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||||
|
) -> None:
|
||||||
|
"""Attach per-layer buffers. Subclasses override to also patch routing."""
|
||||||
|
self._register_aux_buffers(moe_layer, handle, shim)
|
||||||
|
|
||||||
|
def uses_kernel_routing(self, moe_layer: nn.Module) -> bool:
|
||||||
|
"""Return True when a kernel backend (SonicMoE / ScatterMoE) has
|
||||||
|
already replaced the block forward, meaning the routing is handled
|
||||||
|
inside the kernel forward and we should NOT patch the router."""
|
||||||
|
cls = type(moe_layer)
|
||||||
|
# SonicMoE stores the original forward when it patches a class.
|
||||||
|
if hasattr(cls, "_original_forward"):
|
||||||
|
return True
|
||||||
|
# ScatterMoE replaces via kernels library; check for the marker.
|
||||||
|
if hasattr(cls, "_kernel_forward"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralAdapter(BaseMoEAdapter):
|
||||||
|
"""Patches the TopKRouter for Mixtral / Qwen-MoE style softmax→topk
|
||||||
|
routing so that biased logits drive expert *selection* while unbiased
|
||||||
|
softmax scores drive mixture *weights*.
|
||||||
|
|
||||||
|
Works with transformers v5 where experts are fused 3D tensors and
|
||||||
|
the router is ``MixtralTopKRouter`` (returns a 3-tuple).
|
||||||
|
"""
|
||||||
|
|
||||||
|
family = "mixtral"
|
||||||
|
|
||||||
|
def matches(self, model: nn.Module) -> bool:
|
||||||
|
return (
|
||||||
|
getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
|
||||||
|
)
|
||||||
|
|
||||||
|
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||||
|
for m in model.modules():
|
||||||
|
if m.__class__.__name__.endswith("SparseMoeBlock"):
|
||||||
|
yield m
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||||
|
) -> None:
|
||||||
|
self._register_aux_buffers(moe_layer, handle, shim)
|
||||||
|
if not self.uses_kernel_routing(moe_layer):
|
||||||
|
self._patch_router(moe_layer)
|
||||||
|
else:
|
||||||
|
LOG.info(
|
||||||
|
"AuxFreeMoE: kernel backend detected on %s; "
|
||||||
|
"skipping router patch (kernel routing handles bias)",
|
||||||
|
type(moe_layer).__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _patch_router(self, moe_layer: nn.Module) -> None:
|
||||||
|
"""Patch the TopKRouter to inject aux-free bias into expert selection."""
|
||||||
|
gate = getattr(moe_layer, "gate", None)
|
||||||
|
if gate is None:
|
||||||
|
LOG.info("MixtralAdapter: layer missing gate; skipping aux-free patch")
|
||||||
|
return
|
||||||
|
if getattr(gate, "_afb_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Capture reference to the MoE block for bias / counts access.
|
||||||
|
block_ref = moe_layer
|
||||||
|
|
||||||
|
def afb_router_forward(self, hidden_states: torch.Tensor):
|
||||||
|
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
||||||
|
router_logits = F.linear(hidden_states, self.weight)
|
||||||
|
router_probs = F.softmax(router_logits.float(), dim=-1)
|
||||||
|
|
||||||
|
# Biased selection, unbiased weights
|
||||||
|
bias = block_ref._afb_bias
|
||||||
|
biased = router_probs + bias
|
||||||
|
_, router_indices = torch.topk(biased, self.top_k, dim=-1)
|
||||||
|
router_scores = torch.gather(router_probs, 1, router_indices)
|
||||||
|
|
||||||
|
# Renormalize (Mixtral always normalizes; Qwen checks config)
|
||||||
|
if getattr(self, "norm_topk_prob", True):
|
||||||
|
router_scores = router_scores / router_scores.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Accumulate counts for the bias-update callback
|
||||||
|
flat_idx = router_indices.reshape(-1)
|
||||||
|
counts = torch.bincount(flat_idx, minlength=self.num_experts)
|
||||||
|
block_ref._afb_counts.add_(counts.to(block_ref._afb_counts.dtype))
|
||||||
|
|
||||||
|
return router_probs, router_scores, router_indices
|
||||||
|
|
||||||
|
gate.forward = afb_router_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
|
||||||
|
gate._afb_patched = True
|
||||||
|
moe_layer._afb_patched = True
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3Adapter(MixtralAdapter):
|
||||||
|
family = "qwen3_moe"
|
||||||
|
|
||||||
|
def matches(self, model: nn.Module) -> bool:
|
||||||
|
return getattr(getattr(model, "config", object()), "model_type", "") in (
|
||||||
|
"qwen3_moe",
|
||||||
|
"qwen2_moe",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35MoeAdapter(MixtralAdapter):
|
||||||
|
"""Adapter for Qwen 3.5 MoE models.
|
||||||
|
|
||||||
|
Same softmax→topk router pattern as Mixtral/Qwen3. The shared expert
|
||||||
|
is handled by the block forward (untouched by router-level patching).
|
||||||
|
"""
|
||||||
|
|
||||||
|
family = "qwen3_5_moe"
|
||||||
|
|
||||||
|
def matches(self, model: nn.Module) -> bool:
|
||||||
|
return getattr(getattr(model, "config", object()), "model_type", "") in (
|
||||||
|
"qwen3_5_moe",
|
||||||
|
"qwen3_5_moe_text",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BailingAdapter(BaseMoEAdapter):
|
||||||
|
family = "bailing_moe"
|
||||||
|
|
||||||
|
def matches(self, model: nn.Module) -> bool:
|
||||||
|
cfg = getattr(model, "config", None)
|
||||||
|
if cfg is None:
|
||||||
|
return False
|
||||||
|
model_type = getattr(cfg, "model_type", "") or ""
|
||||||
|
if model_type in ("bailing_moe", "bailing_moe_v2", "ring_moe", "ring"):
|
||||||
|
return True
|
||||||
|
cfg_name = cfg.__class__.__name__.lower()
|
||||||
|
return "bailingmoev2" in cfg_name or "ring" in cfg_name
|
||||||
|
|
||||||
|
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||||
|
for m in model.modules():
|
||||||
|
if m.__class__.__name__ == "BailingMoeV2SparseMoeBlock":
|
||||||
|
yield m
|
||||||
|
|
||||||
|
def get_num_experts(self, moe_layer: nn.Module) -> int:
|
||||||
|
if hasattr(moe_layer, "num_experts"):
|
||||||
|
return int(moe_layer.num_experts)
|
||||||
|
cfg = getattr(moe_layer, "config", None)
|
||||||
|
if cfg is None:
|
||||||
|
raise AttributeError(f"Cannot determine num_experts for {type(moe_layer)}")
|
||||||
|
return int(cfg.num_experts)
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||||
|
) -> None:
|
||||||
|
self._register_aux_buffers(moe_layer, handle, shim)
|
||||||
|
self._patch_bailing_gate(moe_layer)
|
||||||
|
|
||||||
|
def _patch_bailing_gate(self, moe_layer: nn.Module) -> None:
|
||||||
|
gate = getattr(moe_layer, "gate", None)
|
||||||
|
if gate is None:
|
||||||
|
LOG.info("BailingAdapter: layer missing gate; skipping aux-free patch")
|
||||||
|
return
|
||||||
|
if getattr(gate, "_afb_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
def afb_gate_forward(self, hidden_states: torch.Tensor):
|
||||||
|
flat = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
logits = F.linear(flat.float(), self.weight.float())
|
||||||
|
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
|
||||||
|
bias = moe_layer._afb_bias
|
||||||
|
biased_scores = scores_unbiased + bias
|
||||||
|
_, topk_idx = self.group_limited_topk(biased_scores)
|
||||||
|
weights = torch.gather(scores_unbiased, 1, topk_idx)
|
||||||
|
if self.top_k > 1:
|
||||||
|
denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20)
|
||||||
|
weights = weights / denom
|
||||||
|
weights = weights * self.routed_scaling_factor
|
||||||
|
|
||||||
|
flat_topk = topk_idx.reshape(-1)
|
||||||
|
counts = torch.bincount(flat_topk, minlength=bias.numel())
|
||||||
|
moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype))
|
||||||
|
|
||||||
|
return topk_idx, weights.to(hidden_states.dtype), logits
|
||||||
|
|
||||||
|
gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
|
||||||
|
gate._afb_patched = True
|
||||||
|
|
||||||
|
|
||||||
|
class Llama4Adapter(BaseMoEAdapter):
|
||||||
|
family = "llama4"
|
||||||
|
|
||||||
|
def matches(self, model: nn.Module) -> bool:
|
||||||
|
return getattr(getattr(model, "config", object()), "model_type", "") == "llama4"
|
||||||
|
|
||||||
|
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||||
|
for m in model.modules():
|
||||||
|
if m.__class__.__name__ == "Llama4TextMoe":
|
||||||
|
yield m
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
|
||||||
|
) -> None:
|
||||||
|
self._register_aux_buffers(moe_layer, handle, shim)
|
||||||
|
self._patch_llama4_router(moe_layer)
|
||||||
|
|
||||||
|
def _patch_llama4_router(self, moe_layer: nn.Module) -> None:
|
||||||
|
router = getattr(moe_layer, "router", None)
|
||||||
|
if router is None:
|
||||||
|
LOG.info("Llama4Adapter: layer missing router; skipping aux-free patch")
|
||||||
|
return
|
||||||
|
if getattr(router, "_afb_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
def afb_router_forward(self, hidden_states: torch.Tensor):
|
||||||
|
flat = (
|
||||||
|
hidden_states
|
||||||
|
if hidden_states.dim() == 2
|
||||||
|
else hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
)
|
||||||
|
router_logits = F.linear(flat, self.weight, self.bias)
|
||||||
|
bias = moe_layer._afb_bias
|
||||||
|
biased_logits = router_logits + bias
|
||||||
|
_, router_indices = torch.topk(biased_logits, self.top_k, dim=1)
|
||||||
|
unbiased_top = torch.gather(router_logits, 1, router_indices)
|
||||||
|
router_scores = torch.full_like(router_logits, float("-inf"))
|
||||||
|
router_scores.scatter_(1, router_indices, unbiased_top)
|
||||||
|
router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype)
|
||||||
|
|
||||||
|
counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel())
|
||||||
|
moe_layer._afb_counts.add_(counts.to(moe_layer._afb_counts.dtype))
|
||||||
|
|
||||||
|
return router_scores, router_logits
|
||||||
|
|
||||||
|
router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined]
|
||||||
|
router._afb_patched = True
|
||||||
|
|
||||||
|
|
||||||
|
def discover_and_prepare_layers(
|
||||||
|
model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim
|
||||||
|
) -> list[LayerHandle]:
|
||||||
|
"""Discover MoE layers using the first matching adapter and attach per-layer buffers.
|
||||||
|
|
||||||
|
Returns a list of layer handles for later routing patching and updates.
|
||||||
|
"""
|
||||||
|
handles: list[LayerHandle] = []
|
||||||
|
adapter: Optional[BaseMoEAdapter] = None
|
||||||
|
for a in adapters:
|
||||||
|
if a.matches(model):
|
||||||
|
adapter = a
|
||||||
|
break
|
||||||
|
|
||||||
|
if adapter is None:
|
||||||
|
LOG.info("AuxFreeMoE: no matching adapter found; skipping aux-free routing")
|
||||||
|
return handles
|
||||||
|
|
||||||
|
# disable aux loss at model level if possible
|
||||||
|
adapter.disable_aux_loss(getattr(model, "config", model))
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
for layer in adapter.find_moe_layers(model):
|
||||||
|
try:
|
||||||
|
top_k = adapter.get_top_k(layer)
|
||||||
|
nE = adapter.get_num_experts(layer)
|
||||||
|
except (AttributeError, TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k)
|
||||||
|
adapter.prepare(layer, handle, shim)
|
||||||
|
handles.append(handle)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"AuxFreeMoE: prepared %d %s layers for aux-free routing",
|
||||||
|
len(handles),
|
||||||
|
adapter.family,
|
||||||
|
)
|
||||||
|
return handles
|
||||||
71
src/axolotl/integrations/aux_free_router/args.py
Normal file
71
src/axolotl/integrations/aux_free_router/args.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Plugin args for the Aux-Loss-Free MoE router integration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class AuxFreeRouterArgs(BaseModel):
|
||||||
|
"""
|
||||||
|
Input args for Aux-Loss-Free MoE routing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
moe_balance_type: Literal["gshard", "noaux_tc"] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "MoE load balancing strategy: 'gshard' for auxiliary loss, "
|
||||||
|
"'noaux_tc' for aux-loss-free bias updates affecting top-k selection only. "
|
||||||
|
"Defaults to model's native behavior when unset."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
moe_update_rate: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Per-step bias update rate (gamma). Recommended: 0.005-0.05. "
|
||||||
|
"If unset, plugin default is 0.01."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
moe_update_momentum: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "EMA momentum for expert load smoothing (0-1). "
|
||||||
|
"If unset, plugin default is 0.9."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
moe_bias_cap: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Absolute clamp for expert bias magnitude. "
|
||||||
|
"If unset, plugin default is 2.0."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
moe_afb_warmup_steps: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of initial steps to delay aux-free bias updates, "
|
||||||
|
"allowing routing to stabilize. If unset, plugin default is 0."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
moe_bias_sync_group: Literal["world", "ep"] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Reduction group for expert load counts: 'world' (DP) or "
|
||||||
|
"'ep' (expert-parallel group if available). Defaults to 'world' when unset."
|
||||||
|
},
|
||||||
|
)
|
||||||
166
src/axolotl/integrations/aux_free_router/core.py
Normal file
166
src/axolotl/integrations/aux_free_router/core.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AuxFreeConfig:
|
||||||
|
rate: float = 0.01
|
||||||
|
momentum: float = 0.9
|
||||||
|
bias_cap: float = 2.0
|
||||||
|
warmup_steps: int = 0
|
||||||
|
sync_group: str = "world" # or "ep"
|
||||||
|
|
||||||
|
|
||||||
|
class AuxFreeState:
|
||||||
|
"""Holds per-layer bias and EMA load buffers."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_layers: int,
|
||||||
|
num_experts: int,
|
||||||
|
device: torch.device,
|
||||||
|
cfg: AuxFreeConfig,
|
||||||
|
):
|
||||||
|
self.bias = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
|
||||||
|
self.ema_load = [
|
||||||
|
torch.zeros(num_experts, device=device) for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
self.cfg = cfg
|
||||||
|
self.steps = 0
|
||||||
|
|
||||||
|
|
||||||
|
class AuxFreeShim:
|
||||||
|
"""Model-agnostic shim for aux-loss-free expert selection and bias updates."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
state: AuxFreeState,
|
||||||
|
ep_group: Optional[dist.ProcessGroup] = None,
|
||||||
|
ep_size: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.state = state
|
||||||
|
self.ep_group = ep_group
|
||||||
|
self._ep_size = ep_size
|
||||||
|
self._ep_group_pending = (
|
||||||
|
self.state.cfg.sync_group == "ep" and self.ep_group is None
|
||||||
|
)
|
||||||
|
self._layer_modules: dict[int, torch.nn.Module] = {}
|
||||||
|
self._prev_bias_sign: dict[int, torch.Tensor] = {}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def select_experts(
|
||||||
|
self, layer_idx: int, logits: torch.Tensor, top_k: int
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Returns (topk_indices, weights) using biased selection and unbiased weights."""
|
||||||
|
module = self._layer_modules.get(layer_idx)
|
||||||
|
if module is not None and hasattr(module, "_afb_bias"):
|
||||||
|
b = module._afb_bias
|
||||||
|
else:
|
||||||
|
b = self.state.bias[layer_idx]
|
||||||
|
biased = logits + b # bias is a buffer
|
||||||
|
_topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1)
|
||||||
|
chosen_logits = torch.gather(logits, -1, topk_idx)
|
||||||
|
weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype)
|
||||||
|
return topk_idx, weights
|
||||||
|
|
||||||
|
def register_layer_buffers(self, layer_idx: int, module: torch.nn.Module) -> None:
|
||||||
|
"""Bind model buffers so shim updates stay in sync with patched layers."""
|
||||||
|
self._layer_modules[layer_idx] = module
|
||||||
|
bias = module._afb_bias
|
||||||
|
ema = module._afb_ema
|
||||||
|
# Keep state views pointing to the same tensors to avoid drift.
|
||||||
|
if layer_idx < len(self.state.bias):
|
||||||
|
self.state.bias[layer_idx] = bias
|
||||||
|
if layer_idx < len(self.state.ema_load):
|
||||||
|
self.state.ema_load[layer_idx] = ema
|
||||||
|
|
||||||
|
def begin_step(self) -> None:
|
||||||
|
"""Call once per optimizer step before per-layer updates."""
|
||||||
|
self.state.steps += 1
|
||||||
|
|
||||||
|
def get_prev_bias_sign(self, layer_idx: int) -> Optional[torch.Tensor]:
|
||||||
|
return self._prev_bias_sign.get(layer_idx)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor:
|
||||||
|
self._maybe_init_ep_group()
|
||||||
|
if not dist.is_available() or not dist.is_initialized():
|
||||||
|
return counts
|
||||||
|
group = self.ep_group if self.ep_group is not None else dist.group.WORLD
|
||||||
|
dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=group)
|
||||||
|
return counts
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def update_bias(self, layer_idx: int, step_counts: torch.Tensor, tokens_seen: int):
|
||||||
|
"""Apply EMA-smoothed bias update toward uniform target, with clamp and optional mean-centering."""
|
||||||
|
cfg = self.state.cfg
|
||||||
|
if self.state.steps <= cfg.warmup_steps:
|
||||||
|
return
|
||||||
|
|
||||||
|
nE = step_counts.numel()
|
||||||
|
if tokens_seen <= 0:
|
||||||
|
return
|
||||||
|
module = self._layer_modules.get(layer_idx)
|
||||||
|
if module is not None and hasattr(module, "_afb_ema"):
|
||||||
|
ema = module._afb_ema
|
||||||
|
bias = module._afb_bias
|
||||||
|
else:
|
||||||
|
ema = self.state.ema_load[layer_idx]
|
||||||
|
bias = self.state.bias[layer_idx]
|
||||||
|
counts = step_counts.to(ema.device)
|
||||||
|
freq = counts.float() / float(tokens_seen)
|
||||||
|
ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq)
|
||||||
|
target = 1.0 / float(nE)
|
||||||
|
delta = cfg.rate * (target - ema)
|
||||||
|
# optional mean-centering to keep sum(bias) ~ 0
|
||||||
|
delta = delta - delta.mean()
|
||||||
|
bias.add_(delta)
|
||||||
|
if cfg.bias_cap is not None and cfg.bias_cap > 0:
|
||||||
|
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
|
||||||
|
self._prev_bias_sign[layer_idx] = torch.sign(bias.detach())
|
||||||
|
|
||||||
|
def _maybe_init_ep_group(self) -> None:
|
||||||
|
if not self._ep_group_pending:
|
||||||
|
return
|
||||||
|
if not dist.is_available() or not dist.is_initialized():
|
||||||
|
return
|
||||||
|
ep_size = self._ep_size
|
||||||
|
if not ep_size or ep_size <= 1:
|
||||||
|
LOG.warning(
|
||||||
|
"AuxFreeMoE: moe_bias_sync_group='ep' requested but expert_parallel_size<=1; defaulting to world group"
|
||||||
|
)
|
||||||
|
self.ep_group = dist.group.WORLD
|
||||||
|
self._ep_group_pending = False
|
||||||
|
return
|
||||||
|
world = dist.get_world_size()
|
||||||
|
if world % ep_size != 0:
|
||||||
|
LOG.warning(
|
||||||
|
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world group",
|
||||||
|
ep_size,
|
||||||
|
world,
|
||||||
|
)
|
||||||
|
self.ep_group = dist.group.WORLD
|
||||||
|
self._ep_group_pending = False
|
||||||
|
return
|
||||||
|
if ep_size == world:
|
||||||
|
self.ep_group = dist.group.WORLD
|
||||||
|
else:
|
||||||
|
rank = dist.get_rank()
|
||||||
|
group_start = (rank // ep_size) * ep_size
|
||||||
|
ranks = tuple(range(group_start, group_start + ep_size))
|
||||||
|
self.ep_group = dist.new_group(ranks)
|
||||||
|
LOG.info(
|
||||||
|
"AuxFreeMoE: initialized expert-parallel reduction group (size=%s, world=%s)",
|
||||||
|
ep_size,
|
||||||
|
dist.get_world_size(),
|
||||||
|
)
|
||||||
|
self._ep_group_pending = False
|
||||||
267
src/axolotl/integrations/aux_free_router/plugin.py
Normal file
267
src/axolotl/integrations/aux_free_router/plugin.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""Aux-loss-free MoE Router Plugin for Axolotl.
|
||||||
|
|
||||||
|
This plugin wires an aux-free gating option into compatible MoE models using
|
||||||
|
unbiased logits for mixture weights and per-expert biases for top-k selection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from transformers.trainer_callback import TrainerCallback
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
from .adapters import (
|
||||||
|
BailingAdapter,
|
||||||
|
BaseMoEAdapter,
|
||||||
|
Llama4Adapter,
|
||||||
|
MixtralAdapter,
|
||||||
|
Qwen3Adapter,
|
||||||
|
Qwen35MoeAdapter,
|
||||||
|
discover_and_prepare_layers,
|
||||||
|
)
|
||||||
|
from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
||||||
|
"""Post-step callback to update aux-free biases from accumulated expert counts."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
shim: AuxFreeShim,
|
||||||
|
layer_modules: list[torch.nn.Module],
|
||||||
|
trainer: Any,
|
||||||
|
):
|
||||||
|
self.shim = shim
|
||||||
|
self.layer_modules = layer_modules
|
||||||
|
self.trainer = trainer
|
||||||
|
self._prev_bias_sign: dict[int, torch.Tensor] = {}
|
||||||
|
self._telemetry_buffer: dict[int, dict[str, float]] = {}
|
||||||
|
|
||||||
|
def on_step_end(self, args, state, control, **kwargs): # noqa: D401
|
||||||
|
# Iterate prepared MoE layers and apply the bias update rule.
|
||||||
|
self.shim.begin_step()
|
||||||
|
for layer in self.layer_modules:
|
||||||
|
if not hasattr(layer, "_afb_counts") or not hasattr(
|
||||||
|
layer, "_afb_layer_idx"
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
counts = layer._afb_counts
|
||||||
|
if counts is None:
|
||||||
|
continue
|
||||||
|
counts = self.shim.all_reduce_counts(counts)
|
||||||
|
layer_idx = getattr(layer, "_afb_layer_idx", None)
|
||||||
|
if layer_idx is None:
|
||||||
|
counts.zero_()
|
||||||
|
continue
|
||||||
|
bias = layer._afb_bias
|
||||||
|
counts_for_update = counts.to(bias.device)
|
||||||
|
tokens_seen = int(counts_for_update.sum().item())
|
||||||
|
# local layer-state EMA and bias update
|
||||||
|
self.shim.update_bias(layer_idx, counts_for_update, tokens_seen)
|
||||||
|
self._collect_telemetry(layer_idx, counts_for_update, tokens_seen, bias)
|
||||||
|
# reset step counts
|
||||||
|
counts.zero_()
|
||||||
|
|
||||||
|
if self._should_log(args, state) and self._telemetry_buffer:
|
||||||
|
logs: dict[str, float] = {}
|
||||||
|
for layer_idx, metrics in sorted(self._telemetry_buffer.items()):
|
||||||
|
prefix = f"moe_afb/l{layer_idx}_"
|
||||||
|
for key, value in metrics.items():
|
||||||
|
logs[f"{prefix}{key}"] = value
|
||||||
|
if logs and hasattr(self.trainer, "log"):
|
||||||
|
self.trainer.log(logs)
|
||||||
|
self._telemetry_buffer.clear()
|
||||||
|
return control
|
||||||
|
|
||||||
|
def _collect_telemetry(
|
||||||
|
self,
|
||||||
|
layer_idx: int,
|
||||||
|
counts: torch.Tensor,
|
||||||
|
tokens_seen: int,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
if tokens_seen <= 0:
|
||||||
|
return
|
||||||
|
freq = counts.float() / float(tokens_seen)
|
||||||
|
load_min = freq.min().item()
|
||||||
|
load_mean = freq.mean().item()
|
||||||
|
load_max = freq.max().item()
|
||||||
|
bias_abs_max = bias.abs().max().item()
|
||||||
|
|
||||||
|
prev_sign = self._prev_bias_sign.get(layer_idx)
|
||||||
|
current_sign = torch.sign(bias.detach())
|
||||||
|
if prev_sign is None or prev_sign.shape != current_sign.shape:
|
||||||
|
oscillation = 0.0
|
||||||
|
else:
|
||||||
|
changed = (current_sign != prev_sign) & (
|
||||||
|
(current_sign != 0) | (prev_sign != 0)
|
||||||
|
)
|
||||||
|
if changed.numel() == 0:
|
||||||
|
oscillation = 0.0
|
||||||
|
else:
|
||||||
|
oscillation = changed.float().mean().item()
|
||||||
|
self._prev_bias_sign[layer_idx] = current_sign.clone()
|
||||||
|
|
||||||
|
self._telemetry_buffer[layer_idx] = {
|
||||||
|
"load_min": load_min,
|
||||||
|
"load_mean": load_mean,
|
||||||
|
"load_max": load_max,
|
||||||
|
"bias_abs_max": bias_abs_max,
|
||||||
|
"bias_sign_flip_frac": oscillation,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _should_log(self, args, state) -> bool:
|
||||||
|
interval = getattr(args, "logging_steps", 0)
|
||||||
|
if not interval:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
interval = max(1, int(interval))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return False
|
||||||
|
return interval > 0 and state.global_step % interval == 0
|
||||||
|
|
||||||
|
|
||||||
|
class AuxFreeMoEPlugin(BasePlugin):
|
||||||
|
"""Plugin that enables aux-loss-free routing when configured."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._handles: list = []
|
||||||
|
self._shim: Optional[AuxFreeShim] = None
|
||||||
|
self._ep_group_cache: dict[tuple[int, ...], dist.ProcessGroup] = {}
|
||||||
|
|
||||||
|
def get_input_args(self):
|
||||||
|
return "axolotl.integrations.aux_free_router.AuxFreeRouterArgs"
|
||||||
|
|
||||||
|
def post_model_build(self, cfg, model):
|
||||||
|
# Enable only when explicitly requested
|
||||||
|
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
|
||||||
|
return
|
||||||
|
|
||||||
|
# Be conservative — skip known native aux-free families
|
||||||
|
native_auxfree = getattr(
|
||||||
|
getattr(model, "config", object()), "model_type", ""
|
||||||
|
) in (
|
||||||
|
"deepseek_v3",
|
||||||
|
"glm4_moe",
|
||||||
|
)
|
||||||
|
if native_auxfree:
|
||||||
|
LOG.info(
|
||||||
|
"AuxFreeMoE: model reports native aux-free routing; skipping patching"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build aux-free state and shim
|
||||||
|
rate = cfg.moe_update_rate if cfg.moe_update_rate is not None else 0.01
|
||||||
|
momentum = (
|
||||||
|
cfg.moe_update_momentum if cfg.moe_update_momentum is not None else 0.9
|
||||||
|
)
|
||||||
|
bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0
|
||||||
|
warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0
|
||||||
|
sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world"
|
||||||
|
af_cfg = AuxFreeConfig(
|
||||||
|
rate=rate,
|
||||||
|
momentum=momentum,
|
||||||
|
bias_cap=bias_cap,
|
||||||
|
warmup_steps=warmup,
|
||||||
|
sync_group=sync_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Discover layers to count the number and experts for state sizing
|
||||||
|
adapters: list[BaseMoEAdapter] = [
|
||||||
|
MixtralAdapter(),
|
||||||
|
Qwen3Adapter(),
|
||||||
|
Qwen35MoeAdapter(),
|
||||||
|
BailingAdapter(),
|
||||||
|
Llama4Adapter(),
|
||||||
|
]
|
||||||
|
|
||||||
|
# For initial state sizing, we conservatively assume the first discovered layer defines nE
|
||||||
|
n_layers = 0
|
||||||
|
n_experts = None
|
||||||
|
for _m in model.modules():
|
||||||
|
n_layers += 1 # upper bound — we will re-use bias slots sparsely
|
||||||
|
device = next(model.parameters(), torch.tensor(0)).device
|
||||||
|
if n_layers <= 0:
|
||||||
|
n_layers = 1
|
||||||
|
if n_experts is None:
|
||||||
|
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
|
||||||
|
n_experts = 2
|
||||||
|
state = AuxFreeState(
|
||||||
|
num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg
|
||||||
|
)
|
||||||
|
ep_size = getattr(cfg, "expert_parallel_size", None)
|
||||||
|
ep_group = None
|
||||||
|
if sync_group == "ep":
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
ep_group = self._resolve_ep_group(cfg)
|
||||||
|
else:
|
||||||
|
LOG.info(
|
||||||
|
"AuxFreeMoE: deferring expert-parallel group resolution until torch.distributed initializes"
|
||||||
|
)
|
||||||
|
self._shim = AuxFreeShim(state=state, ep_group=ep_group, ep_size=ep_size)
|
||||||
|
|
||||||
|
# Discover and prepare layers (attach per-layer buffers)
|
||||||
|
self._handles = discover_and_prepare_layers(model, adapters, self._shim)
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
f"AuxFreeMoE: enabled with rate={rate}, momentum={momentum}, cap={bias_cap}, warmup={warmup}, group={sync_group}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]:
|
||||||
|
if not dist.is_available() or not dist.is_initialized():
|
||||||
|
LOG.warning(
|
||||||
|
"AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
ep_size = getattr(cfg, "expert_parallel_size", None)
|
||||||
|
if not ep_size or ep_size <= 1:
|
||||||
|
LOG.warning(
|
||||||
|
"AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
world = dist.get_world_size()
|
||||||
|
if world % ep_size != 0:
|
||||||
|
LOG.warning(
|
||||||
|
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world",
|
||||||
|
ep_size,
|
||||||
|
world,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
if ep_size == world:
|
||||||
|
return dist.group.WORLD
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
# All ranks must collectively create all EP subgroups in the same order
|
||||||
|
# to avoid deadlocks (dist.new_group is a collective operation).
|
||||||
|
world_size = world
|
||||||
|
my_group = None
|
||||||
|
for group_start in range(0, world_size, ep_size):
|
||||||
|
ranks = tuple(range(group_start, group_start + ep_size))
|
||||||
|
if ranks not in self._ep_group_cache:
|
||||||
|
self._ep_group_cache[ranks] = dist.new_group(ranks)
|
||||||
|
if rank in ranks:
|
||||||
|
my_group = self._ep_group_cache[ranks]
|
||||||
|
return my_group
|
||||||
|
|
||||||
|
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||||
|
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
|
||||||
|
return []
|
||||||
|
if self._shim is None:
|
||||||
|
return []
|
||||||
|
# gather concrete layer modules from handles
|
||||||
|
layers = [h.layer for h in self._handles]
|
||||||
|
cb = MoeAuxFreeBiasUpdateCallback(
|
||||||
|
self._shim,
|
||||||
|
layers,
|
||||||
|
trainer,
|
||||||
|
)
|
||||||
|
LOG.info("AuxFreeMoE: registering post-step bias update callback")
|
||||||
|
return [cb]
|
||||||
@@ -240,7 +240,16 @@ def _softmax_topk_route(
|
|||||||
|
|
||||||
top_k = base_gate.top_k
|
top_k = base_gate.top_k
|
||||||
num_experts = base_gate.num_experts
|
num_experts = base_gate.num_experts
|
||||||
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
|
||||||
|
# Aux-free bias: biased selection, unbiased weights
|
||||||
|
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||||
|
if afb_bias is not None:
|
||||||
|
scores_for_choice = routing_weights + afb_bias
|
||||||
|
_, selected_experts = torch.topk(scores_for_choice, top_k, dim=-1)
|
||||||
|
routing_weights = routing_weights.gather(1, selected_experts)
|
||||||
|
_accumulate_afb_counts(moe_block, selected_experts)
|
||||||
|
else:
|
||||||
|
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
||||||
|
|
||||||
if getattr(base_gate, "norm_topk_prob", True):
|
if getattr(base_gate, "norm_topk_prob", True):
|
||||||
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
||||||
@@ -282,6 +291,11 @@ def _sigmoid_topk_route(
|
|||||||
else:
|
else:
|
||||||
scores_for_choice = router_probs
|
scores_for_choice = router_probs
|
||||||
|
|
||||||
|
# Aux-free bias: stacks on top of e_score_correction_bias for selection
|
||||||
|
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||||
|
if afb_bias is not None:
|
||||||
|
scores_for_choice = scores_for_choice + afb_bias
|
||||||
|
|
||||||
# Group-based selection: pick top groups, mask the rest
|
# Group-based selection: pick top groups, mask the rest
|
||||||
n_group = getattr(moe_block, "n_group", 1)
|
n_group = getattr(moe_block, "n_group", 1)
|
||||||
if n_group > 1:
|
if n_group > 1:
|
||||||
@@ -307,6 +321,10 @@ def _sigmoid_topk_route(
|
|||||||
# Gather weights from original sigmoid scores (not bias-corrected)
|
# Gather weights from original sigmoid scores (not bias-corrected)
|
||||||
topk_weights = router_probs.gather(1, topk_indices)
|
topk_weights = router_probs.gather(1, topk_indices)
|
||||||
|
|
||||||
|
# Accumulate counts for aux-free bias update
|
||||||
|
if afb_bias is not None:
|
||||||
|
_accumulate_afb_counts(moe_block, topk_indices)
|
||||||
|
|
||||||
# Optional renormalization + scaling
|
# Optional renormalization + scaling
|
||||||
if getattr(moe_block, "norm_topk_prob", True):
|
if getattr(moe_block, "norm_topk_prob", True):
|
||||||
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
||||||
@@ -335,6 +353,16 @@ def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
|
||||||
|
"""Accumulate per-expert token counts for aux-free bias updates."""
|
||||||
|
afb_counts = getattr(moe_block, "_afb_counts", None)
|
||||||
|
if afb_counts is None:
|
||||||
|
return
|
||||||
|
flat_idx = topk_indices.reshape(-1)
|
||||||
|
counts = torch.bincount(flat_idx, minlength=afb_counts.numel())
|
||||||
|
afb_counts.add_(counts.to(afb_counts.dtype))
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Shared expert helpers
|
# Shared expert helpers
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|||||||
@@ -9,6 +9,12 @@ Different MoE architectures use different routing strategies:
|
|||||||
|
|
||||||
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
||||||
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
||||||
|
|
||||||
|
Aux-loss-free (AFB) bias integration: when the aux_free_router plugin is
|
||||||
|
active, ``moe_block._afb_bias`` and ``moe_block._afb_counts`` are registered
|
||||||
|
as buffers. The routing functions transparently inject the bias into expert
|
||||||
|
*selection* (biased topk) while keeping mixture *weights* from unbiased
|
||||||
|
scores, then accumulate per-expert token counts for the post-step bias update.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -101,17 +107,25 @@ def softmax_topk_routing(
|
|||||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||||
|
|
||||||
|
# Aux-free bias: biased selection, unbiased weights
|
||||||
|
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||||
|
scores_for_choice = router_probs
|
||||||
|
if afb_bias is not None:
|
||||||
|
scores_for_choice = router_probs + afb_bias
|
||||||
|
|
||||||
# Select top-k experts per token
|
# Select top-k experts per token
|
||||||
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
|
top_values, top_indices = torch.topk(scores_for_choice, K, dim=-1) # [T, K] each
|
||||||
|
|
||||||
|
# When aux-free bias is active, gather unbiased weights and accumulate counts
|
||||||
|
if afb_bias is not None:
|
||||||
|
top_values = router_probs.gather(1, top_indices)
|
||||||
|
_accumulate_afb_counts(moe_block, top_indices)
|
||||||
|
|
||||||
# Renormalize if configured (default True for models without the attribute,
|
# Renormalize if configured (default True for models without the attribute,
|
||||||
# e.g. Mixtral/MiniMax which always normalize)
|
# e.g. Mixtral/MiniMax which always normalize)
|
||||||
if getattr(gate, "norm_topk_prob", True):
|
if getattr(gate, "norm_topk_prob", True):
|
||||||
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
|
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
# no-op: matches transformers which casts to softmax output dtype (float32).
|
|
||||||
# top_values = top_values.to(router_probs.dtype)
|
|
||||||
|
|
||||||
# Flatten for moe_general_routing_inputs.
|
# Flatten for moe_general_routing_inputs.
|
||||||
# Token indices are naturally sorted ascending from the [T, K] layout:
|
# Token indices are naturally sorted ascending from the [T, K] layout:
|
||||||
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
|
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
|
||||||
@@ -142,7 +156,11 @@ def softmax_group_topk_routing(
|
|||||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||||
|
|
||||||
|
# Aux-free bias: inject before group selection / topk
|
||||||
|
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||||
scores_for_choice = router_probs
|
scores_for_choice = router_probs
|
||||||
|
if afb_bias is not None:
|
||||||
|
scores_for_choice = router_probs + afb_bias
|
||||||
|
|
||||||
# Group selection: pick top groups, mask the rest
|
# Group selection: pick top groups, mask the rest
|
||||||
if n_group > 1:
|
if n_group > 1:
|
||||||
@@ -159,11 +177,17 @@ def softmax_group_topk_routing(
|
|||||||
score_mask = (
|
score_mask = (
|
||||||
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
||||||
)
|
)
|
||||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
scores_for_choice = scores_for_choice.masked_fill(
|
||||||
|
~score_mask.bool(), -float("inf")
|
||||||
|
)
|
||||||
|
|
||||||
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||||
topk_weights = router_probs.gather(1, topk_indices)
|
topk_weights = router_probs.gather(1, topk_indices)
|
||||||
|
|
||||||
|
# Accumulate counts for aux-free bias update
|
||||||
|
if afb_bias is not None:
|
||||||
|
_accumulate_afb_counts(moe_block, topk_indices)
|
||||||
|
|
||||||
# Renormalization + scaling
|
# Renormalization + scaling
|
||||||
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||||
if norm_topk_prob:
|
if norm_topk_prob:
|
||||||
@@ -233,6 +257,11 @@ def sigmoid_topk_routing(
|
|||||||
)
|
)
|
||||||
scores_for_choice = router_probs + e_score_correction_bias
|
scores_for_choice = router_probs + e_score_correction_bias
|
||||||
|
|
||||||
|
# Aux-free bias: stacks on top of e_score_correction_bias for selection
|
||||||
|
afb_bias = getattr(moe_block, "_afb_bias", None)
|
||||||
|
if afb_bias is not None:
|
||||||
|
scores_for_choice = scores_for_choice + afb_bias
|
||||||
|
|
||||||
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
|
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
|
||||||
if n_group > 1:
|
if n_group > 1:
|
||||||
group_scores = (
|
group_scores = (
|
||||||
@@ -248,7 +277,9 @@ def sigmoid_topk_routing(
|
|||||||
score_mask = (
|
score_mask = (
|
||||||
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
||||||
)
|
)
|
||||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
scores_for_choice = scores_for_choice.masked_fill(
|
||||||
|
~score_mask.bool(), -float("inf")
|
||||||
|
)
|
||||||
|
|
||||||
# Final topk from (possibly masked) scores
|
# Final topk from (possibly masked) scores
|
||||||
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||||
@@ -256,6 +287,10 @@ def sigmoid_topk_routing(
|
|||||||
# Gather weights from original sigmoid scores (not bias-corrected)
|
# Gather weights from original sigmoid scores (not bias-corrected)
|
||||||
topk_weights = router_probs.gather(1, topk_indices)
|
topk_weights = router_probs.gather(1, topk_indices)
|
||||||
|
|
||||||
|
# Accumulate counts for aux-free bias update
|
||||||
|
if afb_bias is not None:
|
||||||
|
_accumulate_afb_counts(moe_block, topk_indices)
|
||||||
|
|
||||||
# Optional renormalization + scaling
|
# Optional renormalization + scaling
|
||||||
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||||
if norm_topk_prob:
|
if norm_topk_prob:
|
||||||
@@ -276,3 +311,21 @@ def sigmoid_topk_routing(
|
|||||||
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||||
|
|
||||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||||
|
|
||||||
|
|
||||||
|
def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
|
||||||
|
"""Accumulate per-expert token counts for the aux-free bias update.
|
||||||
|
|
||||||
|
Called when ``moe_block._afb_bias`` is present (registered by the
|
||||||
|
``aux_free_router`` plugin). The counts are later consumed by the
|
||||||
|
``MoeAuxFreeBiasUpdateCallback`` at each training step.
|
||||||
|
"""
|
||||||
|
if hasattr(moe_block, "training") and not moe_block.training:
|
||||||
|
return
|
||||||
|
afb_counts = getattr(moe_block, "_afb_counts", None)
|
||||||
|
if afb_counts is None:
|
||||||
|
return
|
||||||
|
num_experts = afb_counts.numel()
|
||||||
|
flat_idx = topk_indices.reshape(-1)
|
||||||
|
counts = torch.bincount(flat_idx, minlength=num_experts)
|
||||||
|
afb_counts.add_(counts.to(afb_counts.dtype))
|
||||||
|
|||||||
@@ -30,15 +30,6 @@ class LigerArgs(BaseModel):
|
|||||||
|
|
||||||
liger_rope: bool | None = None
|
liger_rope: bool | None = None
|
||||||
liger_rms_norm: bool | None = None
|
liger_rms_norm: bool | None = None
|
||||||
liger_rms_norm_gated: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": (
|
|
||||||
"Enables fused RMSNorm+SiLU gate Triton kernel for models with "
|
|
||||||
"gated RMSNorm (e.g. Qwen3.5 / Qwen3.5 MoE linear attention layers)."
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
liger_layer_norm: bool | None = None
|
liger_layer_norm: bool | None = None
|
||||||
liger_swiglu: bool | None = None
|
liger_swiglu: bool | None = None
|
||||||
liger_glu_activation: bool | None = None
|
liger_glu_activation: bool | None = None
|
||||||
|
|||||||
@@ -1,175 +0,0 @@
|
|||||||
"""
|
|
||||||
Liger FLCE for Qwen3.5. Based on transformers v5.3.0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from copy import deepcopy
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
|
|
||||||
|
|
||||||
def lce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Cache] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs,
|
|
||||||
) -> CausalLMOutputWithPast:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
"""
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
|
|
||||||
logits = None
|
|
||||||
loss = None
|
|
||||||
# if in training mode, don't materialize logits
|
|
||||||
if self.training and (labels is not None):
|
|
||||||
loss = LigerForCausalLMLoss(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
lm_head_weight=self.lm_head.weight,
|
|
||||||
labels=labels,
|
|
||||||
hidden_size=self.config.hidden_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
else: # if in inference mode materialize logits
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_liger_kernel_to_qwen3_5(
|
|
||||||
cross_entropy: bool = False,
|
|
||||||
fused_linear_cross_entropy: bool = False,
|
|
||||||
rms_norm: bool = False,
|
|
||||||
rms_norm_gated: bool = False,
|
|
||||||
glu_activation: bool = False,
|
|
||||||
layer_norm: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 models.
|
|
||||||
|
|
||||||
Note: Qwen3_5RMSNorm uses zero-init weight with offset 1.0 (like Gemma),
|
|
||||||
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
||||||
fused_linear_cross_entropy (bool):
|
|
||||||
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
||||||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
||||||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
||||||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
|
||||||
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
|
|
||||||
Qwen3_5RMSNormGated (used in linear attention layers). Default is False.
|
|
||||||
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
||||||
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import transformers.models.qwen3_5.modeling_qwen3_5 # noqa: F401
|
|
||||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
|
||||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
||||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
||||||
|
|
||||||
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
||||||
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
||||||
)
|
|
||||||
|
|
||||||
modeling_qwen3_5 = sys.modules["transformers.models.qwen3_5.modeling_qwen3_5"]
|
|
||||||
|
|
||||||
if rms_norm:
|
|
||||||
# Qwen3_5RMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
|
|
||||||
class LigerRMSNormForQwen3_5(LigerRMSNorm):
|
|
||||||
def __init__(self, dim, eps=1e-6, **kwargs):
|
|
||||||
super().__init__(
|
|
||||||
dim,
|
|
||||||
eps=eps,
|
|
||||||
offset=1.0,
|
|
||||||
casting_mode="gemma",
|
|
||||||
init_fn="zeros",
|
|
||||||
in_place=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3_5
|
|
||||||
|
|
||||||
if rms_norm_gated:
|
|
||||||
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
|
||||||
|
|
||||||
modeling_qwen3_5.Qwen3_5RMSNormGated = FusedRMSNormGated
|
|
||||||
|
|
||||||
if glu_activation:
|
|
||||||
|
|
||||||
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
|
||||||
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
|
|
||||||
config = deepcopy(config)
|
|
||||||
if intermediate_size is not None:
|
|
||||||
config.intermediate_size = intermediate_size
|
|
||||||
return LigerSwiGLUMLP(config, **kwargs)
|
|
||||||
|
|
||||||
modeling_qwen3_5.Qwen3_5MLP = _liger_swiglu_mlp_wrapper
|
|
||||||
|
|
||||||
if layer_norm:
|
|
||||||
modeling_qwen3_5.nn.LayerNorm = LigerLayerNorm
|
|
||||||
|
|
||||||
if cross_entropy:
|
|
||||||
from transformers.loss.loss_utils import nn
|
|
||||||
|
|
||||||
nn.functional.cross_entropy = liger_cross_entropy
|
|
||||||
|
|
||||||
if fused_linear_cross_entropy:
|
|
||||||
modeling_qwen3_5.Qwen3_5ForCausalLM.forward = lce_forward
|
|
||||||
@@ -1,198 +0,0 @@
|
|||||||
"""
|
|
||||||
Liger FLCE for Qwen3.5 MoE. Based on transformers v5.3.0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from copy import deepcopy
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
||||||
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
|
||||||
|
|
||||||
|
|
||||||
def lce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values=None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_router_logits: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs,
|
|
||||||
) -> MoeCausalLMOutputWithPast:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
"""
|
|
||||||
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
|
|
||||||
load_balancing_loss_func,
|
|
||||||
)
|
|
||||||
|
|
||||||
output_router_logits = (
|
|
||||||
output_router_logits
|
|
||||||
if output_router_logits is not None
|
|
||||||
else self.config.output_router_logits
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_router_logits=output_router_logits,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
|
|
||||||
logits = None
|
|
||||||
loss = None
|
|
||||||
# if in training mode, don't materialize logits
|
|
||||||
if self.training and (labels is not None):
|
|
||||||
loss = LigerForCausalLMLoss(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
lm_head_weight=self.lm_head.weight,
|
|
||||||
labels=labels,
|
|
||||||
hidden_size=self.config.hidden_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
else: # if in inference mode materialize logits
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits,
|
|
||||||
labels,
|
|
||||||
self.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
aux_loss = None
|
|
||||||
if output_router_logits:
|
|
||||||
aux_loss = load_balancing_loss_func(
|
|
||||||
outputs.router_logits,
|
|
||||||
self.num_experts,
|
|
||||||
self.num_experts_per_tok,
|
|
||||||
attention_mask,
|
|
||||||
)
|
|
||||||
if labels is not None:
|
|
||||||
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
|
|
||||||
|
|
||||||
return MoeCausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
aux_loss=aux_loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
router_logits=outputs.router_logits,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_liger_kernel_to_qwen3_5_moe(
|
|
||||||
cross_entropy: bool = False,
|
|
||||||
fused_linear_cross_entropy: bool = False,
|
|
||||||
rms_norm: bool = False,
|
|
||||||
rms_norm_gated: bool = False,
|
|
||||||
glu_activation: bool = False,
|
|
||||||
layer_norm: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 MoE models.
|
|
||||||
|
|
||||||
Note: Qwen3_5MoeRMSNorm uses zero-init weight with offset 1.0 (like Gemma),
|
|
||||||
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
||||||
fused_linear_cross_entropy (bool):
|
|
||||||
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
|
||||||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
||||||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
||||||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
|
||||||
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
|
|
||||||
Qwen3_5MoeRMSNormGated (used in linear attention layers). Default is False.
|
|
||||||
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
|
||||||
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import transformers.models.qwen3_5_moe.modeling_qwen3_5_moe # noqa: F401
|
|
||||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
|
||||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
||||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
||||||
|
|
||||||
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
||||||
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
||||||
)
|
|
||||||
|
|
||||||
modeling_mod = sys.modules["transformers.models.qwen3_5_moe.modeling_qwen3_5_moe"]
|
|
||||||
|
|
||||||
if rms_norm:
|
|
||||||
# Qwen3_5MoeRMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
|
|
||||||
class LigerRMSNormForQwen3_5Moe(LigerRMSNorm):
|
|
||||||
def __init__(self, dim, eps=1e-6, **kwargs):
|
|
||||||
super().__init__(
|
|
||||||
dim,
|
|
||||||
eps=eps,
|
|
||||||
offset=1.0,
|
|
||||||
casting_mode="gemma",
|
|
||||||
init_fn="zeros",
|
|
||||||
in_place=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
modeling_mod.Qwen3_5MoeRMSNorm = LigerRMSNormForQwen3_5Moe
|
|
||||||
|
|
||||||
if rms_norm_gated:
|
|
||||||
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
|
||||||
|
|
||||||
modeling_mod.Qwen3_5MoeRMSNormGated = FusedRMSNormGated
|
|
||||||
|
|
||||||
if glu_activation:
|
|
||||||
|
|
||||||
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
|
||||||
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
|
|
||||||
config = deepcopy(config)
|
|
||||||
if intermediate_size is not None:
|
|
||||||
config.intermediate_size = intermediate_size
|
|
||||||
return LigerSwiGLUMLP(config, **kwargs)
|
|
||||||
|
|
||||||
modeling_mod.Qwen3_5MoeMLP = _liger_swiglu_mlp_wrapper
|
|
||||||
|
|
||||||
if layer_norm:
|
|
||||||
modeling_mod.nn.LayerNorm = LigerLayerNorm
|
|
||||||
|
|
||||||
if cross_entropy:
|
|
||||||
from transformers.loss.loss_utils import nn
|
|
||||||
|
|
||||||
nn.functional.cross_entropy = liger_cross_entropy
|
|
||||||
|
|
||||||
if fused_linear_cross_entropy:
|
|
||||||
modeling_mod.Qwen3_5MoeForCausalLM.forward = lce_forward
|
|
||||||
@@ -174,19 +174,6 @@ class LigerPlugin(BasePlugin):
|
|||||||
rms_norm=cfg.liger_rms_norm,
|
rms_norm=cfg.liger_rms_norm,
|
||||||
layer_norm=cfg.liger_layer_norm,
|
layer_norm=cfg.liger_layer_norm,
|
||||||
)
|
)
|
||||||
elif cfg.model_config_type == "qwen3_5":
|
|
||||||
from axolotl.integrations.liger.models.qwen3_5 import (
|
|
||||||
apply_liger_kernel_to_qwen3_5,
|
|
||||||
)
|
|
||||||
|
|
||||||
apply_liger_kernel_to_qwen3_5(
|
|
||||||
cross_entropy=cfg.liger_cross_entropy,
|
|
||||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
||||||
glu_activation=cfg.liger_glu_activation,
|
|
||||||
rms_norm=cfg.liger_rms_norm,
|
|
||||||
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
|
|
||||||
layer_norm=cfg.liger_layer_norm,
|
|
||||||
)
|
|
||||||
elif cfg.model_config_type == "qwen3_moe":
|
elif cfg.model_config_type == "qwen3_moe":
|
||||||
from axolotl.integrations.liger.models.qwen3_moe import (
|
from axolotl.integrations.liger.models.qwen3_moe import (
|
||||||
apply_liger_kernel_to_qwen3_moe,
|
apply_liger_kernel_to_qwen3_moe,
|
||||||
@@ -199,19 +186,6 @@ class LigerPlugin(BasePlugin):
|
|||||||
rms_norm=cfg.liger_rms_norm,
|
rms_norm=cfg.liger_rms_norm,
|
||||||
layer_norm=cfg.liger_layer_norm,
|
layer_norm=cfg.liger_layer_norm,
|
||||||
)
|
)
|
||||||
elif cfg.model_config_type == "qwen3_5_moe":
|
|
||||||
from axolotl.integrations.liger.models.qwen3_5_moe import (
|
|
||||||
apply_liger_kernel_to_qwen3_5_moe,
|
|
||||||
)
|
|
||||||
|
|
||||||
apply_liger_kernel_to_qwen3_5_moe(
|
|
||||||
cross_entropy=cfg.liger_cross_entropy,
|
|
||||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
||||||
glu_activation=cfg.liger_glu_activation,
|
|
||||||
rms_norm=cfg.liger_rms_norm,
|
|
||||||
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
|
|
||||||
layer_norm=cfg.liger_layer_norm,
|
|
||||||
)
|
|
||||||
elif cfg.model_config_type == "granitemoe":
|
elif cfg.model_config_type == "granitemoe":
|
||||||
from liger_kernel.transformers import apply_liger_kernel_to_granite
|
from liger_kernel.transformers import apply_liger_kernel_to_granite
|
||||||
|
|
||||||
|
|||||||
@@ -1,147 +0,0 @@
|
|||||||
"""
|
|
||||||
Triton kernels for DoRA (Weight-Decomposed Low-Rank Adaptation).
|
|
||||||
|
|
||||||
Fuses the weight norm computation and magnitude scaling to avoid
|
|
||||||
materializing the full [out_features, in_features] combined weight matrix.
|
|
||||||
The B@A product is computed row-by-row inside the kernel.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from .quantize import dequantize
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _dora_fused_norm_kernel(
|
|
||||||
# Pointers
|
|
||||||
W_ptr, # base weight [out, in] (dequantized, row-major)
|
|
||||||
B_ptr, # LoRA B [out, rank] (row-major)
|
|
||||||
A_ptr, # LoRA A [rank, in] (row-major)
|
|
||||||
mag_ptr, # magnitude vector [out]
|
|
||||||
out_ptr, # output mag_norm_scale [out]
|
|
||||||
# Shapes
|
|
||||||
out_features,
|
|
||||||
in_features,
|
|
||||||
rank,
|
|
||||||
# Scaling
|
|
||||||
lora_scale, # float scaling factor
|
|
||||||
# Block sizes
|
|
||||||
BLOCK_IN: tl.constexpr,
|
|
||||||
BLOCK_R: tl.constexpr, # >= rank, power of 2
|
|
||||||
):
|
|
||||||
"""Compute mag_norm_scale[i] = magnitude[i] / ||W[i,:] + s * (B[i,:] @ A)[:] ||_2
|
|
||||||
|
|
||||||
Each program handles one output row. B[row,:] is loaded once (small),
|
|
||||||
then we tile over in_features computing the dot product with A[:,tile]
|
|
||||||
and accumulating the squared norm.
|
|
||||||
|
|
||||||
This avoids materializing the full [out, in] B@A matrix.
|
|
||||||
"""
|
|
||||||
row = tl.program_id(0)
|
|
||||||
if row >= out_features:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Accumulate squared norm across tiles of in_features
|
|
||||||
norm_sq_acc = tl.zeros([BLOCK_IN], dtype=tl.float32)
|
|
||||||
|
|
||||||
for start in range(0, in_features, BLOCK_IN):
|
|
||||||
cols = start + tl.arange(0, BLOCK_IN)
|
|
||||||
col_mask = cols < in_features
|
|
||||||
|
|
||||||
# Load W[row, cols]
|
|
||||||
w_vals = tl.load(
|
|
||||||
W_ptr + row * in_features + cols,
|
|
||||||
mask=col_mask,
|
|
||||||
other=0.0,
|
|
||||||
).to(tl.float32)
|
|
||||||
|
|
||||||
# Compute (B[row,:] @ A[:, cols]) for this tile
|
|
||||||
# Load B[row, r] as scalar and A[r, cols] as vector for each r
|
|
||||||
ba_vals = tl.zeros([BLOCK_IN], dtype=tl.float32)
|
|
||||||
for r in tl.static_range(BLOCK_R):
|
|
||||||
# Load scalar B[row, r]
|
|
||||||
b_val = tl.load(
|
|
||||||
B_ptr + row * rank + r,
|
|
||||||
mask=(r < rank),
|
|
||||||
other=0.0,
|
|
||||||
).to(tl.float32)
|
|
||||||
# Load vector A[r, cols]
|
|
||||||
a_vals = tl.load(
|
|
||||||
A_ptr + r * in_features + cols,
|
|
||||||
mask=(col_mask & (r < rank)),
|
|
||||||
other=0.0,
|
|
||||||
).to(tl.float32)
|
|
||||||
ba_vals += b_val * a_vals
|
|
||||||
|
|
||||||
# Combined: W + s * (B @ A)
|
|
||||||
combined = w_vals + lora_scale * ba_vals
|
|
||||||
|
|
||||||
# Accumulate squared values
|
|
||||||
norm_sq_acc += tl.where(col_mask, combined * combined, 0.0)
|
|
||||||
|
|
||||||
# Reduce to scalar norm
|
|
||||||
norm_sq = tl.sum(norm_sq_acc, axis=0)
|
|
||||||
norm = tl.sqrt(norm_sq + 1e-12) # epsilon for numerical stability
|
|
||||||
|
|
||||||
# Load magnitude and compute scale
|
|
||||||
mag = tl.load(mag_ptr + row).to(tl.float32)
|
|
||||||
scale = mag / norm
|
|
||||||
|
|
||||||
tl.store(out_ptr + row, scale)
|
|
||||||
|
|
||||||
|
|
||||||
def triton_dora_scale(
|
|
||||||
W: torch.Tensor,
|
|
||||||
W_quant,
|
|
||||||
A: torch.Tensor,
|
|
||||||
B: torch.Tensor,
|
|
||||||
s: float,
|
|
||||||
magnitude: torch.Tensor,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Compute DoRA mag_norm_scale using fused Triton kernel.
|
|
||||||
|
|
||||||
Computes B@A row-by-row inside the kernel, avoiding the full
|
|
||||||
[out_features, in_features] materialization.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
W: base weight [out, in] (possibly quantized)
|
|
||||||
W_quant: quantization state
|
|
||||||
A: LoRA A [rank, in]
|
|
||||||
B: LoRA B [out, rank]
|
|
||||||
s: LoRA scaling factor
|
|
||||||
magnitude: learned magnitude [out]
|
|
||||||
dtype: compute dtype
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
mag_norm_scale: [out] tensor = magnitude / ||W + s * B @ A||_2
|
|
||||||
"""
|
|
||||||
# Dequantize W to [out, in]
|
|
||||||
W_full = dequantize(W.t(), W_quant).t().contiguous().to(dtype)
|
|
||||||
|
|
||||||
out_features, in_features = W_full.shape
|
|
||||||
rank = A.shape[0]
|
|
||||||
|
|
||||||
out = torch.empty(out_features, dtype=dtype, device=W.device)
|
|
||||||
|
|
||||||
# Block sizes
|
|
||||||
BLOCK_IN = triton.next_power_of_2(min(in_features, 2048))
|
|
||||||
BLOCK_R = triton.next_power_of_2(rank)
|
|
||||||
|
|
||||||
_dora_fused_norm_kernel[(out_features,)](
|
|
||||||
W_full,
|
|
||||||
B.contiguous().to(dtype),
|
|
||||||
A.contiguous().to(dtype),
|
|
||||||
magnitude.contiguous(),
|
|
||||||
out,
|
|
||||||
out_features=out_features,
|
|
||||||
in_features=in_features,
|
|
||||||
rank=rank,
|
|
||||||
lora_scale=s,
|
|
||||||
BLOCK_IN=BLOCK_IN,
|
|
||||||
BLOCK_R=BLOCK_R,
|
|
||||||
)
|
|
||||||
|
|
||||||
return out.detach()
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -105,10 +105,6 @@ def dequantize(
|
|||||||
# Extract quantization state
|
# Extract quantization state
|
||||||
if not isinstance(quant_state, list):
|
if not isinstance(quant_state, list):
|
||||||
# New style quant_state class
|
# New style quant_state class
|
||||||
# Non-double-quantized models have offset=None and state2=None
|
|
||||||
if quant_state.offset is None or quant_state.state2 is None:
|
|
||||||
# Fall back to bitsandbytes standard dequantize
|
|
||||||
return bnb.functional.dequantize_4bit(W, quant_state, quant_type="nf4")
|
|
||||||
absmax = quant_state.absmax.to(target_device)
|
absmax = quant_state.absmax.to(target_device)
|
||||||
shape = quant_state.shape
|
shape = quant_state.shape
|
||||||
dtype = quant_state.dtype
|
dtype = quant_state.dtype
|
||||||
|
|||||||
@@ -1,333 +0,0 @@
|
|||||||
"""
|
|
||||||
Fused RMSNorm + SiLU Gate Triton kernel.
|
|
||||||
|
|
||||||
Computes: Y = (W + offset) * RMSNorm(X) * silu(G)
|
|
||||||
where RMSNorm(X) = X / sqrt(mean(X^2) + eps)
|
|
||||||
and silu(G) = G * sigmoid(G)
|
|
||||||
|
|
||||||
Used by Qwen3.5's GatedDeltaNet linear attention layers (Qwen3_5RMSNormGated).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
import operator
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
from liger_kernel.ops.utils import (
|
|
||||||
calculate_settings,
|
|
||||||
compare_version,
|
|
||||||
ensure_contiguous,
|
|
||||||
torch_to_triton_dtype,
|
|
||||||
)
|
|
||||||
from liger_kernel.utils import is_npu_available
|
|
||||||
|
|
||||||
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
||||||
try:
|
|
||||||
from triton.language.extra.libdevice import rsqrt
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
from triton.language.extra.cuda.libdevice import rsqrt
|
|
||||||
else:
|
|
||||||
from triton.language.math import rsqrt
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _rms_norm_gated_forward_kernel(
|
|
||||||
Y_ptr,
|
|
||||||
Y_row_stride,
|
|
||||||
X_ptr,
|
|
||||||
X_row_stride,
|
|
||||||
G_ptr,
|
|
||||||
G_row_stride,
|
|
||||||
W_ptr,
|
|
||||||
W_row_stride,
|
|
||||||
RSTD_ptr,
|
|
||||||
RSTD_row_stride,
|
|
||||||
n_cols,
|
|
||||||
eps,
|
|
||||||
offset,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Y = (W + offset) * (X / RMS(X)) * silu(G)
|
|
||||||
|
|
||||||
All computation done in fp32 (Gemma-style), result cast to input dtype.
|
|
||||||
"""
|
|
||||||
row_idx = tl.program_id(0).to(tl.int64)
|
|
||||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = col_offsets < n_cols
|
|
||||||
|
|
||||||
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
|
||||||
G_row = tl.load(G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0)
|
|
||||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
||||||
|
|
||||||
X_row_dtype = X_row.dtype
|
|
||||||
|
|
||||||
# Cast everything to fp32
|
|
||||||
X_fp32 = X_row.to(tl.float32)
|
|
||||||
G_fp32 = G_row.to(tl.float32)
|
|
||||||
W_fp32 = W_row.to(tl.float32)
|
|
||||||
|
|
||||||
# RMS norm
|
|
||||||
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
|
||||||
rstd = rsqrt(mean_sq + eps)
|
|
||||||
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
|
||||||
|
|
||||||
X_norm = X_fp32 * rstd
|
|
||||||
|
|
||||||
# SiLU gate: silu(G) = G * sigmoid(G)
|
|
||||||
sig_G = tl.sigmoid(G_fp32)
|
|
||||||
silu_G = G_fp32 * sig_G
|
|
||||||
|
|
||||||
# Fused output
|
|
||||||
Y_row = (offset + W_fp32) * X_norm * silu_G
|
|
||||||
|
|
||||||
tl.store(
|
|
||||||
Y_ptr + row_idx * Y_row_stride + col_offsets,
|
|
||||||
Y_row.to(X_row_dtype),
|
|
||||||
mask=mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _rms_norm_gated_backward_kernel(
|
|
||||||
dY_ptr,
|
|
||||||
dY_row_stride,
|
|
||||||
dX_ptr,
|
|
||||||
dX_row_stride,
|
|
||||||
dG_ptr,
|
|
||||||
dG_row_stride,
|
|
||||||
X_ptr,
|
|
||||||
X_row_stride,
|
|
||||||
X_dtype: tl.constexpr,
|
|
||||||
G_ptr,
|
|
||||||
G_row_stride,
|
|
||||||
W_ptr,
|
|
||||||
W_row_stride,
|
|
||||||
RSTD_ptr,
|
|
||||||
RSTD_row_stride,
|
|
||||||
dW_ptr,
|
|
||||||
dW_row_stride,
|
|
||||||
n_rows,
|
|
||||||
n_cols,
|
|
||||||
offset,
|
|
||||||
rows_per_program,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Backward for Y = (W + offset) * (X * RSTD) * silu(G)
|
|
||||||
|
|
||||||
dW = sum_batch(dY * X_norm * silu(G))
|
|
||||||
dG = dY * (W + offset) * X_norm * silu'(G)
|
|
||||||
where silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
|
|
||||||
dX = RSTD * (m - (1/N) * RSTD^2 * dot(m, X) * X)
|
|
||||||
where m = dY * (W + offset) * silu(G)
|
|
||||||
"""
|
|
||||||
row_block_id = tl.program_id(0).to(tl.int64)
|
|
||||||
row_start = row_block_id * rows_per_program
|
|
||||||
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
||||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = col_offsets < n_cols
|
|
||||||
|
|
||||||
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
||||||
|
|
||||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
|
||||||
W_row = W_row.to(tl.float32) + offset
|
|
||||||
|
|
||||||
for row_idx in range(row_start, row_end):
|
|
||||||
dY_row = tl.load(
|
|
||||||
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0.0
|
|
||||||
)
|
|
||||||
X_row = tl.load(
|
|
||||||
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0.0
|
|
||||||
)
|
|
||||||
G_row = tl.load(
|
|
||||||
G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0.0
|
|
||||||
)
|
|
||||||
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
|
||||||
|
|
||||||
# Cast to fp32
|
|
||||||
dY_fp32 = dY_row.to(tl.float32)
|
|
||||||
X_fp32 = X_row.to(tl.float32)
|
|
||||||
G_fp32 = G_row.to(tl.float32)
|
|
||||||
|
|
||||||
# Recompute intermediates
|
|
||||||
X_norm = X_fp32 * rstd_row
|
|
||||||
sig_G = tl.sigmoid(G_fp32)
|
|
||||||
silu_G = G_fp32 * sig_G
|
|
||||||
|
|
||||||
# dW: accumulate dY * X_norm * silu(G)
|
|
||||||
dW_acc += dY_fp32 * X_norm * silu_G
|
|
||||||
|
|
||||||
# dG: dY * (W + offset) * X_norm * silu'(G)
|
|
||||||
# silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
|
|
||||||
silu_prime_G = sig_G * (1.0 + G_fp32 * (1.0 - sig_G))
|
|
||||||
dG_row = dY_fp32 * W_row * X_norm * silu_prime_G
|
|
||||||
tl.store(
|
|
||||||
dG_ptr + row_idx * dG_row_stride + col_offsets,
|
|
||||||
dG_row.to(X_dtype),
|
|
||||||
mask=mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
# dX: standard RMSNorm backward with effective gradient m = dY * W * silu(G)
|
|
||||||
m = dY_fp32 * W_row * silu_G
|
|
||||||
dX_row = rstd_row * m
|
|
||||||
dX_row += rstd_row * (
|
|
||||||
-(1.0 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_fp32, axis=0) * X_fp32
|
|
||||||
)
|
|
||||||
tl.store(
|
|
||||||
dX_ptr + row_idx * dX_row_stride + col_offsets,
|
|
||||||
dX_row.to(X_dtype),
|
|
||||||
mask=mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
tl.store(
|
|
||||||
dW_ptr + row_block_id * dW_row_stride + col_offsets,
|
|
||||||
dW_acc,
|
|
||||||
mask=mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def rms_norm_gated_forward(X, G, W, eps, offset):
|
|
||||||
shape = X.shape
|
|
||||||
dim = shape[-1]
|
|
||||||
X = X.view(-1, dim)
|
|
||||||
G = G.view(-1, dim)
|
|
||||||
n_rows, n_cols = X.shape
|
|
||||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
||||||
|
|
||||||
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
||||||
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
|
|
||||||
|
|
||||||
assert X.shape[1] == W.shape[0], (
|
|
||||||
f"Incompatible hidden size: X.shape[1]={X.shape[1]} vs W.shape[0]={W.shape[0]}"
|
|
||||||
)
|
|
||||||
assert X.shape == G.shape, (
|
|
||||||
f"X and G must have same shape, got {X.shape} and {G.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
_rms_norm_gated_forward_kernel[(n_rows,)](
|
|
||||||
Y,
|
|
||||||
Y.stride(0),
|
|
||||||
X,
|
|
||||||
X.stride(0),
|
|
||||||
G,
|
|
||||||
G.stride(0),
|
|
||||||
W,
|
|
||||||
W.stride(0),
|
|
||||||
RSTD,
|
|
||||||
RSTD.stride(0),
|
|
||||||
n_cols,
|
|
||||||
eps,
|
|
||||||
offset,
|
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
|
||||||
num_warps=num_warps,
|
|
||||||
)
|
|
||||||
return Y.view(*shape), X, G, RSTD, BLOCK_SIZE, num_warps
|
|
||||||
|
|
||||||
|
|
||||||
def rms_norm_gated_backward(dY, X, G, W, RSTD, offset, BLOCK_SIZE, num_warps):
|
|
||||||
shape = dY.shape
|
|
||||||
dim = shape[-1]
|
|
||||||
dY = dY.view(-1, dim)
|
|
||||||
n_rows, n_cols = dY.shape
|
|
||||||
|
|
||||||
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
||||||
|
|
||||||
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
||||||
dX = torch.empty_like(dY)
|
|
||||||
dG = torch.empty_like(dY)
|
|
||||||
|
|
||||||
rows_per_program = math.ceil(n_rows / sm_count)
|
|
||||||
grid = (sm_count,)
|
|
||||||
|
|
||||||
_rms_norm_gated_backward_kernel[grid](
|
|
||||||
dY,
|
|
||||||
dY.stride(0),
|
|
||||||
dX,
|
|
||||||
dX.stride(0),
|
|
||||||
dG,
|
|
||||||
dG.stride(0),
|
|
||||||
X,
|
|
||||||
X.stride(0),
|
|
||||||
torch_to_triton_dtype[X.dtype],
|
|
||||||
G,
|
|
||||||
G.stride(0),
|
|
||||||
W,
|
|
||||||
W.stride(0),
|
|
||||||
RSTD,
|
|
||||||
RSTD.stride(0),
|
|
||||||
_dW,
|
|
||||||
_dW.stride(0),
|
|
||||||
n_rows,
|
|
||||||
n_cols,
|
|
||||||
offset,
|
|
||||||
rows_per_program,
|
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
|
||||||
num_warps=num_warps,
|
|
||||||
)
|
|
||||||
|
|
||||||
dX = dX.view(*shape)
|
|
||||||
dG = dG.view(*shape)
|
|
||||||
dW = _dW.sum(dim=0).to(W.dtype)
|
|
||||||
return dX, dG, dW
|
|
||||||
|
|
||||||
|
|
||||||
class FusedRMSNormGatedFunction(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
@ensure_contiguous
|
|
||||||
def forward(ctx, X, G, W, eps, offset=0.0):
|
|
||||||
"""
|
|
||||||
X: (B, T, H) or (BxT, H) — input hidden states
|
|
||||||
G: (B, T, H) or (BxT, H) — gate tensor
|
|
||||||
W: (H,) — weight parameter
|
|
||||||
"""
|
|
||||||
Y, X, G, RSTD, BLOCK_SIZE, num_warps = rms_norm_gated_forward(
|
|
||||||
X, G, W, eps, offset
|
|
||||||
)
|
|
||||||
ctx.offset = offset
|
|
||||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
||||||
ctx.num_warps = num_warps
|
|
||||||
ctx.save_for_backward(X, G, W, RSTD)
|
|
||||||
return Y
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@ensure_contiguous
|
|
||||||
def backward(ctx, dY):
|
|
||||||
X, G, W, RSTD = ctx.saved_tensors
|
|
||||||
dX, dG, dW = rms_norm_gated_backward(
|
|
||||||
dY, X, G, W, RSTD, ctx.offset, ctx.BLOCK_SIZE, ctx.num_warps
|
|
||||||
)
|
|
||||||
return dX, dG, dW, None, None
|
|
||||||
|
|
||||||
|
|
||||||
class FusedRMSNormGated(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Fused RMSNorm + SiLU Gate.
|
|
||||||
|
|
||||||
Computes: Y = W * RMSNorm(X) * silu(G)
|
|
||||||
|
|
||||||
Drop-in replacement for Qwen3_5RMSNormGated with matching
|
|
||||||
init signature: __init__(hidden_size, eps=1e-6, **kwargs)
|
|
||||||
and forward signature: forward(hidden_states, gate=None)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6, offset=0.0, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
self.offset = offset
|
|
||||||
|
|
||||||
def forward(self, hidden_states, gate=None):
|
|
||||||
if gate is None:
|
|
||||||
raise ValueError("FusedRMSNormGated requires a gate tensor")
|
|
||||||
if hidden_states.device.type != "cuda":
|
|
||||||
raise ValueError(
|
|
||||||
f"FusedRMSNormGated requires CUDA tensors, got device={hidden_states.device}"
|
|
||||||
)
|
|
||||||
return FusedRMSNormGatedFunction.apply(
|
|
||||||
hidden_states, gate, self.weight, self.variance_epsilon, self.offset
|
|
||||||
)
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
|
||||||
@@ -12,7 +12,6 @@ from torch import nn
|
|||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from axolotl.kernels.lora import (
|
from axolotl.kernels.lora import (
|
||||||
apply_lora_embedding,
|
|
||||||
apply_lora_mlp_geglu,
|
apply_lora_mlp_geglu,
|
||||||
apply_lora_mlp_swiglu,
|
apply_lora_mlp_swiglu,
|
||||||
apply_lora_o,
|
apply_lora_o,
|
||||||
@@ -371,13 +370,13 @@ def apply_lora_kernel_patches(
|
|||||||
active_adapter = model.active_adapter
|
active_adapter = model.active_adapter
|
||||||
lora_config = model.model.peft_config[active_adapter]
|
lora_config = model.model.peft_config[active_adapter]
|
||||||
|
|
||||||
# Log what features are active
|
# Only patch if conditions are met
|
||||||
if lora_config.lora_dropout > 0:
|
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
|
||||||
LOG.info(f"LoRA kernels: dropout={lora_config.lora_dropout} enabled")
|
|
||||||
if lora_config.bias != "none":
|
if not can_patch:
|
||||||
LOG.info(f"LoRA kernels: bias={lora_config.bias} enabled")
|
LOG.warning("Cannot patch layers - requires no dropout and no bias")
|
||||||
if lora_config.use_dora:
|
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
|
||||||
LOG.info("LoRA kernels: DoRA enabled")
|
return model
|
||||||
|
|
||||||
# This needs to be reset after patching
|
# This needs to be reset after patching
|
||||||
original_level = LOG.getEffectiveLevel()
|
original_level = LOG.getEffectiveLevel()
|
||||||
@@ -420,33 +419,44 @@ def apply_lora_kernel_patches(
|
|||||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||||
]
|
]
|
||||||
can_patch_qkv = all(
|
can_patch_qkv = all(
|
||||||
hasattr(module, "lora_A") for module in layer_modules
|
hasattr(module, "lora_A")
|
||||||
|
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||||
|
for module in layer_modules
|
||||||
)
|
)
|
||||||
|
|
||||||
if can_patch_qkv:
|
if can_patch_qkv:
|
||||||
|
# Add optimized implementation
|
||||||
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some attention QKV projections - requires LoRA adapters"
|
"Cannot patch some attention QKV projections - requires LoRA "
|
||||||
|
"adapters and no lora_magnitude_vector (DoRA)"
|
||||||
)
|
)
|
||||||
if cfg.lora_o_kernel:
|
if cfg.lora_o_kernel:
|
||||||
# Output patching
|
# Output patching
|
||||||
layer_modules = [
|
layer_modules = [
|
||||||
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
|
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||||
]
|
]
|
||||||
can_patch_o = all(hasattr(module, "lora_A") for module in layer_modules)
|
can_patch_o = all(
|
||||||
|
hasattr(module, "lora_A")
|
||||||
|
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||||
|
for module in layer_modules
|
||||||
|
)
|
||||||
|
|
||||||
if can_patch_o:
|
if can_patch_o:
|
||||||
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some attention output projection - requires LoRA adapters"
|
"Cannot patch some attention output projection - requires LoRA "
|
||||||
|
"adapters and no lora_magnitude_vector (DoRA)"
|
||||||
)
|
)
|
||||||
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
||||||
if cfg.lora_mlp_kernel:
|
if cfg.lora_mlp_kernel:
|
||||||
# MLP patching
|
# MLP patching
|
||||||
can_patch_mlp = all(
|
can_patch_mlp = all(
|
||||||
hasattr(proj, "lora_A") for proj in (gate_proj, up_proj, down_proj)
|
hasattr(proj, "lora_A")
|
||||||
|
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
||||||
|
for proj in (gate_proj, up_proj, down_proj)
|
||||||
)
|
)
|
||||||
|
|
||||||
if can_patch_mlp:
|
if can_patch_mlp:
|
||||||
@@ -454,50 +464,15 @@ def apply_lora_kernel_patches(
|
|||||||
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some MLP layers - requires LoRA adapters"
|
"Cannot patch some MLP layers - requires LoRA adapters and no "
|
||||||
|
"lora_magnitude_vector (DoRA)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Patch embedding layers (model-level, not per-layer)
|
|
||||||
if cfg.lora_embedding_kernel:
|
|
||||||
_patch_embedding_layers(model, cfg)
|
|
||||||
|
|
||||||
LOG.setLevel(original_level)
|
LOG.setLevel(original_level)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _patch_embedding_layers(model: PeftModelForCausalLM, cfg: DictDefault):
|
|
||||||
"""Patch embedding layers with fused LoRA kernel.
|
|
||||||
|
|
||||||
Handles both embed_tokens (nn.Embedding with lora_embedding_A/B) and
|
|
||||||
lm_head (nn.Linear with lora_A/B, used when tied embeddings are untied by PEFT).
|
|
||||||
"""
|
|
||||||
pretrained_model = model.model
|
|
||||||
patched = 0
|
|
||||||
|
|
||||||
# Find embedding modules - check common locations
|
|
||||||
for attr_path in [
|
|
||||||
("model", "embed_tokens"),
|
|
||||||
("model", "language_model", "embed_tokens"),
|
|
||||||
]:
|
|
||||||
parent = pretrained_model
|
|
||||||
for attr in attr_path:
|
|
||||||
parent = getattr(parent, attr, None)
|
|
||||||
if parent is None:
|
|
||||||
break
|
|
||||||
if parent is not None and hasattr(parent, "lora_embedding_A"):
|
|
||||||
LOG.info(f"Patching embedding layer: {'.'.join(attr_path)}")
|
|
||||||
parent.forward = types.MethodType(apply_lora_embedding, parent)
|
|
||||||
patched += 1
|
|
||||||
|
|
||||||
# lm_head with LoRA is a Linear layer - already handled by LoRA_O/LoRA_W kernels
|
|
||||||
# when included in target_modules. No special embedding handling needed since
|
|
||||||
# PEFT wraps it as a Linear (not Embedding) even for tied models.
|
|
||||||
|
|
||||||
if not patched:
|
|
||||||
LOG.debug("No embedding layers with LoRA found to patch")
|
|
||||||
|
|
||||||
|
|
||||||
class FakeMLP(nn.Module):
|
class FakeMLP(nn.Module):
|
||||||
"""
|
"""
|
||||||
placeholder MLP for triton patching
|
placeholder MLP for triton patching
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
import typing
|
||||||
import weakref
|
import weakref
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
@@ -41,6 +42,9 @@ from axolotl.utils.schemas.enums import RLType
|
|||||||
from axolotl.utils.train import determine_last_checkpoint
|
from axolotl.utils.train import determine_last_checkpoint
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||||
@@ -483,7 +487,7 @@ def handle_untrained_tokens_fix(
|
|||||||
def setup_model_and_trainer(
|
def setup_model_and_trainer(
|
||||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
Trainer,
|
"HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
|
||||||
PeftModel | PreTrainedModel,
|
PeftModel | PreTrainedModel,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
PeftConfig | None,
|
PeftConfig | None,
|
||||||
@@ -550,36 +554,6 @@ def setup_model_and_trainer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_tui_enabled(cfg: DictDefault) -> bool:
|
|
||||||
"""Check if TUI is enabled via config or environment variable."""
|
|
||||||
if os.environ.get("AXOLOTL_TUI", "").lower() in ("1", "true", "yes"):
|
|
||||||
return True
|
|
||||||
tui = cfg.get("tui")
|
|
||||||
if tui is None:
|
|
||||||
return False
|
|
||||||
if isinstance(tui, bool):
|
|
||||||
return tui
|
|
||||||
if isinstance(tui, dict):
|
|
||||||
return tui.get("enabled", False)
|
|
||||||
if hasattr(tui, "enabled"):
|
|
||||||
return tui.enabled
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tui_config(cfg: DictDefault) -> dict:
|
|
||||||
"""Extract TUI config dict from cfg."""
|
|
||||||
tui = cfg.get("tui")
|
|
||||||
if tui is None or isinstance(tui, bool):
|
|
||||||
return {"enabled": True}
|
|
||||||
if isinstance(tui, dict):
|
|
||||||
return {**tui, "enabled": True}
|
|
||||||
if hasattr(tui, "model_dump"):
|
|
||||||
d = tui.model_dump()
|
|
||||||
d["enabled"] = True
|
|
||||||
return d
|
|
||||||
return {"enabled": True}
|
|
||||||
|
|
||||||
|
|
||||||
@send_errors
|
@send_errors
|
||||||
def train(
|
def train(
|
||||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||||
@@ -603,37 +577,6 @@ def train(
|
|||||||
processor,
|
processor,
|
||||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||||
|
|
||||||
# Register TUI callback if enabled and rank 0
|
|
||||||
tui_enabled = _is_tui_enabled(cfg)
|
|
||||||
if tui_enabled and cfg.local_rank == 0:
|
|
||||||
from axolotl.tui import AxolotlTUICallback
|
|
||||||
from axolotl.tui.config import TUIConfig
|
|
||||||
|
|
||||||
tui_config = _get_tui_config(cfg)
|
|
||||||
tui_config_obj = (
|
|
||||||
TUIConfig(**tui_config) if isinstance(tui_config, dict) else tui_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reuse the early-started renderer if available (started in do_train)
|
|
||||||
early_renderer = getattr(cfg, "_tui_renderer", None)
|
|
||||||
early_queue = getattr(cfg, "_tui_queue", None)
|
|
||||||
|
|
||||||
tui_callback = AxolotlTUICallback(config=tui_config_obj)
|
|
||||||
if early_renderer is not None and early_queue is not None:
|
|
||||||
# Reuse the already-running renderer and queue
|
|
||||||
tui_callback._renderer = early_renderer
|
|
||||||
tui_callback._queue = early_queue
|
|
||||||
tui_callback._renderer_started_early = True
|
|
||||||
trainer.add_callback(tui_callback)
|
|
||||||
|
|
||||||
# Stash model info so on_train_begin can emit a single unified run_info event
|
|
||||||
tui_callback._pending_run_info = {
|
|
||||||
"model_name": cfg.base_model or "",
|
|
||||||
"training_mode": str(cfg.rl) if cfg.rl else "sft",
|
|
||||||
"world_size": int(os.environ.get("WORLD_SIZE", 1)),
|
|
||||||
}
|
|
||||||
LOG.info("TUI dashboard enabled")
|
|
||||||
|
|
||||||
# Handle untrained tokens if configured
|
# Handle untrained tokens if configured
|
||||||
train_dataset = dataset_meta.train_dataset
|
train_dataset = dataset_meta.train_dataset
|
||||||
handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)
|
handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
"""Axolotl Training TUI — rich-based terminal dashboard for monitoring training runs."""
|
|
||||||
|
|
||||||
from axolotl.tui.callback import AxolotlTUICallback
|
|
||||||
from axolotl.tui.config import TUIConfig
|
|
||||||
from axolotl.tui.io_capture import LineParser, register_parser
|
|
||||||
from axolotl.tui.panels import BasePanel, register_panel
|
|
||||||
from axolotl.tui.state import TUIState
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AxolotlTUICallback",
|
|
||||||
"BasePanel",
|
|
||||||
"LineParser",
|
|
||||||
"TUIConfig",
|
|
||||||
"TUIState",
|
|
||||||
"register_panel",
|
|
||||||
"register_parser",
|
|
||||||
]
|
|
||||||
@@ -1,142 +0,0 @@
|
|||||||
"""AxolotlTUICallback — HF TrainerCallback that feeds metrics to the TUI."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import queue
|
|
||||||
|
|
||||||
from transformers.trainer_callback import TrainerCallback
|
|
||||||
|
|
||||||
from axolotl.tui.config import TUIConfig
|
|
||||||
from axolotl.tui.renderer import TUIRenderer
|
|
||||||
|
|
||||||
|
|
||||||
class _TUILogHandler(logging.Handler):
|
|
||||||
"""Logging handler that pushes log records into the TUI metric queue."""
|
|
||||||
|
|
||||||
_LEVEL_MAP = {
|
|
||||||
logging.DEBUG: "debug",
|
|
||||||
logging.INFO: "info",
|
|
||||||
logging.WARNING: "warning",
|
|
||||||
logging.ERROR: "error",
|
|
||||||
logging.CRITICAL: "error",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, metric_queue: queue.Queue, min_level: str = "info"):
|
|
||||||
super().__init__()
|
|
||||||
level_name = min_level.upper()
|
|
||||||
self.setLevel(getattr(logging, level_name, logging.INFO))
|
|
||||||
self._queue = metric_queue
|
|
||||||
|
|
||||||
def emit(self, record: logging.LogRecord) -> None:
|
|
||||||
try:
|
|
||||||
level = self._LEVEL_MAP.get(record.levelno, "info")
|
|
||||||
msg = self.format(record)
|
|
||||||
self._queue.put_nowait(
|
|
||||||
{
|
|
||||||
"type": "log_line",
|
|
||||||
"level": level,
|
|
||||||
"message": msg,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except queue.Full:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
self.handleError(record)
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTUICallback(TrainerCallback):
|
|
||||||
"""Pushes training metrics into a queue for the TUI renderer.
|
|
||||||
|
|
||||||
The callback never blocks on the render thread. The queue is bounded
|
|
||||||
(maxsize=512) with put_nowait; overflow is silently dropped.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: TUIConfig):
|
|
||||||
self._config = config
|
|
||||||
self._queue: queue.Queue = queue.Queue(maxsize=4096)
|
|
||||||
self._renderer = TUIRenderer(config=config, metric_queue=self._queue)
|
|
||||||
self._log_handler: _TUILogHandler | None = None
|
|
||||||
self._renderer_started_early: bool = False
|
|
||||||
self._pending_run_info: dict | None = None
|
|
||||||
|
|
||||||
def _put(self, event: dict) -> None:
|
|
||||||
try:
|
|
||||||
self._queue.put_nowait(event)
|
|
||||||
except queue.Full:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
|
||||||
# Send a single unified run_info event with all fields
|
|
||||||
run_info = {
|
|
||||||
"type": "run_info",
|
|
||||||
"run_name": getattr(args, "run_name", "") or "",
|
|
||||||
"total_steps": state.max_steps,
|
|
||||||
"total_epochs": float(args.num_train_epochs)
|
|
||||||
if args.num_train_epochs
|
|
||||||
else 1.0,
|
|
||||||
}
|
|
||||||
# Merge in model_name/training_mode/world_size if stashed by train.py
|
|
||||||
if self._pending_run_info:
|
|
||||||
run_info.update(self._pending_run_info)
|
|
||||||
self._pending_run_info = None
|
|
||||||
self._put(run_info)
|
|
||||||
|
|
||||||
if not self._renderer_started_early:
|
|
||||||
# Attach a logging handler to feed log messages into the events panel
|
|
||||||
self._log_handler = _TUILogHandler(
|
|
||||||
self._queue, min_level=self._config.log_level
|
|
||||||
)
|
|
||||||
self._log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
|
|
||||||
# Attach to both root and axolotl loggers (axolotl has propagate=False)
|
|
||||||
logging.getLogger().addHandler(self._log_handler)
|
|
||||||
logging.getLogger("axolotl").addHandler(self._log_handler)
|
|
||||||
|
|
||||||
# Start the renderer background thread
|
|
||||||
self._renderer.start()
|
|
||||||
|
|
||||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
|
||||||
if logs is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Filter out non-numeric keys and internal keys
|
|
||||||
filtered = {}
|
|
||||||
for key, value in logs.items():
|
|
||||||
if key.startswith("_"):
|
|
||||||
continue
|
|
||||||
if isinstance(value, (int, float)):
|
|
||||||
filtered[key] = value
|
|
||||||
elif isinstance(value, str):
|
|
||||||
# HF Trainer sometimes passes string-encoded numbers
|
|
||||||
try:
|
|
||||||
filtered[key] = float(value)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
if filtered:
|
|
||||||
self._put({"type": "metrics", "logs": filtered})
|
|
||||||
|
|
||||||
def on_step_end(self, args, state, control, **kwargs):
|
|
||||||
self._put(
|
|
||||||
{
|
|
||||||
"type": "step",
|
|
||||||
"step": state.global_step,
|
|
||||||
"total_steps": state.max_steps,
|
|
||||||
"epoch": state.epoch if state.epoch else 0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_prediction_step(self, args, state, control, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_train_end(self, args, state, control, **kwargs):
|
|
||||||
self._put({"type": "done"})
|
|
||||||
# If renderer was started early, do_train's finally block handles stop
|
|
||||||
if not self._renderer_started_early:
|
|
||||||
self._renderer.stop()
|
|
||||||
|
|
||||||
# Remove the logging handler (only if we added it)
|
|
||||||
if self._log_handler:
|
|
||||||
logging.getLogger().removeHandler(self._log_handler)
|
|
||||||
logging.getLogger("axolotl").removeHandler(self._log_handler)
|
|
||||||
self._log_handler = None
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
"""TUI configuration — Pydantic model for TUI settings."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class TUIConfig(BaseModel):
|
|
||||||
"""Configuration for the Axolotl Training TUI dashboard."""
|
|
||||||
|
|
||||||
enabled: bool = Field(
|
|
||||||
default=False,
|
|
||||||
json_schema_extra={"description": "Enable the TUI dashboard"},
|
|
||||||
)
|
|
||||||
refresh_rate: int = Field(
|
|
||||||
default=4,
|
|
||||||
json_schema_extra={"description": "Renders per second"},
|
|
||||||
)
|
|
||||||
log_level: str = Field(
|
|
||||||
default="debug",
|
|
||||||
json_schema_extra={"description": "Minimum log level shown in events panel"},
|
|
||||||
)
|
|
||||||
panels: list[str] = Field(
|
|
||||||
default_factory=lambda: ["progress", "training", "hardware", "events", "debug"],
|
|
||||||
json_schema_extra={"description": "Ordered list of panels to display"},
|
|
||||||
)
|
|
||||||
hardware_poll_interval: int = Field(
|
|
||||||
default=2,
|
|
||||||
json_schema_extra={"description": "Seconds between pynvml GPU queries"},
|
|
||||||
)
|
|
||||||
stdout_log_path: str = Field(
|
|
||||||
default="axolotl_stdout.log",
|
|
||||||
json_schema_extra={"description": "File path for captured stdout/stderr log"},
|
|
||||||
)
|
|
||||||
parser_plugins: list[str] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
json_schema_extra={"description": "List of extra parser classes to load"},
|
|
||||||
)
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
"""GPU polling wrapper around pynvml with graceful fallback."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from axolotl.tui.state import GPUStats
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_nvml_available = False
|
|
||||||
try:
|
|
||||||
import pynvml
|
|
||||||
|
|
||||||
pynvml.nvmlInit()
|
|
||||||
_nvml_available = True
|
|
||||||
except Exception:
|
|
||||||
LOG.debug("pynvml unavailable — GPU stats will not be shown")
|
|
||||||
|
|
||||||
|
|
||||||
class GPUPoller:
|
|
||||||
"""Polls local GPU stats via pynvml. Falls back gracefully if unavailable."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._device_count = 0
|
|
||||||
if _nvml_available:
|
|
||||||
try:
|
|
||||||
self._device_count = pynvml.nvmlDeviceGetCount()
|
|
||||||
except Exception:
|
|
||||||
self._device_count = 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available(self) -> bool:
|
|
||||||
return _nvml_available and self._device_count > 0
|
|
||||||
|
|
||||||
def poll(self) -> list[GPUStats]:
|
|
||||||
if not self.available:
|
|
||||||
return []
|
|
||||||
|
|
||||||
stats = []
|
|
||||||
for i in range(self._device_count):
|
|
||||||
try:
|
|
||||||
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
|
||||||
name = pynvml.nvmlDeviceGetName(handle)
|
|
||||||
if isinstance(name, bytes):
|
|
||||||
name = name.decode("utf-8")
|
|
||||||
|
|
||||||
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
|
||||||
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
||||||
temp = pynvml.nvmlDeviceGetTemperature(
|
|
||||||
handle, pynvml.NVML_TEMPERATURE_GPU
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0
|
|
||||||
except Exception:
|
|
||||||
power = None
|
|
||||||
|
|
||||||
stats.append(
|
|
||||||
GPUStats(
|
|
||||||
id=i,
|
|
||||||
name=name,
|
|
||||||
util_pct=util.gpu,
|
|
||||||
vram_used_gb=mem.used / (1024**3),
|
|
||||||
vram_total_gb=mem.total / (1024**3),
|
|
||||||
temp_c=temp,
|
|
||||||
power_w=power,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
LOG.debug("Error polling GPU device %d", i, exc_info=True)
|
|
||||||
return stats
|
|
||||||
@@ -1,196 +0,0 @@
|
|||||||
"""I/O capture: OS-level stdout/stderr redirect, line parser chain, and parser registry."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import queue
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import IO
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Parser registry
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
_parser_registry: list[type[LineParser]] = []
|
|
||||||
|
|
||||||
|
|
||||||
def register_parser(cls: type[LineParser]) -> type[LineParser]:
|
|
||||||
"""Decorator to register a LineParser subclass."""
|
|
||||||
if cls not in _parser_registry:
|
|
||||||
_parser_registry.append(cls)
|
|
||||||
return cls
|
|
||||||
|
|
||||||
|
|
||||||
def get_registered_parsers() -> list[type[LineParser]]:
|
|
||||||
return list(_parser_registry)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Base LineParser
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class LineParser(ABC):
|
|
||||||
"""Base class for stdout/stderr line parsers."""
|
|
||||||
|
|
||||||
priority: int = 50
|
|
||||||
name: str = ""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def parse(self, line: str, source: str) -> list[dict]:
|
|
||||||
"""Parse a single captured line.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
line: one line of captured output, trailing newline stripped.
|
|
||||||
source: "stdout" or "stderr".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of event dicts to push onto the metric queue.
|
|
||||||
Return [] if this line is not relevant.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# ParserChain
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class ParserChain:
|
|
||||||
def __init__(self):
|
|
||||||
self._parsers: list[LineParser] = []
|
|
||||||
|
|
||||||
def register(self, parser: LineParser) -> None:
|
|
||||||
self._parsers.append(parser)
|
|
||||||
self._parsers.sort(key=lambda p: p.priority)
|
|
||||||
|
|
||||||
def parse(self, line: str, source: str = "stdout") -> list[dict]:
|
|
||||||
events: list[dict] = []
|
|
||||||
for parser in self._parsers:
|
|
||||||
events.extend(parser.parse(line, source))
|
|
||||||
return events
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# IOCapture — OS-level fd redirect to pipe
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class IOCapture:
|
|
||||||
"""Redirects fd 1 and fd 2 into an OS pipe, drains via a reader thread,
|
|
||||||
passes lines through a ParserChain, and tees to a log file."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, log_path: str, parser_chain: ParserChain, metric_queue: queue.Queue
|
|
||||||
):
|
|
||||||
self._parser_chain = parser_chain
|
|
||||||
self._queue = metric_queue
|
|
||||||
self._log_path = log_path
|
|
||||||
self._log_file: IO[str] | None = None
|
|
||||||
self._thread: threading.Thread | None = None
|
|
||||||
self._read_fd: int | None = None
|
|
||||||
self._write_fd: int | None = None
|
|
||||||
self._saved_stdout_fd: int | None = None
|
|
||||||
self._saved_stderr_fd: int | None = None
|
|
||||||
|
|
||||||
def start(self) -> None:
|
|
||||||
# Write run-start separator
|
|
||||||
self._log_file = open(self._log_path, "a", buffering=1) # noqa: SIM115
|
|
||||||
self._log_file.write(
|
|
||||||
f"\n=== axolotl run started {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ===\n"
|
|
||||||
)
|
|
||||||
self._log_file.flush()
|
|
||||||
|
|
||||||
# OS-level pipe
|
|
||||||
self._read_fd, self._write_fd = os.pipe()
|
|
||||||
|
|
||||||
# Save originals
|
|
||||||
self._saved_stdout_fd = os.dup(1)
|
|
||||||
self._saved_stderr_fd = os.dup(2)
|
|
||||||
|
|
||||||
# Redirect both stdout and stderr into the write end
|
|
||||||
os.dup2(self._write_fd, 1)
|
|
||||||
os.dup2(self._write_fd, 2)
|
|
||||||
os.close(self._write_fd) # write end now held by fds 1 and 2
|
|
||||||
|
|
||||||
# Also redirect Python-level handles
|
|
||||||
sys.stdout = open(1, "w", buffering=1, closefd=False) # noqa: SIM115
|
|
||||||
sys.stderr = open(2, "w", buffering=1, closefd=False) # noqa: SIM115
|
|
||||||
|
|
||||||
# Drain thread
|
|
||||||
self._thread = threading.Thread(target=self._drain, daemon=True)
|
|
||||||
self._thread.start()
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
# Restore fds — closes the write end, causing reader to see EOF
|
|
||||||
if self._saved_stdout_fd is not None and self._saved_stderr_fd is not None:
|
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
sys.stderr = sys.__stderr__
|
|
||||||
os.dup2(self._saved_stdout_fd, 1)
|
|
||||||
os.dup2(self._saved_stderr_fd, 2)
|
|
||||||
os.close(self._saved_stdout_fd)
|
|
||||||
os.close(self._saved_stderr_fd)
|
|
||||||
self._saved_stdout_fd = None
|
|
||||||
self._saved_stderr_fd = None
|
|
||||||
|
|
||||||
if self._thread is not None:
|
|
||||||
self._thread.join(timeout=2.0)
|
|
||||||
if self._thread.is_alive():
|
|
||||||
logging.getLogger(__name__).warning(
|
|
||||||
"IO capture thread did not exit after 2s"
|
|
||||||
)
|
|
||||||
self._thread = None
|
|
||||||
|
|
||||||
if self._log_file is not None:
|
|
||||||
self._log_file.close()
|
|
||||||
self._log_file = None
|
|
||||||
|
|
||||||
def _drain(self) -> None:
|
|
||||||
# Read raw bytes and split on both \n and \r to handle tqdm progress bars
|
|
||||||
# which use \r for in-place updates without \n
|
|
||||||
assert self._read_fd is not None, "_drain called before start()"
|
|
||||||
with os.fdopen(self._read_fd, "rb") as pipe:
|
|
||||||
buf = b""
|
|
||||||
while True:
|
|
||||||
chunk = pipe.read(4096)
|
|
||||||
if not chunk:
|
|
||||||
# EOF — process remaining buffer
|
|
||||||
if buf:
|
|
||||||
self._process_line(buf.decode("utf-8", errors="replace"))
|
|
||||||
break
|
|
||||||
buf += chunk
|
|
||||||
# Split on \n or \r
|
|
||||||
while b"\n" in buf or b"\r" in buf:
|
|
||||||
# Find the earliest delimiter
|
|
||||||
idx_n = buf.find(b"\n")
|
|
||||||
idx_r = buf.find(b"\r")
|
|
||||||
if idx_n == -1:
|
|
||||||
idx = idx_r
|
|
||||||
elif idx_r == -1:
|
|
||||||
idx = idx_n
|
|
||||||
else:
|
|
||||||
idx = min(idx_n, idx_r)
|
|
||||||
line = buf[:idx].decode("utf-8", errors="replace")
|
|
||||||
buf = buf[idx + 1 :]
|
|
||||||
# Handle \r\n as single delimiter
|
|
||||||
if buf.startswith(b"\n"):
|
|
||||||
buf = buf[1:]
|
|
||||||
if line:
|
|
||||||
self._process_line(line)
|
|
||||||
|
|
||||||
def _process_line(self, line: str) -> None:
|
|
||||||
line = line.rstrip()
|
|
||||||
if not line:
|
|
||||||
return
|
|
||||||
if self._log_file:
|
|
||||||
self._log_file.write(line + "\n")
|
|
||||||
self._log_file.flush()
|
|
||||||
for event in self._parser_chain.parse(line):
|
|
||||||
try:
|
|
||||||
self._queue.put_nowait(event)
|
|
||||||
except queue.Full:
|
|
||||||
pass
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
"""Panel registry and base class for TUI panels."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from rich.console import RenderableType
|
|
||||||
|
|
||||||
from axolotl.tui.state import TUIState
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Panel registry
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
_panel_registry: dict[str, type[BasePanel]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def register_panel(position: str = "bottom", weight: int = 50):
|
|
||||||
"""Decorator to register a panel class with position and weight."""
|
|
||||||
|
|
||||||
def decorator(cls: type[BasePanel]) -> type[BasePanel]:
|
|
||||||
cls.position = position
|
|
||||||
cls.weight = weight
|
|
||||||
_panel_registry[cls.name] = cls
|
|
||||||
return cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def get_registered_panels() -> dict[str, type[BasePanel]]:
|
|
||||||
return dict(_panel_registry)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# BasePanel
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class BasePanel(ABC):
|
|
||||||
name: str = ""
|
|
||||||
position: str = "bottom"
|
|
||||||
weight: int = 50
|
|
||||||
min_height: int = 4
|
|
||||||
max_height: int | None = None
|
|
||||||
modes: list[str] = ["*"]
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def render(self, state: TUIState) -> RenderableType:
|
|
||||||
"""Return a rich renderable. Called every tick."""
|
|
||||||
...
|
|
||||||
|
|
||||||
def on_event(self, event: dict) -> None: # noqa: B027
|
|
||||||
"""Optional: react to raw metric events before state is merged."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# Auto-import built-in panels to trigger registration
|
|
||||||
from axolotl.tui.panels.completions import CompletionsPanel # noqa: E402, F401
|
|
||||||
from axolotl.tui.panels.debug import DebugPanel # noqa: E402, F401
|
|
||||||
from axolotl.tui.panels.events import EventsPanel # noqa: E402, F401
|
|
||||||
from axolotl.tui.panels.hardware import HardwarePanel # noqa: E402, F401
|
|
||||||
from axolotl.tui.panels.progress import ProgressPanel # noqa: E402, F401
|
|
||||||
from axolotl.tui.panels.training import TrainingPanel # noqa: E402, F401
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
"""CompletionsPanel — shows recent RL/log_completions samples."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from rich.console import RenderableType
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.table import Table
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
from axolotl.tui.panels import BasePanel, register_panel
|
|
||||||
from axolotl.tui.state import TUIState
|
|
||||||
|
|
||||||
|
|
||||||
def _truncate(s: str, maxlen: int = 60) -> str:
|
|
||||||
return s[:maxlen] + "…" if len(s) > maxlen else s
|
|
||||||
|
|
||||||
|
|
||||||
@register_panel(position="bottom", weight=20)
|
|
||||||
class CompletionsPanel(BasePanel):
|
|
||||||
name = "completions"
|
|
||||||
min_height = 6
|
|
||||||
modes = ["grpo", "dpo"]
|
|
||||||
|
|
||||||
def render(self, state: TUIState) -> RenderableType:
|
|
||||||
if "*" not in self.modes and state.training_mode not in self.modes:
|
|
||||||
return Text("")
|
|
||||||
|
|
||||||
if not state.completions:
|
|
||||||
return Panel(
|
|
||||||
Text("No completions yet...", style="dim"),
|
|
||||||
title="Completions",
|
|
||||||
border_style="magenta",
|
|
||||||
)
|
|
||||||
|
|
||||||
table = Table(
|
|
||||||
show_header=True,
|
|
||||||
header_style="bold",
|
|
||||||
expand=True,
|
|
||||||
box=None,
|
|
||||||
pad_edge=False,
|
|
||||||
)
|
|
||||||
table.add_column("step", justify="right", width=6)
|
|
||||||
table.add_column("prompt", no_wrap=False, max_width=40)
|
|
||||||
table.add_column("completion", no_wrap=False, max_width=40)
|
|
||||||
table.add_column("reward", justify="right", width=8)
|
|
||||||
table.add_column("adv", justify="right", width=8)
|
|
||||||
|
|
||||||
for sample in list(state.completions)[-5:]:
|
|
||||||
reward_str = f"{sample.reward:.2f}" if sample.reward is not None else "--"
|
|
||||||
adv_str = (
|
|
||||||
f"{sample.advantage:+.2f}" if sample.advantage is not None else "--"
|
|
||||||
)
|
|
||||||
table.add_row(
|
|
||||||
str(sample.step),
|
|
||||||
_truncate(sample.prompt),
|
|
||||||
_truncate(sample.completion),
|
|
||||||
reward_str,
|
|
||||||
adv_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
return Panel(table, title="Completions", border_style="magenta")
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
"""DebugPanel — scrolling log of debug-level messages, separate from main events."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from rich.console import RenderableType
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
from axolotl.tui.panels import BasePanel, register_panel
|
|
||||||
from axolotl.tui.state import TUIState
|
|
||||||
|
|
||||||
|
|
||||||
@register_panel(position="bottom", weight=30)
|
|
||||||
class DebugPanel(BasePanel):
|
|
||||||
name = "debug"
|
|
||||||
min_height = 6
|
|
||||||
max_height = 10
|
|
||||||
|
|
||||||
def render(self, state: TUIState) -> RenderableType:
|
|
||||||
lines = Text()
|
|
||||||
# Show last 8 debug-level log lines
|
|
||||||
debug_lines = [
|
|
||||||
log_entry for log_entry in state.log_lines if log_entry.level == "debug"
|
|
||||||
][-8:]
|
|
||||||
for log_line in debug_lines:
|
|
||||||
ts = log_line.timestamp.strftime("%H:%M:%S")
|
|
||||||
lines.append(f"[{ts}] ", style="dim")
|
|
||||||
lines.append(log_line.message[:200], style="dim")
|
|
||||||
lines.append("\n")
|
|
||||||
|
|
||||||
if not debug_lines:
|
|
||||||
lines = Text("No debug messages yet...", style="dim")
|
|
||||||
|
|
||||||
return Panel(lines, title="Debug", border_style="dim")
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
"""EventsPanel — scrolling log of recent events, color-coded by level."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from rich.console import RenderableType
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
from axolotl.tui.panels import BasePanel, register_panel
|
|
||||||
from axolotl.tui.state import TUIState
|
|
||||||
|
|
||||||
_LEVEL_STYLES = {
|
|
||||||
"debug": "dim",
|
|
||||||
"info": "",
|
|
||||||
"warning": "yellow",
|
|
||||||
"error": "red bold",
|
|
||||||
"critical": "red bold",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@register_panel(position="bottom", weight=10)
|
|
||||||
class EventsPanel(BasePanel):
|
|
||||||
name = "events"
|
|
||||||
min_height = 8
|
|
||||||
max_height = 20
|
|
||||||
|
|
||||||
def render(self, state: TUIState) -> RenderableType:
|
|
||||||
lines = Text()
|
|
||||||
# Show last 15 non-debug log lines (debug goes to DebugPanel)
|
|
||||||
recent = [
|
|
||||||
log_entry for log_entry in state.log_lines if log_entry.level != "debug"
|
|
||||||
][-15:]
|
|
||||||
for log_line in recent:
|
|
||||||
ts = log_line.timestamp.strftime("%H:%M:%S")
|
|
||||||
level = log_line.level.upper()
|
|
||||||
style = _LEVEL_STYLES.get(log_line.level, "")
|
|
||||||
lines.append(f"[{ts}] ", style="dim")
|
|
||||||
lines.append(f"[{level}] ", style=style or "")
|
|
||||||
lines.append(log_line.message[:200], style=style or "")
|
|
||||||
lines.append("\n")
|
|
||||||
|
|
||||||
if not recent:
|
|
||||||
lines = Text("No events yet...", style="dim")
|
|
||||||
|
|
||||||
return Panel(lines, title="Events", border_style="yellow")
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
"""HardwarePanel — per-GPU stats via pynvml."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from rich.console import RenderableType
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.table import Table
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
from axolotl.tui.panels import BasePanel, register_panel
|
|
||||||
from axolotl.tui.state import TUIState
|
|
||||||
|
|
||||||
_BAR_FULL = "█"
|
|
||||||
_BAR_EMPTY = "░"
|
|
||||||
|
|
||||||
|
|
||||||
def _util_bar(pct: float, width: int = 6) -> Text:
|
|
||||||
filled = int(pct / 100 * width)
|
|
||||||
bar = _BAR_FULL * filled + _BAR_EMPTY * (width - filled)
|
|
||||||
color = "green" if pct < 70 else ("yellow" if pct < 90 else "red")
|
|
||||||
return Text.assemble((bar, color), f" {pct:3.0f}%")
|
|
||||||
|
|
||||||
|
|
||||||
@register_panel(position="right", weight=10)
|
|
||||||
class HardwarePanel(BasePanel):
|
|
||||||
name = "hardware"
|
|
||||||
min_height = 6
|
|
||||||
|
|
||||||
def render(self, state: TUIState) -> RenderableType:
|
|
||||||
if not state.gpus:
|
|
||||||
return Panel(
|
|
||||||
Text("GPU stats unavailable", style="dim"),
|
|
||||||
title="Hardware",
|
|
||||||
border_style="green",
|
|
||||||
)
|
|
||||||
|
|
||||||
table = Table(
|
|
||||||
show_header=True,
|
|
||||||
header_style="bold",
|
|
||||||
expand=True,
|
|
||||||
box=None,
|
|
||||||
pad_edge=False,
|
|
||||||
)
|
|
||||||
table.add_column("id", justify="right", width=3)
|
|
||||||
table.add_column("util", no_wrap=True)
|
|
||||||
table.add_column("vram", no_wrap=True)
|
|
||||||
table.add_column("°C", justify="right", width=4)
|
|
||||||
table.add_column("W", justify="right", width=5)
|
|
||||||
|
|
||||||
total_vram_used = 0.0
|
|
||||||
total_vram_total = 0.0
|
|
||||||
total_util = 0.0
|
|
||||||
|
|
||||||
for gpu in state.gpus:
|
|
||||||
total_vram_used += gpu.vram_used_gb
|
|
||||||
total_vram_total += gpu.vram_total_gb
|
|
||||||
total_util += gpu.util_pct
|
|
||||||
|
|
||||||
power_str = f"{gpu.power_w:.0f}" if gpu.power_w is not None else "--"
|
|
||||||
table.add_row(
|
|
||||||
str(gpu.id),
|
|
||||||
_util_bar(gpu.util_pct),
|
|
||||||
f"{gpu.vram_used_gb:.1f}/{gpu.vram_total_gb:.1f} GB",
|
|
||||||
str(gpu.temp_c),
|
|
||||||
power_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Footer with aggregates
|
|
||||||
n = len(state.gpus)
|
|
||||||
if n > 1:
|
|
||||||
avg_util = total_util / n
|
|
||||||
table.add_row(
|
|
||||||
"Σ",
|
|
||||||
Text(f"avg {avg_util:.0f}%", style="dim"),
|
|
||||||
Text(f"{total_vram_used:.1f}/{total_vram_total:.1f} GB", style="dim"),
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
|
|
||||||
return Panel(table, title="Hardware", border_style="green")
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
"""ProgressPanel — top-bar progress display with step count, elapsed, ETA."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from rich.console import RenderableType
|
|
||||||
from rich.progress import BarColumn, Progress, TextColumn
|
|
||||||
from rich.table import Table
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
from axolotl.tui.panels import BasePanel, register_panel
|
|
||||||
from axolotl.tui.state import TUIState
|
|
||||||
|
|
||||||
|
|
||||||
def _fmt_time(seconds: float | None) -> str:
|
|
||||||
if seconds is None or seconds < 0:
|
|
||||||
return "--:--:--"
|
|
||||||
h = int(seconds) // 3600
|
|
||||||
m = (int(seconds) % 3600) // 60
|
|
||||||
s = int(seconds) % 60
|
|
||||||
return f"{h}:{m:02d}:{s:02d}"
|
|
||||||
|
|
||||||
|
|
||||||
def _fmt_eta(seconds: float | None) -> str:
|
|
||||||
if seconds is None or seconds < 0:
|
|
||||||
return "eta --"
|
|
||||||
h = int(seconds) // 3600
|
|
||||||
m = (int(seconds) % 3600) // 60
|
|
||||||
if h > 0:
|
|
||||||
return f"eta {h}h{m:02d}m"
|
|
||||||
return f"eta {m}m{int(seconds) % 60:02d}s"
|
|
||||||
|
|
||||||
|
|
||||||
@register_panel(position="top", weight=10)
|
|
||||||
class ProgressPanel(BasePanel):
|
|
||||||
name = "progress"
|
|
||||||
min_height = 3
|
|
||||||
max_height = 3
|
|
||||||
|
|
||||||
def render(self, state: TUIState) -> RenderableType:
|
|
||||||
pct = (
|
|
||||||
(state.current_step / state.total_steps * 100)
|
|
||||||
if state.total_steps > 0
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Header line
|
|
||||||
mode_upper = state.training_mode.upper() if state.training_mode else "SFT"
|
|
||||||
model_short = state.model_name.split("/")[-1] if state.model_name else "model"
|
|
||||||
header = Text.assemble(
|
|
||||||
("● ", "bold green"),
|
|
||||||
("AXOLOTL", "bold cyan"),
|
|
||||||
f" {mode_upper} · {model_short} ",
|
|
||||||
(
|
|
||||||
f"{state.current_step} / {state.total_steps}",
|
|
||||||
"bold",
|
|
||||||
),
|
|
||||||
f" · {_fmt_time(state.elapsed_seconds)} elapsed · {_fmt_eta(state.eta_seconds)} · {pct:.1f}%",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Progress bar
|
|
||||||
progress = Progress(
|
|
||||||
TextColumn(""),
|
|
||||||
BarColumn(bar_width=None),
|
|
||||||
TextColumn("{task.percentage:>3.0f}%"),
|
|
||||||
expand=True,
|
|
||||||
)
|
|
||||||
task = progress.add_task("", total=state.total_steps or 1)
|
|
||||||
progress.update(task, completed=state.current_step)
|
|
||||||
|
|
||||||
table = Table.grid(expand=True)
|
|
||||||
table.add_row(header)
|
|
||||||
table.add_row(progress)
|
|
||||||
return table
|
|
||||||
@@ -1,97 +0,0 @@
|
|||||||
"""TrainingPanel — live scalar metrics table with loss sparkline."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from rich.console import RenderableType
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.table import Table
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
from axolotl.tui.panels import BasePanel, register_panel
|
|
||||||
from axolotl.tui.state import TUIState
|
|
||||||
|
|
||||||
# Braille sparkline characters (8 levels)
|
|
||||||
_SPARK_CHARS = "▁▂▃▄▅▆▇█"
|
|
||||||
|
|
||||||
|
|
||||||
def _sparkline(values: list[float] | None, width: int = 20) -> str:
|
|
||||||
if not values or len(values) < 2:
|
|
||||||
return ""
|
|
||||||
vals = list(values)[-width:]
|
|
||||||
lo, hi = min(vals), max(vals)
|
|
||||||
rng = hi - lo if hi != lo else 1.0
|
|
||||||
return "".join(_SPARK_CHARS[min(int((v - lo) / rng * 7), 7)] for v in vals)
|
|
||||||
|
|
||||||
|
|
||||||
# Known key ordering and formatting
|
|
||||||
_KNOWN_KEYS: list[tuple[str, str, str]] = [
|
|
||||||
("loss", "loss", ".4f"),
|
|
||||||
("grad_norm", "grad norm", ".3f"),
|
|
||||||
("learning_rate", "lr", ".2e"),
|
|
||||||
("tokens_per_second", "tok/s", ".1f"),
|
|
||||||
("samples_per_second", "samples/s", ".1f"),
|
|
||||||
("mfu", "MFU", ".1f"),
|
|
||||||
# RL-specific
|
|
||||||
("rewards_mean", "rewards/mean", ".4f"),
|
|
||||||
("rewards_std", "rewards/std", ".4f"),
|
|
||||||
("kl_divergence", "KL", ".4f"),
|
|
||||||
("clip_ratio", "clip ratio", ".3f"),
|
|
||||||
("queue_size", "queue", "d"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@register_panel(position="left", weight=10)
|
|
||||||
class TrainingPanel(BasePanel):
|
|
||||||
name = "training"
|
|
||||||
min_height = 8
|
|
||||||
|
|
||||||
def render(self, state: TUIState) -> RenderableType:
|
|
||||||
table = Table(
|
|
||||||
show_header=True,
|
|
||||||
header_style="bold",
|
|
||||||
expand=True,
|
|
||||||
box=None,
|
|
||||||
pad_edge=False,
|
|
||||||
)
|
|
||||||
table.add_column("metric", style="cyan", no_wrap=True)
|
|
||||||
table.add_column("value", justify="right")
|
|
||||||
table.add_column("trend", justify="left", no_wrap=True)
|
|
||||||
|
|
||||||
for attr, label, fmt in _KNOWN_KEYS:
|
|
||||||
val = getattr(state, attr, None)
|
|
||||||
if val is None:
|
|
||||||
# Also check extra dict
|
|
||||||
val = state.extra.get(attr)
|
|
||||||
if val is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
formatted = f"{val:{fmt}}"
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
formatted = str(val)
|
|
||||||
|
|
||||||
trend = ""
|
|
||||||
if attr == "loss":
|
|
||||||
trend = _sparkline(list(state.loss_history))
|
|
||||||
|
|
||||||
table.add_row(label, formatted, trend)
|
|
||||||
|
|
||||||
# Any extra keys not in _KNOWN_KEYS
|
|
||||||
known_attrs = {k for k, _, _ in _KNOWN_KEYS}
|
|
||||||
for key, val in sorted(state.extra.items()):
|
|
||||||
if key in known_attrs or val is None:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
formatted = f"{val:.4f}"
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
formatted = str(val)
|
|
||||||
table.add_row(key, formatted, "")
|
|
||||||
|
|
||||||
if table.row_count == 0:
|
|
||||||
return Panel(
|
|
||||||
Text("Waiting for first log step...", style="dim"),
|
|
||||||
title="Training",
|
|
||||||
border_style="blue",
|
|
||||||
)
|
|
||||||
|
|
||||||
return Panel(table, title="Training", border_style="blue")
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
"""Built-in line parsers — auto-imported to trigger @register_parser decorators."""
|
|
||||||
|
|
||||||
from axolotl.tui.parsers.deepspeed import DeepSpeedParser # noqa: F401
|
|
||||||
from axolotl.tui.parsers.nccl import NCCLErrorParser # noqa: F401
|
|
||||||
from axolotl.tui.parsers.raw_log import RawLogParser # noqa: F401
|
|
||||||
from axolotl.tui.parsers.torch_compile import TorchCompileParser # noqa: F401
|
|
||||||
from axolotl.tui.parsers.tqdm import TqdmParser # noqa: F401
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
"""DeepSpeedParser — extracts DeepSpeed stage info and throughput metrics."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
from axolotl.tui.io_capture import LineParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser
|
|
||||||
class DeepSpeedParser(LineParser):
|
|
||||||
priority = 20
|
|
||||||
name = "deepspeed"
|
|
||||||
|
|
||||||
_SAMPLES_RE = re.compile(r"samples/sec=([0-9.]+)")
|
|
||||||
_STAGE_RE = re.compile(r"ZeRO Stage (\d)")
|
|
||||||
|
|
||||||
def parse(self, line: str, source: str) -> list[dict]:
|
|
||||||
events: list[dict] = []
|
|
||||||
if m := self._SAMPLES_RE.search(line):
|
|
||||||
events.append(
|
|
||||||
{
|
|
||||||
"type": "metrics",
|
|
||||||
"logs": {"samples_per_second": float(m.group(1))},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if m := self._STAGE_RE.search(line):
|
|
||||||
events.append({"type": "run_info", "zero_stage": int(m.group(1))})
|
|
||||||
return events
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
"""NCCLErrorParser — surfaces NCCL errors as red alert events."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
from axolotl.tui.io_capture import LineParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser
|
|
||||||
class NCCLErrorParser(LineParser):
|
|
||||||
priority = 10
|
|
||||||
name = "nccl_error"
|
|
||||||
|
|
||||||
_RE = re.compile(r"NCCL error|Unhandled NCCL", re.IGNORECASE)
|
|
||||||
|
|
||||||
def parse(self, line: str, source: str) -> list[dict]:
|
|
||||||
if self._RE.search(line):
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"type": "log_line",
|
|
||||||
"level": "error",
|
|
||||||
"message": f"⚠ NCCL: {line}",
|
|
||||||
},
|
|
||||||
{"type": "alert", "severity": "error", "message": line},
|
|
||||||
]
|
|
||||||
return []
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
"""RawLogParser — catches every line as a log_line event."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
from axolotl.tui.io_capture import LineParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser
|
|
||||||
class RawLogParser(LineParser):
|
|
||||||
priority = 99
|
|
||||||
name = "raw_log"
|
|
||||||
|
|
||||||
_LOG_RE = re.compile(
|
|
||||||
r"^(?P<ts>\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}[,\.]\d+)"
|
|
||||||
r"\s*[-]\s*(?P<level>DEBUG|INFO|WARNING|ERROR|CRITICAL)"
|
|
||||||
r"\s*[-]\s*(?P<msg>.+)$",
|
|
||||||
re.IGNORECASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter out tqdm progress bar lines and other noisy output
|
|
||||||
_TQDM_RE = re.compile(r"^\s*\d+%\|.*\|")
|
|
||||||
_EMPTY_RE = re.compile(r"^\s*$")
|
|
||||||
|
|
||||||
def parse(self, line: str, source: str) -> list[dict]:
|
|
||||||
# Skip empty lines and tqdm progress bar updates
|
|
||||||
if self._EMPTY_RE.match(line) or self._TQDM_RE.match(line):
|
|
||||||
return []
|
|
||||||
|
|
||||||
m = self._LOG_RE.match(line)
|
|
||||||
level = (
|
|
||||||
m.group("level").lower()
|
|
||||||
if m
|
|
||||||
else ("error" if source == "stderr" else "info")
|
|
||||||
)
|
|
||||||
return [{"type": "log_line", "level": level, "message": line}]
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
"""TorchCompileParser — detects torch.compile graph breaks and recompilations."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
from axolotl.tui.io_capture import LineParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser
|
|
||||||
class TorchCompileParser(LineParser):
|
|
||||||
priority = 20
|
|
||||||
name = "torch_compile"
|
|
||||||
|
|
||||||
_RE = re.compile(r"Graph break|Recompiling|torch\.compile", re.IGNORECASE)
|
|
||||||
|
|
||||||
def parse(self, line: str, source: str) -> list[dict]:
|
|
||||||
if self._RE.search(line):
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"type": "log_line",
|
|
||||||
"level": "warning",
|
|
||||||
"message": f"⚡ compile: {line}",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
return []
|
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
"""TqdmParser — captures tqdm progress bar output and surfaces as structured events."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
from axolotl.tui.io_capture import LineParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser
|
|
||||||
class TqdmParser(LineParser):
|
|
||||||
priority = 15
|
|
||||||
name = "tqdm"
|
|
||||||
|
|
||||||
# Match tqdm-style progress lines, e.g.:
|
|
||||||
# Tokenizing Prompts (num_proc=24): 35%|███▍ | 19008/54568 [00:02<00:02, 17417.65 examples/s]
|
|
||||||
# Loading weights: 53%|█████▎ | 77/146 [00:00<00:00, 396.39it/s]
|
|
||||||
# 0%| | 0/30 [00:00<?, ?it/s]
|
|
||||||
_TQDM_RE = re.compile(
|
|
||||||
r"(?P<desc>.*?)\s*"
|
|
||||||
r"(?P<pct>\d+)%\|[▏▎▍▌▋▊▉█░▓▒# ]*\|\s*"
|
|
||||||
r"(?P<current>[\d,]+)/(?P<total>[\d,]+)"
|
|
||||||
r"\s*\[(?P<elapsed>[^\]]*)\]"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Also match simpler forms like:
|
|
||||||
# Fetching 0 files: 0it [00:00, ?it/s]
|
|
||||||
_FETCH_RE = re.compile(r"(?P<desc>[\w\s]+):\s*(?P<current>\d+)(?:it)?\s*\[.*?\]")
|
|
||||||
|
|
||||||
def parse(self, line: str, source: str) -> list[dict]:
|
|
||||||
m = self._TQDM_RE.search(line)
|
|
||||||
if m:
|
|
||||||
desc = m.group("desc").strip().rstrip(":")
|
|
||||||
pct = int(m.group("pct"))
|
|
||||||
current = int(m.group("current").replace(",", ""))
|
|
||||||
total = int(m.group("total").replace(",", ""))
|
|
||||||
|
|
||||||
events: list[dict] = []
|
|
||||||
|
|
||||||
# Surface as a log line with progress info
|
|
||||||
if pct == 100 or pct == 0 or pct % 25 == 0:
|
|
||||||
msg = (
|
|
||||||
f"[{desc}] {pct}% ({current}/{total})"
|
|
||||||
if desc
|
|
||||||
else f"{pct}% ({current}/{total})"
|
|
||||||
)
|
|
||||||
events.append(
|
|
||||||
{
|
|
||||||
"type": "log_line",
|
|
||||||
"level": "info",
|
|
||||||
"message": msg,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Also emit as a progress metric
|
|
||||||
cleaned_desc = desc.strip().lower().replace(" ", "_")
|
|
||||||
if not cleaned_desc:
|
|
||||||
cleaned_desc = "progress"
|
|
||||||
events.append(
|
|
||||||
{
|
|
||||||
"type": "metrics",
|
|
||||||
"logs": {
|
|
||||||
f"progress/{cleaned_desc}": pct / 100.0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return events
|
|
||||||
|
|
||||||
# Fallback: try simpler fetch-style progress lines
|
|
||||||
m = self._FETCH_RE.search(line)
|
|
||||||
if m:
|
|
||||||
desc = m.group("desc").strip().rstrip(":")
|
|
||||||
current = int(m.group("current"))
|
|
||||||
cleaned_desc = desc.strip().lower().replace(" ", "_")
|
|
||||||
if not cleaned_desc:
|
|
||||||
cleaned_desc = "fetch"
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"type": "log_line",
|
|
||||||
"level": "info",
|
|
||||||
"message": f"[{desc}] {current}" if desc else f"{current}",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
return []
|
|
||||||
@@ -1,449 +0,0 @@
|
|||||||
"""TUIRenderer — background daemon thread that drives the rich.live.Live display."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import queue
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.layout import Layout
|
|
||||||
from rich.live import Live
|
|
||||||
|
|
||||||
from axolotl.tui.config import TUIConfig
|
|
||||||
from axolotl.tui.gpu import GPUPoller
|
|
||||||
from axolotl.tui.io_capture import (
|
|
||||||
IOCapture,
|
|
||||||
ParserChain,
|
|
||||||
get_registered_parsers,
|
|
||||||
)
|
|
||||||
from axolotl.tui.panels import BasePanel, get_registered_panels
|
|
||||||
from axolotl.tui.state import CompletionSample, LogLine, TUIState
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class TUIRenderer:
|
|
||||||
"""Background thread that renders the TUI dashboard using rich.live.Live."""
|
|
||||||
|
|
||||||
def __init__(self, config: TUIConfig, metric_queue: queue.Queue):
|
|
||||||
self._config = config
|
|
||||||
self._queue = metric_queue
|
|
||||||
self._state = TUIState()
|
|
||||||
self._gpu_poller = GPUPoller()
|
|
||||||
self._panels: list[BasePanel] = []
|
|
||||||
self._thread: threading.Thread | None = None
|
|
||||||
self._stop_event = threading.Event()
|
|
||||||
self._io_capture: IOCapture | None = None
|
|
||||||
self._parser_chain: ParserChain | None = None
|
|
||||||
|
|
||||||
def _init_panels(self) -> None:
|
|
||||||
registry = get_registered_panels()
|
|
||||||
for panel_name in self._config.panels:
|
|
||||||
if panel_name in registry:
|
|
||||||
self._panels.append(registry[panel_name]())
|
|
||||||
|
|
||||||
def _init_parser_chain(self) -> None:
|
|
||||||
# Ensure built-in parsers are imported so @register_parser decorators fire
|
|
||||||
import axolotl.tui.parsers # noqa: F401
|
|
||||||
|
|
||||||
self._parser_chain = ParserChain()
|
|
||||||
# Register all built-in parsers
|
|
||||||
for parser_cls in get_registered_parsers():
|
|
||||||
self._parser_chain.register(parser_cls())
|
|
||||||
|
|
||||||
# Load plugin parsers
|
|
||||||
for plugin_spec in self._config.parser_plugins:
|
|
||||||
try:
|
|
||||||
if "::" in plugin_spec:
|
|
||||||
# file path :: class name
|
|
||||||
file_path, class_name = plugin_spec.split("::", 1)
|
|
||||||
import importlib.util
|
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location(
|
|
||||||
"custom_parser", file_path
|
|
||||||
)
|
|
||||||
if spec is None or spec.loader is None:
|
|
||||||
raise ImportError(f"Cannot load spec for {file_path}")
|
|
||||||
mod = importlib.util.module_from_spec(spec)
|
|
||||||
spec.loader.exec_module(mod)
|
|
||||||
parser_cls = getattr(mod, class_name)
|
|
||||||
else:
|
|
||||||
# dotted module path
|
|
||||||
module_path, class_name = plugin_spec.rsplit(".", 1)
|
|
||||||
mod = importlib.import_module(module_path)
|
|
||||||
parser_cls = getattr(mod, class_name)
|
|
||||||
self._parser_chain.register(parser_cls())
|
|
||||||
except Exception as exc:
|
|
||||||
LOG.warning(f"Failed to load parser plugin {plugin_spec}: {exc}")
|
|
||||||
|
|
||||||
def _build_layout(self) -> Layout:
|
|
||||||
layout = Layout()
|
|
||||||
|
|
||||||
top_panels = [p for p in self._panels if p.position == "top"]
|
|
||||||
left_panels = [p for p in self._panels if p.position == "left"]
|
|
||||||
right_panels = [p for p in self._panels if p.position == "right"]
|
|
||||||
bottom_panels = [p for p in self._panels if p.position == "bottom"]
|
|
||||||
|
|
||||||
sections = []
|
|
||||||
|
|
||||||
if top_panels:
|
|
||||||
layout_top = Layout(name="top", size=3)
|
|
||||||
sections.append(layout_top)
|
|
||||||
|
|
||||||
if left_panels or right_panels:
|
|
||||||
layout_middle = Layout(name="middle", ratio=3)
|
|
||||||
middle_parts = []
|
|
||||||
if left_panels:
|
|
||||||
middle_parts.append(Layout(name="left", ratio=1))
|
|
||||||
if right_panels:
|
|
||||||
middle_parts.append(Layout(name="right", ratio=1))
|
|
||||||
if middle_parts:
|
|
||||||
layout_middle.split_row(*middle_parts)
|
|
||||||
sections.append(layout_middle)
|
|
||||||
|
|
||||||
if bottom_panels:
|
|
||||||
layout_bottom = Layout(name="bottom", ratio=2)
|
|
||||||
if len(bottom_panels) > 1:
|
|
||||||
layout_bottom.split_row(
|
|
||||||
*[
|
|
||||||
Layout(name=f"bottom_{i}", ratio=1)
|
|
||||||
for i in range(len(bottom_panels))
|
|
||||||
]
|
|
||||||
)
|
|
||||||
sections.append(layout_bottom)
|
|
||||||
|
|
||||||
if sections:
|
|
||||||
layout.split_column(*sections)
|
|
||||||
|
|
||||||
return layout
|
|
||||||
|
|
||||||
def _update_layout(self, layout: Layout) -> None:
|
|
||||||
top_panels = [p for p in self._panels if p.position == "top"]
|
|
||||||
left_panels = [p for p in self._panels if p.position == "left"]
|
|
||||||
right_panels = [p for p in self._panels if p.position == "right"]
|
|
||||||
bottom_panels = [p for p in self._panels if p.position == "bottom"]
|
|
||||||
|
|
||||||
if top_panels:
|
|
||||||
layout["top"].update(top_panels[0].render(self._state))
|
|
||||||
|
|
||||||
if left_panels:
|
|
||||||
layout["left"].update(left_panels[0].render(self._state))
|
|
||||||
|
|
||||||
if right_panels:
|
|
||||||
layout["right"].update(right_panels[0].render(self._state))
|
|
||||||
|
|
||||||
if bottom_panels:
|
|
||||||
if len(bottom_panels) == 1:
|
|
||||||
layout["bottom"].update(bottom_panels[0].render(self._state))
|
|
||||||
else:
|
|
||||||
for i, panel in enumerate(bottom_panels):
|
|
||||||
layout[f"bottom_{i}"].update(panel.render(self._state))
|
|
||||||
|
|
||||||
def _drain_queue(self) -> None:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
event = self._queue.get_nowait()
|
|
||||||
except queue.Empty:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Dispatch event to panels first
|
|
||||||
for panel in self._panels:
|
|
||||||
panel.on_event(event)
|
|
||||||
|
|
||||||
event_type = event.get("type")
|
|
||||||
|
|
||||||
if event_type == "metrics":
|
|
||||||
logs = event.get("logs", {})
|
|
||||||
self._apply_metrics(logs)
|
|
||||||
|
|
||||||
elif event_type == "step":
|
|
||||||
self._state.current_step = event.get("step", self._state.current_step)
|
|
||||||
self._state.total_steps = event.get(
|
|
||||||
"total_steps", self._state.total_steps
|
|
||||||
)
|
|
||||||
self._state.current_epoch = event.get(
|
|
||||||
"epoch", self._state.current_epoch
|
|
||||||
)
|
|
||||||
now = time.time()
|
|
||||||
self._state.elapsed_seconds = now - self._state.start_time.timestamp()
|
|
||||||
if self._state.current_step > 0 and self._state.total_steps > 0:
|
|
||||||
rate = self._state.elapsed_seconds / self._state.current_step
|
|
||||||
remaining = self._state.total_steps - self._state.current_step
|
|
||||||
self._state.eta_seconds = rate * remaining
|
|
||||||
|
|
||||||
elif event_type == "log_line":
|
|
||||||
level = event.get("level", "info")
|
|
||||||
message = event.get("message", "")
|
|
||||||
self._state.log_lines.append(
|
|
||||||
LogLine(
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
level=level,
|
|
||||||
message=message,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif event_type == "completion":
|
|
||||||
self._state.completions.append(
|
|
||||||
CompletionSample(
|
|
||||||
step=event.get("step", 0),
|
|
||||||
prompt=event.get("prompt", ""),
|
|
||||||
completion=event.get("completion", ""),
|
|
||||||
reward=event.get("reward"),
|
|
||||||
advantage=event.get("advantage"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif event_type == "run_info":
|
|
||||||
if "run_name" in event:
|
|
||||||
self._state.run_name = event["run_name"]
|
|
||||||
if "model_name" in event:
|
|
||||||
self._state.model_name = event["model_name"]
|
|
||||||
if "training_mode" in event:
|
|
||||||
self._state.training_mode = event["training_mode"]
|
|
||||||
if "world_size" in event:
|
|
||||||
self._state.world_size = event["world_size"]
|
|
||||||
if "total_steps" in event:
|
|
||||||
self._state.total_steps = event["total_steps"]
|
|
||||||
if "total_epochs" in event:
|
|
||||||
self._state.total_epochs = event["total_epochs"]
|
|
||||||
if "zero_stage" in event:
|
|
||||||
self._state.zero_stage = event["zero_stage"]
|
|
||||||
|
|
||||||
elif event_type == "done":
|
|
||||||
self._stop_event.set()
|
|
||||||
|
|
||||||
def _apply_metrics(self, logs: dict[str, Any]) -> None:
|
|
||||||
metric_map = {
|
|
||||||
"loss": "loss",
|
|
||||||
"grad_norm": "grad_norm",
|
|
||||||
"learning_rate": "learning_rate",
|
|
||||||
"tokens_per_second": "tokens_per_second",
|
|
||||||
"samples_per_second": "samples_per_second",
|
|
||||||
"mfu": "mfu",
|
|
||||||
"rewards/mean": "rewards_mean",
|
|
||||||
"rewards_mean": "rewards_mean",
|
|
||||||
"rewards/std": "rewards_std",
|
|
||||||
"rewards_std": "rewards_std",
|
|
||||||
"kl": "kl_divergence",
|
|
||||||
"kl_divergence": "kl_divergence",
|
|
||||||
"clip_ratio": "clip_ratio",
|
|
||||||
"queue_size": "queue_size",
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, value in logs.items():
|
|
||||||
if key in metric_map:
|
|
||||||
setattr(self._state, metric_map[key], value)
|
|
||||||
else:
|
|
||||||
self._state.extra[key] = value
|
|
||||||
|
|
||||||
if "loss" in logs and logs["loss"] is not None:
|
|
||||||
self._state.loss_history.append(logs["loss"])
|
|
||||||
|
|
||||||
def start(self) -> None:
|
|
||||||
self._init_panels()
|
|
||||||
self._init_parser_chain()
|
|
||||||
|
|
||||||
# Set up I/O capture
|
|
||||||
assert self._parser_chain is not None, "_init_parser_chain must be called first"
|
|
||||||
self._io_capture = IOCapture(
|
|
||||||
log_path=self._config.stdout_log_path,
|
|
||||||
parser_chain=self._parser_chain,
|
|
||||||
metric_queue=self._queue,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Monkeypatch tqdm to suppress terminal output and route through our queue.
|
|
||||||
# This prevents tqdm progress bars from flickering through the TUI and
|
|
||||||
# ensures all progress events appear in the Events panel.
|
|
||||||
self._install_tqdm_hook()
|
|
||||||
|
|
||||||
self._io_capture_ready = threading.Event()
|
|
||||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
|
||||||
self._thread.start()
|
|
||||||
self._io_capture_ready.wait(timeout=5.0)
|
|
||||||
|
|
||||||
def _install_tqdm_hook(self) -> None:
|
|
||||||
"""Replace tqdm's display method to route updates through TUI queue."""
|
|
||||||
try:
|
|
||||||
import io
|
|
||||||
|
|
||||||
import tqdm
|
|
||||||
import tqdm.auto
|
|
||||||
|
|
||||||
q = self._queue
|
|
||||||
self._tqdm_parser = None
|
|
||||||
# Find our tqdm parser in the chain
|
|
||||||
for p in self._parser_chain._parsers if self._parser_chain else []:
|
|
||||||
if p.name == "tqdm":
|
|
||||||
self._tqdm_parser = p
|
|
||||||
break
|
|
||||||
|
|
||||||
# Save originals for restore
|
|
||||||
self._orig_tqdm_class_auto = tqdm.auto.tqdm
|
|
||||||
self._orig_tqdm_class_tqdm = tqdm.tqdm
|
|
||||||
self._orig_tqdm_class_std = tqdm.std.tqdm
|
|
||||||
|
|
||||||
class TUITqdm(tqdm.tqdm):
|
|
||||||
"""tqdm subclass that sends progress to TUI instead of terminal."""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
# Force output to devnull so nothing reaches the terminal
|
|
||||||
kwargs["file"] = io.StringIO()
|
|
||||||
kwargs["dynamic_ncols"] = False
|
|
||||||
kwargs["ncols"] = 80
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def display(self, msg=None, pos=None):
|
|
||||||
# Build a progress string and push to queue
|
|
||||||
if self.total and self.total > 0:
|
|
||||||
pct = self.n / self.total * 100
|
|
||||||
desc = self.desc.rstrip(": ") if self.desc else ""
|
|
||||||
# Emit events at milestones or at low frequency
|
|
||||||
is_milestone = (
|
|
||||||
self.n == 0 or self.n >= self.total or int(pct) % 25 == 0
|
|
||||||
)
|
|
||||||
if is_milestone:
|
|
||||||
try:
|
|
||||||
q.put_nowait(
|
|
||||||
{
|
|
||||||
"type": "log_line",
|
|
||||||
"level": "info",
|
|
||||||
"message": f"[{desc}] {pct:.0f}% ({self.n}/{self.total})"
|
|
||||||
if desc
|
|
||||||
else f"{pct:.0f}% ({self.n}/{self.total})",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
metric_key = (
|
|
||||||
f"progress/{desc.lower().replace(' ', '_')}"
|
|
||||||
if desc
|
|
||||||
else "progress/unknown"
|
|
||||||
)
|
|
||||||
q.put_nowait(
|
|
||||||
{
|
|
||||||
"type": "metrics",
|
|
||||||
"logs": {metric_key: pct / 100.0},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
# Emit final completion event
|
|
||||||
if self.total and self.total > 0 and self.n > 0:
|
|
||||||
desc = self.desc.rstrip(": ") if self.desc else ""
|
|
||||||
try:
|
|
||||||
q.put_nowait(
|
|
||||||
{
|
|
||||||
"type": "log_line",
|
|
||||||
"level": "info",
|
|
||||||
"message": f"[{desc}] 100% ({self.total}/{self.total}) done"
|
|
||||||
if desc
|
|
||||||
else f"100% ({self.total}/{self.total}) done",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
super().close()
|
|
||||||
|
|
||||||
# Replace tqdm globally
|
|
||||||
tqdm.auto.tqdm = TUITqdm
|
|
||||||
tqdm.tqdm = TUITqdm
|
|
||||||
# Also patch tqdm.std which some libraries use directly
|
|
||||||
tqdm.std.tqdm = TUITqdm
|
|
||||||
self._tui_tqdm_cls = TUITqdm
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
LOG.debug(f"Failed to install tqdm hook: {exc}")
|
|
||||||
|
|
||||||
def _uninstall_tqdm_hook(self) -> None:
|
|
||||||
"""Restore original tqdm."""
|
|
||||||
try:
|
|
||||||
import tqdm
|
|
||||||
import tqdm.auto
|
|
||||||
|
|
||||||
if hasattr(self, "_orig_tqdm_class_auto"):
|
|
||||||
tqdm.auto.tqdm = self._orig_tqdm_class_auto
|
|
||||||
if hasattr(self, "_orig_tqdm_class_tqdm"):
|
|
||||||
tqdm.tqdm = self._orig_tqdm_class_tqdm
|
|
||||||
if hasattr(self, "_orig_tqdm_class_std"):
|
|
||||||
tqdm.std.tqdm = self._orig_tqdm_class_std
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
self._stop_event.set()
|
|
||||||
self._uninstall_tqdm_hook()
|
|
||||||
if self._thread is not None:
|
|
||||||
self._thread.join(timeout=5.0)
|
|
||||||
|
|
||||||
def _run(self) -> None:
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Save a handle to the REAL terminal BEFORE IO capture redirects fds.
|
|
||||||
# This ensures rich.live.Live writes to the terminal, not the pipe.
|
|
||||||
saved_tty_fd = os.dup(1)
|
|
||||||
tty_file = os.fdopen(saved_tty_fd, "w", buffering=1, closefd=True)
|
|
||||||
console = Console(file=tty_file)
|
|
||||||
|
|
||||||
layout = self._build_layout()
|
|
||||||
tick_interval = 1.0 / max(self._config.refresh_rate, 1)
|
|
||||||
gpu_poll_counter = 0
|
|
||||||
gpu_poll_ticks = max(
|
|
||||||
1, int(self._config.hardware_poll_interval / tick_interval)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Start I/O capture — redirects fd 1/2 to pipe AFTER we saved the tty fd
|
|
||||||
if self._io_capture:
|
|
||||||
self._io_capture.start()
|
|
||||||
|
|
||||||
# Signal that IO capture is live so start() can return
|
|
||||||
if hasattr(self, "_io_capture_ready"):
|
|
||||||
self._io_capture_ready.set()
|
|
||||||
|
|
||||||
try:
|
|
||||||
with Live(
|
|
||||||
layout,
|
|
||||||
console=console,
|
|
||||||
refresh_per_second=self._config.refresh_rate,
|
|
||||||
screen=True,
|
|
||||||
redirect_stdout=False,
|
|
||||||
redirect_stderr=False,
|
|
||||||
) as live:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
self._drain_queue()
|
|
||||||
|
|
||||||
# Poll GPU stats periodically
|
|
||||||
gpu_poll_counter += 1
|
|
||||||
if gpu_poll_counter >= gpu_poll_ticks:
|
|
||||||
gpu_poll_counter = 0
|
|
||||||
if self._gpu_poller.available:
|
|
||||||
self._state.gpus = self._gpu_poller.poll()
|
|
||||||
|
|
||||||
# Update elapsed time
|
|
||||||
self._state.elapsed_seconds = (
|
|
||||||
time.time() - self._state.start_time.timestamp()
|
|
||||||
)
|
|
||||||
|
|
||||||
self._update_layout(layout)
|
|
||||||
live.update(layout)
|
|
||||||
|
|
||||||
time.sleep(tick_interval)
|
|
||||||
|
|
||||||
# Final drain
|
|
||||||
self._drain_queue()
|
|
||||||
self._update_layout(layout)
|
|
||||||
live.update(layout)
|
|
||||||
finally:
|
|
||||||
if self._io_capture:
|
|
||||||
self._io_capture.stop()
|
|
||||||
try:
|
|
||||||
tty_file.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
"""TUI shared data model — dataclasses for the dashboard state."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections import deque
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GPUStats:
|
|
||||||
id: int
|
|
||||||
name: str
|
|
||||||
util_pct: float
|
|
||||||
vram_used_gb: float
|
|
||||||
vram_total_gb: float
|
|
||||||
temp_c: int
|
|
||||||
power_w: float | None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LogLine:
|
|
||||||
timestamp: datetime
|
|
||||||
level: str # "info" | "debug" | "warning" | "error"
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CompletionSample:
|
|
||||||
step: int
|
|
||||||
prompt: str
|
|
||||||
completion: str
|
|
||||||
reward: float | None
|
|
||||||
advantage: float | None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TUIState:
|
|
||||||
# Run metadata
|
|
||||||
run_name: str = ""
|
|
||||||
model_name: str = ""
|
|
||||||
training_mode: str = "sft"
|
|
||||||
world_size: int = 1
|
|
||||||
start_time: datetime = field(default_factory=datetime.now)
|
|
||||||
|
|
||||||
# Progress
|
|
||||||
current_step: int = 0
|
|
||||||
total_steps: int = 0
|
|
||||||
current_epoch: float = 0.0
|
|
||||||
total_epochs: float = 1.0
|
|
||||||
elapsed_seconds: float = 0.0
|
|
||||||
eta_seconds: float | None = None
|
|
||||||
|
|
||||||
# Training metrics (rolling window + current)
|
|
||||||
loss: float | None = None
|
|
||||||
grad_norm: float | None = None
|
|
||||||
learning_rate: float | None = None
|
|
||||||
tokens_per_second: float | None = None
|
|
||||||
samples_per_second: float | None = None
|
|
||||||
mfu: float | None = None
|
|
||||||
|
|
||||||
# RL-specific (None for non-RL modes)
|
|
||||||
rewards_mean: float | None = None
|
|
||||||
rewards_std: float | None = None
|
|
||||||
kl_divergence: float | None = None
|
|
||||||
clip_ratio: float | None = None
|
|
||||||
queue_size: int | None = None
|
|
||||||
|
|
||||||
# Per-GPU hardware (list indexed by local rank)
|
|
||||||
gpus: list[GPUStats] = field(default_factory=list)
|
|
||||||
|
|
||||||
# Recent log lines
|
|
||||||
log_lines: deque[LogLine] = field(default_factory=lambda: deque(maxlen=200))
|
|
||||||
|
|
||||||
# Recent completions (GRPO/SFT with log_completions)
|
|
||||||
completions: deque[CompletionSample] = field(
|
|
||||||
default_factory=lambda: deque(maxlen=20)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Loss history for sparkline
|
|
||||||
loss_history: deque[float] = field(default_factory=lambda: deque(maxlen=50))
|
|
||||||
|
|
||||||
# DeepSpeed zero stage (None if not using DeepSpeed)
|
|
||||||
zero_stage: int | None = None
|
|
||||||
|
|
||||||
# Arbitrary plugin state
|
|
||||||
extra: dict[str, Any] = field(default_factory=dict)
|
|
||||||
@@ -299,6 +299,7 @@ def validate_config(
|
|||||||
AxolotlInputConfig = AxolotlInputConfigBase
|
AxolotlInputConfig = AxolotlInputConfigBase
|
||||||
|
|
||||||
if cfg.plugins:
|
if cfg.plugins:
|
||||||
|
prepare_plugins(cfg)
|
||||||
(
|
(
|
||||||
AxolotlConfigWCapabilities,
|
AxolotlConfigWCapabilities,
|
||||||
AxolotlInputConfig,
|
AxolotlInputConfig,
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from pydantic import (
|
|||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.tui.config import TUIConfig
|
|
||||||
from axolotl.utils.datasets import get_default_process_count
|
from axolotl.utils.datasets import get_default_process_count
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.datasets import (
|
from axolotl.utils.schemas.datasets import (
|
||||||
@@ -141,12 +140,6 @@ class AxolotlInputConfig(
|
|||||||
vllm: VllmConfig | None = Field(
|
vllm: VllmConfig | None = Field(
|
||||||
default_factory=lambda: VllmConfig(),
|
default_factory=lambda: VllmConfig(),
|
||||||
)
|
)
|
||||||
tui: TUIConfig | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "TUI dashboard configuration. Set enabled: true to activate."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
qat: QATConfig | None = None
|
qat: QATConfig | None = None
|
||||||
quantization: PTQConfig | None = None
|
quantization: PTQConfig | None = None
|
||||||
reward_model: bool | None = Field(
|
reward_model: bool | None = Field(
|
||||||
@@ -710,12 +703,6 @@ class AxolotlInputConfig(
|
|||||||
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
|
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
lora_embedding_kernel: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Apply custom LoRA autograd function for embedding layers. See: https://docs.axolotl.ai/docs/lora_optims.html"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
chunked_cross_entropy: bool | None = Field(
|
chunked_cross_entropy: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -849,6 +836,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
|
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
expert_parallel_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of processes participating in expert-parallel collectives. Set >1 to form EP groups for aux-free reductions; defaults to world when unset."
|
||||||
|
},
|
||||||
|
)
|
||||||
special_tokens: SpecialTokensConfig | None = Field(
|
special_tokens: SpecialTokensConfig | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -1326,7 +1319,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
data.get("lora_mlp_kernel")
|
data.get("lora_mlp_kernel")
|
||||||
or data.get("lora_qkv_kernel")
|
or data.get("lora_qkv_kernel")
|
||||||
or data.get("lora_o_kernel")
|
or data.get("lora_o_kernel")
|
||||||
or data.get("lora_embedding_kernel")
|
|
||||||
):
|
):
|
||||||
capabilities = data.get("capabilities")
|
capabilities = data.get("capabilities")
|
||||||
is_fsdp = data.get("fsdp_config") is not None
|
is_fsdp = data.get("fsdp_config") is not None
|
||||||
@@ -1374,12 +1366,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data.get("adapter") in ["lora", "qlora"]:
|
if data.get("adapter") in ["lora", "qlora"]:
|
||||||
# Skip if already set, using unsloth optimizations, or using 8-bit
|
# Skip if already set, using unsloth optimizations, or using 8-bit
|
||||||
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
||||||
kernel_fields = [
|
kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
|
||||||
"lora_mlp_kernel",
|
|
||||||
"lora_qkv_kernel",
|
|
||||||
"lora_o_kernel",
|
|
||||||
"lora_embedding_kernel",
|
|
||||||
]
|
|
||||||
if (
|
if (
|
||||||
any(data.get(k) is not None for k in kernel_fields)
|
any(data.get(k) is not None for k in kernel_fields)
|
||||||
or any(data.get(k) for k in unsloth_fields)
|
or any(data.get(k) for k in unsloth_fields)
|
||||||
@@ -1392,6 +1379,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data.get("trust_remote_code"):
|
if data.get("trust_remote_code"):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
|
||||||
|
if data.get("lora_dropout") != 0:
|
||||||
|
return data
|
||||||
|
|
||||||
# Check multi-GPU compatibility
|
# Check multi-GPU compatibility
|
||||||
capabilities = data.get("capabilities")
|
capabilities = data.get("capabilities")
|
||||||
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
|
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
|
||||||
@@ -1413,9 +1404,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data.get("lora_o_kernel") is None:
|
if data.get("lora_o_kernel") is None:
|
||||||
data["lora_o_kernel"] = True
|
data["lora_o_kernel"] = True
|
||||||
|
|
||||||
if data.get("lora_embedding_kernel") is None:
|
|
||||||
data["lora_embedding_kernel"] = True
|
|
||||||
|
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Auto-enabling LoRA kernel optimizations for faster training. "
|
"Auto-enabling LoRA kernel optimizations for faster training. "
|
||||||
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "
|
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "
|
||||||
|
|||||||
@@ -681,7 +681,15 @@ class LoRAValidationMixin:
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_lora_kernels_dora(cls, data):
|
def check_lora_kernels_dora(cls, data):
|
||||||
# DoRA is now supported by lora kernels
|
if (
|
||||||
|
data.get("lora_mlp_kernel")
|
||||||
|
or data.get("lora_qkv_kernel")
|
||||||
|
or data.get("lora_o_kernel")
|
||||||
|
) and data.get("peft_use_dora"):
|
||||||
|
raise ValueError(
|
||||||
|
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
||||||
|
"compatible with DoRA at the moment."
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -1378,6 +1386,14 @@ class ComplexValidationMixin:
|
|||||||
self.tensor_parallel_size = 1
|
self.tensor_parallel_size = 1
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_expert_parallel_size(self):
|
||||||
|
if not getattr(self, "expert_parallel_size", None):
|
||||||
|
self.expert_parallel_size = 1
|
||||||
|
elif self.expert_parallel_size < 1:
|
||||||
|
raise ValueError("expert_parallel_size must be >= 1")
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_context_parallel_size(self):
|
def check_context_parallel_size(self):
|
||||||
if self.sequence_parallel_degree and not self.context_parallel_size:
|
if self.sequence_parallel_degree and not self.context_parallel_size:
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ class TestLoraFP8Guard(unittest.TestCase):
|
|||||||
|
|
||||||
proj.base_layer = base_layer
|
proj.base_layer = base_layer
|
||||||
|
|
||||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
|
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
||||||
# quant_state should be None since weight is bf16, not FP8
|
# quant_state should be None since weight is bf16, not FP8
|
||||||
self.assertIsNone(quant_state)
|
self.assertIsNone(quant_state)
|
||||||
|
|
||||||
@@ -174,7 +174,7 @@ class TestLoraFP8Guard(unittest.TestCase):
|
|||||||
scale_inv = torch.ones(1)
|
scale_inv = torch.ones(1)
|
||||||
base_layer.weight_scale_inv = scale_inv
|
base_layer.weight_scale_inv = scale_inv
|
||||||
|
|
||||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
|
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
||||||
self.assertIs(quant_state, scale_inv)
|
self.assertIs(quant_state, scale_inv)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ def mock_proj():
|
|||||||
def test_get_lora_parameters(mock_proj):
|
def test_get_lora_parameters(mock_proj):
|
||||||
"""Tests get_lora_parameters function"""
|
"""Tests get_lora_parameters function"""
|
||||||
# Test with LoRA enabled
|
# Test with LoRA enabled
|
||||||
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||||
|
|
||||||
assert isinstance(W, torch.Tensor)
|
assert isinstance(W, torch.Tensor)
|
||||||
assert W.shape == (128, 64)
|
assert W.shape == (128, 64)
|
||||||
@@ -113,13 +113,13 @@ def test_get_lora_parameters(mock_proj):
|
|||||||
|
|
||||||
# Test with LoRA disabled
|
# Test with LoRA disabled
|
||||||
mock_proj.disable_adapters = True
|
mock_proj.disable_adapters = True
|
||||||
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||||
assert A is None and B is None and s is None
|
assert A is None and B is None and s is None
|
||||||
|
|
||||||
# Test with merged state
|
# Test with merged state
|
||||||
mock_proj.disable_adapters = False
|
mock_proj.disable_adapters = False
|
||||||
mock_proj.merged = True
|
mock_proj.merged = True
|
||||||
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||||
assert A is None and B is None and s is None
|
assert A is None and B is None and s is None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,120 +0,0 @@
|
|||||||
"""Test LoRA kernels under FSDP2 multi-GPU training.
|
|
||||||
|
|
||||||
Verifies that lora_qkv_kernel, lora_o_kernel, lora_mlp_kernel, and
|
|
||||||
lora_embedding_kernel work correctly with FSDP2 sharding, including
|
|
||||||
with bias, dropout, and DoRA enabled.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
|
||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
from tests.e2e.utils import require_torch_2_7_0
|
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
|
||||||
|
|
||||||
|
|
||||||
def _run_training(temp_dir, cfg):
|
|
||||||
"""Write config and launch multi-GPU training."""
|
|
||||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
|
||||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
|
||||||
|
|
||||||
execute_subprocess_async(
|
|
||||||
[
|
|
||||||
"axolotl",
|
|
||||||
"train",
|
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
|
||||||
"--num-processes",
|
|
||||||
"2",
|
|
||||||
"--main-process-port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _base_lora_fsdp2_config(temp_dir, **overrides):
|
|
||||||
"""Base config for LoRA + FSDP2 + kernel tests."""
|
|
||||||
cfg = {
|
|
||||||
"base_model": "Qwen/Qwen3-0.6B",
|
|
||||||
"sequence_len": 512,
|
|
||||||
"val_set_size": 0.0,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "tatsu-lab/alpaca",
|
|
||||||
"type": "alpaca",
|
|
||||||
"split": "train[:1%]",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 3,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 1e-4,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"bf16": True,
|
|
||||||
"fsdp_version": 2,
|
|
||||||
"fsdp_config": {
|
|
||||||
"offload_params": False,
|
|
||||||
"cpu_ram_efficient_loading": False,
|
|
||||||
"transformer_layer_cls_to_wrap": "Qwen3DecoderLayer",
|
|
||||||
"state_dict_type": "FULL_STATE_DICT",
|
|
||||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
|
||||||
"reshard_after_forward": True,
|
|
||||||
},
|
|
||||||
# Enable all LoRA kernels
|
|
||||||
"lora_mlp_kernel": True,
|
|
||||||
"lora_qkv_kernel": True,
|
|
||||||
"lora_o_kernel": True,
|
|
||||||
"lora_embedding_kernel": True,
|
|
||||||
"save_safetensors": True,
|
|
||||||
}
|
|
||||||
cfg.update(overrides)
|
|
||||||
return DictDefault(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
class TestFSDP2LoRAKernels:
|
|
||||||
"""Test LoRA kernels under FSDP2."""
|
|
||||||
|
|
||||||
@require_torch_2_7_0
|
|
||||||
def test_lora_kernels_basic(self, temp_dir):
|
|
||||||
"""Basic LoRA + kernels + FSDP2: no dropout, no bias, no DoRA."""
|
|
||||||
cfg = _base_lora_fsdp2_config(temp_dir)
|
|
||||||
_run_training(temp_dir, cfg)
|
|
||||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
|
||||||
|
|
||||||
@require_torch_2_7_0
|
|
||||||
def test_lora_kernels_with_dropout(self, temp_dir):
|
|
||||||
"""LoRA kernels + dropout + FSDP2."""
|
|
||||||
cfg = _base_lora_fsdp2_config(temp_dir, lora_dropout=0.1)
|
|
||||||
_run_training(temp_dir, cfg)
|
|
||||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
|
||||||
|
|
||||||
@require_torch_2_7_0
|
|
||||||
def test_lora_kernels_with_dora(self, temp_dir):
|
|
||||||
"""LoRA kernels + DoRA + FSDP2."""
|
|
||||||
cfg = _base_lora_fsdp2_config(temp_dir, peft_use_dora=True)
|
|
||||||
_run_training(temp_dir, cfg)
|
|
||||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
|
||||||
|
|
||||||
@require_torch_2_7_0
|
|
||||||
def test_lora_kernels_with_dora_and_dropout(self, temp_dir):
|
|
||||||
"""LoRA kernels + DoRA + dropout + FSDP2."""
|
|
||||||
cfg = _base_lora_fsdp2_config(
|
|
||||||
temp_dir,
|
|
||||||
peft_use_dora=True,
|
|
||||||
lora_dropout=0.05,
|
|
||||||
)
|
|
||||||
_run_training(temp_dir, cfg)
|
|
||||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
|
||||||
@@ -222,9 +222,9 @@ def test_model_specific_activation(model_name, expected_activation):
|
|||||||
|
|
||||||
|
|
||||||
def test_kernel_patch_conditions():
|
def test_kernel_patch_conditions():
|
||||||
"""Test that kernels ARE patched even with dropout and bias (now supported)."""
|
"""Test various conditions that should prevent kernel patching."""
|
||||||
test_configs = [
|
test_configs = [
|
||||||
# Dropout — kernels now support this
|
# Dropout prevents patching
|
||||||
{
|
{
|
||||||
"peft_type": "LORA",
|
"peft_type": "LORA",
|
||||||
"task_type": "CAUSAL_LM",
|
"task_type": "CAUSAL_LM",
|
||||||
@@ -234,7 +234,7 @@ def test_kernel_patch_conditions():
|
|||||||
"lora_dropout": 0.1,
|
"lora_dropout": 0.1,
|
||||||
"bias": "none",
|
"bias": "none",
|
||||||
},
|
},
|
||||||
# Bias — kernels now support this
|
# Bias prevents patching
|
||||||
{
|
{
|
||||||
"peft_type": "LORA",
|
"peft_type": "LORA",
|
||||||
"task_type": "CAUSAL_LM",
|
"task_type": "CAUSAL_LM",
|
||||||
@@ -252,14 +252,13 @@ def test_kernel_patch_conditions():
|
|||||||
model = PeftModelForCausalLM(model, peft_config)
|
model = PeftModelForCausalLM(model, peft_config)
|
||||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||||
|
|
||||||
|
# Should not patch
|
||||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||||
layer = patched_model.model.model.layers[0].mlp
|
layer = patched_model.model.model.layers[0].mlp
|
||||||
|
|
||||||
# Verify patches ARE applied (dropout and bias are now supported)
|
# Verify no patches applied
|
||||||
assert (
|
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
|
||||||
layer.forward.__func__ is apply_lora_mlp_swiglu
|
assert layer.forward.__func__ is not apply_lora_mlp_geglu
|
||||||
or layer.forward.__func__ is apply_lora_mlp_geglu
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_kernel_config_options():
|
def test_kernel_config_options():
|
||||||
@@ -512,7 +511,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
|
|||||||
|
|
||||||
|
|
||||||
def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
||||||
"""Test model loading with dropout non-zero DOES patch (now supported)."""
|
"""Test model loading with dropout non-zero should not patch."""
|
||||||
|
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
|
|
||||||
@@ -547,18 +546,31 @@ def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
|||||||
# Load config
|
# Load config
|
||||||
cfg = load_cfg(str(path))
|
cfg = load_cfg(str(path))
|
||||||
|
|
||||||
|
# Get original attention class
|
||||||
|
attention_cls = get_attention_cls_from_config(cfg)
|
||||||
|
|
||||||
|
# Store original state before patching
|
||||||
|
original_forward_method = attention_cls.forward
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
|
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
|
||||||
|
|
||||||
|
# We call modelloader as that's where the patches are applied
|
||||||
|
# despite the fact that we're not using it to load the model
|
||||||
model_loader = ModelLoader(cfg, tokenizer)
|
model_loader = ModelLoader(cfg, tokenizer)
|
||||||
|
|
||||||
# Apply patches — should succeed even with dropout > 0
|
# Apply patch
|
||||||
model_loader.patch_manager._apply_self_attention_lora_patch()
|
model_loader.patch_manager._apply_self_attention_lora_patch()
|
||||||
|
|
||||||
|
# Verify patch was not applied
|
||||||
|
assert attention_cls.forward == original_forward_method
|
||||||
|
|
||||||
|
# Apply apply_lora_kernel_patches
|
||||||
model_loader.patch_manager._apply_lora_kernel_patch(model)
|
model_loader.patch_manager._apply_lora_kernel_patch(model)
|
||||||
|
|
||||||
# Verify patches WERE applied (dropout is now supported by kernels)
|
# Verify patch was not applied
|
||||||
layers = get_layers(model)
|
layers = get_layers(model)
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
for self_attn in find_self_attn_in_layer(layer):
|
for self_attn in find_self_attn_in_layer(layer):
|
||||||
assert hasattr(self_attn, "apply_qkv")
|
assert not hasattr(self_attn, "apply_qkv")
|
||||||
assert hasattr(self_attn, "apply_o")
|
assert not hasattr(self_attn, "apply_o")
|
||||||
|
|||||||
75
tests/e2e/test_llama4_moe_aux_free.py
Normal file
75
tests/e2e/test_llama4_moe_aux_free.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""
|
||||||
|
E2E smoke test for Llama 4 aux-loss-free routing via plugin
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
|
|
||||||
|
class TestLlama4MoeAuxFree(unittest.TestCase):
|
||||||
|
"""Smoke test to ensure aux-free plugin patches Llama 4 MoE correctly."""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_llama4_aux_free_smoke(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "yujiepan/llama-4-tiny-random",
|
||||||
|
"tokenizer_config": "yujiepan/llama-4-tiny-random",
|
||||||
|
"trust_remote_code": False,
|
||||||
|
"flash_attention": False,
|
||||||
|
"sequence_len": 512,
|
||||||
|
"bf16": False,
|
||||||
|
"fp16": False,
|
||||||
|
"val_set_size": 0.02,
|
||||||
|
"special_tokens": {},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 1e-5,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 5,
|
||||||
|
"save_steps": 0,
|
||||||
|
"eval_steps": 0,
|
||||||
|
"save_first_step": False,
|
||||||
|
"plugins": [
|
||||||
|
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||||
|
],
|
||||||
|
"moe_balance_type": "noaux_tc",
|
||||||
|
"moe_update_rate": 0.01,
|
||||||
|
"moe_update_momentum": 0.9,
|
||||||
|
"moe_bias_cap": 2.0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
dataset_meta = load_datasets(cfg=cfg)
|
||||||
|
|
||||||
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
patched = next((m for m in model.modules() if hasattr(m, "_afb_bias")), None)
|
||||||
|
assert patched is not None, (
|
||||||
|
"Llama 4 MoE layer was not patched by aux-free plugin"
|
||||||
|
)
|
||||||
|
assert patched._afb_bias.ndim == 1
|
||||||
|
assert patched._afb_counts.ndim == 1
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
79
tests/e2e/test_moe_aux_free.py
Normal file
79
tests/e2e/test_moe_aux_free.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""
|
||||||
|
E2E smoke tests for Aux-Loss-Free MoE routing via plugin
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
|
|
||||||
|
class TestMoeAuxFree(unittest.TestCase):
|
||||||
|
"""Smoke tests to ensure aux-free plugin enables and runs on Mixtral tiny."""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_mixtral_aux_free_smoke(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
|
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||||
|
"flash_attention": False,
|
||||||
|
"sequence_len": 512,
|
||||||
|
"bf16": False,
|
||||||
|
"fp16": False,
|
||||||
|
"val_set_size": 0.02,
|
||||||
|
"special_tokens": {},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 1e-5,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 5,
|
||||||
|
"save_steps": 0,
|
||||||
|
"eval_steps": 0,
|
||||||
|
"save_first_step": False,
|
||||||
|
# Aux-free plugin and toggles
|
||||||
|
"plugins": [
|
||||||
|
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||||
|
],
|
||||||
|
"moe_balance_type": "noaux_tc",
|
||||||
|
"moe_update_rate": 0.01,
|
||||||
|
"moe_update_momentum": 0.9,
|
||||||
|
"moe_bias_cap": 2.0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
dataset_meta = load_datasets(cfg=cfg)
|
||||||
|
|
||||||
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
# Inspect model modules for a patched MoE layer
|
||||||
|
patched = None
|
||||||
|
for m in model.modules():
|
||||||
|
if hasattr(m, "_afb_patched") and m._afb_patched is True:
|
||||||
|
patched = m
|
||||||
|
break
|
||||||
|
assert patched is not None, "No MoE layer patched by aux-free plugin"
|
||||||
|
assert hasattr(patched, "_afb_bias") and patched._afb_bias.ndim == 1
|
||||||
|
assert hasattr(patched, "_afb_counts") and patched._afb_counts.ndim == 1
|
||||||
|
# ensure counts buffer got reset by callback (best effort)
|
||||||
|
assert torch.all(patched._afb_counts == 0)
|
||||||
|
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
91
tests/e2e/test_moe_aux_parity.py
Normal file
91
tests/e2e/test_moe_aux_parity.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
"""
|
||||||
|
Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-tiny.
|
||||||
|
Checks that aux-free training loss does not degrade beyond a small tolerance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
|
||||||
|
def _last_logged_loss(trainer) -> float | None:
|
||||||
|
# Scan log_history for the most recent entry with a 'loss' key
|
||||||
|
for entry in reversed(trainer.state.log_history):
|
||||||
|
if isinstance(entry, dict) and "loss" in entry:
|
||||||
|
return float(entry["loss"])
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class TestMoeAuxParity(unittest.TestCase):
|
||||||
|
@with_temp_dir
|
||||||
|
def test_mixtral_auxfree_vs_auxloss_loss_parity(self, temp_dir):
|
||||||
|
base_cfg = {
|
||||||
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
|
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||||
|
"flash_attention": False,
|
||||||
|
"sequence_len": 512,
|
||||||
|
"bf16": False,
|
||||||
|
"fp16": False,
|
||||||
|
"val_set_size": 0.02,
|
||||||
|
"special_tokens": {},
|
||||||
|
"datasets": [
|
||||||
|
{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"learning_rate": 1e-5,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 8,
|
||||||
|
"save_steps": 0,
|
||||||
|
"eval_steps": 0,
|
||||||
|
"save_first_step": False,
|
||||||
|
"seed": 42,
|
||||||
|
"logging_steps": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Baseline: aux-loss (gshard)
|
||||||
|
cfg0 = DictDefault(dict(base_cfg))
|
||||||
|
cfg0.output_dir = f"{temp_dir}/baseline"
|
||||||
|
cfg0 = validate_config(cfg0)
|
||||||
|
normalize_config(cfg0)
|
||||||
|
# baseline uses default aux-loss routing; no plugin registration
|
||||||
|
dataset_meta0 = load_datasets(cfg=cfg0)
|
||||||
|
model0, _, trainer0 = train(cfg=cfg0, dataset_meta=dataset_meta0)
|
||||||
|
loss0 = _last_logged_loss(trainer0)
|
||||||
|
assert loss0 is not None
|
||||||
|
|
||||||
|
# Release baseline resources before starting aux-free run
|
||||||
|
del model0, trainer0, dataset_meta0
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Aux-free: plugin + noaux_tc
|
||||||
|
cfg1 = DictDefault(dict(base_cfg))
|
||||||
|
cfg1.output_dir = f"{temp_dir}/auxfree"
|
||||||
|
cfg1.plugins = [
|
||||||
|
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||||
|
]
|
||||||
|
cfg1.moe_balance_type = "noaux_tc"
|
||||||
|
cfg1.moe_update_rate = 0.01
|
||||||
|
cfg1.moe_update_momentum = 0.9
|
||||||
|
cfg1.moe_bias_cap = 2.0
|
||||||
|
prepare_plugins(cfg1)
|
||||||
|
cfg1 = validate_config(cfg1)
|
||||||
|
normalize_config(cfg1)
|
||||||
|
dataset_meta1 = load_datasets(cfg=cfg1)
|
||||||
|
model1, _, trainer1 = train(cfg=cfg1, dataset_meta=dataset_meta1)
|
||||||
|
loss1 = _last_logged_loss(trainer1)
|
||||||
|
assert loss1 is not None
|
||||||
|
|
||||||
|
# Assert aux-free loss is within 10% of aux-loss baseline
|
||||||
|
assert loss1 <= 1.1 * loss0, f"aux-free loss {loss1} > 1.1 * baseline {loss0}"
|
||||||
76
tests/e2e/test_qwen3_moe_aux_free.py
Normal file
76
tests/e2e/test_qwen3_moe_aux_free.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""
|
||||||
|
E2E smoke test for Aux-Loss-Free MoE routing on Qwen3-MoE tiny
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
|
|
||||||
|
class TestQwen3MoeAuxFree(unittest.TestCase):
|
||||||
|
@with_temp_dir
|
||||||
|
def test_qwen3_moe_aux_free_smoke(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
|
||||||
|
"tokenizer_config": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
|
||||||
|
"flash_attention": False,
|
||||||
|
"sequence_len": 512,
|
||||||
|
"bf16": False,
|
||||||
|
"fp16": False,
|
||||||
|
"val_set_size": 0.02,
|
||||||
|
"special_tokens": {},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 1e-5,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 5,
|
||||||
|
"save_steps": 0,
|
||||||
|
"eval_steps": 0,
|
||||||
|
"save_first_step": False,
|
||||||
|
# Aux-free plugin and toggles
|
||||||
|
"plugins": [
|
||||||
|
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||||
|
],
|
||||||
|
"moe_balance_type": "noaux_tc",
|
||||||
|
"moe_update_rate": 0.01,
|
||||||
|
"moe_update_momentum": 0.9,
|
||||||
|
"moe_bias_cap": 2.0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
dataset_meta = load_datasets(cfg=cfg)
|
||||||
|
|
||||||
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
# check that at least one sparse MoE block has been patched
|
||||||
|
found = False
|
||||||
|
for m in model.modules():
|
||||||
|
if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(
|
||||||
|
m, "_afb_patched"
|
||||||
|
):
|
||||||
|
assert m._afb_patched is True
|
||||||
|
assert hasattr(m, "_afb_bias") and m._afb_bias.ndim == 1
|
||||||
|
assert hasattr(m, "_afb_counts") and m._afb_counts.ndim == 1
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
assert found, "No Qwen3-MoE sparse block patched by aux-free plugin"
|
||||||
|
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
74
tests/e2e/test_ring_moe_aux_free.py
Normal file
74
tests/e2e/test_ring_moe_aux_free.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""
|
||||||
|
E2E smoke test for Ring 2.0 aux-loss-free routing via plugin
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
|
|
||||||
|
class TestRingMoeAuxFree(unittest.TestCase):
|
||||||
|
"""Smoke test to ensure aux-free plugin patches Ring Mini 2.0 correctly."""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_ring_aux_free_smoke(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "yujiepan/ring-tiny-random",
|
||||||
|
"tokenizer_config": "yujiepan/ring-tiny-random",
|
||||||
|
"trust_remote_code": True,
|
||||||
|
"flash_attention": False,
|
||||||
|
"sequence_len": 512,
|
||||||
|
"bf16": False,
|
||||||
|
"fp16": False,
|
||||||
|
"val_set_size": 0.02,
|
||||||
|
"special_tokens": {},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 1e-5,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 5,
|
||||||
|
"save_steps": 0,
|
||||||
|
"eval_steps": 0,
|
||||||
|
"save_first_step": False,
|
||||||
|
# Aux-free plugin config
|
||||||
|
"plugins": [
|
||||||
|
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||||
|
],
|
||||||
|
"moe_balance_type": "noaux_tc",
|
||||||
|
"moe_update_rate": 0.01,
|
||||||
|
"moe_update_momentum": 0.9,
|
||||||
|
"moe_bias_cap": 2.0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
dataset_meta = load_datasets(cfg=cfg)
|
||||||
|
|
||||||
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
patched = next((m for m in model.modules() if hasattr(m, "_afb_bias")), None)
|
||||||
|
assert patched is not None, "Ring MoE layer was not patched by aux-free plugin"
|
||||||
|
assert patched._afb_bias.ndim == 1
|
||||||
|
assert patched._afb_counts.ndim == 1
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -12,7 +12,11 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from tbparse import SummaryReader
|
|
||||||
|
try:
|
||||||
|
from tbparse import SummaryReader
|
||||||
|
except ImportError: # pragma: no cover - optional dependency
|
||||||
|
SummaryReader = None
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -179,12 +183,16 @@ def check_tensorboard(
|
|||||||
tag: str,
|
tag: str,
|
||||||
lt_val: float,
|
lt_val: float,
|
||||||
assertion_err: str,
|
assertion_err: str,
|
||||||
rtol: float = 0.05,
|
rtol: float = 0.02,
|
||||||
gt_zero: bool = True,
|
gt_zero: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
helper function to parse and check tensorboard logs
|
helper function to parse and check tensorboard logs
|
||||||
"""
|
"""
|
||||||
|
if SummaryReader is None:
|
||||||
|
raise unittest.SkipTest(
|
||||||
|
"tbparse is not installed; skipping tensorboard assertions"
|
||||||
|
)
|
||||||
tb_log_path = most_recent_subdir(temp_run_dir)
|
tb_log_path = most_recent_subdir(temp_run_dir)
|
||||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||||
reader = SummaryReader(event_file)
|
reader = SummaryReader(event_file)
|
||||||
|
|||||||
@@ -1,229 +0,0 @@
|
|||||||
"""
|
|
||||||
Correctness tests for fused RMSNorm + SiLU Gate kernel.
|
|
||||||
|
|
||||||
Tests against the eager Qwen3_5RMSNormGated implementation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
pytest.importorskip("triton", reason="triton required for fused kernels")
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
pytest.skip("CUDA required for fused kernel tests", allow_module_level=True)
|
|
||||||
|
|
||||||
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
|
||||||
|
|
||||||
|
|
||||||
class EagerRMSNormGated(torch.nn.Module):
|
|
||||||
"""Reference implementation matching Qwen3_5RMSNormGated exactly."""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, hidden_states, gate=None):
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
||||||
hidden_states = self.weight * hidden_states.to(input_dtype)
|
|
||||||
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
|
|
||||||
return hidden_states.to(input_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def _sync_weights(eager_mod, fused_mod):
|
|
||||||
"""Copy weights from eager to fused module."""
|
|
||||||
fused_mod.weight.data.copy_(eager_mod.weight.data)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"shape",
|
|
||||||
[
|
|
||||||
(2, 128, 256),
|
|
||||||
(4, 64, 512),
|
|
||||||
(1, 32, 1024),
|
|
||||||
(2, 16, 2560), # Qwen3.5-4B hidden_size
|
|
||||||
(2, 16, 4096), # Qwen3.5-9B hidden_size
|
|
||||||
(1, 8, 5120), # Qwen3.5-27B hidden_size
|
|
||||||
(4, 16, 2048), # Qwen3.5-35B-A3B (MoE) hidden_size
|
|
||||||
(4, 16, 3072), # Qwen3.5-122B-A10B (MoE) hidden_size
|
|
||||||
],
|
|
||||||
)
|
|
||||||
class TestRMSNormGatedForward:
|
|
||||||
def test_output_matches_eager(self, dtype, shape):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
B, T, H = shape
|
|
||||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
|
||||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
|
||||||
|
|
||||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
_sync_weights(eager, fused)
|
|
||||||
|
|
||||||
y_eager = eager(X, gate=G)
|
|
||||||
y_fused = fused(X, gate=G)
|
|
||||||
|
|
||||||
if dtype == torch.float32:
|
|
||||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-5, rtol=1e-5)
|
|
||||||
else:
|
|
||||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
|
||||||
|
|
||||||
def test_output_shape(self, dtype, shape):
|
|
||||||
B, T, H = shape
|
|
||||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
|
||||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
|
||||||
|
|
||||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
y = fused(X, gate=G)
|
|
||||||
assert y.shape == (B, T, H)
|
|
||||||
assert y.dtype == dtype
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"shape",
|
|
||||||
[
|
|
||||||
(2, 32, 256),
|
|
||||||
(2, 16, 512),
|
|
||||||
(2, 16, 2560), # Qwen3.5-4B
|
|
||||||
(1, 8, 4096), # Qwen3.5-9B
|
|
||||||
(1, 8, 5120), # Qwen3.5-27B
|
|
||||||
(2, 16, 2048), # Qwen3.5-35B-A3B (MoE)
|
|
||||||
(2, 16, 3072), # Qwen3.5-122B-A10B (MoE)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
class TestRMSNormGatedBackward:
|
|
||||||
def test_grad_x(self, dtype, shape):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
B, T, H = shape
|
|
||||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
|
||||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
|
||||||
X_ref = X.detach().clone().requires_grad_(True)
|
|
||||||
G_ref = G.detach().clone().requires_grad_(True)
|
|
||||||
|
|
||||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
_sync_weights(eager, fused)
|
|
||||||
|
|
||||||
y_eager = eager(X_ref, gate=G_ref)
|
|
||||||
y_fused = fused(X, gate=G)
|
|
||||||
|
|
||||||
grad_out = torch.randn_like(y_eager)
|
|
||||||
y_eager.backward(grad_out)
|
|
||||||
y_fused.backward(grad_out)
|
|
||||||
|
|
||||||
if dtype == torch.float32:
|
|
||||||
atol, rtol = 1e-4, 1e-4
|
|
||||||
else:
|
|
||||||
atol, rtol = 5e-2, 5e-2
|
|
||||||
|
|
||||||
torch.testing.assert_close(X.grad, X_ref.grad, atol=atol, rtol=rtol)
|
|
||||||
|
|
||||||
def test_grad_gate(self, dtype, shape):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
B, T, H = shape
|
|
||||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
|
||||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
|
||||||
X_ref = X.detach().clone().requires_grad_(True)
|
|
||||||
G_ref = G.detach().clone().requires_grad_(True)
|
|
||||||
|
|
||||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
_sync_weights(eager, fused)
|
|
||||||
|
|
||||||
y_eager = eager(X_ref, gate=G_ref)
|
|
||||||
y_fused = fused(X, gate=G)
|
|
||||||
|
|
||||||
grad_out = torch.randn_like(y_eager)
|
|
||||||
y_eager.backward(grad_out)
|
|
||||||
y_fused.backward(grad_out)
|
|
||||||
|
|
||||||
if dtype == torch.float32:
|
|
||||||
atol, rtol = 1e-4, 1e-4
|
|
||||||
else:
|
|
||||||
atol, rtol = 5e-2, 5e-2
|
|
||||||
|
|
||||||
torch.testing.assert_close(G.grad, G_ref.grad, atol=atol, rtol=rtol)
|
|
||||||
|
|
||||||
def test_grad_weight(self, dtype, shape):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
B, T, H = shape
|
|
||||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
|
||||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
|
||||||
X_ref = X.detach().clone().requires_grad_(True)
|
|
||||||
G_ref = G.detach().clone().requires_grad_(True)
|
|
||||||
|
|
||||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
|
||||||
_sync_weights(eager, fused)
|
|
||||||
|
|
||||||
y_eager = eager(X_ref, gate=G_ref)
|
|
||||||
y_fused = fused(X, gate=G)
|
|
||||||
|
|
||||||
grad_out = torch.randn_like(y_eager)
|
|
||||||
y_eager.backward(grad_out)
|
|
||||||
y_fused.backward(grad_out)
|
|
||||||
|
|
||||||
if dtype == torch.float32:
|
|
||||||
atol, rtol = 1e-4, 1e-4
|
|
||||||
else:
|
|
||||||
atol, rtol = 5e-2, 5e-2
|
|
||||||
|
|
||||||
torch.testing.assert_close(
|
|
||||||
fused.weight.grad, eager.weight.grad, atol=atol, rtol=rtol
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestRMSNormGatedEdgeCases:
|
|
||||||
def test_gate_none_raises(self):
|
|
||||||
fused = FusedRMSNormGated(256).cuda()
|
|
||||||
X = torch.randn(2, 4, 256, device="cuda")
|
|
||||||
with pytest.raises(ValueError, match="requires a gate tensor"):
|
|
||||||
fused(X, gate=None)
|
|
||||||
|
|
||||||
def test_2d_input(self):
|
|
||||||
"""Test with (BxT, H) shaped input instead of (B, T, H)."""
|
|
||||||
torch.manual_seed(42)
|
|
||||||
H = 512
|
|
||||||
X = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
|
||||||
G = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
|
||||||
X_ref = X.detach().clone().requires_grad_(True)
|
|
||||||
G_ref = G.detach().clone().requires_grad_(True)
|
|
||||||
|
|
||||||
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
|
||||||
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
|
||||||
_sync_weights(eager, fused)
|
|
||||||
|
|
||||||
y_eager = eager(X_ref, gate=G_ref)
|
|
||||||
y_fused = fused(X, gate=G)
|
|
||||||
|
|
||||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
|
||||||
|
|
||||||
grad_out = torch.randn_like(y_eager)
|
|
||||||
y_eager.backward(grad_out)
|
|
||||||
y_fused.backward(grad_out)
|
|
||||||
|
|
||||||
torch.testing.assert_close(X.grad, X_ref.grad, atol=5e-2, rtol=5e-2)
|
|
||||||
torch.testing.assert_close(G.grad, G_ref.grad, atol=5e-2, rtol=5e-2)
|
|
||||||
|
|
||||||
def test_random_weight_init(self):
|
|
||||||
"""Test with non-default weight values."""
|
|
||||||
torch.manual_seed(123)
|
|
||||||
H = 256
|
|
||||||
X = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
|
|
||||||
G = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
|
|
||||||
|
|
||||||
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
|
||||||
# Randomize weights
|
|
||||||
eager.weight.data = torch.randn_like(eager.weight.data)
|
|
||||||
|
|
||||||
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
|
||||||
_sync_weights(eager, fused)
|
|
||||||
|
|
||||||
y_eager = eager(X, gate=G)
|
|
||||||
y_fused = fused(X, gate=G)
|
|
||||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
|
||||||
666
tests/unit/test_aux_free_adapters.py
Normal file
666
tests/unit/test_aux_free_adapters.py
Normal file
@@ -0,0 +1,666 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from importlib import util as importlib_util
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(**overrides):
|
||||||
|
defaults = dict(
|
||||||
|
moe_balance_type="noaux_tc",
|
||||||
|
moe_update_rate=0.1,
|
||||||
|
moe_update_momentum=0.9,
|
||||||
|
moe_bias_cap=2.0,
|
||||||
|
moe_afb_warmup_steps=0,
|
||||||
|
moe_bias_sync_group="world",
|
||||||
|
expert_parallel_size=1,
|
||||||
|
)
|
||||||
|
defaults.update(overrides)
|
||||||
|
return SimpleNamespace(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_bailing_modules():
|
||||||
|
repo_dir = snapshot_download(
|
||||||
|
repo_id="inclusionAI/Ring-mini-2.0",
|
||||||
|
allow_patterns=[
|
||||||
|
"configuration_bailing_moe_v2.py",
|
||||||
|
"modeling_bailing_moe_v2.py",
|
||||||
|
"__init__.py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
repo = Path(repo_dir)
|
||||||
|
config_path = repo / "configuration_bailing_moe_v2.py"
|
||||||
|
modeling_path = repo / "modeling_bailing_moe_v2.py"
|
||||||
|
|
||||||
|
config_name = "bailing_moe_v2.configuration_bailing_moe_v2"
|
||||||
|
if config_name not in sys.modules:
|
||||||
|
spec = importlib_util.spec_from_file_location(config_name, config_path)
|
||||||
|
module = importlib_util.module_from_spec(spec)
|
||||||
|
sys.modules[config_name] = module
|
||||||
|
sys.modules["configuration_bailing_moe_v2"] = module
|
||||||
|
assert spec.loader is not None
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
config_module = sys.modules[config_name]
|
||||||
|
|
||||||
|
modeling_name = "bailing_moe_v2.modeling_bailing_moe_v2"
|
||||||
|
if modeling_name not in sys.modules:
|
||||||
|
spec = importlib_util.spec_from_file_location(modeling_name, modeling_path)
|
||||||
|
module = importlib_util.module_from_spec(spec)
|
||||||
|
sys.modules[modeling_name] = module
|
||||||
|
sys.modules["modeling_bailing_moe_v2"] = module
|
||||||
|
assert spec.loader is not None
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
modeling_module = sys.modules[modeling_name]
|
||||||
|
|
||||||
|
BailingMoeV2Config = config_module.BailingMoeV2Config
|
||||||
|
BailingMoeV2SparseMoeBlock = modeling_module.BailingMoeV2SparseMoeBlock
|
||||||
|
|
||||||
|
return BailingMoeV2Config, BailingMoeV2SparseMoeBlock
|
||||||
|
|
||||||
|
|
||||||
|
def _build_bailing_model():
|
||||||
|
BailingConfig, BailingBlock = _load_bailing_modules()
|
||||||
|
config = BailingConfig(
|
||||||
|
hidden_size=16,
|
||||||
|
intermediate_size=32,
|
||||||
|
moe_intermediate_size=32,
|
||||||
|
num_experts=4,
|
||||||
|
num_shared_experts=None,
|
||||||
|
num_experts_per_tok=2,
|
||||||
|
n_group=1,
|
||||||
|
topk_group=1,
|
||||||
|
routed_scaling_factor=1.0,
|
||||||
|
)
|
||||||
|
block = BailingBlock(config)
|
||||||
|
|
||||||
|
class DummyModel(nn.Module):
|
||||||
|
def __init__(self, layer):
|
||||||
|
super().__init__()
|
||||||
|
self.block = layer
|
||||||
|
self.config = SimpleNamespace(model_type="bailing_moe")
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
return self.block(hidden_states)
|
||||||
|
|
||||||
|
return DummyModel(block), block
|
||||||
|
|
||||||
|
|
||||||
|
def _build_llama4_model():
|
||||||
|
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
|
||||||
|
|
||||||
|
# Build config without __post_init__ validation (works around a
|
||||||
|
# huggingface_hub strict-dataclass type mismatch for layer_types).
|
||||||
|
config = object.__new__(__import__("transformers").Llama4TextConfig)
|
||||||
|
config.__dict__.update(
|
||||||
|
hidden_size=16,
|
||||||
|
intermediate_size=32,
|
||||||
|
num_local_experts=4,
|
||||||
|
num_attention_heads=2,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
num_experts_per_tok=2,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
hidden_act="silu",
|
||||||
|
layer_types=None,
|
||||||
|
)
|
||||||
|
layer = Llama4TextMoe(config)
|
||||||
|
|
||||||
|
class DummyModel(nn.Module):
|
||||||
|
def __init__(self, moe_layer):
|
||||||
|
super().__init__()
|
||||||
|
self.moe = moe_layer
|
||||||
|
self.config = SimpleNamespace(model_type="llama4")
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
return self.moe(hidden_states)
|
||||||
|
|
||||||
|
return DummyModel(layer), layer
|
||||||
|
|
||||||
|
|
||||||
|
def _build_mixtral_model():
|
||||||
|
from transformers import MixtralConfig
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||||
|
|
||||||
|
config = MixtralConfig(
|
||||||
|
hidden_size=16,
|
||||||
|
intermediate_size=32,
|
||||||
|
num_local_experts=4,
|
||||||
|
num_experts_per_tok=2,
|
||||||
|
num_attention_heads=2,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
)
|
||||||
|
layer = MixtralSparseMoeBlock(config)
|
||||||
|
layer.config = config
|
||||||
|
|
||||||
|
class DummyModel(nn.Module):
|
||||||
|
def __init__(self, moe_layer):
|
||||||
|
super().__init__()
|
||||||
|
self.moe = moe_layer
|
||||||
|
self.config = SimpleNamespace(model_type="mixtral")
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
return self.moe(hidden_states)
|
||||||
|
|
||||||
|
return DummyModel(layer), layer
|
||||||
|
|
||||||
|
|
||||||
|
def _build_qwen35_moe_model():
|
||||||
|
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
|
||||||
|
Qwen3_5MoeTextConfig,
|
||||||
|
)
|
||||||
|
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
|
||||||
|
Qwen3_5MoeSparseMoeBlock,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = Qwen3_5MoeTextConfig(
|
||||||
|
hidden_size=16,
|
||||||
|
moe_intermediate_size=32,
|
||||||
|
shared_expert_intermediate_size=32,
|
||||||
|
num_experts=4,
|
||||||
|
num_experts_per_tok=2,
|
||||||
|
num_attention_heads=2,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
)
|
||||||
|
layer = Qwen3_5MoeSparseMoeBlock(config)
|
||||||
|
|
||||||
|
class DummyModel(nn.Module):
|
||||||
|
def __init__(self, moe_layer):
|
||||||
|
super().__init__()
|
||||||
|
self.moe = moe_layer
|
||||||
|
self.config = SimpleNamespace(model_type="qwen3_5_moe")
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
return self.moe(hidden_states)
|
||||||
|
|
||||||
|
return DummyModel(layer), layer
|
||||||
|
|
||||||
|
|
||||||
|
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
|
||||||
|
if args is None:
|
||||||
|
args = SimpleNamespace(logging_steps=1)
|
||||||
|
if state is None:
|
||||||
|
state = SimpleNamespace(global_step=1, log_history=[])
|
||||||
|
if control is None:
|
||||||
|
control = SimpleNamespace(
|
||||||
|
should_log=False,
|
||||||
|
should_evaluate=False,
|
||||||
|
should_save=False,
|
||||||
|
should_training_stop=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
class DummyTrainer:
|
||||||
|
def __init__(self, state_obj, control_obj):
|
||||||
|
self.state = state_obj
|
||||||
|
self.control = control_obj
|
||||||
|
|
||||||
|
def log(self, logs):
|
||||||
|
output = dict(logs)
|
||||||
|
output["step"] = self.state.global_step
|
||||||
|
self.state.log_history.append(output)
|
||||||
|
self.control.should_log = True
|
||||||
|
|
||||||
|
dummy_trainer = DummyTrainer(state, control)
|
||||||
|
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=dummy_trainer)
|
||||||
|
assert callbacks, "expected aux-free callback to be registered"
|
||||||
|
callback = callbacks[0]
|
||||||
|
callback.on_step_end(args=args, state=state, control=control)
|
||||||
|
return state, control
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuxFreeAdapters(unittest.TestCase):
|
||||||
|
def test_bailing_adapter_updates_counts_and_bias(self):
|
||||||
|
model, block = _build_bailing_model()
|
||||||
|
cfg = _cfg()
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
self.assertTrue(hasattr(block, "_afb_bias"))
|
||||||
|
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||||
|
block(hidden)
|
||||||
|
self.assertGreater(torch.count_nonzero(block._afb_counts), 0)
|
||||||
|
|
||||||
|
_run_callback(plugin, cfg)
|
||||||
|
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
|
||||||
|
self.assertFalse(
|
||||||
|
torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema))
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_llama4_adapter_biases_router_selection(self):
|
||||||
|
model, layer = _build_llama4_model()
|
||||||
|
cfg = _cfg()
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
self.assertTrue(hasattr(layer, "_afb_bias"))
|
||||||
|
hidden = torch.randn(2, 4, layer.hidden_dim)
|
||||||
|
layer(hidden)
|
||||||
|
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||||
|
|
||||||
|
_run_callback(plugin, cfg)
|
||||||
|
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
||||||
|
self.assertFalse(
|
||||||
|
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_bias_warmup_respected(self):
|
||||||
|
model, block = _build_bailing_model()
|
||||||
|
cfg = _cfg(moe_afb_warmup_steps=2)
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
def _step():
|
||||||
|
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||||
|
block(hidden)
|
||||||
|
_run_callback(plugin, cfg)
|
||||||
|
|
||||||
|
# Warmup steps should leave bias untouched.
|
||||||
|
_step()
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
|
||||||
|
)
|
||||||
|
|
||||||
|
_step()
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Third step exceeds warmup -> bias should update.
|
||||||
|
_step()
|
||||||
|
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
|
||||||
|
|
||||||
|
def test_mixtral_adapter_patches_router_not_forward(self):
|
||||||
|
"""Verify that aux-free patches the router (gate) only, and the
|
||||||
|
v5 block forward signature (single tensor return) is preserved."""
|
||||||
|
model, layer = _build_mixtral_model()
|
||||||
|
cfg = _cfg()
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
# Gate should be patched, not the block forward
|
||||||
|
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
|
||||||
|
self.assertTrue(getattr(layer, "_afb_patched", False))
|
||||||
|
|
||||||
|
# v5 block forward returns a single tensor (not a tuple with logits)
|
||||||
|
hidden = torch.randn(2, 3, layer.config.hidden_size)
|
||||||
|
out = layer(hidden)
|
||||||
|
self.assertIsInstance(out, torch.Tensor)
|
||||||
|
self.assertEqual(out.shape, hidden.shape)
|
||||||
|
|
||||||
|
# Counts should have been accumulated
|
||||||
|
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||||
|
_run_callback(plugin, cfg)
|
||||||
|
|
||||||
|
def test_mixtral_adapter_bias_affects_selection(self):
|
||||||
|
"""When bias is large for one expert, it should be selected more often."""
|
||||||
|
model, layer = _build_mixtral_model()
|
||||||
|
cfg = _cfg()
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
# Set a large bias for expert 0 to force its selection
|
||||||
|
layer._afb_bias.zero_()
|
||||||
|
layer._afb_bias[0] = 10.0
|
||||||
|
|
||||||
|
hidden = torch.randn(2, 8, layer.config.hidden_size)
|
||||||
|
num_tokens = 2 * 8 # batch * seq
|
||||||
|
layer(hidden)
|
||||||
|
|
||||||
|
# With top_k=2, expert 0 should appear in every token's selection
|
||||||
|
# (once per token = num_tokens counts, not num_tokens * top_k)
|
||||||
|
counts = layer._afb_counts.clone()
|
||||||
|
self.assertEqual(
|
||||||
|
int(counts[0].item()),
|
||||||
|
num_tokens,
|
||||||
|
msg="Expert 0 should be selected for every token when heavily biased",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_qwen35_moe_adapter_patches_router_and_preserves_shared_expert(self):
|
||||||
|
"""Verify Qwen 3.5 MoE: router is patched, shared expert is untouched,
|
||||||
|
output includes shared expert contribution."""
|
||||||
|
model, layer = _build_qwen35_moe_model()
|
||||||
|
cfg = _cfg()
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
# Gate should be patched
|
||||||
|
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
|
||||||
|
self.assertTrue(getattr(layer, "_afb_patched", False))
|
||||||
|
# Shared expert should be unmodified
|
||||||
|
self.assertTrue(hasattr(layer, "shared_expert"))
|
||||||
|
self.assertTrue(hasattr(layer, "shared_expert_gate"))
|
||||||
|
|
||||||
|
# Forward should return a single tensor (shared + routed)
|
||||||
|
hidden_size = layer.gate.hidden_dim
|
||||||
|
hidden = torch.randn(2, 3, hidden_size)
|
||||||
|
out = layer(hidden)
|
||||||
|
self.assertIsInstance(out, torch.Tensor)
|
||||||
|
self.assertEqual(out.shape, hidden.shape)
|
||||||
|
|
||||||
|
# Counts should have been accumulated
|
||||||
|
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||||
|
|
||||||
|
def test_qwen35_moe_adapter_bias_updates(self):
|
||||||
|
"""Full cycle: forward → callback → verify bias update for Qwen 3.5 MoE."""
|
||||||
|
model, layer = _build_qwen35_moe_model()
|
||||||
|
cfg = _cfg()
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
hidden_size = layer.gate.hidden_dim
|
||||||
|
hidden = torch.randn(2, 4, hidden_size)
|
||||||
|
layer(hidden)
|
||||||
|
|
||||||
|
# Bias should start at zero
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(layer._afb_bias, torch.zeros_like(layer._afb_bias))
|
||||||
|
)
|
||||||
|
|
||||||
|
_run_callback(plugin, cfg)
|
||||||
|
|
||||||
|
# After callback: counts reset, EMA updated, bias updated
|
||||||
|
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
||||||
|
self.assertFalse(
|
||||||
|
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_qwen35_moe_adapter_model_type_matching(self):
|
||||||
|
"""Verify the adapter matches both qwen3_5_moe and qwen3_5_moe_text."""
|
||||||
|
from axolotl.integrations.aux_free_router.adapters import Qwen35MoeAdapter
|
||||||
|
|
||||||
|
adapter = Qwen35MoeAdapter()
|
||||||
|
|
||||||
|
model_moe = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_5_moe"))
|
||||||
|
model_text = SimpleNamespace(
|
||||||
|
config=SimpleNamespace(model_type="qwen3_5_moe_text")
|
||||||
|
)
|
||||||
|
model_other = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_moe"))
|
||||||
|
|
||||||
|
self.assertTrue(adapter.matches(model_moe))
|
||||||
|
self.assertTrue(adapter.matches(model_text))
|
||||||
|
self.assertFalse(adapter.matches(model_other))
|
||||||
|
|
||||||
|
def test_ep_group_resolution_deferred_until_dist_ready(self):
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
self.skipTest(
|
||||||
|
"Cannot safely test deferred EP group resolution when a process group is already initialized"
|
||||||
|
)
|
||||||
|
|
||||||
|
model, block = _build_bailing_model()
|
||||||
|
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
self.assertIsNotNone(plugin._shim)
|
||||||
|
self.assertIsNone(plugin._shim.ep_group)
|
||||||
|
|
||||||
|
tmp_init = tempfile.NamedTemporaryFile(delete=False)
|
||||||
|
tmp_init.close()
|
||||||
|
init_method = f"file://{tmp_init.name}"
|
||||||
|
dist.init_process_group(
|
||||||
|
backend="gloo", init_method=init_method, world_size=1, rank=0
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||||
|
block(hidden)
|
||||||
|
_run_callback(
|
||||||
|
plugin,
|
||||||
|
cfg,
|
||||||
|
args=SimpleNamespace(logging_steps=1),
|
||||||
|
state=SimpleNamespace(global_step=1, log_history=[]),
|
||||||
|
control=SimpleNamespace(
|
||||||
|
should_log=False,
|
||||||
|
should_evaluate=False,
|
||||||
|
should_save=False,
|
||||||
|
should_training_stop=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
|
||||||
|
finally:
|
||||||
|
dist.destroy_process_group()
|
||||||
|
os.unlink(tmp_init.name)
|
||||||
|
|
||||||
|
def test_telemetry_logging(self):
|
||||||
|
model, layer = _build_mixtral_model()
|
||||||
|
cfg = _cfg()
|
||||||
|
plugin = AuxFreeMoEPlugin()
|
||||||
|
plugin.post_model_build(cfg, model)
|
||||||
|
|
||||||
|
hidden_dim = layer.config.hidden_size
|
||||||
|
hidden = torch.randn(2, 3, hidden_dim)
|
||||||
|
layer(hidden)
|
||||||
|
|
||||||
|
args = SimpleNamespace(logging_steps=1)
|
||||||
|
state = SimpleNamespace(global_step=1, log_history=[])
|
||||||
|
control = SimpleNamespace(
|
||||||
|
should_log=False,
|
||||||
|
should_evaluate=False,
|
||||||
|
should_save=False,
|
||||||
|
should_training_stop=False,
|
||||||
|
)
|
||||||
|
_run_callback(plugin, cfg, args=args, state=state, control=control)
|
||||||
|
|
||||||
|
self.assertTrue(control.should_log)
|
||||||
|
self.assertTrue(state.log_history)
|
||||||
|
telemetry = state.log_history[-1]
|
||||||
|
self.assertEqual(telemetry["step"], state.global_step)
|
||||||
|
self.assertIn("moe_afb/l0_load_min", telemetry)
|
||||||
|
self.assertIn("moe_afb/l0_load_max", telemetry)
|
||||||
|
self.assertIn("moe_afb/l0_bias_abs_max", telemetry)
|
||||||
|
|
||||||
|
def test_get_num_experts_v5_attribute_paths(self):
|
||||||
|
"""Verify get_num_experts works with v5 attribute layout where
|
||||||
|
num_experts is on gate/experts sub-modules, not the block."""
|
||||||
|
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
|
||||||
|
|
||||||
|
adapter = MixtralAdapter()
|
||||||
|
|
||||||
|
# Simulates v5 MixtralSparseMoeBlock (num_experts on gate, not block)
|
||||||
|
block = SimpleNamespace(
|
||||||
|
gate=SimpleNamespace(num_experts=8),
|
||||||
|
experts=SimpleNamespace(num_experts=8),
|
||||||
|
)
|
||||||
|
self.assertEqual(adapter.get_num_experts(block), 8)
|
||||||
|
|
||||||
|
# Also works when num_experts is directly on block
|
||||||
|
block2 = SimpleNamespace(num_experts=4)
|
||||||
|
self.assertEqual(adapter.get_num_experts(block2), 4)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuxFreeKernelComposition(unittest.TestCase):
|
||||||
|
"""Tests that aux-free bias composes correctly with kernel routing."""
|
||||||
|
|
||||||
|
def test_sonicmoe_softmax_routing_with_afb_bias(self):
|
||||||
|
"""SonicMoE softmax routing should use biased selection / unbiased weights."""
|
||||||
|
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||||
|
|
||||||
|
num_experts = 4
|
||||||
|
top_k = 2
|
||||||
|
hidden_dim = 16
|
||||||
|
T = 6
|
||||||
|
|
||||||
|
# Build a mock MoE block with gate attributes
|
||||||
|
gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
||||||
|
gate.top_k = top_k
|
||||||
|
gate.num_experts = num_experts
|
||||||
|
gate.norm_topk_prob = True
|
||||||
|
|
||||||
|
moe_block = SimpleNamespace(gate=gate)
|
||||||
|
hidden = torch.randn(T, hidden_dim)
|
||||||
|
|
||||||
|
# Baseline: no bias
|
||||||
|
scores_base, tok_base, exp_base, logits_base = softmax_topk_routing(
|
||||||
|
hidden, moe_block
|
||||||
|
)
|
||||||
|
self.assertEqual(scores_base.shape[0], T * top_k)
|
||||||
|
|
||||||
|
# Now register aux-free buffers and set heavy bias on expert 0
|
||||||
|
moe_block._afb_bias = torch.zeros(num_experts)
|
||||||
|
moe_block._afb_bias[0] = 100.0
|
||||||
|
moe_block._afb_counts = torch.zeros(num_experts)
|
||||||
|
|
||||||
|
scores_biased, tok_biased, exp_biased, logits_biased = softmax_topk_routing(
|
||||||
|
hidden, moe_block
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expert 0 should be selected for every token
|
||||||
|
self.assertTrue(
|
||||||
|
(exp_biased == 0).any(),
|
||||||
|
"Expert 0 should appear in selections when heavily biased",
|
||||||
|
)
|
||||||
|
# Counts should have been accumulated
|
||||||
|
self.assertGreater(moe_block._afb_counts[0].item(), 0)
|
||||||
|
# Total counts should equal T * top_k
|
||||||
|
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
|
||||||
|
|
||||||
|
def test_sonicmoe_routing_without_bias_unchanged(self):
|
||||||
|
"""Without _afb_bias, routing should produce identical results."""
|
||||||
|
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||||
|
|
||||||
|
num_experts = 4
|
||||||
|
top_k = 2
|
||||||
|
hidden_dim = 16
|
||||||
|
|
||||||
|
gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
||||||
|
gate.top_k = top_k
|
||||||
|
gate.num_experts = num_experts
|
||||||
|
gate.norm_topk_prob = True
|
||||||
|
|
||||||
|
moe_block = SimpleNamespace(gate=gate)
|
||||||
|
hidden = torch.randn(4, hidden_dim)
|
||||||
|
|
||||||
|
# Without _afb_bias attribute
|
||||||
|
scores1, _, exp1, _ = softmax_topk_routing(hidden, moe_block)
|
||||||
|
|
||||||
|
# With _afb_bias = zeros (should be equivalent)
|
||||||
|
moe_block._afb_bias = torch.zeros(num_experts)
|
||||||
|
moe_block._afb_counts = torch.zeros(num_experts)
|
||||||
|
scores2, _, exp2, _ = softmax_topk_routing(hidden, moe_block)
|
||||||
|
|
||||||
|
torch.testing.assert_close(scores1, scores2)
|
||||||
|
torch.testing.assert_close(exp1, exp2)
|
||||||
|
|
||||||
|
@unittest.skipUnless(
|
||||||
|
importlib_util.find_spec("triton") is not None,
|
||||||
|
"triton not installed (required by scattermoe)",
|
||||||
|
)
|
||||||
|
def test_scattermoe_softmax_routing_with_afb_bias(self):
|
||||||
|
"""ScatterMoE softmax routing should use biased selection / unbiased weights."""
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||||
|
_softmax_topk_route,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_experts = 4
|
||||||
|
top_k = 2
|
||||||
|
hidden_dim = 16
|
||||||
|
T = 6
|
||||||
|
|
||||||
|
gate_weight = torch.randn(num_experts, hidden_dim)
|
||||||
|
base_gate = SimpleNamespace(
|
||||||
|
top_k=top_k,
|
||||||
|
num_experts=num_experts,
|
||||||
|
norm_topk_prob=True,
|
||||||
|
weight=gate_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
moe_block = SimpleNamespace()
|
||||||
|
hidden = torch.randn(T, hidden_dim)
|
||||||
|
|
||||||
|
# Baseline without bias
|
||||||
|
w_base, e_base, _, _ = _softmax_topk_route(
|
||||||
|
moe_block, base_gate, hidden, gate_weight, None
|
||||||
|
)
|
||||||
|
|
||||||
|
# With heavy bias on expert 0
|
||||||
|
moe_block._afb_bias = torch.zeros(num_experts)
|
||||||
|
moe_block._afb_bias[0] = 100.0
|
||||||
|
moe_block._afb_counts = torch.zeros(num_experts)
|
||||||
|
|
||||||
|
w_biased, e_biased, _, _ = _softmax_topk_route(
|
||||||
|
moe_block, base_gate, hidden, gate_weight, None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expert 0 should appear in all selections
|
||||||
|
self.assertTrue((e_biased == 0).any())
|
||||||
|
# Counts accumulated
|
||||||
|
self.assertGreater(moe_block._afb_counts[0].item(), 0)
|
||||||
|
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
|
||||||
|
|
||||||
|
def test_kernel_routing_skips_router_patch(self):
|
||||||
|
"""When a kernel backend has patched the block class, the adapter
|
||||||
|
should skip patching the router (buffers are still registered)."""
|
||||||
|
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
|
||||||
|
|
||||||
|
adapter = MixtralAdapter()
|
||||||
|
|
||||||
|
# Create a mock layer whose class has _original_forward (SonicMoE marker)
|
||||||
|
class PatchedBlock(nn.Module):
|
||||||
|
_original_forward = True # SonicMoE marker
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.gate = nn.Linear(16, 4, bias=False)
|
||||||
|
self.gate.top_k = 2
|
||||||
|
self.gate.num_experts = 4
|
||||||
|
self.gate.hidden_dim = 16
|
||||||
|
self.experts = nn.Linear(16, 16) # placeholder
|
||||||
|
|
||||||
|
layer = PatchedBlock()
|
||||||
|
self.assertTrue(adapter.uses_kernel_routing(layer))
|
||||||
|
|
||||||
|
# Gate should NOT be patched (kernel handles routing)
|
||||||
|
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
|
||||||
|
|
||||||
|
def test_adapter_buffers_registered_even_with_kernel(self):
|
||||||
|
"""Even when kernel routing is active, aux-free buffers must be
|
||||||
|
registered on the MoE block so the kernel routing can find them."""
|
||||||
|
from axolotl.integrations.aux_free_router.adapters import (
|
||||||
|
LayerHandle,
|
||||||
|
MixtralAdapter,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.aux_free_router.core import (
|
||||||
|
AuxFreeConfig,
|
||||||
|
AuxFreeShim,
|
||||||
|
AuxFreeState,
|
||||||
|
)
|
||||||
|
|
||||||
|
class PatchedBlock(nn.Module):
|
||||||
|
_original_forward = True
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.gate = nn.Linear(16, 4, bias=False)
|
||||||
|
self.gate.top_k = 2
|
||||||
|
self.gate.num_experts = 4
|
||||||
|
self.gate.hidden_dim = 16
|
||||||
|
self.experts = nn.Linear(16, 16)
|
||||||
|
|
||||||
|
layer = PatchedBlock()
|
||||||
|
adapter = MixtralAdapter()
|
||||||
|
cfg = AuxFreeConfig()
|
||||||
|
state = AuxFreeState(
|
||||||
|
num_layers=1, num_experts=4, device=torch.device("cpu"), cfg=cfg
|
||||||
|
)
|
||||||
|
shim = AuxFreeShim(state=state)
|
||||||
|
handle = LayerHandle(layer=layer, layer_idx=0, num_experts=4, top_k=2)
|
||||||
|
|
||||||
|
adapter.prepare(layer, handle, shim)
|
||||||
|
|
||||||
|
# Buffers should be registered for kernel routing to use
|
||||||
|
self.assertTrue(hasattr(layer, "_afb_bias"))
|
||||||
|
self.assertTrue(hasattr(layer, "_afb_counts"))
|
||||||
|
self.assertTrue(hasattr(layer, "_afb_ema"))
|
||||||
|
# But gate should NOT be patched
|
||||||
|
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -28,22 +28,20 @@ class TestLoRAConfigValidation:
|
|||||||
result = validate_config(valid_config)
|
result = validate_config(valid_config)
|
||||||
assert result["adapter"] == "lora"
|
assert result["adapter"] == "lora"
|
||||||
|
|
||||||
# DoRA is now compatible with lora kernels
|
with pytest.raises(ValueError, match="not compatible with DoRA"):
|
||||||
dora_kernel_config = DictDefault(
|
invalid_config = DictDefault(
|
||||||
{
|
{
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_mlp_kernel": True,
|
"lora_mlp_kernel": True,
|
||||||
"peft_use_dora": True,
|
"peft_use_dora": True,
|
||||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||||
"micro_batch_size": 1,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"learning_rate": 1e-5,
|
"learning_rate": 1e-5,
|
||||||
"base_model": "dummy_model",
|
"base_model": "dummy_model",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = validate_config(dora_kernel_config)
|
validate_config(invalid_config)
|
||||||
assert result["lora_mlp_kernel"] is True
|
|
||||||
assert result["peft_use_dora"] is True
|
|
||||||
|
|
||||||
def test_qlora_4bit_validation(self):
|
def test_qlora_4bit_validation(self):
|
||||||
"""Test QLoRA 4-bit configuration validation"""
|
"""Test QLoRA 4-bit configuration validation"""
|
||||||
|
|||||||
@@ -38,11 +38,6 @@ class TestLoRAParameterFreezing:
|
|||||||
|
|
||||||
mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype)
|
mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype)
|
||||||
mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype)
|
mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype)
|
||||||
mock_layer.lora_B["default"].bias = None
|
|
||||||
|
|
||||||
# Required by get_lora_parameters for dropout/DoRA extraction
|
|
||||||
mock_layer.lora_dropout = {}
|
|
||||||
mock_layer.lora_magnitude_vector = None
|
|
||||||
else:
|
else:
|
||||||
mock_layer.weight = base_layer.weight
|
mock_layer.weight = base_layer.weight
|
||||||
mock_layer.bias = base_layer.bias
|
mock_layer.bias = base_layer.bias
|
||||||
@@ -53,7 +48,7 @@ class TestLoRAParameterFreezing:
|
|||||||
"""Test that LoRA parameters are None when adapters are disabled."""
|
"""Test that LoRA parameters are None when adapters are disabled."""
|
||||||
layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)
|
layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||||
|
|
||||||
# Base parameters should be returned
|
# Base parameters should be returned
|
||||||
assert W is not None
|
assert W is not None
|
||||||
@@ -67,7 +62,7 @@ class TestLoRAParameterFreezing:
|
|||||||
"""Test that LoRA parameters are None when adapters are merged."""
|
"""Test that LoRA parameters are None when adapters are merged."""
|
||||||
layer = self.create_mock_lora_layer(has_adapters=True, merged=True)
|
layer = self.create_mock_lora_layer(has_adapters=True, merged=True)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||||
|
|
||||||
# Base parameters should be returned
|
# Base parameters should be returned
|
||||||
assert W is not None
|
assert W is not None
|
||||||
@@ -82,7 +77,7 @@ class TestLoRAParameterFreezing:
|
|||||||
"""Test parameter behavior when no adapters are present."""
|
"""Test parameter behavior when no adapters are present."""
|
||||||
layer = self.create_mock_lora_layer(has_adapters=False)
|
layer = self.create_mock_lora_layer(has_adapters=False)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||||
|
|
||||||
# Base parameters should be returned
|
# Base parameters should be returned
|
||||||
assert W is not None
|
assert W is not None
|
||||||
@@ -99,7 +94,7 @@ class TestLoRAParameterFreezing:
|
|||||||
has_adapters=True, adapters_disabled=False, merged=False
|
has_adapters=True, adapters_disabled=False, merged=False
|
||||||
)
|
)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||||
|
|
||||||
# All parameters should be returned
|
# All parameters should be returned
|
||||||
assert W is not None
|
assert W is not None
|
||||||
@@ -115,7 +110,7 @@ class TestLoRAParameterFreezing:
|
|||||||
has_adapters=True, adapters_disabled=False, merged=False
|
has_adapters=True, adapters_disabled=False, merged=False
|
||||||
)
|
)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||||
|
|
||||||
# Check shape consistency
|
# Check shape consistency
|
||||||
assert W.shape == (512, 256)
|
assert W.shape == (512, 256)
|
||||||
@@ -129,7 +124,7 @@ class TestLoRAParameterFreezing:
|
|||||||
has_adapters=True, adapters_disabled=False, merged=False
|
has_adapters=True, adapters_disabled=False, merged=False
|
||||||
)
|
)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||||
|
|
||||||
assert W.dtype == self.dtype
|
assert W.dtype == self.dtype
|
||||||
assert b.dtype == self.dtype
|
assert b.dtype == self.dtype
|
||||||
@@ -143,7 +138,7 @@ class TestLoRAParameterFreezing:
|
|||||||
quant_state_mock = Mock()
|
quant_state_mock = Mock()
|
||||||
layer.base_layer.weight.quant_state = quant_state_mock
|
layer.base_layer.weight.quant_state = quant_state_mock
|
||||||
|
|
||||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||||
|
|
||||||
assert quant_state == quant_state_mock
|
assert quant_state == quant_state_mock
|
||||||
|
|
||||||
@@ -162,7 +157,7 @@ class TestLoRAParameterFreezing:
|
|||||||
|
|
||||||
layer.active_adapters = ["adapter2"]
|
layer.active_adapters = ["adapter2"]
|
||||||
|
|
||||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||||
|
|
||||||
assert s == 0.2
|
assert s == 0.2
|
||||||
assert torch.equal(A, layer.lora_A["adapter2"].weight)
|
assert torch.equal(A, layer.lora_A["adapter2"].weight)
|
||||||
@@ -197,13 +192,13 @@ class TestLoRAParameterFreezingIntegration:
|
|||||||
model = get_peft_model(base_model, lora_config)
|
model = get_peft_model(base_model, lora_config)
|
||||||
lora_layer = model.base_model.model.linear
|
lora_layer = model.base_model.model.linear
|
||||||
# Test with adapters enabled
|
# Test with adapters enabled
|
||||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer)
|
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
|
||||||
assert A is not None
|
assert A is not None
|
||||||
assert B is not None
|
assert B is not None
|
||||||
assert s is not None
|
assert s is not None
|
||||||
# Test with adapters disabled
|
# Test with adapters disabled
|
||||||
model.disable_adapter_layers()
|
model.disable_adapter_layers()
|
||||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer)
|
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
|
||||||
assert A is None
|
assert A is None
|
||||||
assert B is None
|
assert B is None
|
||||||
assert s is None
|
assert s is None
|
||||||
|
|||||||
Reference in New Issue
Block a user