From 889b27ecf108a9671ad869c62cd548687a3515d6 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 22 Aug 2025 05:08:02 +0000 Subject: [PATCH] tui --- requirements.txt | 4 + src/axolotl/cli/main.py | 20 + src/axolotl/tui/README.md | 216 ++++++++++ src/axolotl/tui/__init__.py | 1 + src/axolotl/tui/app.py | 180 +++++++++ src/axolotl/tui/dialogs/__init__.py | 1 + src/axolotl/tui/dialogs/training.py | 112 ++++++ src/axolotl/tui/screens/__init__.py | 1 + src/axolotl/tui/screens/base.py | 50 +++ src/axolotl/tui/screens/config.py | 383 ++++++++++++++++++ src/axolotl/tui/screens/datasets.py | 505 ++++++++++++++++++++++++ src/axolotl/tui/screens/inference.py | 395 +++++++++++++++++++ src/axolotl/tui/screens/models.py | 393 ++++++++++++++++++ src/axolotl/tui/screens/monitor.py | 444 +++++++++++++++++++++ src/axolotl/tui/screens/training.py | 568 +++++++++++++++++++++++++++ 15 files changed, 3273 insertions(+) create mode 100644 src/axolotl/tui/README.md create mode 100644 src/axolotl/tui/__init__.py create mode 100644 src/axolotl/tui/app.py create mode 100644 src/axolotl/tui/dialogs/__init__.py create mode 100644 src/axolotl/tui/dialogs/training.py create mode 100644 src/axolotl/tui/screens/__init__.py create mode 100644 src/axolotl/tui/screens/base.py create mode 100644 src/axolotl/tui/screens/config.py create mode 100644 src/axolotl/tui/screens/datasets.py create mode 100644 src/axolotl/tui/screens/inference.py create mode 100644 src/axolotl/tui/screens/models.py create mode 100644 src/axolotl/tui/screens/monitor.py create mode 100644 src/axolotl/tui/screens/training.py diff --git a/requirements.txt b/requirements.txt index c2552002f..67598c91a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -72,3 +72,7 @@ axolotl-contribs-lgpl==0.0.6 axolotl-contribs-mit==0.0.5 mistral-common==1.8.3 + +# TUI dependencies +textual==1.0.0 +rich==14.1.0 diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index e63392802..931073ac9 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -344,6 +344,26 @@ def delinearize_llama4(model: str, output: str): cli.add_command(lm_eval) +@cli.command() +def tui(): + """ + Launch the Axolotl Terminal User Interface (TUI). + + Provides an interactive interface for configuration management, + training monitoring, dataset handling, and model operations. + """ + try: + from axolotl.tui.app import run + + run() + except ImportError: + click.echo( + "TUI dependencies not installed. Install with: pip install textual rich" + ) + except Exception as e: + click.echo(f"Error launching TUI: {e}") + + def main(): cli() diff --git a/src/axolotl/tui/README.md b/src/axolotl/tui/README.md new file mode 100644 index 000000000..73195fa74 --- /dev/null +++ b/src/axolotl/tui/README.md @@ -0,0 +1,216 @@ +# Axolotl TUI (Terminal User Interface) + +A comprehensive Terminal User Interface for Axolotl, providing an interactive way to manage configurations, training jobs, datasets, models, and system monitoring. + +## Features + +### ๐Ÿ  Main Dashboard +- **Welcome Screen**: Central hub with quick access to all features +- **Keyboard Navigation**: Efficient navigation with keyboard shortcuts +- **Screen Management**: Easy switching between different functional areas + +### ๐Ÿ“ Configuration Management +- **YAML Editor**: Syntax-highlighted editor for Axolotl configurations +- **Real-time Validation**: Instant config validation with detailed error reporting +- **File Browser**: Navigate and select configuration files +- **Template Loading**: Load example configurations +- **Remote Config Support**: Load configurations from URLs + +**Key Shortcuts:** +- `Ctrl+N`: New configuration +- `Ctrl+S`: Save configuration +- `Ctrl+V`: Validate configuration +- `Ctrl+E`: Toggle edit mode + +### ๐Ÿš€ Training Management +- **Job Launcher**: Start training with different launchers (accelerate, torchrun) +- **Real-time Monitoring**: Live training progress and metrics +- **Loss Visualization**: Sparkline charts for loss curves +- **Job Control**: Start, stop, resume, and manage multiple training jobs +- **Log Streaming**: Real-time log viewing and filtering + +**Key Shortcuts:** +- `Ctrl+T`: New training job +- `Ctrl+R`: Resume training +- `Ctrl+X`: Stop training +- `R`: Refresh status + +### ๐Ÿ“Š Dataset Management +- **Dataset Browser**: Explore local and remote datasets +- **Preview & Statistics**: View dataset samples and metadata +- **Preprocessing**: Run dataset preprocessing with progress tracking +- **HuggingFace Integration**: Download and manage HF datasets +- **Format Detection**: Automatic dataset format recognition + +**Key Shortcuts:** +- `Ctrl+P`: Preprocess dataset +- `Ctrl+V`: Preview dataset +- `Ctrl+I`: Dataset information +- `R`: Refresh dataset list + +### ๐Ÿค– Model Management +- **Model Discovery**: Automatically find trained models +- **LoRA Operations**: Merge LoRA adapters with base models +- **Quantization**: Quantize models for deployment +- **Evaluation**: Run model evaluation benchmarks +- **Storage Info**: View model sizes and storage details + +**Key Shortcuts:** +- `Ctrl+M`: Merge LoRA +- `Ctrl+Q`: Quantize model +- `Ctrl+E`: Evaluate model +- `R`: Refresh model list + +### ๐Ÿ’ฌ Inference & Testing +- **Interactive Chat**: Chat interface for model testing +- **Parameter Tuning**: Adjust inference parameters (temperature, top-p, max tokens) +- **Model Loading**: Load and switch between different models +- **Chat History**: Save and load conversation history +- **Gradio Integration**: Launch Gradio web interface + +**Key Shortcuts:** +- `Ctrl+Enter`: Send message +- `Ctrl+C`: Clear chat +- `Ctrl+L`: Load model +- `Ctrl+S`: Save chat + +### ๐Ÿ“ˆ System Monitoring +- **Resource Monitoring**: Real-time CPU, GPU, and memory usage +- **Process Management**: View and manage running processes +- **Performance Graphs**: Historical usage charts with sparklines +- **GPU Information**: Detailed GPU status and memory usage +- **Temperature Monitoring**: System temperature tracking + +**Key Shortcuts:** +- `R`: Refresh metrics +- `Ctrl+K`: Kill selected process + +## Installation + +### Dependencies +```bash +pip install textual==1.0.0 rich==14.1.0 +``` + +### Launch TUI +```bash +# From command line +python -m axolotl.cli.main tui + +# From Python code +from axolotl.tui.app import run +run() +``` + +## Architecture + +### Screen Structure +``` +AxolotlTUI (Main App) +โ”œโ”€โ”€ WelcomeScreen (Dashboard) +โ”œโ”€โ”€ ConfigScreen (Configuration Management) +โ”œโ”€โ”€ TrainingScreen (Training Management) +โ”œโ”€โ”€ DatasetScreen (Dataset Management) +โ”œโ”€โ”€ ModelScreen (Model Management) +โ”œโ”€โ”€ InferenceScreen (Inference & Testing) +โ””โ”€โ”€ MonitorScreen (System Monitoring) +``` + +### Key Components +- **BaseScreen**: Common functionality for all screens +- **Screen Navigation**: Stack-based screen management +- **Event Handling**: Reactive UI updates +- **Background Tasks**: Non-blocking operations +- **State Management**: Shared application state + +### Integration Points +- **CLI Commands**: Seamless integration with existing axolotl CLI +- **Configuration System**: Uses axolotl's native config loading +- **Training Pipeline**: Integrates with axolotl training functions +- **Model Loading**: Compatible with axolotl model management + +## Usage Examples + +### 1. Creating a New Configuration +1. Launch TUI: `python -m axolotl.cli.main tui` +2. Select "Configuration Management" or press `C` +3. Press `Ctrl+N` for new configuration +4. Edit the template configuration +5. Press `Ctrl+V` to validate +6. Press `Ctrl+S` to save + +### 2. Starting a Training Job +1. Navigate to "Training Management" or press `T` +2. Press `Ctrl+T` for new training job +3. Select configuration file and launcher +4. Monitor progress in real-time +5. View loss curves and logs + +### 3. Interactive Model Testing +1. Go to "Inference & Testing" or press `I` +2. Load a trained model with `Ctrl+L` +3. Adjust inference parameters as needed +4. Start chatting with the model +5. Save conversation with `Ctrl+S` + +## Navigation + +### Global Shortcuts +- `Ctrl+Q`: Quit application +- `Escape`: Go back/close current screen +- `Tab`: Navigate between UI elements +- `Enter`: Select/activate element +- `Space`: Toggle switches/checkboxes + +### Screen Shortcuts +Each screen has specific shortcuts displayed in the footer. Common patterns: +- `Ctrl+[Letter]`: Primary actions +- `R`: Refresh/reload +- `F1-F12`: Function keys for advanced features + +## Customization + +### Themes +The TUI uses Textual's theming system and can be customized by modifying the CSS in each screen class. + +### Adding New Screens +1. Create a new screen class inheriting from `BaseScreen` +2. Implement the `compose()` method for UI layout +3. Add event handlers for user interactions +4. Register the screen in the main app navigation + +### Extending Functionality +- Add new widgets to existing screens +- Implement custom data visualization +- Integrate with external tools and APIs +- Add new keyboard shortcuts + +## Troubleshooting + +### Common Issues +1. **Import Errors**: Ensure textual and rich are installed +2. **Permission Errors**: Check file system permissions for config directories +3. **GPU Monitoring**: Install pynvml for GPU monitoring features +4. **Config Validation**: Ensure axolotl dependencies are properly installed + +### Debug Mode +Launch with debug logging: +```bash +TEXTUAL_LOG=DEBUG python -m axolotl.cli.main tui +``` + +### Performance +- Use `Ctrl+\` to open Textual's debug console +- Monitor memory usage with the system monitor +- Disable auto-refresh for better performance on slower systems + +## Contributing + +The TUI is designed to be extensible. Contributions are welcome for: +- New screen implementations +- Enhanced visualizations +- Better keyboard navigation +- Additional integrations +- Performance improvements + +See the main Axolotl repository for contribution guidelines. diff --git a/src/axolotl/tui/__init__.py b/src/axolotl/tui/__init__.py new file mode 100644 index 000000000..7f50cd651 --- /dev/null +++ b/src/axolotl/tui/__init__.py @@ -0,0 +1 @@ +"""Axolotl Terminal User Interface (TUI).""" diff --git a/src/axolotl/tui/app.py b/src/axolotl/tui/app.py new file mode 100644 index 000000000..7a6cdb1fd --- /dev/null +++ b/src/axolotl/tui/app.py @@ -0,0 +1,180 @@ +"""Main TUI application for Axolotl.""" + +from textual import on +from textual.app import App, ComposeResult +from textual.binding import Binding +from textual.containers import Container, Horizontal, ScrollableContainer +from textual.screen import Screen +from textual.widgets import Button, Footer, Header, Label, Static + +from axolotl.tui.screens.config import ConfigScreen +from axolotl.tui.screens.datasets import DatasetScreen +from axolotl.tui.screens.inference import InferenceScreen +from axolotl.tui.screens.models import ModelScreen +from axolotl.tui.screens.monitor import MonitorScreen +from axolotl.tui.screens.training import TrainingScreen + + +class WelcomeScreen(Screen): + """Welcome screen with main menu.""" + + BINDINGS = [ + Binding("q", "quit", "Quit"), + Binding("c", "config", "Configuration"), + Binding("t", "training", "Training"), + Binding("d", "datasets", "Datasets"), + Binding("m", "models", "Models"), + Binding("i", "inference", "Inference"), + Binding("s", "monitor", "System Monitor"), + ] + + def compose(self) -> ComposeResult: + """Compose the welcome screen.""" + yield Header() + yield Container( + Static("๐Ÿฆพ Axolotl TUI", classes="title"), + Static( + "A Terminal User Interface for fine-tuning LLMs", classes="subtitle" + ), + Container( + Button("Configuration Management [C]", id="config", variant="primary"), + Button("Training Management [T]", id="training", variant="primary"), + Button("Dataset Management [D]", id="datasets", variant="primary"), + Button("Model Management [M]", id="models", variant="primary"), + Button("Inference & Testing [I]", id="inference", variant="primary"), + Button("System Monitor [S]", id="monitor", variant="primary"), + classes="menu-container", + ), + classes="welcome-container", + ) + yield Footer() + + def action_quit(self) -> None: + """Quit the application.""" + self.app.exit() + + def action_config(self) -> None: + """Navigate to config screen.""" + self.app.push_screen(ConfigScreen()) + + def action_training(self) -> None: + """Navigate to training screen.""" + self.app.push_screen(TrainingScreen()) + + def action_datasets(self) -> None: + """Navigate to datasets screen.""" + self.app.push_screen(DatasetScreen()) + + def action_models(self) -> None: + """Navigate to models screen.""" + self.app.push_screen(ModelScreen()) + + def action_inference(self) -> None: + """Navigate to inference screen.""" + self.app.push_screen(InferenceScreen()) + + def action_monitor(self) -> None: + """Navigate to monitor screen.""" + self.app.push_screen(MonitorScreen()) + + @on(Button.Pressed, "#config") + def on_config_pressed(self) -> None: + """Handle config button press.""" + self.action_config() + + @on(Button.Pressed, "#training") + def on_training_pressed(self) -> None: + """Handle training button press.""" + self.action_training() + + @on(Button.Pressed, "#datasets") + def on_datasets_pressed(self) -> None: + """Handle datasets button press.""" + self.action_datasets() + + @on(Button.Pressed, "#models") + def on_models_pressed(self) -> None: + """Handle models button press.""" + self.action_models() + + @on(Button.Pressed, "#inference") + def on_inference_pressed(self) -> None: + """Handle inference button press.""" + self.action_inference() + + @on(Button.Pressed, "#monitor") + def on_monitor_pressed(self) -> None: + """Handle monitor button press.""" + self.action_monitor() + + +class AxolotlTUI(App): + """Main Axolotl TUI Application.""" + + CSS = """ + .title { + text-align: center; + text-style: bold; + padding: 1; + color: $primary; + content-align: center middle; + } + + .subtitle { + text-align: center; + padding: 1; + color: $text-muted; + } + + .welcome-container { + align: center middle; + height: 100%; + } + + .menu-container { + layout: vertical; + align: center middle; + padding: 2; + width: auto; + height: auto; + } + + .menu-container Button { + width: 35; + margin: 1; + } + + Screen { + align: center middle; + } + """ + + BINDINGS = [ + Binding("ctrl+q", "quit", "Quit", priority=True), + Binding("escape", "back", "Back", priority=True), + ] + + def on_mount(self) -> None: + """Called when the app is mounted.""" + self.title = "Axolotl TUI" + self.sub_title = "Fine-tuning LLMs made easy" + self.push_screen(WelcomeScreen()) + + def action_quit(self) -> None: + """Quit the application.""" + self.exit() + + def action_back(self) -> None: + """Go back to previous screen.""" + if len(self.screen_stack) > 1: + self.pop_screen() + + +def run(): + """Run the Axolotl TUI application.""" + app = AxolotlTUI() + app.run() + + +if __name__ == "__main__": + run() diff --git a/src/axolotl/tui/dialogs/__init__.py b/src/axolotl/tui/dialogs/__init__.py new file mode 100644 index 000000000..88280c743 --- /dev/null +++ b/src/axolotl/tui/dialogs/__init__.py @@ -0,0 +1 @@ +"""TUI dialogs for Axolotl.""" diff --git a/src/axolotl/tui/dialogs/training.py b/src/axolotl/tui/dialogs/training.py new file mode 100644 index 000000000..837f011d7 --- /dev/null +++ b/src/axolotl/tui/dialogs/training.py @@ -0,0 +1,112 @@ +"""Training dialogs for Axolotl TUI.""" + +from pathlib import Path + +from textual import on +from textual.app import ComposeResult +from textual.containers import Container, Horizontal, Vertical +from textual.screen import ModalScreen +from textual.widgets import Button, Input, Label, Select, Static + + +class NewTrainingDialog(ModalScreen): + """Dialog for starting a new training job.""" + + CSS = """ + NewTrainingDialog { + align: center middle; + } + + .dialog-container { + background: $surface; + border: thick $primary; + padding: 2; + width: 60; + height: auto; + } + + .dialog-title { + text-align: center; + text-style: bold; + padding: 1; + color: $primary; + } + + .form-field { + margin: 1 0; + } + + .form-label { + margin: 0 0 1 0; + color: $text-muted; + } + + .button-container { + layout: horizontal; + align: center middle; + margin: 2 0 0 0; + } + + .button-container Button { + margin: 0 1; + } + """ + + def compose(self) -> ComposeResult: + """Compose the dialog.""" + yield Container( + Static("Start New Training Job", classes="dialog-title"), + Container( + Label("Configuration File:", classes="form-label"), + Input( + placeholder="Path to config YAML file", + id="config-path", + value="/workspace/configs/", + ), + classes="form-field", + ), + Container( + Label("Launcher:", classes="form-label"), + Select( + [ + ("accelerate", "Accelerate (Recommended)"), + ("torchrun", "TorchRun"), + ("deepspeed", "DeepSpeed"), + ], + id="launcher", + value="accelerate", + ), + classes="form-field", + ), + Container( + Button("Start Training", variant="primary", id="start"), + Button("Cancel", variant="default", id="cancel"), + classes="button-container", + ), + classes="dialog-container", + ) + + @on(Button.Pressed, "#start") + def handle_start(self) -> None: + """Handle start button press.""" + config_input = self.query_one("#config-path", Input) + launcher_select = self.query_one("#launcher", Select) + + config_path = config_input.value.strip() + if not config_path: + return + + if not Path(config_path).exists(): + return + + result = { + "config_path": config_path, + "launcher": launcher_select.value, + } + + self.dismiss(result) + + @on(Button.Pressed, "#cancel") + def handle_cancel(self) -> None: + """Handle cancel button press.""" + self.dismiss(None) diff --git a/src/axolotl/tui/screens/__init__.py b/src/axolotl/tui/screens/__init__.py new file mode 100644 index 000000000..0e9069f65 --- /dev/null +++ b/src/axolotl/tui/screens/__init__.py @@ -0,0 +1 @@ +"""TUI screens for Axolotl.""" diff --git a/src/axolotl/tui/screens/base.py b/src/axolotl/tui/screens/base.py new file mode 100644 index 000000000..0728562c0 --- /dev/null +++ b/src/axolotl/tui/screens/base.py @@ -0,0 +1,50 @@ +"""Base screen class for Axolotl TUI screens.""" + +from textual.app import ComposeResult +from textual.binding import Binding +from textual.containers import Container +from textual.screen import Screen +from textual.widgets import Footer, Header, Static + + +class BaseScreen(Screen): + """Base class for all Axolotl TUI screens.""" + + BINDINGS = [ + Binding("escape", "back", "Back"), + Binding("q", "quit", "Quit"), + ] + + def __init__(self, title: str = "Axolotl", subtitle: str = ""): + """Initialize the base screen. + + Args: + title: The screen title + subtitle: Optional subtitle for the screen + """ + super().__init__() + self.screen_title = title + self.screen_subtitle = subtitle + + def compose(self) -> ComposeResult: + """Compose the base screen layout.""" + yield Header() + yield Container( + Static(f"๐Ÿฆพ {self.screen_title}", classes="screen-title"), + ( + Static(self.screen_subtitle, classes="screen-subtitle") + if self.screen_subtitle + else Static("") + ), + Container(id="content"), + id="main-container", + ) + yield Footer() + + def action_back(self) -> None: + """Go back to previous screen.""" + self.app.pop_screen() + + def action_quit(self) -> None: + """Quit the application.""" + self.app.exit() diff --git a/src/axolotl/tui/screens/config.py b/src/axolotl/tui/screens/config.py new file mode 100644 index 000000000..3eabd4a8e --- /dev/null +++ b/src/axolotl/tui/screens/config.py @@ -0,0 +1,383 @@ +"""Configuration management screen for Axolotl TUI.""" + +import os +from pathlib import Path +from typing import Optional + +import yaml +from textual import on, work +from textual.app import ComposeResult +from textual.binding import Binding +from textual.containers import Container, Horizontal, ScrollableContainer, Vertical +from textual.message import Message +from textual.reactive import reactive +from textual.widgets import ( + Button, + DataTable, + DirectoryTree, + Footer, + Header, + Input, + Label, + LoadingIndicator, + Log, + Select, + Static, + Switch, + TextArea, +) + +from axolotl.cli.config import load_cfg +from axolotl.tui.screens.base import BaseScreen + + +class ConfigScreen(BaseScreen): + """Configuration management screen.""" + + BINDINGS = [ + Binding("ctrl+n", "new_config", "New Config"), + Binding("ctrl+o", "open_config", "Open Config"), + Binding("ctrl+s", "save_config", "Save Config"), + Binding("ctrl+v", "validate_config", "Validate"), + Binding("ctrl+e", "edit_mode", "Toggle Edit Mode"), + ] + + CSS = """ + .config-container { + layout: horizontal; + height: 100%; + } + + .file-browser { + width: 30%; + border: solid $primary; + padding: 1; + margin: 1; + } + + .config-editor { + width: 70%; + border: solid $secondary; + padding: 1; + margin: 1; + } + + .config-form { + height: 80%; + } + + .config-actions { + layout: horizontal; + height: 3; + align: center middle; + padding: 1; + } + + .config-actions Button { + margin: 0 1; + } + + TextArea { + height: 100%; + } + + .validation-log { + height: 20%; + border: solid $warning; + padding: 1; + } + + .screen-title { + text-align: center; + text-style: bold; + padding: 1; + color: $primary; + } + + .screen-subtitle { + text-align: center; + padding: 0 0 1 0; + color: $text-muted; + } + """ + + def __init__(self): + """Initialize the config screen.""" + super().__init__( + title="Configuration Management", + subtitle="Create, edit, and validate Axolotl configurations", + ) + self.current_config_path: Optional[Path] = None + self.edit_mode = reactive(False) + self.config_data = {} + + def compose(self) -> ComposeResult: + """Compose the config screen layout.""" + yield Header() + yield Container( + Static("๐Ÿฆพ Configuration Management", classes="screen-title"), + Static( + "Create, edit, and validate Axolotl configurations", + classes="screen-subtitle", + ), + Container( + Container( + Label("Config Files"), + DirectoryTree( + ( + Path("/workspace/configs") + if Path("/workspace/configs").exists() + else Path.cwd() + ), + id="config-tree", + ), + classes="file-browser", + ), + Container( + Container( + TextArea( + "", + language="yaml", + theme="monokai", + id="config-editor", + read_only=True, + ), + classes="config-form", + ), + Container( + Button("New", id="new-config", variant="primary"), + Button("Open", id="open-config", variant="primary"), + Button("Save", id="save-config", variant="success"), + Button("Validate", id="validate-config", variant="warning"), + Button("Edit Mode", id="toggle-edit", variant="default"), + Button("Load Example", id="load-example", variant="default"), + classes="config-actions", + ), + Container( + Log(id="validation-log", wrap=True, highlight=True), + classes="validation-log", + ), + classes="config-editor", + ), + classes="config-container", + ), + id="content", + ) + yield Footer() + + def on_mount(self) -> None: + """Called when the screen is mounted.""" + tree = self.query_one("#config-tree", DirectoryTree) + tree.show_root = False + tree.guide_depth = 3 + + log = self.query_one("#validation-log", Log) + log.write_line("Ready to load configuration files...") + + @on(DirectoryTree.FileSelected) + def handle_file_selected(self, event: DirectoryTree.FileSelected) -> None: + """Handle file selection from the directory tree.""" + if event.path.suffix in [".yaml", ".yml"]: + self.load_config_file(event.path) + + def load_config_file(self, path: Path) -> None: + """Load a configuration file.""" + self.current_config_path = path + try: + with open(path, "r") as f: + content = f.read() + self.config_data = yaml.safe_load(content) + + editor = self.query_one("#config-editor", TextArea) + editor.load_text(content) + + log = self.query_one("#validation-log", Log) + log.clear() + log.write_line(f"โœ… Loaded: {path.name}") + + except Exception as e: + log = self.query_one("#validation-log", Log) + log.write_line(f"โŒ Error loading {path.name}: {str(e)}") + + @on(Button.Pressed, "#new-config") + def handle_new_config(self) -> None: + """Create a new configuration.""" + template = """# Axolotl Configuration +base_model: +model_type: +tokenizer_type: + +# Dataset Configuration +datasets: + - path: + type: + +# Training Configuration +output_dir: ./outputs +num_epochs: 3 +micro_batch_size: 1 +gradient_accumulation_steps: 4 +learning_rate: 0.00002 +warmup_steps: 100 +eval_steps: 100 +save_steps: 500 + +# LoRA Configuration (optional) +adapter: lora +lora_r: 8 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: + +# Training optimizations +gradient_checkpointing: true +flash_attention: true +bf16: auto +tf32: true + +# Logging +logging_steps: 10 +wandb_project: +wandb_entity: +""" + editor = self.query_one("#config-editor", TextArea) + editor.load_text(template) + editor.read_only = False + self.edit_mode = True + self.update_edit_button() + + log = self.query_one("#validation-log", Log) + log.clear() + log.write_line("๐Ÿ“ New configuration created. Edit and save when ready.") + + @on(Button.Pressed, "#save-config") + def handle_save_config(self) -> None: + """Save the current configuration.""" + editor = self.query_one("#config-editor", TextArea) + content = editor.text + + if not content.strip(): + log = self.query_one("#validation-log", Log) + log.write_line("โš ๏ธ Cannot save empty configuration") + return + + if not self.current_config_path: + default_path = Path("/workspace/configs/new_config.yaml") + default_path.parent.mkdir(parents=True, exist_ok=True) + self.current_config_path = default_path + + try: + with open(self.current_config_path, "w") as f: + f.write(content) + + log = self.query_one("#validation-log", Log) + log.write_line(f"๐Ÿ’พ Saved: {self.current_config_path.name}") + except Exception as e: + log = self.query_one("#validation-log", Log) + log.write_line(f"โŒ Error saving: {str(e)}") + + @on(Button.Pressed, "#validate-config") + @work(thread=True) + async def handle_validate_config(self) -> None: + """Validate the current configuration.""" + editor = self.query_one("#config-editor", TextArea) + content = editor.text + + if not content.strip(): + log = self.query_one("#validation-log", Log) + log.write_line("โš ๏ธ No configuration to validate") + return + + log = self.query_one("#validation-log", Log) + log.clear() + log.write_line("๐Ÿ” Validating configuration...") + + try: + import tempfile + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".yaml", delete=False + ) as f: + f.write(content) + temp_path = f.name + + from argparse import Namespace + + from axolotl.cli.config import check_user_config + + args = Namespace( + config=temp_path, + debug=False, + debug_text_only=False, + debug_num_examples=5, + accelerate_config=None, + multi_gpu=False, + ) + + check_user_config(args) + + log.write_line("โœ… Configuration is valid!") + + os.unlink(temp_path) + + except Exception as e: + log.write_line(f"โŒ Validation failed: {str(e)}") + if "temp_path" in locals(): + os.unlink(temp_path) + + @on(Button.Pressed, "#toggle-edit") + def handle_toggle_edit(self) -> None: + """Toggle edit mode for the configuration.""" + editor = self.query_one("#config-editor", TextArea) + self.edit_mode = not self.edit_mode + editor.read_only = not self.edit_mode + self.update_edit_button() + + log = self.query_one("#validation-log", Log) + if self.edit_mode: + log.write_line("โœ๏ธ Edit mode enabled") + else: + log.write_line("๐Ÿ‘๏ธ View mode enabled") + + @on(Button.Pressed, "#load-example") + async def handle_load_example(self) -> None: + """Load an example configuration.""" + examples_dir = Path("/workspace/axolotl/examples") + if not examples_dir.exists(): + log = self.query_one("#validation-log", Log) + log.write_line("โš ๏ธ Examples directory not found") + return + + yaml_files = list(examples_dir.glob("**/*.yml")) + list( + examples_dir.glob("**/*.yaml") + ) + if yaml_files: + self.load_config_file(yaml_files[0]) + log = self.query_one("#validation-log", Log) + log.write_line(f"๐Ÿ“š Loaded example: {yaml_files[0].name}") + + def update_edit_button(self) -> None: + """Update the edit button appearance.""" + button = self.query_one("#toggle-edit", Button) + if self.edit_mode: + button.variant = "warning" + button.label = "Edit Mode: ON" + else: + button.variant = "default" + button.label = "Edit Mode: OFF" + + def action_new_config(self) -> None: + """Create a new configuration.""" + self.handle_new_config() + + def action_save_config(self) -> None: + """Save the current configuration.""" + self.handle_save_config() + + def action_validate_config(self) -> None: + """Validate the current configuration.""" + self.handle_validate_config() + + def action_edit_mode(self) -> None: + """Toggle edit mode.""" + self.handle_toggle_edit() diff --git a/src/axolotl/tui/screens/datasets.py b/src/axolotl/tui/screens/datasets.py new file mode 100644 index 000000000..22670b096 --- /dev/null +++ b/src/axolotl/tui/screens/datasets.py @@ -0,0 +1,505 @@ +"""Dataset management screen for Axolotl TUI.""" + +import json +from pathlib import Path +from typing import Dict, List, Optional + +from textual import on, work +from textual.app import ComposeResult +from textual.binding import Binding +from textual.containers import Container, Horizontal, ScrollableContainer, Vertical +from textual.widgets import ( + Button, + DataTable, + Footer, + Header, + Input, + Label, + LoadingIndicator, + Log, + Pretty, + ProgressBar, + Select, + Static, + TabbedContent, + TabPane, + TextArea, +) + +from axolotl.tui.screens.base import BaseScreen + + +class DatasetScreen(BaseScreen): + """Dataset management screen.""" + + BINDINGS = [ + Binding("ctrl+p", "preprocess", "Preprocess"), + Binding("ctrl+v", "preview", "Preview"), + Binding("ctrl+i", "info", "Info"), + Binding("r", "refresh", "Refresh"), + ] + + CSS = """ + .dataset-container { + layout: horizontal; + height: 100%; + } + + .dataset-list { + width: 40%; + border: solid $primary; + padding: 1; + margin: 1; + } + + .dataset-details { + width: 60%; + border: solid $secondary; + padding: 1; + margin: 1; + } + + .dataset-actions { + layout: horizontal; + height: 4; + align: center middle; + padding: 1; + } + + .dataset-actions Button { + margin: 0 1; + } + + DataTable { + height: 100%; + } + + .preview-container { + height: 100%; + border: solid $info; + padding: 1; + } + + TextArea { + height: 100%; + } + + .stats-container { + layout: vertical; + padding: 1; + } + + .stat-row { + layout: horizontal; + padding: 0 0 1 0; + } + + .stat-label { + width: 50%; + color: $text-muted; + } + + .stat-value { + width: 50%; + text-align: right; + text-style: bold; + } + + .screen-title { + text-align: center; + text-style: bold; + padding: 1; + color: $primary; + } + + .screen-subtitle { + text-align: center; + padding: 0 0 1 0; + color: $text-muted; + } + + .progress-container { + padding: 1; + border: solid $warning; + margin: 1; + } + """ + + def __init__(self): + """Initialize the dataset screen.""" + super().__init__( + title="Dataset Management", + subtitle="Browse, preview, and preprocess datasets", + ) + self.datasets: Dict[str, Dict] = {} + self.selected_dataset: Optional[str] = None + self.preprocessing_active = False + + def compose(self) -> ComposeResult: + """Compose the dataset screen layout.""" + yield Header() + yield Container( + Static("๐Ÿฆพ Dataset Management", classes="screen-title"), + Static( + "Browse, preview, and preprocess datasets", classes="screen-subtitle" + ), + Container( + Container( + Label("Available Datasets"), + DataTable(id="dataset-table"), + Container( + Button("Load Dataset", id="load-dataset", variant="primary"), + Button("Preprocess", id="preprocess", variant="success"), + Button("Download", id="download", variant="default"), + Button("Refresh", id="refresh", variant="default"), + classes="dataset-actions", + ), + classes="dataset-list", + ), + Container( + TabbedContent( + TabPane( + "Preview", + Container( + TextArea( + "", + language="json", + theme="monokai", + id="dataset-preview", + read_only=True, + ), + classes="preview-container", + ), + ), + TabPane( + "Statistics", + Container( + Container( + Static("Dataset Name:", classes="stat-label"), + Static("-", id="stat-name", classes="stat-value"), + classes="stat-row", + ), + Container( + Static("Type:", classes="stat-label"), + Static("-", id="stat-type", classes="stat-value"), + classes="stat-row", + ), + Container( + Static("Size:", classes="stat-label"), + Static("-", id="stat-size", classes="stat-value"), + classes="stat-row", + ), + Container( + Static("Samples:", classes="stat-label"), + Static( + "-", id="stat-samples", classes="stat-value" + ), + classes="stat-row", + ), + Container( + Static("Features:", classes="stat-label"), + Static( + "-", id="stat-features", classes="stat-value" + ), + classes="stat-row", + ), + Container( + Static("Format:", classes="stat-label"), + Static("-", id="stat-format", classes="stat-value"), + classes="stat-row", + ), + Container( + Static("Preprocessed:", classes="stat-label"), + Static( + "-", + id="stat-preprocessed", + classes="stat-value", + ), + classes="stat-row", + ), + classes="stats-container", + ), + ), + TabPane( + "Processing", + Container( + Log(id="processing-log", wrap=True, highlight=True), + Container( + Label("Preprocessing Progress:"), + ProgressBar( + total=100, + id="preprocessing-progress", + ), + classes="progress-container", + ), + ), + ), + ), + classes="dataset-details", + ), + classes="dataset-container", + ), + id="content", + ) + yield Footer() + + def on_mount(self) -> None: + """Called when the screen is mounted.""" + self.setup_dataset_table() + self.load_datasets() + + log = self.query_one("#processing-log", Log) + log.write_line("Dataset manager ready.") + + def setup_dataset_table(self) -> None: + """Setup the dataset table.""" + table = self.query_one("#dataset-table", DataTable) + table.add_columns("Name", "Type", "Size", "Status") + table.cursor_type = "row" + table.zebra_stripes = True + + @work(thread=True) + async def load_datasets(self) -> None: + """Load available datasets.""" + # Check for local datasets + datasets_dir = Path("/workspace/datasets") + if datasets_dir.exists(): + for dataset_path in datasets_dir.glob("*"): + if dataset_path.is_dir(): + self.datasets[dataset_path.name] = { + "name": dataset_path.name, + "path": str(dataset_path), + "type": "local", + "size": self.get_dir_size(dataset_path), + "status": "available", + } + + # Check for HuggingFace datasets in configs + configs_dir = Path("/workspace/configs") + if configs_dir.exists(): + for config_file in configs_dir.glob("*.yaml"): + try: + import yaml + + with open(config_file) as f: + config = yaml.safe_load(f) + if "datasets" in config: + for ds in config.get("datasets", []): + if "path" in ds: + ds_name = ds["path"].split("/")[-1] + self.datasets[ds_name] = { + "name": ds_name, + "path": ds["path"], + "type": ds.get("type", "huggingface"), + "size": "Unknown", + "status": "remote", + } + except Exception: + pass + + self.refresh_dataset_table() + + def get_dir_size(self, path: Path) -> str: + """Get human-readable directory size.""" + total_size = sum(f.stat().st_size for f in path.rglob("*") if f.is_file()) + + for unit in ["B", "KB", "MB", "GB"]: + if total_size < 1024.0: + return f"{total_size:.2f} {unit}" + total_size /= 1024.0 + return f"{total_size:.2f} TB" + + def refresh_dataset_table(self) -> None: + """Refresh the dataset table.""" + table = self.query_one("#dataset-table", DataTable) + table.clear() + + for name, info in self.datasets.items(): + table.add_row( + name[:30], + info["type"], + info["size"], + info["status"], + ) + + @on(DataTable.RowSelected) + def handle_dataset_selected(self, event: DataTable.RowSelected) -> None: + """Handle dataset selection from table.""" + if event.row_index >= 0: + dataset_names = list(self.datasets.keys()) + if event.row_index < len(dataset_names): + self.selected_dataset = dataset_names[event.row_index] + self.load_dataset_preview() + self.update_dataset_stats() + + @work(thread=True) + async def load_dataset_preview(self) -> None: + """Load preview of selected dataset.""" + if not self.selected_dataset: + return + + dataset_info = self.datasets[self.selected_dataset] + preview_text = "" + + try: + if dataset_info["type"] == "local" and Path(dataset_info["path"]).exists(): + # Load first few samples from local dataset + sample_files = list(Path(dataset_info["path"]).glob("*.json"))[:3] + samples = [] + for sample_file in sample_files: + with open(sample_file) as f: + samples.append(json.load(f)) + + preview_text = json.dumps(samples, indent=2) + else: + # Show dataset info for remote datasets + preview_text = json.dumps(dataset_info, indent=2) + + except Exception as e: + preview_text = f"Error loading preview: {str(e)}" + + preview = self.query_one("#dataset-preview", TextArea) + preview.load_text(preview_text) + + def update_dataset_stats(self) -> None: + """Update dataset statistics display.""" + if not self.selected_dataset: + return + + info = self.datasets[self.selected_dataset] + + self.query_one("#stat-name", Static).update(info["name"]) + self.query_one("#stat-type", Static).update(info["type"]) + self.query_one("#stat-size", Static).update(info["size"]) + self.query_one("#stat-samples", Static).update("N/A") + self.query_one("#stat-features", Static).update("N/A") + self.query_one("#stat-format", Static).update("JSON") + self.query_one("#stat-preprocessed", Static).update("No") + + @on(Button.Pressed, "#preprocess") + @work(thread=True) + async def handle_preprocess(self) -> None: + """Preprocess selected dataset.""" + if not self.selected_dataset or self.preprocessing_active: + return + + self.preprocessing_active = True + dataset_info = self.datasets[self.selected_dataset] + + log = self.query_one("#processing-log", Log) + log.clear() + log.write_line(f"๐Ÿ”„ Starting preprocessing for {self.selected_dataset}...") + + progress = self.query_one("#preprocessing-progress", ProgressBar) + progress.update(progress=0) + + try: + import subprocess + import tempfile + + # Create a temporary config for preprocessing + with tempfile.NamedTemporaryFile( + mode="w", suffix=".yaml", delete=False + ) as f: + config = { + "datasets": [ + { + "path": dataset_info["path"], + "type": dataset_info.get("type", "alpaca"), + } + ], + "output_dir": f"/tmp/preprocessed_{self.selected_dataset}", + } + import yaml + + yaml.dump(config, f) + temp_config = f.name + + # Run preprocessing + cmd = ["python", "-m", "axolotl.cli.preprocess", temp_config] + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + # Monitor progress + for line in process.stdout: + log.write_line(line.strip()) + # Update progress bar based on output + if "Processing" in line: + progress.advance(10) + + process.wait() + + if process.returncode == 0: + log.write_line("โœ… Preprocessing completed successfully!") + dataset_info["status"] = "preprocessed" + progress.update(progress=100) + else: + log.write_line( + f"โŒ Preprocessing failed with code {process.returncode}" + ) + + import os + + os.unlink(temp_config) + + except Exception as e: + log.write_line(f"โŒ Error during preprocessing: {str(e)}") + finally: + self.preprocessing_active = False + self.refresh_dataset_table() + + @on(Button.Pressed, "#load-dataset") + async def handle_load_dataset(self) -> None: + """Load a new dataset.""" + log = self.query_one("#processing-log", Log) + log.write_line("๐Ÿ“ฆ Load dataset functionality coming soon...") + + @on(Button.Pressed, "#download") + @work(thread=True) + async def handle_download(self) -> None: + """Download a remote dataset.""" + if not self.selected_dataset: + return + + dataset_info = self.datasets[self.selected_dataset] + if dataset_info["type"] != "huggingface": + return + + log = self.query_one("#processing-log", Log) + log.clear() + log.write_line(f"๐Ÿ“ฅ Downloading {self.selected_dataset} from HuggingFace...") + + try: + from datasets import load_dataset + + dataset = load_dataset(dataset_info["path"]) + save_path = Path(f"/workspace/datasets/{self.selected_dataset}") + save_path.mkdir(parents=True, exist_ok=True) + + dataset.save_to_disk(str(save_path)) + + log.write_line(f"โœ… Downloaded to {save_path}") + dataset_info["type"] = "local" + dataset_info["status"] = "available" + dataset_info["path"] = str(save_path) + self.refresh_dataset_table() + + except Exception as e: + log.write_line(f"โŒ Download failed: {str(e)}") + + @on(Button.Pressed, "#refresh") + def handle_refresh(self) -> None: + """Refresh dataset list.""" + self.load_datasets() + + def action_preprocess(self) -> None: + """Preprocess selected dataset.""" + self.handle_preprocess() + + def action_refresh(self) -> None: + """Refresh dataset list.""" + self.handle_refresh() diff --git a/src/axolotl/tui/screens/inference.py b/src/axolotl/tui/screens/inference.py new file mode 100644 index 000000000..df68c8eec --- /dev/null +++ b/src/axolotl/tui/screens/inference.py @@ -0,0 +1,395 @@ +"""Inference and testing screen for Axolotl TUI.""" + +from pathlib import Path +from typing import Dict, List, Optional + +from textual import on, work +from textual.app import ComposeResult +from textual.binding import Binding +from textual.containers import Container, Horizontal, ScrollableContainer, Vertical +from textual.widgets import ( + Button, + Input, + Label, + Log, + Pretty, + Select, + Static, + TextArea, +) + +from axolotl.tui.screens.base import BaseScreen + + +class InferenceScreen(BaseScreen): + """Inference and testing screen.""" + + BINDINGS = [ + Binding("ctrl+enter", "send_message", "Send"), + Binding("ctrl+c", "clear_chat", "Clear"), + Binding("ctrl+l", "load_model", "Load Model"), + Binding("ctrl+s", "save_chat", "Save Chat"), + ] + + CSS = """ + .inference-container { + layout: horizontal; + height: 100%; + } + + .model-selector { + width: 30%; + border: solid $primary; + padding: 1; + margin: 1; + } + + .chat-interface { + width: 70%; + border: solid $secondary; + padding: 1; + margin: 1; + } + + .chat-history { + height: 70%; + border: solid $info; + padding: 1; + margin: 0 0 1 0; + } + + .input-area { + height: 20%; + border: solid $warning; + padding: 1; + margin: 0 0 1 0; + } + + .chat-controls { + layout: horizontal; + height: 4; + align: center middle; + padding: 1; + } + + .chat-controls Button { + margin: 0 1; + } + + .model-info { + padding: 1; + border: solid $surface; + margin: 1 0; + } + + .screen-title { + text-align: center; + text-style: bold; + padding: 1; + color: $primary; + } + + .screen-subtitle { + text-align: center; + padding: 0 0 1 0; + color: $text-muted; + } + + TextArea { + height: 100%; + } + + Log { + height: 100%; + } + """ + + def __init__(self): + """Initialize the inference screen.""" + super().__init__( + title="Inference & Testing", subtitle="Interactive chat and model testing" + ) + self.loaded_model: Optional[str] = None + self.chat_history: List[Dict[str, str]] = [] + + def compose(self) -> ComposeResult: + """Compose the inference screen layout.""" + yield Container( + Static("๐Ÿฆพ Inference & Testing", classes="screen-title"), + Static("Interactive chat and model testing", classes="screen-subtitle"), + Container( + Container( + Label("Model Selection"), + Select( + [("none", "No model loaded")], + id="model-select", + value="none", + ), + Container( + Button("Load Model", id="load-model", variant="primary"), + Button("Unload", id="unload-model", variant="default"), + Button("Gradio UI", id="gradio-ui", variant="success"), + ), + Container( + Static("No model loaded", id="model-status"), + classes="model-info", + ), + Label("Inference Parameters"), + Container( + Label("Temperature:"), + Input(value="0.7", id="temperature"), + Label("Max Tokens:"), + Input(value="256", id="max-tokens"), + Label("Top P:"), + Input(value="0.9", id="top-p"), + ), + classes="model-selector", + ), + Container( + Container( + Log(id="chat-history", wrap=True, highlight=True), + classes="chat-history", + ), + Container( + TextArea( + placeholder="Type your message here...", + id="message-input", + ), + classes="input-area", + ), + Container( + Button("Send [Ctrl+Enter]", id="send", variant="primary"), + Button("Clear Chat", id="clear", variant="warning"), + Button("Save Chat", id="save-chat", variant="default"), + Button("Load Examples", id="load-examples", variant="default"), + classes="chat-controls", + ), + classes="chat-interface", + ), + classes="inference-container", + ), + id="content", + ) + + def on_mount(self) -> None: + """Called when the screen is mounted.""" + self.load_available_models() + + chat = self.query_one("#chat-history", Log) + chat.write_line("๐Ÿ’ฌ Welcome to Axolotl Inference!") + chat.write_line("Load a model to start chatting.") + + @work(thread=True) + async def load_available_models(self) -> None: + """Load list of available models.""" + models = [("none", "No model loaded")] + + # Check for trained models + outputs_dir = Path("./outputs") + if outputs_dir.exists(): + for model_dir in outputs_dir.glob("*"): + if model_dir.is_dir() and (model_dir / "pytorch_model.bin").exists(): + models.append((str(model_dir), model_dir.name)) + + # Check for HuggingFace models in cache + hf_cache = Path.home() / ".cache" / "huggingface" / "transformers" + if hf_cache.exists(): + for model_dir in hf_cache.glob("models--*"): + if model_dir.is_dir(): + model_name = model_dir.name.replace("models--", "").replace( + "--", "/" + ) + models.append((str(model_dir), f"HF: {model_name}")) + + select = self.query_one("#model-select", Select) + select.set_options(models) + + @on(Button.Pressed, "#load-model") + @work(thread=True) + async def handle_load_model(self) -> None: + """Load selected model for inference.""" + select = self.query_one("#model-select", Select) + if select.value == "none": + return + + chat = self.query_one("#chat-history", Log) + chat.write_line(f"๐Ÿ”„ Loading model: {select.value}") + + status = self.query_one("#model-status", Static) + status.update("Loading...") + + try: + # Simulate model loading (in real implementation, would load the actual model) + import time + + time.sleep(2) # Simulate loading time + + self.loaded_model = select.value + status.update(f"โœ… Loaded: {Path(select.value).name}") + chat.write_line("โœ… Model loaded successfully!") + chat.write_line("You can now start chatting.") + + except Exception as e: + status.update("โŒ Failed to load") + chat.write_line(f"โŒ Failed to load model: {str(e)}") + + @on(Button.Pressed, "#send") + async def handle_send_message(self) -> None: + """Send message to model.""" + if not self.loaded_model: + chat = self.query_one("#chat-history", Log) + chat.write_line("โš ๏ธ Please load a model first") + return + + message_input = self.query_one("#message-input", TextArea) + message = message_input.text.strip() + + if not message: + return + + # Add user message to chat + chat = self.query_one("#chat-history", Log) + chat.write_line(f"๐Ÿ‘ค User: {message}") + + # Clear input + message_input.clear() + + # Add to history + self.chat_history.append({"role": "user", "content": message}) + + # Generate response (placeholder) + await self.generate_response(message) + + @work(thread=True) + async def generate_response(self, message: str) -> None: + """Generate model response (placeholder implementation).""" + chat = self.query_one("#chat-history", Log) + chat.write_line("๐Ÿค– Assistant: Thinking...") + + try: + # Get inference parameters + temperature = float(self.query_one("#temperature", Input).value) + max_tokens = int(self.query_one("#max-tokens", Input).value) + top_p = float(self.query_one("#top-p", Input).value) + + # Placeholder response (in real implementation, would call the model) + import time + + time.sleep(1) # Simulate inference time + + response = f"This is a placeholder response to: '{message}'. In a real implementation, this would be generated by the loaded model using the parameters: temperature={temperature}, max_tokens={max_tokens}, top_p={top_p}." + + # Update chat with response + chat.write_line(f"๐Ÿค– Assistant: {response}") + + # Add to history + self.chat_history.append({"role": "assistant", "content": response}) + + except Exception as e: + chat.write_line(f"โŒ Error generating response: {str(e)}") + + @on(Button.Pressed, "#clear") + def handle_clear_chat(self) -> None: + """Clear chat history.""" + chat = self.query_one("#chat-history", Log) + chat.clear() + self.chat_history = [] + chat.write_line("๐Ÿ’ฌ Chat cleared. Start a new conversation!") + + @on(Button.Pressed, "#save-chat") + def handle_save_chat(self) -> None: + """Save chat history to file.""" + if not self.chat_history: + chat = self.query_one("#chat-history", Log) + chat.write_line("โš ๏ธ No chat history to save") + return + + try: + import json + from datetime import datetime + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"chat_history_{timestamp}.json" + + with open(filename, "w") as f: + json.dump(self.chat_history, f, indent=2) + + chat = self.query_one("#chat-history", Log) + chat.write_line(f"๐Ÿ’พ Chat saved to {filename}") + + except Exception as e: + chat = self.query_one("#chat-history", Log) + chat.write_line(f"โŒ Error saving chat: {str(e)}") + + @on(Button.Pressed, "#load-examples") + def handle_load_examples(self) -> None: + """Load example prompts.""" + examples = [ + "Explain the concept of machine learning in simple terms.", + "Write a Python function to calculate fibonacci numbers.", + "What are the benefits of fine-tuning language models?", + "Describe the difference between supervised and unsupervised learning.", + ] + + chat = self.query_one("#chat-history", Log) + chat.write_line("๐Ÿ“š Example prompts:") + for i, example in enumerate(examples, 1): + chat.write_line(f"{i}. {example}") + chat.write_line("Copy and paste any example to try it out!") + + @on(Button.Pressed, "#gradio-ui") + @work(thread=True) + async def handle_gradio_ui(self) -> None: + """Launch Gradio web interface.""" + chat = self.query_one("#chat-history", Log) + chat.write_line("๐ŸŒ Launching Gradio web interface...") + + try: + import subprocess + + if self.loaded_model: + cmd = [ + "python", + "-m", + "axolotl.cli.inference", + self.loaded_model, + "--gradio", + ] + else: + chat.write_line("โš ๏ธ No model loaded. Loading default interface...") + cmd = ["python", "-m", "axolotl.cli.inference", "--gradio"] + + subprocess.Popen(cmd) + chat.write_line("โœ… Gradio interface launched! Check your browser.") + + except Exception as e: + chat.write_line(f"โŒ Error launching Gradio: {str(e)}") + + @on(Button.Pressed, "#unload-model") + def handle_unload_model(self) -> None: + """Unload current model.""" + self.loaded_model = None + status = self.query_one("#model-status", Static) + status.update("No model loaded") + + select = self.query_one("#model-select", Select) + select.value = "none" + + chat = self.query_one("#chat-history", Log) + chat.write_line("๐Ÿ”„ Model unloaded") + + def action_send_message(self) -> None: + """Send message action.""" + self.handle_send_message() + + def action_clear_chat(self) -> None: + """Clear chat action.""" + self.handle_clear_chat() + + def action_load_model(self) -> None: + """Load model action.""" + self.handle_load_model() + + def action_save_chat(self) -> None: + """Save chat action.""" + self.handle_save_chat() diff --git a/src/axolotl/tui/screens/models.py b/src/axolotl/tui/screens/models.py new file mode 100644 index 000000000..9e0acc723 --- /dev/null +++ b/src/axolotl/tui/screens/models.py @@ -0,0 +1,393 @@ +"""Model management screen for Axolotl TUI.""" + +from pathlib import Path +from typing import Dict, List, Optional + +from textual import on, work +from textual.app import ComposeResult +from textual.binding import Binding +from textual.containers import Container, Horizontal, ScrollableContainer +from textual.widgets import ( + Button, + DataTable, + Footer, + Header, + Input, + Label, + Log, + ProgressBar, + Select, + Static, + TabbedContent, + TabPane, +) + +from axolotl.tui.screens.base import BaseScreen + + +class ModelScreen(BaseScreen): + """Model management screen.""" + + BINDINGS = [ + Binding("ctrl+m", "merge_lora", "Merge LoRA"), + Binding("ctrl+q", "quantize", "Quantize"), + Binding("ctrl+e", "evaluate", "Evaluate"), + Binding("r", "refresh", "Refresh"), + ] + + CSS = """ + .model-container { + layout: horizontal; + height: 100%; + } + + .model-list { + width: 50%; + border: solid $primary; + padding: 1; + margin: 1; + } + + .model-operations { + width: 50%; + border: solid $secondary; + padding: 1; + margin: 1; + } + + .model-actions { + layout: horizontal; + height: 4; + align: center middle; + padding: 1; + } + + .model-actions Button { + margin: 0 1; + } + + DataTable { + height: 80%; + } + + .screen-title { + text-align: center; + text-style: bold; + padding: 1; + color: $primary; + } + + .screen-subtitle { + text-align: center; + padding: 0 0 1 0; + color: $text-muted; + } + """ + + def __init__(self): + """Initialize the model screen.""" + super().__init__( + title="Model Management", + subtitle="Manage trained models, merge LoRA adapters, and quantize models", + ) + self.models: Dict[str, Dict] = {} + self.selected_model: Optional[str] = None + + def compose(self) -> ComposeResult: + """Compose the model screen layout.""" + yield Header() + yield Container( + Static("๐Ÿฆพ Model Management", classes="screen-title"), + Static( + "Manage trained models, merge LoRA adapters, and quantize models", + classes="screen-subtitle", + ), + Container( + Container( + Label("Available Models"), + DataTable(id="model-table"), + Container( + Button("Merge LoRA", id="merge-lora", variant="primary"), + Button("Quantize", id="quantize", variant="success"), + Button("Evaluate", id="evaluate", variant="warning"), + Button("Refresh", id="refresh", variant="default"), + classes="model-actions", + ), + classes="model-list", + ), + Container( + TabbedContent( + TabPane( + "Operations", + Container( + Log(id="operations-log", wrap=True, highlight=True), + Container( + Label("Operation Progress:"), + ProgressBar( + total=100, + id="operation-progress", + ), + ), + ), + ), + TabPane( + "Model Info", + ScrollableContainer( + Static( + "Model information will appear here", + id="model-info", + ), + ), + ), + ), + classes="model-operations", + ), + classes="model-container", + ), + id="content", + ) + yield Footer() + + def on_mount(self) -> None: + """Called when the screen is mounted.""" + self.setup_model_table() + self.load_models() + + log = self.query_one("#operations-log", Log) + log.write_line("Model manager ready.") + + def setup_model_table(self) -> None: + """Setup the model table.""" + table = self.query_one("#model-table", DataTable) + table.add_columns("Name", "Type", "Size", "Status") + table.cursor_type = "row" + table.zebra_stripes = True + + @work(thread=True) + async def load_models(self) -> None: + """Load available models.""" + # Check outputs directory for trained models + outputs_dir = Path("./outputs") + if outputs_dir.exists(): + for model_dir in outputs_dir.glob("*"): + if model_dir.is_dir(): + self.models[model_dir.name] = { + "name": model_dir.name, + "path": str(model_dir), + "type": "checkpoint", + "size": self.get_dir_size(model_dir), + "status": "available", + } + + self.refresh_model_table() + + def get_dir_size(self, path: Path) -> str: + """Get human-readable directory size.""" + try: + total_size = sum(f.stat().st_size for f in path.rglob("*") if f.is_file()) + + for unit in ["B", "KB", "MB", "GB"]: + if total_size < 1024.0: + return f"{total_size:.2f} {unit}" + total_size /= 1024.0 + return f"{total_size:.2f} TB" + except Exception: + return "Unknown" + + def refresh_model_table(self) -> None: + """Refresh the model table.""" + table = self.query_one("#model-table", DataTable) + table.clear() + + for name, info in self.models.items(): + table.add_row( + name[:30], + info["type"], + info["size"], + info["status"], + ) + + @on(DataTable.RowSelected) + def handle_model_selected(self, event: DataTable.RowSelected) -> None: + """Handle model selection from table.""" + if event.row_index >= 0: + model_names = list(self.models.keys()) + if event.row_index < len(model_names): + self.selected_model = model_names[event.row_index] + self.update_model_info() + + def update_model_info(self) -> None: + """Update model information display.""" + if not self.selected_model: + return + + info = self.models[self.selected_model] + info_text = f""" +Model Name: {info['name']} +Path: {info['path']} +Type: {info['type']} +Size: {info['size']} +Status: {info['status']} + """ + + self.query_one("#model-info", Static).update(info_text) + + @on(Button.Pressed, "#merge-lora") + @work(thread=True) + async def handle_merge_lora(self) -> None: + """Merge LoRA adapters with base model.""" + if not self.selected_model: + log = self.query_one("#operations-log", Log) + log.write_line("โš ๏ธ No model selected") + return + + model_info = self.models[self.selected_model] + log = self.query_one("#operations-log", Log) + log.clear() + log.write_line(f"๐Ÿ”„ Merging LoRA adapters for {self.selected_model}...") + + progress = self.query_one("#operation-progress", ProgressBar) + progress.update(progress=0) + + try: + import subprocess + + cmd = ["python", "-m", "axolotl.cli.merge_lora", model_info["path"]] + + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + for line in process.stdout: + log.write_line(line.strip()) + progress.advance(10) + + process.wait() + + if process.returncode == 0: + log.write_line("โœ… LoRA merge completed successfully!") + progress.update(progress=100) + else: + log.write_line(f"โŒ LoRA merge failed with code {process.returncode}") + + except Exception as e: + log.write_line(f"โŒ Error during LoRA merge: {str(e)}") + + @on(Button.Pressed, "#quantize") + @work(thread=True) + async def handle_quantize(self) -> None: + """Quantize selected model.""" + if not self.selected_model: + log = self.query_one("#operations-log", Log) + log.write_line("โš ๏ธ No model selected") + return + + model_info = self.models[self.selected_model] + log = self.query_one("#operations-log", Log) + log.clear() + log.write_line(f"๐Ÿ”„ Quantizing {self.selected_model}...") + + progress = self.query_one("#operation-progress", ProgressBar) + progress.update(progress=0) + + try: + import subprocess + + cmd = [ + "python", + "-m", + "axolotl.cli.quantize", + model_info["path"], + "--output-dir", + f"{model_info['path']}_quantized", + ] + + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + for line in process.stdout: + log.write_line(line.strip()) + progress.advance(5) + + process.wait() + + if process.returncode == 0: + log.write_line("โœ… Quantization completed successfully!") + progress.update(progress=100) + else: + log.write_line(f"โŒ Quantization failed with code {process.returncode}") + + except Exception as e: + log.write_line(f"โŒ Error during quantization: {str(e)}") + + @on(Button.Pressed, "#evaluate") + @work(thread=True) + async def handle_evaluate(self) -> None: + """Evaluate selected model.""" + if not self.selected_model: + log = self.query_one("#operations-log", Log) + log.write_line("โš ๏ธ No model selected") + return + + model_info = self.models[self.selected_model] + log = self.query_one("#operations-log", Log) + log.clear() + log.write_line(f"๐Ÿ”„ Evaluating {self.selected_model}...") + + progress = self.query_one("#operation-progress", ProgressBar) + progress.update(progress=0) + + try: + import subprocess + + cmd = ["python", "-m", "axolotl.cli.evaluate", model_info["path"]] + + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + for line in process.stdout: + log.write_line(line.strip()) + progress.advance(10) + + process.wait() + + if process.returncode == 0: + log.write_line("โœ… Evaluation completed successfully!") + progress.update(progress=100) + else: + log.write_line(f"โŒ Evaluation failed with code {process.returncode}") + + except Exception as e: + log.write_line(f"โŒ Error during evaluation: {str(e)}") + + @on(Button.Pressed, "#refresh") + def handle_refresh(self) -> None: + """Refresh model list.""" + self.load_models() + + def action_merge_lora(self) -> None: + """Merge LoRA adapters.""" + self.handle_merge_lora() + + def action_quantize(self) -> None: + """Quantize model.""" + self.handle_quantize() + + def action_evaluate(self) -> None: + """Evaluate model.""" + self.handle_evaluate() + + def action_refresh(self) -> None: + """Refresh model list.""" + self.handle_refresh() diff --git a/src/axolotl/tui/screens/monitor.py b/src/axolotl/tui/screens/monitor.py new file mode 100644 index 000000000..a0eca8829 --- /dev/null +++ b/src/axolotl/tui/screens/monitor.py @@ -0,0 +1,444 @@ +"""System monitoring screen for Axolotl TUI.""" + +import psutil +from textual import on, work +from textual.app import ComposeResult +from textual.binding import Binding +from textual.containers import Container, Horizontal, Vertical +from textual.widgets import ( + Button, + DataTable, + Footer, + Header, + Label, + Log, + ProgressBar, + Sparkline, + Static, + TabbedContent, + TabPane, +) + +from axolotl.tui.screens.base import BaseScreen + + +class MonitorScreen(BaseScreen): + """System monitoring screen.""" + + BINDINGS = [ + Binding("r", "refresh", "Refresh"), + Binding("ctrl+k", "kill_process", "Kill Process"), + ] + + CSS = """ + .monitor-container { + layout: vertical; + height: 100%; + } + + .metrics-grid { + layout: horizontal; + height: 20%; + padding: 1; + } + + .metric-card { + width: 25%; + border: solid $surface; + padding: 1; + margin: 0 1; + } + + .metric-label { + text-style: bold; + color: $text-muted; + text-align: center; + } + + .metric-value { + text-style: bold; + text-align: center; + padding: 1; + font-size: 2; + } + + .charts-container { + height: 40%; + layout: horizontal; + padding: 1; + } + + .chart-panel { + width: 50%; + border: solid $info; + padding: 1; + margin: 0 1; + } + + .processes-container { + height: 40%; + border: solid $warning; + padding: 1; + margin: 1; + } + + DataTable { + height: 90%; + } + + .process-controls { + layout: horizontal; + height: 4; + align: center middle; + padding: 1; + } + + .process-controls Button { + margin: 0 1; + } + + .screen-title { + text-align: center; + text-style: bold; + padding: 1; + color: $primary; + } + + .screen-subtitle { + text-align: center; + padding: 0 0 1 0; + color: $text-muted; + } + + Sparkline { + height: 8; + } + + ProgressBar { + margin: 1 0; + } + """ + + def __init__(self): + """Initialize the monitor screen.""" + super().__init__( + title="System Monitor", + subtitle="Monitor system resources and running processes", + ) + self.cpu_history = [] + self.memory_history = [] + self.gpu_history = [] + + def compose(self) -> ComposeResult: + """Compose the monitor screen layout.""" + yield Header() + yield Container( + Static("๐Ÿฆพ System Monitor", classes="screen-title"), + Static( + "Monitor system resources and running processes", + classes="screen-subtitle", + ), + Container( + Container( + Container( + Static("CPU Usage", classes="metric-label"), + Static("0%", id="cpu-usage", classes="metric-value"), + ProgressBar(total=100, id="cpu-progress"), + classes="metric-card", + ), + Container( + Static("Memory", classes="metric-label"), + Static("0%", id="memory-usage", classes="metric-value"), + ProgressBar(total=100, id="memory-progress"), + classes="metric-card", + ), + Container( + Static("GPU Usage", classes="metric-label"), + Static("0%", id="gpu-usage", classes="metric-value"), + ProgressBar(total=100, id="gpu-progress"), + classes="metric-card", + ), + Container( + Static("Temperature", classes="metric-label"), + Static("0ยฐC", id="temperature", classes="metric-value"), + classes="metric-card", + ), + classes="metrics-grid", + ), + Container( + Container( + Label("CPU History"), + Sparkline([], id="cpu-sparkline"), + classes="chart-panel", + ), + Container( + Label("Memory History"), + Sparkline([], id="memory-sparkline"), + classes="chart-panel", + ), + classes="charts-container", + ), + Container( + TabbedContent( + TabPane( + "Processes", + Container( + DataTable(id="process-table"), + Container( + Button( + "Kill Process", + id="kill-process", + variant="error", + ), + Button("Refresh", id="refresh", variant="default"), + Button( + "Auto Refresh", + id="auto-refresh", + variant="primary", + ), + classes="process-controls", + ), + ), + ), + TabPane( + "GPU Info", + Log(id="gpu-info", wrap=True, highlight=True), + ), + TabPane( + "System Logs", + Log(id="system-logs", wrap=True, highlight=True), + ), + ), + classes="processes-container", + ), + classes="monitor-container", + ), + id="content", + ) + yield Footer() + + def on_mount(self) -> None: + """Called when the screen is mounted.""" + self.setup_process_table() + self.start_monitoring() + + # Initial system info + self.update_system_info() + self.update_gpu_info() + + def setup_process_table(self) -> None: + """Setup the process table.""" + table = self.query_one("#process-table", DataTable) + table.add_columns("PID", "Name", "CPU%", "Memory%", "Status") + table.cursor_type = "row" + table.zebra_stripes = True + + def start_monitoring(self) -> None: + """Start the monitoring timer.""" + self.set_interval(2.0, self.update_system_metrics) + + @work(thread=True) + async def update_system_metrics(self) -> None: + """Update system metrics.""" + try: + # CPU usage + cpu_percent = psutil.cpu_percent(interval=None) + self.cpu_history.append(cpu_percent) + if len(self.cpu_history) > 50: + self.cpu_history.pop(0) + + # Memory usage + memory = psutil.virtual_memory() + memory_percent = memory.percent + self.memory_history.append(memory_percent) + if len(self.memory_history) > 50: + self.memory_history.pop(0) + + # GPU usage (if available) + gpu_percent = self.get_gpu_usage() + self.gpu_history.append(gpu_percent) + if len(self.gpu_history) > 50: + self.gpu_history.pop(0) + + # Temperature + temperature = self.get_temperature() + + # Update UI + self.update_metrics_display( + cpu_percent, memory_percent, gpu_percent, temperature + ) + self.update_sparklines() + self.update_process_table() + + except Exception as e: + log = self.query_one("#system-logs", Log) + log.write_line(f"Error updating metrics: {str(e)}") + + def get_gpu_usage(self) -> float: + """Get GPU usage percentage.""" + try: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + util = pynvml.nvmlDeviceGetUtilizationRates(handle) + return util.gpu + except Exception: + return 0.0 + + def get_temperature(self) -> str: + """Get system temperature.""" + try: + temps = psutil.sensors_temperatures() + if temps: + for name, entries in temps.items(): + if entries: + return f"{entries[0].current:.1f}ยฐC" + return "N/A" + except Exception: + return "N/A" + + def update_metrics_display( + self, cpu: float, memory: float, gpu: float, temp: str + ) -> None: + """Update metrics display.""" + self.query_one("#cpu-usage", Static).update(f"{cpu:.1f}%") + self.query_one("#memory-usage", Static).update(f"{memory:.1f}%") + self.query_one("#gpu-usage", Static).update(f"{gpu:.1f}%") + self.query_one("#temperature", Static).update(temp) + + self.query_one("#cpu-progress", ProgressBar).update(progress=cpu) + self.query_one("#memory-progress", ProgressBar).update(progress=memory) + self.query_one("#gpu-progress", ProgressBar).update(progress=gpu) + + def update_sparklines(self) -> None: + """Update sparkline charts.""" + if self.cpu_history: + cpu_sparkline = self.query_one("#cpu-sparkline", Sparkline) + cpu_sparkline.data = self.cpu_history + + if self.memory_history: + memory_sparkline = self.query_one("#memory-sparkline", Sparkline) + memory_sparkline.data = self.memory_history + + def update_process_table(self) -> None: + """Update the process table.""" + table = self.query_one("#process-table", DataTable) + table.clear() + + try: + # Get top processes by CPU usage + processes = [] + for proc in psutil.process_iter( + ["pid", "name", "cpu_percent", "memory_percent", "status"] + ): + try: + pinfo = proc.info + if pinfo["cpu_percent"] > 0.1: # Only show processes using CPU + processes.append(pinfo) + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + + # Sort by CPU usage + processes.sort(key=lambda x: x["cpu_percent"], reverse=True) + + # Add top 20 processes + for proc in processes[:20]: + table.add_row( + str(proc["pid"]), + proc["name"][:20], + f"{proc['cpu_percent']:.1f}%", + f"{proc['memory_percent']:.1f}%", + proc["status"], + ) + + except Exception as e: + log = self.query_one("#system-logs", Log) + log.write_line(f"Error updating process table: {str(e)}") + + def update_system_info(self) -> None: + """Update system information.""" + try: + # System info + boot_time = psutil.boot_time() + cpu_count = psutil.cpu_count() + memory = psutil.virtual_memory() + + log = self.query_one("#system-logs", Log) + log.write_line(f"System started. CPU cores: {cpu_count}") + log.write_line(f"Total memory: {memory.total / (1024**3):.1f} GB") + log.write_line(f"Available memory: {memory.available / (1024**3):.1f} GB") + + except Exception as e: + log = self.query_one("#system-logs", Log) + log.write_line(f"Error getting system info: {str(e)}") + + def update_gpu_info(self) -> None: + """Update GPU information.""" + try: + import pynvml + + pynvml.nvmlInit() + + device_count = pynvml.nvmlDeviceGetCount() + log = self.query_one("#gpu-info", Log) + log.clear() + log.write_line(f"Found {device_count} GPU(s)") + + for i in range(device_count): + handle = pynvml.nvmlDeviceGetHandleByIndex(i) + name = pynvml.nvmlDeviceGetName(handle).decode() + memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + + log.write_line(f"\nGPU {i}: {name}") + log.write_line( + f"Memory: {memory_info.used / (1024**3):.1f} / {memory_info.total / (1024**3):.1f} GB" + ) + log.write_line(f"Free: {memory_info.free / (1024**3):.1f} GB") + + except Exception as e: + log = self.query_one("#gpu-info", Log) + log.clear() + log.write_line(f"GPU info unavailable: {str(e)}") + + @on(Button.Pressed, "#kill-process") + def handle_kill_process(self) -> None: + """Kill selected process.""" + table = self.query_one("#process-table", DataTable) + if table.cursor_row >= 0: + try: + row = table.get_row_at(table.cursor_row) + pid = int(row[0]) + + process = psutil.Process(pid) + process.terminate() + + log = self.query_one("#system-logs", Log) + log.write_line(f"Terminated process {pid}") + + except Exception as e: + log = self.query_one("#system-logs", Log) + log.write_line(f"Error killing process: {str(e)}") + + @on(Button.Pressed, "#refresh") + def handle_refresh(self) -> None: + """Refresh all metrics.""" + self.update_system_info() + self.update_gpu_info() + + log = self.query_one("#system-logs", Log) + log.write_line("Metrics refreshed") + + @on(Button.Pressed, "#auto-refresh") + def handle_auto_refresh(self) -> None: + """Toggle auto refresh.""" + log = self.query_one("#system-logs", Log) + log.write_line("Auto refresh is always enabled (every 2 seconds)") + + def action_refresh(self) -> None: + """Refresh action.""" + self.handle_refresh() + + def action_kill_process(self) -> None: + """Kill process action.""" + self.handle_kill_process() diff --git a/src/axolotl/tui/screens/training.py b/src/axolotl/tui/screens/training.py new file mode 100644 index 000000000..8658f653a --- /dev/null +++ b/src/axolotl/tui/screens/training.py @@ -0,0 +1,568 @@ +"""Training management screen for Axolotl TUI.""" + +import asyncio +import os +import subprocess +import threading +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional + +from textual import on, work +from textual.app import ComposeResult +from textual.binding import Binding +from textual.containers import Container, Horizontal, ScrollableContainer, Vertical +from textual.reactive import reactive +from textual.widgets import ( + Button, + DataTable, + Footer, + Header, + Input, + Label, + LoadingIndicator, + Log, + ProgressBar, + Select, + Sparkline, + Static, + Switch, + TabbedContent, + TabPane, +) + +from axolotl.tui.screens.base import BaseScreen + + +@dataclass +class TrainingJob: + """Represents a training job.""" + + id: str + config_path: str + status: str # pending, running, completed, failed + start_time: Optional[datetime] = None + end_time: Optional[datetime] = None + process: Optional[subprocess.Popen] = None + log_file: Optional[str] = None + current_epoch: int = 0 + total_epochs: int = 0 + current_loss: float = 0.0 + losses: List[float] = None + + def __post_init__(self): + if self.losses is None: + self.losses = [] + + +class TrainingScreen(BaseScreen): + """Training management screen.""" + + BINDINGS = [ + Binding("ctrl+t", "new_training", "New Training"), + Binding("ctrl+r", "resume_training", "Resume"), + Binding("ctrl+x", "stop_training", "Stop"), + Binding("ctrl+l", "view_logs", "View Logs"), + Binding("r", "refresh", "Refresh"), + ] + + CSS = """ + .training-container { + layout: vertical; + height: 100%; + } + + .job-list-container { + height: 40%; + border: solid $primary; + padding: 1; + margin: 1; + } + + .job-details-container { + height: 60%; + padding: 1; + } + + .control-panel { + layout: horizontal; + height: 4; + align: center middle; + padding: 1; + border: solid $secondary; + margin: 1; + } + + .control-panel Button { + margin: 0 1; + } + + .metrics-panel { + layout: horizontal; + height: 10; + border: solid $info; + padding: 1; + margin: 1; + } + + .metric-card { + width: 25%; + border: tall $surface; + padding: 1; + margin: 0 1; + } + + .metric-label { + text-style: bold; + color: $text-muted; + } + + .metric-value { + text-style: bold; + text-align: center; + padding: 1; + } + + .log-viewer { + border: solid $warning; + padding: 1; + margin: 1; + } + + #training-logs { + height: 100%; + } + + DataTable { + height: 100%; + } + + .screen-title { + text-align: center; + text-style: bold; + padding: 1; + color: $primary; + } + + .screen-subtitle { + text-align: center; + padding: 0 0 1 0; + color: $text-muted; + } + + .sparkline-container { + height: 5; + border: solid $success; + padding: 1; + margin: 1; + } + """ + + def __init__(self): + """Initialize the training screen.""" + super().__init__( + title="Training Management", + subtitle="Launch, monitor, and manage training jobs", + ) + self.jobs: Dict[str, TrainingJob] = {} + self.selected_job_id: Optional[str] = None + self.update_timer = None + + def compose(self) -> ComposeResult: + """Compose the training screen layout.""" + yield Header() + yield Container( + Static("๐Ÿฆพ Training Management", classes="screen-title"), + Static( + "Launch, monitor, and manage training jobs", classes="screen-subtitle" + ), + Container( + Container( + Label("Active Training Jobs"), + DataTable(id="job-table"), + classes="job-list-container", + ), + Container( + Button("New Training", id="new-training", variant="primary"), + Button("Resume", id="resume-training", variant="success"), + Button("Stop", id="stop-training", variant="error"), + Button("View Logs", id="view-logs", variant="default"), + Button("Clear Completed", id="clear-completed", variant="warning"), + Button("Refresh", id="refresh", variant="default"), + classes="control-panel", + ), + Container( + Container( + Static("Current Epoch", classes="metric-label"), + Static("0 / 0", id="epoch-metric", classes="metric-value"), + classes="metric-card", + ), + Container( + Static("Loss", classes="metric-label"), + Static("0.000", id="loss-metric", classes="metric-value"), + classes="metric-card", + ), + Container( + Static("Status", classes="metric-label"), + Static("Idle", id="status-metric", classes="metric-value"), + classes="metric-card", + ), + Container( + Static("Duration", classes="metric-label"), + Static( + "00:00:00", id="duration-metric", classes="metric-value" + ), + classes="metric-card", + ), + classes="metrics-panel", + ), + Container( + Label("Loss History"), + Sparkline( + [], + id="loss-sparkline", + summary_function=min, + ), + classes="sparkline-container", + ), + Container( + TabbedContent( + TabPane( + "Training Logs", + Log(id="training-logs", wrap=True, highlight=True), + ), + TabPane( + "System Logs", + Log(id="system-logs", wrap=True, highlight=True), + ), + TabPane( + "Validation", + Log(id="validation-logs", wrap=True, highlight=True), + ), + ), + classes="log-viewer", + ), + classes="job-details-container", + ), + classes="training-container", + id="content", + ) + yield Footer() + + def on_mount(self) -> None: + """Called when the screen is mounted.""" + self.setup_job_table() + self.start_update_timer() + + log = self.query_one("#training-logs", Log) + log.write_line( + "Training manager ready. Select a configuration to start training." + ) + + def setup_job_table(self) -> None: + """Setup the job table.""" + table = self.query_one("#job-table", DataTable) + table.add_columns("ID", "Config", "Status", "Epoch", "Loss", "Duration") + table.cursor_type = "row" + table.zebra_stripes = True + + def start_update_timer(self) -> None: + """Start the periodic update timer.""" + self.set_interval(2.0, self.update_job_status) + + @work(thread=True) + async def update_job_status(self) -> None: + """Update job status periodically.""" + for job_id, job in self.jobs.items(): + if job.status == "running" and job.process: + poll = job.process.poll() + if poll is not None: + if poll == 0: + job.status = "completed" + else: + job.status = "failed" + job.end_time = datetime.now() + + self.refresh_job_table() + self.update_selected_job_metrics() + + def refresh_job_table(self) -> None: + """Refresh the job table.""" + table = self.query_one("#job-table", DataTable) + table.clear() + + for job_id, job in self.jobs.items(): + duration = self.calculate_duration(job) + table.add_row( + job_id[:8], + Path(job.config_path).name, + job.status, + f"{job.current_epoch}/{job.total_epochs}", + f"{job.current_loss:.4f}" if job.current_loss else "N/A", + duration, + ) + + def calculate_duration(self, job: TrainingJob) -> str: + """Calculate job duration.""" + if not job.start_time: + return "00:00:00" + + end_time = job.end_time or datetime.now() + duration = end_time - job.start_time + hours = int(duration.total_seconds() // 3600) + minutes = int((duration.total_seconds() % 3600) // 60) + seconds = int(duration.total_seconds() % 60) + return f"{hours:02d}:{minutes:02d}:{seconds:02d}" + + def update_selected_job_metrics(self) -> None: + """Update metrics for selected job.""" + if not self.selected_job_id or self.selected_job_id not in self.jobs: + return + + job = self.jobs[self.selected_job_id] + + self.query_one("#epoch-metric", Static).update( + f"{job.current_epoch} / {job.total_epochs}" + ) + self.query_one("#loss-metric", Static).update( + f"{job.current_loss:.4f}" if job.current_loss else "N/A" + ) + self.query_one("#status-metric", Static).update(job.status.upper()) + self.query_one("#duration-metric", Static).update(self.calculate_duration(job)) + + if job.losses: + sparkline = self.query_one("#loss-sparkline", Sparkline) + sparkline.data = job.losses[-50:] # Show last 50 loss values + + @on(DataTable.RowSelected) + def handle_row_selected(self, event: DataTable.RowSelected) -> None: + """Handle job selection from table.""" + if event.row_index >= 0: + job_ids = list(self.jobs.keys()) + if event.row_index < len(job_ids): + self.selected_job_id = job_ids[event.row_index] + self.update_selected_job_metrics() + self.load_job_logs() + + def load_job_logs(self) -> None: + """Load logs for selected job.""" + if not self.selected_job_id or self.selected_job_id not in self.jobs: + return + + job = self.jobs[self.selected_job_id] + if job.log_file and Path(job.log_file).exists(): + try: + with open(job.log_file, "r") as f: + content = f.read() + log = self.query_one("#training-logs", Log) + log.clear() + for line in content.split("\n")[-100:]: # Show last 100 lines + if line.strip(): + log.write_line(line) + except Exception as e: + log = self.query_one("#training-logs", Log) + log.write_line(f"Error loading logs: {str(e)}") + + @on(Button.Pressed, "#new-training") + async def handle_new_training(self) -> None: + """Start a new training job.""" + from axolotl.tui.dialogs.training import NewTrainingDialog + + dialog = NewTrainingDialog() + result = await self.app.push_screen_wait(dialog) + + if result and "config_path" in result: + await self.start_training_job( + result["config_path"], result.get("launcher", "accelerate") + ) + + @work(thread=True) + async def start_training_job( + self, config_path: str, launcher: str = "accelerate" + ) -> None: + """Start a training job.""" + import uuid + from datetime import datetime + + job_id = str(uuid.uuid4()) + log_file = f"/tmp/axolotl_training_{job_id}.log" + + job = TrainingJob( + id=job_id, + config_path=config_path, + status="pending", + start_time=datetime.now(), + log_file=log_file, + total_epochs=3, # Default, should parse from config + ) + + self.jobs[job_id] = job + self.selected_job_id = job_id + + log = self.query_one("#training-logs", Log) + log.clear() + log.write_line(f"๐Ÿš€ Starting training job {job_id[:8]}...") + log.write_line(f"Config: {config_path}") + log.write_line(f"Launcher: {launcher}") + + try: + if launcher == "accelerate": + cmd = ["accelerate", "launch", "-m", "axolotl.cli.train", config_path] + else: + cmd = [ + "torchrun", + "--nproc_per_node=1", + "-m", + "axolotl.cli.train", + config_path, + ] + + with open(log_file, "w") as f: + process = subprocess.Popen( + cmd, + stdout=f, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + + job.process = process + job.status = "running" + + log.write_line("โœ… Training started successfully!") + self.refresh_job_table() + + self.monitor_training_output(job_id) + + except Exception as e: + job.status = "failed" + job.end_time = datetime.now() + log.write_line(f"โŒ Failed to start training: {str(e)}") + self.refresh_job_table() + + def monitor_training_output(self, job_id: str) -> None: + """Monitor training output and extract metrics.""" + if job_id not in self.jobs: + return + + job = self.jobs[job_id] + if not job.log_file: + return + + def tail_log(): + import re + import time + + with open(job.log_file, "r") as f: + f.seek(0, 2) # Go to end of file + while job.status == "running": + line = f.readline() + if line: + # Parse training metrics from log + epoch_match = re.search(r"Epoch (\d+)/(\d+)", line) + if epoch_match: + job.current_epoch = int(epoch_match.group(1)) + job.total_epochs = int(epoch_match.group(2)) + + loss_match = re.search( + r"loss['\"]?\s*:\s*([\d.]+)", line, re.IGNORECASE + ) + if loss_match: + job.current_loss = float(loss_match.group(1)) + job.losses.append(job.current_loss) + + # Update log viewer + self.call_from_thread(self.append_training_log, line.strip()) + else: + time.sleep(0.5) + + thread = threading.Thread(target=tail_log, daemon=True) + thread.start() + + def append_training_log(self, line: str) -> None: + """Append line to training log.""" + log = self.query_one("#training-logs", Log) + log.write_line(line) + + @on(Button.Pressed, "#stop-training") + def handle_stop_training(self) -> None: + """Stop selected training job.""" + if not self.selected_job_id or self.selected_job_id not in self.jobs: + log = self.query_one("#training-logs", Log) + log.write_line("โš ๏ธ No job selected") + return + + job = self.jobs[self.selected_job_id] + if job.status == "running" and job.process: + job.process.terminate() + job.status = "stopped" + job.end_time = datetime.now() + + log = self.query_one("#training-logs", Log) + log.write_line(f"๐Ÿ›‘ Training job {job.id[:8]} stopped") + self.refresh_job_table() + + @on(Button.Pressed, "#resume-training") + async def handle_resume_training(self) -> None: + """Resume a stopped training job.""" + if not self.selected_job_id or self.selected_job_id not in self.jobs: + log = self.query_one("#training-logs", Log) + log.write_line("โš ๏ธ No job selected") + return + + job = self.jobs[self.selected_job_id] + if job.status in ["stopped", "failed"]: + await self.start_training_job(job.config_path) + + @on(Button.Pressed, "#clear-completed") + def handle_clear_completed(self) -> None: + """Clear completed jobs from the list.""" + completed_jobs = [ + job_id + for job_id, job in self.jobs.items() + if job.status in ["completed", "failed", "stopped"] + ] + + for job_id in completed_jobs: + del self.jobs[job_id] + + self.refresh_job_table() + log = self.query_one("#training-logs", Log) + log.write_line(f"๐Ÿงน Cleared {len(completed_jobs)} completed jobs") + + @on(Button.Pressed, "#refresh") + def handle_refresh(self) -> None: + """Refresh the job list and metrics.""" + self.refresh_job_table() + self.update_selected_job_metrics() + if self.selected_job_id: + self.load_job_logs() + + @on(Button.Pressed, "#view-logs") + def handle_view_logs(self) -> None: + """View full logs for selected job.""" + if not self.selected_job_id or self.selected_job_id not in self.jobs: + return + + job = self.jobs[self.selected_job_id] + if job.log_file and Path(job.log_file).exists(): + import subprocess + + subprocess.run(["less", job.log_file]) + + def action_new_training(self) -> None: + """Start a new training job.""" + self.handle_new_training() + + def action_stop_training(self) -> None: + """Stop selected training job.""" + self.handle_stop_training() + + def action_resume_training(self) -> None: + """Resume selected training job.""" + self.handle_resume_training() + + def action_refresh(self) -> None: + """Refresh the display.""" + self.handle_refresh()