Compare commits

..

5 Commits

Author SHA1 Message Date
Wing Lian
db6af43f3b chore: lint 2026-03-23 04:54:00 +00:00
Wing Lian
35d06c8087 add textui 2026-03-23 04:54:00 +00:00
Wing Lian
0e583efeaa increase rtol, codecov informational only, don't silently fail errors w curl (#3534) [skip ci] 2026-03-22 13:54:03 -04:00
Wing Lian
b3289fd190 feat: LoRA kernel support for bias, dropout, dora, embeddings (#3528) [skip ci]
* feat: LoRA kernel support for bias, dropout, dora, embeddings

* chore: lint

* chore: lint

* address PR feedback, add regression tests, add fsdp2 tests for lora kernels

* update tests for new sigs

* update tests now that bias and dropout are supported
2026-03-22 13:53:19 -04:00
Wing Lian
a67392c427 liger support for qwen 3.5 and fused rmsnorm+gated (#3531) [skip ci]
* liger support for qwen 3.5 and fused rmsnorm+gated

* support for qwen 3.5 moe

* fix version ref

* fixups for PR code review
2026-03-22 13:19:21 -04:00
63 changed files with 5692 additions and 502 deletions

View File

@@ -91,6 +91,7 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
type=click.Path(exists=True, path_type=str),
help="YAML config for sweeping hyperparameters",
)
@click.option("--tui", is_flag=True, default=False, help="Enable TUI dashboard")
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
@@ -101,6 +102,7 @@ def train(
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
cloud: str | None = None,
sweep: str | None = None,
tui: bool = False,
**kwargs,
):
"""
@@ -118,6 +120,10 @@ def train(
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
# Handle --tui flag: set env var so subprocess workers pick it up
if tui:
os.environ["AXOLOTL_TUI"] = "1"
# Handle Ray launcher override
_launcher = None if kwargs.get("use_ray") else launcher

View File

@@ -2,6 +2,7 @@
import gc
import os
import queue
from pathlib import Path
from typing import Union
@@ -34,22 +35,101 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
plugin_manager = PluginManager.get_instance()
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
if not dataset_meta:
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
# Start TUI early (before data loading) so it captures preprocessing events
tui_renderer = None
tui_queue: queue.Queue | None = None
is_rank_0 = int(os.getenv("LOCAL_RANK", "0")) == 0
if is_rank_0:
from axolotl.train import _is_tui_enabled
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
if _is_tui_enabled(cfg):
import queue as _queue
del model, tokenizer, trainer
from axolotl.train import _get_tui_config
from axolotl.tui.config import TUIConfig
from axolotl.tui.renderer import TUIRenderer
gc.collect()
tui_config_dict = _get_tui_config(cfg)
tui_config = (
TUIConfig(**tui_config_dict)
if isinstance(tui_config_dict, dict)
else tui_config_dict
)
tui_queue = _queue.Queue(maxsize=4096)
tui_renderer = TUIRenderer(config=tui_config, metric_queue=tui_queue)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train_unload(cfg)
# Send initial run info
model_name = cfg.base_model or ""
training_mode = str(cfg.rl) if cfg.rl else "sft"
world_size = int(os.environ.get("WORLD_SIZE", 1))
try:
tui_queue.put_nowait(
{
"type": "run_info",
"model_name": model_name,
"training_mode": training_mode,
"world_size": world_size,
}
)
except _queue.Full:
pass
tui_renderer.start()
# Attach logging handler early
import logging
from axolotl.tui.callback import _TUILogHandler
_early_log_handler = _TUILogHandler(
tui_queue, min_level=tui_config.log_level
)
_early_log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
# Attach to BOTH root and axolotl loggers because axolotl logger
# has propagate=False so root handler never sees axolotl.* messages
root_logger = logging.getLogger()
root_logger.addHandler(_early_log_handler)
axolotl_logger = logging.getLogger("axolotl")
axolotl_logger.addHandler(_early_log_handler)
# Stash refs on cfg so train() can reuse the renderer
cfg._tui_renderer = tui_renderer
cfg._tui_queue = tui_queue
cfg._tui_early_log_handler = _early_log_handler
try:
plugin_manager = PluginManager.get_instance()
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
if not dataset_meta:
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
del model, tokenizer, trainer
gc.collect()
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train_unload(cfg)
finally:
# If the TUI renderer started early but train() didn't get to stop it
# (e.g., error during data loading), clean up here
if tui_renderer is not None and not tui_renderer._stop_event.is_set():
try:
if tui_queue is not None:
tui_queue.put_nowait({"type": "done"})
except queue.Full:
pass
tui_renderer.stop()
# Remove early log handler from both root and axolotl loggers
if hasattr(cfg, "_tui_early_log_handler"):
import logging
logging.getLogger().removeHandler(cfg._tui_early_log_handler)
logging.getLogger("axolotl").removeHandler(cfg._tui_early_log_handler)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):

View File

@@ -30,6 +30,15 @@ class LigerArgs(BaseModel):
liger_rope: bool | None = None
liger_rms_norm: bool | None = None
liger_rms_norm_gated: bool | None = Field(
default=None,
json_schema_extra={
"description": (
"Enables fused RMSNorm+SiLU gate Triton kernel for models with "
"gated RMSNorm (e.g. Qwen3.5 / Qwen3.5 MoE linear attention layers)."
)
},
)
liger_layer_norm: bool | None = None
liger_swiglu: bool | None = None
liger_glu_activation: bool | None = None

View File

@@ -0,0 +1,175 @@
"""
Liger FLCE for Qwen3.5. Based on transformers v5.3.0.
"""
import sys
from copy import deepcopy
from typing import Optional, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> CausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_liger_kernel_to_qwen3_5(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
rms_norm_gated: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 models.
Note: Qwen3_5RMSNorm uses zero-init weight with offset 1.0 (like Gemma),
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
Qwen3_5RMSNormGated (used in linear attention layers). Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3_5.modeling_qwen3_5 # noqa: F401
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
modeling_qwen3_5 = sys.modules["transformers.models.qwen3_5.modeling_qwen3_5"]
if rms_norm:
# Qwen3_5RMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
class LigerRMSNormForQwen3_5(LigerRMSNorm):
def __init__(self, dim, eps=1e-6, **kwargs):
super().__init__(
dim,
eps=eps,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3_5
if rms_norm_gated:
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
modeling_qwen3_5.Qwen3_5RMSNormGated = FusedRMSNormGated
if glu_activation:
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
config = deepcopy(config)
if intermediate_size is not None:
config.intermediate_size = intermediate_size
return LigerSwiGLUMLP(config, **kwargs)
modeling_qwen3_5.Qwen3_5MLP = _liger_swiglu_mlp_wrapper
if layer_norm:
modeling_qwen3_5.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_qwen3_5.Qwen3_5ForCausalLM.forward = lce_forward

View File

@@ -0,0 +1,198 @@
"""
Liger FLCE for Qwen3.5 MoE. Based on transformers v5.3.0.
"""
import sys
from copy import deepcopy
from typing import Optional, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values=None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> MoeCausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
load_balancing_loss_func,
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits,
labels,
self.vocab_size,
**kwargs,
)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
def apply_liger_kernel_to_qwen3_5_moe(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
rms_norm_gated: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 MoE models.
Note: Qwen3_5MoeRMSNorm uses zero-init weight with offset 1.0 (like Gemma),
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
Qwen3_5MoeRMSNormGated (used in linear attention layers). Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3_5_moe.modeling_qwen3_5_moe # noqa: F401
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
modeling_mod = sys.modules["transformers.models.qwen3_5_moe.modeling_qwen3_5_moe"]
if rms_norm:
# Qwen3_5MoeRMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
class LigerRMSNormForQwen3_5Moe(LigerRMSNorm):
def __init__(self, dim, eps=1e-6, **kwargs):
super().__init__(
dim,
eps=eps,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
modeling_mod.Qwen3_5MoeRMSNorm = LigerRMSNormForQwen3_5Moe
if rms_norm_gated:
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
modeling_mod.Qwen3_5MoeRMSNormGated = FusedRMSNormGated
if glu_activation:
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
config = deepcopy(config)
if intermediate_size is not None:
config.intermediate_size = intermediate_size
return LigerSwiGLUMLP(config, **kwargs)
modeling_mod.Qwen3_5MoeMLP = _liger_swiglu_mlp_wrapper
if layer_norm:
modeling_mod.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_mod.Qwen3_5MoeForCausalLM.forward = lce_forward

View File

@@ -174,6 +174,19 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_5":
from axolotl.integrations.liger.models.qwen3_5 import (
apply_liger_kernel_to_qwen3_5,
)
apply_liger_kernel_to_qwen3_5(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_moe":
from axolotl.integrations.liger.models.qwen3_moe import (
apply_liger_kernel_to_qwen3_moe,
@@ -186,6 +199,19 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_5_moe":
from axolotl.integrations.liger.models.qwen3_5_moe import (
apply_liger_kernel_to_qwen3_5_moe,
)
apply_liger_kernel_to_qwen3_5_moe(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "granitemoe":
from liger_kernel.transformers import apply_liger_kernel_to_granite

147
src/axolotl/kernels/dora.py Normal file
View File

@@ -0,0 +1,147 @@
"""
Triton kernels for DoRA (Weight-Decomposed Low-Rank Adaptation).
Fuses the weight norm computation and magnitude scaling to avoid
materializing the full [out_features, in_features] combined weight matrix.
The B@A product is computed row-by-row inside the kernel.
"""
import torch
import triton
import triton.language as tl
from .quantize import dequantize
@triton.jit
def _dora_fused_norm_kernel(
# Pointers
W_ptr, # base weight [out, in] (dequantized, row-major)
B_ptr, # LoRA B [out, rank] (row-major)
A_ptr, # LoRA A [rank, in] (row-major)
mag_ptr, # magnitude vector [out]
out_ptr, # output mag_norm_scale [out]
# Shapes
out_features,
in_features,
rank,
# Scaling
lora_scale, # float scaling factor
# Block sizes
BLOCK_IN: tl.constexpr,
BLOCK_R: tl.constexpr, # >= rank, power of 2
):
"""Compute mag_norm_scale[i] = magnitude[i] / ||W[i,:] + s * (B[i,:] @ A)[:] ||_2
Each program handles one output row. B[row,:] is loaded once (small),
then we tile over in_features computing the dot product with A[:,tile]
and accumulating the squared norm.
This avoids materializing the full [out, in] B@A matrix.
"""
row = tl.program_id(0)
if row >= out_features:
return
# Accumulate squared norm across tiles of in_features
norm_sq_acc = tl.zeros([BLOCK_IN], dtype=tl.float32)
for start in range(0, in_features, BLOCK_IN):
cols = start + tl.arange(0, BLOCK_IN)
col_mask = cols < in_features
# Load W[row, cols]
w_vals = tl.load(
W_ptr + row * in_features + cols,
mask=col_mask,
other=0.0,
).to(tl.float32)
# Compute (B[row,:] @ A[:, cols]) for this tile
# Load B[row, r] as scalar and A[r, cols] as vector for each r
ba_vals = tl.zeros([BLOCK_IN], dtype=tl.float32)
for r in tl.static_range(BLOCK_R):
# Load scalar B[row, r]
b_val = tl.load(
B_ptr + row * rank + r,
mask=(r < rank),
other=0.0,
).to(tl.float32)
# Load vector A[r, cols]
a_vals = tl.load(
A_ptr + r * in_features + cols,
mask=(col_mask & (r < rank)),
other=0.0,
).to(tl.float32)
ba_vals += b_val * a_vals
# Combined: W + s * (B @ A)
combined = w_vals + lora_scale * ba_vals
# Accumulate squared values
norm_sq_acc += tl.where(col_mask, combined * combined, 0.0)
# Reduce to scalar norm
norm_sq = tl.sum(norm_sq_acc, axis=0)
norm = tl.sqrt(norm_sq + 1e-12) # epsilon for numerical stability
# Load magnitude and compute scale
mag = tl.load(mag_ptr + row).to(tl.float32)
scale = mag / norm
tl.store(out_ptr + row, scale)
def triton_dora_scale(
W: torch.Tensor,
W_quant,
A: torch.Tensor,
B: torch.Tensor,
s: float,
magnitude: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
"""Compute DoRA mag_norm_scale using fused Triton kernel.
Computes B@A row-by-row inside the kernel, avoiding the full
[out_features, in_features] materialization.
Args:
W: base weight [out, in] (possibly quantized)
W_quant: quantization state
A: LoRA A [rank, in]
B: LoRA B [out, rank]
s: LoRA scaling factor
magnitude: learned magnitude [out]
dtype: compute dtype
Returns:
mag_norm_scale: [out] tensor = magnitude / ||W + s * B @ A||_2
"""
# Dequantize W to [out, in]
W_full = dequantize(W.t(), W_quant).t().contiguous().to(dtype)
out_features, in_features = W_full.shape
rank = A.shape[0]
out = torch.empty(out_features, dtype=dtype, device=W.device)
# Block sizes
BLOCK_IN = triton.next_power_of_2(min(in_features, 2048))
BLOCK_R = triton.next_power_of_2(rank)
_dora_fused_norm_kernel[(out_features,)](
W_full,
B.contiguous().to(dtype),
A.contiguous().to(dtype),
magnitude.contiguous(),
out,
out_features=out_features,
in_features=in_features,
rank=rank,
lora_scale=s,
BLOCK_IN=BLOCK_IN,
BLOCK_R=BLOCK_R,
)
return out.detach()

File diff suppressed because it is too large Load Diff

View File

@@ -105,6 +105,10 @@ def dequantize(
# Extract quantization state
if not isinstance(quant_state, list):
# New style quant_state class
# Non-double-quantized models have offset=None and state2=None
if quant_state.offset is None or quant_state.state2 is None:
# Fall back to bitsandbytes standard dequantize
return bnb.functional.dequantize_4bit(W, quant_state, quant_type="nf4")
absmax = quant_state.absmax.to(target_device)
shape = quant_state.shape
dtype = quant_state.dtype

View File

@@ -0,0 +1,333 @@
"""
Fused RMSNorm + SiLU Gate Triton kernel.
Computes: Y = (W + offset) * RMSNorm(X) * silu(G)
where RMSNorm(X) = X / sqrt(mean(X^2) + eps)
and silu(G) = G * sigmoid(G)
Used by Qwen3.5's GatedDeltaNet linear attention layers (Qwen3_5RMSNormGated).
"""
import math
import operator
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import (
calculate_settings,
compare_version,
ensure_contiguous,
torch_to_triton_dtype,
)
from liger_kernel.utils import is_npu_available
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
@triton.jit
def _rms_norm_gated_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
G_ptr,
G_row_stride,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_cols,
eps,
offset,
BLOCK_SIZE: tl.constexpr,
):
"""
Y = (W + offset) * (X / RMS(X)) * silu(G)
All computation done in fp32 (Gemma-style), result cast to input dtype.
"""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
G_row = tl.load(G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
X_row_dtype = X_row.dtype
# Cast everything to fp32
X_fp32 = X_row.to(tl.float32)
G_fp32 = G_row.to(tl.float32)
W_fp32 = W_row.to(tl.float32)
# RMS norm
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
rstd = rsqrt(mean_sq + eps)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
X_norm = X_fp32 * rstd
# SiLU gate: silu(G) = G * sigmoid(G)
sig_G = tl.sigmoid(G_fp32)
silu_G = G_fp32 * sig_G
# Fused output
Y_row = (offset + W_fp32) * X_norm * silu_G
tl.store(
Y_ptr + row_idx * Y_row_stride + col_offsets,
Y_row.to(X_row_dtype),
mask=mask,
)
@triton.jit
def _rms_norm_gated_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
dG_ptr,
dG_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
G_ptr,
G_row_stride,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows,
n_cols,
offset,
rows_per_program,
BLOCK_SIZE: tl.constexpr,
):
"""
Backward for Y = (W + offset) * (X * RSTD) * silu(G)
dW = sum_batch(dY * X_norm * silu(G))
dG = dY * (W + offset) * X_norm * silu'(G)
where silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
dX = RSTD * (m - (1/N) * RSTD^2 * dot(m, X) * X)
where m = dY * (W + offset) * silu(G)
"""
row_block_id = tl.program_id(0).to(tl.int64)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
W_row = W_row.to(tl.float32) + offset
for row_idx in range(row_start, row_end):
dY_row = tl.load(
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0.0
)
X_row = tl.load(
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0.0
)
G_row = tl.load(
G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0.0
)
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
# Cast to fp32
dY_fp32 = dY_row.to(tl.float32)
X_fp32 = X_row.to(tl.float32)
G_fp32 = G_row.to(tl.float32)
# Recompute intermediates
X_norm = X_fp32 * rstd_row
sig_G = tl.sigmoid(G_fp32)
silu_G = G_fp32 * sig_G
# dW: accumulate dY * X_norm * silu(G)
dW_acc += dY_fp32 * X_norm * silu_G
# dG: dY * (W + offset) * X_norm * silu'(G)
# silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
silu_prime_G = sig_G * (1.0 + G_fp32 * (1.0 - sig_G))
dG_row = dY_fp32 * W_row * X_norm * silu_prime_G
tl.store(
dG_ptr + row_idx * dG_row_stride + col_offsets,
dG_row.to(X_dtype),
mask=mask,
)
# dX: standard RMSNorm backward with effective gradient m = dY * W * silu(G)
m = dY_fp32 * W_row * silu_G
dX_row = rstd_row * m
dX_row += rstd_row * (
-(1.0 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_fp32, axis=0) * X_fp32
)
tl.store(
dX_ptr + row_idx * dX_row_stride + col_offsets,
dX_row.to(X_dtype),
mask=mask,
)
tl.store(
dW_ptr + row_block_id * dW_row_stride + col_offsets,
dW_acc,
mask=mask,
)
def rms_norm_gated_forward(X, G, W, eps, offset):
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
G = G.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
assert X.shape[1] == W.shape[0], (
f"Incompatible hidden size: X.shape[1]={X.shape[1]} vs W.shape[0]={W.shape[0]}"
)
assert X.shape == G.shape, (
f"X and G must have same shape, got {X.shape} and {G.shape}"
)
_rms_norm_gated_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
G,
G.stride(0),
W,
W.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
eps,
offset,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return Y.view(*shape), X, G, RSTD, BLOCK_SIZE, num_warps
def rms_norm_gated_backward(dY, X, G, W, RSTD, offset, BLOCK_SIZE, num_warps):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
n_rows, n_cols = dY.shape
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
dX = torch.empty_like(dY)
dG = torch.empty_like(dY)
rows_per_program = math.ceil(n_rows / sm_count)
grid = (sm_count,)
_rms_norm_gated_backward_kernel[grid](
dY,
dY.stride(0),
dX,
dX.stride(0),
dG,
dG.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
G,
G.stride(0),
W,
W.stride(0),
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0),
n_rows,
n_cols,
offset,
rows_per_program,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
dX = dX.view(*shape)
dG = dG.view(*shape)
dW = _dW.sum(dim=0).to(W.dtype)
return dX, dG, dW
class FusedRMSNormGatedFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, G, W, eps, offset=0.0):
"""
X: (B, T, H) or (BxT, H) — input hidden states
G: (B, T, H) or (BxT, H) — gate tensor
W: (H,) — weight parameter
"""
Y, X, G, RSTD, BLOCK_SIZE, num_warps = rms_norm_gated_forward(
X, G, W, eps, offset
)
ctx.offset = offset
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(X, G, W, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, G, W, RSTD = ctx.saved_tensors
dX, dG, dW = rms_norm_gated_backward(
dY, X, G, W, RSTD, ctx.offset, ctx.BLOCK_SIZE, ctx.num_warps
)
return dX, dG, dW, None, None
class FusedRMSNormGated(torch.nn.Module):
"""
Fused RMSNorm + SiLU Gate.
Computes: Y = W * RMSNorm(X) * silu(G)
Drop-in replacement for Qwen3_5RMSNormGated with matching
init signature: __init__(hidden_size, eps=1e-6, **kwargs)
and forward signature: forward(hidden_states, gate=None)
"""
def __init__(self, hidden_size, eps=1e-6, offset=0.0, **kwargs):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.offset = offset
def forward(self, hidden_states, gate=None):
if gate is None:
raise ValueError("FusedRMSNormGated requires a gate tensor")
if hidden_states.device.type != "cuda":
raise ValueError(
f"FusedRMSNormGated requires CUDA tensors, got device={hidden_states.device}"
)
return FusedRMSNormGatedFunction.apply(
hidden_states, gate, self.weight, self.variance_epsilon, self.offset
)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

View File

@@ -12,6 +12,7 @@ from torch import nn
from transformers import AutoConfig
from axolotl.kernels.lora import (
apply_lora_embedding,
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
@@ -370,13 +371,13 @@ def apply_lora_kernel_patches(
active_adapter = model.active_adapter
lora_config = model.model.peft_config[active_adapter]
# Only patch if conditions are met
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
if not can_patch:
LOG.warning("Cannot patch layers - requires no dropout and no bias")
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
return model
# Log what features are active
if lora_config.lora_dropout > 0:
LOG.info(f"LoRA kernels: dropout={lora_config.lora_dropout} enabled")
if lora_config.bias != "none":
LOG.info(f"LoRA kernels: bias={lora_config.bias} enabled")
if lora_config.use_dora:
LOG.info("LoRA kernels: DoRA enabled")
# This needs to be reset after patching
original_level = LOG.getEffectiveLevel()
@@ -419,44 +420,33 @@ def apply_lora_kernel_patches(
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
hasattr(module, "lora_A") for module in layer_modules
)
if can_patch_qkv:
# Add optimized implementation
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
"Cannot patch some attention QKV projections - requires LoRA adapters"
)
if cfg.lora_o_kernel:
# Output patching
layer_modules = [
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
]
can_patch_o = all(
hasattr(module, "lora_A")
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
can_patch_o = all(hasattr(module, "lora_A") for module in layer_modules)
if can_patch_o:
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
"Cannot patch some attention output projection - requires LoRA adapters"
)
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
if cfg.lora_mlp_kernel:
# MLP patching
can_patch_mlp = all(
hasattr(proj, "lora_A")
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
hasattr(proj, "lora_A") for proj in (gate_proj, up_proj, down_proj)
)
if can_patch_mlp:
@@ -464,15 +454,50 @@ def apply_lora_kernel_patches(
layer.mlp.forward = types.MethodType(apply_fn, mlp)
else:
LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters and no "
"lora_magnitude_vector (DoRA)"
"Cannot patch some MLP layers - requires LoRA adapters"
)
# Patch embedding layers (model-level, not per-layer)
if cfg.lora_embedding_kernel:
_patch_embedding_layers(model, cfg)
LOG.setLevel(original_level)
return model
def _patch_embedding_layers(model: PeftModelForCausalLM, cfg: DictDefault):
"""Patch embedding layers with fused LoRA kernel.
Handles both embed_tokens (nn.Embedding with lora_embedding_A/B) and
lm_head (nn.Linear with lora_A/B, used when tied embeddings are untied by PEFT).
"""
pretrained_model = model.model
patched = 0
# Find embedding modules - check common locations
for attr_path in [
("model", "embed_tokens"),
("model", "language_model", "embed_tokens"),
]:
parent = pretrained_model
for attr in attr_path:
parent = getattr(parent, attr, None)
if parent is None:
break
if parent is not None and hasattr(parent, "lora_embedding_A"):
LOG.info(f"Patching embedding layer: {'.'.join(attr_path)}")
parent.forward = types.MethodType(apply_lora_embedding, parent)
patched += 1
# lm_head with LoRA is a Linear layer - already handled by LoRA_O/LoRA_W kernels
# when included in target_modules. No special embedding handling needed since
# PEFT wraps it as a Linear (not Embedding) even for tied models.
if not patched:
LOG.debug("No embedding layers with LoRA found to patch")
class FakeMLP(nn.Module):
"""
placeholder MLP for triton patching

View File

@@ -9,7 +9,6 @@ import os
import shutil
import signal
import sys
import typing
import weakref
from collections import OrderedDict
from contextlib import ExitStack
@@ -42,9 +41,6 @@ from axolotl.utils.schemas.enums import RLType
from axolotl.utils.train import determine_last_checkpoint
from axolotl.utils.trainer import setup_trainer
if typing.TYPE_CHECKING:
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
LOG = get_logger(__name__)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
@@ -487,7 +483,7 @@ def handle_untrained_tokens_fix(
def setup_model_and_trainer(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[
"HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
Trainer,
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,
@@ -554,6 +550,36 @@ def setup_model_and_trainer(
)
def _is_tui_enabled(cfg: DictDefault) -> bool:
"""Check if TUI is enabled via config or environment variable."""
if os.environ.get("AXOLOTL_TUI", "").lower() in ("1", "true", "yes"):
return True
tui = cfg.get("tui")
if tui is None:
return False
if isinstance(tui, bool):
return tui
if isinstance(tui, dict):
return tui.get("enabled", False)
if hasattr(tui, "enabled"):
return tui.enabled
return False
def _get_tui_config(cfg: DictDefault) -> dict:
"""Extract TUI config dict from cfg."""
tui = cfg.get("tui")
if tui is None or isinstance(tui, bool):
return {"enabled": True}
if isinstance(tui, dict):
return {**tui, "enabled": True}
if hasattr(tui, "model_dump"):
d = tui.model_dump()
d["enabled"] = True
return d
return {"enabled": True}
@send_errors
def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
@@ -577,6 +603,37 @@ def train(
processor,
) = setup_model_and_trainer(cfg, dataset_meta)
# Register TUI callback if enabled and rank 0
tui_enabled = _is_tui_enabled(cfg)
if tui_enabled and cfg.local_rank == 0:
from axolotl.tui import AxolotlTUICallback
from axolotl.tui.config import TUIConfig
tui_config = _get_tui_config(cfg)
tui_config_obj = (
TUIConfig(**tui_config) if isinstance(tui_config, dict) else tui_config
)
# Reuse the early-started renderer if available (started in do_train)
early_renderer = getattr(cfg, "_tui_renderer", None)
early_queue = getattr(cfg, "_tui_queue", None)
tui_callback = AxolotlTUICallback(config=tui_config_obj)
if early_renderer is not None and early_queue is not None:
# Reuse the already-running renderer and queue
tui_callback._renderer = early_renderer
tui_callback._queue = early_queue
tui_callback._renderer_started_early = True
trainer.add_callback(tui_callback)
# Stash model info so on_train_begin can emit a single unified run_info event
tui_callback._pending_run_info = {
"model_name": cfg.base_model or "",
"training_mode": str(cfg.rl) if cfg.rl else "sft",
"world_size": int(os.environ.get("WORLD_SIZE", 1)),
}
LOG.info("TUI dashboard enabled")
# Handle untrained tokens if configured
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)

View File

@@ -0,0 +1,17 @@
"""Axolotl Training TUI — rich-based terminal dashboard for monitoring training runs."""
from axolotl.tui.callback import AxolotlTUICallback
from axolotl.tui.config import TUIConfig
from axolotl.tui.io_capture import LineParser, register_parser
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
__all__ = [
"AxolotlTUICallback",
"BasePanel",
"LineParser",
"TUIConfig",
"TUIState",
"register_panel",
"register_parser",
]

142
src/axolotl/tui/callback.py Normal file
View File

@@ -0,0 +1,142 @@
"""AxolotlTUICallback — HF TrainerCallback that feeds metrics to the TUI."""
from __future__ import annotations
import logging
import queue
from transformers.trainer_callback import TrainerCallback
from axolotl.tui.config import TUIConfig
from axolotl.tui.renderer import TUIRenderer
class _TUILogHandler(logging.Handler):
"""Logging handler that pushes log records into the TUI metric queue."""
_LEVEL_MAP = {
logging.DEBUG: "debug",
logging.INFO: "info",
logging.WARNING: "warning",
logging.ERROR: "error",
logging.CRITICAL: "error",
}
def __init__(self, metric_queue: queue.Queue, min_level: str = "info"):
super().__init__()
level_name = min_level.upper()
self.setLevel(getattr(logging, level_name, logging.INFO))
self._queue = metric_queue
def emit(self, record: logging.LogRecord) -> None:
try:
level = self._LEVEL_MAP.get(record.levelno, "info")
msg = self.format(record)
self._queue.put_nowait(
{
"type": "log_line",
"level": level,
"message": msg,
}
)
except queue.Full:
pass
except Exception:
self.handleError(record)
class AxolotlTUICallback(TrainerCallback):
"""Pushes training metrics into a queue for the TUI renderer.
The callback never blocks on the render thread. The queue is bounded
(maxsize=512) with put_nowait; overflow is silently dropped.
"""
def __init__(self, config: TUIConfig):
self._config = config
self._queue: queue.Queue = queue.Queue(maxsize=4096)
self._renderer = TUIRenderer(config=config, metric_queue=self._queue)
self._log_handler: _TUILogHandler | None = None
self._renderer_started_early: bool = False
self._pending_run_info: dict | None = None
def _put(self, event: dict) -> None:
try:
self._queue.put_nowait(event)
except queue.Full:
pass
def on_train_begin(self, args, state, control, model=None, **kwargs):
# Send a single unified run_info event with all fields
run_info = {
"type": "run_info",
"run_name": getattr(args, "run_name", "") or "",
"total_steps": state.max_steps,
"total_epochs": float(args.num_train_epochs)
if args.num_train_epochs
else 1.0,
}
# Merge in model_name/training_mode/world_size if stashed by train.py
if self._pending_run_info:
run_info.update(self._pending_run_info)
self._pending_run_info = None
self._put(run_info)
if not self._renderer_started_early:
# Attach a logging handler to feed log messages into the events panel
self._log_handler = _TUILogHandler(
self._queue, min_level=self._config.log_level
)
self._log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
# Attach to both root and axolotl loggers (axolotl has propagate=False)
logging.getLogger().addHandler(self._log_handler)
logging.getLogger("axolotl").addHandler(self._log_handler)
# Start the renderer background thread
self._renderer.start()
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
# Filter out non-numeric keys and internal keys
filtered = {}
for key, value in logs.items():
if key.startswith("_"):
continue
if isinstance(value, (int, float)):
filtered[key] = value
elif isinstance(value, str):
# HF Trainer sometimes passes string-encoded numbers
try:
filtered[key] = float(value)
except (ValueError, TypeError):
pass
if filtered:
self._put({"type": "metrics", "logs": filtered})
def on_step_end(self, args, state, control, **kwargs):
self._put(
{
"type": "step",
"step": state.global_step,
"total_steps": state.max_steps,
"epoch": state.epoch if state.epoch else 0,
}
)
def on_prediction_step(self, args, state, control, **kwargs):
pass
def on_train_end(self, args, state, control, **kwargs):
self._put({"type": "done"})
# If renderer was started early, do_train's finally block handles stop
if not self._renderer_started_early:
self._renderer.stop()
# Remove the logging handler (only if we added it)
if self._log_handler:
logging.getLogger().removeHandler(self._log_handler)
logging.getLogger("axolotl").removeHandler(self._log_handler)
self._log_handler = None

38
src/axolotl/tui/config.py Normal file
View File

@@ -0,0 +1,38 @@
"""TUI configuration — Pydantic model for TUI settings."""
from __future__ import annotations
from pydantic import BaseModel, Field
class TUIConfig(BaseModel):
"""Configuration for the Axolotl Training TUI dashboard."""
enabled: bool = Field(
default=False,
json_schema_extra={"description": "Enable the TUI dashboard"},
)
refresh_rate: int = Field(
default=4,
json_schema_extra={"description": "Renders per second"},
)
log_level: str = Field(
default="debug",
json_schema_extra={"description": "Minimum log level shown in events panel"},
)
panels: list[str] = Field(
default_factory=lambda: ["progress", "training", "hardware", "events", "debug"],
json_schema_extra={"description": "Ordered list of panels to display"},
)
hardware_poll_interval: int = Field(
default=2,
json_schema_extra={"description": "Seconds between pynvml GPU queries"},
)
stdout_log_path: str = Field(
default="axolotl_stdout.log",
json_schema_extra={"description": "File path for captured stdout/stderr log"},
)
parser_plugins: list[str] = Field(
default_factory=list,
json_schema_extra={"description": "List of extra parser classes to load"},
)

72
src/axolotl/tui/gpu.py Normal file
View File

@@ -0,0 +1,72 @@
"""GPU polling wrapper around pynvml with graceful fallback."""
from __future__ import annotations
import logging
from axolotl.tui.state import GPUStats
LOG = logging.getLogger(__name__)
_nvml_available = False
try:
import pynvml
pynvml.nvmlInit()
_nvml_available = True
except Exception:
LOG.debug("pynvml unavailable — GPU stats will not be shown")
class GPUPoller:
"""Polls local GPU stats via pynvml. Falls back gracefully if unavailable."""
def __init__(self):
self._device_count = 0
if _nvml_available:
try:
self._device_count = pynvml.nvmlDeviceGetCount()
except Exception:
self._device_count = 0
@property
def available(self) -> bool:
return _nvml_available and self._device_count > 0
def poll(self) -> list[GPUStats]:
if not self.available:
return []
stats = []
for i in range(self._device_count):
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
name = pynvml.nvmlDeviceGetName(handle)
if isinstance(name, bytes):
name = name.decode("utf-8")
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
temp = pynvml.nvmlDeviceGetTemperature(
handle, pynvml.NVML_TEMPERATURE_GPU
)
try:
power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0
except Exception:
power = None
stats.append(
GPUStats(
id=i,
name=name,
util_pct=util.gpu,
vram_used_gb=mem.used / (1024**3),
vram_total_gb=mem.total / (1024**3),
temp_c=temp,
power_w=power,
)
)
except Exception:
LOG.debug("Error polling GPU device %d", i, exc_info=True)
return stats

View File

@@ -0,0 +1,196 @@
"""I/O capture: OS-level stdout/stderr redirect, line parser chain, and parser registry."""
from __future__ import annotations
import logging
import os
import queue
import sys
import threading
from abc import ABC, abstractmethod
from datetime import datetime
from typing import IO
# ---------------------------------------------------------------------------
# Parser registry
# ---------------------------------------------------------------------------
_parser_registry: list[type[LineParser]] = []
def register_parser(cls: type[LineParser]) -> type[LineParser]:
"""Decorator to register a LineParser subclass."""
if cls not in _parser_registry:
_parser_registry.append(cls)
return cls
def get_registered_parsers() -> list[type[LineParser]]:
return list(_parser_registry)
# ---------------------------------------------------------------------------
# Base LineParser
# ---------------------------------------------------------------------------
class LineParser(ABC):
"""Base class for stdout/stderr line parsers."""
priority: int = 50
name: str = ""
@abstractmethod
def parse(self, line: str, source: str) -> list[dict]:
"""Parse a single captured line.
Args:
line: one line of captured output, trailing newline stripped.
source: "stdout" or "stderr".
Returns:
List of event dicts to push onto the metric queue.
Return [] if this line is not relevant.
"""
...
# ---------------------------------------------------------------------------
# ParserChain
# ---------------------------------------------------------------------------
class ParserChain:
def __init__(self):
self._parsers: list[LineParser] = []
def register(self, parser: LineParser) -> None:
self._parsers.append(parser)
self._parsers.sort(key=lambda p: p.priority)
def parse(self, line: str, source: str = "stdout") -> list[dict]:
events: list[dict] = []
for parser in self._parsers:
events.extend(parser.parse(line, source))
return events
# ---------------------------------------------------------------------------
# IOCapture — OS-level fd redirect to pipe
# ---------------------------------------------------------------------------
class IOCapture:
"""Redirects fd 1 and fd 2 into an OS pipe, drains via a reader thread,
passes lines through a ParserChain, and tees to a log file."""
def __init__(
self, log_path: str, parser_chain: ParserChain, metric_queue: queue.Queue
):
self._parser_chain = parser_chain
self._queue = metric_queue
self._log_path = log_path
self._log_file: IO[str] | None = None
self._thread: threading.Thread | None = None
self._read_fd: int | None = None
self._write_fd: int | None = None
self._saved_stdout_fd: int | None = None
self._saved_stderr_fd: int | None = None
def start(self) -> None:
# Write run-start separator
self._log_file = open(self._log_path, "a", buffering=1) # noqa: SIM115
self._log_file.write(
f"\n=== axolotl run started {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ===\n"
)
self._log_file.flush()
# OS-level pipe
self._read_fd, self._write_fd = os.pipe()
# Save originals
self._saved_stdout_fd = os.dup(1)
self._saved_stderr_fd = os.dup(2)
# Redirect both stdout and stderr into the write end
os.dup2(self._write_fd, 1)
os.dup2(self._write_fd, 2)
os.close(self._write_fd) # write end now held by fds 1 and 2
# Also redirect Python-level handles
sys.stdout = open(1, "w", buffering=1, closefd=False) # noqa: SIM115
sys.stderr = open(2, "w", buffering=1, closefd=False) # noqa: SIM115
# Drain thread
self._thread = threading.Thread(target=self._drain, daemon=True)
self._thread.start()
def stop(self) -> None:
# Restore fds — closes the write end, causing reader to see EOF
if self._saved_stdout_fd is not None and self._saved_stderr_fd is not None:
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
os.dup2(self._saved_stdout_fd, 1)
os.dup2(self._saved_stderr_fd, 2)
os.close(self._saved_stdout_fd)
os.close(self._saved_stderr_fd)
self._saved_stdout_fd = None
self._saved_stderr_fd = None
if self._thread is not None:
self._thread.join(timeout=2.0)
if self._thread.is_alive():
logging.getLogger(__name__).warning(
"IO capture thread did not exit after 2s"
)
self._thread = None
if self._log_file is not None:
self._log_file.close()
self._log_file = None
def _drain(self) -> None:
# Read raw bytes and split on both \n and \r to handle tqdm progress bars
# which use \r for in-place updates without \n
assert self._read_fd is not None, "_drain called before start()"
with os.fdopen(self._read_fd, "rb") as pipe:
buf = b""
while True:
chunk = pipe.read(4096)
if not chunk:
# EOF — process remaining buffer
if buf:
self._process_line(buf.decode("utf-8", errors="replace"))
break
buf += chunk
# Split on \n or \r
while b"\n" in buf or b"\r" in buf:
# Find the earliest delimiter
idx_n = buf.find(b"\n")
idx_r = buf.find(b"\r")
if idx_n == -1:
idx = idx_r
elif idx_r == -1:
idx = idx_n
else:
idx = min(idx_n, idx_r)
line = buf[:idx].decode("utf-8", errors="replace")
buf = buf[idx + 1 :]
# Handle \r\n as single delimiter
if buf.startswith(b"\n"):
buf = buf[1:]
if line:
self._process_line(line)
def _process_line(self, line: str) -> None:
line = line.rstrip()
if not line:
return
if self._log_file:
self._log_file.write(line + "\n")
self._log_file.flush()
for event in self._parser_chain.parse(line):
try:
self._queue.put_nowait(event)
except queue.Full:
pass

View File

@@ -0,0 +1,63 @@
"""Panel registry and base class for TUI panels."""
from __future__ import annotations
from abc import ABC, abstractmethod
from rich.console import RenderableType
from axolotl.tui.state import TUIState
# ---------------------------------------------------------------------------
# Panel registry
# ---------------------------------------------------------------------------
_panel_registry: dict[str, type[BasePanel]] = {}
def register_panel(position: str = "bottom", weight: int = 50):
"""Decorator to register a panel class with position and weight."""
def decorator(cls: type[BasePanel]) -> type[BasePanel]:
cls.position = position
cls.weight = weight
_panel_registry[cls.name] = cls
return cls
return decorator
def get_registered_panels() -> dict[str, type[BasePanel]]:
return dict(_panel_registry)
# ---------------------------------------------------------------------------
# BasePanel
# ---------------------------------------------------------------------------
class BasePanel(ABC):
name: str = ""
position: str = "bottom"
weight: int = 50
min_height: int = 4
max_height: int | None = None
modes: list[str] = ["*"]
@abstractmethod
def render(self, state: TUIState) -> RenderableType:
"""Return a rich renderable. Called every tick."""
...
def on_event(self, event: dict) -> None: # noqa: B027
"""Optional: react to raw metric events before state is merged."""
pass
# Auto-import built-in panels to trigger registration
from axolotl.tui.panels.completions import CompletionsPanel # noqa: E402, F401
from axolotl.tui.panels.debug import DebugPanel # noqa: E402, F401
from axolotl.tui.panels.events import EventsPanel # noqa: E402, F401
from axolotl.tui.panels.hardware import HardwarePanel # noqa: E402, F401
from axolotl.tui.panels.progress import ProgressPanel # noqa: E402, F401
from axolotl.tui.panels.training import TrainingPanel # noqa: E402, F401

View File

@@ -0,0 +1,61 @@
"""CompletionsPanel — shows recent RL/log_completions samples."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
def _truncate(s: str, maxlen: int = 60) -> str:
return s[:maxlen] + "" if len(s) > maxlen else s
@register_panel(position="bottom", weight=20)
class CompletionsPanel(BasePanel):
name = "completions"
min_height = 6
modes = ["grpo", "dpo"]
def render(self, state: TUIState) -> RenderableType:
if "*" not in self.modes and state.training_mode not in self.modes:
return Text("")
if not state.completions:
return Panel(
Text("No completions yet...", style="dim"),
title="Completions",
border_style="magenta",
)
table = Table(
show_header=True,
header_style="bold",
expand=True,
box=None,
pad_edge=False,
)
table.add_column("step", justify="right", width=6)
table.add_column("prompt", no_wrap=False, max_width=40)
table.add_column("completion", no_wrap=False, max_width=40)
table.add_column("reward", justify="right", width=8)
table.add_column("adv", justify="right", width=8)
for sample in list(state.completions)[-5:]:
reward_str = f"{sample.reward:.2f}" if sample.reward is not None else "--"
adv_str = (
f"{sample.advantage:+.2f}" if sample.advantage is not None else "--"
)
table.add_row(
str(sample.step),
_truncate(sample.prompt),
_truncate(sample.completion),
reward_str,
adv_str,
)
return Panel(table, title="Completions", border_style="magenta")

View File

@@ -0,0 +1,34 @@
"""DebugPanel — scrolling log of debug-level messages, separate from main events."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
@register_panel(position="bottom", weight=30)
class DebugPanel(BasePanel):
name = "debug"
min_height = 6
max_height = 10
def render(self, state: TUIState) -> RenderableType:
lines = Text()
# Show last 8 debug-level log lines
debug_lines = [
log_entry for log_entry in state.log_lines if log_entry.level == "debug"
][-8:]
for log_line in debug_lines:
ts = log_line.timestamp.strftime("%H:%M:%S")
lines.append(f"[{ts}] ", style="dim")
lines.append(log_line.message[:200], style="dim")
lines.append("\n")
if not debug_lines:
lines = Text("No debug messages yet...", style="dim")
return Panel(lines, title="Debug", border_style="dim")

View File

@@ -0,0 +1,45 @@
"""EventsPanel — scrolling log of recent events, color-coded by level."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
_LEVEL_STYLES = {
"debug": "dim",
"info": "",
"warning": "yellow",
"error": "red bold",
"critical": "red bold",
}
@register_panel(position="bottom", weight=10)
class EventsPanel(BasePanel):
name = "events"
min_height = 8
max_height = 20
def render(self, state: TUIState) -> RenderableType:
lines = Text()
# Show last 15 non-debug log lines (debug goes to DebugPanel)
recent = [
log_entry for log_entry in state.log_lines if log_entry.level != "debug"
][-15:]
for log_line in recent:
ts = log_line.timestamp.strftime("%H:%M:%S")
level = log_line.level.upper()
style = _LEVEL_STYLES.get(log_line.level, "")
lines.append(f"[{ts}] ", style="dim")
lines.append(f"[{level}] ", style=style or "")
lines.append(log_line.message[:200], style=style or "")
lines.append("\n")
if not recent:
lines = Text("No events yet...", style="dim")
return Panel(lines, title="Events", border_style="yellow")

View File

@@ -0,0 +1,80 @@
"""HardwarePanel — per-GPU stats via pynvml."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
_BAR_FULL = ""
_BAR_EMPTY = ""
def _util_bar(pct: float, width: int = 6) -> Text:
filled = int(pct / 100 * width)
bar = _BAR_FULL * filled + _BAR_EMPTY * (width - filled)
color = "green" if pct < 70 else ("yellow" if pct < 90 else "red")
return Text.assemble((bar, color), f" {pct:3.0f}%")
@register_panel(position="right", weight=10)
class HardwarePanel(BasePanel):
name = "hardware"
min_height = 6
def render(self, state: TUIState) -> RenderableType:
if not state.gpus:
return Panel(
Text("GPU stats unavailable", style="dim"),
title="Hardware",
border_style="green",
)
table = Table(
show_header=True,
header_style="bold",
expand=True,
box=None,
pad_edge=False,
)
table.add_column("id", justify="right", width=3)
table.add_column("util", no_wrap=True)
table.add_column("vram", no_wrap=True)
table.add_column("°C", justify="right", width=4)
table.add_column("W", justify="right", width=5)
total_vram_used = 0.0
total_vram_total = 0.0
total_util = 0.0
for gpu in state.gpus:
total_vram_used += gpu.vram_used_gb
total_vram_total += gpu.vram_total_gb
total_util += gpu.util_pct
power_str = f"{gpu.power_w:.0f}" if gpu.power_w is not None else "--"
table.add_row(
str(gpu.id),
_util_bar(gpu.util_pct),
f"{gpu.vram_used_gb:.1f}/{gpu.vram_total_gb:.1f} GB",
str(gpu.temp_c),
power_str,
)
# Footer with aggregates
n = len(state.gpus)
if n > 1:
avg_util = total_util / n
table.add_row(
"Σ",
Text(f"avg {avg_util:.0f}%", style="dim"),
Text(f"{total_vram_used:.1f}/{total_vram_total:.1f} GB", style="dim"),
"",
"",
)
return Panel(table, title="Hardware", border_style="green")

View File

@@ -0,0 +1,73 @@
"""ProgressPanel — top-bar progress display with step count, elapsed, ETA."""
from __future__ import annotations
from rich.console import RenderableType
from rich.progress import BarColumn, Progress, TextColumn
from rich.table import Table
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
def _fmt_time(seconds: float | None) -> str:
if seconds is None or seconds < 0:
return "--:--:--"
h = int(seconds) // 3600
m = (int(seconds) % 3600) // 60
s = int(seconds) % 60
return f"{h}:{m:02d}:{s:02d}"
def _fmt_eta(seconds: float | None) -> str:
if seconds is None or seconds < 0:
return "eta --"
h = int(seconds) // 3600
m = (int(seconds) % 3600) // 60
if h > 0:
return f"eta {h}h{m:02d}m"
return f"eta {m}m{int(seconds) % 60:02d}s"
@register_panel(position="top", weight=10)
class ProgressPanel(BasePanel):
name = "progress"
min_height = 3
max_height = 3
def render(self, state: TUIState) -> RenderableType:
pct = (
(state.current_step / state.total_steps * 100)
if state.total_steps > 0
else 0
)
# Header line
mode_upper = state.training_mode.upper() if state.training_mode else "SFT"
model_short = state.model_name.split("/")[-1] if state.model_name else "model"
header = Text.assemble(
("", "bold green"),
("AXOLOTL", "bold cyan"),
f" {mode_upper} · {model_short} ",
(
f"{state.current_step} / {state.total_steps}",
"bold",
),
f" · {_fmt_time(state.elapsed_seconds)} elapsed · {_fmt_eta(state.eta_seconds)} · {pct:.1f}%",
)
# Progress bar
progress = Progress(
TextColumn(""),
BarColumn(bar_width=None),
TextColumn("{task.percentage:>3.0f}%"),
expand=True,
)
task = progress.add_task("", total=state.total_steps or 1)
progress.update(task, completed=state.current_step)
table = Table.grid(expand=True)
table.add_row(header)
table.add_row(progress)
return table

View File

@@ -0,0 +1,97 @@
"""TrainingPanel — live scalar metrics table with loss sparkline."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
# Braille sparkline characters (8 levels)
_SPARK_CHARS = "▁▂▃▄▅▆▇█"
def _sparkline(values: list[float] | None, width: int = 20) -> str:
if not values or len(values) < 2:
return ""
vals = list(values)[-width:]
lo, hi = min(vals), max(vals)
rng = hi - lo if hi != lo else 1.0
return "".join(_SPARK_CHARS[min(int((v - lo) / rng * 7), 7)] for v in vals)
# Known key ordering and formatting
_KNOWN_KEYS: list[tuple[str, str, str]] = [
("loss", "loss", ".4f"),
("grad_norm", "grad norm", ".3f"),
("learning_rate", "lr", ".2e"),
("tokens_per_second", "tok/s", ".1f"),
("samples_per_second", "samples/s", ".1f"),
("mfu", "MFU", ".1f"),
# RL-specific
("rewards_mean", "rewards/mean", ".4f"),
("rewards_std", "rewards/std", ".4f"),
("kl_divergence", "KL", ".4f"),
("clip_ratio", "clip ratio", ".3f"),
("queue_size", "queue", "d"),
]
@register_panel(position="left", weight=10)
class TrainingPanel(BasePanel):
name = "training"
min_height = 8
def render(self, state: TUIState) -> RenderableType:
table = Table(
show_header=True,
header_style="bold",
expand=True,
box=None,
pad_edge=False,
)
table.add_column("metric", style="cyan", no_wrap=True)
table.add_column("value", justify="right")
table.add_column("trend", justify="left", no_wrap=True)
for attr, label, fmt in _KNOWN_KEYS:
val = getattr(state, attr, None)
if val is None:
# Also check extra dict
val = state.extra.get(attr)
if val is None:
continue
try:
formatted = f"{val:{fmt}}"
except (ValueError, TypeError):
formatted = str(val)
trend = ""
if attr == "loss":
trend = _sparkline(list(state.loss_history))
table.add_row(label, formatted, trend)
# Any extra keys not in _KNOWN_KEYS
known_attrs = {k for k, _, _ in _KNOWN_KEYS}
for key, val in sorted(state.extra.items()):
if key in known_attrs or val is None:
continue
try:
formatted = f"{val:.4f}"
except (ValueError, TypeError):
formatted = str(val)
table.add_row(key, formatted, "")
if table.row_count == 0:
return Panel(
Text("Waiting for first log step...", style="dim"),
title="Training",
border_style="blue",
)
return Panel(table, title="Training", border_style="blue")

View File

@@ -0,0 +1,7 @@
"""Built-in line parsers — auto-imported to trigger @register_parser decorators."""
from axolotl.tui.parsers.deepspeed import DeepSpeedParser # noqa: F401
from axolotl.tui.parsers.nccl import NCCLErrorParser # noqa: F401
from axolotl.tui.parsers.raw_log import RawLogParser # noqa: F401
from axolotl.tui.parsers.torch_compile import TorchCompileParser # noqa: F401
from axolotl.tui.parsers.tqdm import TqdmParser # noqa: F401

View File

@@ -0,0 +1,29 @@
"""DeepSpeedParser — extracts DeepSpeed stage info and throughput metrics."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class DeepSpeedParser(LineParser):
priority = 20
name = "deepspeed"
_SAMPLES_RE = re.compile(r"samples/sec=([0-9.]+)")
_STAGE_RE = re.compile(r"ZeRO Stage (\d)")
def parse(self, line: str, source: str) -> list[dict]:
events: list[dict] = []
if m := self._SAMPLES_RE.search(line):
events.append(
{
"type": "metrics",
"logs": {"samples_per_second": float(m.group(1))},
}
)
if m := self._STAGE_RE.search(line):
events.append({"type": "run_info", "zero_stage": int(m.group(1))})
return events

View File

@@ -0,0 +1,27 @@
"""NCCLErrorParser — surfaces NCCL errors as red alert events."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class NCCLErrorParser(LineParser):
priority = 10
name = "nccl_error"
_RE = re.compile(r"NCCL error|Unhandled NCCL", re.IGNORECASE)
def parse(self, line: str, source: str) -> list[dict]:
if self._RE.search(line):
return [
{
"type": "log_line",
"level": "error",
"message": f"⚠ NCCL: {line}",
},
{"type": "alert", "severity": "error", "message": line},
]
return []

View File

@@ -0,0 +1,37 @@
"""RawLogParser — catches every line as a log_line event."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class RawLogParser(LineParser):
priority = 99
name = "raw_log"
_LOG_RE = re.compile(
r"^(?P<ts>\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}[,\.]\d+)"
r"\s*[-]\s*(?P<level>DEBUG|INFO|WARNING|ERROR|CRITICAL)"
r"\s*[-]\s*(?P<msg>.+)$",
re.IGNORECASE,
)
# Filter out tqdm progress bar lines and other noisy output
_TQDM_RE = re.compile(r"^\s*\d+%\|.*\|")
_EMPTY_RE = re.compile(r"^\s*$")
def parse(self, line: str, source: str) -> list[dict]:
# Skip empty lines and tqdm progress bar updates
if self._EMPTY_RE.match(line) or self._TQDM_RE.match(line):
return []
m = self._LOG_RE.match(line)
level = (
m.group("level").lower()
if m
else ("error" if source == "stderr" else "info")
)
return [{"type": "log_line", "level": level, "message": line}]

View File

@@ -0,0 +1,26 @@
"""TorchCompileParser — detects torch.compile graph breaks and recompilations."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class TorchCompileParser(LineParser):
priority = 20
name = "torch_compile"
_RE = re.compile(r"Graph break|Recompiling|torch\.compile", re.IGNORECASE)
def parse(self, line: str, source: str) -> list[dict]:
if self._RE.search(line):
return [
{
"type": "log_line",
"level": "warning",
"message": f"⚡ compile: {line}",
}
]
return []

View File

@@ -0,0 +1,86 @@
"""TqdmParser — captures tqdm progress bar output and surfaces as structured events."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class TqdmParser(LineParser):
priority = 15
name = "tqdm"
# Match tqdm-style progress lines, e.g.:
# Tokenizing Prompts (num_proc=24): 35%|███▍ | 19008/54568 [00:02<00:02, 17417.65 examples/s]
# Loading weights: 53%|█████▎ | 77/146 [00:00<00:00, 396.39it/s]
# 0%| | 0/30 [00:00<?, ?it/s]
_TQDM_RE = re.compile(
r"(?P<desc>.*?)\s*"
r"(?P<pct>\d+)%\|[▏▎▍▌▋▊▉█░▓▒# ]*\|\s*"
r"(?P<current>[\d,]+)/(?P<total>[\d,]+)"
r"\s*\[(?P<elapsed>[^\]]*)\]"
)
# Also match simpler forms like:
# Fetching 0 files: 0it [00:00, ?it/s]
_FETCH_RE = re.compile(r"(?P<desc>[\w\s]+):\s*(?P<current>\d+)(?:it)?\s*\[.*?\]")
def parse(self, line: str, source: str) -> list[dict]:
m = self._TQDM_RE.search(line)
if m:
desc = m.group("desc").strip().rstrip(":")
pct = int(m.group("pct"))
current = int(m.group("current").replace(",", ""))
total = int(m.group("total").replace(",", ""))
events: list[dict] = []
# Surface as a log line with progress info
if pct == 100 or pct == 0 or pct % 25 == 0:
msg = (
f"[{desc}] {pct}% ({current}/{total})"
if desc
else f"{pct}% ({current}/{total})"
)
events.append(
{
"type": "log_line",
"level": "info",
"message": msg,
}
)
# Also emit as a progress metric
cleaned_desc = desc.strip().lower().replace(" ", "_")
if not cleaned_desc:
cleaned_desc = "progress"
events.append(
{
"type": "metrics",
"logs": {
f"progress/{cleaned_desc}": pct / 100.0,
},
}
)
return events
# Fallback: try simpler fetch-style progress lines
m = self._FETCH_RE.search(line)
if m:
desc = m.group("desc").strip().rstrip(":")
current = int(m.group("current"))
cleaned_desc = desc.strip().lower().replace(" ", "_")
if not cleaned_desc:
cleaned_desc = "fetch"
return [
{
"type": "log_line",
"level": "info",
"message": f"[{desc}] {current}" if desc else f"{current}",
}
]
return []

449
src/axolotl/tui/renderer.py Normal file
View File

@@ -0,0 +1,449 @@
"""TUIRenderer — background daemon thread that drives the rich.live.Live display."""
from __future__ import annotations
import logging
import queue
import threading
import time
from datetime import datetime
from typing import Any
from rich.console import Console
from rich.layout import Layout
from rich.live import Live
from axolotl.tui.config import TUIConfig
from axolotl.tui.gpu import GPUPoller
from axolotl.tui.io_capture import (
IOCapture,
ParserChain,
get_registered_parsers,
)
from axolotl.tui.panels import BasePanel, get_registered_panels
from axolotl.tui.state import CompletionSample, LogLine, TUIState
LOG = logging.getLogger(__name__)
class TUIRenderer:
"""Background thread that renders the TUI dashboard using rich.live.Live."""
def __init__(self, config: TUIConfig, metric_queue: queue.Queue):
self._config = config
self._queue = metric_queue
self._state = TUIState()
self._gpu_poller = GPUPoller()
self._panels: list[BasePanel] = []
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._io_capture: IOCapture | None = None
self._parser_chain: ParserChain | None = None
def _init_panels(self) -> None:
registry = get_registered_panels()
for panel_name in self._config.panels:
if panel_name in registry:
self._panels.append(registry[panel_name]())
def _init_parser_chain(self) -> None:
# Ensure built-in parsers are imported so @register_parser decorators fire
import axolotl.tui.parsers # noqa: F401
self._parser_chain = ParserChain()
# Register all built-in parsers
for parser_cls in get_registered_parsers():
self._parser_chain.register(parser_cls())
# Load plugin parsers
for plugin_spec in self._config.parser_plugins:
try:
if "::" in plugin_spec:
# file path :: class name
file_path, class_name = plugin_spec.split("::", 1)
import importlib.util
spec = importlib.util.spec_from_file_location(
"custom_parser", file_path
)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load spec for {file_path}")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
parser_cls = getattr(mod, class_name)
else:
# dotted module path
module_path, class_name = plugin_spec.rsplit(".", 1)
mod = importlib.import_module(module_path)
parser_cls = getattr(mod, class_name)
self._parser_chain.register(parser_cls())
except Exception as exc:
LOG.warning(f"Failed to load parser plugin {plugin_spec}: {exc}")
def _build_layout(self) -> Layout:
layout = Layout()
top_panels = [p for p in self._panels if p.position == "top"]
left_panels = [p for p in self._panels if p.position == "left"]
right_panels = [p for p in self._panels if p.position == "right"]
bottom_panels = [p for p in self._panels if p.position == "bottom"]
sections = []
if top_panels:
layout_top = Layout(name="top", size=3)
sections.append(layout_top)
if left_panels or right_panels:
layout_middle = Layout(name="middle", ratio=3)
middle_parts = []
if left_panels:
middle_parts.append(Layout(name="left", ratio=1))
if right_panels:
middle_parts.append(Layout(name="right", ratio=1))
if middle_parts:
layout_middle.split_row(*middle_parts)
sections.append(layout_middle)
if bottom_panels:
layout_bottom = Layout(name="bottom", ratio=2)
if len(bottom_panels) > 1:
layout_bottom.split_row(
*[
Layout(name=f"bottom_{i}", ratio=1)
for i in range(len(bottom_panels))
]
)
sections.append(layout_bottom)
if sections:
layout.split_column(*sections)
return layout
def _update_layout(self, layout: Layout) -> None:
top_panels = [p for p in self._panels if p.position == "top"]
left_panels = [p for p in self._panels if p.position == "left"]
right_panels = [p for p in self._panels if p.position == "right"]
bottom_panels = [p for p in self._panels if p.position == "bottom"]
if top_panels:
layout["top"].update(top_panels[0].render(self._state))
if left_panels:
layout["left"].update(left_panels[0].render(self._state))
if right_panels:
layout["right"].update(right_panels[0].render(self._state))
if bottom_panels:
if len(bottom_panels) == 1:
layout["bottom"].update(bottom_panels[0].render(self._state))
else:
for i, panel in enumerate(bottom_panels):
layout[f"bottom_{i}"].update(panel.render(self._state))
def _drain_queue(self) -> None:
while True:
try:
event = self._queue.get_nowait()
except queue.Empty:
break
# Dispatch event to panels first
for panel in self._panels:
panel.on_event(event)
event_type = event.get("type")
if event_type == "metrics":
logs = event.get("logs", {})
self._apply_metrics(logs)
elif event_type == "step":
self._state.current_step = event.get("step", self._state.current_step)
self._state.total_steps = event.get(
"total_steps", self._state.total_steps
)
self._state.current_epoch = event.get(
"epoch", self._state.current_epoch
)
now = time.time()
self._state.elapsed_seconds = now - self._state.start_time.timestamp()
if self._state.current_step > 0 and self._state.total_steps > 0:
rate = self._state.elapsed_seconds / self._state.current_step
remaining = self._state.total_steps - self._state.current_step
self._state.eta_seconds = rate * remaining
elif event_type == "log_line":
level = event.get("level", "info")
message = event.get("message", "")
self._state.log_lines.append(
LogLine(
timestamp=datetime.now(),
level=level,
message=message,
)
)
elif event_type == "completion":
self._state.completions.append(
CompletionSample(
step=event.get("step", 0),
prompt=event.get("prompt", ""),
completion=event.get("completion", ""),
reward=event.get("reward"),
advantage=event.get("advantage"),
)
)
elif event_type == "run_info":
if "run_name" in event:
self._state.run_name = event["run_name"]
if "model_name" in event:
self._state.model_name = event["model_name"]
if "training_mode" in event:
self._state.training_mode = event["training_mode"]
if "world_size" in event:
self._state.world_size = event["world_size"]
if "total_steps" in event:
self._state.total_steps = event["total_steps"]
if "total_epochs" in event:
self._state.total_epochs = event["total_epochs"]
if "zero_stage" in event:
self._state.zero_stage = event["zero_stage"]
elif event_type == "done":
self._stop_event.set()
def _apply_metrics(self, logs: dict[str, Any]) -> None:
metric_map = {
"loss": "loss",
"grad_norm": "grad_norm",
"learning_rate": "learning_rate",
"tokens_per_second": "tokens_per_second",
"samples_per_second": "samples_per_second",
"mfu": "mfu",
"rewards/mean": "rewards_mean",
"rewards_mean": "rewards_mean",
"rewards/std": "rewards_std",
"rewards_std": "rewards_std",
"kl": "kl_divergence",
"kl_divergence": "kl_divergence",
"clip_ratio": "clip_ratio",
"queue_size": "queue_size",
}
for key, value in logs.items():
if key in metric_map:
setattr(self._state, metric_map[key], value)
else:
self._state.extra[key] = value
if "loss" in logs and logs["loss"] is not None:
self._state.loss_history.append(logs["loss"])
def start(self) -> None:
self._init_panels()
self._init_parser_chain()
# Set up I/O capture
assert self._parser_chain is not None, "_init_parser_chain must be called first"
self._io_capture = IOCapture(
log_path=self._config.stdout_log_path,
parser_chain=self._parser_chain,
metric_queue=self._queue,
)
# Monkeypatch tqdm to suppress terminal output and route through our queue.
# This prevents tqdm progress bars from flickering through the TUI and
# ensures all progress events appear in the Events panel.
self._install_tqdm_hook()
self._io_capture_ready = threading.Event()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
self._io_capture_ready.wait(timeout=5.0)
def _install_tqdm_hook(self) -> None:
"""Replace tqdm's display method to route updates through TUI queue."""
try:
import io
import tqdm
import tqdm.auto
q = self._queue
self._tqdm_parser = None
# Find our tqdm parser in the chain
for p in self._parser_chain._parsers if self._parser_chain else []:
if p.name == "tqdm":
self._tqdm_parser = p
break
# Save originals for restore
self._orig_tqdm_class_auto = tqdm.auto.tqdm
self._orig_tqdm_class_tqdm = tqdm.tqdm
self._orig_tqdm_class_std = tqdm.std.tqdm
class TUITqdm(tqdm.tqdm):
"""tqdm subclass that sends progress to TUI instead of terminal."""
def __init__(self, *args, **kwargs):
# Force output to devnull so nothing reaches the terminal
kwargs["file"] = io.StringIO()
kwargs["dynamic_ncols"] = False
kwargs["ncols"] = 80
super().__init__(*args, **kwargs)
def display(self, msg=None, pos=None):
# Build a progress string and push to queue
if self.total and self.total > 0:
pct = self.n / self.total * 100
desc = self.desc.rstrip(": ") if self.desc else ""
# Emit events at milestones or at low frequency
is_milestone = (
self.n == 0 or self.n >= self.total or int(pct) % 25 == 0
)
if is_milestone:
try:
q.put_nowait(
{
"type": "log_line",
"level": "info",
"message": f"[{desc}] {pct:.0f}% ({self.n}/{self.total})"
if desc
else f"{pct:.0f}% ({self.n}/{self.total})",
}
)
except Exception:
pass
try:
metric_key = (
f"progress/{desc.lower().replace(' ', '_')}"
if desc
else "progress/unknown"
)
q.put_nowait(
{
"type": "metrics",
"logs": {metric_key: pct / 100.0},
}
)
except Exception:
pass
def close(self):
# Emit final completion event
if self.total and self.total > 0 and self.n > 0:
desc = self.desc.rstrip(": ") if self.desc else ""
try:
q.put_nowait(
{
"type": "log_line",
"level": "info",
"message": f"[{desc}] 100% ({self.total}/{self.total}) done"
if desc
else f"100% ({self.total}/{self.total}) done",
}
)
except Exception:
pass
super().close()
# Replace tqdm globally
tqdm.auto.tqdm = TUITqdm
tqdm.tqdm = TUITqdm
# Also patch tqdm.std which some libraries use directly
tqdm.std.tqdm = TUITqdm
self._tui_tqdm_cls = TUITqdm
except Exception as exc:
LOG.debug(f"Failed to install tqdm hook: {exc}")
def _uninstall_tqdm_hook(self) -> None:
"""Restore original tqdm."""
try:
import tqdm
import tqdm.auto
if hasattr(self, "_orig_tqdm_class_auto"):
tqdm.auto.tqdm = self._orig_tqdm_class_auto
if hasattr(self, "_orig_tqdm_class_tqdm"):
tqdm.tqdm = self._orig_tqdm_class_tqdm
if hasattr(self, "_orig_tqdm_class_std"):
tqdm.std.tqdm = self._orig_tqdm_class_std
except Exception:
pass
def stop(self) -> None:
self._stop_event.set()
self._uninstall_tqdm_hook()
if self._thread is not None:
self._thread.join(timeout=5.0)
def _run(self) -> None:
import os
# Save a handle to the REAL terminal BEFORE IO capture redirects fds.
# This ensures rich.live.Live writes to the terminal, not the pipe.
saved_tty_fd = os.dup(1)
tty_file = os.fdopen(saved_tty_fd, "w", buffering=1, closefd=True)
console = Console(file=tty_file)
layout = self._build_layout()
tick_interval = 1.0 / max(self._config.refresh_rate, 1)
gpu_poll_counter = 0
gpu_poll_ticks = max(
1, int(self._config.hardware_poll_interval / tick_interval)
)
# Start I/O capture — redirects fd 1/2 to pipe AFTER we saved the tty fd
if self._io_capture:
self._io_capture.start()
# Signal that IO capture is live so start() can return
if hasattr(self, "_io_capture_ready"):
self._io_capture_ready.set()
try:
with Live(
layout,
console=console,
refresh_per_second=self._config.refresh_rate,
screen=True,
redirect_stdout=False,
redirect_stderr=False,
) as live:
while not self._stop_event.is_set():
self._drain_queue()
# Poll GPU stats periodically
gpu_poll_counter += 1
if gpu_poll_counter >= gpu_poll_ticks:
gpu_poll_counter = 0
if self._gpu_poller.available:
self._state.gpus = self._gpu_poller.poll()
# Update elapsed time
self._state.elapsed_seconds = (
time.time() - self._state.start_time.timestamp()
)
self._update_layout(layout)
live.update(layout)
time.sleep(tick_interval)
# Final drain
self._drain_queue()
self._update_layout(layout)
live.update(layout)
finally:
if self._io_capture:
self._io_capture.stop()
try:
tty_file.close()
except Exception:
pass

88
src/axolotl/tui/state.py Normal file
View File

@@ -0,0 +1,88 @@
"""TUI shared data model — dataclasses for the dashboard state."""
from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
@dataclass
class GPUStats:
id: int
name: str
util_pct: float
vram_used_gb: float
vram_total_gb: float
temp_c: int
power_w: float | None
@dataclass
class LogLine:
timestamp: datetime
level: str # "info" | "debug" | "warning" | "error"
message: str
@dataclass
class CompletionSample:
step: int
prompt: str
completion: str
reward: float | None
advantage: float | None
@dataclass
class TUIState:
# Run metadata
run_name: str = ""
model_name: str = ""
training_mode: str = "sft"
world_size: int = 1
start_time: datetime = field(default_factory=datetime.now)
# Progress
current_step: int = 0
total_steps: int = 0
current_epoch: float = 0.0
total_epochs: float = 1.0
elapsed_seconds: float = 0.0
eta_seconds: float | None = None
# Training metrics (rolling window + current)
loss: float | None = None
grad_norm: float | None = None
learning_rate: float | None = None
tokens_per_second: float | None = None
samples_per_second: float | None = None
mfu: float | None = None
# RL-specific (None for non-RL modes)
rewards_mean: float | None = None
rewards_std: float | None = None
kl_divergence: float | None = None
clip_ratio: float | None = None
queue_size: int | None = None
# Per-GPU hardware (list indexed by local rank)
gpus: list[GPUStats] = field(default_factory=list)
# Recent log lines
log_lines: deque[LogLine] = field(default_factory=lambda: deque(maxlen=200))
# Recent completions (GRPO/SFT with log_completions)
completions: deque[CompletionSample] = field(
default_factory=lambda: deque(maxlen=20)
)
# Loss history for sparkline
loss_history: deque[float] = field(default_factory=lambda: deque(maxlen=50))
# DeepSpeed zero stage (None if not using DeepSpeed)
zero_stage: int | None = None
# Arbitrary plugin state
extra: dict[str, Any] = field(default_factory=dict)

View File

@@ -13,6 +13,7 @@ from pydantic import (
model_validator,
)
from axolotl.tui.config import TUIConfig
from axolotl.utils.datasets import get_default_process_count
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.datasets import (
@@ -140,6 +141,12 @@ class AxolotlInputConfig(
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(),
)
tui: TUIConfig | None = Field(
default=None,
json_schema_extra={
"description": "TUI dashboard configuration. Set enabled: true to activate."
},
)
qat: QATConfig | None = None
quantization: PTQConfig | None = None
reward_model: bool | None = Field(
@@ -703,6 +710,12 @@ class AxolotlInputConfig(
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
},
)
lora_embedding_kernel: bool | None = Field(
default=None,
json_schema_extra={
"description": "Apply custom LoRA autograd function for embedding layers. See: https://docs.axolotl.ai/docs/lora_optims.html"
},
)
chunked_cross_entropy: bool | None = Field(
default=None,
@@ -1313,6 +1326,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
or data.get("lora_embedding_kernel")
):
capabilities = data.get("capabilities")
is_fsdp = data.get("fsdp_config") is not None
@@ -1360,7 +1374,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("adapter") in ["lora", "qlora"]:
# Skip if already set, using unsloth optimizations, or using 8-bit
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
kernel_fields = [
"lora_mlp_kernel",
"lora_qkv_kernel",
"lora_o_kernel",
"lora_embedding_kernel",
]
if (
any(data.get(k) is not None for k in kernel_fields)
or any(data.get(k) for k in unsloth_fields)
@@ -1373,10 +1392,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("trust_remote_code"):
return data
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
if data.get("lora_dropout") != 0:
return data
# Check multi-GPU compatibility
capabilities = data.get("capabilities")
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
@@ -1398,6 +1413,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("lora_o_kernel") is None:
data["lora_o_kernel"] = True
if data.get("lora_embedding_kernel") is None:
data["lora_embedding_kernel"] = True
LOG.warning(
"Auto-enabling LoRA kernel optimizations for faster training. "
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "

View File

@@ -681,15 +681,7 @@ class LoRAValidationMixin:
@model_validator(mode="before")
@classmethod
def check_lora_kernels_dora(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
) and data.get("peft_use_dora"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
"compatible with DoRA at the moment."
)
# DoRA is now supported by lora kernels
return data
@model_validator(mode="before")

View File

@@ -153,7 +153,7 @@ class TestLoraFP8Guard(unittest.TestCase):
proj.base_layer = base_layer
W, b, quant_state, A, B, s = get_lora_parameters(proj)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
# quant_state should be None since weight is bf16, not FP8
self.assertIsNone(quant_state)
@@ -174,7 +174,7 @@ class TestLoraFP8Guard(unittest.TestCase):
scale_inv = torch.ones(1)
base_layer.weight_scale_inv = scale_inv
W, b, quant_state, A, B, s = get_lora_parameters(proj)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
self.assertIs(quant_state, scale_inv)

View File

@@ -102,7 +102,7 @@ def mock_proj():
def test_get_lora_parameters(mock_proj):
"""Tests get_lora_parameters function"""
# Test with LoRA enabled
W, b, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
assert isinstance(W, torch.Tensor)
assert W.shape == (128, 64)
@@ -113,13 +113,13 @@ def test_get_lora_parameters(mock_proj):
# Test with LoRA disabled
mock_proj.disable_adapters = True
W, b, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
# Test with merged state
mock_proj.disable_adapters = False
mock_proj.merged = True
W, b, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None

File diff suppressed because it is too large Load Diff

View File

@@ -86,5 +86,5 @@ class TestPackedFlex:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.1, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
)

View File

@@ -37,7 +37,7 @@ def verify_training_success(temp_dir):
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/loss"]
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (

View File

@@ -37,7 +37,7 @@ def verify_fp8_training_success(temp_dir):
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/loss"]
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (

View File

@@ -38,7 +38,7 @@ def verify_training_success(temp_dir):
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/loss"]
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (

View File

@@ -38,7 +38,7 @@ def verify_training_success(temp_dir):
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/loss"]
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (

View File

@@ -0,0 +1,120 @@
"""Test LoRA kernels under FSDP2 multi-GPU training.
Verifies that lora_qkv_kernel, lora_o_kernel, lora_mlp_kernel, and
lora_embedding_kernel work correctly with FSDP2 sharding, including
with bias, dropout, and DoRA enabled.
"""
from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import require_torch_2_7_0
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
def _run_training(temp_dir, cfg):
"""Write config and launch multi-GPU training."""
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
def _base_lora_fsdp2_config(temp_dir, **overrides):
"""Base config for LoRA + FSDP2 + kernel tests."""
cfg = {
"base_model": "Qwen/Qwen3-0.6B",
"sequence_len": 512,
"val_set_size": 0.0,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:1%]",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 3,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"bf16": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen3DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
# Enable all LoRA kernels
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
"lora_embedding_kernel": True,
"save_safetensors": True,
}
cfg.update(overrides)
return DictDefault(cfg)
class TestFSDP2LoRAKernels:
"""Test LoRA kernels under FSDP2."""
@require_torch_2_7_0
def test_lora_kernels_basic(self, temp_dir):
"""Basic LoRA + kernels + FSDP2: no dropout, no bias, no DoRA."""
cfg = _base_lora_fsdp2_config(temp_dir)
_run_training(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@require_torch_2_7_0
def test_lora_kernels_with_dropout(self, temp_dir):
"""LoRA kernels + dropout + FSDP2."""
cfg = _base_lora_fsdp2_config(temp_dir, lora_dropout=0.1)
_run_training(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@require_torch_2_7_0
def test_lora_kernels_with_dora(self, temp_dir):
"""LoRA kernels + DoRA + FSDP2."""
cfg = _base_lora_fsdp2_config(temp_dir, peft_use_dora=True)
_run_training(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@require_torch_2_7_0
def test_lora_kernels_with_dora_and_dropout(self, temp_dir):
"""LoRA kernels + DoRA + dropout + FSDP2."""
cfg = _base_lora_fsdp2_config(
temp_dir,
peft_use_dora=True,
lora_dropout=0.05,
)
_run_training(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()

View File

@@ -94,5 +94,5 @@ class TestMultiGPUGemma3:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.8, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss (%s) is too high"
)

View File

@@ -90,7 +90,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.8, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.8, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -156,7 +156,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
def test_dpo_lora_ddp(self, temp_dir):
@@ -233,7 +233,7 @@ class TestMultiGPULlama:
loss_threshold = 2.3
check_tensorboard(
temp_dir + "/runs",
"train/loss",
"train/train_loss",
loss_threshold,
"Train Loss (%s) is too high",
)
@@ -312,7 +312,7 @@ class TestMultiGPULlama:
loss_threshold = 2.3
check_tensorboard(
temp_dir + "/runs",
"train/loss",
"train/train_loss",
loss_threshold,
"Train Loss (%s) is too high",
)
@@ -385,7 +385,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -461,7 +461,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@require_torch_2_6_0
@@ -543,7 +543,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.1, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
)
def test_fsdp_qlora_prequant_packed(self, temp_dir):
@@ -623,7 +623,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -708,7 +708,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.45, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.45, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -784,7 +784,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -859,7 +859,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.5, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
)
@pytest.mark.skip(
@@ -925,5 +925,5 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 4.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss (%s) is too high"
)

View File

@@ -79,7 +79,7 @@ class TestMultiGPURay:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@require_torch_2_7_0
@@ -138,7 +138,7 @@ class TestMultiGPURay:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@require_torch_2_7_0
@@ -205,5 +205,5 @@ class TestMultiGPURay:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.3, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)

View File

@@ -64,5 +64,5 @@ class TestTensorParallel:
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high"
)

View File

@@ -222,9 +222,9 @@ def test_model_specific_activation(model_name, expected_activation):
def test_kernel_patch_conditions():
"""Test various conditions that should prevent kernel patching."""
"""Test that kernels ARE patched even with dropout and bias (now supported)."""
test_configs = [
# Dropout prevents patching
# Dropout — kernels now support this
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
@@ -234,7 +234,7 @@ def test_kernel_patch_conditions():
"lora_dropout": 0.1,
"bias": "none",
},
# Bias prevents patching
# Bias — kernels now support this
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
@@ -252,13 +252,14 @@ def test_kernel_patch_conditions():
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
# Should not patch
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0].mlp
# Verify no patches applied
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
assert layer.forward.__func__ is not apply_lora_mlp_geglu
# Verify patches ARE applied (dropout and bias are now supported)
assert (
layer.forward.__func__ is apply_lora_mlp_swiglu
or layer.forward.__func__ is apply_lora_mlp_geglu
)
def test_kernel_config_options():
@@ -511,7 +512,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
def test_kernel_training_integration_dropout_non_zero(temp_dir):
"""Test model loading with dropout non-zero should not patch."""
"""Test model loading with dropout non-zero DOES patch (now supported)."""
from axolotl.cli.utils import load_model_and_tokenizer
@@ -546,31 +547,18 @@ def test_kernel_training_integration_dropout_non_zero(temp_dir):
# Load config
cfg = load_cfg(str(path))
# Get original attention class
attention_cls = get_attention_cls_from_config(cfg)
# Store original state before patching
original_forward_method = attention_cls.forward
# Load model
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
# We call modelloader as that's where the patches are applied
# despite the fact that we're not using it to load the model
model_loader = ModelLoader(cfg, tokenizer)
# Apply patch
# Apply patches — should succeed even with dropout > 0
model_loader.patch_manager._apply_self_attention_lora_patch()
# Verify patch was not applied
assert attention_cls.forward == original_forward_method
# Apply apply_lora_kernel_patches
model_loader.patch_manager._apply_lora_kernel_patch(model)
# Verify patch was not applied
# Verify patches WERE applied (dropout is now supported by kernels)
layers = get_layers(model)
for layer in layers:
for self_attn in find_self_attn_in_layer(layer):
assert not hasattr(self_attn, "apply_qkv")
assert not hasattr(self_attn, "apply_o")
assert hasattr(self_attn, "apply_qkv")
assert hasattr(self_attn, "apply_o")

View File

@@ -78,5 +78,5 @@ class TestFAXentropyLlama:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.5, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high"
)

View File

@@ -77,5 +77,5 @@ class TestFAFlattening:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.5, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high"
)

View File

@@ -73,7 +73,7 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
@@ -124,7 +124,7 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -180,5 +180,5 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)

View File

@@ -63,5 +63,5 @@ class TestPackedFlex(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.1, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
)

View File

@@ -57,7 +57,9 @@ class TestEmbeddingsLrScale(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(temp_dir + "/runs", "train/loss", 2.0, "Loss is too high")
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
)
@with_temp_dir
def test_train_w_embedding_lr(self, temp_dir):
@@ -98,4 +100,6 @@ class TestEmbeddingsLrScale(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(temp_dir + "/runs", "train/loss", 2.0, "Loss is too high")
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
)

View File

@@ -66,7 +66,7 @@ class TestPretrainLlama:
loss_threshold = 6.5
check_tensorboard(
temp_dir + "/runs",
"train/loss",
"train/train_loss",
loss_threshold,
"Train Loss (%s) is too high",
)

View File

@@ -62,5 +62,5 @@ class TestPackedLlama(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.0, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)

View File

@@ -57,7 +57,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.7, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.7, "Train Loss (%s) is too high"
)
check_model_output_exists(temp_dir, cfg)

View File

@@ -128,7 +128,7 @@ class TestQATLlama:
loss_threshold = 2.3
check_tensorboard(
temp_dir + "/runs",
"train/loss",
"train/train_loss",
loss_threshold,
"Train Loss (%s) is too high",
)

View File

@@ -66,6 +66,6 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/loss", 2.5, "Train Loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
)
check_model_output_exists(temp_dir, cfg)

View File

@@ -66,7 +66,7 @@ class TestStreamingDatasets:
# Verify training actually happened by checking loss decrease
check_tensorboard(
temp_dir + "/runs",
"train/loss",
"train/train_loss",
3.0,
"Train Loss (%s) is too high",
)

View File

@@ -179,7 +179,7 @@ def check_tensorboard(
tag: str,
lt_val: float,
assertion_err: str,
rtol: float = 0.02,
rtol: float = 0.05,
gt_zero: bool = True,
) -> None:
"""

View File

@@ -0,0 +1,229 @@
"""
Correctness tests for fused RMSNorm + SiLU Gate kernel.
Tests against the eager Qwen3_5RMSNormGated implementation.
"""
import pytest
import torch
import torch.nn.functional as F
pytest.importorskip("triton", reason="triton required for fused kernels")
if not torch.cuda.is_available():
pytest.skip("CUDA required for fused kernel tests", allow_module_level=True)
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
class EagerRMSNormGated(torch.nn.Module):
"""Reference implementation matching Qwen3_5RMSNormGated exactly."""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight * hidden_states.to(input_dtype)
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
return hidden_states.to(input_dtype)
def _sync_weights(eager_mod, fused_mod):
"""Copy weights from eager to fused module."""
fused_mod.weight.data.copy_(eager_mod.weight.data)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
"shape",
[
(2, 128, 256),
(4, 64, 512),
(1, 32, 1024),
(2, 16, 2560), # Qwen3.5-4B hidden_size
(2, 16, 4096), # Qwen3.5-9B hidden_size
(1, 8, 5120), # Qwen3.5-27B hidden_size
(4, 16, 2048), # Qwen3.5-35B-A3B (MoE) hidden_size
(4, 16, 3072), # Qwen3.5-122B-A10B (MoE) hidden_size
],
)
class TestRMSNormGatedForward:
def test_output_matches_eager(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X, gate=G)
y_fused = fused(X, gate=G)
if dtype == torch.float32:
torch.testing.assert_close(y_fused, y_eager, atol=1e-5, rtol=1e-5)
else:
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
def test_output_shape(self, dtype, shape):
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
y = fused(X, gate=G)
assert y.shape == (B, T, H)
assert y.dtype == dtype
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
"shape",
[
(2, 32, 256),
(2, 16, 512),
(2, 16, 2560), # Qwen3.5-4B
(1, 8, 4096), # Qwen3.5-9B
(1, 8, 5120), # Qwen3.5-27B
(2, 16, 2048), # Qwen3.5-35B-A3B (MoE)
(2, 16, 3072), # Qwen3.5-122B-A10B (MoE)
],
)
class TestRMSNormGatedBackward:
def test_grad_x(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
if dtype == torch.float32:
atol, rtol = 1e-4, 1e-4
else:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(X.grad, X_ref.grad, atol=atol, rtol=rtol)
def test_grad_gate(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
if dtype == torch.float32:
atol, rtol = 1e-4, 1e-4
else:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(G.grad, G_ref.grad, atol=atol, rtol=rtol)
def test_grad_weight(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
if dtype == torch.float32:
atol, rtol = 1e-4, 1e-4
else:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(
fused.weight.grad, eager.weight.grad, atol=atol, rtol=rtol
)
class TestRMSNormGatedEdgeCases:
def test_gate_none_raises(self):
fused = FusedRMSNormGated(256).cuda()
X = torch.randn(2, 4, 256, device="cuda")
with pytest.raises(ValueError, match="requires a gate tensor"):
fused(X, gate=None)
def test_2d_input(self):
"""Test with (BxT, H) shaped input instead of (B, T, H)."""
torch.manual_seed(42)
H = 512
X = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
G = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
torch.testing.assert_close(X.grad, X_ref.grad, atol=5e-2, rtol=5e-2)
torch.testing.assert_close(G.grad, G_ref.grad, atol=5e-2, rtol=5e-2)
def test_random_weight_init(self):
"""Test with non-default weight values."""
torch.manual_seed(123)
H = 256
X = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
G = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
# Randomize weights
eager.weight.data = torch.randn_like(eager.weight.data)
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X, gate=G)
y_fused = fused(X, gate=G)
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)

View File

@@ -28,20 +28,22 @@ class TestLoRAConfigValidation:
result = validate_config(valid_config)
assert result["adapter"] == "lora"
with pytest.raises(ValueError, match="not compatible with DoRA"):
invalid_config = DictDefault(
{
"adapter": "lora",
"lora_mlp_kernel": True,
"peft_use_dora": True,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
validate_config(invalid_config)
# DoRA is now compatible with lora kernels
dora_kernel_config = DictDefault(
{
"adapter": "lora",
"lora_mlp_kernel": True,
"peft_use_dora": True,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
result = validate_config(dora_kernel_config)
assert result["lora_mlp_kernel"] is True
assert result["peft_use_dora"] is True
def test_qlora_4bit_validation(self):
"""Test QLoRA 4-bit configuration validation"""

View File

@@ -38,6 +38,11 @@ class TestLoRAParameterFreezing:
mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype)
mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype)
mock_layer.lora_B["default"].bias = None
# Required by get_lora_parameters for dropout/DoRA extraction
mock_layer.lora_dropout = {}
mock_layer.lora_magnitude_vector = None
else:
mock_layer.weight = base_layer.weight
mock_layer.bias = base_layer.bias
@@ -48,7 +53,7 @@ class TestLoRAParameterFreezing:
"""Test that LoRA parameters are None when adapters are disabled."""
layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
@@ -62,7 +67,7 @@ class TestLoRAParameterFreezing:
"""Test that LoRA parameters are None when adapters are merged."""
layer = self.create_mock_lora_layer(has_adapters=True, merged=True)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
@@ -77,7 +82,7 @@ class TestLoRAParameterFreezing:
"""Test parameter behavior when no adapters are present."""
layer = self.create_mock_lora_layer(has_adapters=False)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
@@ -94,7 +99,7 @@ class TestLoRAParameterFreezing:
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
# All parameters should be returned
assert W is not None
@@ -110,7 +115,7 @@ class TestLoRAParameterFreezing:
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
# Check shape consistency
assert W.shape == (512, 256)
@@ -124,7 +129,7 @@ class TestLoRAParameterFreezing:
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
assert W.dtype == self.dtype
assert b.dtype == self.dtype
@@ -138,7 +143,7 @@ class TestLoRAParameterFreezing:
quant_state_mock = Mock()
layer.base_layer.weight.quant_state = quant_state_mock
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
assert quant_state == quant_state_mock
@@ -157,7 +162,7 @@ class TestLoRAParameterFreezing:
layer.active_adapters = ["adapter2"]
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
assert s == 0.2
assert torch.equal(A, layer.lora_A["adapter2"].weight)
@@ -192,13 +197,13 @@ class TestLoRAParameterFreezingIntegration:
model = get_peft_model(base_model, lora_config)
lora_layer = model.base_model.model.linear
# Test with adapters enabled
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer)
assert A is not None
assert B is not None
assert s is not None
# Test with adapters disabled
model.disable_adapter_layers()
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer)
assert A is None
assert B is None
assert s is None