feat: Add SwanLab integration for experiment tracking (#3334)
* feat(swanlab): add SwanLab integration for experiment tracking SwanLab integration provides comprehensive experiment tracking and monitoring for Axolotl training. Features: - Hyperparameter logging - Training metrics tracking - RLHF completion logging - Performance profiling - Configuration validation and conflict detection Includes: - Plugin in src/axolotl/integrations/swanlab/ - Callback in src/axolotl/utils/callbacks/swanlab.py - Tests in tests/integrations/test_swanlab.py - Examples in examples/swanlab/ 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * fix(swanlab): address PR #3334 review feedback from winglian and CodeRabbit - Change use_swanlab default to True (winglian) - Clear buffer after periodic logging to prevent duplicates (CodeRabbit Major) - Add safe exception handling in config fallback (CodeRabbit) - Use context managers for file operations (CodeRabbit) - Replace LOG.error with LOG.exception for better debugging (CodeRabbit) - Sort __all__ alphabetically (CodeRabbit) - Add language specifiers to README code blocks (CodeRabbit) - Fix end-of-file newline in README (pre-commit) Resolves actionable comments and nitpicks from CodeRabbit review. Addresses reviewer feedback from @winglian. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * only run swanlab integration tests if package is available --------- Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
1284
src/axolotl/integrations/swanlab/README.md
Normal file
1284
src/axolotl/integrations/swanlab/README.md
Normal file
File diff suppressed because it is too large
Load Diff
6
src/axolotl/integrations/swanlab/__init__.py
Normal file
6
src/axolotl/integrations/swanlab/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""SwanLab integration plugin for Axolotl"""
|
||||
|
||||
from axolotl.integrations.swanlab.args import SwanLabConfig
|
||||
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
|
||||
|
||||
__all__ = ["SwanLabConfig", "SwanLabPlugin"]
|
||||
140
src/axolotl/integrations/swanlab/args.py
Normal file
140
src/axolotl/integrations/swanlab/args.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""SwanLab configuration arguments"""
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
|
||||
class SwanLabConfig(BaseModel):
|
||||
"""SwanLab configuration subset"""
|
||||
|
||||
use_swanlab: bool | None = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Enable SwanLab experiment tracking and visualization"
|
||||
},
|
||||
)
|
||||
swanlab_project: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Your SwanLab project name"},
|
||||
)
|
||||
swanlab_experiment_name: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Set the name of your SwanLab experiment"},
|
||||
)
|
||||
swanlab_description: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Description for your SwanLab experiment"},
|
||||
)
|
||||
swanlab_mode: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": '"cloud" to sync to SwanLab cloud, "local" for local only, "offline" to save metadata locally, "disabled" to turn off SwanLab'
|
||||
},
|
||||
)
|
||||
swanlab_workspace: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "SwanLab workspace name (organization or username)"
|
||||
},
|
||||
)
|
||||
swanlab_api_key: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "SwanLab API key for authentication. Can also be set via SWANLAB_API_KEY environment variable"
|
||||
},
|
||||
)
|
||||
swanlab_log_model: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "Whether to log model checkpoints to SwanLab (feature coming soon)"
|
||||
},
|
||||
)
|
||||
swanlab_web_host: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Web address for SwanLab cloud environment (for private deployment)"
|
||||
},
|
||||
)
|
||||
swanlab_api_host: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "API address for SwanLab cloud environment (for private deployment)"
|
||||
},
|
||||
)
|
||||
swanlab_lark_webhook_url: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Lark (Feishu) webhook URL for sending training notifications to team chat"
|
||||
},
|
||||
)
|
||||
swanlab_lark_secret: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Secret for Lark webhook HMAC signature authentication (optional)"
|
||||
},
|
||||
)
|
||||
swanlab_log_completions: bool | None = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Enable logging RLHF completions to SwanLab for qualitative analysis (DPO/KTO/ORPO/GRPO)"
|
||||
},
|
||||
)
|
||||
swanlab_completion_log_interval: int | None = Field(
|
||||
default=100,
|
||||
json_schema_extra={
|
||||
"description": "Number of training steps between completion table logging to SwanLab"
|
||||
},
|
||||
)
|
||||
swanlab_completion_max_buffer: int | None = Field(
|
||||
default=128,
|
||||
json_schema_extra={
|
||||
"description": "Maximum number of completions to buffer before logging (prevents memory leaks)"
|
||||
},
|
||||
)
|
||||
|
||||
@field_validator("swanlab_mode")
|
||||
@classmethod
|
||||
def validate_swanlab_mode(cls, v):
|
||||
"""Validate swanlab_mode is one of the allowed values."""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
valid_modes = ["cloud", "local", "offline", "disabled"]
|
||||
if v not in valid_modes:
|
||||
raise ValueError(
|
||||
f"Invalid swanlab_mode: '{v}'.\n\n"
|
||||
f"Valid options: {', '.join(valid_modes)}\n\n"
|
||||
f"Examples:\n"
|
||||
f" swanlab_mode: cloud # Sync to SwanLab cloud\n"
|
||||
f" swanlab_mode: local # Local only, no cloud sync\n"
|
||||
f" swanlab_mode: offline # Save metadata locally\n"
|
||||
f" swanlab_mode: disabled # Turn off SwanLab\n"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("swanlab_project")
|
||||
@classmethod
|
||||
def validate_swanlab_project(cls, v):
|
||||
"""Validate swanlab_project is non-empty when provided."""
|
||||
if v is not None and isinstance(v, str) and len(v.strip()) == 0:
|
||||
raise ValueError(
|
||||
"swanlab_project cannot be an empty string.\n\n"
|
||||
"Either:\n"
|
||||
" 1. Provide a valid project name: swanlab_project: my-project\n"
|
||||
" 2. Remove the swanlab_project field entirely\n"
|
||||
)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_swanlab_enabled_requires_project(self):
|
||||
"""Validate that if use_swanlab is True, swanlab_project must be set."""
|
||||
if self.use_swanlab is True and not self.swanlab_project:
|
||||
raise ValueError(
|
||||
"SwanLab enabled (use_swanlab: true) but 'swanlab_project' is not set.\n\n"
|
||||
"Solutions:\n"
|
||||
" 1. Add 'swanlab_project: your-project-name' to your config\n"
|
||||
" 2. Set 'use_swanlab: false' to disable SwanLab\n\n"
|
||||
"Example:\n"
|
||||
" use_swanlab: true\n"
|
||||
" swanlab_project: my-llm-training\n"
|
||||
)
|
||||
return self
|
||||
179
src/axolotl/integrations/swanlab/callbacks.py
Normal file
179
src/axolotl/integrations/swanlab/callbacks.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""SwanLab callbacks for Axolotl trainers.
|
||||
|
||||
This module provides HuggingFace Trainer callbacks for logging
|
||||
RLHF completions to SwanLab.
|
||||
"""
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from axolotl.integrations.swanlab.completion_logger import CompletionLogger
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class SwanLabRLHFCompletionCallback(TrainerCallback):
|
||||
"""Callback for logging RLHF completions to SwanLab.
|
||||
|
||||
This callback periodically logs model completions (prompts, chosen/rejected
|
||||
responses, rewards) to SwanLab during RLHF training for qualitative analysis.
|
||||
|
||||
Supports DPO, KTO, ORPO, and GRPO trainers.
|
||||
|
||||
Example usage:
|
||||
>>> callback = SwanLabRLHFCompletionCallback(
|
||||
... log_interval=100, # Log every 100 steps
|
||||
... max_completions=128, # Keep last 128 completions
|
||||
... )
|
||||
>>> trainer.add_callback(callback)
|
||||
|
||||
Attributes:
|
||||
logger: CompletionLogger instance
|
||||
log_interval: Number of steps between SwanLab logging
|
||||
trainer_type: Auto-detected trainer type (dpo/kto/orpo/grpo)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_interval: int = 100,
|
||||
max_completions: int = 128,
|
||||
table_name: str = "rlhf_completions",
|
||||
):
|
||||
"""Initialize SwanLab RLHF completion callback.
|
||||
|
||||
Args:
|
||||
log_interval: Log to SwanLab every N steps. Default: 100
|
||||
max_completions: Maximum completions to buffer. Default: 128
|
||||
table_name: SwanLab table name. Default: "rlhf_completions"
|
||||
"""
|
||||
super().__init__()
|
||||
self.logger = CompletionLogger(maxlen=max_completions)
|
||||
self.log_interval = log_interval
|
||||
self.table_name = table_name
|
||||
self.trainer_type: str | None = None # Auto-detected
|
||||
self._last_logged_step = 0
|
||||
|
||||
def on_init_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Detect trainer type on initialization."""
|
||||
trainer = kwargs.get("trainer")
|
||||
if trainer is not None:
|
||||
trainer_name = trainer.__class__.__name__
|
||||
if "DPO" in trainer_name:
|
||||
self.trainer_type = "dpo"
|
||||
elif "KTO" in trainer_name:
|
||||
self.trainer_type = "kto"
|
||||
elif "ORPO" in trainer_name:
|
||||
self.trainer_type = "orpo"
|
||||
elif "GRPO" in trainer_name:
|
||||
self.trainer_type = "grpo"
|
||||
else:
|
||||
self.trainer_type = "unknown"
|
||||
|
||||
LOG.info(
|
||||
f"SwanLab RLHF completion logging enabled for {trainer_name} "
|
||||
f"(type: {self.trainer_type})"
|
||||
)
|
||||
|
||||
def on_log(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
logs: dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Capture completions from logs and buffer them.
|
||||
|
||||
Different trainers log completions in different formats:
|
||||
- DPO: logs['dpo/chosen'], logs['dpo/rejected'], logs['dpo/reward_diff']
|
||||
- KTO: logs['kto/completion'], logs['kto/label'], logs['kto/reward']
|
||||
- ORPO: logs['orpo/chosen'], logs['orpo/rejected']
|
||||
- GRPO: logs['grpo/completion'], logs['grpo/reward']
|
||||
|
||||
Note: This is a placeholder implementation. Actual log keys depend
|
||||
on the TRL trainer implementation. You may need to patch the trainers
|
||||
to expose completion data in logs.
|
||||
"""
|
||||
if logs is None or self.trainer_type is None:
|
||||
return
|
||||
|
||||
step = state.global_step
|
||||
|
||||
# DPO completions
|
||||
if self.trainer_type == "dpo":
|
||||
if all(key in logs for key in ["dpo/prompt", "dpo/chosen", "dpo/rejected"]):
|
||||
self.logger.add_dpo_completion(
|
||||
step=step,
|
||||
prompt=logs.get("dpo/prompt", ""),
|
||||
chosen=logs.get("dpo/chosen", ""),
|
||||
rejected=logs.get("dpo/rejected", ""),
|
||||
reward_diff=logs.get("dpo/reward_diff"),
|
||||
)
|
||||
|
||||
# KTO completions
|
||||
elif self.trainer_type == "kto":
|
||||
if all(key in logs for key in ["kto/prompt", "kto/completion"]):
|
||||
self.logger.add_kto_completion(
|
||||
step=step,
|
||||
prompt=logs.get("kto/prompt", ""),
|
||||
completion=logs.get("kto/completion", ""),
|
||||
label=logs.get("kto/label", False),
|
||||
reward=logs.get("kto/reward"),
|
||||
)
|
||||
|
||||
# ORPO completions
|
||||
elif self.trainer_type == "orpo":
|
||||
if all(
|
||||
key in logs for key in ["orpo/prompt", "orpo/chosen", "orpo/rejected"]
|
||||
):
|
||||
self.logger.add_orpo_completion(
|
||||
step=step,
|
||||
prompt=logs.get("orpo/prompt", ""),
|
||||
chosen=logs.get("orpo/chosen", ""),
|
||||
rejected=logs.get("orpo/rejected", ""),
|
||||
log_odds_ratio=logs.get("orpo/log_odds_ratio"),
|
||||
)
|
||||
|
||||
# GRPO completions
|
||||
elif self.trainer_type == "grpo":
|
||||
if all(key in logs for key in ["grpo/prompt", "grpo/completion"]):
|
||||
self.logger.add_grpo_completion(
|
||||
step=step,
|
||||
prompt=logs.get("grpo/prompt", ""),
|
||||
completion=logs.get("grpo/completion", ""),
|
||||
reward=logs.get("grpo/reward"),
|
||||
advantage=logs.get("grpo/advantage"),
|
||||
)
|
||||
|
||||
# Periodically log to SwanLab
|
||||
if step - self._last_logged_step >= self.log_interval:
|
||||
if len(self.logger) > 0:
|
||||
self.logger.log_to_swanlab(table_name=self.table_name)
|
||||
self.logger.clear()
|
||||
self._last_logged_step = step
|
||||
|
||||
def on_train_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Log remaining completions at end of training."""
|
||||
if len(self.logger) > 0:
|
||||
LOG.info(
|
||||
f"Training complete, logging final {len(self.logger)} completions to SwanLab"
|
||||
)
|
||||
self.logger.log_to_swanlab(table_name=self.table_name)
|
||||
self._last_logged_step = state.global_step
|
||||
228
src/axolotl/integrations/swanlab/completion_logger.py
Normal file
228
src/axolotl/integrations/swanlab/completion_logger.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""SwanLab completion logger for RLHF/DPO/KTO/ORPO/GRPO training.
|
||||
|
||||
This module provides utilities for logging model completions during
|
||||
preference training to SwanLab for qualitative analysis.
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class CompletionLogger:
|
||||
"""Memory-bounded logger for RLHF completions.
|
||||
|
||||
Stores prompts, completions, and rewards in fixed-size deques to prevent
|
||||
memory leaks during long training runs. Logs completion tables to SwanLab
|
||||
for qualitative analysis of model outputs.
|
||||
|
||||
Example usage:
|
||||
>>> logger = CompletionLogger(maxlen=128)
|
||||
>>> logger.add_dpo_completion(
|
||||
... step=0,
|
||||
... prompt="What is AI?",
|
||||
... chosen="Artificial Intelligence is...",
|
||||
... rejected="AI means...",
|
||||
... reward_diff=0.5
|
||||
... )
|
||||
>>> logger.log_to_swanlab()
|
||||
|
||||
Attributes:
|
||||
maxlen: Maximum number of completions to store (older ones are dropped)
|
||||
data: Deque storing completion dictionaries
|
||||
"""
|
||||
|
||||
def __init__(self, maxlen: int = 128):
|
||||
"""Initialize completion logger with bounded buffer.
|
||||
|
||||
Args:
|
||||
maxlen: Maximum number of completions to store. When the buffer
|
||||
is full, oldest completions are automatically discarded.
|
||||
Default: 128 (sufficient for most RLHF runs without memory issues)
|
||||
"""
|
||||
self.maxlen = maxlen
|
||||
self.data: deque[Mapping[str, Any]] = deque(maxlen=maxlen)
|
||||
|
||||
def add_dpo_completion(
|
||||
self,
|
||||
step: int,
|
||||
prompt: str,
|
||||
chosen: str,
|
||||
rejected: str,
|
||||
reward_diff: float | None = None,
|
||||
) -> None:
|
||||
"""Add a DPO completion to the buffer.
|
||||
|
||||
Args:
|
||||
step: Training step number
|
||||
prompt: Input prompt
|
||||
chosen: Chosen (preferred) completion
|
||||
rejected: Rejected (non-preferred) completion
|
||||
reward_diff: Reward difference (chosen - rejected), if available
|
||||
"""
|
||||
entry = {
|
||||
"step": step,
|
||||
"prompt": prompt,
|
||||
"chosen": chosen,
|
||||
"rejected": rejected,
|
||||
}
|
||||
if reward_diff is not None:
|
||||
entry["reward_diff"] = reward_diff
|
||||
|
||||
self.data.append(entry)
|
||||
|
||||
def add_kto_completion(
|
||||
self,
|
||||
step: int,
|
||||
prompt: str,
|
||||
completion: str,
|
||||
label: bool,
|
||||
reward: float | None = None,
|
||||
) -> None:
|
||||
"""Add a KTO completion to the buffer.
|
||||
|
||||
Args:
|
||||
step: Training step number
|
||||
prompt: Input prompt
|
||||
completion: Model-generated completion
|
||||
label: True if desirable, False if undesirable
|
||||
reward: Reward score, if available
|
||||
"""
|
||||
entry = {
|
||||
"step": step,
|
||||
"prompt": prompt,
|
||||
"completion": completion,
|
||||
"label": "desirable" if label else "undesirable",
|
||||
}
|
||||
if reward is not None:
|
||||
entry["reward"] = reward
|
||||
|
||||
self.data.append(entry)
|
||||
|
||||
def add_orpo_completion(
|
||||
self,
|
||||
step: int,
|
||||
prompt: str,
|
||||
chosen: str,
|
||||
rejected: str,
|
||||
log_odds_ratio: float | None = None,
|
||||
) -> None:
|
||||
"""Add an ORPO completion to the buffer.
|
||||
|
||||
Args:
|
||||
step: Training step number
|
||||
prompt: Input prompt
|
||||
chosen: Chosen (preferred) completion
|
||||
rejected: Rejected (non-preferred) completion
|
||||
log_odds_ratio: Log odds ratio between chosen and rejected
|
||||
"""
|
||||
entry = {
|
||||
"step": step,
|
||||
"prompt": prompt,
|
||||
"chosen": chosen,
|
||||
"rejected": rejected,
|
||||
}
|
||||
if log_odds_ratio is not None:
|
||||
entry["log_odds_ratio"] = log_odds_ratio
|
||||
|
||||
self.data.append(entry)
|
||||
|
||||
def add_grpo_completion(
|
||||
self,
|
||||
step: int,
|
||||
prompt: str,
|
||||
completion: str,
|
||||
reward: float | None = None,
|
||||
advantage: float | None = None,
|
||||
) -> None:
|
||||
"""Add a GRPO completion to the buffer.
|
||||
|
||||
Args:
|
||||
step: Training step number
|
||||
prompt: Input prompt
|
||||
completion: Model-generated completion
|
||||
reward: Reward score from reward model
|
||||
advantage: Advantage estimate (reward - baseline)
|
||||
"""
|
||||
entry = {
|
||||
"step": step,
|
||||
"prompt": prompt,
|
||||
"completion": completion,
|
||||
}
|
||||
if reward is not None:
|
||||
entry["reward"] = reward
|
||||
if advantage is not None:
|
||||
entry["advantage"] = advantage
|
||||
|
||||
self.data.append(entry)
|
||||
|
||||
def log_to_swanlab(self, table_name: str = "completions") -> bool:
|
||||
"""Log buffered completions to SwanLab as a table.
|
||||
|
||||
Creates a SwanLab echarts Table with all buffered completions.
|
||||
Only logs if SwanLab is initialized and data is available.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table in SwanLab dashboard.
|
||||
Default: "completions"
|
||||
|
||||
Returns:
|
||||
True if logging succeeded, False otherwise
|
||||
"""
|
||||
if not self.data:
|
||||
LOG.debug("No completions to log to SwanLab")
|
||||
return False
|
||||
|
||||
try:
|
||||
import swanlab
|
||||
|
||||
if swanlab.get_run() is None:
|
||||
LOG.debug("SwanLab not initialized, skipping completion logging")
|
||||
return False
|
||||
|
||||
# Convert deque to list of dicts
|
||||
completions = list(self.data)
|
||||
|
||||
# Extract headers from first entry (all entries should have same structure)
|
||||
headers = list(completions[0].keys())
|
||||
|
||||
# Build rows: each completion becomes one row
|
||||
rows = []
|
||||
for completion in completions:
|
||||
row = [completion.get(header, "") for header in headers]
|
||||
rows.append(row)
|
||||
|
||||
# Log to SwanLab as echarts Table
|
||||
swanlab.log({table_name: swanlab.echarts.Table().add(headers, rows)})
|
||||
|
||||
LOG.info(f"Logged {len(rows)} completions to SwanLab table '{table_name}'")
|
||||
return True
|
||||
|
||||
except ImportError:
|
||||
LOG.warning(
|
||||
"SwanLab not installed, cannot log completions. "
|
||||
"Install with: pip install swanlab"
|
||||
)
|
||||
return False
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
LOG.exception("Failed to log completions to SwanLab: %s", err)
|
||||
return False
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all buffered completions."""
|
||||
self.data.clear()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return number of buffered completions."""
|
||||
return len(self.data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation showing buffer status."""
|
||||
return (
|
||||
f"CompletionLogger(maxlen={self.maxlen}, "
|
||||
f"buffered={len(self.data)}/{self.maxlen})"
|
||||
)
|
||||
554
src/axolotl/integrations/swanlab/plugins.py
Normal file
554
src/axolotl/integrations/swanlab/plugins.py
Normal file
@@ -0,0 +1,554 @@
|
||||
"""SwanLab Plugin for Axolotl"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class SwanLabPlugin(BasePlugin):
|
||||
"""
|
||||
SwanLab integration plugin for Axolotl.
|
||||
|
||||
Provides experiment tracking, visualization, and logging capabilities
|
||||
using SwanLab (https://swanlab.cn).
|
||||
|
||||
Usage in config.yaml:
|
||||
plugins:
|
||||
- axolotl.integrations.swanlab.SwanLabPlugin
|
||||
|
||||
use_swanlab: true
|
||||
swanlab_project: my-project
|
||||
swanlab_experiment_name: my-experiment
|
||||
swanlab_mode: cloud # or 'local', 'offline', 'disabled'
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.swanlab_initialized = False
|
||||
LOG.info("SwanLab plugin initialized")
|
||||
|
||||
def get_input_args(self) -> str:
|
||||
"""Returns the configuration model for SwanLab integration."""
|
||||
return "axolotl.integrations.swanlab.SwanLabConfig"
|
||||
|
||||
def register(self, cfg: dict):
|
||||
"""Register SwanLab plugin with configuration and conflict detection."""
|
||||
LOG.info("Registering SwanLab plugin")
|
||||
|
||||
# === Conflict Detection: Required Fields ===
|
||||
|
||||
# Check if SwanLab is enabled
|
||||
if cfg.get("use_swanlab"):
|
||||
# 1. Validate project name is set
|
||||
if not cfg.get("swanlab_project"):
|
||||
raise ValueError(
|
||||
"SwanLab enabled but 'swanlab_project' is not set.\n\n"
|
||||
"Solutions:\n"
|
||||
" 1. Add 'swanlab_project: your-project-name' to your config\n"
|
||||
" 2. Set 'use_swanlab: false' to disable SwanLab\n\n"
|
||||
"See: src/axolotl/integrations/swanlab/README.md for examples"
|
||||
)
|
||||
|
||||
# 2. Validate swanlab_mode value
|
||||
valid_modes = ["cloud", "local", "offline", "disabled"]
|
||||
mode = cfg.get("swanlab_mode")
|
||||
if mode and mode not in valid_modes:
|
||||
raise ValueError(
|
||||
f"Invalid swanlab_mode: '{mode}'.\n\n"
|
||||
f"Valid options: {', '.join(valid_modes)}\n\n"
|
||||
f"Example:\n"
|
||||
f" swanlab_mode: cloud # Sync to SwanLab cloud\n"
|
||||
f" swanlab_mode: local # Local only, no cloud sync\n"
|
||||
)
|
||||
|
||||
# 3. Check API key for cloud mode
|
||||
import os
|
||||
|
||||
mode = cfg.get("swanlab_mode", "cloud") # Default is cloud
|
||||
if mode == "cloud":
|
||||
api_key = cfg.get("swanlab_api_key") or os.environ.get(
|
||||
"SWANLAB_API_KEY"
|
||||
)
|
||||
if not api_key:
|
||||
LOG.warning(
|
||||
"SwanLab cloud mode enabled but no API key found.\n"
|
||||
"SwanLab may fail to initialize during training.\n\n"
|
||||
"Solutions:\n"
|
||||
" 1. Set SWANLAB_API_KEY environment variable:\n"
|
||||
" export SWANLAB_API_KEY=your-api-key\n"
|
||||
" 2. Add 'swanlab_api_key: your-api-key' to config (less secure)\n"
|
||||
" 3. Run 'swanlab login' before training\n"
|
||||
" 4. Use 'swanlab_mode: local' for offline tracking\n"
|
||||
)
|
||||
|
||||
# === Conflict Detection: Multi-Logger Performance Warning ===
|
||||
|
||||
# Detect all active logging tools
|
||||
active_loggers = []
|
||||
if cfg.get("use_wandb"):
|
||||
active_loggers.append("WandB")
|
||||
if cfg.get("use_mlflow"):
|
||||
active_loggers.append("MLflow")
|
||||
if cfg.get("comet_api_key") or cfg.get("comet_project_name"):
|
||||
active_loggers.append("Comet")
|
||||
if cfg.get("use_swanlab"):
|
||||
active_loggers.append("SwanLab")
|
||||
|
||||
if len(active_loggers) > 1:
|
||||
LOG.warning(
|
||||
f"\n{'=' * 70}\n"
|
||||
f"Multiple logging tools enabled: {', '.join(active_loggers)}\n"
|
||||
f"{'=' * 70}\n"
|
||||
f"This may cause:\n"
|
||||
f" - Performance overhead (~1-2% per logger, cumulative)\n"
|
||||
f" - Increased memory usage\n"
|
||||
f" - Longer training time per step\n"
|
||||
f" - Potential config/callback conflicts\n\n"
|
||||
f"Recommendations:\n"
|
||||
f" - Choose ONE primary logging tool for production training\n"
|
||||
f" - Use multiple loggers only for:\n"
|
||||
f" * Migration period (transitioning between tools)\n"
|
||||
f" * Short comparison runs\n"
|
||||
f" * Debugging specific tool issues\n"
|
||||
f" - Monitor system resources (CPU, memory) during training\n"
|
||||
f"{'=' * 70}\n"
|
||||
)
|
||||
|
||||
if len(active_loggers) >= 3:
|
||||
LOG.error(
|
||||
f"\n{'!' * 70}\n"
|
||||
f"WARNING: {len(active_loggers)} logging tools enabled simultaneously!\n"
|
||||
f"{'!' * 70}\n"
|
||||
f"This is likely unintentional and WILL significantly impact performance.\n"
|
||||
f"Expected overhead: ~{len(active_loggers) * 1.5:.1f}% per training step.\n\n"
|
||||
f"STRONGLY RECOMMEND:\n"
|
||||
f" - Disable all but ONE logging tool\n"
|
||||
f" - Use config inheritance to manage multiple configs\n"
|
||||
f"{'!' * 70}\n"
|
||||
)
|
||||
|
||||
# === Auto-Enable Logic ===
|
||||
|
||||
# Enable SwanLab if project is specified
|
||||
if cfg.get("swanlab_project") and not cfg.get("use_swanlab"):
|
||||
cfg["use_swanlab"] = True
|
||||
LOG.info("Automatically enabled use_swanlab because swanlab_project is set")
|
||||
|
||||
def pre_model_load(self, cfg: DictDefault):
|
||||
"""Initialize SwanLab before model loading with runtime checks."""
|
||||
if not cfg.use_swanlab:
|
||||
return
|
||||
|
||||
# === Runtime Check: Import Availability ===
|
||||
try:
|
||||
import swanlab
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"SwanLab is not installed.\n\n"
|
||||
"Install with:\n"
|
||||
" pip install swanlab\n\n"
|
||||
"Or add to requirements:\n"
|
||||
" swanlab>=0.3.0\n\n"
|
||||
f"Original error: {err}"
|
||||
) from err
|
||||
|
||||
# Log SwanLab version
|
||||
try:
|
||||
swanlab_version = swanlab.__version__
|
||||
LOG.info(f"SwanLab version: {swanlab_version}")
|
||||
except AttributeError:
|
||||
LOG.warning("Could not determine SwanLab version")
|
||||
|
||||
# === Runtime Check: Distributed Training Setup ===
|
||||
from axolotl.utils.distributed import get_world_size, is_main_process
|
||||
|
||||
world_size = get_world_size()
|
||||
if world_size > 1:
|
||||
mode = getattr(cfg, "swanlab_mode", "cloud")
|
||||
LOG.info(
|
||||
f"\n{'=' * 70}\n"
|
||||
f"Distributed training detected (world_size={world_size})\n"
|
||||
f"SwanLab mode: {mode}\n"
|
||||
f"{'=' * 70}\n"
|
||||
f"Behavior:\n"
|
||||
f" - Only rank 0 will initialize SwanLab\n"
|
||||
f" - Other ranks will skip SwanLab to avoid conflicts\n"
|
||||
)
|
||||
|
||||
if mode == "cloud":
|
||||
LOG.info(
|
||||
f" - Only rank 0 will upload to SwanLab cloud\n"
|
||||
f" - Other ranks run without SwanLab overhead\n"
|
||||
f"{'=' * 70}\n"
|
||||
)
|
||||
|
||||
# Only initialize SwanLab on the main process (rank 0)
|
||||
# to avoid creating multiple runs in distributed training
|
||||
if not is_main_process():
|
||||
LOG.debug("Skipping SwanLab initialization on non-main process")
|
||||
return
|
||||
|
||||
# Initialize SwanLab run (passing all params directly to init)
|
||||
try:
|
||||
init_kwargs = self._get_swanlab_init_kwargs(cfg)
|
||||
swanlab.init(**init_kwargs)
|
||||
self.swanlab_initialized = True
|
||||
LOG.info(f"SwanLab initialized with project: {cfg.swanlab_project}")
|
||||
|
||||
# Register Lark notification callback (if configured)
|
||||
self._register_lark_callback(cfg)
|
||||
|
||||
# Log configuration (with error handling)
|
||||
try:
|
||||
config_dict = self._prepare_config_for_logging(cfg)
|
||||
swanlab.config.update(config_dict)
|
||||
LOG.debug("Successfully logged config to SwanLab")
|
||||
except Exception as config_err: # pylint: disable=broad-except
|
||||
LOG.warning(
|
||||
f"Failed to log config to SwanLab: {config_err}. Continuing anyway."
|
||||
)
|
||||
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
LOG.exception("Failed to initialize SwanLab: %s", err)
|
||||
self.swanlab_initialized = False
|
||||
|
||||
def add_callbacks_pre_trainer(self, cfg: DictDefault, model):
|
||||
"""Add SwanLab callbacks before trainer creation."""
|
||||
callbacks: list[TrainerCallback] = []
|
||||
|
||||
if not cfg.use_swanlab:
|
||||
return callbacks
|
||||
|
||||
if not self.swanlab_initialized:
|
||||
LOG.warning("SwanLab not initialized, skipping callback registration")
|
||||
return callbacks
|
||||
|
||||
try:
|
||||
from axolotl.utils.callbacks.swanlab import (
|
||||
CustomSwanLabCallback,
|
||||
SaveAxolotlConfigtoSwanLabCallback,
|
||||
)
|
||||
|
||||
# Add our custom lightweight SwanLabCallback
|
||||
# (avoids omegaconf/antlr4 version conflicts)
|
||||
swanlab_callback = CustomSwanLabCallback()
|
||||
callbacks.append(swanlab_callback)
|
||||
LOG.info("Added CustomSwanLabCallback for metrics logging")
|
||||
|
||||
# Add Axolotl config logging callback
|
||||
if cfg.axolotl_config_path:
|
||||
config_callback = SaveAxolotlConfigtoSwanLabCallback(
|
||||
cfg.axolotl_config_path
|
||||
)
|
||||
callbacks.append(config_callback)
|
||||
LOG.info("Added SaveAxolotlConfigtoSwanLabCallback")
|
||||
|
||||
except ImportError as err:
|
||||
LOG.exception("Failed to import SwanLab callbacks: %s", err)
|
||||
|
||||
return callbacks
|
||||
|
||||
def post_trainer_create(self, cfg: DictDefault, trainer):
|
||||
"""Post-trainer creation hook."""
|
||||
if cfg.use_swanlab and self.swanlab_initialized:
|
||||
try:
|
||||
import swanlab
|
||||
|
||||
# Log additional trainer information (with safe conversion)
|
||||
trainer_config = {
|
||||
"total_steps": int(trainer.state.max_steps)
|
||||
if trainer.state.max_steps
|
||||
else None,
|
||||
"num_train_epochs": float(trainer.args.num_train_epochs)
|
||||
if trainer.args.num_train_epochs
|
||||
else None,
|
||||
"train_batch_size": int(trainer.args.train_batch_size)
|
||||
if hasattr(trainer.args, "train_batch_size")
|
||||
else None,
|
||||
"gradient_accumulation_steps": int(
|
||||
trainer.args.gradient_accumulation_steps
|
||||
)
|
||||
if trainer.args.gradient_accumulation_steps
|
||||
else None,
|
||||
}
|
||||
# Remove None values
|
||||
trainer_config = {
|
||||
k: v for k, v in trainer_config.items() if v is not None
|
||||
}
|
||||
|
||||
if trainer_config:
|
||||
swanlab.config.update(trainer_config)
|
||||
LOG.info("Logged trainer configuration to SwanLab")
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
LOG.debug(f"Failed to log trainer config to SwanLab: {err}")
|
||||
|
||||
# Register RLHF completion logging callback if enabled
|
||||
self._register_completion_callback(cfg, trainer)
|
||||
|
||||
def _get_swanlab_init_kwargs(self, cfg: DictDefault) -> dict:
|
||||
"""Prepare kwargs for swanlab.init().
|
||||
|
||||
Passes all configuration parameters directly to swanlab.init()
|
||||
instead of using environment variables as an intermediate layer.
|
||||
|
||||
Returns:
|
||||
dict: Keyword arguments for swanlab.init()
|
||||
"""
|
||||
init_kwargs = {}
|
||||
|
||||
# Project name (required)
|
||||
if cfg.swanlab_project:
|
||||
init_kwargs["project"] = cfg.swanlab_project
|
||||
|
||||
# Experiment name
|
||||
if cfg.swanlab_experiment_name:
|
||||
init_kwargs["experiment_name"] = cfg.swanlab_experiment_name
|
||||
|
||||
# Description
|
||||
if cfg.swanlab_description:
|
||||
init_kwargs["description"] = cfg.swanlab_description
|
||||
|
||||
# Workspace (organization)
|
||||
if cfg.swanlab_workspace:
|
||||
init_kwargs["workspace"] = cfg.swanlab_workspace
|
||||
|
||||
# Mode: cloud, local, offline, disabled
|
||||
if cfg.swanlab_mode:
|
||||
init_kwargs["mode"] = cfg.swanlab_mode
|
||||
|
||||
# API key (pass directly instead of via env var)
|
||||
if cfg.swanlab_api_key:
|
||||
init_kwargs["api_key"] = cfg.swanlab_api_key
|
||||
|
||||
# Private deployment hosts (pass directly instead of via env var)
|
||||
if cfg.swanlab_web_host:
|
||||
init_kwargs["web_host"] = cfg.swanlab_web_host
|
||||
|
||||
if cfg.swanlab_api_host:
|
||||
init_kwargs["api_host"] = cfg.swanlab_api_host
|
||||
|
||||
# Log model checkpoints (coming soon in SwanLab)
|
||||
if cfg.swanlab_log_model:
|
||||
init_kwargs["log_model"] = cfg.swanlab_log_model
|
||||
|
||||
# Custom branding - adds Axolotl identifier to SwanLab UI
|
||||
# This helps identify runs from Axolotl vs other frameworks
|
||||
init_kwargs["config"] = {"UPPERFRAME": "🦎 Axolotl"}
|
||||
|
||||
return init_kwargs
|
||||
|
||||
def _prepare_config_for_logging(self, cfg: DictDefault) -> dict:
|
||||
"""Prepare configuration dict for logging to SwanLab."""
|
||||
|
||||
def safe_convert(value):
|
||||
"""Convert value to JSON-serializable type."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (int, float, bool)):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
# Convert everything else to string
|
||||
return str(value)
|
||||
|
||||
try:
|
||||
# Extract important training parameters with safe conversion
|
||||
config_dict = {
|
||||
"base_model": safe_convert(getattr(cfg, "base_model", "")),
|
||||
"model_type": safe_convert(getattr(cfg, "model_type", "")),
|
||||
"sequence_len": safe_convert(getattr(cfg, "sequence_len", None)),
|
||||
"micro_batch_size": safe_convert(
|
||||
getattr(cfg, "micro_batch_size", None)
|
||||
),
|
||||
"gradient_accumulation_steps": safe_convert(
|
||||
getattr(cfg, "gradient_accumulation_steps", None)
|
||||
),
|
||||
"num_epochs": safe_convert(getattr(cfg, "num_epochs", None)),
|
||||
"max_steps": safe_convert(getattr(cfg, "max_steps", None)),
|
||||
"learning_rate": safe_convert(getattr(cfg, "learning_rate", None)),
|
||||
"lr_scheduler": safe_convert(getattr(cfg, "lr_scheduler", "")),
|
||||
"optimizer": safe_convert(getattr(cfg, "optimizer", "")),
|
||||
"warmup_ratio": safe_convert(getattr(cfg, "warmup_ratio", None)),
|
||||
"weight_decay": safe_convert(getattr(cfg, "weight_decay", None)),
|
||||
"seed": safe_convert(getattr(cfg, "seed", None)),
|
||||
"bf16": safe_convert(getattr(cfg, "bf16", None)),
|
||||
"tf32": safe_convert(getattr(cfg, "tf32", None)),
|
||||
"flash_attention": safe_convert(getattr(cfg, "flash_attention", None)),
|
||||
"sample_packing": safe_convert(getattr(cfg, "sample_packing", None)),
|
||||
}
|
||||
|
||||
# Add FSDP/parallel config - only boolean flags
|
||||
if hasattr(cfg, "fsdp_config") and cfg.fsdp_config:
|
||||
config_dict["fsdp_enabled"] = True
|
||||
config_dict["fsdp_version"] = safe_convert(
|
||||
getattr(cfg, "fsdp_version", None)
|
||||
)
|
||||
|
||||
if hasattr(cfg, "deepspeed") and cfg.deepspeed:
|
||||
config_dict["deepspeed_enabled"] = True
|
||||
|
||||
# Add context parallel info
|
||||
if hasattr(cfg, "context_parallel_size"):
|
||||
config_dict["context_parallel_size"] = safe_convert(
|
||||
getattr(cfg, "context_parallel_size", None)
|
||||
)
|
||||
if hasattr(cfg, "tensor_parallel_size"):
|
||||
config_dict["tensor_parallel_size"] = safe_convert(
|
||||
getattr(cfg, "tensor_parallel_size", None)
|
||||
)
|
||||
if hasattr(cfg, "dp_shard_size"):
|
||||
config_dict["dp_shard_size"] = safe_convert(
|
||||
getattr(cfg, "dp_shard_size", None)
|
||||
)
|
||||
|
||||
# Remove None values and empty strings
|
||||
config_dict = {
|
||||
k: v
|
||||
for k, v in config_dict.items()
|
||||
if v is not None and v != "" and v != "None"
|
||||
}
|
||||
|
||||
return config_dict
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
LOG.warning(f"Failed to prepare config for logging: {err}")
|
||||
# Return minimal config
|
||||
try:
|
||||
lr = getattr(cfg, "learning_rate", None)
|
||||
lr_value = float(lr) if lr is not None else None
|
||||
except (TypeError, ValueError):
|
||||
lr_value = None
|
||||
return {
|
||||
"base_model": str(getattr(cfg, "base_model", "unknown")),
|
||||
"learning_rate": lr_value,
|
||||
}
|
||||
|
||||
def _register_lark_callback(self, cfg: DictDefault):
|
||||
"""Register Lark (Feishu) notification callback if configured.
|
||||
|
||||
Lark notifications enable sending training updates to team chat channels,
|
||||
useful for production monitoring and team collaboration.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object with Lark webhook settings
|
||||
"""
|
||||
# Check if Lark webhook URL is configured
|
||||
lark_webhook_url = getattr(cfg, "swanlab_lark_webhook_url", None)
|
||||
if not lark_webhook_url:
|
||||
return # Lark not configured, skip
|
||||
|
||||
try:
|
||||
import swanlab
|
||||
from swanlab.plugin.notification import LarkCallback
|
||||
|
||||
# Get optional secret for HMAC signature authentication
|
||||
lark_secret = getattr(cfg, "swanlab_lark_secret", None)
|
||||
|
||||
# Create Lark callback with webhook URL and optional secret
|
||||
lark_callback = LarkCallback(
|
||||
webhook_url=lark_webhook_url,
|
||||
secret=lark_secret,
|
||||
)
|
||||
|
||||
# Register callback with SwanLab
|
||||
swanlab.register_callbacks([lark_callback])
|
||||
|
||||
if lark_secret:
|
||||
LOG.info(
|
||||
"Registered Lark notification callback with HMAC authentication"
|
||||
)
|
||||
else:
|
||||
LOG.info("Registered Lark notification callback (no HMAC secret)")
|
||||
LOG.warning(
|
||||
"Lark webhook has no secret configured. "
|
||||
"For production use, set 'swanlab_lark_secret' to enable HMAC signature verification."
|
||||
)
|
||||
|
||||
except ImportError as err:
|
||||
LOG.warning(
|
||||
f"Failed to import SwanLab Lark plugin: {err}\n\n"
|
||||
"Lark notifications require SwanLab >= 0.3.0 with plugin support.\n"
|
||||
"Install with: pip install 'swanlab>=0.3.0'\n\n"
|
||||
"Continuing without Lark notifications..."
|
||||
)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
LOG.exception(
|
||||
"Failed to register Lark callback: %s\n\n"
|
||||
"Check your Lark webhook URL and secret configuration.\n"
|
||||
"Continuing without Lark notifications...",
|
||||
err,
|
||||
)
|
||||
|
||||
def _register_completion_callback(self, cfg: DictDefault, trainer):
|
||||
"""Register RLHF completion logging callback if enabled and applicable.
|
||||
|
||||
This callback logs model completions (prompts, chosen/rejected responses,
|
||||
rewards) to SwanLab during RLHF training for qualitative analysis.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object with completion logging settings
|
||||
trainer: The trainer instance to add callback to
|
||||
"""
|
||||
# Check if completion logging is enabled
|
||||
log_completions = getattr(cfg, "swanlab_log_completions", True)
|
||||
if not log_completions:
|
||||
LOG.debug("SwanLab completion logging disabled by config")
|
||||
return
|
||||
|
||||
# Check if trainer is an RLHF trainer
|
||||
trainer_name = trainer.__class__.__name__
|
||||
rlhf_trainers = ["DPO", "KTO", "ORPO", "GRPO", "CPO"]
|
||||
is_rlhf_trainer = any(name in trainer_name for name in rlhf_trainers)
|
||||
|
||||
if not is_rlhf_trainer:
|
||||
LOG.debug(
|
||||
f"Trainer {trainer_name} is not an RLHF trainer, "
|
||||
"skipping completion logging callback"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
from axolotl.integrations.swanlab.callbacks import (
|
||||
SwanLabRLHFCompletionCallback,
|
||||
)
|
||||
|
||||
# Get configuration parameters
|
||||
log_interval = getattr(cfg, "swanlab_completion_log_interval", 100)
|
||||
max_buffer = getattr(cfg, "swanlab_completion_max_buffer", 128)
|
||||
|
||||
# Create and register callback
|
||||
completion_callback = SwanLabRLHFCompletionCallback(
|
||||
log_interval=log_interval,
|
||||
max_completions=max_buffer,
|
||||
table_name="rlhf_completions",
|
||||
)
|
||||
|
||||
trainer.add_callback(completion_callback)
|
||||
|
||||
LOG.info(
|
||||
f"Registered SwanLab RLHF completion logging callback for {trainer_name} "
|
||||
f"(log_interval={log_interval}, max_buffer={max_buffer})"
|
||||
)
|
||||
|
||||
except ImportError as err:
|
||||
LOG.warning(
|
||||
f"Failed to import SwanLab completion callback: {err}\n\n"
|
||||
"This is a bug - the callback should be available.\n"
|
||||
"Please report this issue.\n\n"
|
||||
"Continuing without completion logging..."
|
||||
)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
LOG.exception(
|
||||
"Failed to register SwanLab completion callback: %s\n\n"
|
||||
"Continuing without completion logging...",
|
||||
err,
|
||||
)
|
||||
203
src/axolotl/integrations/swanlab/profiling.py
Normal file
203
src/axolotl/integrations/swanlab/profiling.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""SwanLab profiling utilities for Axolotl trainers.
|
||||
|
||||
This module provides decorators and context managers for profiling
|
||||
trainer methods and logging execution times to SwanLab.
|
||||
"""
|
||||
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Any, Callable
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def swanlab_profiling_context(trainer: Any, func_name: str):
|
||||
"""Context manager for profiling trainer methods.
|
||||
|
||||
Measures execution time and logs to SwanLab if enabled.
|
||||
|
||||
Example usage:
|
||||
>>> with swanlab_profiling_context(self, "training_step"):
|
||||
... result = do_expensive_computation()
|
||||
|
||||
Args:
|
||||
trainer: Trainer instance (must have cfg attribute with use_swanlab flag)
|
||||
func_name: Name of the function being profiled
|
||||
|
||||
Yields:
|
||||
None
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
duration = time.perf_counter() - start_time
|
||||
|
||||
# Check if SwanLab is enabled and initialized
|
||||
use_swanlab = getattr(getattr(trainer, "cfg", None), "use_swanlab", False)
|
||||
if use_swanlab:
|
||||
try:
|
||||
import swanlab
|
||||
|
||||
if swanlab.get_run() is not None:
|
||||
# Log profiling metric
|
||||
trainer_class = trainer.__class__.__name__
|
||||
metric_name = f"profiling/Time taken: {trainer_class}.{func_name}"
|
||||
|
||||
swanlab.log({metric_name: duration})
|
||||
|
||||
except ImportError:
|
||||
# SwanLab not installed, silently skip
|
||||
pass
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
# Log error but don't fail training
|
||||
LOG.debug(f"Failed to log profiling metric for {func_name}: {err}")
|
||||
|
||||
|
||||
def swanlab_profile(func: Callable) -> Callable:
|
||||
"""Decorator to profile and log function execution time to SwanLab.
|
||||
|
||||
Automatically measures execution time of trainer methods and logs
|
||||
to SwanLab as profiling metrics.
|
||||
|
||||
Example usage:
|
||||
>>> class MyTrainer:
|
||||
... @swanlab_profile
|
||||
... def training_step(self, model, inputs):
|
||||
... return super().training_step(model, inputs)
|
||||
|
||||
Args:
|
||||
func: Function to profile (must be a method of a trainer instance)
|
||||
|
||||
Returns:
|
||||
Wrapped function with profiling
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with swanlab_profiling_context(self, func.__name__):
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ProfilingConfig:
|
||||
"""Configuration for SwanLab profiling.
|
||||
|
||||
This class provides a centralized way to control profiling behavior.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether profiling is enabled globally
|
||||
min_duration_ms: Minimum duration (in ms) to log (filters out very fast ops)
|
||||
log_interval: Log every N function calls (to reduce overhead)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool = True,
|
||||
min_duration_ms: float = 0.1,
|
||||
log_interval: int = 1,
|
||||
):
|
||||
"""Initialize profiling configuration.
|
||||
|
||||
Args:
|
||||
enabled: Enable profiling. Default: True
|
||||
min_duration_ms: Minimum duration to log (ms). Default: 0.1
|
||||
log_interval: Log every N calls. Default: 1 (log all)
|
||||
"""
|
||||
self.enabled = enabled
|
||||
self.min_duration_ms = min_duration_ms
|
||||
self.log_interval = log_interval
|
||||
self._call_counts: dict[str, int] = {}
|
||||
|
||||
def should_log(self, func_name: str, duration_seconds: float) -> bool:
|
||||
"""Check if a profiling measurement should be logged.
|
||||
|
||||
Args:
|
||||
func_name: Name of the profiled function
|
||||
duration_seconds: Execution duration in seconds
|
||||
|
||||
Returns:
|
||||
True if should log, False otherwise
|
||||
"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
# Check minimum duration threshold
|
||||
duration_ms = duration_seconds * 1000
|
||||
if duration_ms < self.min_duration_ms:
|
||||
return False
|
||||
|
||||
# Check log interval
|
||||
self._call_counts.setdefault(func_name, 0)
|
||||
self._call_counts[func_name] += 1
|
||||
|
||||
# Always log on first call OR at intervals
|
||||
count = self._call_counts[func_name]
|
||||
if count == 1 or count % self.log_interval == 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Global profiling config (can be modified by users)
|
||||
DEFAULT_PROFILING_CONFIG = ProfilingConfig()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def swanlab_profiling_context_advanced(
|
||||
trainer: Any,
|
||||
func_name: str,
|
||||
config: ProfilingConfig | None = None,
|
||||
):
|
||||
"""Advanced profiling context with configurable behavior.
|
||||
|
||||
Similar to swanlab_profiling_context but with additional configuration
|
||||
options for filtering and throttling profiling logs.
|
||||
|
||||
Example usage:
|
||||
>>> config = ProfilingConfig(min_duration_ms=1.0, log_interval=10)
|
||||
>>> with swanlab_profiling_context_advanced(self, "forward", config):
|
||||
... output = model(inputs)
|
||||
|
||||
Args:
|
||||
trainer: Trainer instance
|
||||
func_name: Function name
|
||||
config: Profiling configuration. If None, uses DEFAULT_PROFILING_CONFIG
|
||||
|
||||
Yields:
|
||||
None
|
||||
"""
|
||||
if config is None:
|
||||
config = DEFAULT_PROFILING_CONFIG
|
||||
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
duration = time.perf_counter() - start_time
|
||||
|
||||
# Check if should log based on config
|
||||
if config.should_log(func_name, duration):
|
||||
# Check if SwanLab is enabled
|
||||
use_swanlab = getattr(getattr(trainer, "cfg", None), "use_swanlab", False)
|
||||
if use_swanlab:
|
||||
try:
|
||||
import swanlab
|
||||
|
||||
if swanlab.get_run() is not None:
|
||||
trainer_class = trainer.__class__.__name__
|
||||
metric_name = (
|
||||
f"profiling/Time taken: {trainer_class}.{func_name}"
|
||||
)
|
||||
|
||||
swanlab.log({metric_name: duration})
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
LOG.debug(f"Failed to log profiling metric for {func_name}: {err}")
|
||||
248
src/axolotl/utils/callbacks/swanlab.py
Normal file
248
src/axolotl/utils/callbacks/swanlab.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Callbacks for SwanLab integration"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.training_args import AxolotlTrainingArguments
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class CustomSwanLabCallback(TrainerCallback):
|
||||
"""
|
||||
Lightweight SwanLab callback that directly logs metrics without using
|
||||
SwanLab's transformers integration (which requires omegaconf).
|
||||
|
||||
This avoids the antlr4 version conflict between omegaconf and axolotl.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._initialized = False
|
||||
self.swanlab = None
|
||||
|
||||
def setup(self):
|
||||
"""Lazy initialization of SwanLab"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
import swanlab
|
||||
|
||||
self.swanlab = swanlab
|
||||
|
||||
# Check if SwanLab run is initialized
|
||||
if swanlab.get_run() is None:
|
||||
LOG.warning("SwanLab run is not initialized")
|
||||
return
|
||||
|
||||
self._initialized = True
|
||||
LOG.info("CustomSwanLabCallback initialized successfully")
|
||||
except ImportError:
|
||||
LOG.error("SwanLab is not installed")
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Called at the beginning of training"""
|
||||
if not state.is_world_process_zero:
|
||||
return control
|
||||
|
||||
self.setup()
|
||||
|
||||
if not self._initialized:
|
||||
return control
|
||||
|
||||
# Log training configuration
|
||||
try:
|
||||
self.swanlab.config.update(
|
||||
{
|
||||
"train_batch_size": args.per_device_train_batch_size,
|
||||
"eval_batch_size": args.per_device_eval_batch_size,
|
||||
"learning_rate": args.learning_rate,
|
||||
"num_train_epochs": args.num_train_epochs,
|
||||
"max_steps": args.max_steps,
|
||||
"warmup_steps": args.warmup_steps,
|
||||
"logging_steps": args.logging_steps,
|
||||
"save_steps": args.save_steps,
|
||||
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
||||
}
|
||||
)
|
||||
LOG.debug("Training configuration logged to SwanLab")
|
||||
except Exception as err:
|
||||
LOG.warning(f"Failed to log training config: {err}")
|
||||
|
||||
return control
|
||||
|
||||
def on_log(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
logs=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Called when logging metrics"""
|
||||
if not state.is_world_process_zero:
|
||||
return control
|
||||
|
||||
if not self._initialized:
|
||||
self.setup()
|
||||
|
||||
if not self._initialized or logs is None:
|
||||
return control
|
||||
|
||||
# Log metrics to SwanLab
|
||||
try:
|
||||
# Filter out non-numeric values and prepare for logging
|
||||
metrics = {}
|
||||
for key, value in logs.items():
|
||||
if isinstance(value, (int, float)):
|
||||
# Use step from state
|
||||
metrics[key] = value
|
||||
|
||||
if metrics and state.global_step is not None:
|
||||
self.swanlab.log(metrics, step=state.global_step)
|
||||
except Exception as err:
|
||||
LOG.warning(f"Failed to log metrics to SwanLab: {err}")
|
||||
|
||||
return control
|
||||
|
||||
def on_train_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Called at the end of training"""
|
||||
if not state.is_world_process_zero:
|
||||
return control
|
||||
|
||||
if self._initialized:
|
||||
LOG.info("Training completed. SwanLab logs are available.")
|
||||
|
||||
return control
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoSwanLabCallback(TrainerCallback):
|
||||
"""Callback to save axolotl config to SwanLab"""
|
||||
|
||||
def __init__(self, axolotl_config_path):
|
||||
self.axolotl_config_path = axolotl_config_path
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: AxolotlTrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if state.is_world_process_zero:
|
||||
try:
|
||||
import swanlab
|
||||
|
||||
# Check if SwanLab is initialized
|
||||
if swanlab.get_run() is None:
|
||||
LOG.warning(
|
||||
"SwanLab run is not initialized. Please initialize SwanLab before training."
|
||||
)
|
||||
return control
|
||||
|
||||
# Log Axolotl config as artifact
|
||||
with NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||
) as temp_file:
|
||||
copyfile(self.axolotl_config_path, temp_file.name)
|
||||
|
||||
# Log config file to SwanLab
|
||||
with open(temp_file.name, "r", encoding="utf-8") as config_file:
|
||||
swanlab.log(
|
||||
{
|
||||
"axolotl_config": swanlab.Text(
|
||||
config_file.read(), caption="Axolotl Config"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
"The Axolotl config has been saved to the SwanLab run under logs."
|
||||
)
|
||||
|
||||
# Clean up temp file
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
except ImportError:
|
||||
LOG.warning(
|
||||
"SwanLab is not installed. Install it with: pip install swanlab"
|
||||
)
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to SwanLab: {err}")
|
||||
|
||||
# Log DeepSpeed config if available
|
||||
if args.deepspeed:
|
||||
try:
|
||||
import swanlab
|
||||
|
||||
with NamedTemporaryFile(
|
||||
mode="w",
|
||||
delete=False,
|
||||
suffix=".json",
|
||||
prefix="deepspeed_config_",
|
||||
) as temp_file:
|
||||
skip_upload = False
|
||||
if isinstance(args.deepspeed, dict):
|
||||
json.dump(args.deepspeed, temp_file, indent=4)
|
||||
elif isinstance(args.deepspeed, str) and os.path.exists(
|
||||
args.deepspeed
|
||||
):
|
||||
copyfile(args.deepspeed, temp_file.name)
|
||||
else:
|
||||
skip_upload = True
|
||||
|
||||
if not skip_upload:
|
||||
temp_file.flush()
|
||||
with open(
|
||||
temp_file.name, "r", encoding="utf-8"
|
||||
) as ds_config_file:
|
||||
swanlab.log(
|
||||
{
|
||||
"deepspeed_config": swanlab.Text(
|
||||
ds_config_file.read(),
|
||||
caption="DeepSpeed Config",
|
||||
)
|
||||
}
|
||||
)
|
||||
LOG.info(
|
||||
"The DeepSpeed config has been saved to the SwanLab run under logs."
|
||||
)
|
||||
|
||||
# Clean up temp file
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(
|
||||
f"Error while saving DeepSpeed config to SwanLab: {err}"
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return control
|
||||
Reference in New Issue
Block a user