* feat(swanlab): add SwanLab integration for experiment tracking SwanLab integration provides comprehensive experiment tracking and monitoring for Axolotl training. Features: - Hyperparameter logging - Training metrics tracking - RLHF completion logging - Performance profiling - Configuration validation and conflict detection Includes: - Plugin in src/axolotl/integrations/swanlab/ - Callback in src/axolotl/utils/callbacks/swanlab.py - Tests in tests/integrations/test_swanlab.py - Examples in examples/swanlab/ 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * fix(swanlab): address PR #3334 review feedback from winglian and CodeRabbit - Change use_swanlab default to True (winglian) - Clear buffer after periodic logging to prevent duplicates (CodeRabbit Major) - Add safe exception handling in config fallback (CodeRabbit) - Use context managers for file operations (CodeRabbit) - Replace LOG.error with LOG.exception for better debugging (CodeRabbit) - Sort __all__ alphabetically (CodeRabbit) - Add language specifiers to README code blocks (CodeRabbit) - Fix end-of-file newline in README (pre-commit) Resolves actionable comments and nitpicks from CodeRabbit review. Addresses reviewer feedback from @winglian. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * only run swanlab integration tests if package is available --------- Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
300 lines
11 KiB
Python
300 lines
11 KiB
Python
"""Example: Custom Trainer with SwanLab Profiling
|
|
|
|
This example demonstrates how to add SwanLab profiling to your custom trainer.
|
|
|
|
Features:
|
|
- @swanlab_profile decorator for automatic profiling
|
|
- swanlab_profiling_context for fine-grained profiling
|
|
- ProfilingConfig for advanced filtering and throttling
|
|
|
|
Usage:
|
|
1. Create your custom trainer extending AxolotlTrainer
|
|
2. Add @swanlab_profile decorators to methods you want to profile
|
|
3. Use swanlab_profiling_context for fine-grained profiling within methods
|
|
4. Enable SwanLab in your config (use_swanlab: true)
|
|
|
|
See also:
|
|
- examples/swanlab/lora-swanlab-profiling.yml for config
|
|
- src/axolotl/integrations/swanlab/profiling.py for implementation
|
|
"""
|
|
|
|
from axolotl.core.trainers.base import AxolotlTrainer
|
|
from axolotl.integrations.swanlab.profiling import (
|
|
ProfilingConfig,
|
|
swanlab_profile,
|
|
swanlab_profiling_context,
|
|
swanlab_profiling_context_advanced,
|
|
)
|
|
|
|
|
|
class CustomTrainerWithProfiling(AxolotlTrainer):
|
|
"""Custom trainer with SwanLab profiling enabled.
|
|
|
|
This trainer demonstrates three profiling patterns:
|
|
1. Decorator-based profiling (@swanlab_profile)
|
|
2. Context manager profiling (swanlab_profiling_context)
|
|
3. Advanced profiling with filtering (ProfilingConfig)
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
# Create custom profiling config for high-frequency operations
|
|
self.fast_op_config = ProfilingConfig(
|
|
enabled=True,
|
|
min_duration_ms=0.5, # Only log if duration > 0.5ms
|
|
log_interval=50, # Log every 50th call
|
|
)
|
|
|
|
# ========================================================================
|
|
# Pattern 1: Decorator-based Profiling
|
|
# ========================================================================
|
|
# Best for: Methods you always want to profile
|
|
# Overhead: ~2-5 microseconds per call (negligible)
|
|
|
|
@swanlab_profile
|
|
def training_step(self, model, inputs):
|
|
"""Main training step - always profile.
|
|
|
|
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.training_step
|
|
"""
|
|
return super().training_step(model, inputs)
|
|
|
|
@swanlab_profile
|
|
def compute_loss(self, model, inputs, return_outputs=False):
|
|
"""Loss computation - always profile.
|
|
|
|
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.compute_loss
|
|
"""
|
|
return super().compute_loss(model, inputs, return_outputs)
|
|
|
|
@swanlab_profile
|
|
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
|
|
"""Prediction step - always profile.
|
|
|
|
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prediction_step
|
|
"""
|
|
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
|
|
|
|
# ========================================================================
|
|
# Pattern 2: Fine-grained Context Manager Profiling
|
|
# ========================================================================
|
|
# Best for: Profiling specific code blocks within a method
|
|
# Use case: When you want to profile forward vs backward separately
|
|
|
|
def complex_training_step(self, model, inputs):
|
|
"""Training step with fine-grained profiling.
|
|
|
|
Profiling metrics:
|
|
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
|
|
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
|
|
- profiling/Time taken: CustomTrainerWithProfiling.optimizer_step
|
|
"""
|
|
# Profile just the forward pass
|
|
with swanlab_profiling_context(self, "forward_pass"):
|
|
outputs = model(**inputs)
|
|
loss = outputs.loss
|
|
|
|
# Profile just the backward pass
|
|
with swanlab_profiling_context(self, "backward_pass"):
|
|
loss.backward()
|
|
|
|
# Profile optimizer step
|
|
with swanlab_profiling_context(self, "optimizer_step"):
|
|
self.optimizer.step()
|
|
self.optimizer.zero_grad()
|
|
|
|
return outputs
|
|
|
|
# ========================================================================
|
|
# Pattern 3: Advanced Profiling with Filtering
|
|
# ========================================================================
|
|
# Best for: High-frequency operations where you want to throttle logging
|
|
# Use case: Methods called 100+ times per step
|
|
|
|
def _prepare_inputs(self, inputs):
|
|
"""Prepare inputs - throttled profiling.
|
|
|
|
This method is called frequently (once per batch), so we throttle
|
|
profiling to reduce overhead:
|
|
- Only log if duration > 0.5ms (skip very fast operations)
|
|
- Only log every 50th call (reduce logging frequency)
|
|
|
|
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_inputs
|
|
"""
|
|
with swanlab_profiling_context_advanced(
|
|
self, "prepare_inputs", config=self.fast_op_config
|
|
):
|
|
return super()._prepare_inputs(inputs)
|
|
|
|
def _prepare_input_for_model(self, input_ids):
|
|
"""Another high-frequency operation - throttled profiling.
|
|
|
|
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_input_for_model
|
|
"""
|
|
with swanlab_profiling_context_advanced(
|
|
self, "prepare_input_for_model", config=self.fast_op_config
|
|
):
|
|
# Your custom input preparation logic
|
|
return input_ids
|
|
|
|
# ========================================================================
|
|
# Pattern 4: Exception-safe Profiling
|
|
# ========================================================================
|
|
# Profiling is exception-safe: duration is logged even if method raises
|
|
|
|
@swanlab_profile
|
|
def potentially_failing_method(self):
|
|
"""This method may raise an exception.
|
|
|
|
SwanLab profiling will still log the duration before re-raising.
|
|
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.potentially_failing_method
|
|
"""
|
|
# Do some work
|
|
result = self._do_risky_computation()
|
|
|
|
# If this raises, profiling duration is still logged
|
|
if result < 0:
|
|
raise ValueError("Invalid result")
|
|
|
|
return result
|
|
|
|
def _do_risky_computation(self):
|
|
"""Placeholder for risky computation."""
|
|
return 42
|
|
|
|
|
|
# ============================================================================
|
|
# Advanced Example: Custom ProfilingConfig Per Method
|
|
# ============================================================================
|
|
|
|
|
|
class AdvancedProfilingTrainer(AxolotlTrainer):
|
|
"""Trainer with method-specific profiling configurations."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
# Different profiling configs for different method types
|
|
self.critical_path_config = ProfilingConfig(
|
|
enabled=True,
|
|
min_duration_ms=0.0, # Log everything on critical path
|
|
log_interval=1, # Log every call
|
|
)
|
|
|
|
self.fast_path_config = ProfilingConfig(
|
|
enabled=True,
|
|
min_duration_ms=1.0, # Only log if > 1ms
|
|
log_interval=100, # Log every 100th call
|
|
)
|
|
|
|
self.debug_config = ProfilingConfig(
|
|
enabled=True,
|
|
min_duration_ms=0.0, # Log everything
|
|
log_interval=1, # Log every call
|
|
)
|
|
|
|
def training_step(self, model, inputs):
|
|
"""Critical path - log everything."""
|
|
with swanlab_profiling_context_advanced(
|
|
self, "training_step", config=self.critical_path_config
|
|
):
|
|
return super().training_step(model, inputs)
|
|
|
|
def _prepare_inputs(self, inputs):
|
|
"""Fast path - throttle logging."""
|
|
with swanlab_profiling_context_advanced(
|
|
self, "prepare_inputs", config=self.fast_path_config
|
|
):
|
|
return super()._prepare_inputs(inputs)
|
|
|
|
def _debug_method(self, data):
|
|
"""Debug-only method - verbose logging."""
|
|
with swanlab_profiling_context_advanced(
|
|
self, "debug_method", config=self.debug_config
|
|
):
|
|
# Your debug logic
|
|
pass
|
|
|
|
|
|
# ============================================================================
|
|
# How to Use This Custom Trainer
|
|
# ============================================================================
|
|
|
|
"""
|
|
To use this custom trainer:
|
|
|
|
1. Save this file to your project (e.g., my_custom_trainer.py)
|
|
|
|
2. Create a config file that uses your custom trainer:
|
|
|
|
# config.yml
|
|
base_model: NousResearch/Llama-3.2-1B
|
|
|
|
# ... other config ...
|
|
|
|
plugins:
|
|
- axolotl.integrations.swanlab.SwanLabPlugin
|
|
|
|
use_swanlab: true
|
|
swanlab_project: my-profiling-experiment
|
|
|
|
# Optional: Specify custom trainer
|
|
# (Or modify axolotl to use your custom trainer class)
|
|
|
|
3. Run training:
|
|
|
|
export SWANLAB_API_KEY=your-api-key
|
|
accelerate launch -m axolotl.cli.train config.yml
|
|
|
|
4. View profiling metrics in SwanLab dashboard:
|
|
- profiling/Time taken: CustomTrainerWithProfiling.training_step
|
|
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
|
|
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
|
|
- etc.
|
|
|
|
5. Compare profiling metrics across runs:
|
|
- Run baseline without optimizations
|
|
- Run with flash_attention enabled
|
|
- Run with gradient_checkpointing enabled
|
|
- Compare profiling metrics to see performance impact
|
|
"""
|
|
|
|
# ============================================================================
|
|
# Tips for Effective Profiling
|
|
# ============================================================================
|
|
|
|
"""
|
|
1. Profile the critical path first:
|
|
- training_step, compute_loss, prediction_step
|
|
- These methods are called most frequently and have biggest impact
|
|
|
|
2. Use throttling for high-frequency operations:
|
|
- Methods called 100+ times per step
|
|
- Use log_interval=50 or log_interval=100
|
|
- Reduces profiling overhead and dashboard clutter
|
|
|
|
3. Filter noise with min_duration_ms:
|
|
- Set min_duration_ms=1.0 to skip very fast operations
|
|
- Focus on operations that actually take time
|
|
|
|
4. Compare across runs:
|
|
- Run same config multiple times to check consistency
|
|
- Compare different optimization strategies
|
|
- Track profiling trends over time
|
|
|
|
5. Monitor distributed training:
|
|
- Check for per-rank timing differences
|
|
- Look for stragglers (slower ranks)
|
|
- Identify synchronization bottlenecks
|
|
|
|
6. Disable profiling in production:
|
|
- from axolotl.integrations.swanlab.profiling import DEFAULT_PROFILING_CONFIG
|
|
- DEFAULT_PROFILING_CONFIG.enabled = False
|
|
|
|
7. Exception handling:
|
|
- Profiling is exception-safe
|
|
- Duration logged even if method raises
|
|
- Useful for debugging methods that fail intermittently
|
|
"""
|