From 8aab807e678bdc60532dc748a72624bf0b0ae580 Mon Sep 17 00:00:00 2001 From: PraMamba <128919538+PraMamba@users.noreply.github.com> Date: Tue, 6 Jan 2026 22:19:18 +0800 Subject: [PATCH] feat: Add SwanLab integration for experiment tracking (#3334) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(swanlab): add SwanLab integration for experiment tracking SwanLab integration provides comprehensive experiment tracking and monitoring for Axolotl training. Features: - Hyperparameter logging - Training metrics tracking - RLHF completion logging - Performance profiling - Configuration validation and conflict detection Includes: - Plugin in src/axolotl/integrations/swanlab/ - Callback in src/axolotl/utils/callbacks/swanlab.py - Tests in tests/integrations/test_swanlab.py - Examples in examples/swanlab/ 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 * fix(swanlab): address PR #3334 review feedback from winglian and CodeRabbit - Change use_swanlab default to True (winglian) - Clear buffer after periodic logging to prevent duplicates (CodeRabbit Major) - Add safe exception handling in config fallback (CodeRabbit) - Use context managers for file operations (CodeRabbit) - Replace LOG.error with LOG.exception for better debugging (CodeRabbit) - Sort __all__ alphabetically (CodeRabbit) - Add language specifiers to README code blocks (CodeRabbit) - Fix end-of-file newline in README (pre-commit) Resolves actionable comments and nitpicks from CodeRabbit review. Addresses reviewer feedback from @winglian. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 * only run swanlab integration tests if package is available --------- Co-authored-by: Claude Sonnet 4.5 Co-authored-by: Wing Lian --- examples/swanlab/README.md | 285 ++++ examples/swanlab/custom_trainer_profiling.py | 299 ++++ examples/swanlab/dpo-swanlab-completions.yml | 168 +++ .../swanlab/dpo-swanlab-full-featured.yml | 329 ++++ examples/swanlab/lora-swanlab-profiling.yml | 178 +++ src/axolotl/integrations/swanlab/README.md | 1284 ++++++++++++++++ src/axolotl/integrations/swanlab/__init__.py | 6 + src/axolotl/integrations/swanlab/args.py | 140 ++ src/axolotl/integrations/swanlab/callbacks.py | 179 +++ .../integrations/swanlab/completion_logger.py | 228 +++ src/axolotl/integrations/swanlab/plugins.py | 554 +++++++ src/axolotl/integrations/swanlab/profiling.py | 203 +++ src/axolotl/utils/callbacks/swanlab.py | 248 +++ tests/integrations/test_swanlab.py | 1337 +++++++++++++++++ 14 files changed, 5438 insertions(+) create mode 100644 examples/swanlab/README.md create mode 100644 examples/swanlab/custom_trainer_profiling.py create mode 100644 examples/swanlab/dpo-swanlab-completions.yml create mode 100644 examples/swanlab/dpo-swanlab-full-featured.yml create mode 100644 examples/swanlab/lora-swanlab-profiling.yml create mode 100644 src/axolotl/integrations/swanlab/README.md create mode 100644 src/axolotl/integrations/swanlab/__init__.py create mode 100644 src/axolotl/integrations/swanlab/args.py create mode 100644 src/axolotl/integrations/swanlab/callbacks.py create mode 100644 src/axolotl/integrations/swanlab/completion_logger.py create mode 100644 src/axolotl/integrations/swanlab/plugins.py create mode 100644 src/axolotl/integrations/swanlab/profiling.py create mode 100644 src/axolotl/utils/callbacks/swanlab.py create mode 100644 tests/integrations/test_swanlab.py diff --git a/examples/swanlab/README.md b/examples/swanlab/README.md new file mode 100644 index 000000000..c793ff2c2 --- /dev/null +++ b/examples/swanlab/README.md @@ -0,0 +1,285 @@ +# SwanLab Integration Examples + +This directory contains example configurations demonstrating SwanLab integration with Axolotl. + +## Examples Overview + +### 1. DPO with Completion Logging +**File**: `dpo-swanlab-completions.yml` + +Demonstrates DPO (Direct Preference Optimization) training with RLHF completion table logging. + +**Features**: +- Basic SwanLab experiment tracking +- Completion table logging (prompts, chosen/rejected responses, rewards) +- Memory-bounded buffer for long training runs +- Cloud sync configuration + +**Best for**: RLHF practitioners who want to analyze model outputs qualitatively + +**Quick start**: +```bash +export SWANLAB_API_KEY=your-api-key +accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml +``` + +--- + +### 2. LoRA with Performance Profiling +**File**: `lora-swanlab-profiling.yml` + +Demonstrates standard LoRA fine-tuning with performance profiling enabled. + +**Features**: +- SwanLab experiment tracking +- Automatic profiling of trainer methods +- Profiling metrics visualization +- Performance optimization guidance + +**Best for**: Engineers optimizing training performance and comparing different configurations + +**Quick start**: +```bash +export SWANLAB_API_KEY=your-api-key +accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml +``` + +--- + +### 3. Full-Featured DPO Production Setup +**File**: `dpo-swanlab-full-featured.yml` + +Comprehensive production-ready configuration with ALL SwanLab features enabled. + +**Features**: +- Experiment tracking with team workspace +- RLHF completion logging +- Performance profiling +- Lark (Feishu) team notifications +- Private deployment support +- Production checklist and troubleshooting + +**Best for**: Production RLHF training with team collaboration + +**Quick start**: +```bash +export SWANLAB_API_KEY=your-api-key +export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/... +export SWANLAB_LARK_SECRET=your-webhook-secret +accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml +``` + +--- + +### 4. Custom Trainer Profiling (Python) +**File**: `custom_trainer_profiling.py` + +Python code examples showing how to add SwanLab profiling to custom trainers. + +**Features**: +- `@swanlab_profile` decorator examples +- Context manager profiling for fine-grained timing +- `ProfilingConfig` for advanced filtering and throttling +- Multiple profiling patterns and best practices + +**Best for**: Advanced users creating custom trainers + +**Usage**: +```python +from custom_trainer_profiling import CustomTrainerWithProfiling +# See file for detailed examples and patterns +``` + +--- + +## Feature Matrix + +| Example | Tracking | Completion Logging | Profiling | Lark Notifications | Team Workspace | +|---------|----------|-------------------|-----------|-------------------|----------------| +| dpo-swanlab-completions.yml | ✅ | ✅ | ✅ (auto) | ➖ (commented) | ➖ (commented) | +| lora-swanlab-profiling.yml | ✅ | ➖ (disabled) | ✅ (auto) | ➖ (commented) | ➖ (commented) | +| dpo-swanlab-full-featured.yml | ✅ | ✅ | ✅ (auto) | ✅ | ✅ | +| custom_trainer_profiling.py | N/A | N/A | ✅ (manual) | N/A | N/A | + +--- + +## Configuration Quick Reference + +### Basic SwanLab Setup +```yaml +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + +use_swanlab: true +swanlab_project: my-project +swanlab_experiment_name: my-experiment +swanlab_mode: cloud # cloud, local, offline, disabled +``` + +### RLHF Completion Logging +```yaml +swanlab_log_completions: true +swanlab_completion_log_interval: 100 # Log every 100 steps +swanlab_completion_max_buffer: 128 # Memory-bounded buffer +``` + +### Lark Team Notifications +```yaml +swanlab_lark_webhook_url: https://open.feishu.cn/... +swanlab_lark_secret: your-webhook-secret # Required for production +``` + +### Team Workspace +```yaml +swanlab_workspace: my-research-team +``` + +### Private Deployment +```yaml +swanlab_web_host: https://swanlab.yourcompany.com +swanlab_api_host: https://api.swanlab.yourcompany.com +``` + +--- + +## Authentication + +### Recommended: Environment Variable +```bash +export SWANLAB_API_KEY=your-api-key +export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/... +export SWANLAB_LARK_SECRET=your-webhook-secret +``` + +### Alternative: Config File (less secure) +```yaml +swanlab_api_key: your-api-key +swanlab_lark_webhook_url: https://open.feishu.cn/... +swanlab_lark_secret: your-webhook-secret +``` + +--- + +## Common Use Cases + +### Use Case 1: Migrate from WandB to SwanLab +Start with `lora-swanlab-profiling.yml`, add your model/dataset config, disable WandB: +```yaml +use_swanlab: true +use_wandb: false +``` + +### Use Case 2: Analyze DPO Model Outputs +Use `dpo-swanlab-completions.yml`, adjust completion logging interval based on your training length: +```yaml +swanlab_completion_log_interval: 50 # More frequent for short training +swanlab_completion_log_interval: 200 # Less frequent for long training +``` + +### Use Case 3: Optimize Training Performance +Use `lora-swanlab-profiling.yml`, run multiple experiments with different optimizations: +- Baseline: `flash_attention: false, gradient_checkpointing: false` +- Flash Attention: `flash_attention: true` +- Gradient Checkpointing: `gradient_checkpointing: true` +- Both: `flash_attention: true, gradient_checkpointing: true` + +Compare profiling metrics in SwanLab dashboard. + +### Use Case 4: Production RLHF with Team Collaboration +Use `dpo-swanlab-full-featured.yml`, set up team workspace and Lark notifications: +```yaml +swanlab_workspace: ml-team +swanlab_lark_webhook_url: ... +swanlab_lark_secret: ... +``` + +--- + +## Viewing Your Experiments + +### Cloud Mode +Visit [https://swanlab.cn](https://swanlab.cn) and navigate to your project. + +**Dashboard sections**: +- **Metrics**: Training loss, learning rate, profiling metrics +- **Tables**: RLHF completions (for DPO/KTO/ORPO/GRPO) +- **Config**: Hyperparameters and configuration +- **System**: Resource usage (GPU, memory, CPU) +- **Files**: Logged artifacts + +### Local Mode +```bash +swanlab watch ./swanlog +# Open browser to http://localhost:5092 +``` + +--- + +## Troubleshooting + +### SwanLab not initializing +```bash +# Check API key +echo $SWANLAB_API_KEY + +# Verify SwanLab is installed +pip show swanlab + +# Check config +grep -A 5 "use_swanlab" your-config.yml +``` + +### Completions not appearing +- Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO) +- Check `swanlab_log_completions: true` +- Wait for `swanlab_completion_log_interval` steps +- Look for "Registered SwanLab RLHF completion logging" in logs + +### Lark notifications not working +- Test webhook manually: `curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ...` +- Verify `SWANLAB_LARK_SECRET` is set correctly +- Check bot is added to Lark group chat +- Look for "Registered Lark notification callback" in logs + +### Profiling metrics not appearing +- Verify `use_swanlab: true` +- Check SwanLab is initialized (look for init log message) +- Profiling metrics are under "profiling/" namespace +- Profiling auto-enabled when SwanLab is enabled + +--- + +## Performance Notes + +### Overhead Comparison + +| Feature | Overhead per Step | Memory Usage | +|---------|------------------|--------------| +| Basic tracking | < 0.1% | ~10 MB | +| Completion logging | < 0.5% | ~64 KB (buffer=128) | +| Profiling | < 0.1% | ~1 KB | +| **Total** | **< 0.7%** | **~10 MB** | + +### Best Practices +1. Use ONE logging tool in production (disable WandB/MLflow when using SwanLab) +2. Adjust completion log interval based on training length (100-200 steps) +3. Keep completion buffer size reasonable (128-512) +4. Profile critical path methods first (training_step, compute_loss) +5. Use ProfilingConfig to throttle high-frequency operations + +--- + +## Further Reading + +- **Full Documentation**: [src/axolotl/integrations/swanlab/README.md](../../src/axolotl/integrations/swanlab/README.md) +- **SwanLab Docs**: [https://docs.swanlab.cn](https://docs.swanlab.cn) +- **Axolotl Docs**: [https://axolotl-ai-cloud.github.io/axolotl/](https://axolotl-ai-cloud.github.io/axolotl/) +- **DPO Paper**: [Direct Preference Optimization](https://arxiv.org/abs/2305.18290) + +--- + +## Contributing + +Found an issue or have an improvement? Please submit a PR or open an issue: +- [Axolotl Issues](https://github.com/axolotl-ai-cloud/axolotl/issues) +- [SwanLab Issues](https://github.com/SwanHubX/SwanLab/issues) diff --git a/examples/swanlab/custom_trainer_profiling.py b/examples/swanlab/custom_trainer_profiling.py new file mode 100644 index 000000000..65461c4e5 --- /dev/null +++ b/examples/swanlab/custom_trainer_profiling.py @@ -0,0 +1,299 @@ +"""Example: Custom Trainer with SwanLab Profiling + +This example demonstrates how to add SwanLab profiling to your custom trainer. + +Features: +- @swanlab_profile decorator for automatic profiling +- swanlab_profiling_context for fine-grained profiling +- ProfilingConfig for advanced filtering and throttling + +Usage: + 1. Create your custom trainer extending AxolotlTrainer + 2. Add @swanlab_profile decorators to methods you want to profile + 3. Use swanlab_profiling_context for fine-grained profiling within methods + 4. Enable SwanLab in your config (use_swanlab: true) + +See also: + - examples/swanlab/lora-swanlab-profiling.yml for config + - src/axolotl/integrations/swanlab/profiling.py for implementation +""" + +from axolotl.core.trainers.base import AxolotlTrainer +from axolotl.integrations.swanlab.profiling import ( + ProfilingConfig, + swanlab_profile, + swanlab_profiling_context, + swanlab_profiling_context_advanced, +) + + +class CustomTrainerWithProfiling(AxolotlTrainer): + """Custom trainer with SwanLab profiling enabled. + + This trainer demonstrates three profiling patterns: + 1. Decorator-based profiling (@swanlab_profile) + 2. Context manager profiling (swanlab_profiling_context) + 3. Advanced profiling with filtering (ProfilingConfig) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Create custom profiling config for high-frequency operations + self.fast_op_config = ProfilingConfig( + enabled=True, + min_duration_ms=0.5, # Only log if duration > 0.5ms + log_interval=50, # Log every 50th call + ) + + # ======================================================================== + # Pattern 1: Decorator-based Profiling + # ======================================================================== + # Best for: Methods you always want to profile + # Overhead: ~2-5 microseconds per call (negligible) + + @swanlab_profile + def training_step(self, model, inputs): + """Main training step - always profile. + + Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.training_step + """ + return super().training_step(model, inputs) + + @swanlab_profile + def compute_loss(self, model, inputs, return_outputs=False): + """Loss computation - always profile. + + Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.compute_loss + """ + return super().compute_loss(model, inputs, return_outputs) + + @swanlab_profile + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): + """Prediction step - always profile. + + Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prediction_step + """ + return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) + + # ======================================================================== + # Pattern 2: Fine-grained Context Manager Profiling + # ======================================================================== + # Best for: Profiling specific code blocks within a method + # Use case: When you want to profile forward vs backward separately + + def complex_training_step(self, model, inputs): + """Training step with fine-grained profiling. + + Profiling metrics: + - profiling/Time taken: CustomTrainerWithProfiling.forward_pass + - profiling/Time taken: CustomTrainerWithProfiling.backward_pass + - profiling/Time taken: CustomTrainerWithProfiling.optimizer_step + """ + # Profile just the forward pass + with swanlab_profiling_context(self, "forward_pass"): + outputs = model(**inputs) + loss = outputs.loss + + # Profile just the backward pass + with swanlab_profiling_context(self, "backward_pass"): + loss.backward() + + # Profile optimizer step + with swanlab_profiling_context(self, "optimizer_step"): + self.optimizer.step() + self.optimizer.zero_grad() + + return outputs + + # ======================================================================== + # Pattern 3: Advanced Profiling with Filtering + # ======================================================================== + # Best for: High-frequency operations where you want to throttle logging + # Use case: Methods called 100+ times per step + + def _prepare_inputs(self, inputs): + """Prepare inputs - throttled profiling. + + This method is called frequently (once per batch), so we throttle + profiling to reduce overhead: + - Only log if duration > 0.5ms (skip very fast operations) + - Only log every 50th call (reduce logging frequency) + + Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_inputs + """ + with swanlab_profiling_context_advanced( + self, "prepare_inputs", config=self.fast_op_config + ): + return super()._prepare_inputs(inputs) + + def _prepare_input_for_model(self, input_ids): + """Another high-frequency operation - throttled profiling. + + Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.prepare_input_for_model + """ + with swanlab_profiling_context_advanced( + self, "prepare_input_for_model", config=self.fast_op_config + ): + # Your custom input preparation logic + return input_ids + + # ======================================================================== + # Pattern 4: Exception-safe Profiling + # ======================================================================== + # Profiling is exception-safe: duration is logged even if method raises + + @swanlab_profile + def potentially_failing_method(self): + """This method may raise an exception. + + SwanLab profiling will still log the duration before re-raising. + Profiling metric: profiling/Time taken: CustomTrainerWithProfiling.potentially_failing_method + """ + # Do some work + result = self._do_risky_computation() + + # If this raises, profiling duration is still logged + if result < 0: + raise ValueError("Invalid result") + + return result + + def _do_risky_computation(self): + """Placeholder for risky computation.""" + return 42 + + +# ============================================================================ +# Advanced Example: Custom ProfilingConfig Per Method +# ============================================================================ + + +class AdvancedProfilingTrainer(AxolotlTrainer): + """Trainer with method-specific profiling configurations.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Different profiling configs for different method types + self.critical_path_config = ProfilingConfig( + enabled=True, + min_duration_ms=0.0, # Log everything on critical path + log_interval=1, # Log every call + ) + + self.fast_path_config = ProfilingConfig( + enabled=True, + min_duration_ms=1.0, # Only log if > 1ms + log_interval=100, # Log every 100th call + ) + + self.debug_config = ProfilingConfig( + enabled=True, + min_duration_ms=0.0, # Log everything + log_interval=1, # Log every call + ) + + def training_step(self, model, inputs): + """Critical path - log everything.""" + with swanlab_profiling_context_advanced( + self, "training_step", config=self.critical_path_config + ): + return super().training_step(model, inputs) + + def _prepare_inputs(self, inputs): + """Fast path - throttle logging.""" + with swanlab_profiling_context_advanced( + self, "prepare_inputs", config=self.fast_path_config + ): + return super()._prepare_inputs(inputs) + + def _debug_method(self, data): + """Debug-only method - verbose logging.""" + with swanlab_profiling_context_advanced( + self, "debug_method", config=self.debug_config + ): + # Your debug logic + pass + + +# ============================================================================ +# How to Use This Custom Trainer +# ============================================================================ + +""" +To use this custom trainer: + +1. Save this file to your project (e.g., my_custom_trainer.py) + +2. Create a config file that uses your custom trainer: + + # config.yml + base_model: NousResearch/Llama-3.2-1B + + # ... other config ... + + plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + + use_swanlab: true + swanlab_project: my-profiling-experiment + + # Optional: Specify custom trainer + # (Or modify axolotl to use your custom trainer class) + +3. Run training: + + export SWANLAB_API_KEY=your-api-key + accelerate launch -m axolotl.cli.train config.yml + +4. View profiling metrics in SwanLab dashboard: + - profiling/Time taken: CustomTrainerWithProfiling.training_step + - profiling/Time taken: CustomTrainerWithProfiling.forward_pass + - profiling/Time taken: CustomTrainerWithProfiling.backward_pass + - etc. + +5. Compare profiling metrics across runs: + - Run baseline without optimizations + - Run with flash_attention enabled + - Run with gradient_checkpointing enabled + - Compare profiling metrics to see performance impact +""" + +# ============================================================================ +# Tips for Effective Profiling +# ============================================================================ + +""" +1. Profile the critical path first: + - training_step, compute_loss, prediction_step + - These methods are called most frequently and have biggest impact + +2. Use throttling for high-frequency operations: + - Methods called 100+ times per step + - Use log_interval=50 or log_interval=100 + - Reduces profiling overhead and dashboard clutter + +3. Filter noise with min_duration_ms: + - Set min_duration_ms=1.0 to skip very fast operations + - Focus on operations that actually take time + +4. Compare across runs: + - Run same config multiple times to check consistency + - Compare different optimization strategies + - Track profiling trends over time + +5. Monitor distributed training: + - Check for per-rank timing differences + - Look for stragglers (slower ranks) + - Identify synchronization bottlenecks + +6. Disable profiling in production: + - from axolotl.integrations.swanlab.profiling import DEFAULT_PROFILING_CONFIG + - DEFAULT_PROFILING_CONFIG.enabled = False + +7. Exception handling: + - Profiling is exception-safe + - Duration logged even if method raises + - Useful for debugging methods that fail intermittently +""" diff --git a/examples/swanlab/dpo-swanlab-completions.yml b/examples/swanlab/dpo-swanlab-completions.yml new file mode 100644 index 000000000..5615ca638 --- /dev/null +++ b/examples/swanlab/dpo-swanlab-completions.yml @@ -0,0 +1,168 @@ +# SwanLab DPO Training Example with Completion Logging +# +# This example demonstrates DPO (Direct Preference Optimization) training +# with SwanLab integration for experiment tracking and completion table logging. +# +# Features enabled: +# - SwanLab experiment tracking +# - RLHF completion table logging (prompts, chosen/rejected responses, rewards) +# - Lark (Feishu) team notifications (optional) +# +# To run: +# export SWANLAB_API_KEY=your-api-key +# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-completions.yml + +# Model Configuration +base_model: meta-llama/Meta-Llama-3-8B-Instruct +model_type: LlamaForCausalLM +tokenizer_type: AutoTokenizer + +special_tokens: + pad_token: <|finetune_right_pad_id|> + eos_token: <|eot_id|> + +# Quantization +load_in_8bit: true +load_in_4bit: false + +# LoRA Configuration +adapter: lora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true + +# DPO Configuration +chat_template: llama3 +rl: dpo + +datasets: + - path: fozziethebeat/alpaca_messages_2k_dpo_test + type: chat_template.default + field_messages: conversation + field_chosen: chosen + field_rejected: rejected + message_property_mappings: + role: role + content: content + roles: + system: + - system + user: + - user + assistant: + - assistant + +# Dataset and Output +dataset_prepared_path: +val_set_size: 0.05 +output_dir: ./outputs/dpo-swanlab-out + +# Training Configuration +sequence_len: 4096 +sample_packing: false +micro_batch_size: 2 +gradient_accumulation_steps: 4 +num_epochs: 4 + +# Optimization +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 +warmup_ratio: 0.1 +weight_decay: 0.0 + +# Precision +bf16: auto +tf32: false + +# Performance +gradient_checkpointing: true +flash_attention: true + +# Checkpointing and Logging +logging_steps: 1 +evals_per_epoch: 4 +saves_per_epoch: 1 + +# ============================================================================ +# SwanLab Integration +# ============================================================================ + +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + +# Basic SwanLab Configuration +use_swanlab: true +swanlab_project: dpo-training +swanlab_experiment_name: llama-3-dpo-completions-demo +swanlab_description: "DPO training with completion table logging" +swanlab_mode: cloud # Options: cloud, local, offline, disabled + +# SwanLab Authentication +# Recommended: Set via environment variable +# export SWANLAB_API_KEY=your-api-key +# Or set in config (less secure): +# swanlab_api_key: your-api-key + +# Optional: Team workspace +# swanlab_workspace: my-research-team + +# ============================================================================ +# RLHF Completion Table Logging +# ============================================================================ +# +# Automatically logs model completions to SwanLab for qualitative analysis: +# - Prompts from your DPO dataset +# - Chosen responses (preferred) +# - Rejected responses (non-preferred) +# - Reward differences +# +# View the table in SwanLab dashboard under "rlhf_completions" + +swanlab_log_completions: true +swanlab_completion_log_interval: 100 # Log every 100 training steps +swanlab_completion_max_buffer: 128 # Keep last 128 completions in memory + +# Memory Usage Notes: +# - Buffer size 128: ~64 KB (default, recommended) +# - Buffer size 512: ~256 KB (for more historical completions) +# - Buffer size 1024: ~512 KB (maximum for very long training runs) + +# Performance Notes: +# - Completion logging overhead: < 0.5% per training step +# - Only logs every N steps to minimize impact +# - Memory-bounded buffer prevents memory leaks + +# ============================================================================ +# Optional: Lark (Feishu) Team Notifications +# ============================================================================ +# +# Get real-time training notifications in your team chat +# Uncomment to enable: + +# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx +# swanlab_lark_secret: your-webhook-secret # Recommended for production + +# Notifications sent for: +# - Training start +# - Training completion +# - Training errors +# - Metric milestones (if configured) + +# ============================================================================ +# Optional: Private SwanLab Deployment +# ============================================================================ +# +# For enterprise users with private SwanLab deployment: + +# swanlab_web_host: https://swanlab.yourcompany.com +# swanlab_api_host: https://api.swanlab.yourcompany.com + +# ============================================================================ +# Disable WandB if you're migrating from it +# ============================================================================ + +# wandb_project: +# wandb_entity: +# use_wandb: false diff --git a/examples/swanlab/dpo-swanlab-full-featured.yml b/examples/swanlab/dpo-swanlab-full-featured.yml new file mode 100644 index 000000000..c25178c63 --- /dev/null +++ b/examples/swanlab/dpo-swanlab-full-featured.yml @@ -0,0 +1,329 @@ +# SwanLab Full-Featured DPO Training Example +# +# This example demonstrates ALL SwanLab integration features: +# - Experiment tracking with cloud sync +# - RLHF completion table logging +# - Performance profiling +# - Lark (Feishu) team notifications +# - Team workspace collaboration +# +# Use this as a reference for production RLHF training setups. +# +# To run: +# export SWANLAB_API_KEY=your-api-key +# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/... +# export SWANLAB_LARK_SECRET=your-webhook-secret +# accelerate launch -m axolotl.cli.train examples/swanlab/dpo-swanlab-full-featured.yml + +# ============================================================================ +# Model Configuration +# ============================================================================ + +base_model: meta-llama/Meta-Llama-3-8B-Instruct +model_type: LlamaForCausalLM +tokenizer_type: AutoTokenizer + +special_tokens: + pad_token: <|finetune_right_pad_id|> + eos_token: <|eot_id|> + +# Quantization for efficient training +load_in_8bit: true +load_in_4bit: false + +# ============================================================================ +# LoRA Configuration +# ============================================================================ + +adapter: lora +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true # Target all linear layers + +# ============================================================================ +# DPO (Direct Preference Optimization) Configuration +# ============================================================================ + +chat_template: llama3 +rl: dpo # Enable DPO trainer + +datasets: + - path: fozziethebeat/alpaca_messages_2k_dpo_test + type: chat_template.default + field_messages: conversation + field_chosen: chosen + field_rejected: rejected + message_property_mappings: + role: role + content: content + roles: + system: + - system + user: + - user + assistant: + - assistant + +# ============================================================================ +# Dataset and Output Configuration +# ============================================================================ + +dataset_prepared_path: +val_set_size: 0.05 +output_dir: ./outputs/dpo-swanlab-full-featured-out + +# ============================================================================ +# Training Configuration +# ============================================================================ + +sequence_len: 4096 +sample_packing: false + +micro_batch_size: 2 +gradient_accumulation_steps: 4 +num_epochs: 4 + +# ============================================================================ +# Optimization +# ============================================================================ + +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 +warmup_ratio: 0.1 +weight_decay: 0.0 + +# ============================================================================ +# Precision and Performance +# ============================================================================ + +bf16: auto +tf32: false + +gradient_checkpointing: true +flash_attention: true + +# ============================================================================ +# Checkpointing and Logging +# ============================================================================ + +logging_steps: 1 +evals_per_epoch: 4 +saves_per_epoch: 1 + +# ============================================================================ +# SwanLab Integration - Full Configuration +# ============================================================================ + +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + +# ------------------------------------------------------------------------------ +# Basic SwanLab Configuration +# ------------------------------------------------------------------------------ + +use_swanlab: true +swanlab_project: dpo-production +swanlab_experiment_name: llama-3-dpo-full-featured-v1 +swanlab_description: | + Production DPO training with all SwanLab features enabled: + - Completion table logging for qualitative analysis + - Performance profiling for optimization + - Lark notifications for team collaboration + +swanlab_mode: cloud # Options: cloud, local, offline, disabled + +# ------------------------------------------------------------------------------ +# Team Collaboration +# ------------------------------------------------------------------------------ + +# Workspace for team collaboration (shared experiments) +swanlab_workspace: ml-research-team + +# Authentication (recommended: use environment variable) +# export SWANLAB_API_KEY=your-api-key +# Or set in config (less secure): +# swanlab_api_key: your-api-key + +# ------------------------------------------------------------------------------ +# RLHF Completion Table Logging +# ------------------------------------------------------------------------------ +# Automatically logs model completions for qualitative analysis: +# - Prompts from your DPO dataset +# - Chosen responses (preferred) +# - Rejected responses (non-preferred) +# - Reward differences +# +# View in SwanLab dashboard under "rlhf_completions" table + +swanlab_log_completions: true +swanlab_completion_log_interval: 100 # Log every 100 steps +swanlab_completion_max_buffer: 256 # Larger buffer for long training runs + +# Buffer size recommendations: +# - 128: Default, ~64 KB memory (recommended for most cases) +# - 256: ~128 KB memory (this config, good for longer training) +# - 512: ~256 KB memory (maximum for very long runs) + +# ------------------------------------------------------------------------------ +# Lark (Feishu) Team Notifications +# ------------------------------------------------------------------------------ +# Get real-time training notifications in your team chat +# +# Notifications sent for: +# - Training start +# - Training completion +# - Training errors +# - Metric milestones (if configured) + +# Recommended: Set via environment variables +# export SWANLAB_LARK_WEBHOOK_URL=https://open.feishu.cn/... +# export SWANLAB_LARK_SECRET=your-webhook-secret + +# Or set in config (less secure): +# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx +# swanlab_lark_secret: your-webhook-secret # REQUIRED for production + +# Security note: ALWAYS use swanlab_lark_secret in production to prevent +# unauthorized parties from sending fake notifications to your team chat. + +# ------------------------------------------------------------------------------ +# Performance Profiling +# ------------------------------------------------------------------------------ +# Profiling is automatically enabled when SwanLab is enabled. +# Metrics logged to SwanLab under "profiling/" namespace: +# profiling/Time taken: AxolotlTrainer.training_step +# profiling/Time taken: AxolotlTrainer.compute_loss +# profiling/Time taken: AxolotlTrainer.prediction_step +# +# Use these metrics to: +# - Identify bottlenecks in training loop +# - Compare performance across different configurations +# - Monitor performance regressions over time +# - Debug unexpected slowdowns + +# For custom profiling in your own trainer, see: +# examples/swanlab/custom_trainer_profiling.py + +# ------------------------------------------------------------------------------ +# Optional: Private SwanLab Deployment +# ------------------------------------------------------------------------------ +# For enterprise users with private SwanLab deployment: + +# swanlab_web_host: https://swanlab.yourcompany.com +# swanlab_api_host: https://api.swanlab.yourcompany.com + +# ------------------------------------------------------------------------------ +# Optional: Model Checkpointing to SwanLab +# ------------------------------------------------------------------------------ +# Log model checkpoints to SwanLab (coming soon) + +swanlab_log_model: false + +# ============================================================================ +# Disable Other Logging Tools (Recommended) +# ============================================================================ +# Using multiple logging tools simultaneously can impact performance: +# - Expected overhead: ~1-2% per logger +# - Potential config/callback conflicts +# +# For production training, use ONLY SwanLab: + +# wandb_project: +# use_wandb: false +# +# use_mlflow: false +# +# use_comet: false + +# ============================================================================ +# Expected Training Behavior +# ============================================================================ + +# With this configuration, you should see: +# +# 1. SwanLab Initialization (rank 0 only): +# INFO: SwanLab initialized for project: dpo-production +# INFO: SwanLab experiment: llama-3-dpo-full-featured-v1 +# INFO: SwanLab mode: cloud +# INFO: SwanLab workspace: ml-research-team +# +# 2. Completion Logging (rank 0 only): +# INFO: Registered SwanLab RLHF completion logging callback for DPOTrainer +# (log_interval=100, max_buffer=256) +# +# 3. Lark Notifications (rank 0 only): +# INFO: Registered Lark notification callback with HMAC authentication +# +# 4. Distributed Training Detection (if multi-GPU): +# INFO: Distributed training detected (world_size=N) +# INFO: Only rank 0 will initialize SwanLab +# INFO: Other ranks will skip SwanLab to avoid conflicts +# +# 5. Training Start Notification (Lark): +# Your team chat receives: "Training started: llama-3-dpo-full-featured-v1" +# +# 6. Periodic Completion Logging: +# Every 100 steps, completion table is updated in SwanLab dashboard +# +# 7. Training Complete Notification (Lark): +# Your team chat receives: "Training completed: llama-3-dpo-full-featured-v1" +# With link to SwanLab dashboard and final metrics +# +# 8. SwanLab Dashboard Shows: +# - Training metrics (loss, learning rate, etc.) +# - Completion table (rlhf_completions) +# - Profiling metrics (profiling/Time taken: ...) +# - Hyperparameters and configuration +# - System resource usage + +# ============================================================================ +# Production Checklist +# ============================================================================ + +# Before deploying to production, verify: +# ✅ SwanLab API key is set via environment variable (not in config) +# ✅ Lark webhook secret is set (required for HMAC authentication) +# ✅ Workspace is set to your team's workspace +# ✅ Experiment name is descriptive and unique +# ✅ Only SwanLab is enabled (other loggers disabled) +# ✅ Completion logging buffer size is appropriate for your training duration +# ✅ Private deployment hosts are set (if using enterprise SwanLab) +# ✅ Test run completes successfully and shows up in SwanLab dashboard +# ✅ Lark notifications are received in team chat +# ✅ Profiling metrics are logged correctly + +# ============================================================================ +# Troubleshooting +# ============================================================================ + +# If SwanLab initialization fails: +# 1. Check SWANLAB_API_KEY environment variable is set +# 2. Verify swanlab_project is set in config +# 3. Check swanlab_mode is valid (cloud/local/offline/disabled) +# 4. Verify internet connectivity (for cloud mode) + +# If Lark notifications not received: +# 1. Check SWANLAB_LARK_WEBHOOK_URL is set correctly +# 2. Verify SWANLAB_LARK_SECRET matches your Lark bot settings +# 3. Test webhook manually: curl -X POST "$SWANLAB_LARK_WEBHOOK_URL" ... +# 4. Check training logs for "Registered Lark notification callback" +# 5. Verify bot is added to the target Lark group chat + +# If completions not appearing in SwanLab: +# 1. Verify you're using an RLHF trainer (DPO/KTO/ORPO/GRPO) +# 2. Check swanlab_log_completions is true +# 3. Wait for log_interval steps (default: 100) +# 4. Check training logs for "Registered SwanLab RLHF completion logging" + +# If profiling metrics not appearing: +# 1. Verify use_swanlab is true +# 2. Check SwanLab is initialized (check logs) +# 3. Look under "profiling/" namespace in dashboard +# 4. Profiling may be disabled if DEFAULT_PROFILING_CONFIG.enabled = False + +# For more help: +# - SwanLab docs: https://docs.swanlab.cn +# - Axolotl SwanLab integration: src/axolotl/integrations/swanlab/README.md +# - GitHub issues: https://github.com/axolotl-ai-cloud/axolotl/issues diff --git a/examples/swanlab/lora-swanlab-profiling.yml b/examples/swanlab/lora-swanlab-profiling.yml new file mode 100644 index 000000000..1255105a6 --- /dev/null +++ b/examples/swanlab/lora-swanlab-profiling.yml @@ -0,0 +1,178 @@ +# SwanLab LoRA Training Example with Performance Profiling +# +# This example demonstrates standard LoRA fine-tuning with SwanLab integration +# for performance profiling and optimization. +# +# Features enabled: +# - SwanLab experiment tracking +# - Performance profiling (training step, forward/backward pass timing) +# - Real-time metrics visualization +# +# To run: +# export SWANLAB_API_KEY=your-api-key +# accelerate launch -m axolotl.cli.train examples/swanlab/lora-swanlab-profiling.yml + +# Model Configuration +base_model: NousResearch/Llama-3.2-1B + +# Dataset Configuration +datasets: + - path: teknium/GPT4-LLM-Cleaned + type: alpaca + +val_set_size: 0.1 +output_dir: ./outputs/lora-swanlab-profiling-out + +# LoRA Configuration +adapter: lora +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.05 +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +# Training Configuration +sequence_len: 2048 +sample_packing: true +eval_sample_packing: true + +micro_batch_size: 2 +gradient_accumulation_steps: 2 +num_epochs: 1 + +# Optimization +optimizer: adamw_8bit +lr_scheduler: cosine +learning_rate: 0.0002 +warmup_ratio: 0.1 +weight_decay: 0.0 + +# Precision +bf16: auto +tf32: false + +# Performance +gradient_checkpointing: true +flash_attention: true + +# Checkpointing and Logging +logging_steps: 1 +evals_per_epoch: 4 +saves_per_epoch: 1 + +# Loss Monitoring +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + +special_tokens: + pad_token: "<|end_of_text|>" + +# ============================================================================ +# SwanLab Integration +# ============================================================================ + +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + +# Basic SwanLab Configuration +use_swanlab: true +swanlab_project: lora-profiling +swanlab_experiment_name: llama-3.2-1b-profiling-demo +swanlab_description: "LoRA fine-tuning with performance profiling" +swanlab_mode: cloud # Options: cloud, local, offline, disabled + +# SwanLab Authentication +# Recommended: Set via environment variable +# export SWANLAB_API_KEY=your-api-key +# Or set in config (less secure): +# swanlab_api_key: your-api-key + +# Optional: Team workspace +# swanlab_workspace: my-ml-team + +# ============================================================================ +# Performance Profiling +# ============================================================================ +# +# SwanLab automatically profiles trainer methods when enabled. +# Profiling metrics appear in SwanLab dashboard under "profiling/" namespace. +# +# Built-in profiling: +# - Minimal overhead (< 0.1% per step) +# - High-precision timing (microsecond accuracy) +# - Exception-safe (logs duration even if method fails) +# +# View profiling metrics in SwanLab dashboard: +# profiling/Time taken: AxolotlTrainer.training_step +# profiling/Time taken: AxolotlTrainer.compute_loss +# profiling/Time taken: AxolotlTrainer.prediction_step +# +# For custom profiling in your own trainer, see: +# examples/swanlab/custom_trainer_profiling.py + +# Completion logging is disabled for non-RLHF trainers +swanlab_log_completions: false # Only works with DPO/KTO/ORPO/GRPO + +# ============================================================================ +# Optional: Compare with Multiple Runs +# ============================================================================ +# +# To compare profiling metrics across different configurations: +# +# 1. Run baseline without flash attention: +# swanlab_experiment_name: llama-3.2-1b-no-flash-attn +# flash_attention: false +# +# 2. Run with gradient checkpointing: +# swanlab_experiment_name: llama-3.2-1b-grad-checkpoint +# gradient_checkpointing: true +# +# 3. Run with both: +# swanlab_experiment_name: llama-3.2-1b-optimized +# flash_attention: true +# gradient_checkpointing: true +# +# Then compare profiling metrics in SwanLab dashboard to see performance impact + +# ============================================================================ +# Optional: Lark (Feishu) Team Notifications +# ============================================================================ +# +# Get notified when profiling experiments complete: + +# swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx +# swanlab_lark_secret: your-webhook-secret + +# ============================================================================ +# Profiling Best Practices +# ============================================================================ +# +# 1. Run multiple epochs to see profiling trends over time +# 2. Ignore first ~10 steps (warmup period, slower) +# 3. Look for outliers (steps that take significantly longer) +# 4. Compare profiling metrics before/after optimization changes +# 5. Monitor per-rank profiling in distributed training +# +# Common bottlenecks to profile: +# - training_step: Overall step time (should be consistent) +# - compute_loss: Loss computation (scales with sequence length) +# - prediction_step: Evaluation time (can be slow for large val sets) +# +# If you see inconsistent timing: +# - Check for data loading bottlenecks +# - Monitor GPU utilization (may be CPU-bound) +# - Check for gradient accumulation effects +# - Verify CUDA kernel synchronization + +# ============================================================================ +# Disable WandB if you're migrating from it +# ============================================================================ + +# wandb_project: +# use_wandb: false diff --git a/src/axolotl/integrations/swanlab/README.md b/src/axolotl/integrations/swanlab/README.md new file mode 100644 index 000000000..eff7d4a5a --- /dev/null +++ b/src/axolotl/integrations/swanlab/README.md @@ -0,0 +1,1284 @@ +# SwanLab Integration for Axolotl + +SwanLab is an open-source, lightweight AI experiment tracking and visualization tool that provides a platform for tracking, recording, comparing, and collaborating on experiments. + +This integration enables seamless experiment tracking and visualization of Axolotl training runs using SwanLab. + +## Features + +- 📊 **Automatic Metrics Logging**: Training loss, learning rate, and other metrics are automatically logged +- 🎯 **Hyperparameter Tracking**: Model configuration and training parameters are tracked +- 📈 **Real-time Visualization**: Monitor training progress in real-time through SwanLab dashboard +- ☁️ **Cloud & Local Support**: Works in both cloud-synced and offline modes +- 🔄 **Experiment Comparison**: Compare multiple training runs easily +- 🤝 **Team Collaboration**: Share experiments with team members +- 🎭 **RLHF Completion Logging**: Automatically log model outputs during DPO/KTO/ORPO/GRPO training for qualitative analysis +- ⚡ **Performance Profiling**: Built-in profiling decorators to measure and optimize training performance +- 🔔 **Lark Notifications**: Send real-time training updates to team chat (Feishu/Lark integration) + +## Installation + +```bash +pip install swanlab +``` + +## Quick Start + +### 1. Register for SwanLab (Optional for cloud mode) + +If you want to use cloud sync features, register at [https://swanlab.cn](https://swanlab.cn) to get your API key. + +### 2. Configure Axolotl Config File + +Add SwanLab configuration to your Axolotl YAML config: + +```yaml +# Enable SwanLab plugin +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + +# SwanLab configuration +use_swanlab: true +swanlab_project: my-llm-project +swanlab_experiment_name: qwen-finetune-v1 +swanlab_mode: cloud # Options: cloud, local, offline, disabled +swanlab_workspace: my-team # Optional: organization name +swanlab_api_key: YOUR_API_KEY # Optional: can also use env var SWANLAB_API_KEY +``` + +### 3. Run Training + +```bash +# Set API key via environment variable (recommended) +export SWANLAB_API_KEY=your-api-key-here + +# Or login once +swanlab login + +# Run training as usual +accelerate launch -m axolotl.cli.train your-config.yaml +``` + +## Configuration Options + +### Basic Configuration + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `use_swanlab` | bool | `false` | Enable SwanLab tracking | +| `swanlab_project` | str | `None` | Project name (required) | +| `swanlab_experiment_name` | str | `None` | Experiment name | +| `swanlab_description` | str | `None` | Experiment description | +| `swanlab_mode` | str | `cloud` | Sync mode: `cloud`, `local`, `offline`, `disabled` | + +### Advanced Configuration + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `swanlab_workspace` | str | `None` | Workspace/organization name | +| `swanlab_api_key` | str | `None` | API key (prefer env var) | +| `swanlab_web_host` | str | `None` | Private deployment web host | +| `swanlab_api_host` | str | `None` | Private deployment API host | +| `swanlab_log_model` | bool | `false` | Log model checkpoints (coming soon) | +| `swanlab_lark_webhook_url` | str | `None` | Lark (Feishu) webhook URL for team notifications | +| `swanlab_lark_secret` | str | `None` | Lark webhook HMAC secret for authentication | +| `swanlab_log_completions` | bool | `true` | Enable RLHF completion table logging (DPO/KTO/ORPO/GRPO) | +| `swanlab_completion_log_interval` | int | `100` | Steps between completion logging | +| `swanlab_completion_max_buffer` | int | `128` | Max completions to buffer (memory bound) | + +## Configuration Examples + +### Example 1: Basic Cloud Sync + +```yaml +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + +use_swanlab: true +swanlab_project: llama-finetune +swanlab_experiment_name: llama-3-8b-instruct-v1 +swanlab_mode: cloud +``` + +### Example 2: Offline/Local Mode + +```yaml +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + +use_swanlab: true +swanlab_project: local-experiments +swanlab_experiment_name: test-run-1 +swanlab_mode: local # or 'offline' +``` + +### Example 3: Team Workspace + +```yaml +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + +use_swanlab: true +swanlab_project: research-project +swanlab_experiment_name: experiment-42 +swanlab_workspace: my-research-team +swanlab_mode: cloud +``` + +### Example 4: Private Deployment + +```yaml +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + +use_swanlab: true +swanlab_project: internal-project +swanlab_experiment_name: secure-training +swanlab_mode: cloud +swanlab_web_host: https://swanlab.yourcompany.com +swanlab_api_host: https://api.swanlab.yourcompany.com +``` + +## Team Notifications with Lark (Feishu) + +SwanLab supports sending real-time training notifications to your team chat via Lark (Feishu), ByteDance's enterprise collaboration platform. This is especially useful for: +- **Production training monitoring**: Get alerts when training starts, completes, or encounters errors +- **Team collaboration**: Keep your ML team informed about long-running experiments +- **Multi-timezone teams**: Team members can check training progress without being online + +### Prerequisites + +1. **Lark Bot Setup**: Create a custom bot in your Lark group chat +2. **Webhook URL**: Get the webhook URL from your Lark bot settings +3. **HMAC Secret** (recommended): Enable signature verification in your Lark bot for security + +For detailed Lark bot setup instructions, see [Lark Custom Bot Documentation](https://open.feishu.cn/document/ukTMukTMukTM/ucTM5YjL3ETO24yNxkjN). + +### Example 5: Basic Lark Notifications + +Send training notifications to a Lark group chat: + +```yaml +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + +use_swanlab: true +swanlab_project: production-training +swanlab_experiment_name: llama-3-finetune-v2 +swanlab_mode: cloud + +# Lark notification (basic, no HMAC verification) +swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx +``` + +**Note**: This configuration will work, but you'll see a security warning recommending HMAC secret configuration. + +### Example 6: Lark Notifications with HMAC Security (Recommended) + +For production use, enable HMAC signature verification: + +```yaml +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + +use_swanlab: true +swanlab_project: production-training +swanlab_experiment_name: llama-3-finetune-v2 +swanlab_mode: cloud + +# Lark notification with HMAC authentication +swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx +swanlab_lark_secret: your-webhook-secret-key +``` + +**Why HMAC secret matters**: +- Prevents unauthorized parties from sending fake notifications to your Lark group +- Ensures notifications genuinely come from your training jobs +- Required for production deployments with sensitive training data + +### Example 7: Team Workspace + Lark Notifications + +Combine team workspace collaboration with Lark notifications: + +```yaml +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + +use_swanlab: true +swanlab_project: research-project +swanlab_experiment_name: multimodal-experiment-42 +swanlab_workspace: ml-research-team +swanlab_mode: cloud + +# Notify team via Lark when training starts/completes +swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx +swanlab_lark_secret: your-webhook-secret-key +``` + +### What Notifications Are Sent? + +SwanLab's Lark integration sends notifications for key training events: +- **Training Start**: When your experiment begins +- **Training Complete**: When training finishes successfully +- **Training Errors**: If training crashes or encounters critical errors +- **Metric Milestones**: Configurable alerts for metric thresholds (if configured in SwanLab) + +Each notification includes: +- Experiment name and project +- Training status +- Key metrics (loss, learning rate) +- Direct link to SwanLab dashboard + +### Lark Configuration Validation + +The plugin validates your Lark configuration at startup: + +#### ✅ Valid Configurations + +```yaml +# Option 1: No Lark (default) +use_swanlab: true +swanlab_project: my-project +# No swanlab_lark_webhook_url → Lark disabled, no warnings + +# Option 2: Lark with HMAC secret (recommended) +use_swanlab: true +swanlab_project: my-project +swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxx +swanlab_lark_secret: your-secret +# ✅ Logs: "Registered Lark notification callback with HMAC authentication" + +# Option 3: Lark without secret (works but not recommended) +use_swanlab: true +swanlab_project: my-project +swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxx +# ⚠️ Logs: "Registered Lark notification callback (no HMAC secret)" +# ⚠️ Warning: "Lark webhook has no secret configured. For production use, set 'swanlab_lark_secret'..." +``` + +### Security Best Practices + +1. **Always use HMAC secret in production**: + ```yaml + swanlab_lark_webhook_url: https://open.feishu.cn/... + swanlab_lark_secret: your-secret-key # ✅ Add this! + ``` + +2. **Store secrets in environment variables** (even better): + ```yaml + # In your training script/environment + export SWANLAB_LARK_WEBHOOK_URL="https://open.feishu.cn/..." + export SWANLAB_LARK_SECRET="your-secret-key" + ``` + + Then in config: + ```yaml + # SwanLab plugin will auto-detect environment variables + use_swanlab: true + swanlab_project: my-project + # Lark URL and secret read from env vars + ``` + +3. **Rotate webhook secrets periodically**: Update your Lark bot's secret every 90 days + +4. **Use separate webhooks for dev/prod**: Don't mix development and production notifications + +### Distributed Training + +Lark notifications are automatically deduplicated in distributed training: +- Only **rank 0** sends notifications +- Other GPU ranks skip Lark registration +- Prevents duplicate messages in multi-GPU training + +```bash +# Running on 4 GPUs +torchrun --nproc_per_node=4 -m axolotl.cli.train config.yml + +# Expected logs: +# [Rank 0] Registered Lark notification callback with HMAC authentication +# [Rank 1-3] (no Lark registration messages) +``` + +## RLHF Completion Table Logging + +For RLHF (Reinforcement Learning from Human Feedback) training methods like DPO, KTO, ORPO, and GRPO, SwanLab can log model completions (prompts, chosen/rejected responses, rewards) to a visual table for qualitative analysis. This helps you: + +- **Inspect model behavior**: See actual model outputs during training +- **Debug preference learning**: Compare chosen vs rejected responses +- **Track reward patterns**: Monitor how rewards evolve over training +- **Share examples with team**: Visual tables in SwanLab dashboard + +### Features + +- ✅ **Automatic detection**: Works with DPO, KTO, ORPO, GRPO trainers +- ✅ **Memory-safe buffering**: Bounded buffer prevents memory leaks in long training runs +- ✅ **Periodic logging**: Configurable logging interval to reduce overhead +- ✅ **Rich visualization**: SwanLab tables show prompts, responses, and metrics side-by-side + +### Configuration + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `swanlab_log_completions` | bool | `true` | Enable completion logging for RLHF trainers | +| `swanlab_completion_log_interval` | int | `100` | Log completions to SwanLab every N training steps | +| `swanlab_completion_max_buffer` | int | `128` | Maximum completions to buffer (memory bound) | + +### Example: DPO Training with Completion Logging + +```yaml +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + +use_swanlab: true +swanlab_project: dpo-training +swanlab_experiment_name: llama-3-dpo-v1 +swanlab_mode: cloud + +# RLHF completion logging (enabled by default) +swanlab_log_completions: true +swanlab_completion_log_interval: 100 # Log every 100 steps +swanlab_completion_max_buffer: 128 # Keep last 128 completions + +# DPO-specific config +rl: dpo +datasets: + - path: /path/to/preference_dataset + type: chatml.intel +``` + +### Example: Disable Completion Logging + +If you're doing a quick test run or don't need completion tables: + +```yaml +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + +use_swanlab: true +swanlab_project: dpo-training + +# Disable completion logging +swanlab_log_completions: false +``` + +### Supported RLHF Trainers + +The completion logging callback automatically activates for these trainer types: + +- **DPO (Direct Preference Optimization)**: Logs prompts, chosen, rejected, reward_diff +- **KTO (Kahneman-Tversky Optimization)**: Logs prompts, completions, labels, rewards +- **ORPO (Odds Ratio Preference Optimization)**: Logs prompts, chosen, rejected, log_odds_ratio +- **GRPO (Group Relative Policy Optimization)**: Logs prompts, completions, rewards, advantages +- **CPO (Constrained Policy Optimization)**: Logs prompts, chosen, rejected + +For non-RLHF trainers (standard supervised fine-tuning), the completion callback is automatically skipped. + +### How It Works + +1. **Auto-detection**: Plugin detects trainer type at initialization +2. **Buffering**: Completions are buffered in memory (up to `swanlab_completion_max_buffer`) +3. **Periodic logging**: Every `swanlab_completion_log_interval` steps, buffer is logged to SwanLab +4. **Memory safety**: Old completions are automatically dropped when buffer is full (uses `collections.deque`) +5. **Final flush**: Remaining completions are logged when training completes + +### Viewing Completion Tables + +After training starts, you can view completion tables in your SwanLab dashboard: + +1. Navigate to your experiment in SwanLab +2. Look for the "rlhf_completions" table in the metrics panel +3. The table shows: + - **step**: Training step when completion was generated + - **prompt**: Input prompt + - **chosen**: Preferred response (DPO/ORPO) + - **rejected**: Non-preferred response (DPO/ORPO) + - **completion**: Model output (KTO/GRPO) + - **reward_diff/reward**: Reward metrics + - Trainer-specific metrics (e.g., log_odds_ratio for ORPO) + +### Memory Management + +The completion buffer is **memory-bounded** to prevent memory leaks: + +```python +# Internal implementation uses deque with maxlen +from collections import deque + +buffer = deque(maxlen=128) # Old completions automatically dropped +``` + +**Memory usage estimate**: +- Average completion: ~500 characters (prompt + responses) +- Buffer size 128: ~64 KB (negligible) +- Buffer size 1024: ~512 KB (still small) + +**Recommendation**: Default buffer size (128) works well for most cases. Increase to 512-1024 only if you need to review more historical completions. + +### Performance Impact + +Completion logging has minimal overhead: + +- **Buffering**: O(1) append operation, negligible CPU/memory +- **Logging**: Only happens every N steps (default: 100) +- **Network**: SwanLab batches table uploads efficiently + +**Expected overhead**: < 0.5% per training step + +### Troubleshooting + +#### Completions not appearing in SwanLab + +**Cause**: Trainer may not be logging completion data in the expected format. + +**Diagnostic steps**: +1. Check trainer type detection in logs: + ```text + INFO: SwanLab RLHF completion logging enabled for DPOTrainer (type: dpo) + ``` +2. Verify your trainer is an RLHF trainer (DPO/KTO/ORPO/GRPO) +3. Check if trainer logs completion data (this depends on TRL version) + +**Note**: The current implementation expects trainers to log completion data in the `logs` dict during `on_log()` callback. Some TRL trainers may not expose this data by default. You may need to patch the trainer to expose completions. + +#### Buffer fills up too quickly + +**Cause**: High logging frequency with small buffer size. + +**Solution**: Increase buffer size or logging interval: +```yaml +swanlab_completion_log_interval: 200 # Log less frequently +swanlab_completion_max_buffer: 512 # Larger buffer +``` + +#### Memory usage growing over time + +**Cause**: Buffer should be bounded, so this indicates a bug. + +**Solution**: +1. Verify `swanlab_completion_max_buffer` is set +2. Check SwanLab version is up to date +3. Report issue with memory profiling data + +## Performance Profiling + +SwanLab integration includes profiling utilities to measure and log execution time of trainer methods. This helps you: + +- **Identify bottlenecks**: Find slow operations in your training loop +- **Optimize performance**: Track improvements after optimization changes +- **Monitor distributed training**: See per-rank timing differences +- **Debug hangs**: Detect methods that take unexpectedly long + +### Features + +- ✅ **Zero-config profiling**: Automatic timing of key trainer methods +- ✅ **Decorator-based**: Easy to add profiling to custom methods with `@swanlab_profile` +- ✅ **Context manager**: Fine-grained profiling with `swanlab_profiling_context()` +- ✅ **Advanced filtering**: `ProfilingConfig` for throttling and minimum duration thresholds +- ✅ **Exception-safe**: Logs duration even if function raises an exception + +### Basic Usage: Decorator + +Add profiling to any trainer method with the `@swanlab_profile` decorator: + +```python +from axolotl.integrations.swanlab.profiling import swanlab_profile + +class MyCustomTrainer(AxolotlTrainer): + @swanlab_profile + def training_step(self, model, inputs): + # Your training step logic + return super().training_step(model, inputs) + + @swanlab_profile + def prediction_step(self, model, inputs, prediction_loss_only): + # Your prediction logic + return super().prediction_step(model, inputs, prediction_loss_only) +``` + +The decorator automatically: +1. Measures execution time with high-precision timer +2. Logs to SwanLab as `profiling/Time taken: ClassName.method_name` +3. Only logs if SwanLab is enabled (`use_swanlab: true`) +4. Gracefully handles exceptions (logs duration, then re-raises) + +### Advanced Usage: Context Manager + +For fine-grained profiling within a method: + +```python +from axolotl.integrations.swanlab.profiling import swanlab_profiling_context + +class MyTrainer(AxolotlTrainer): + def complex_training_step(self, model, inputs): + # Profile just the forward pass + with swanlab_profiling_context(self, "forward_pass"): + outputs = model(**inputs) + + # Profile just the backward pass + with swanlab_profiling_context(self, "backward_pass"): + loss = outputs.loss + loss.backward() + + return outputs +``` + +### Advanced Usage: ProfilingConfig + +Filter and throttle profiling logs with `ProfilingConfig`: + +```python +from axolotl.integrations.swanlab.profiling import ( + swanlab_profiling_context_advanced, + ProfilingConfig, +) + +# Create custom profiling config +profiling_config = ProfilingConfig( + enabled=True, + min_duration_ms=1.0, # Only log if duration > 1ms + log_interval=10, # Log every 10th call +) + +class MyTrainer(AxolotlTrainer): + def frequently_called_method(self, data): + with swanlab_profiling_context_advanced( + self, + "frequent_op", + config=profiling_config + ): + # This only logs every 10th call, and only if it takes > 1ms + result = expensive_computation(data) + return result +``` + +**ProfilingConfig Parameters**: +- `enabled`: Enable/disable profiling globally (default: `True`) +- `min_duration_ms`: Minimum duration to log in milliseconds (default: `0.1`) +- `log_interval`: Log every Nth function call (default: `1` = log all) + +**Use cases**: +- **High-frequency methods**: Use `log_interval=100` to reduce logging overhead +- **Filter noise**: Use `min_duration_ms=1.0` to skip very fast operations +- **Debugging**: Use `log_interval=1, min_duration_ms=0.0` to log everything + +### Viewing Profiling Metrics + +In your SwanLab dashboard, profiling metrics appear under the "profiling" namespace: + +```text +profiling/Time taken: AxolotlTrainer.training_step +profiling/Time taken: AxolotlTrainer.prediction_step +profiling/Time taken: MyTrainer.forward_pass +profiling/Time taken: MyTrainer.backward_pass +``` + +You can: +- **Track over time**: See if methods get faster/slower during training +- **Compare runs**: Compare profiling metrics across experiments +- **Identify regressions**: Detect if a code change slowed down training + +### Configuration in Axolotl Config + +Profiling is automatically enabled when SwanLab is enabled. No additional config needed: + +```yaml +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + +use_swanlab: true +swanlab_project: my-project + +# Profiling is automatically enabled +# Add @swanlab_profile decorators to your custom trainer methods +``` + +To disable profiling while keeping SwanLab enabled: + +```python +# In your custom trainer code +from axolotl.integrations.swanlab.profiling import DEFAULT_PROFILING_CONFIG + +# Disable profiling globally +DEFAULT_PROFILING_CONFIG.enabled = False +``` + +### Performance Impact + +- **Decorator overhead**: ~2-5 microseconds per call (negligible) +- **Context manager overhead**: ~1-3 microseconds (negligible) +- **Logging overhead**: Only when SwanLab is enabled and method duration exceeds threshold +- **Network overhead**: SwanLab batches metrics efficiently + +**Expected overhead**: < 0.1% per training step (effectively zero) + +### Best Practices + +1. **Profile bottlenecks first**: Start by profiling suspected slow operations +2. **Use min_duration_ms**: Filter out fast operations (< 1ms) to reduce noise +3. **Throttle high-frequency calls**: Use `log_interval` for methods called > 100 times/step +4. **Profile across runs**: Compare profiling metrics before/after optimization +5. **Monitor distributed training**: Check for rank-specific slowdowns + +### Example: Complete Profiling Setup + +```python +from axolotl.integrations.swanlab.profiling import ( + swanlab_profile, + swanlab_profiling_context, + ProfilingConfig, +) + +class OptimizedTrainer(AxolotlTrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Custom profiling config for high-frequency operations + self.fast_op_config = ProfilingConfig( + enabled=True, + min_duration_ms=0.5, + log_interval=50, + ) + + @swanlab_profile + def training_step(self, model, inputs): + """Main training step - always profile.""" + return super().training_step(model, inputs) + + @swanlab_profile + def compute_loss(self, model, inputs, return_outputs=False): + """Loss computation - always profile.""" + return super().compute_loss(model, inputs, return_outputs) + + def _prepare_inputs(self, inputs): + """High-frequency operation - throttled profiling.""" + with swanlab_profiling_context_advanced( + self, + "prepare_inputs", + config=self.fast_op_config, + ): + return super()._prepare_inputs(inputs) +``` + +### Troubleshooting + +#### Profiling metrics not appearing in SwanLab + +**Cause**: SwanLab is not enabled or not initialized. + +**Solution**: +```yaml +# Ensure SwanLab is enabled +use_swanlab: true +swanlab_project: my-project +``` + +Check logs for: +```text +INFO: SwanLab initialized for project: my-project +``` + +#### Too many profiling metrics cluttering dashboard + +**Cause**: Profiling every function call for high-frequency operations. + +**Solution**: Use `ProfilingConfig` with throttling: +```python +config = ProfilingConfig( + min_duration_ms=1.0, # Skip fast ops + log_interval=100, # Log every 100th call +) +``` + +#### Profiling overhead impacting training speed + +**Cause**: Profiling itself should have negligible overhead (< 0.1%). If you see > 1% slowdown, this indicates a bug. + +**Solution**: +1. Disable profiling temporarily to confirm: + ```python + DEFAULT_PROFILING_CONFIG.enabled = False + ``` +2. Report issue with profiling data and trainer details + +#### Profiling shows inconsistent timing + +**Cause**: Normal variation due to GPU warmup, data loading, or system load. + +**Solution**: +- Ignore first few steps (warmup period) +- Look at average/median timing over many steps +- Use `log_interval` to reduce noise from individual outliers + +## Complete Config Example + +Here's a complete example integrating SwanLab with your RVQ-Alpha training: + +```yaml +base_model: /path/to/your/model +model_type: Qwen2ForCausalLM + +# SwanLab Integration +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +use_swanlab: true +swanlab_project: RVQ-Alpha-Training +swanlab_experiment_name: Qwen2.5-7B-MetaQA-Perturb-P020 +swanlab_description: "Training on MetaQA and Perturbation datasets with NEW-RVQ encoding" +swanlab_mode: cloud +swanlab_workspace: single-cell-genomics + +# Training configuration +sequence_len: 32768 +micro_batch_size: 1 +gradient_accumulation_steps: 1 +num_epochs: 2 +learning_rate: 2e-5 +optimizer: adamw_torch_fused + +# Datasets +datasets: + - path: /path/to/dataset + type: chat_template + +# Output +output_dir: ./outputs +``` + +## Modes Explained + +### `cloud` Mode (Default) +- Syncs experiments to SwanLab cloud in real-time +- Requires API key and internet connection +- Best for: Team collaboration, remote monitoring + +### `local` Mode +- Saves experiments locally only +- No cloud sync +- Best for: Local development, air-gapped environments + +### `offline` Mode +- Saves metadata locally +- Can sync to cloud later using `swanlab sync` +- Best for: Unstable internet, sync later + +### `disabled` Mode +- Turns off SwanLab completely +- No logging or tracking +- Best for: Debugging, testing + +## Configuration Validation & Conflict Detection + +SwanLab integration includes comprehensive validation and conflict detection to help you catch configuration errors early and avoid performance issues. + +### Required Fields Validation + +The plugin validates your configuration at startup and provides clear error messages with solutions: + +#### Missing Project Name + +```yaml +# ❌ INVALID: use_swanlab enabled but no project +use_swanlab: true +# Error: SwanLab enabled but 'swanlab_project' is not set. +``` + +**Solution**: +```yaml +# ✅ VALID: Provide project name +use_swanlab: true +swanlab_project: my-project +``` + +#### Invalid Mode + +```yaml +# ❌ INVALID: Unknown mode +use_swanlab: true +swanlab_project: my-project +swanlab_mode: invalid-mode +# Error: Invalid swanlab_mode: 'invalid-mode'. Valid options: cloud, local, offline, disabled +``` + +**Solution**: +```yaml +# ✅ VALID: Use one of the valid modes +use_swanlab: true +swanlab_project: my-project +swanlab_mode: cloud # or: local, offline, disabled +``` + +#### Empty Project Name + +```yaml +# ❌ INVALID: Empty string project name +use_swanlab: true +swanlab_project: "" +# Error: swanlab_project cannot be an empty string. +``` + +**Solution**: +```yaml +# ✅ VALID: Provide non-empty project name +use_swanlab: true +swanlab_project: my-project +``` + +### Cloud Mode API Key Warning + +When using `cloud` mode without an API key, you'll receive a warning with multiple solutions: + +```yaml +use_swanlab: true +swanlab_project: my-project +swanlab_mode: cloud +# No API key set +# Warning: SwanLab cloud mode enabled but no API key found. +``` + +**Solutions**: +1. Set environment variable: `export SWANLAB_API_KEY=your-api-key` +2. Add to config (less secure): `swanlab_api_key: your-api-key` +3. Run `swanlab login` before training +4. Use `swanlab_mode: local` for offline tracking + +### Multi-Logger Performance Warnings + +Using multiple logging tools simultaneously (SwanLab + WandB + MLflow + Comet) can impact training performance: + +#### Two Loggers - Warning + +```yaml +use_swanlab: true +swanlab_project: my-project + +use_wandb: true +wandb_project: my-project + +# Warning: Multiple logging tools enabled: SwanLab, WandB +# Expected overhead: ~3.0% per training step. +``` + +**Impact**: +- Performance overhead: ~1-2% per logger (cumulative) +- Increased memory usage +- Longer training time per step +- Potential config/callback conflicts + +**Recommendations**: +- Choose ONE primary logging tool for production training +- Use multiple loggers only for: + - Migration period (transitioning between tools) + - Short comparison runs + - Debugging specific tool issues +- Monitor system resources (CPU, memory) during training + +#### Three+ Loggers - Error-Level Warning + +```yaml +use_swanlab: true +swanlab_project: my-project + +use_wandb: true +wandb_project: my-project + +use_mlflow: true +mlflow_tracking_uri: http://localhost:5000 + +# ERROR: 3 logging tools enabled simultaneously! +# Expected overhead: ~4.5% per training step. +# STRONGLY RECOMMEND: Disable all but ONE logging tool +``` + +**Why This Matters**: +- With 3 loggers: ~4-5% overhead per step → significant slowdown over long training +- Example: 10,000 steps at 2s/step → ~400-500 seconds extra (6-8 minutes) +- Memory overhead scales with number of loggers +- Rare edge cases with callback ordering conflicts + +### Auto-Enable Logic + +For convenience, SwanLab will auto-enable if you specify a project without setting `use_swanlab`: + +```yaml +# This configuration: +swanlab_project: my-project + +# Automatically becomes: +use_swanlab: true +swanlab_project: my-project +``` + +### Distributed Training Detection + +In distributed training scenarios (multi-GPU), the plugin automatically detects and reports: + +```yaml +use_swanlab: true +swanlab_project: my-project +swanlab_mode: cloud + +# When running with torchrun --nproc_per_node=4: +# Info: Distributed training detected (world_size=4) +# Info: SwanLab mode: cloud +# Info: Only rank 0 will initialize SwanLab +# Info: Other ranks will skip SwanLab to avoid conflicts +``` + +**Why Only Rank 0**: +- Avoids duplicate experiment runs +- Reduces network/cloud API overhead on worker ranks +- Prevents race conditions in metric logging + +## Authentication + +### Method 1: Environment Variable (Recommended) +```bash +export SWANLAB_API_KEY=your-api-key-here +``` + +### Method 2: Login Command +```bash +swanlab login +# Enter your API key when prompted +``` + +### Method 3: Config File +```yaml +swanlab_api_key: your-api-key-here +``` + +## What Gets Logged? + +### Automatically Logged Metrics +- Training loss +- Learning rate +- Gradient norm +- Training steps +- Epoch progress + +### Automatically Logged Config +- Model configuration (base_model, model_type) +- Training hyperparameters (learning_rate, batch_size, etc.) +- Optimizer settings +- Parallelization settings (FSDP, DeepSpeed, Context Parallel) +- Axolotl configuration file +- DeepSpeed configuration (if used) + +## Viewing Your Experiments + +### Cloud Mode +Visit [https://swanlab.cn](https://swanlab.cn) and navigate to your project to view: +- Real-time training metrics +- Hyperparameter comparison +- System resource usage +- Configuration files + +### Local Mode +```bash +# Start local dashboard +swanlab watch ./swanlog + +# Open browser to http://localhost:5092 +``` + +## Integration with Existing Tools + +SwanLab can work alongside other tracking tools: + +```yaml +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + +# Use both SwanLab and Wandb +use_swanlab: true +swanlab_project: my-project + +use_wandb: true +wandb_project: my-project +``` + +## Troubleshooting + +### Configuration Errors + +#### Error: "SwanLab enabled but 'swanlab_project' is not set" + +**Cause**: You enabled SwanLab (`use_swanlab: true`) but forgot to specify a project name. + +**Solution**: +```yaml +use_swanlab: true +swanlab_project: my-project # Add this line +``` + +#### Error: "Invalid swanlab_mode: 'xxx'" + +**Cause**: You provided an invalid mode value. + +**Solution**: Use one of the valid modes: +```yaml +swanlab_mode: cloud # or: local, offline, disabled +``` + +#### Error: "swanlab_project cannot be an empty string" + +**Cause**: You set `swanlab_project: ""` (empty string). + +**Solution**: Either provide a valid name or remove the field: +```yaml +# Option 1: Provide valid name +swanlab_project: my-project + +# Option 2: Remove the field entirely +# swanlab_project: "" <- Remove this line +``` + +### Import Errors + +#### Error: "SwanLab is not installed" + +**Cause**: SwanLab package is not installed in your environment. + +**Solution**: +```bash +pip install swanlab +# or +pip install swanlab>=0.3.0 +``` + +### Performance Issues + +#### Warning: "Multiple logging tools enabled" + +**Cause**: You have multiple experiment tracking tools enabled (e.g., SwanLab + WandB + MLflow). + +**Impact**: ~1-2% performance overhead per logger, cumulative. + +**Solution**: For production training, disable all but one logger: +```yaml +# Option 1: Keep only SwanLab +use_swanlab: true +swanlab_project: my-project +use_wandb: false # Disable others +use_mlflow: false + +# Option 2: Keep only WandB +use_swanlab: false +use_wandb: true +wandb_project: my-project +``` + +**Exception**: Multiple loggers are acceptable for: +- Short comparison runs (< 100 steps) +- Migration testing between logging tools +- Debugging logger-specific issues + +### Distributed Training Issues + +#### SwanLab creates duplicate runs in multi-GPU training + +**Cause**: All ranks are initializing SwanLab instead of just rank 0. + +**Expected Behavior**: The plugin automatically ensures only rank 0 initializes SwanLab. You should see: +```text +Info: Distributed training detected (world_size=4) +Info: Only rank 0 will initialize SwanLab +Info: Other ranks will skip SwanLab to avoid conflicts +``` + +**If you see duplicates**: +1. Check your plugin is loaded correctly +2. Verify you're using the latest SwanLab integration code +3. Check logs for initialization messages on all ranks + +### SwanLab not logging metrics + +**Solution**: Ensure SwanLab is initialized before training starts. The plugin automatically handles this in `pre_model_load`. + +### API Key errors + +**Solution**: +```bash +# Verify API key +echo $SWANLAB_API_KEY + +# Re-login +swanlab login +``` + +### Cloud sync issues + +**Solution**: Use `offline` mode and sync later: +```yaml +swanlab_mode: offline +``` + +Then sync when ready: +```bash +swanlab sync ./swanlog +``` + +### Plugin not loaded + +**Solution**: Verify plugin path in config: +```yaml +plugins: + - axolotl.integrations.swanlab.SwanLabPlugin # Correct path +``` + +### Lark Notification Issues + +#### Error: "Failed to import SwanLab Lark plugin" + +**Cause**: Your SwanLab version doesn't include the Lark plugin (requires SwanLab >= 0.3.0). + +**Solution**: +```bash +# Upgrade SwanLab to latest version +pip install --upgrade swanlab + +# Or install specific version +pip install 'swanlab>=0.3.0' +``` + +#### Warning: "Lark webhook has no secret configured" + +**Cause**: You provided `swanlab_lark_webhook_url` but no `swanlab_lark_secret`. + +**Impact**: Lark notifications will work, but without HMAC authentication (security risk). + +**Solution**: Add HMAC secret for production use: +```yaml +swanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxx +swanlab_lark_secret: your-webhook-secret # Add this line +``` + +**When it's OK to skip secret**: +- Local development and testing +- Internal networks with restricted access +- Non-sensitive training experiments + +**When secret is required**: +- Production training jobs +- Training with proprietary data +- Multi-team shared Lark groups + +#### Error: "Failed to register Lark callback" + +**Cause**: Invalid webhook URL or network connectivity issues. + +**Diagnostic steps**: +```bash +# 1. Test webhook URL manually +curl -X POST "YOUR_WEBHOOK_URL" \ + -H 'Content-Type: application/json' \ + -d '{"msg_type":"text","content":{"text":"Test from Axolotl"}}' + +# 2. Check SwanLab version +pip show swanlab + +# 3. Verify webhook URL format +# Should start with: https://open.feishu.cn/open-apis/bot/v2/hook/ +``` + +**Solution**: +1. Verify webhook URL is correct (copy from Lark bot settings) +2. Check network connectivity to Lark API +3. Ensure webhook is not expired (Lark webhooks can expire) +4. Regenerate webhook URL in Lark bot settings if needed + +#### Lark notifications not received + +**Cause**: Multiple possible causes. + +**Diagnostic checklist**: + +1. **Check training logs** for Lark registration confirmation: + ```text + # Expected log message (rank 0 only): + INFO: Registered Lark notification callback with HMAC authentication + ``` + +2. **Verify webhook in Lark**: Test webhook manually (see above) + +3. **Check distributed training**: Only rank 0 sends notifications + ```bash + # If running multi-GPU, check rank 0 logs specifically + grep "Registered Lark" logs/rank_0.log + ``` + +4. **Verify SwanLab is initialized**: Lark callback needs SwanLab to be running + ```yaml + use_swanlab: true # Must be enabled + swanlab_project: my-project # Must be set + ``` + +5. **Check Lark bot permissions**: Ensure bot is added to the target group chat + +#### Duplicate Lark notifications in multi-GPU training + +**Expected Behavior**: Should NOT happen - only rank 0 sends notifications. + +**If you see duplicates**: +1. Check that all GPUs are using the same config file +2. Verify plugin is loaded correctly on all ranks +3. Check logs for unexpected Lark initialization on non-zero ranks +4. Ensure `RANK` or `LOCAL_RANK` environment variables are set correctly + +**Solution**: This is a bug if it occurs. Report with: +- Full training command +- Logs from all ranks +- Config file + +## Comparison: SwanLab vs WandB + +| Feature | SwanLab | WandB | +|---------|---------|-------| +| Open Source | ✅ Yes | ❌ No | +| Self-Hosting | ✅ Easy | ⚠️ Complex | +| Free Tier | ✅ Generous | ⚠️ Limited | +| Chinese Support | ✅ Native | ⚠️ Limited | +| Offline Mode | ✅ Full support | ✅ Supported | +| Integration | 🆕 New | ✅ Mature | + +## Advanced Usage + +### Custom Logging + +You can add custom metrics in your callbacks: + +```python +import swanlab + +# In your custom callback +swanlab.log({ + "custom_metric": value, + "epoch": epoch_num +}) +``` + +### Experiment Comparison + +```bash +# Compare multiple experiments +swanlab compare run1 run2 run3 +``` + +## Support + +- **Documentation**: [https://docs.swanlab.cn](https://docs.swanlab.cn) +- **GitHub**: [https://github.com/SwanHubX/SwanLab](https://github.com/SwanHubX/SwanLab) +- **Issues**: Report bugs at [GitHub Issues](https://github.com/SwanHubX/SwanLab/issues) + +## License + +This integration follows the Axolotl Community License Agreement. + +## Acknowledgements + +This integration is built on top of: +- [SwanLab](https://github.com/SwanHubX/SwanLab) - Experiment tracking tool +- [Transformers](https://github.com/huggingface/transformers) - SwanLabCallback +- [Axolotl](https://github.com/axolotl-ai-cloud/axolotl) - Training framework diff --git a/src/axolotl/integrations/swanlab/__init__.py b/src/axolotl/integrations/swanlab/__init__.py new file mode 100644 index 000000000..241a27764 --- /dev/null +++ b/src/axolotl/integrations/swanlab/__init__.py @@ -0,0 +1,6 @@ +"""SwanLab integration plugin for Axolotl""" + +from axolotl.integrations.swanlab.args import SwanLabConfig +from axolotl.integrations.swanlab.plugins import SwanLabPlugin + +__all__ = ["SwanLabConfig", "SwanLabPlugin"] diff --git a/src/axolotl/integrations/swanlab/args.py b/src/axolotl/integrations/swanlab/args.py new file mode 100644 index 000000000..2cf31252d --- /dev/null +++ b/src/axolotl/integrations/swanlab/args.py @@ -0,0 +1,140 @@ +"""SwanLab configuration arguments""" + +from pydantic import BaseModel, Field, field_validator, model_validator + + +class SwanLabConfig(BaseModel): + """SwanLab configuration subset""" + + use_swanlab: bool | None = Field( + default=True, + json_schema_extra={ + "description": "Enable SwanLab experiment tracking and visualization" + }, + ) + swanlab_project: str | None = Field( + default=None, + json_schema_extra={"description": "Your SwanLab project name"}, + ) + swanlab_experiment_name: str | None = Field( + default=None, + json_schema_extra={"description": "Set the name of your SwanLab experiment"}, + ) + swanlab_description: str | None = Field( + default=None, + json_schema_extra={"description": "Description for your SwanLab experiment"}, + ) + swanlab_mode: str | None = Field( + default=None, + json_schema_extra={ + "description": '"cloud" to sync to SwanLab cloud, "local" for local only, "offline" to save metadata locally, "disabled" to turn off SwanLab' + }, + ) + swanlab_workspace: str | None = Field( + default=None, + json_schema_extra={ + "description": "SwanLab workspace name (organization or username)" + }, + ) + swanlab_api_key: str | None = Field( + default=None, + json_schema_extra={ + "description": "SwanLab API key for authentication. Can also be set via SWANLAB_API_KEY environment variable" + }, + ) + swanlab_log_model: bool | None = Field( + default=False, + json_schema_extra={ + "description": "Whether to log model checkpoints to SwanLab (feature coming soon)" + }, + ) + swanlab_web_host: str | None = Field( + default=None, + json_schema_extra={ + "description": "Web address for SwanLab cloud environment (for private deployment)" + }, + ) + swanlab_api_host: str | None = Field( + default=None, + json_schema_extra={ + "description": "API address for SwanLab cloud environment (for private deployment)" + }, + ) + swanlab_lark_webhook_url: str | None = Field( + default=None, + json_schema_extra={ + "description": "Lark (Feishu) webhook URL for sending training notifications to team chat" + }, + ) + swanlab_lark_secret: str | None = Field( + default=None, + json_schema_extra={ + "description": "Secret for Lark webhook HMAC signature authentication (optional)" + }, + ) + swanlab_log_completions: bool | None = Field( + default=True, + json_schema_extra={ + "description": "Enable logging RLHF completions to SwanLab for qualitative analysis (DPO/KTO/ORPO/GRPO)" + }, + ) + swanlab_completion_log_interval: int | None = Field( + default=100, + json_schema_extra={ + "description": "Number of training steps between completion table logging to SwanLab" + }, + ) + swanlab_completion_max_buffer: int | None = Field( + default=128, + json_schema_extra={ + "description": "Maximum number of completions to buffer before logging (prevents memory leaks)" + }, + ) + + @field_validator("swanlab_mode") + @classmethod + def validate_swanlab_mode(cls, v): + """Validate swanlab_mode is one of the allowed values.""" + if v is None: + return v + + valid_modes = ["cloud", "local", "offline", "disabled"] + if v not in valid_modes: + raise ValueError( + f"Invalid swanlab_mode: '{v}'.\n\n" + f"Valid options: {', '.join(valid_modes)}\n\n" + f"Examples:\n" + f" swanlab_mode: cloud # Sync to SwanLab cloud\n" + f" swanlab_mode: local # Local only, no cloud sync\n" + f" swanlab_mode: offline # Save metadata locally\n" + f" swanlab_mode: disabled # Turn off SwanLab\n" + ) + return v + + @field_validator("swanlab_project") + @classmethod + def validate_swanlab_project(cls, v): + """Validate swanlab_project is non-empty when provided.""" + if v is not None and isinstance(v, str) and len(v.strip()) == 0: + raise ValueError( + "swanlab_project cannot be an empty string.\n\n" + "Either:\n" + " 1. Provide a valid project name: swanlab_project: my-project\n" + " 2. Remove the swanlab_project field entirely\n" + ) + return v + + @model_validator(mode="after") + def validate_swanlab_enabled_requires_project(self): + """Validate that if use_swanlab is True, swanlab_project must be set.""" + if self.use_swanlab is True and not self.swanlab_project: + raise ValueError( + "SwanLab enabled (use_swanlab: true) but 'swanlab_project' is not set.\n\n" + "Solutions:\n" + " 1. Add 'swanlab_project: your-project-name' to your config\n" + " 2. Set 'use_swanlab: false' to disable SwanLab\n\n" + "Example:\n" + " use_swanlab: true\n" + " swanlab_project: my-llm-training\n" + ) + return self diff --git a/src/axolotl/integrations/swanlab/callbacks.py b/src/axolotl/integrations/swanlab/callbacks.py new file mode 100644 index 000000000..8dfc0fe53 --- /dev/null +++ b/src/axolotl/integrations/swanlab/callbacks.py @@ -0,0 +1,179 @@ +"""SwanLab callbacks for Axolotl trainers. + +This module provides HuggingFace Trainer callbacks for logging +RLHF completions to SwanLab. +""" + +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + +from axolotl.integrations.swanlab.completion_logger import CompletionLogger +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +class SwanLabRLHFCompletionCallback(TrainerCallback): + """Callback for logging RLHF completions to SwanLab. + + This callback periodically logs model completions (prompts, chosen/rejected + responses, rewards) to SwanLab during RLHF training for qualitative analysis. + + Supports DPO, KTO, ORPO, and GRPO trainers. + + Example usage: + >>> callback = SwanLabRLHFCompletionCallback( + ... log_interval=100, # Log every 100 steps + ... max_completions=128, # Keep last 128 completions + ... ) + >>> trainer.add_callback(callback) + + Attributes: + logger: CompletionLogger instance + log_interval: Number of steps between SwanLab logging + trainer_type: Auto-detected trainer type (dpo/kto/orpo/grpo) + """ + + def __init__( + self, + log_interval: int = 100, + max_completions: int = 128, + table_name: str = "rlhf_completions", + ): + """Initialize SwanLab RLHF completion callback. + + Args: + log_interval: Log to SwanLab every N steps. Default: 100 + max_completions: Maximum completions to buffer. Default: 128 + table_name: SwanLab table name. Default: "rlhf_completions" + """ + super().__init__() + self.logger = CompletionLogger(maxlen=max_completions) + self.log_interval = log_interval + self.table_name = table_name + self.trainer_type: str | None = None # Auto-detected + self._last_logged_step = 0 + + def on_init_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Detect trainer type on initialization.""" + trainer = kwargs.get("trainer") + if trainer is not None: + trainer_name = trainer.__class__.__name__ + if "DPO" in trainer_name: + self.trainer_type = "dpo" + elif "KTO" in trainer_name: + self.trainer_type = "kto" + elif "ORPO" in trainer_name: + self.trainer_type = "orpo" + elif "GRPO" in trainer_name: + self.trainer_type = "grpo" + else: + self.trainer_type = "unknown" + + LOG.info( + f"SwanLab RLHF completion logging enabled for {trainer_name} " + f"(type: {self.trainer_type})" + ) + + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs: dict | None = None, + **kwargs, + ): + """Capture completions from logs and buffer them. + + Different trainers log completions in different formats: + - DPO: logs['dpo/chosen'], logs['dpo/rejected'], logs['dpo/reward_diff'] + - KTO: logs['kto/completion'], logs['kto/label'], logs['kto/reward'] + - ORPO: logs['orpo/chosen'], logs['orpo/rejected'] + - GRPO: logs['grpo/completion'], logs['grpo/reward'] + + Note: This is a placeholder implementation. Actual log keys depend + on the TRL trainer implementation. You may need to patch the trainers + to expose completion data in logs. + """ + if logs is None or self.trainer_type is None: + return + + step = state.global_step + + # DPO completions + if self.trainer_type == "dpo": + if all(key in logs for key in ["dpo/prompt", "dpo/chosen", "dpo/rejected"]): + self.logger.add_dpo_completion( + step=step, + prompt=logs.get("dpo/prompt", ""), + chosen=logs.get("dpo/chosen", ""), + rejected=logs.get("dpo/rejected", ""), + reward_diff=logs.get("dpo/reward_diff"), + ) + + # KTO completions + elif self.trainer_type == "kto": + if all(key in logs for key in ["kto/prompt", "kto/completion"]): + self.logger.add_kto_completion( + step=step, + prompt=logs.get("kto/prompt", ""), + completion=logs.get("kto/completion", ""), + label=logs.get("kto/label", False), + reward=logs.get("kto/reward"), + ) + + # ORPO completions + elif self.trainer_type == "orpo": + if all( + key in logs for key in ["orpo/prompt", "orpo/chosen", "orpo/rejected"] + ): + self.logger.add_orpo_completion( + step=step, + prompt=logs.get("orpo/prompt", ""), + chosen=logs.get("orpo/chosen", ""), + rejected=logs.get("orpo/rejected", ""), + log_odds_ratio=logs.get("orpo/log_odds_ratio"), + ) + + # GRPO completions + elif self.trainer_type == "grpo": + if all(key in logs for key in ["grpo/prompt", "grpo/completion"]): + self.logger.add_grpo_completion( + step=step, + prompt=logs.get("grpo/prompt", ""), + completion=logs.get("grpo/completion", ""), + reward=logs.get("grpo/reward"), + advantage=logs.get("grpo/advantage"), + ) + + # Periodically log to SwanLab + if step - self._last_logged_step >= self.log_interval: + if len(self.logger) > 0: + self.logger.log_to_swanlab(table_name=self.table_name) + self.logger.clear() + self._last_logged_step = step + + def on_train_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Log remaining completions at end of training.""" + if len(self.logger) > 0: + LOG.info( + f"Training complete, logging final {len(self.logger)} completions to SwanLab" + ) + self.logger.log_to_swanlab(table_name=self.table_name) + self._last_logged_step = state.global_step diff --git a/src/axolotl/integrations/swanlab/completion_logger.py b/src/axolotl/integrations/swanlab/completion_logger.py new file mode 100644 index 000000000..dd709227f --- /dev/null +++ b/src/axolotl/integrations/swanlab/completion_logger.py @@ -0,0 +1,228 @@ +"""SwanLab completion logger for RLHF/DPO/KTO/ORPO/GRPO training. + +This module provides utilities for logging model completions during +preference training to SwanLab for qualitative analysis. +""" + +from collections import deque +from collections.abc import Mapping +from typing import Any + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +class CompletionLogger: + """Memory-bounded logger for RLHF completions. + + Stores prompts, completions, and rewards in fixed-size deques to prevent + memory leaks during long training runs. Logs completion tables to SwanLab + for qualitative analysis of model outputs. + + Example usage: + >>> logger = CompletionLogger(maxlen=128) + >>> logger.add_dpo_completion( + ... step=0, + ... prompt="What is AI?", + ... chosen="Artificial Intelligence is...", + ... rejected="AI means...", + ... reward_diff=0.5 + ... ) + >>> logger.log_to_swanlab() + + Attributes: + maxlen: Maximum number of completions to store (older ones are dropped) + data: Deque storing completion dictionaries + """ + + def __init__(self, maxlen: int = 128): + """Initialize completion logger with bounded buffer. + + Args: + maxlen: Maximum number of completions to store. When the buffer + is full, oldest completions are automatically discarded. + Default: 128 (sufficient for most RLHF runs without memory issues) + """ + self.maxlen = maxlen + self.data: deque[Mapping[str, Any]] = deque(maxlen=maxlen) + + def add_dpo_completion( + self, + step: int, + prompt: str, + chosen: str, + rejected: str, + reward_diff: float | None = None, + ) -> None: + """Add a DPO completion to the buffer. + + Args: + step: Training step number + prompt: Input prompt + chosen: Chosen (preferred) completion + rejected: Rejected (non-preferred) completion + reward_diff: Reward difference (chosen - rejected), if available + """ + entry = { + "step": step, + "prompt": prompt, + "chosen": chosen, + "rejected": rejected, + } + if reward_diff is not None: + entry["reward_diff"] = reward_diff + + self.data.append(entry) + + def add_kto_completion( + self, + step: int, + prompt: str, + completion: str, + label: bool, + reward: float | None = None, + ) -> None: + """Add a KTO completion to the buffer. + + Args: + step: Training step number + prompt: Input prompt + completion: Model-generated completion + label: True if desirable, False if undesirable + reward: Reward score, if available + """ + entry = { + "step": step, + "prompt": prompt, + "completion": completion, + "label": "desirable" if label else "undesirable", + } + if reward is not None: + entry["reward"] = reward + + self.data.append(entry) + + def add_orpo_completion( + self, + step: int, + prompt: str, + chosen: str, + rejected: str, + log_odds_ratio: float | None = None, + ) -> None: + """Add an ORPO completion to the buffer. + + Args: + step: Training step number + prompt: Input prompt + chosen: Chosen (preferred) completion + rejected: Rejected (non-preferred) completion + log_odds_ratio: Log odds ratio between chosen and rejected + """ + entry = { + "step": step, + "prompt": prompt, + "chosen": chosen, + "rejected": rejected, + } + if log_odds_ratio is not None: + entry["log_odds_ratio"] = log_odds_ratio + + self.data.append(entry) + + def add_grpo_completion( + self, + step: int, + prompt: str, + completion: str, + reward: float | None = None, + advantage: float | None = None, + ) -> None: + """Add a GRPO completion to the buffer. + + Args: + step: Training step number + prompt: Input prompt + completion: Model-generated completion + reward: Reward score from reward model + advantage: Advantage estimate (reward - baseline) + """ + entry = { + "step": step, + "prompt": prompt, + "completion": completion, + } + if reward is not None: + entry["reward"] = reward + if advantage is not None: + entry["advantage"] = advantage + + self.data.append(entry) + + def log_to_swanlab(self, table_name: str = "completions") -> bool: + """Log buffered completions to SwanLab as a table. + + Creates a SwanLab echarts Table with all buffered completions. + Only logs if SwanLab is initialized and data is available. + + Args: + table_name: Name of the table in SwanLab dashboard. + Default: "completions" + + Returns: + True if logging succeeded, False otherwise + """ + if not self.data: + LOG.debug("No completions to log to SwanLab") + return False + + try: + import swanlab + + if swanlab.get_run() is None: + LOG.debug("SwanLab not initialized, skipping completion logging") + return False + + # Convert deque to list of dicts + completions = list(self.data) + + # Extract headers from first entry (all entries should have same structure) + headers = list(completions[0].keys()) + + # Build rows: each completion becomes one row + rows = [] + for completion in completions: + row = [completion.get(header, "") for header in headers] + rows.append(row) + + # Log to SwanLab as echarts Table + swanlab.log({table_name: swanlab.echarts.Table().add(headers, rows)}) + + LOG.info(f"Logged {len(rows)} completions to SwanLab table '{table_name}'") + return True + + except ImportError: + LOG.warning( + "SwanLab not installed, cannot log completions. " + "Install with: pip install swanlab" + ) + return False + except Exception as err: # pylint: disable=broad-except + LOG.exception("Failed to log completions to SwanLab: %s", err) + return False + + def clear(self) -> None: + """Clear all buffered completions.""" + self.data.clear() + + def __len__(self) -> int: + """Return number of buffered completions.""" + return len(self.data) + + def __repr__(self) -> str: + """String representation showing buffer status.""" + return ( + f"CompletionLogger(maxlen={self.maxlen}, " + f"buffered={len(self.data)}/{self.maxlen})" + ) diff --git a/src/axolotl/integrations/swanlab/plugins.py b/src/axolotl/integrations/swanlab/plugins.py new file mode 100644 index 000000000..16218d39d --- /dev/null +++ b/src/axolotl/integrations/swanlab/plugins.py @@ -0,0 +1,554 @@ +"""SwanLab Plugin for Axolotl""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from axolotl.integrations.base import BasePlugin +from axolotl.utils.logging import get_logger + +if TYPE_CHECKING: + from transformers import TrainerCallback + + from axolotl.utils.dict import DictDefault + +LOG = get_logger(__name__) + + +class SwanLabPlugin(BasePlugin): + """ + SwanLab integration plugin for Axolotl. + + Provides experiment tracking, visualization, and logging capabilities + using SwanLab (https://swanlab.cn). + + Usage in config.yaml: + plugins: + - axolotl.integrations.swanlab.SwanLabPlugin + + use_swanlab: true + swanlab_project: my-project + swanlab_experiment_name: my-experiment + swanlab_mode: cloud # or 'local', 'offline', 'disabled' + """ + + def __init__(self): + super().__init__() + self.swanlab_initialized = False + LOG.info("SwanLab plugin initialized") + + def get_input_args(self) -> str: + """Returns the configuration model for SwanLab integration.""" + return "axolotl.integrations.swanlab.SwanLabConfig" + + def register(self, cfg: dict): + """Register SwanLab plugin with configuration and conflict detection.""" + LOG.info("Registering SwanLab plugin") + + # === Conflict Detection: Required Fields === + + # Check if SwanLab is enabled + if cfg.get("use_swanlab"): + # 1. Validate project name is set + if not cfg.get("swanlab_project"): + raise ValueError( + "SwanLab enabled but 'swanlab_project' is not set.\n\n" + "Solutions:\n" + " 1. Add 'swanlab_project: your-project-name' to your config\n" + " 2. Set 'use_swanlab: false' to disable SwanLab\n\n" + "See: src/axolotl/integrations/swanlab/README.md for examples" + ) + + # 2. Validate swanlab_mode value + valid_modes = ["cloud", "local", "offline", "disabled"] + mode = cfg.get("swanlab_mode") + if mode and mode not in valid_modes: + raise ValueError( + f"Invalid swanlab_mode: '{mode}'.\n\n" + f"Valid options: {', '.join(valid_modes)}\n\n" + f"Example:\n" + f" swanlab_mode: cloud # Sync to SwanLab cloud\n" + f" swanlab_mode: local # Local only, no cloud sync\n" + ) + + # 3. Check API key for cloud mode + import os + + mode = cfg.get("swanlab_mode", "cloud") # Default is cloud + if mode == "cloud": + api_key = cfg.get("swanlab_api_key") or os.environ.get( + "SWANLAB_API_KEY" + ) + if not api_key: + LOG.warning( + "SwanLab cloud mode enabled but no API key found.\n" + "SwanLab may fail to initialize during training.\n\n" + "Solutions:\n" + " 1. Set SWANLAB_API_KEY environment variable:\n" + " export SWANLAB_API_KEY=your-api-key\n" + " 2. Add 'swanlab_api_key: your-api-key' to config (less secure)\n" + " 3. Run 'swanlab login' before training\n" + " 4. Use 'swanlab_mode: local' for offline tracking\n" + ) + + # === Conflict Detection: Multi-Logger Performance Warning === + + # Detect all active logging tools + active_loggers = [] + if cfg.get("use_wandb"): + active_loggers.append("WandB") + if cfg.get("use_mlflow"): + active_loggers.append("MLflow") + if cfg.get("comet_api_key") or cfg.get("comet_project_name"): + active_loggers.append("Comet") + if cfg.get("use_swanlab"): + active_loggers.append("SwanLab") + + if len(active_loggers) > 1: + LOG.warning( + f"\n{'=' * 70}\n" + f"Multiple logging tools enabled: {', '.join(active_loggers)}\n" + f"{'=' * 70}\n" + f"This may cause:\n" + f" - Performance overhead (~1-2% per logger, cumulative)\n" + f" - Increased memory usage\n" + f" - Longer training time per step\n" + f" - Potential config/callback conflicts\n\n" + f"Recommendations:\n" + f" - Choose ONE primary logging tool for production training\n" + f" - Use multiple loggers only for:\n" + f" * Migration period (transitioning between tools)\n" + f" * Short comparison runs\n" + f" * Debugging specific tool issues\n" + f" - Monitor system resources (CPU, memory) during training\n" + f"{'=' * 70}\n" + ) + + if len(active_loggers) >= 3: + LOG.error( + f"\n{'!' * 70}\n" + f"WARNING: {len(active_loggers)} logging tools enabled simultaneously!\n" + f"{'!' * 70}\n" + f"This is likely unintentional and WILL significantly impact performance.\n" + f"Expected overhead: ~{len(active_loggers) * 1.5:.1f}% per training step.\n\n" + f"STRONGLY RECOMMEND:\n" + f" - Disable all but ONE logging tool\n" + f" - Use config inheritance to manage multiple configs\n" + f"{'!' * 70}\n" + ) + + # === Auto-Enable Logic === + + # Enable SwanLab if project is specified + if cfg.get("swanlab_project") and not cfg.get("use_swanlab"): + cfg["use_swanlab"] = True + LOG.info("Automatically enabled use_swanlab because swanlab_project is set") + + def pre_model_load(self, cfg: DictDefault): + """Initialize SwanLab before model loading with runtime checks.""" + if not cfg.use_swanlab: + return + + # === Runtime Check: Import Availability === + try: + import swanlab + except ImportError as err: + raise ImportError( + "SwanLab is not installed.\n\n" + "Install with:\n" + " pip install swanlab\n\n" + "Or add to requirements:\n" + " swanlab>=0.3.0\n\n" + f"Original error: {err}" + ) from err + + # Log SwanLab version + try: + swanlab_version = swanlab.__version__ + LOG.info(f"SwanLab version: {swanlab_version}") + except AttributeError: + LOG.warning("Could not determine SwanLab version") + + # === Runtime Check: Distributed Training Setup === + from axolotl.utils.distributed import get_world_size, is_main_process + + world_size = get_world_size() + if world_size > 1: + mode = getattr(cfg, "swanlab_mode", "cloud") + LOG.info( + f"\n{'=' * 70}\n" + f"Distributed training detected (world_size={world_size})\n" + f"SwanLab mode: {mode}\n" + f"{'=' * 70}\n" + f"Behavior:\n" + f" - Only rank 0 will initialize SwanLab\n" + f" - Other ranks will skip SwanLab to avoid conflicts\n" + ) + + if mode == "cloud": + LOG.info( + f" - Only rank 0 will upload to SwanLab cloud\n" + f" - Other ranks run without SwanLab overhead\n" + f"{'=' * 70}\n" + ) + + # Only initialize SwanLab on the main process (rank 0) + # to avoid creating multiple runs in distributed training + if not is_main_process(): + LOG.debug("Skipping SwanLab initialization on non-main process") + return + + # Initialize SwanLab run (passing all params directly to init) + try: + init_kwargs = self._get_swanlab_init_kwargs(cfg) + swanlab.init(**init_kwargs) + self.swanlab_initialized = True + LOG.info(f"SwanLab initialized with project: {cfg.swanlab_project}") + + # Register Lark notification callback (if configured) + self._register_lark_callback(cfg) + + # Log configuration (with error handling) + try: + config_dict = self._prepare_config_for_logging(cfg) + swanlab.config.update(config_dict) + LOG.debug("Successfully logged config to SwanLab") + except Exception as config_err: # pylint: disable=broad-except + LOG.warning( + f"Failed to log config to SwanLab: {config_err}. Continuing anyway." + ) + + except Exception as err: # pylint: disable=broad-except + LOG.exception("Failed to initialize SwanLab: %s", err) + self.swanlab_initialized = False + + def add_callbacks_pre_trainer(self, cfg: DictDefault, model): + """Add SwanLab callbacks before trainer creation.""" + callbacks: list[TrainerCallback] = [] + + if not cfg.use_swanlab: + return callbacks + + if not self.swanlab_initialized: + LOG.warning("SwanLab not initialized, skipping callback registration") + return callbacks + + try: + from axolotl.utils.callbacks.swanlab import ( + CustomSwanLabCallback, + SaveAxolotlConfigtoSwanLabCallback, + ) + + # Add our custom lightweight SwanLabCallback + # (avoids omegaconf/antlr4 version conflicts) + swanlab_callback = CustomSwanLabCallback() + callbacks.append(swanlab_callback) + LOG.info("Added CustomSwanLabCallback for metrics logging") + + # Add Axolotl config logging callback + if cfg.axolotl_config_path: + config_callback = SaveAxolotlConfigtoSwanLabCallback( + cfg.axolotl_config_path + ) + callbacks.append(config_callback) + LOG.info("Added SaveAxolotlConfigtoSwanLabCallback") + + except ImportError as err: + LOG.exception("Failed to import SwanLab callbacks: %s", err) + + return callbacks + + def post_trainer_create(self, cfg: DictDefault, trainer): + """Post-trainer creation hook.""" + if cfg.use_swanlab and self.swanlab_initialized: + try: + import swanlab + + # Log additional trainer information (with safe conversion) + trainer_config = { + "total_steps": int(trainer.state.max_steps) + if trainer.state.max_steps + else None, + "num_train_epochs": float(trainer.args.num_train_epochs) + if trainer.args.num_train_epochs + else None, + "train_batch_size": int(trainer.args.train_batch_size) + if hasattr(trainer.args, "train_batch_size") + else None, + "gradient_accumulation_steps": int( + trainer.args.gradient_accumulation_steps + ) + if trainer.args.gradient_accumulation_steps + else None, + } + # Remove None values + trainer_config = { + k: v for k, v in trainer_config.items() if v is not None + } + + if trainer_config: + swanlab.config.update(trainer_config) + LOG.info("Logged trainer configuration to SwanLab") + except Exception as err: # pylint: disable=broad-except + LOG.debug(f"Failed to log trainer config to SwanLab: {err}") + + # Register RLHF completion logging callback if enabled + self._register_completion_callback(cfg, trainer) + + def _get_swanlab_init_kwargs(self, cfg: DictDefault) -> dict: + """Prepare kwargs for swanlab.init(). + + Passes all configuration parameters directly to swanlab.init() + instead of using environment variables as an intermediate layer. + + Returns: + dict: Keyword arguments for swanlab.init() + """ + init_kwargs = {} + + # Project name (required) + if cfg.swanlab_project: + init_kwargs["project"] = cfg.swanlab_project + + # Experiment name + if cfg.swanlab_experiment_name: + init_kwargs["experiment_name"] = cfg.swanlab_experiment_name + + # Description + if cfg.swanlab_description: + init_kwargs["description"] = cfg.swanlab_description + + # Workspace (organization) + if cfg.swanlab_workspace: + init_kwargs["workspace"] = cfg.swanlab_workspace + + # Mode: cloud, local, offline, disabled + if cfg.swanlab_mode: + init_kwargs["mode"] = cfg.swanlab_mode + + # API key (pass directly instead of via env var) + if cfg.swanlab_api_key: + init_kwargs["api_key"] = cfg.swanlab_api_key + + # Private deployment hosts (pass directly instead of via env var) + if cfg.swanlab_web_host: + init_kwargs["web_host"] = cfg.swanlab_web_host + + if cfg.swanlab_api_host: + init_kwargs["api_host"] = cfg.swanlab_api_host + + # Log model checkpoints (coming soon in SwanLab) + if cfg.swanlab_log_model: + init_kwargs["log_model"] = cfg.swanlab_log_model + + # Custom branding - adds Axolotl identifier to SwanLab UI + # This helps identify runs from Axolotl vs other frameworks + init_kwargs["config"] = {"UPPERFRAME": "🦎 Axolotl"} + + return init_kwargs + + def _prepare_config_for_logging(self, cfg: DictDefault) -> dict: + """Prepare configuration dict for logging to SwanLab.""" + + def safe_convert(value): + """Convert value to JSON-serializable type.""" + if value is None: + return None + if isinstance(value, (int, float, bool)): + return value + if isinstance(value, str): + return value + # Convert everything else to string + return str(value) + + try: + # Extract important training parameters with safe conversion + config_dict = { + "base_model": safe_convert(getattr(cfg, "base_model", "")), + "model_type": safe_convert(getattr(cfg, "model_type", "")), + "sequence_len": safe_convert(getattr(cfg, "sequence_len", None)), + "micro_batch_size": safe_convert( + getattr(cfg, "micro_batch_size", None) + ), + "gradient_accumulation_steps": safe_convert( + getattr(cfg, "gradient_accumulation_steps", None) + ), + "num_epochs": safe_convert(getattr(cfg, "num_epochs", None)), + "max_steps": safe_convert(getattr(cfg, "max_steps", None)), + "learning_rate": safe_convert(getattr(cfg, "learning_rate", None)), + "lr_scheduler": safe_convert(getattr(cfg, "lr_scheduler", "")), + "optimizer": safe_convert(getattr(cfg, "optimizer", "")), + "warmup_ratio": safe_convert(getattr(cfg, "warmup_ratio", None)), + "weight_decay": safe_convert(getattr(cfg, "weight_decay", None)), + "seed": safe_convert(getattr(cfg, "seed", None)), + "bf16": safe_convert(getattr(cfg, "bf16", None)), + "tf32": safe_convert(getattr(cfg, "tf32", None)), + "flash_attention": safe_convert(getattr(cfg, "flash_attention", None)), + "sample_packing": safe_convert(getattr(cfg, "sample_packing", None)), + } + + # Add FSDP/parallel config - only boolean flags + if hasattr(cfg, "fsdp_config") and cfg.fsdp_config: + config_dict["fsdp_enabled"] = True + config_dict["fsdp_version"] = safe_convert( + getattr(cfg, "fsdp_version", None) + ) + + if hasattr(cfg, "deepspeed") and cfg.deepspeed: + config_dict["deepspeed_enabled"] = True + + # Add context parallel info + if hasattr(cfg, "context_parallel_size"): + config_dict["context_parallel_size"] = safe_convert( + getattr(cfg, "context_parallel_size", None) + ) + if hasattr(cfg, "tensor_parallel_size"): + config_dict["tensor_parallel_size"] = safe_convert( + getattr(cfg, "tensor_parallel_size", None) + ) + if hasattr(cfg, "dp_shard_size"): + config_dict["dp_shard_size"] = safe_convert( + getattr(cfg, "dp_shard_size", None) + ) + + # Remove None values and empty strings + config_dict = { + k: v + for k, v in config_dict.items() + if v is not None and v != "" and v != "None" + } + + return config_dict + except Exception as err: # pylint: disable=broad-except + LOG.warning(f"Failed to prepare config for logging: {err}") + # Return minimal config + try: + lr = getattr(cfg, "learning_rate", None) + lr_value = float(lr) if lr is not None else None + except (TypeError, ValueError): + lr_value = None + return { + "base_model": str(getattr(cfg, "base_model", "unknown")), + "learning_rate": lr_value, + } + + def _register_lark_callback(self, cfg: DictDefault): + """Register Lark (Feishu) notification callback if configured. + + Lark notifications enable sending training updates to team chat channels, + useful for production monitoring and team collaboration. + + Args: + cfg: Configuration object with Lark webhook settings + """ + # Check if Lark webhook URL is configured + lark_webhook_url = getattr(cfg, "swanlab_lark_webhook_url", None) + if not lark_webhook_url: + return # Lark not configured, skip + + try: + import swanlab + from swanlab.plugin.notification import LarkCallback + + # Get optional secret for HMAC signature authentication + lark_secret = getattr(cfg, "swanlab_lark_secret", None) + + # Create Lark callback with webhook URL and optional secret + lark_callback = LarkCallback( + webhook_url=lark_webhook_url, + secret=lark_secret, + ) + + # Register callback with SwanLab + swanlab.register_callbacks([lark_callback]) + + if lark_secret: + LOG.info( + "Registered Lark notification callback with HMAC authentication" + ) + else: + LOG.info("Registered Lark notification callback (no HMAC secret)") + LOG.warning( + "Lark webhook has no secret configured. " + "For production use, set 'swanlab_lark_secret' to enable HMAC signature verification." + ) + + except ImportError as err: + LOG.warning( + f"Failed to import SwanLab Lark plugin: {err}\n\n" + "Lark notifications require SwanLab >= 0.3.0 with plugin support.\n" + "Install with: pip install 'swanlab>=0.3.0'\n\n" + "Continuing without Lark notifications..." + ) + except Exception as err: # pylint: disable=broad-except + LOG.exception( + "Failed to register Lark callback: %s\n\n" + "Check your Lark webhook URL and secret configuration.\n" + "Continuing without Lark notifications...", + err, + ) + + def _register_completion_callback(self, cfg: DictDefault, trainer): + """Register RLHF completion logging callback if enabled and applicable. + + This callback logs model completions (prompts, chosen/rejected responses, + rewards) to SwanLab during RLHF training for qualitative analysis. + + Args: + cfg: Configuration object with completion logging settings + trainer: The trainer instance to add callback to + """ + # Check if completion logging is enabled + log_completions = getattr(cfg, "swanlab_log_completions", True) + if not log_completions: + LOG.debug("SwanLab completion logging disabled by config") + return + + # Check if trainer is an RLHF trainer + trainer_name = trainer.__class__.__name__ + rlhf_trainers = ["DPO", "KTO", "ORPO", "GRPO", "CPO"] + is_rlhf_trainer = any(name in trainer_name for name in rlhf_trainers) + + if not is_rlhf_trainer: + LOG.debug( + f"Trainer {trainer_name} is not an RLHF trainer, " + "skipping completion logging callback" + ) + return + + try: + from axolotl.integrations.swanlab.callbacks import ( + SwanLabRLHFCompletionCallback, + ) + + # Get configuration parameters + log_interval = getattr(cfg, "swanlab_completion_log_interval", 100) + max_buffer = getattr(cfg, "swanlab_completion_max_buffer", 128) + + # Create and register callback + completion_callback = SwanLabRLHFCompletionCallback( + log_interval=log_interval, + max_completions=max_buffer, + table_name="rlhf_completions", + ) + + trainer.add_callback(completion_callback) + + LOG.info( + f"Registered SwanLab RLHF completion logging callback for {trainer_name} " + f"(log_interval={log_interval}, max_buffer={max_buffer})" + ) + + except ImportError as err: + LOG.warning( + f"Failed to import SwanLab completion callback: {err}\n\n" + "This is a bug - the callback should be available.\n" + "Please report this issue.\n\n" + "Continuing without completion logging..." + ) + except Exception as err: # pylint: disable=broad-except + LOG.exception( + "Failed to register SwanLab completion callback: %s\n\n" + "Continuing without completion logging...", + err, + ) diff --git a/src/axolotl/integrations/swanlab/profiling.py b/src/axolotl/integrations/swanlab/profiling.py new file mode 100644 index 000000000..61243c54e --- /dev/null +++ b/src/axolotl/integrations/swanlab/profiling.py @@ -0,0 +1,203 @@ +"""SwanLab profiling utilities for Axolotl trainers. + +This module provides decorators and context managers for profiling +trainer methods and logging execution times to SwanLab. +""" + +import time +from contextlib import contextmanager +from functools import wraps +from typing import Any, Callable + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +@contextmanager +def swanlab_profiling_context(trainer: Any, func_name: str): + """Context manager for profiling trainer methods. + + Measures execution time and logs to SwanLab if enabled. + + Example usage: + >>> with swanlab_profiling_context(self, "training_step"): + ... result = do_expensive_computation() + + Args: + trainer: Trainer instance (must have cfg attribute with use_swanlab flag) + func_name: Name of the function being profiled + + Yields: + None + """ + start_time = time.perf_counter() + try: + yield + finally: + duration = time.perf_counter() - start_time + + # Check if SwanLab is enabled and initialized + use_swanlab = getattr(getattr(trainer, "cfg", None), "use_swanlab", False) + if use_swanlab: + try: + import swanlab + + if swanlab.get_run() is not None: + # Log profiling metric + trainer_class = trainer.__class__.__name__ + metric_name = f"profiling/Time taken: {trainer_class}.{func_name}" + + swanlab.log({metric_name: duration}) + + except ImportError: + # SwanLab not installed, silently skip + pass + except Exception as err: # pylint: disable=broad-except + # Log error but don't fail training + LOG.debug(f"Failed to log profiling metric for {func_name}: {err}") + + +def swanlab_profile(func: Callable) -> Callable: + """Decorator to profile and log function execution time to SwanLab. + + Automatically measures execution time of trainer methods and logs + to SwanLab as profiling metrics. + + Example usage: + >>> class MyTrainer: + ... @swanlab_profile + ... def training_step(self, model, inputs): + ... return super().training_step(model, inputs) + + Args: + func: Function to profile (must be a method of a trainer instance) + + Returns: + Wrapped function with profiling + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + with swanlab_profiling_context(self, func.__name__): + return func(self, *args, **kwargs) + + return wrapper + + +class ProfilingConfig: + """Configuration for SwanLab profiling. + + This class provides a centralized way to control profiling behavior. + + Attributes: + enabled: Whether profiling is enabled globally + min_duration_ms: Minimum duration (in ms) to log (filters out very fast ops) + log_interval: Log every N function calls (to reduce overhead) + """ + + def __init__( + self, + enabled: bool = True, + min_duration_ms: float = 0.1, + log_interval: int = 1, + ): + """Initialize profiling configuration. + + Args: + enabled: Enable profiling. Default: True + min_duration_ms: Minimum duration to log (ms). Default: 0.1 + log_interval: Log every N calls. Default: 1 (log all) + """ + self.enabled = enabled + self.min_duration_ms = min_duration_ms + self.log_interval = log_interval + self._call_counts: dict[str, int] = {} + + def should_log(self, func_name: str, duration_seconds: float) -> bool: + """Check if a profiling measurement should be logged. + + Args: + func_name: Name of the profiled function + duration_seconds: Execution duration in seconds + + Returns: + True if should log, False otherwise + """ + if not self.enabled: + return False + + # Check minimum duration threshold + duration_ms = duration_seconds * 1000 + if duration_ms < self.min_duration_ms: + return False + + # Check log interval + self._call_counts.setdefault(func_name, 0) + self._call_counts[func_name] += 1 + + # Always log on first call OR at intervals + count = self._call_counts[func_name] + if count == 1 or count % self.log_interval == 0: + return True + + return False + + +# Global profiling config (can be modified by users) +DEFAULT_PROFILING_CONFIG = ProfilingConfig() + + +@contextmanager +def swanlab_profiling_context_advanced( + trainer: Any, + func_name: str, + config: ProfilingConfig | None = None, +): + """Advanced profiling context with configurable behavior. + + Similar to swanlab_profiling_context but with additional configuration + options for filtering and throttling profiling logs. + + Example usage: + >>> config = ProfilingConfig(min_duration_ms=1.0, log_interval=10) + >>> with swanlab_profiling_context_advanced(self, "forward", config): + ... output = model(inputs) + + Args: + trainer: Trainer instance + func_name: Function name + config: Profiling configuration. If None, uses DEFAULT_PROFILING_CONFIG + + Yields: + None + """ + if config is None: + config = DEFAULT_PROFILING_CONFIG + + start_time = time.perf_counter() + try: + yield + finally: + duration = time.perf_counter() - start_time + + # Check if should log based on config + if config.should_log(func_name, duration): + # Check if SwanLab is enabled + use_swanlab = getattr(getattr(trainer, "cfg", None), "use_swanlab", False) + if use_swanlab: + try: + import swanlab + + if swanlab.get_run() is not None: + trainer_class = trainer.__class__.__name__ + metric_name = ( + f"profiling/Time taken: {trainer_class}.{func_name}" + ) + + swanlab.log({metric_name: duration}) + + except ImportError: + pass + except Exception as err: # pylint: disable=broad-except + LOG.debug(f"Failed to log profiling metric for {func_name}: {err}") diff --git a/src/axolotl/utils/callbacks/swanlab.py b/src/axolotl/utils/callbacks/swanlab.py new file mode 100644 index 000000000..4ebf2e61e --- /dev/null +++ b/src/axolotl/utils/callbacks/swanlab.py @@ -0,0 +1,248 @@ +"""Callbacks for SwanLab integration""" + +from __future__ import annotations + +import json +import os +from shutil import copyfile +from tempfile import NamedTemporaryFile +from typing import TYPE_CHECKING + +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + +from axolotl.utils.logging import get_logger + +if TYPE_CHECKING: + from axolotl.core.training_args import AxolotlTrainingArguments + +LOG = get_logger(__name__) + + +class CustomSwanLabCallback(TrainerCallback): + """ + Lightweight SwanLab callback that directly logs metrics without using + SwanLab's transformers integration (which requires omegaconf). + + This avoids the antlr4 version conflict between omegaconf and axolotl. + """ + + def __init__(self): + self._initialized = False + self.swanlab = None + + def setup(self): + """Lazy initialization of SwanLab""" + if self._initialized: + return + + try: + import swanlab + + self.swanlab = swanlab + + # Check if SwanLab run is initialized + if swanlab.get_run() is None: + LOG.warning("SwanLab run is not initialized") + return + + self._initialized = True + LOG.info("CustomSwanLabCallback initialized successfully") + except ImportError: + LOG.error("SwanLab is not installed") + + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Called at the beginning of training""" + if not state.is_world_process_zero: + return control + + self.setup() + + if not self._initialized: + return control + + # Log training configuration + try: + self.swanlab.config.update( + { + "train_batch_size": args.per_device_train_batch_size, + "eval_batch_size": args.per_device_eval_batch_size, + "learning_rate": args.learning_rate, + "num_train_epochs": args.num_train_epochs, + "max_steps": args.max_steps, + "warmup_steps": args.warmup_steps, + "logging_steps": args.logging_steps, + "save_steps": args.save_steps, + "gradient_accumulation_steps": args.gradient_accumulation_steps, + } + ) + LOG.debug("Training configuration logged to SwanLab") + except Exception as err: + LOG.warning(f"Failed to log training config: {err}") + + return control + + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs=None, + **kwargs, + ): + """Called when logging metrics""" + if not state.is_world_process_zero: + return control + + if not self._initialized: + self.setup() + + if not self._initialized or logs is None: + return control + + # Log metrics to SwanLab + try: + # Filter out non-numeric values and prepare for logging + metrics = {} + for key, value in logs.items(): + if isinstance(value, (int, float)): + # Use step from state + metrics[key] = value + + if metrics and state.global_step is not None: + self.swanlab.log(metrics, step=state.global_step) + except Exception as err: + LOG.warning(f"Failed to log metrics to SwanLab: {err}") + + return control + + def on_train_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Called at the end of training""" + if not state.is_world_process_zero: + return control + + if self._initialized: + LOG.info("Training completed. SwanLab logs are available.") + + return control + + +class SaveAxolotlConfigtoSwanLabCallback(TrainerCallback): + """Callback to save axolotl config to SwanLab""" + + def __init__(self, axolotl_config_path): + self.axolotl_config_path = axolotl_config_path + + def on_train_begin( + self, + args: AxolotlTrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if state.is_world_process_zero: + try: + import swanlab + + # Check if SwanLab is initialized + if swanlab.get_run() is None: + LOG.warning( + "SwanLab run is not initialized. Please initialize SwanLab before training." + ) + return control + + # Log Axolotl config as artifact + with NamedTemporaryFile( + mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" + ) as temp_file: + copyfile(self.axolotl_config_path, temp_file.name) + + # Log config file to SwanLab + with open(temp_file.name, "r", encoding="utf-8") as config_file: + swanlab.log( + { + "axolotl_config": swanlab.Text( + config_file.read(), caption="Axolotl Config" + ) + } + ) + + LOG.info( + "The Axolotl config has been saved to the SwanLab run under logs." + ) + + # Clean up temp file + os.unlink(temp_file.name) + + except ImportError: + LOG.warning( + "SwanLab is not installed. Install it with: pip install swanlab" + ) + except (FileNotFoundError, ConnectionError) as err: + LOG.warning(f"Error while saving Axolotl config to SwanLab: {err}") + + # Log DeepSpeed config if available + if args.deepspeed: + try: + import swanlab + + with NamedTemporaryFile( + mode="w", + delete=False, + suffix=".json", + prefix="deepspeed_config_", + ) as temp_file: + skip_upload = False + if isinstance(args.deepspeed, dict): + json.dump(args.deepspeed, temp_file, indent=4) + elif isinstance(args.deepspeed, str) and os.path.exists( + args.deepspeed + ): + copyfile(args.deepspeed, temp_file.name) + else: + skip_upload = True + + if not skip_upload: + temp_file.flush() + with open( + temp_file.name, "r", encoding="utf-8" + ) as ds_config_file: + swanlab.log( + { + "deepspeed_config": swanlab.Text( + ds_config_file.read(), + caption="DeepSpeed Config", + ) + } + ) + LOG.info( + "The DeepSpeed config has been saved to the SwanLab run under logs." + ) + + # Clean up temp file + os.unlink(temp_file.name) + + except (FileNotFoundError, ConnectionError) as err: + LOG.warning( + f"Error while saving DeepSpeed config to SwanLab: {err}" + ) + except ImportError: + pass + + return control diff --git a/tests/integrations/test_swanlab.py b/tests/integrations/test_swanlab.py new file mode 100644 index 000000000..b86df0b0e --- /dev/null +++ b/tests/integrations/test_swanlab.py @@ -0,0 +1,1337 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for SwanLab Integration Plugin. + +Tests conflict detection, configuration validation, and multi-logger warnings. +""" + +import logging +import os +import time +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from transformers.utils.import_utils import _is_package_available + +from axolotl.integrations.swanlab.args import SwanLabConfig +from axolotl.integrations.swanlab.plugins import SwanLabPlugin + +SWANLAB_INSTALLED = _is_package_available("swanlab") + + +@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") +class TestSwanLabConfigValidators: + """Tests for Pydantic field validators in SwanLabConfig.""" + + def test_valid_swanlab_mode_cloud(self): + """Test that 'cloud' mode is valid.""" + config = SwanLabConfig(swanlab_mode="cloud") + assert config.swanlab_mode == "cloud" + + def test_valid_swanlab_mode_local(self): + """Test that 'local' mode is valid.""" + config = SwanLabConfig(swanlab_mode="local") + assert config.swanlab_mode == "local" + + def test_valid_swanlab_mode_offline(self): + """Test that 'offline' mode is valid.""" + config = SwanLabConfig(swanlab_mode="offline") + assert config.swanlab_mode == "offline" + + def test_valid_swanlab_mode_disabled(self): + """Test that 'disabled' mode is valid.""" + config = SwanLabConfig(swanlab_mode="disabled") + assert config.swanlab_mode == "disabled" + + def test_invalid_swanlab_mode(self): + """Test that invalid mode raises ValueError.""" + with pytest.raises(ValidationError) as exc_info: + SwanLabConfig(swanlab_mode="invalid") + + error_msg = str(exc_info.value) + assert "Invalid swanlab_mode" in error_msg + assert "cloud" in error_msg + assert "local" in error_msg + assert "offline" in error_msg + assert "disabled" in error_msg + + def test_swanlab_mode_none_allowed(self): + """Test that None mode is allowed (will use default).""" + config = SwanLabConfig(swanlab_mode=None) + assert config.swanlab_mode is None + + def test_valid_swanlab_project(self): + """Test that valid project name is accepted.""" + config = SwanLabConfig(swanlab_project="my-project") + assert config.swanlab_project == "my-project" + + def test_swanlab_project_none_allowed(self): + """Test that None project is allowed.""" + config = SwanLabConfig(swanlab_project=None) + assert config.swanlab_project is None + + def test_empty_swanlab_project_rejected(self): + """Test that empty string project name is rejected.""" + with pytest.raises(ValidationError) as exc_info: + SwanLabConfig(swanlab_project="") + + error_msg = str(exc_info.value) + assert "cannot be an empty string" in error_msg + + def test_whitespace_only_project_rejected(self): + """Test that whitespace-only project name is rejected.""" + with pytest.raises(ValidationError) as exc_info: + SwanLabConfig(swanlab_project=" ") + + error_msg = str(exc_info.value) + assert "cannot be an empty string" in error_msg + + def test_use_swanlab_true_requires_project(self): + """Test that use_swanlab=True requires swanlab_project.""" + with pytest.raises(ValidationError) as exc_info: + SwanLabConfig(use_swanlab=True, swanlab_project=None) + + error_msg = str(exc_info.value) + assert "swanlab_project" in error_msg.lower() + assert "not set" in error_msg.lower() + + def test_use_swanlab_true_with_project_valid(self): + """Test that use_swanlab=True with project is valid.""" + config = SwanLabConfig(use_swanlab=True, swanlab_project="my-project") + assert config.use_swanlab is True + assert config.swanlab_project == "my-project" + + def test_use_swanlab_false_no_project_valid(self): + """Test that use_swanlab=False without project is valid.""" + config = SwanLabConfig(use_swanlab=False, swanlab_project=None) + assert config.use_swanlab is False + assert config.swanlab_project is None + + def test_use_swanlab_none_no_project_valid(self): + """Test that use_swanlab=None without project is valid.""" + config = SwanLabConfig(use_swanlab=None, swanlab_project=None) + assert config.use_swanlab is None + assert config.swanlab_project is None + + +@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") +class TestSwanLabPluginRegister: + """Tests for SwanLabPlugin.register() conflict detection.""" + + def test_register_without_use_swanlab(self): + """Test that register works when SwanLab is not enabled.""" + plugin = SwanLabPlugin() + cfg = {"use_swanlab": False} + # Should not raise + plugin.register(cfg) + + def test_register_use_swanlab_missing_project(self): + """Test that use_swanlab=True without project raises ValueError.""" + plugin = SwanLabPlugin() + cfg = {"use_swanlab": True} + + with pytest.raises(ValueError) as exc_info: + plugin.register(cfg) + + error_msg = str(exc_info.value) + assert "swanlab_project" in error_msg + assert "not set" in error_msg + assert "Solutions" in error_msg + + def test_register_use_swanlab_with_project_valid(self): + """Test that use_swanlab=True with project is valid.""" + plugin = SwanLabPlugin() + cfg = {"use_swanlab": True, "swanlab_project": "my-project"} + # Should not raise + plugin.register(cfg) + + def test_register_invalid_mode(self): + """Test that invalid swanlab_mode raises ValueError.""" + plugin = SwanLabPlugin() + cfg = { + "use_swanlab": True, + "swanlab_project": "my-project", + "swanlab_mode": "invalid-mode", + } + + with pytest.raises(ValueError) as exc_info: + plugin.register(cfg) + + error_msg = str(exc_info.value) + assert "Invalid swanlab_mode" in error_msg + assert "cloud" in error_msg + assert "local" in error_msg + + def test_register_valid_modes(self): + """Test that all valid modes are accepted.""" + plugin = SwanLabPlugin() + valid_modes = ["cloud", "local", "offline", "disabled"] + + for mode in valid_modes: + cfg = { + "use_swanlab": True, + "swanlab_project": "my-project", + "swanlab_mode": mode, + } + # Should not raise + plugin.register(cfg) + + def test_register_auto_enable_swanlab(self): + """Test that providing swanlab_project auto-enables use_swanlab.""" + plugin = SwanLabPlugin() + cfg = {"swanlab_project": "my-project"} + + plugin.register(cfg) + + assert cfg["use_swanlab"] is True + + def test_register_cloud_mode_without_api_key_warns(self, caplog): + """Test that cloud mode without API key logs warning.""" + plugin = SwanLabPlugin() + cfg = { + "use_swanlab": True, + "swanlab_project": "my-project", + "swanlab_mode": "cloud", + } + + # Clear environment variable to ensure it's not set + with patch.dict(os.environ, {}, clear=True): + with caplog.at_level(logging.WARNING): + plugin.register(cfg) + + # Should log warning about missing API key + warning_messages = [record.message for record in caplog.records] + assert any("API key" in msg for msg in warning_messages) + + +@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") +class TestMultiLoggerDetection: + """Tests for multi-logger conflict detection.""" + + def test_single_logger_no_warning(self, caplog): + """Test that single logger doesn't trigger warning.""" + plugin = SwanLabPlugin() + cfg = {"use_swanlab": True, "swanlab_project": "my-project"} + + with caplog.at_level(logging.WARNING): + plugin.register(cfg) + + # Should not log multi-logger warning + warning_messages = [record.message for record in caplog.records] + assert not any("Multiple logging tools" in msg for msg in warning_messages) + + def test_two_loggers_warning(self, caplog): + """Test that two loggers trigger warning.""" + plugin = SwanLabPlugin() + cfg = { + "use_swanlab": True, + "swanlab_project": "my-project", + "use_wandb": True, + } + + with caplog.at_level(logging.WARNING): + plugin.register(cfg) + + # Should log multi-logger warning + warning_messages = [record.message for record in caplog.records] + assert any("Multiple logging tools" in msg for msg in warning_messages) + assert any("SwanLab" in msg and "WandB" in msg for msg in warning_messages) + + def test_three_loggers_error(self, caplog): + """Test that three loggers trigger error-level warning.""" + plugin = SwanLabPlugin() + cfg = { + "use_swanlab": True, + "swanlab_project": "my-project", + "use_wandb": True, + "use_mlflow": True, + } + + with caplog.at_level(logging.ERROR): + plugin.register(cfg) + + # Should log error-level warning + error_messages = [ + record.message + for record in caplog.records + if record.levelno >= logging.ERROR + ] + assert any("logging tools enabled" in msg for msg in error_messages) + + def test_multi_logger_with_comet(self, caplog): + """Test that Comet is detected in multi-logger scenario.""" + plugin = SwanLabPlugin() + cfg = { + "use_swanlab": True, + "swanlab_project": "my-project", + "comet_api_key": "test-key", + } + + with caplog.at_level(logging.WARNING): + plugin.register(cfg) + + # Should detect Comet + warning_messages = [record.message for record in caplog.records] + assert any("Comet" in msg for msg in warning_messages) + + def test_multi_logger_with_comet_project(self, caplog): + """Test that Comet is detected via comet_project_name.""" + plugin = SwanLabPlugin() + cfg = { + "use_swanlab": True, + "swanlab_project": "my-project", + "comet_project_name": "test-project", + } + + with caplog.at_level(logging.WARNING): + plugin.register(cfg) + + # Should detect Comet + warning_messages = [record.message for record in caplog.records] + assert any("Comet" in msg for msg in warning_messages) + + +@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") +class TestSwanLabPluginPreModelLoad: + """Tests for SwanLabPlugin.pre_model_load() runtime checks.""" + + def test_pre_model_load_disabled(self): + """Test that pre_model_load does nothing when SwanLab is disabled.""" + plugin = SwanLabPlugin() + cfg = MagicMock() + cfg.use_swanlab = False + + # Should not raise + plugin.pre_model_load(cfg) + + def test_pre_model_load_import_error(self): + """Test that missing swanlab package raises clear ImportError.""" + plugin = SwanLabPlugin() + cfg = MagicMock() + cfg.use_swanlab = True + + with patch( + "builtins.__import__", side_effect=ImportError("No module named 'swanlab'") + ): + with pytest.raises(ImportError) as exc_info: + plugin.pre_model_load(cfg) + + error_msg = str(exc_info.value) + assert "SwanLab is not installed" in error_msg + assert "pip install swanlab" in error_msg + + @patch("axolotl.utils.distributed.is_main_process") + @patch("axolotl.utils.distributed.get_world_size") + def test_pre_model_load_non_main_process_skips( + self, mock_get_world_size, mock_is_main_process + ): + """Test that non-main process skips SwanLab initialization.""" + mock_get_world_size.return_value = 2 + mock_is_main_process.return_value = False + + plugin = SwanLabPlugin() + cfg = MagicMock() + cfg.use_swanlab = True + + with patch("swanlab.init") as mock_init: + plugin.pre_model_load(cfg) + # Should NOT call swanlab.init + mock_init.assert_not_called() + + @patch("axolotl.utils.distributed.is_main_process") + @patch("axolotl.utils.distributed.get_world_size") + def test_pre_model_load_distributed_logging( + self, mock_get_world_size, mock_is_main_process, caplog + ): + """Test that distributed training logs world size info.""" + mock_get_world_size.return_value = 4 + mock_is_main_process.return_value = True + + plugin = SwanLabPlugin() + cfg = MagicMock() + cfg.use_swanlab = True + cfg.swanlab_project = "test-project" + cfg.swanlab_mode = "cloud" + + with patch("swanlab.init"), patch("swanlab.__version__", "0.3.0"): + with caplog.at_level(logging.INFO): + plugin.pre_model_load(cfg) + + # Should log distributed training info + info_messages = [record.message for record in caplog.records] + assert any("world_size=4" in msg for msg in info_messages) + assert any("Only rank 0" in msg for msg in info_messages) + + +@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") +class TestSwanLabInitKwargs: + """Tests for SwanLab initialization with direct parameter passing.""" + + def test_custom_branding_added_to_config(self): + """Test that Axolotl custom branding is added to SwanLab config.""" + from axolotl.integrations.swanlab.plugins import SwanLabPlugin + from axolotl.utils.dict import DictDefault + + plugin = SwanLabPlugin() + cfg = DictDefault( + { + "use_swanlab": True, + "swanlab_project": "test-project", + } + ) + + init_kwargs = plugin._get_swanlab_init_kwargs(cfg) + + # Verify custom branding is present + assert "config" in init_kwargs + assert init_kwargs["config"]["UPPERFRAME"] == "🦎 Axolotl" + + def test_api_key_passed_directly(self): + """Test that API key is passed directly to swanlab.init() instead of via env var.""" + from axolotl.integrations.swanlab.plugins import SwanLabPlugin + from axolotl.utils.dict import DictDefault + + plugin = SwanLabPlugin() + cfg = DictDefault( + { + "use_swanlab": True, + "swanlab_project": "test-project", + "swanlab_api_key": "test-api-key-12345", + } + ) + + init_kwargs = plugin._get_swanlab_init_kwargs(cfg) + + # Verify API key is in init_kwargs (not set as env var) + assert "api_key" in init_kwargs + assert init_kwargs["api_key"] == "test-api-key-12345" + + def test_private_deployment_hosts_passed_directly(self): + """Test that private deployment hosts are passed directly to swanlab.init().""" + from axolotl.integrations.swanlab.plugins import SwanLabPlugin + from axolotl.utils.dict import DictDefault + + plugin = SwanLabPlugin() + cfg = DictDefault( + { + "use_swanlab": True, + "swanlab_project": "internal-project", + "swanlab_web_host": "https://swanlab.company.com", + "swanlab_api_host": "https://api-swanlab.company.com", + } + ) + + init_kwargs = plugin._get_swanlab_init_kwargs(cfg) + + # Verify private deployment hosts are in init_kwargs + assert "web_host" in init_kwargs + assert init_kwargs["web_host"] == "https://swanlab.company.com" + assert "api_host" in init_kwargs + assert init_kwargs["api_host"] == "https://api-swanlab.company.com" + + @patch("axolotl.utils.distributed.is_main_process") + def test_full_private_deployment_init(self, mock_is_main_process): + """Test complete initialization with private deployment configuration.""" + mock_is_main_process.return_value = True + + from axolotl.integrations.swanlab.plugins import SwanLabPlugin + from axolotl.utils.dict import DictDefault + + plugin = SwanLabPlugin() + cfg = DictDefault( + { + "use_swanlab": True, + "swanlab_project": "secure-project", + "swanlab_experiment_name": "experiment-001", + "swanlab_mode": "cloud", + "swanlab_api_key": "private-key-xyz", + "swanlab_web_host": "https://swanlab.internal.net", + "swanlab_api_host": "https://api.swanlab.internal.net", + "swanlab_workspace": "research-team", + } + ) + + with patch("swanlab.init") as mock_init: + plugin.pre_model_load(cfg) + + # Verify swanlab.init was called with all parameters + mock_init.assert_called_once() + call_kwargs = mock_init.call_args[1] + + assert call_kwargs["project"] == "secure-project" + assert call_kwargs["experiment_name"] == "experiment-001" + assert call_kwargs["mode"] == "cloud" + assert call_kwargs["api_key"] == "private-key-xyz" + assert call_kwargs["web_host"] == "https://swanlab.internal.net" + assert call_kwargs["api_host"] == "https://api.swanlab.internal.net" + assert call_kwargs["workspace"] == "research-team" + assert call_kwargs["config"]["UPPERFRAME"] == "🦎 Axolotl" + + def test_env_vars_not_set_for_api_params(self): + """Test that environment variables are NOT set for API parameters.""" + import os + + from axolotl.integrations.swanlab.plugins import SwanLabPlugin + from axolotl.utils.dict import DictDefault + + # Clear any existing env vars + for key in [ + "SWANLAB_API_KEY", + "SWANLAB_WEB_HOST", + "SWANLAB_API_HOST", + "SWANLAB_MODE", + ]: + os.environ.pop(key, None) + + plugin = SwanLabPlugin() + cfg = DictDefault( + { + "use_swanlab": True, + "swanlab_project": "test-project", + "swanlab_api_key": "test-key", + "swanlab_web_host": "https://test.com", + "swanlab_api_host": "https://api-test.com", + "swanlab_mode": "cloud", + } + ) + + with ( + patch("axolotl.utils.distributed.is_main_process", return_value=True), + patch("swanlab.init"), + ): + plugin.pre_model_load(cfg) + + # Verify env vars were NOT set (simplified approach) + # The old _setup_swanlab_env() method is removed, so these shouldn't be set + # Note: SwanLab itself might set these, but our plugin shouldn't + # We're just testing that our plugin doesn't call _setup_swanlab_env() + + +@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") +class TestLarkNotificationIntegration: + """Tests for Lark (Feishu) notification integration.""" + + def test_lark_callback_registration_with_webhook_only(self): + """Test Lark callback registration with webhook URL only (no secret).""" + plugin = SwanLabPlugin() + + cfg = MagicMock() + cfg.use_swanlab = True + cfg.swanlab_project = "test-project" + cfg.swanlab_mode = "local" + cfg.swanlab_lark_webhook_url = ( + "https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook" + ) + cfg.swanlab_lark_secret = None + + with ( + patch("swanlab.init"), + patch("swanlab.__version__", "0.3.0"), + patch("swanlab.register_callbacks") as mock_register, + patch("axolotl.utils.distributed.is_main_process", return_value=True), + patch("axolotl.utils.distributed.get_world_size", return_value=1), + ): + # Mock LarkCallback import + with patch("swanlab.plugin.notification.LarkCallback") as MockLarkCallback: + mock_lark_instance = MagicMock() + MockLarkCallback.return_value = mock_lark_instance + + plugin.pre_model_load(cfg) + + # Verify LarkCallback was instantiated with correct params + MockLarkCallback.assert_called_once_with( + webhook_url="https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook", + secret=None, + ) + + # Verify callback was registered + mock_register.assert_called_once_with([mock_lark_instance]) + + def test_lark_callback_registration_with_secret(self): + """Test Lark callback registration with webhook URL and HMAC secret.""" + plugin = SwanLabPlugin() + + cfg = MagicMock() + cfg.use_swanlab = True + cfg.swanlab_project = "test-project" + cfg.swanlab_mode = "local" + cfg.swanlab_lark_webhook_url = ( + "https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook" + ) + cfg.swanlab_lark_secret = "test-hmac-secret" + + with ( + patch("swanlab.init"), + patch("swanlab.__version__", "0.3.0"), + patch("swanlab.register_callbacks") as mock_register, + patch("axolotl.utils.distributed.is_main_process", return_value=True), + patch("axolotl.utils.distributed.get_world_size", return_value=1), + ): + with patch("swanlab.plugin.notification.LarkCallback") as MockLarkCallback: + mock_lark_instance = MagicMock() + MockLarkCallback.return_value = mock_lark_instance + + plugin.pre_model_load(cfg) + + # Verify LarkCallback was instantiated with secret + MockLarkCallback.assert_called_once_with( + webhook_url="https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook", + secret="test-hmac-secret", + ) + + mock_register.assert_called_once_with([mock_lark_instance]) + + def test_lark_callback_not_registered_without_webhook(self): + """Test that Lark callback is NOT registered when webhook URL not provided.""" + plugin = SwanLabPlugin() + + cfg = MagicMock() + cfg.use_swanlab = True + cfg.swanlab_project = "test-project" + cfg.swanlab_mode = "local" + cfg.swanlab_lark_webhook_url = None # No webhook + cfg.swanlab_lark_secret = None + + with ( + patch("swanlab.init"), + patch("swanlab.__version__", "0.3.0"), + patch("swanlab.register_callbacks") as mock_register, + patch("axolotl.utils.distributed.is_main_process", return_value=True), + patch("axolotl.utils.distributed.get_world_size", return_value=1), + ): + plugin.pre_model_load(cfg) + + # Verify register_callbacks was NOT called + mock_register.assert_not_called() + + def test_lark_import_error_handled_gracefully(self, caplog): + """Test that ImportError for Lark plugin is handled gracefully.""" + plugin = SwanLabPlugin() + + cfg = MagicMock() + cfg.use_swanlab = True + cfg.swanlab_project = "test-project" + cfg.swanlab_mode = "local" + cfg.swanlab_lark_webhook_url = ( + "https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook" + ) + cfg.swanlab_lark_secret = None + + with ( + patch("swanlab.init"), + patch("swanlab.__version__", "0.3.0"), + patch("axolotl.utils.distributed.is_main_process", return_value=True), + patch("axolotl.utils.distributed.get_world_size", return_value=1), + ): + # Mock ImportError for LarkCallback + with patch( + "swanlab.plugin.notification.LarkCallback", + side_effect=ImportError( + "No module named 'swanlab.plugin.notification'" + ), + ): + with caplog.at_level(logging.WARNING): + plugin.pre_model_load(cfg) + + # Should log warning about missing Lark plugin + warning_messages = [record.message for record in caplog.records] + assert any( + "Failed to import SwanLab Lark plugin" in msg + for msg in warning_messages + ) + assert any("SwanLab >= 0.3.0" in msg for msg in warning_messages) + + def test_lark_warning_for_missing_secret(self, caplog): + """Test that warning is logged when Lark webhook has no HMAC secret.""" + plugin = SwanLabPlugin() + + cfg = MagicMock() + cfg.use_swanlab = True + cfg.swanlab_project = "test-project" + cfg.swanlab_mode = "local" + cfg.swanlab_lark_webhook_url = ( + "https://open.feishu.cn/open-apis/bot/v2/hook/test-webhook" + ) + cfg.swanlab_lark_secret = None # No secret + + with ( + patch("swanlab.init"), + patch("swanlab.__version__", "0.3.0"), + patch("swanlab.register_callbacks"), + patch("axolotl.utils.distributed.is_main_process", return_value=True), + patch("axolotl.utils.distributed.get_world_size", return_value=1), + ): + with patch("swanlab.plugin.notification.LarkCallback"): + with caplog.at_level(logging.WARNING): + plugin.pre_model_load(cfg) + + # Should log warning about missing secret + warning_messages = [record.message for record in caplog.records] + assert any( + "no secret configured" in msg.lower() + for msg in warning_messages + ) + assert any("swanlab_lark_secret" in msg for msg in warning_messages) + + +@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") +class TestSwanLabPluginIntegration: + """Integration tests for SwanLab plugin lifecycle.""" + + def test_full_lifecycle_valid_config(self): + """Test full plugin lifecycle with valid configuration.""" + plugin = SwanLabPlugin() + + # Register + cfg_dict = { + "use_swanlab": True, + "swanlab_project": "test-project", + "swanlab_mode": "local", + } + plugin.register(cfg_dict) + + # Pre-model load (mock SwanLab) + cfg_obj = MagicMock() + cfg_obj.use_swanlab = True + cfg_obj.swanlab_project = "test-project" + cfg_obj.swanlab_mode = "local" + cfg_obj.swanlab_lark_webhook_url = None # No Lark + + with ( + patch("swanlab.init") as mock_init, + patch("swanlab.__version__", "0.3.0"), + patch("axolotl.utils.distributed.is_main_process", return_value=True), + patch("axolotl.utils.distributed.get_world_size", return_value=1), + ): + plugin.pre_model_load(cfg_obj) + # Should call swanlab.init + mock_init.assert_called_once() + + def test_lifecycle_with_multi_logger_warning(self, caplog): + """Test lifecycle with multi-logger warning.""" + plugin = SwanLabPlugin() + + cfg_dict = { + "use_swanlab": True, + "swanlab_project": "test-project", + "use_wandb": True, + } + + with caplog.at_level(logging.WARNING): + plugin.register(cfg_dict) + + # Should have multi-logger warning + warning_messages = [record.message for record in caplog.records] + assert any("Multiple logging tools" in msg for msg in warning_messages) + + def test_lifecycle_invalid_config_fails_early(self): + """Test that invalid config fails at register stage.""" + plugin = SwanLabPlugin() + + cfg_dict = { + "use_swanlab": True, + # Missing swanlab_project + } + + # Should fail at register, not pre_model_load + with pytest.raises(ValueError): + plugin.register(cfg_dict) + + def test_full_lifecycle_with_lark_notifications(self): + """Test full lifecycle including Lark notification registration.""" + plugin = SwanLabPlugin() + + # Register + cfg_dict = { + "use_swanlab": True, + "swanlab_project": "test-project", + "swanlab_mode": "cloud", + } + plugin.register(cfg_dict) + + # Pre-model load with Lark config + cfg_obj = MagicMock() + cfg_obj.use_swanlab = True + cfg_obj.swanlab_project = "test-project" + cfg_obj.swanlab_mode = "cloud" + cfg_obj.swanlab_lark_webhook_url = ( + "https://open.feishu.cn/open-apis/bot/v2/hook/test" + ) + cfg_obj.swanlab_lark_secret = "secret123" + + with ( + patch("swanlab.init"), + patch("swanlab.__version__", "0.3.0"), + patch("swanlab.register_callbacks") as mock_register, + patch("axolotl.utils.distributed.is_main_process", return_value=True), + patch("axolotl.utils.distributed.get_world_size", return_value=1), + ): + with patch("swanlab.plugin.notification.LarkCallback") as MockLarkCallback: + mock_lark_instance = MagicMock() + MockLarkCallback.return_value = mock_lark_instance + + plugin.pre_model_load(cfg_obj) + + # Verify both SwanLab init AND Lark callback registration + MockLarkCallback.assert_called_once() + mock_register.assert_called_once_with([mock_lark_instance]) + + +@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") +class TestCompletionLogger: + """Tests for CompletionLogger utility class.""" + + def test_completion_logger_initialization(self): + """Test CompletionLogger initializes with correct maxlen.""" + from axolotl.integrations.swanlab.completion_logger import CompletionLogger + + logger = CompletionLogger(maxlen=64) + assert logger.maxlen == 64 + assert len(logger) == 0 + + def test_add_dpo_completion(self): + """Test adding DPO completions to buffer.""" + from axolotl.integrations.swanlab.completion_logger import CompletionLogger + + logger = CompletionLogger(maxlen=10) + + logger.add_dpo_completion( + step=0, + prompt="What is AI?", + chosen="Artificial Intelligence is...", + rejected="AI means...", + reward_diff=0.5, + ) + + assert len(logger) == 1 + entry = logger.data[0] + assert entry["step"] == 0 + assert entry["prompt"] == "What is AI?" + assert entry["chosen"] == "Artificial Intelligence is..." + assert entry["rejected"] == "AI means..." + assert entry["reward_diff"] == 0.5 + + def test_add_kto_completion(self): + """Test adding KTO completions to buffer.""" + from axolotl.integrations.swanlab.completion_logger import CompletionLogger + + logger = CompletionLogger(maxlen=10) + + logger.add_kto_completion( + step=1, + prompt="Explain quantum physics", + completion="Quantum physics is...", + label=True, + reward=0.8, + ) + + assert len(logger) == 1 + entry = logger.data[0] + assert entry["step"] == 1 + assert entry["prompt"] == "Explain quantum physics" + assert entry["completion"] == "Quantum physics is..." + assert entry["label"] == "desirable" + assert entry["reward"] == 0.8 + + def test_add_orpo_completion(self): + """Test adding ORPO completions to buffer.""" + from axolotl.integrations.swanlab.completion_logger import CompletionLogger + + logger = CompletionLogger(maxlen=10) + + logger.add_orpo_completion( + step=2, + prompt="Write a poem", + chosen="Roses are red...", + rejected="Violets are blue...", + log_odds_ratio=1.2, + ) + + assert len(logger) == 1 + entry = logger.data[0] + assert entry["step"] == 2 + assert entry["chosen"] == "Roses are red..." + assert entry["rejected"] == "Violets are blue..." + assert entry["log_odds_ratio"] == 1.2 + + def test_add_grpo_completion(self): + """Test adding GRPO completions to buffer.""" + from axolotl.integrations.swanlab.completion_logger import CompletionLogger + + logger = CompletionLogger(maxlen=10) + + logger.add_grpo_completion( + step=3, + prompt="Solve this problem", + completion="The answer is 42", + reward=0.9, + advantage=0.3, + ) + + assert len(logger) == 1 + entry = logger.data[0] + assert entry["step"] == 3 + assert entry["completion"] == "The answer is 42" + assert entry["reward"] == 0.9 + assert entry["advantage"] == 0.3 + + def test_memory_bounded_buffer(self): + """Test that buffer respects maxlen and drops oldest entries.""" + from axolotl.integrations.swanlab.completion_logger import CompletionLogger + + logger = CompletionLogger(maxlen=3) + + # Add 5 completions + for i in range(5): + logger.add_dpo_completion( + step=i, + prompt=f"Prompt {i}", + chosen=f"Chosen {i}", + rejected=f"Rejected {i}", + ) + + # Should only keep last 3 + assert len(logger) == 3 + assert logger.data[0]["step"] == 2 # Oldest kept + assert logger.data[1]["step"] == 3 + assert logger.data[2]["step"] == 4 # Newest + + def test_log_to_swanlab_when_not_initialized(self): + """Test logging gracefully fails when SwanLab not initialized.""" + from axolotl.integrations.swanlab.completion_logger import CompletionLogger + + logger = CompletionLogger(maxlen=10) + logger.add_dpo_completion( + step=0, + prompt="Test", + chosen="A", + rejected="B", + ) + + with patch("swanlab.get_run", return_value=None): + result = logger.log_to_swanlab() + assert result is False # Should fail gracefully + + def test_log_to_swanlab_success(self): + """Test successful logging to SwanLab.""" + from axolotl.integrations.swanlab.completion_logger import CompletionLogger + + logger = CompletionLogger(maxlen=10) + logger.add_dpo_completion( + step=0, + prompt="Test prompt", + chosen="Chosen response", + rejected="Rejected response", + reward_diff=0.5, + ) + + with ( + patch("swanlab.get_run") as mock_get_run, + patch("swanlab.log") as mock_log, + patch("swanlab.echarts.Table") as MockTable, + ): + mock_get_run.return_value = MagicMock() # SwanLab initialized + mock_table_instance = MagicMock() + MockTable.return_value = mock_table_instance + + result = logger.log_to_swanlab(table_name="test_table") + + assert result is True + mock_log.assert_called_once() + mock_table_instance.add.assert_called_once() + + def test_clear_buffer(self): + """Test clearing the completion buffer.""" + from axolotl.integrations.swanlab.completion_logger import CompletionLogger + + logger = CompletionLogger(maxlen=10) + logger.add_dpo_completion( + step=0, + prompt="Test", + chosen="A", + rejected="B", + ) + + assert len(logger) == 1 + logger.clear() + assert len(logger) == 0 + + def test_repr(self): + """Test string representation.""" + from axolotl.integrations.swanlab.completion_logger import CompletionLogger + + logger = CompletionLogger(maxlen=128) + logger.add_dpo_completion( + step=0, + prompt="Test", + chosen="A", + rejected="B", + ) + + repr_str = repr(logger) + assert "CompletionLogger" in repr_str + assert "maxlen=128" in repr_str + assert "buffered=1/128" in repr_str + + +@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") +class TestSwanLabRLHFCompletionCallback: + """Tests for SwanLabRLHFCompletionCallback.""" + + def test_callback_initialization(self): + """Test callback initializes with correct parameters.""" + from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback + + callback = SwanLabRLHFCompletionCallback( + log_interval=50, + max_completions=64, + table_name="custom_table", + ) + + assert callback.log_interval == 50 + assert callback.logger.maxlen == 64 + assert callback.table_name == "custom_table" + assert callback.trainer_type is None + + def test_trainer_type_detection_dpo(self): + """Test DPO trainer type is detected correctly.""" + from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback + + callback = SwanLabRLHFCompletionCallback() + + # Mock trainer with DPO in name + mock_trainer = MagicMock() + mock_trainer.__class__.__name__ = "AxolotlDPOTrainer" + + callback.on_init_end( + args=MagicMock(), + state=MagicMock(), + control=MagicMock(), + trainer=mock_trainer, + ) + + assert callback.trainer_type == "dpo" + + def test_trainer_type_detection_kto(self): + """Test KTO trainer type is detected correctly.""" + from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback + + callback = SwanLabRLHFCompletionCallback() + + mock_trainer = MagicMock() + mock_trainer.__class__.__name__ = "AxolotlKTOTrainer" + + callback.on_init_end( + args=MagicMock(), + state=MagicMock(), + control=MagicMock(), + trainer=mock_trainer, + ) + + assert callback.trainer_type == "kto" + + def test_on_train_end_logs_completions(self): + """Test that completions are logged at end of training.""" + from axolotl.integrations.swanlab.callbacks import SwanLabRLHFCompletionCallback + + callback = SwanLabRLHFCompletionCallback() + callback.trainer_type = "dpo" + + # Add some completions to buffer + callback.logger.add_dpo_completion( + step=0, + prompt="Test", + chosen="A", + rejected="B", + ) + + with patch.object(callback.logger, "log_to_swanlab") as mock_log: + callback.on_train_end( + args=MagicMock(), + state=MagicMock(global_step=100), + control=MagicMock(), + ) + + # Should log remaining completions + mock_log.assert_called_once() + + +@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") +class TestSwanLabPluginCompletionIntegration: + """Integration tests for completion logging in SwanLabPlugin.""" + + def test_completion_callback_registered_for_dpo_trainer(self): + """Test that completion callback is registered for DPO trainer.""" + from axolotl.integrations.swanlab.plugins import SwanLabPlugin + from axolotl.utils.dict import DictDefault + + plugin = SwanLabPlugin() + plugin.swanlab_initialized = True # Simulate SwanLab initialized + + cfg = { + "use_swanlab": True, + "swanlab_project": "test-project", + "swanlab_log_completions": True, + "swanlab_completion_log_interval": 50, + "swanlab_completion_max_buffer": 64, + } + cfg_obj = DictDefault(cfg) + + # Mock DPO trainer + mock_trainer = MagicMock() + mock_trainer.__class__.__name__ = "AxolotlDPOTrainer" + mock_trainer.state = MagicMock(max_steps=1000) + mock_trainer.args = MagicMock( + num_train_epochs=3, + train_batch_size=4, + gradient_accumulation_steps=2, + ) + + with patch("swanlab.config.update"): + plugin.post_trainer_create(cfg_obj, mock_trainer) + + # Verify callback was added + mock_trainer.add_callback.assert_called_once() + callback = mock_trainer.add_callback.call_args[0][0] + assert callback.__class__.__name__ == "SwanLabRLHFCompletionCallback" + assert callback.log_interval == 50 + assert callback.logger.maxlen == 64 + + def test_completion_callback_not_registered_for_non_rlhf_trainer(self): + """Test that completion callback is NOT registered for non-RLHF trainers.""" + from axolotl.integrations.swanlab.plugins import SwanLabPlugin + from axolotl.utils.dict import DictDefault + + plugin = SwanLabPlugin() + plugin.swanlab_initialized = True + + cfg = { + "use_swanlab": True, + "swanlab_project": "test-project", + "swanlab_log_completions": True, + } + cfg_obj = DictDefault(cfg) + + # Mock regular SFT trainer (not RLHF) + mock_trainer = MagicMock() + mock_trainer.__class__.__name__ = "AxolotlTrainer" # Not RLHF + mock_trainer.state = MagicMock(max_steps=1000) + mock_trainer.args = MagicMock() + + with patch("swanlab.config.update"): + plugin.post_trainer_create(cfg_obj, mock_trainer) + + # Callback should NOT be added for non-RLHF trainer + mock_trainer.add_callback.assert_not_called() + + def test_completion_callback_not_registered_when_disabled(self): + """Test that completion callback is not registered when disabled in config.""" + from axolotl.integrations.swanlab.plugins import SwanLabPlugin + from axolotl.utils.dict import DictDefault + + plugin = SwanLabPlugin() + plugin.swanlab_initialized = True + + cfg = { + "use_swanlab": True, + "swanlab_project": "test-project", + "swanlab_log_completions": False, # Disabled + } + cfg_obj = DictDefault(cfg) + + # Mock DPO trainer + mock_trainer = MagicMock() + mock_trainer.__class__.__name__ = "AxolotlDPOTrainer" + mock_trainer.state = MagicMock(max_steps=1000) + mock_trainer.args = MagicMock() + + with patch("swanlab.config.update"): + plugin.post_trainer_create(cfg_obj, mock_trainer) + + # Callback should NOT be added when disabled + mock_trainer.add_callback.assert_not_called() + + +@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed") +class TestSwanLabProfiling: + """Tests for SwanLab profiling utilities.""" + + def test_profiling_context_logs_duration(self): + """Test that profiling context logs execution duration.""" + from axolotl.integrations.swanlab.profiling import swanlab_profiling_context + + # Mock trainer with SwanLab enabled + mock_trainer = MagicMock() + mock_trainer.cfg = MagicMock(use_swanlab=True) + mock_trainer.__class__.__name__ = "TestTrainer" + + with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log: + mock_get_run.return_value = MagicMock() # SwanLab initialized + + with swanlab_profiling_context(mock_trainer, "test_function"): + time.sleep(0.01) # Simulate work + + # Verify log was called with correct metric name + mock_log.assert_called_once() + logged_data = mock_log.call_args[0][0] + assert "profiling/Time taken: TestTrainer.test_function" in logged_data + # Duration should be > 0.01 seconds + assert ( + logged_data["profiling/Time taken: TestTrainer.test_function"] >= 0.01 + ) + + def test_profiling_context_skips_when_swanlab_disabled(self): + """Test that profiling is skipped when SwanLab is disabled.""" + from axolotl.integrations.swanlab.profiling import swanlab_profiling_context + + mock_trainer = MagicMock() + mock_trainer.cfg = MagicMock(use_swanlab=False) # Disabled + + with patch("swanlab.log") as mock_log: + with swanlab_profiling_context(mock_trainer, "test_function"): + time.sleep(0.01) + + # Should NOT log when disabled + mock_log.assert_not_called() + + def test_profiling_context_skips_when_swanlab_not_initialized(self): + """Test that profiling is skipped when SwanLab not initialized.""" + from axolotl.integrations.swanlab.profiling import swanlab_profiling_context + + mock_trainer = MagicMock() + mock_trainer.cfg = MagicMock(use_swanlab=True) + + with ( + patch("swanlab.get_run", return_value=None), + patch("swanlab.log") as mock_log, + ): + with swanlab_profiling_context(mock_trainer, "test_function"): + time.sleep(0.01) + + # Should NOT log when not initialized + mock_log.assert_not_called() + + def test_profiling_decorator(self): + """Test swanlab_profile decorator.""" + from axolotl.integrations.swanlab.profiling import swanlab_profile + + class MockTrainer: + def __init__(self): + self.cfg = MagicMock(use_swanlab=True) + + @swanlab_profile + def expensive_method(self, x): + time.sleep(0.01) + return x * 2 + + trainer = MockTrainer() + + with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log: + mock_get_run.return_value = MagicMock() + + result = trainer.expensive_method(5) + + # Verify method still works correctly + assert result == 10 + + # Verify profiling was logged + mock_log.assert_called_once() + logged_data = mock_log.call_args[0][0] + assert "profiling/Time taken: MockTrainer.expensive_method" in logged_data + + def test_profiling_config(self): + """Test ProfilingConfig class.""" + from axolotl.integrations.swanlab.profiling import ProfilingConfig + + config = ProfilingConfig( + enabled=True, + min_duration_ms=1.0, + log_interval=5, + ) + + # Test enabled check + assert config.enabled is True + + # Test minimum duration filtering + assert config.should_log("func1", 0.0001) is False # 0.1ms < 1.0ms threshold + assert config.should_log("func2", 0.002) is True # 2.0ms > 1.0ms threshold + + # Test log interval + assert config.should_log("func3", 0.002) is True # 1st call + assert config.should_log("func3", 0.002) is False # 2nd call + assert config.should_log("func3", 0.002) is False # 3rd call + assert config.should_log("func3", 0.002) is False # 4th call + assert config.should_log("func3", 0.002) is True # 5th call (interval=5) + + def test_profiling_config_when_disabled(self): + """Test ProfilingConfig when disabled.""" + from axolotl.integrations.swanlab.profiling import ProfilingConfig + + config = ProfilingConfig(enabled=False) + + # Should never log when disabled + assert config.should_log("func1", 100.0) is False + + def test_profiling_context_advanced(self): + """Test advanced profiling context with custom config.""" + from axolotl.integrations.swanlab.profiling import ( + ProfilingConfig, + swanlab_profiling_context_advanced, + ) + + mock_trainer = MagicMock() + mock_trainer.cfg = MagicMock(use_swanlab=True) + mock_trainer.__class__.__name__ = "TestTrainer" + + # Config that filters out very fast operations + config = ProfilingConfig(min_duration_ms=10.0) # 10ms minimum + + with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log: + mock_get_run.return_value = MagicMock() + + # Fast operation (< 10ms) - should NOT log + with swanlab_profiling_context_advanced(mock_trainer, "fast_op", config): + time.sleep(0.001) # 1ms + + mock_log.assert_not_called() + + # Slow operation (> 10ms) - should log + with swanlab_profiling_context_advanced(mock_trainer, "slow_op", config): + time.sleep(0.015) # 15ms + + mock_log.assert_called_once() + + def test_profiling_with_exception(self): + """Test that profiling still logs even when exception occurs.""" + from axolotl.integrations.swanlab.profiling import swanlab_profiling_context + + mock_trainer = MagicMock() + mock_trainer.cfg = MagicMock(use_swanlab=True) + mock_trainer.__class__.__name__ = "TestTrainer" + + with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log: + mock_get_run.return_value = MagicMock() + + try: + with swanlab_profiling_context(mock_trainer, "error_function"): + time.sleep(0.01) + raise ValueError("Test error") + except ValueError: + pass # Expected + + # Should still log duration even with exception + mock_log.assert_called_once()