feat: Add SwanLab integration for experiment tracking (#3334)

* feat(swanlab): add SwanLab integration for experiment tracking

SwanLab integration provides comprehensive experiment tracking and monitoring for Axolotl training.

Features:
- Hyperparameter logging
- Training metrics tracking
- RLHF completion logging
- Performance profiling
- Configuration validation and conflict detection

Includes:
- Plugin in src/axolotl/integrations/swanlab/
- Callback in src/axolotl/utils/callbacks/swanlab.py
- Tests in tests/integrations/test_swanlab.py
- Examples in examples/swanlab/

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* fix(swanlab): address PR #3334 review feedback from winglian and CodeRabbit

- Change use_swanlab default to True (winglian)
- Clear buffer after periodic logging to prevent duplicates (CodeRabbit Major)
- Add safe exception handling in config fallback (CodeRabbit)
- Use context managers for file operations (CodeRabbit)
- Replace LOG.error with LOG.exception for better debugging (CodeRabbit)
- Sort __all__ alphabetically (CodeRabbit)
- Add language specifiers to README code blocks (CodeRabbit)
- Fix end-of-file newline in README (pre-commit)

Resolves actionable comments and nitpicks from CodeRabbit review.
Addresses reviewer feedback from @winglian.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* only run swanlab integration tests if package is available

---------

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
PraMamba
2026-01-06 22:19:18 +08:00
committed by GitHub
parent ee59e4de97
commit 8aab807e67
14 changed files with 5438 additions and 0 deletions

285
examples/swanlab/README.md Normal file
View File

@@ -0,0 +1,285 @@
# SwanLab Integration Examples
This directory contains example configurations demonstrating SwanLab integration with Axolotl.
## Examples Overview
### 1. DPO with Completion Logging
**File**: `dpo-swanlab-completions.yml`
Demonstrates DPO (Direct Preference Optimization) training with RLHF completion table logging.
**Features**:
- Basic SwanLab experiment tracking
- Completion table logging (prompts, chosen/rejected responses, rewards)
- Memory-bounded buffer for long training runs
- Cloud sync configuration
**Best for**: RLHF practitioners who want to analyze model outputs qualitatively
**Quick start**:
```bash
export SWANLAB_API_KEY=your-api-key
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
```
---
### 2. LoRA with Performance Profiling
**File**: `lora-swanlab-profiling.yml`
Demonstrates standard LoRA fine-tuning with performance profiling enabled.
**Features**:
- SwanLab experiment tracking
- Automatic profiling of trainer methods
- Profiling metrics visualization
- Performance optimization guidance
**Best for**: Engineers optimizing training performance and comparing different configurations
**Quick start**:
```bash
export SWANLAB_API_KEY=your-api-key
accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
```
---
### 3. Full-Featured DPO Production Setup
**File**: `dpo-swanlab-full-featured.yml`
Comprehensive production-ready configuration with ALL SwanLab features enabled.
**Features**:
- Experiment tracking with team workspace
- RLHF completion logging
- Performance profiling
- Lark (Feishu) team notifications
- Private deployment support
- Production checklist and troubleshooting
**Best for**: Production RLHF training with team collaboration
**Quick start**:
```bash
export SWANLAB_API_KEY=your-api-key
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
export SWANLAB_LARK_SECRET=your-webhook-secret
accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
```
---
### 4. Custom Trainer Profiling (Python)
**File**: `custom_trainer_profiling.py`
Python code examples showing how to add SwanLab profiling to custom trainers.
**Features**:
- `@swanlab_profile` decorator examples
- Context manager profiling for fine-grained timing
- `ProfilingConfig` for advanced filtering and throttling
- Multiple profiling patterns and best practices
**Best for**: Advanced users creating custom trainers
**Usage**:
```python
from custom_trainer_profiling import CustomTrainerWithProfiling
# See file for detailed examples and patterns
```
---
## Feature Matrix
| Example | Tracking | Completion Logging | Profiling | Lark Notifications | Team Workspace |
|---------|----------|-------------------|-----------|-------------------|----------------|
| dpo-swanlab-completions.yml | ✅ | ✅ | ✅ (auto) | (commented) | (commented) |
| lora-swanlab-profiling.yml | ✅ | (disabled) | ✅ (auto) | (commented) | (commented) |
| dpo-swanlab-full-featured.yml | ✅ | ✅ | ✅ (auto) | ✅ | ✅ |
| custom_trainer_profiling.py | N/A | N/A | ✅ (manual) | N/A | N/A |
---
## Configuration Quick Reference
### Basic SwanLab Setup
```yaml
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
use_swanlab: true
swanlab_project: my-project
swanlab_experiment_name: my-experiment
swanlab_mode: cloud # cloud, local, offline, disabled
```
### RLHF Completion Logging
```yaml
swanlab_log_completions: true
swanlab_completion_log_interval: 100 # Log every 100 steps
swanlab_completion_max_buffer: 128 # Memory-bounded buffer
```
### Lark Team Notifications
```yaml
swanlab_lark_webhook_url: https://open.feishu.cn/...
swanlab_lark_secret: your-webhook-secret # Required for production
```
### Team Workspace
```yaml
swanlab_workspace: my-research-team
```
### Private Deployment
```yaml
swanlab_web_host: https://swanlab.yourcompany.com
swanlab_api_host: https://api.swanlab.yourcompany.com
```
---
## Authentication
### Recommended: Environment Variable
```bash
export SWANLAB_API_KEY=your-api-key
export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
export SWANLAB_LARK_SECRET=your-webhook-secret
```
### Alternative: Config File (less secure)
```yaml
swanlab_api_key: your-api-key
swanlab_lark_webhook_url: https://open.feishu.cn/...
swanlab_lark_secret: your-webhook-secret
```
---
## Common Use Cases
### Use Case 1: Migrate from WandB to SwanLab
Start with `lora-swanlab-profiling.yml`, add your model/dataset config, disable WandB:
```yaml
use_swanlab: true
use_wandb: false
```
### Use Case 2: Analyze DPO Model Outputs
Use `dpo-swanlab-completions.yml`, adjust completion logging interval based on your training length:
```yaml
swanlab_completion_log_interval: 50 # More frequent for short training
swanlab_completion_log_interval: 200 # Less frequent for long training
```
### Use Case 3: Optimize Training Performance
Use `lora-swanlab-profiling.yml`, run multiple experiments with different optimizations:
- Baseline: `flash_attention: false, gradient_checkpointing: false`
- Flash Attention: `flash_attention: true`
- Gradient Checkpointing: `gradient_checkpointing: true`
- Both: `flash_attention: true, gradient_checkpointing: true`
Compare profiling metrics in SwanLab dashboard.
### Use Case 4: Production RLHF with Team Collaboration
Use `dpo-swanlab-full-featured.yml`, set up team workspace and Lark notifications:
```yaml
swanlab_workspace: ml-team
swanlab_lark_webhook_url: ...
swanlab_lark_secret: ...
```
---
## Viewing Your Experiments
### Cloud Mode
Visit [https://swanlab.cn](https://swanlab.cn) and navigate to your project.
**Dashboard sections**:
- **Metrics**: Training loss, learning rate, profiling metrics
- **Tables**: RLHF completions (for DPO/KTO/ORPO/GRPO)
- **Config**: Hyperparameters and configuration
- **System**: Resource usage (GPU, memory, CPU)
- **Files**: Logged artifacts
### Local Mode
```bash
swanlab watch ./swanlog
# Open browser to http://localhost:5092
```
---
## Troubleshooting
### SwanLab not initializing
```bash
# Check API key
echo $SWANLAB_API_KEY
# Verify SwanLab is installed
pip show swanlab
# Check config
grep -A 5 "use_swanlab" your-config.yml
```
### Completions not appearing
- Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
- Check `swanlab_log_completions: true`
- Wait for `swanlab_completion_log_interval` steps
- Look for "Registered SwanLab RLHF completion logging" in logs
### Lark notifications not working
- Test webhook manually: `curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...`
- Verify `SWANLAB_LARK_SECRET` is set correctly
- Check bot is added to Lark group chat
- Look for "Registered Lark notification callback" in logs
### Profiling metrics not appearing
- Verify `use_swanlab: true`
- Check SwanLab is initialized (look for init log message)
- Profiling metrics are under "profiling/" namespace
- Profiling auto-enabled when SwanLab is enabled
---
## Performance Notes
### Overhead Comparison
| Feature | Overhead per Step | Memory Usage |
|---------|------------------|--------------|
| Basic tracking | < 0.1% | ~10 MB |
| Completion logging | < 0.5% | ~64 KB (buffer=128) |
| Profiling | < 0.1% | ~1 KB |
| **Total** | **< 0.7%** | **~10 MB** |
### Best Practices
1. Use ONE logging tool in production (disable WandB/MLflow when using SwanLab)
2. Adjust completion log interval based on training length (100-200 steps)
3. Keep completion buffer size reasonable (128-512)
4. Profile critical path methods first (training_step, compute_loss)
5. Use ProfilingConfig to throttle high-frequency operations
---
## Further Reading
- **Full Documentation**: [src/axolotl/integrations/swanlab/README.md](../../src/axolotl/integrations/swanlab/README.md)
- **SwanLab Docs**: [https://docs.swanlab.cn](https://docs.swanlab.cn)
- **Axolotl Docs**: [https://axolotl-ai-cloud.github.io/axolotl/](https://axolotl-ai-cloud.github.io/axolotl/)
- **DPO Paper**: [Direct Preference Optimization](https://arxiv.org/abs/2305.18290)
---
## Contributing
Found an issue or have an improvement? Please submit a PR or open an issue:
- [Axolotl Issues](https://github.com/axolotl-ai-cloud/axolotl/issues)
- [SwanLab Issues](https://github.com/SwanHubX/SwanLab/issues)

View File

@@ -0,0 +1,299 @@
"""Example: Custom Trainer with SwanLab Profiling
This example demonstrates how to add SwanLab profiling to your custom trainer.
Features:
- @swanlab_profile decorator for automatic profiling
- swanlab_profiling_context for fine-grained profiling
- ProfilingConfig for advanced filtering and throttling
Usage:
1. Create your custom trainer extending AxolotlTrainer
2. Add @swanlab_profile decorators to methods you want to profile
3. Use swanlab_profiling_context for fine-grained profiling within methods
4. Enable SwanLab in your config (use_swanlab: true)
See also:
- examples/swanlab/lora-swanlab-profiling.yml for config
- src/axolotl/integrations/swanlab/profiling.py for implementation
"""
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.integrations.swanlab.profiling import (
ProfilingConfig,
swanlab_profile,
swanlab_profiling_context,
swanlab_profiling_context_advanced,
)
class CustomTrainerWithProfiling(AxolotlTrainer):
"""Custom trainer with SwanLab profiling enabled.
This trainer demonstrates three profiling patterns:
1. Decorator-based profiling (@swanlab_profile)
2. Context manager profiling (swanlab_profiling_context)
3. Advanced profiling with filtering (ProfilingConfig)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Create custom profiling config for high-frequency operations
self.fast_op_config = ProfilingConfig(
enabled=True,
min_duration_ms=0.5, # Only log if duration > 0.5ms
log_interval=50, # Log every 50th call
)
# ========================================================================
# Pattern 1: Decorator-based Profiling
# ========================================================================
# Best for: Methods you always want to profile
# Overhead: ~2-5 microseconds per call (negligible)
@swanlab_profile
def training_step(self, model, inputs):
"""Main training step - always profile.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.training_step
"""
return super().training_step(model, inputs)
@swanlab_profile
def compute_loss(self, model, inputs, return_outputs=False):
"""Loss computation - always profile.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.compute_loss
"""
return super().compute_loss(model, inputs, return_outputs)
@swanlab_profile
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
"""Prediction step - always profile.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prediction_step
"""
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
# ========================================================================
# Pattern 2: Fine-grained Context Manager Profiling
# ========================================================================
# Best for: Profiling specific code blocks within a method
# Use case: When you want to profile forward vs backward separately
def complex_training_step(self, model, inputs):
"""Training step with fine-grained profiling.
Profiling metrics:
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
- profiling/Time taken: CustomTrainerWithProfiling.optimizer_step
"""
# Profile just the forward pass
with swanlab_profiling_context(self, "forward_pass"):
outputs = model(**inputs)
loss = outputs.loss
# Profile just the backward pass
with swanlab_profiling_context(self, "backward_pass"):
loss.backward()
# Profile optimizer step
with swanlab_profiling_context(self, "optimizer_step"):
self.optimizer.step()
self.optimizer.zero_grad()
return outputs
# ========================================================================
# Pattern 3: Advanced Profiling with Filtering
# ========================================================================
# Best for: High-frequency operations where you want to throttle logging
# Use case: Methods called 100+ times per step
def _prepare_inputs(self, inputs):
"""Prepare inputs - throttled profiling.
This method is called frequently (once per batch), so we throttle
profiling to reduce overhead:
- Only log if duration > 0.5ms (skip very fast operations)
- Only log every 50th call (reduce logging frequency)
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_inputs
"""
with swanlab_profiling_context_advanced(
self, "prepare_inputs", config=self.fast_op_config
):
return super()._prepare_inputs(inputs)
def _prepare_input_for_model(self, input_ids):
"""Another high-frequency operation - throttled profiling.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_input_for_model
"""
with swanlab_profiling_context_advanced(
self, "prepare_input_for_model", config=self.fast_op_config
):
# Your custom input preparation logic
return input_ids
# ========================================================================
# Pattern 4: Exception-safe Profiling
# ========================================================================
# Profiling is exception-safe: duration is logged even if method raises
@swanlab_profile
def potentially_failing_method(self):
"""This method may raise an exception.
SwanLab profiling will still log the duration before re-raising.
Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.potentially_failing_method
"""
# Do some work
result = self._do_risky_computation()
# If this raises, profiling duration is still logged
if result < 0:
raise ValueError("Invalid result")
return result
def _do_risky_computation(self):
"""Placeholder for risky computation."""
return 42
# ============================================================================
# Advanced Example: Custom ProfilingConfig Per Method
# ============================================================================
class AdvancedProfilingTrainer(AxolotlTrainer):
"""Trainer with method-specific profiling configurations."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Different profiling configs for different method types
self.critical_path_config = ProfilingConfig(
enabled=True,
min_duration_ms=0.0, # Log everything on critical path
log_interval=1, # Log every call
)
self.fast_path_config = ProfilingConfig(
enabled=True,
min_duration_ms=1.0, # Only log if > 1ms
log_interval=100, # Log every 100th call
)
self.debug_config = ProfilingConfig(
enabled=True,
min_duration_ms=0.0, # Log everything
log_interval=1, # Log every call
)
def training_step(self, model, inputs):
"""Critical path - log everything."""
with swanlab_profiling_context_advanced(
self, "training_step", config=self.critical_path_config
):
return super().training_step(model, inputs)
def _prepare_inputs(self, inputs):
"""Fast path - throttle logging."""
with swanlab_profiling_context_advanced(
self, "prepare_inputs", config=self.fast_path_config
):
return super()._prepare_inputs(inputs)
def _debug_method(self, data):
"""Debug-only method - verbose logging."""
with swanlab_profiling_context_advanced(
self, "debug_method", config=self.debug_config
):
# Your debug logic
pass
# ============================================================================
# How to Use This Custom Trainer
# ============================================================================
"""
To use this custom trainer:
1. Save this file to your project (e.g., my_custom_trainer.py)
2. Create a config file that uses your custom trainer:
# config.yml
base_model: NousResearch/Llama-3.2-1B
# ... other config ...
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
use_swanlab: true
swanlab_project: my-profiling-experiment
# Optional: Specify custom trainer
# (Or modify axolotl to use your custom trainer class)
3. Run training:
export SWANLAB_API_KEY=your-api-key
accelerate launch -m axolotl.cli.train config.yml
4. View profiling metrics in SwanLab dashboard:
- profiling/Time taken: CustomTrainerWithProfiling.training_step
- profiling/Time taken: CustomTrainerWithProfiling.forward_pass
- profiling/Time taken: CustomTrainerWithProfiling.backward_pass
- etc.
5. Compare profiling metrics across runs:
- Run baseline without optimizations
- Run with flash_attention enabled
- Run with gradient_checkpointing enabled
- Compare profiling metrics to see performance impact
"""
# ============================================================================
# Tips for Effective Profiling
# ============================================================================
"""
1. Profile the critical path first:
- training_step, compute_loss, prediction_step
- These methods are called most frequently and have biggest impact
2. Use throttling for high-frequency operations:
- Methods called 100+ times per step
- Use log_interval=50 or log_interval=100
- Reduces profiling overhead and dashboard clutter
3. Filter noise with min_duration_ms:
- Set min_duration_ms=1.0 to skip very fast operations
- Focus on operations that actually take time
4. Compare across runs:
- Run same config multiple times to check consistency
- Compare different optimization strategies
- Track profiling trends over time
5. Monitor distributed training:
- Check for per-rank timing differences
- Look for stragglers (slower ranks)
- Identify synchronization bottlenecks
6. Disable profiling in production:
- from axolotl.integrations.swanlab.profiling import DEFAULT_PROFILING_CONFIG
- DEFAULT_PROFILING_CONFIG.enabled = False
7. Exception handling:
- Profiling is exception-safe
- Duration logged even if method raises
- Useful for debugging methods that fail intermittently
"""

View File

@@ -0,0 +1,168 @@
# SwanLab DPO Training Example with Completion Logging
#
# This example demonstrates DPO (Direct Preference Optimization) training
# with SwanLab integration for experiment tracking and completion table logging.
#
# Features enabled:
# - SwanLab experiment tracking
# - RLHF completion table logging (prompts, chosen/rejected responses, rewards)
# - Lark (Feishu) team notifications (optional)
#
# To run:
# export SWANLAB_API_KEY=your-api-key
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml
# Model Configuration
base_model: meta-llama/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
# Quantization
load_in_8bit: true
load_in_4bit: false
# LoRA Configuration
adapter: lora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
# DPO Configuration
chat_template: llama3
rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_property_mappings:
role: role
content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
# Dataset and Output
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/dpo-swanlab-out
# Training Configuration
sequence_len: 4096
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 4
num_epochs: 4
# Optimization
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
warmup_ratio: 0.1
weight_decay: 0.0
# Precision
bf16: auto
tf32: false
# Performance
gradient_checkpointing: true
flash_attention: true
# Checkpointing and Logging
logging_steps: 1
evals_per_epoch: 4
saves_per_epoch: 1
# ============================================================================
# SwanLab Integration
# ============================================================================
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
# Basic SwanLab Configuration
use_swanlab: true
swanlab_project: dpo-training
swanlab_experiment_name: llama-3-dpo-completions-demo
swanlab_description: "DPO training with completion table logging"
swanlab_mode: cloud # Options: cloud, local, offline, disabled
# SwanLab Authentication
# Recommended: Set via environment variable
# export SWANLAB_API_KEY=your-api-key
# Or set in config (less secure):
# swanlab_api_key: your-api-key
# Optional: Team workspace
# swanlab_workspace: my-research-team
# ============================================================================
# RLHF Completion Table Logging
# ============================================================================
#
# Automatically logs model completions to SwanLab for qualitative analysis:
# - Prompts from your DPO dataset
# - Chosen responses (preferred)
# - Rejected responses (non-preferred)
# - Reward differences
#
# View the table in SwanLab dashboard under "rlhf_completions"
swanlab_log_completions: true
swanlab_completion_log_interval: 100 # Log every 100 training steps
swanlab_completion_max_buffer: 128 # Keep last 128 completions in memory
# Memory Usage Notes:
# - Buffer size 128: ~64 KB (default, recommended)
# - Buffer size 512: ~256 KB (for more historical completions)
# - Buffer size 1024: ~512 KB (maximum for very long training runs)
# Performance Notes:
# - Completion logging overhead: < 0.5% per training step
# - Only logs every N steps to minimize impact
# - Memory-bounded buffer prevents memory leaks
# ============================================================================
# Optional: Lark (Feishu) Team Notifications
# ============================================================================
#
# Get real-time training notifications in your team chat
# Uncomment to enable:
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
# swanlab_lark_secret: your-webhook-secret # Recommended for production
# Notifications sent for:
# - Training start
# - Training completion
# - Training errors
# - Metric milestones (if configured)
# ============================================================================
# Optional: Private SwanLab Deployment
# ============================================================================
#
# For enterprise users with private SwanLab deployment:
# swanlab_web_host: https://swanlab.yourcompany.com
# swanlab_api_host: https://api.swanlab.yourcompany.com
# ============================================================================
# Disable WandB if you're migrating from it
# ============================================================================
# wandb_project:
# wandb_entity:
# use_wandb: false

View File

@@ -0,0 +1,329 @@
# SwanLab Full-Featured DPO Training Example
#
# This example demonstrates ALL SwanLab integration features:
# - Experiment tracking with cloud sync
# - RLHF completion table logging
# - Performance profiling
# - Lark (Feishu) team notifications
# - Team workspace collaboration
#
# Use this as a reference for production RLHF training setups.
#
# To run:
# export SWANLAB_API_KEY=your-api-key
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
# export SWANLAB_LARK_SECRET=your-webhook-secret
# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml
# ============================================================================
# Model Configuration
# ============================================================================
base_model: meta-llama/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
# Quantization for efficient training
load_in_8bit: true
load_in_4bit: false
# ============================================================================
# LoRA Configuration
# ============================================================================
adapter: lora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true # Target all linear layers
# ============================================================================
# DPO (Direct Preference Optimization) Configuration
# ============================================================================
chat_template: llama3
rl: dpo # Enable DPO trainer
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_property_mappings:
role: role
content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
# ============================================================================
# Dataset and Output Configuration
# ============================================================================
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/dpo-swanlab-full-featured-out
# ============================================================================
# Training Configuration
# ============================================================================
sequence_len: 4096
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 4
num_epochs: 4
# ============================================================================
# Optimization
# ============================================================================
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
warmup_ratio: 0.1
weight_decay: 0.0
# ============================================================================
# Precision and Performance
# ============================================================================
bf16: auto
tf32: false
gradient_checkpointing: true
flash_attention: true
# ============================================================================
# Checkpointing and Logging
# ============================================================================
logging_steps: 1
evals_per_epoch: 4
saves_per_epoch: 1
# ============================================================================
# SwanLab Integration - Full Configuration
# ============================================================================
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
# ------------------------------------------------------------------------------
# Basic SwanLab Configuration
# ------------------------------------------------------------------------------
use_swanlab: true
swanlab_project: dpo-production
swanlab_experiment_name: llama-3-dpo-full-featured-v1
swanlab_description: |
Production DPO training with all SwanLab features enabled:
- Completion table logging for qualitative analysis
- Performance profiling for optimization
- Lark notifications for team collaboration
swanlab_mode: cloud # Options: cloud, local, offline, disabled
# ------------------------------------------------------------------------------
# Team Collaboration
# ------------------------------------------------------------------------------
# Workspace for team collaboration (shared experiments)
swanlab_workspace: ml-research-team
# Authentication (recommended: use environment variable)
# export SWANLAB_API_KEY=your-api-key
# Or set in config (less secure):
# swanlab_api_key: your-api-key
# ------------------------------------------------------------------------------
# RLHF Completion Table Logging
# ------------------------------------------------------------------------------
# Automatically logs model completions for qualitative analysis:
# - Prompts from your DPO dataset
# - Chosen responses (preferred)
# - Rejected responses (non-preferred)
# - Reward differences
#
# View in SwanLab dashboard under "rlhf_completions" table
swanlab_log_completions: true
swanlab_completion_log_interval: 100 # Log every 100 steps
swanlab_completion_max_buffer: 256 # Larger buffer for long training runs
# Buffer size recommendations:
# - 128: Default, ~64 KB memory (recommended for most cases)
# - 256: ~128 KB memory (this config, good for longer training)
# - 512: ~256 KB memory (maximum for very long runs)
# ------------------------------------------------------------------------------
# Lark (Feishu) Team Notifications
# ------------------------------------------------------------------------------
# Get real-time training notifications in your team chat
#
# Notifications sent for:
# - Training start
# - Training completion
# - Training errors
# - Metric milestones (if configured)
# Recommended: Set via environment variables
# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/...
# export SWANLAB_LARK_SECRET=your-webhook-secret
# Or set in config (less secure):
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
# swanlab_lark_secret: your-webhook-secret # REQUIRED for production
# Security note: ALWAYS use swanlab_lark_secret in production to prevent
# unauthorized parties from sending fake notifications to your team chat.
# ------------------------------------------------------------------------------
# Performance Profiling
# ------------------------------------------------------------------------------
# Profiling is automatically enabled when SwanLab is enabled.
# Metrics logged to SwanLab under "profiling/" namespace:
# profiling/Time taken: AxolotlTrainer.training_step
# profiling/Time taken: AxolotlTrainer.compute_loss
# profiling/Time taken: AxolotlTrainer.prediction_step
#
# Use these metrics to:
# - Identify bottlenecks in training loop
# - Compare performance across different configurations
# - Monitor performance regressions over time
# - Debug unexpected slowdowns
# For custom profiling in your own trainer, see:
# examples/swanlab/custom_trainer_profiling.py
# ------------------------------------------------------------------------------
# Optional: Private SwanLab Deployment
# ------------------------------------------------------------------------------
# For enterprise users with private SwanLab deployment:
# swanlab_web_host: https://swanlab.yourcompany.com
# swanlab_api_host: https://api.swanlab.yourcompany.com
# ------------------------------------------------------------------------------
# Optional: Model Checkpointing to SwanLab
# ------------------------------------------------------------------------------
# Log model checkpoints to SwanLab (coming soon)
swanlab_log_model: false
# ============================================================================
# Disable Other Logging Tools (Recommended)
# ============================================================================
# Using multiple logging tools simultaneously can impact performance:
# - Expected overhead: ~1-2% per logger
# - Potential config/callback conflicts
#
# For production training, use ONLY SwanLab:
# wandb_project:
# use_wandb: false
#
# use_mlflow: false
#
# use_comet: false
# ============================================================================
# Expected Training Behavior
# ============================================================================
# With this configuration, you should see:
#
# 1. SwanLab Initialization (rank 0 only):
# INFO: SwanLab initialized for project: dpo-production
# INFO: SwanLab experiment: llama-3-dpo-full-featured-v1
# INFO: SwanLab mode: cloud
# INFO: SwanLab workspace: ml-research-team
#
# 2. Completion Logging (rank 0 only):
# INFO: Registered SwanLab RLHF completion logging callback for DPOTrainer
# (log_interval=100, max_buffer=256)
#
# 3. Lark Notifications (rank 0 only):
# INFO: Registered Lark notification callback with HMAC authentication
#
# 4. Distributed Training Detection (if multi-GPU):
# INFO: Distributed training detected (world_size=N)
# INFO: Only rank 0 will initialize SwanLab
# INFO: Other ranks will skip SwanLab to avoid conflicts
#
# 5. Training Start Notification (Lark):
# Your team chat receives: "Training started: llama-3-dpo-full-featured-v1"
#
# 6. Periodic Completion Logging:
# Every 100 steps, completion table is updated in SwanLab dashboard
#
# 7. Training Complete Notification (Lark):
# Your team chat receives: "Training completed: llama-3-dpo-full-featured-v1"
# With link to SwanLab dashboard and final metrics
#
# 8. SwanLab Dashboard Shows:
# - Training metrics (loss, learning rate, etc.)
# - Completion table (rlhf_completions)
# - Profiling metrics (profiling/Time taken: ...)
# - Hyperparameters and configuration
# - System resource usage
# ============================================================================
# Production Checklist
# ============================================================================
# Before deploying to production, verify:
# ✅ SwanLab API key is set via environment variable (not in config)
# ✅ Lark webhook secret is set (required for HMAC authentication)
# ✅ Workspace is set to your team's workspace
# ✅ Experiment name is descriptive and unique
# ✅ Only SwanLab is enabled (other loggers disabled)
# ✅ Completion logging buffer size is appropriate for your training duration
# ✅ Private deployment hosts are set (if using enterprise SwanLab)
# ✅ Test run completes successfully and shows up in SwanLab dashboard
# ✅ Lark notifications are received in team chat
# ✅ Profiling metrics are logged correctly
# ============================================================================
# Troubleshooting
# ============================================================================
# If SwanLab initialization fails:
# 1. Check SWANLAB_API_KEY environment variable is set
# 2. Verify swanlab_project is set in config
# 3. Check swanlab_mode is valid (cloud/local/offline/disabled)
# 4. Verify internet connectivity (for cloud mode)
# If Lark notifications not received:
# 1. Check SWANLAB_LARK_WEBHOOK_URL is set correctly
# 2. Verify SWANLAB_LARK_SECRET matches your Lark bot settings
# 3. Test webhook manually: curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...
# 4. Check training logs for "Registered Lark notification callback"
# 5. Verify bot is added to the target Lark group chat
# If completions not appearing in SwanLab:
# 1. Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO)
# 2. Check swanlab_log_completions is true
# 3. Wait for log_interval steps (default: 100)
# 4. Check training logs for "Registered SwanLab RLHF completion logging"
# If profiling metrics not appearing:
# 1. Verify use_swanlab is true
# 2. Check SwanLab is initialized (check logs)
# 3. Look under "profiling/" namespace in dashboard
# 4. Profiling may be disabled if DEFAULT_PROFILING_CONFIG.enabled = False
# For more help:
# - SwanLab docs: https://docs.swanlab.cn
# - Axolotl SwanLab integration: src/axolotl/integrations/swanlab/README.md
# - GitHub issues: https://github.com/axolotl-ai-cloud/axolotl/issues

View File

@@ -0,0 +1,178 @@
# SwanLab LoRA Training Example with Performance Profiling
#
# This example demonstrates standard LoRA fine-tuning with SwanLab integration
# for performance profiling and optimization.
#
# Features enabled:
# - SwanLab experiment tracking
# - Performance profiling (training step, forward/backward pass timing)
# - Real-time metrics visualization
#
# To run:
# export SWANLAB_API_KEY=your-api-key
# accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml
# Model Configuration
base_model: NousResearch/Llama-3.2-1B
# Dataset Configuration
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
val_set_size: 0.1
output_dir: ./outputs/lora-swanlab-profiling-out
# LoRA Configuration
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
# Training Configuration
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
micro_batch_size: 2
gradient_accumulation_steps: 2
num_epochs: 1
# Optimization
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
warmup_ratio: 0.1
weight_decay: 0.0
# Precision
bf16: auto
tf32: false
# Performance
gradient_checkpointing: true
flash_attention: true
# Checkpointing and Logging
logging_steps: 1
evals_per_epoch: 4
saves_per_epoch: 1
# Loss Monitoring
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
special_tokens:
pad_token: "<|end_of_text|>"
# ============================================================================
# SwanLab Integration
# ============================================================================
plugins:
- axolotl.integrations.swanlab.SwanLabPlugin
# Basic SwanLab Configuration
use_swanlab: true
swanlab_project: lora-profiling
swanlab_experiment_name: llama-3.2-1b-profiling-demo
swanlab_description: "LoRA fine-tuning with performance profiling"
swanlab_mode: cloud # Options: cloud, local, offline, disabled
# SwanLab Authentication
# Recommended: Set via environment variable
# export SWANLAB_API_KEY=your-api-key
# Or set in config (less secure):
# swanlab_api_key: your-api-key
# Optional: Team workspace
# swanlab_workspace: my-ml-team
# ============================================================================
# Performance Profiling
# ============================================================================
#
# SwanLab automatically profiles trainer methods when enabled.
# Profiling metrics appear in SwanLab dashboard under "profiling/" namespace.
#
# Built-in profiling:
# - Minimal overhead (< 0.1% per step)
# - High-precision timing (microsecond accuracy)
# - Exception-safe (logs duration even if method fails)
#
# View profiling metrics in SwanLab dashboard:
# profiling/Time taken: AxolotlTrainer.training_step
# profiling/Time taken: AxolotlTrainer.compute_loss
# profiling/Time taken: AxolotlTrainer.prediction_step
#
# For custom profiling in your own trainer, see:
# examples/swanlab/custom_trainer_profiling.py
# Completion logging is disabled for non-RLHF trainers
swanlab_log_completions: false # Only works with DPO/KTO/ORPO/GRPO
# ============================================================================
# Optional: Compare with Multiple Runs
# ============================================================================
#
# To compare profiling metrics across different configurations:
#
# 1. Run baseline without flash attention:
# swanlab_experiment_name: llama-3.2-1b-no-flash-attn
# flash_attention: false
#
# 2. Run with gradient checkpointing:
# swanlab_experiment_name: llama-3.2-1b-grad-checkpoint
# gradient_checkpointing: true
#
# 3. Run with both:
# swanlab_experiment_name: llama-3.2-1b-optimized
# flash_attention: true
# gradient_checkpointing: true
#
# Then compare profiling metrics in SwanLab dashboard to see performance impact
# ============================================================================
# Optional: Lark (Feishu) Team Notifications
# ============================================================================
#
# Get notified when profiling experiments complete:
# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx
# swanlab_lark_secret: your-webhook-secret
# ============================================================================
# Profiling Best Practices
# ============================================================================
#
# 1. Run multiple epochs to see profiling trends over time
# 2. Ignore first ~10 steps (warmup period, slower)
# 3. Look for outliers (steps that take significantly longer)
# 4. Compare profiling metrics before/after optimization changes
# 5. Monitor per-rank profiling in distributed training
#
# Common bottlenecks to profile:
# - training_step: Overall step time (should be consistent)
# - compute_loss: Loss computation (scales with sequence length)
# - prediction_step: Evaluation time (can be slow for large val sets)
#
# If you see inconsistent timing:
# - Check for data loading bottlenecks
# - Monitor GPU utilization (may be CPU-bound)
# - Check for gradient accumulation effects
# - Verify CUDA kernel synchronization
# ============================================================================
# Disable WandB if you're migrating from it
# ============================================================================
# wandb_project:
# use_wandb: false