feat: Add SwanLab integration for experiment tracking (#3334)

* feat(swanlab): add SwanLab integration for experiment tracking

SwanLab integration provides comprehensive experiment tracking and monitoring for Axolotl training.

Features:
- Hyperparameter logging
- Training metrics tracking
- RLHF completion logging
- Performance profiling
- Configuration validation and conflict detection

Includes:
- Plugin in src/axolotl/integrations/swanlab/
- Callback in src/axolotl/utils/callbacks/swanlab.py
- Tests in tests/integrations/test_swanlab.py
- Examples in examples/swanlab/

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* fix(swanlab): address PR #3334 review feedback from winglian and CodeRabbit

- Change use_swanlab default to True (winglian)
- Clear buffer after periodic logging to prevent duplicates (CodeRabbit Major)
- Add safe exception handling in config fallback (CodeRabbit)
- Use context managers for file operations (CodeRabbit)
- Replace LOG.error with LOG.exception for better debugging (CodeRabbit)
- Sort __all__ alphabetically (CodeRabbit)
- Add language specifiers to README code blocks (CodeRabbit)
- Fix end-of-file newline in README (pre-commit)

Resolves actionable comments and nitpicks from CodeRabbit review.
Addresses reviewer feedback from @winglian.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* only run swanlab integration tests if package is available

---------

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
PraMamba
2026-01-06 22:19:18 +08:00
committed by GitHub
parent ee59e4de97
commit 8aab807e67
14 changed files with 5438 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
"""SwanLab integration plugin for Axolotl"""
from axolotl.integrations.swanlab.args import SwanLabConfig
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
__all__ = ["SwanLabConfig", "SwanLabPlugin"]

View File

@@ -0,0 +1,140 @@
"""SwanLab configuration arguments"""
from pydantic import BaseModel, Field, field_validator, model_validator
class SwanLabConfig(BaseModel):
"""SwanLab configuration subset"""
use_swanlab: bool | None = Field(
default=True,
json_schema_extra={
"description": "Enable SwanLab experiment tracking and visualization"
},
)
swanlab_project: str | None = Field(
default=None,
json_schema_extra={"description": "Your SwanLab project name"},
)
swanlab_experiment_name: str | None = Field(
default=None,
json_schema_extra={"description": "Set the name of your SwanLab experiment"},
)
swanlab_description: str | None = Field(
default=None,
json_schema_extra={"description": "Description for your SwanLab experiment"},
)
swanlab_mode: str | None = Field(
default=None,
json_schema_extra={
"description": '"cloud" to sync to SwanLab cloud, "local" for local only, "offline" to save metadata locally, "disabled" to turn off SwanLab'
},
)
swanlab_workspace: str | None = Field(
default=None,
json_schema_extra={
"description": "SwanLab workspace name (organization or username)"
},
)
swanlab_api_key: str | None = Field(
default=None,
json_schema_extra={
"description": "SwanLab API key for authentication. Can also be set via SWANLAB_API_KEY environment variable"
},
)
swanlab_log_model: bool | None = Field(
default=False,
json_schema_extra={
"description": "Whether to log model checkpoints to SwanLab (feature coming soon)"
},
)
swanlab_web_host: str | None = Field(
default=None,
json_schema_extra={
"description": "Web address for SwanLab cloud environment (for private deployment)"
},
)
swanlab_api_host: str | None = Field(
default=None,
json_schema_extra={
"description": "API address for SwanLab cloud environment (for private deployment)"
},
)
swanlab_lark_webhook_url: str | None = Field(
default=None,
json_schema_extra={
"description": "Lark (Feishu) webhook URL for sending training notifications to team chat"
},
)
swanlab_lark_secret: str | None = Field(
default=None,
json_schema_extra={
"description": "Secret for Lark webhook HMAC signature authentication (optional)"
},
)
swanlab_log_completions: bool | None = Field(
default=True,
json_schema_extra={
"description": "Enable logging RLHF completions to SwanLab for qualitative analysis (DPO/KTO/ORPO/GRPO)"
},
)
swanlab_completion_log_interval: int | None = Field(
default=100,
json_schema_extra={
"description": "Number of training steps between completion table logging to SwanLab"
},
)
swanlab_completion_max_buffer: int | None = Field(
default=128,
json_schema_extra={
"description": "Maximum number of completions to buffer before logging (prevents memory leaks)"
},
)
@field_validator("swanlab_mode")
@classmethod
def validate_swanlab_mode(cls, v):
"""Validate swanlab_mode is one of the allowed values."""
if v is None:
return v
valid_modes = ["cloud", "local", "offline", "disabled"]
if v not in valid_modes:
raise ValueError(
f"Invalid swanlab_mode: '{v}'.\n\n"
f"Valid options: {', '.join(valid_modes)}\n\n"
f"Examples:\n"
f" swanlab_mode: cloud # Sync to SwanLab cloud\n"
f" swanlab_mode: local # Local only, no cloud sync\n"
f" swanlab_mode: offline # Save metadata locally\n"
f" swanlab_mode: disabled # Turn off SwanLab\n"
)
return v
@field_validator("swanlab_project")
@classmethod
def validate_swanlab_project(cls, v):
"""Validate swanlab_project is non-empty when provided."""
if v is not None and isinstance(v, str) and len(v.strip()) == 0:
raise ValueError(
"swanlab_project cannot be an empty string.\n\n"
"Either:\n"
" 1. Provide a valid project name: swanlab_project: my-project\n"
" 2. Remove the swanlab_project field entirely\n"
)
return v
@model_validator(mode="after")
def validate_swanlab_enabled_requires_project(self):
"""Validate that if use_swanlab is True, swanlab_project must be set."""
if self.use_swanlab is True and not self.swanlab_project:
raise ValueError(
"SwanLab enabled (use_swanlab: true) but 'swanlab_project' is not set.\n\n"
"Solutions:\n"
" 1. Add 'swanlab_project: your-project-name' to your config\n"
" 2. Set 'use_swanlab: false' to disable SwanLab\n\n"
"Example:\n"
" use_swanlab: true\n"
" swanlab_project: my-llm-training\n"
)
return self

View File

@@ -0,0 +1,179 @@
"""SwanLab callbacks for Axolotl trainers.
This module provides HuggingFace Trainer callbacks for logging
RLHF completions to SwanLab.
"""
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from axolotl.integrations.swanlab.completion_logger import CompletionLogger
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class SwanLabRLHFCompletionCallback(TrainerCallback):
"""Callback for logging RLHF completions to SwanLab.
This callback periodically logs model completions (prompts, chosen/rejected
responses, rewards) to SwanLab during RLHF training for qualitative analysis.
Supports DPO, KTO, ORPO, and GRPO trainers.
Example usage:
>>> callback = SwanLabRLHFCompletionCallback(
... log_interval=100, # Log every 100 steps
... max_completions=128, # Keep last 128 completions
... )
>>> trainer.add_callback(callback)
Attributes:
logger: CompletionLogger instance
log_interval: Number of steps between SwanLab logging
trainer_type: Auto-detected trainer type (dpo/kto/orpo/grpo)
"""
def __init__(
self,
log_interval: int = 100,
max_completions: int = 128,
table_name: str = "rlhf_completions",
):
"""Initialize SwanLab RLHF completion callback.
Args:
log_interval: Log to SwanLab every N steps. Default: 100
max_completions: Maximum completions to buffer. Default: 128
table_name: SwanLab table name. Default: "rlhf_completions"
"""
super().__init__()
self.logger = CompletionLogger(maxlen=max_completions)
self.log_interval = log_interval
self.table_name = table_name
self.trainer_type: str | None = None # Auto-detected
self._last_logged_step = 0
def on_init_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Detect trainer type on initialization."""
trainer = kwargs.get("trainer")
if trainer is not None:
trainer_name = trainer.__class__.__name__
if "DPO" in trainer_name:
self.trainer_type = "dpo"
elif "KTO" in trainer_name:
self.trainer_type = "kto"
elif "ORPO" in trainer_name:
self.trainer_type = "orpo"
elif "GRPO" in trainer_name:
self.trainer_type = "grpo"
else:
self.trainer_type = "unknown"
LOG.info(
f"SwanLab RLHF completion logging enabled for {trainer_name} "
f"(type: {self.trainer_type})"
)
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs: dict | None = None,
**kwargs,
):
"""Capture completions from logs and buffer them.
Different trainers log completions in different formats:
- DPO: logs['dpo/chosen'], logs['dpo/rejected'], logs['dpo/reward_diff']
- KTO: logs['kto/completion'], logs['kto/label'], logs['kto/reward']
- ORPO: logs['orpo/chosen'], logs['orpo/rejected']
- GRPO: logs['grpo/completion'], logs['grpo/reward']
Note: This is a placeholder implementation. Actual log keys depend
on the TRL trainer implementation. You may need to patch the trainers
to expose completion data in logs.
"""
if logs is None or self.trainer_type is None:
return
step = state.global_step
# DPO completions
if self.trainer_type == "dpo":
if all(key in logs for key in ["dpo/prompt", "dpo/chosen", "dpo/rejected"]):
self.logger.add_dpo_completion(
step=step,
prompt=logs.get("dpo/prompt", ""),
chosen=logs.get("dpo/chosen", ""),
rejected=logs.get("dpo/rejected", ""),
reward_diff=logs.get("dpo/reward_diff"),
)
# KTO completions
elif self.trainer_type == "kto":
if all(key in logs for key in ["kto/prompt", "kto/completion"]):
self.logger.add_kto_completion(
step=step,
prompt=logs.get("kto/prompt", ""),
completion=logs.get("kto/completion", ""),
label=logs.get("kto/label", False),
reward=logs.get("kto/reward"),
)
# ORPO completions
elif self.trainer_type == "orpo":
if all(
key in logs for key in ["orpo/prompt", "orpo/chosen", "orpo/rejected"]
):
self.logger.add_orpo_completion(
step=step,
prompt=logs.get("orpo/prompt", ""),
chosen=logs.get("orpo/chosen", ""),
rejected=logs.get("orpo/rejected", ""),
log_odds_ratio=logs.get("orpo/log_odds_ratio"),
)
# GRPO completions
elif self.trainer_type == "grpo":
if all(key in logs for key in ["grpo/prompt", "grpo/completion"]):
self.logger.add_grpo_completion(
step=step,
prompt=logs.get("grpo/prompt", ""),
completion=logs.get("grpo/completion", ""),
reward=logs.get("grpo/reward"),
advantage=logs.get("grpo/advantage"),
)
# Periodically log to SwanLab
if step - self._last_logged_step >= self.log_interval:
if len(self.logger) > 0:
self.logger.log_to_swanlab(table_name=self.table_name)
self.logger.clear()
self._last_logged_step = step
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Log remaining completions at end of training."""
if len(self.logger) > 0:
LOG.info(
f"Training complete, logging final {len(self.logger)} completions to SwanLab"
)
self.logger.log_to_swanlab(table_name=self.table_name)
self._last_logged_step = state.global_step

View File

@@ -0,0 +1,228 @@
"""SwanLab completion logger for RLHF/DPO/KTO/ORPO/GRPO training.
This module provides utilities for logging model completions during
preference training to SwanLab for qualitative analysis.
"""
from collections import deque
from collections.abc import Mapping
from typing import Any
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class CompletionLogger:
"""Memory-bounded logger for RLHF completions.
Stores prompts, completions, and rewards in fixed-size deques to prevent
memory leaks during long training runs. Logs completion tables to SwanLab
for qualitative analysis of model outputs.
Example usage:
>>> logger = CompletionLogger(maxlen=128)
>>> logger.add_dpo_completion(
... step=0,
... prompt="What is AI?",
... chosen="Artificial Intelligence is...",
... rejected="AI means...",
... reward_diff=0.5
... )
>>> logger.log_to_swanlab()
Attributes:
maxlen: Maximum number of completions to store (older ones are dropped)
data: Deque storing completion dictionaries
"""
def __init__(self, maxlen: int = 128):
"""Initialize completion logger with bounded buffer.
Args:
maxlen: Maximum number of completions to store. When the buffer
is full, oldest completions are automatically discarded.
Default: 128 (sufficient for most RLHF runs without memory issues)
"""
self.maxlen = maxlen
self.data: deque[Mapping[str, Any]] = deque(maxlen=maxlen)
def add_dpo_completion(
self,
step: int,
prompt: str,
chosen: str,
rejected: str,
reward_diff: float | None = None,
) -> None:
"""Add a DPO completion to the buffer.
Args:
step: Training step number
prompt: Input prompt
chosen: Chosen (preferred) completion
rejected: Rejected (non-preferred) completion
reward_diff: Reward difference (chosen - rejected), if available
"""
entry = {
"step": step,
"prompt": prompt,
"chosen": chosen,
"rejected": rejected,
}
if reward_diff is not None:
entry["reward_diff"] = reward_diff
self.data.append(entry)
def add_kto_completion(
self,
step: int,
prompt: str,
completion: str,
label: bool,
reward: float | None = None,
) -> None:
"""Add a KTO completion to the buffer.
Args:
step: Training step number
prompt: Input prompt
completion: Model-generated completion
label: True if desirable, False if undesirable
reward: Reward score, if available
"""
entry = {
"step": step,
"prompt": prompt,
"completion": completion,
"label": "desirable" if label else "undesirable",
}
if reward is not None:
entry["reward"] = reward
self.data.append(entry)
def add_orpo_completion(
self,
step: int,
prompt: str,
chosen: str,
rejected: str,
log_odds_ratio: float | None = None,
) -> None:
"""Add an ORPO completion to the buffer.
Args:
step: Training step number
prompt: Input prompt
chosen: Chosen (preferred) completion
rejected: Rejected (non-preferred) completion
log_odds_ratio: Log odds ratio between chosen and rejected
"""
entry = {
"step": step,
"prompt": prompt,
"chosen": chosen,
"rejected": rejected,
}
if log_odds_ratio is not None:
entry["log_odds_ratio"] = log_odds_ratio
self.data.append(entry)
def add_grpo_completion(
self,
step: int,
prompt: str,
completion: str,
reward: float | None = None,
advantage: float | None = None,
) -> None:
"""Add a GRPO completion to the buffer.
Args:
step: Training step number
prompt: Input prompt
completion: Model-generated completion
reward: Reward score from reward model
advantage: Advantage estimate (reward - baseline)
"""
entry = {
"step": step,
"prompt": prompt,
"completion": completion,
}
if reward is not None:
entry["reward"] = reward
if advantage is not None:
entry["advantage"] = advantage
self.data.append(entry)
def log_to_swanlab(self, table_name: str = "completions") -> bool:
"""Log buffered completions to SwanLab as a table.
Creates a SwanLab echarts Table with all buffered completions.
Only logs if SwanLab is initialized and data is available.
Args:
table_name: Name of the table in SwanLab dashboard.
Default: "completions"
Returns:
True if logging succeeded, False otherwise
"""
if not self.data:
LOG.debug("No completions to log to SwanLab")
return False
try:
import swanlab
if swanlab.get_run() is None:
LOG.debug("SwanLab not initialized, skipping completion logging")
return False
# Convert deque to list of dicts
completions = list(self.data)
# Extract headers from first entry (all entries should have same structure)
headers = list(completions[0].keys())
# Build rows: each completion becomes one row
rows = []
for completion in completions:
row = [completion.get(header, "") for header in headers]
rows.append(row)
# Log to SwanLab as echarts Table
swanlab.log({table_name: swanlab.echarts.Table().add(headers, rows)})
LOG.info(f"Logged {len(rows)} completions to SwanLab table '{table_name}'")
return True
except ImportError:
LOG.warning(
"SwanLab not installed, cannot log completions. "
"Install with: pip install swanlab"
)
return False
except Exception as err: # pylint: disable=broad-except
LOG.exception("Failed to log completions to SwanLab: %s", err)
return False
def clear(self) -> None:
"""Clear all buffered completions."""
self.data.clear()
def __len__(self) -> int:
"""Return number of buffered completions."""
return len(self.data)
def __repr__(self) -> str:
"""String representation showing buffer status."""
return (
f"CompletionLogger(maxlen={self.maxlen}, "
f"buffered={len(self.data)}/{self.maxlen})"
)

View File

@@ -0,0 +1,554 @@
"""SwanLab Plugin for Axolotl"""
from __future__ import annotations
from typing import TYPE_CHECKING
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from transformers import TrainerCallback
from axolotl.utils.dict import DictDefault
LOG = get_logger(__name__)
class SwanLabPlugin(BasePlugin):
"""
SwanLab integration plugin for Axolotl.
Provides experiment tracking, visualization, and logging capabilities
using SwanLab (https://swanlab.cn).
Usage in config.yaml:
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
use_swanlab: true
swanlab_project: my-project
swanlab_experiment_name: my-experiment
swanlab_mode: cloud # or 'local', 'offline', 'disabled'
"""
def __init__(self):
super().__init__()
self.swanlab_initialized = False
LOG.info("SwanLab plugin initialized")
def get_input_args(self) -> str:
"""Returns the configuration model for SwanLab integration."""
return "axolotl.integrations.swanlab.SwanLabConfig"
def register(self, cfg: dict):
"""Register SwanLab plugin with configuration and conflict detection."""
LOG.info("Registering SwanLab plugin")
# === Conflict Detection: Required Fields ===
# Check if SwanLab is enabled
if cfg.get("use_swanlab"):
# 1. Validate project name is set
if not cfg.get("swanlab_project"):
raise ValueError(
"SwanLab enabled but 'swanlab_project' is not set.\n\n"
"Solutions:\n"
" 1. Add 'swanlab_project: your-project-name' to your config\n"
" 2. Set 'use_swanlab: false' to disable SwanLab\n\n"
"See: src/axolotl/integrations/swanlab/README.md for examples"
)
# 2. Validate swanlab_mode value
valid_modes = ["cloud", "local", "offline", "disabled"]
mode = cfg.get("swanlab_mode")
if mode and mode not in valid_modes:
raise ValueError(
f"Invalid swanlab_mode: '{mode}'.\n\n"
f"Valid options: {', '.join(valid_modes)}\n\n"
f"Example:\n"
f" swanlab_mode: cloud # Sync to SwanLab cloud\n"
f" swanlab_mode: local # Local only, no cloud sync\n"
)
# 3. Check API key for cloud mode
import os
mode = cfg.get("swanlab_mode", "cloud") # Default is cloud
if mode == "cloud":
api_key = cfg.get("swanlab_api_key") or os.environ.get(
"SWANLAB_API_KEY"
)
if not api_key:
LOG.warning(
"SwanLab cloud mode enabled but no API key found.\n"
"SwanLab may fail to initialize during training.\n\n"
"Solutions:\n"
" 1. Set SWANLAB_API_KEY environment variable:\n"
" export SWANLAB_API_KEY=your-api-key\n"
" 2. Add 'swanlab_api_key: your-api-key' to config (less secure)\n"
" 3. Run 'swanlab login' before training\n"
" 4. Use 'swanlab_mode: local' for offline tracking\n"
)
# === Conflict Detection: Multi-Logger Performance Warning ===
# Detect all active logging tools
active_loggers = []
if cfg.get("use_wandb"):
active_loggers.append("WandB")
if cfg.get("use_mlflow"):
active_loggers.append("MLflow")
if cfg.get("comet_api_key") or cfg.get("comet_project_name"):
active_loggers.append("Comet")
if cfg.get("use_swanlab"):
active_loggers.append("SwanLab")
if len(active_loggers) > 1:
LOG.warning(
f"\n{'=' * 70}\n"
f"Multiple logging tools enabled: {', '.join(active_loggers)}\n"
f"{'=' * 70}\n"
f"This may cause:\n"
f" - Performance overhead (~1-2% per logger, cumulative)\n"
f" - Increased memory usage\n"
f" - Longer training time per step\n"
f" - Potential config/callback conflicts\n\n"
f"Recommendations:\n"
f" - Choose ONE primary logging tool for production training\n"
f" - Use multiple loggers only for:\n"
f" * Migration period (transitioning between tools)\n"
f" * Short comparison runs\n"
f" * Debugging specific tool issues\n"
f" - Monitor system resources (CPU, memory) during training\n"
f"{'=' * 70}\n"
)
if len(active_loggers) >= 3:
LOG.error(
f"\n{'!' * 70}\n"
f"WARNING: {len(active_loggers)} logging tools enabled simultaneously!\n"
f"{'!' * 70}\n"
f"This is likely unintentional and WILL significantly impact performance.\n"
f"Expected overhead: ~{len(active_loggers) * 1.5:.1f}% per training step.\n\n"
f"STRONGLY RECOMMEND:\n"
f" - Disable all but ONE logging tool\n"
f" - Use config inheritance to manage multiple configs\n"
f"{'!' * 70}\n"
)
# === Auto-Enable Logic ===
# Enable SwanLab if project is specified
if cfg.get("swanlab_project") and not cfg.get("use_swanlab"):
cfg["use_swanlab"] = True
LOG.info("Automatically enabled use_swanlab because swanlab_project is set")
def pre_model_load(self, cfg: DictDefault):
"""Initialize SwanLab before model loading with runtime checks."""
if not cfg.use_swanlab:
return
# === Runtime Check: Import Availability ===
try:
import swanlab
except ImportError as err:
raise ImportError(
"SwanLab is not installed.\n\n"
"Install with:\n"
" pip install swanlab\n\n"
"Or add to requirements:\n"
" swanlab>=0.3.0\n\n"
f"Original error: {err}"
) from err
# Log SwanLab version
try:
swanlab_version = swanlab.__version__
LOG.info(f"SwanLab version: {swanlab_version}")
except AttributeError:
LOG.warning("Could not determine SwanLab version")
# === Runtime Check: Distributed Training Setup ===
from axolotl.utils.distributed import get_world_size, is_main_process
world_size = get_world_size()
if world_size > 1:
mode = getattr(cfg, "swanlab_mode", "cloud")
LOG.info(
f"\n{'=' * 70}\n"
f"Distributed training detected (world_size={world_size})\n"
f"SwanLab mode: {mode}\n"
f"{'=' * 70}\n"
f"Behavior:\n"
f" - Only rank 0 will initialize SwanLab\n"
f" - Other ranks will skip SwanLab to avoid conflicts\n"
)
if mode == "cloud":
LOG.info(
f" - Only rank 0 will upload to SwanLab cloud\n"
f" - Other ranks run without SwanLab overhead\n"
f"{'=' * 70}\n"
)
# Only initialize SwanLab on the main process (rank 0)
# to avoid creating multiple runs in distributed training
if not is_main_process():
LOG.debug("Skipping SwanLab initialization on non-main process")
return
# Initialize SwanLab run (passing all params directly to init)
try:
init_kwargs = self._get_swanlab_init_kwargs(cfg)
swanlab.init(**init_kwargs)
self.swanlab_initialized = True
LOG.info(f"SwanLab initialized with project: {cfg.swanlab_project}")
# Register Lark notification callback (if configured)
self._register_lark_callback(cfg)
# Log configuration (with error handling)
try:
config_dict = self._prepare_config_for_logging(cfg)
swanlab.config.update(config_dict)
LOG.debug("Successfully logged config to SwanLab")
except Exception as config_err: # pylint: disable=broad-except
LOG.warning(
f"Failed to log config to SwanLab: {config_err}. Continuing anyway."
)
except Exception as err: # pylint: disable=broad-except
LOG.exception("Failed to initialize SwanLab: %s", err)
self.swanlab_initialized = False
def add_callbacks_pre_trainer(self, cfg: DictDefault, model):
"""Add SwanLab callbacks before trainer creation."""
callbacks: list[TrainerCallback] = []
if not cfg.use_swanlab:
return callbacks
if not self.swanlab_initialized:
LOG.warning("SwanLab not initialized, skipping callback registration")
return callbacks
try:
from axolotl.utils.callbacks.swanlab import (
CustomSwanLabCallback,
SaveAxolotlConfigtoSwanLabCallback,
)
# Add our custom lightweight SwanLabCallback
# (avoids omegaconf/antlr4 version conflicts)
swanlab_callback = CustomSwanLabCallback()
callbacks.append(swanlab_callback)
LOG.info("Added CustomSwanLabCallback for metrics logging")
# Add Axolotl config logging callback
if cfg.axolotl_config_path:
config_callback = SaveAxolotlConfigtoSwanLabCallback(
cfg.axolotl_config_path
)
callbacks.append(config_callback)
LOG.info("Added SaveAxolotlConfigtoSwanLabCallback")
except ImportError as err:
LOG.exception("Failed to import SwanLab callbacks: %s", err)
return callbacks
def post_trainer_create(self, cfg: DictDefault, trainer):
"""Post-trainer creation hook."""
if cfg.use_swanlab and self.swanlab_initialized:
try:
import swanlab
# Log additional trainer information (with safe conversion)
trainer_config = {
"total_steps": int(trainer.state.max_steps)
if trainer.state.max_steps
else None,
"num_train_epochs": float(trainer.args.num_train_epochs)
if trainer.args.num_train_epochs
else None,
"train_batch_size": int(trainer.args.train_batch_size)
if hasattr(trainer.args, "train_batch_size")
else None,
"gradient_accumulation_steps": int(
trainer.args.gradient_accumulation_steps
)
if trainer.args.gradient_accumulation_steps
else None,
}
# Remove None values
trainer_config = {
k: v for k, v in trainer_config.items() if v is not None
}
if trainer_config:
swanlab.config.update(trainer_config)
LOG.info("Logged trainer configuration to SwanLab")
except Exception as err: # pylint: disable=broad-except
LOG.debug(f"Failed to log trainer config to SwanLab: {err}")
# Register RLHF completion logging callback if enabled
self._register_completion_callback(cfg, trainer)
def _get_swanlab_init_kwargs(self, cfg: DictDefault) -> dict:
"""Prepare kwargs for swanlab.init().
Passes all configuration parameters directly to swanlab.init()
instead of using environment variables as an intermediate layer.
Returns:
dict: Keyword arguments for swanlab.init()
"""
init_kwargs = {}
# Project name (required)
if cfg.swanlab_project:
init_kwargs["project"] = cfg.swanlab_project
# Experiment name
if cfg.swanlab_experiment_name:
init_kwargs["experiment_name"] = cfg.swanlab_experiment_name
# Description
if cfg.swanlab_description:
init_kwargs["description"] = cfg.swanlab_description
# Workspace (organization)
if cfg.swanlab_workspace:
init_kwargs["workspace"] = cfg.swanlab_workspace
# Mode: cloud, local, offline, disabled
if cfg.swanlab_mode:
init_kwargs["mode"] = cfg.swanlab_mode
# API key (pass directly instead of via env var)
if cfg.swanlab_api_key:
init_kwargs["api_key"] = cfg.swanlab_api_key
# Private deployment hosts (pass directly instead of via env var)
if cfg.swanlab_web_host:
init_kwargs["web_host"] = cfg.swanlab_web_host
if cfg.swanlab_api_host:
init_kwargs["api_host"] = cfg.swanlab_api_host
# Log model checkpoints (coming soon in SwanLab)
if cfg.swanlab_log_model:
init_kwargs["log_model"] = cfg.swanlab_log_model
# Custom branding - adds Axolotl identifier to SwanLab UI
# This helps identify runs from Axolotl vs other frameworks
init_kwargs["config"] = {"UPPERFRAME": "🦎 Axolotl"}
return init_kwargs
def _prepare_config_for_logging(self, cfg: DictDefault) -> dict:
"""Prepare configuration dict for logging to SwanLab."""
def safe_convert(value):
"""Convert value to JSON-serializable type."""
if value is None:
return None
if isinstance(value, (int, float, bool)):
return value
if isinstance(value, str):
return value
# Convert everything else to string
return str(value)
try:
# Extract important training parameters with safe conversion
config_dict = {
"base_model": safe_convert(getattr(cfg, "base_model", "")),
"model_type": safe_convert(getattr(cfg, "model_type", "")),
"sequence_len": safe_convert(getattr(cfg, "sequence_len", None)),
"micro_batch_size": safe_convert(
getattr(cfg, "micro_batch_size", None)
),
"gradient_accumulation_steps": safe_convert(
getattr(cfg, "gradient_accumulation_steps", None)
),
"num_epochs": safe_convert(getattr(cfg, "num_epochs", None)),
"max_steps": safe_convert(getattr(cfg, "max_steps", None)),
"learning_rate": safe_convert(getattr(cfg, "learning_rate", None)),
"lr_scheduler": safe_convert(getattr(cfg, "lr_scheduler", "")),
"optimizer": safe_convert(getattr(cfg, "optimizer", "")),
"warmup_ratio": safe_convert(getattr(cfg, "warmup_ratio", None)),
"weight_decay": safe_convert(getattr(cfg, "weight_decay", None)),
"seed": safe_convert(getattr(cfg, "seed", None)),
"bf16": safe_convert(getattr(cfg, "bf16", None)),
"tf32": safe_convert(getattr(cfg, "tf32", None)),
"flash_attention": safe_convert(getattr(cfg, "flash_attention", None)),
"sample_packing": safe_convert(getattr(cfg, "sample_packing", None)),
}
# Add FSDP/parallel config - only boolean flags
if hasattr(cfg, "fsdp_config") and cfg.fsdp_config:
config_dict["fsdp_enabled"] = True
config_dict["fsdp_version"] = safe_convert(
getattr(cfg, "fsdp_version", None)
)
if hasattr(cfg, "deepspeed") and cfg.deepspeed:
config_dict["deepspeed_enabled"] = True
# Add context parallel info
if hasattr(cfg, "context_parallel_size"):
config_dict["context_parallel_size"] = safe_convert(
getattr(cfg, "context_parallel_size", None)
)
if hasattr(cfg, "tensor_parallel_size"):
config_dict["tensor_parallel_size"] = safe_convert(
getattr(cfg, "tensor_parallel_size", None)
)
if hasattr(cfg, "dp_shard_size"):
config_dict["dp_shard_size"] = safe_convert(
getattr(cfg, "dp_shard_size", None)
)
# Remove None values and empty strings
config_dict = {
k: v
for k, v in config_dict.items()
if v is not None and v != "" and v != "None"
}
return config_dict
except Exception as err: # pylint: disable=broad-except
LOG.warning(f"Failed to prepare config for logging: {err}")
# Return minimal config
try:
lr = getattr(cfg, "learning_rate", None)
lr_value = float(lr) if lr is not None else None
except (TypeError, ValueError):
lr_value = None
return {
"base_model": str(getattr(cfg, "base_model", "unknown")),
"learning_rate": lr_value,
}
def _register_lark_callback(self, cfg: DictDefault):
"""Register Lark (Feishu) notification callback if configured.
Lark notifications enable sending training updates to team chat channels,
useful for production monitoring and team collaboration.
Args:
cfg: Configuration object with Lark webhook settings
"""
# Check if Lark webhook URL is configured
lark_webhook_url = getattr(cfg, "swanlab_lark_webhook_url", None)
if not lark_webhook_url:
return # Lark not configured, skip
try:
import swanlab
from swanlab.plugin.notification import LarkCallback
# Get optional secret for HMAC signature authentication
lark_secret = getattr(cfg, "swanlab_lark_secret", None)
# Create Lark callback with webhook URL and optional secret
lark_callback = LarkCallback(
webhook_url=lark_webhook_url,
secret=lark_secret,
)
# Register callback with SwanLab
swanlab.register_callbacks([lark_callback])
if lark_secret:
LOG.info(
"Registered Lark notification callback with HMAC authentication"
)
else:
LOG.info("Registered Lark notification callback (no HMAC secret)")
LOG.warning(
"Lark webhook has no secret configured. "
"For production use, set 'swanlab_lark_secret' to enable HMAC signature verification."
)
except ImportError as err:
LOG.warning(
f"Failed to import SwanLab Lark plugin: {err}\n\n"
"Lark notifications require SwanLab >= 0.3.0 with plugin support.\n"
"Install with: pip install 'swanlab>=0.3.0'\n\n"
"Continuing without Lark notifications..."
)
except Exception as err: # pylint: disable=broad-except
LOG.exception(
"Failed to register Lark callback: %s\n\n"
"Check your Lark webhook URL and secret configuration.\n"
"Continuing without Lark notifications...",
err,
)
def _register_completion_callback(self, cfg: DictDefault, trainer):
"""Register RLHF completion logging callback if enabled and applicable.
This callback logs model completions (prompts, chosen/rejected responses,
rewards) to SwanLab during RLHF training for qualitative analysis.
Args:
cfg: Configuration object with completion logging settings
trainer: The trainer instance to add callback to
"""
# Check if completion logging is enabled
log_completions = getattr(cfg, "swanlab_log_completions", True)
if not log_completions:
LOG.debug("SwanLab completion logging disabled by config")
return
# Check if trainer is an RLHF trainer
trainer_name = trainer.__class__.__name__
rlhf_trainers = ["DPO", "KTO", "ORPO", "GRPO", "CPO"]
is_rlhf_trainer = any(name in trainer_name for name in rlhf_trainers)
if not is_rlhf_trainer:
LOG.debug(
f"Trainer {trainer_name} is not an RLHF trainer, "
"skipping completion logging callback"
)
return
try:
from axolotl.integrations.swanlab.callbacks import (
SwanLabRLHFCompletionCallback,
)
# Get configuration parameters
log_interval = getattr(cfg, "swanlab_completion_log_interval", 100)
max_buffer = getattr(cfg, "swanlab_completion_max_buffer", 128)
# Create and register callback
completion_callback = SwanLabRLHFCompletionCallback(
log_interval=log_interval,
max_completions=max_buffer,
table_name="rlhf_completions",
)
trainer.add_callback(completion_callback)
LOG.info(
f"Registered SwanLab RLHF completion logging callback for {trainer_name} "
f"(log_interval={log_interval}, max_buffer={max_buffer})"
)
except ImportError as err:
LOG.warning(
f"Failed to import SwanLab completion callback: {err}\n\n"
"This is a bug - the callback should be available.\n"
"Please report this issue.\n\n"
"Continuing without completion logging..."
)
except Exception as err: # pylint: disable=broad-except
LOG.exception(
"Failed to register SwanLab completion callback: %s\n\n"
"Continuing without completion logging...",
err,
)

View File

@@ -0,0 +1,203 @@
"""SwanLab profiling utilities for Axolotl trainers.
This module provides decorators and context managers for profiling
trainer methods and logging execution times to SwanLab.
"""
import time
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@contextmanager
def swanlab_profiling_context(trainer: Any, func_name: str):
"""Context manager for profiling trainer methods.
Measures execution time and logs to SwanLab if enabled.
Example usage:
>>> with swanlab_profiling_context(self, "training_step"):
... result = do_expensive_computation()
Args:
trainer: Trainer instance (must have cfg attribute with use_swanlab flag)
func_name: Name of the function being profiled
Yields:
None
"""
start_time = time.perf_counter()
try:
yield
finally:
duration = time.perf_counter() - start_time
# Check if SwanLab is enabled and initialized
use_swanlab = getattr(getattr(trainer, "cfg", None), "use_swanlab", False)
if use_swanlab:
try:
import swanlab
if swanlab.get_run() is not None:
# Log profiling metric
trainer_class = trainer.__class__.__name__
metric_name = f"profiling/Time taken: {trainer_class}.{func_name}"
swanlab.log({metric_name: duration})
except ImportError:
# SwanLab not installed, silently skip
pass
except Exception as err: # pylint: disable=broad-except
# Log error but don't fail training
LOG.debug(f"Failed to log profiling metric for {func_name}: {err}")
def swanlab_profile(func: Callable) -> Callable:
"""Decorator to profile and log function execution time to SwanLab.
Automatically measures execution time of trainer methods and logs
to SwanLab as profiling metrics.
Example usage:
>>> class MyTrainer:
... @swanlab_profile
... def training_step(self, model, inputs):
... return super().training_step(model, inputs)
Args:
func: Function to profile (must be a method of a trainer instance)
Returns:
Wrapped function with profiling
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
with swanlab_profiling_context(self, func.__name__):
return func(self, *args, **kwargs)
return wrapper
class ProfilingConfig:
"""Configuration for SwanLab profiling.
This class provides a centralized way to control profiling behavior.
Attributes:
enabled: Whether profiling is enabled globally
min_duration_ms: Minimum duration (in ms) to log (filters out very fast ops)
log_interval: Log every N function calls (to reduce overhead)
"""
def __init__(
self,
enabled: bool = True,
min_duration_ms: float = 0.1,
log_interval: int = 1,
):
"""Initialize profiling configuration.
Args:
enabled: Enable profiling. Default: True
min_duration_ms: Minimum duration to log (ms). Default: 0.1
log_interval: Log every N calls. Default: 1 (log all)
"""
self.enabled = enabled
self.min_duration_ms = min_duration_ms
self.log_interval = log_interval
self._call_counts: dict[str, int] = {}
def should_log(self, func_name: str, duration_seconds: float) -> bool:
"""Check if a profiling measurement should be logged.
Args:
func_name: Name of the profiled function
duration_seconds: Execution duration in seconds
Returns:
True if should log, False otherwise
"""
if not self.enabled:
return False
# Check minimum duration threshold
duration_ms = duration_seconds * 1000
if duration_ms < self.min_duration_ms:
return False
# Check log interval
self._call_counts.setdefault(func_name, 0)
self._call_counts[func_name] += 1
# Always log on first call OR at intervals
count = self._call_counts[func_name]
if count == 1 or count % self.log_interval == 0:
return True
return False
# Global profiling config (can be modified by users)
DEFAULT_PROFILING_CONFIG = ProfilingConfig()
@contextmanager
def swanlab_profiling_context_advanced(
trainer: Any,
func_name: str,
config: ProfilingConfig | None = None,
):
"""Advanced profiling context with configurable behavior.
Similar to swanlab_profiling_context but with additional configuration
options for filtering and throttling profiling logs.
Example usage:
>>> config = ProfilingConfig(min_duration_ms=1.0, log_interval=10)
>>> with swanlab_profiling_context_advanced(self, "forward", config):
... output = model(inputs)
Args:
trainer: Trainer instance
func_name: Function name
config: Profiling configuration. If None, uses DEFAULT_PROFILING_CONFIG
Yields:
None
"""
if config is None:
config = DEFAULT_PROFILING_CONFIG
start_time = time.perf_counter()
try:
yield
finally:
duration = time.perf_counter() - start_time
# Check if should log based on config
if config.should_log(func_name, duration):
# Check if SwanLab is enabled
use_swanlab = getattr(getattr(trainer, "cfg", None), "use_swanlab", False)
if use_swanlab:
try:
import swanlab
if swanlab.get_run() is not None:
trainer_class = trainer.__class__.__name__
metric_name = (
f"profiling/Time taken: {trainer_class}.{func_name}"
)
swanlab.log({metric_name: duration})
except ImportError:
pass
except Exception as err: # pylint: disable=broad-except
LOG.debug(f"Failed to log profiling metric for {func_name}: {err}")

View File

@@ -0,0 +1,248 @@
"""Callbacks for SwanLab integration"""
from __future__ import annotations
import json
import os
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from axolotl.core.training_args import AxolotlTrainingArguments
LOG = get_logger(__name__)
class CustomSwanLabCallback(TrainerCallback):
"""
Lightweight SwanLab callback that directly logs metrics without using
SwanLab's transformers integration (which requires omegaconf).
This avoids the antlr4 version conflict between omegaconf and axolotl.
"""
def __init__(self):
self._initialized = False
self.swanlab = None
def setup(self):
"""Lazy initialization of SwanLab"""
if self._initialized:
return
try:
import swanlab
self.swanlab = swanlab
# Check if SwanLab run is initialized
if swanlab.get_run() is None:
LOG.warning("SwanLab run is not initialized")
return
self._initialized = True
LOG.info("CustomSwanLabCallback initialized successfully")
except ImportError:
LOG.error("SwanLab is not installed")
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Called at the beginning of training"""
if not state.is_world_process_zero:
return control
self.setup()
if not self._initialized:
return control
# Log training configuration
try:
self.swanlab.config.update(
{
"train_batch_size": args.per_device_train_batch_size,
"eval_batch_size": args.per_device_eval_batch_size,
"learning_rate": args.learning_rate,
"num_train_epochs": args.num_train_epochs,
"max_steps": args.max_steps,
"warmup_steps": args.warmup_steps,
"logging_steps": args.logging_steps,
"save_steps": args.save_steps,
"gradient_accumulation_steps": args.gradient_accumulation_steps,
}
)
LOG.debug("Training configuration logged to SwanLab")
except Exception as err:
LOG.warning(f"Failed to log training config: {err}")
return control
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs=None,
**kwargs,
):
"""Called when logging metrics"""
if not state.is_world_process_zero:
return control
if not self._initialized:
self.setup()
if not self._initialized or logs is None:
return control
# Log metrics to SwanLab
try:
# Filter out non-numeric values and prepare for logging
metrics = {}
for key, value in logs.items():
if isinstance(value, (int, float)):
# Use step from state
metrics[key] = value
if metrics and state.global_step is not None:
self.swanlab.log(metrics, step=state.global_step)
except Exception as err:
LOG.warning(f"Failed to log metrics to SwanLab: {err}")
return control
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Called at the end of training"""
if not state.is_world_process_zero:
return control
if self._initialized:
LOG.info("Training completed. SwanLab logs are available.")
return control
class SaveAxolotlConfigtoSwanLabCallback(TrainerCallback):
"""Callback to save axolotl config to SwanLab"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: AxolotlTrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if state.is_world_process_zero:
try:
import swanlab
# Check if SwanLab is initialized
if swanlab.get_run() is None:
LOG.warning(
"SwanLab run is not initialized. Please initialize SwanLab before training."
)
return control
# Log Axolotl config as artifact
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
# Log config file to SwanLab
with open(temp_file.name, "r", encoding="utf-8") as config_file:
swanlab.log(
{
"axolotl_config": swanlab.Text(
config_file.read(), caption="Axolotl Config"
)
}
)
LOG.info(
"The Axolotl config has been saved to the SwanLab run under logs."
)
# Clean up temp file
os.unlink(temp_file.name)
except ImportError:
LOG.warning(
"SwanLab is not installed. Install it with: pip install swanlab"
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to SwanLab: {err}")
# Log DeepSpeed config if available
if args.deepspeed:
try:
import swanlab
with NamedTemporaryFile(
mode="w",
delete=False,
suffix=".json",
prefix="deepspeed_config_",
) as temp_file:
skip_upload = False
if isinstance(args.deepspeed, dict):
json.dump(args.deepspeed, temp_file, indent=4)
elif isinstance(args.deepspeed, str) and os.path.exists(
args.deepspeed
):
copyfile(args.deepspeed, temp_file.name)
else:
skip_upload = True
if not skip_upload:
temp_file.flush()
with open(
temp_file.name, "r", encoding="utf-8"
) as ds_config_file:
swanlab.log(
{
"deepspeed_config": swanlab.Text(
ds_config_file.read(),
caption="DeepSpeed Config",
)
}
)
LOG.info(
"The DeepSpeed config has been saved to the SwanLab run under logs."
)
# Clean up temp file
os.unlink(temp_file.name)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(
f"Error while saving DeepSpeed config to SwanLab: {err}"
)
except ImportError:
pass
return control