Compare commits

..

2 Commits

Author SHA1 Message Date
Dan Saunders
c3e1882de5 progress 2025-08-22 02:43:16 -04:00
Dan Saunders
889b27ecf1 tui 2025-08-22 05:08:02 +00:00
32 changed files with 13408 additions and 10810 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -13,8 +13,8 @@ liger-kernel==0.6.1
packaging==23.2
huggingface_hub>=0.33.0
peft>=0.17.0
transformers==4.55.3
peft==0.17.0
transformers==4.55.2
tokenizers>=0.21.1
accelerate==1.10.0
datasets==4.0.0
@@ -72,3 +72,8 @@ axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.5
mistral-common==1.8.3
# TUI dependencies
textual==1.0.0
rich==14.1.0
tree_sitter_ruby==0.23.1

View File

@@ -14,13 +14,9 @@ class PreprocessCliArgs:
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
default=False,
default=None,
metadata={
"help": (
"[DEPRECATED] No longer supported. For streaming datasets, use "
"'axolotl train' and set 'streaming: true' in your YAML config, or "
"pass --streaming instead in the CLI."
)
"help": "Use IterableDataset for streaming processing of large datasets"
},
)

View File

@@ -344,6 +344,26 @@ def delinearize_llama4(model: str, output: str):
cli.add_command(lm_eval)
@cli.command()
def tui():
"""
Launch the Axolotl Terminal User Interface (TUI).
Provides an interactive interface for configuration management,
training monitoring, dataset handling, and model operations.
"""
try:
from axolotl.tui.app import run
run()
except ImportError:
click.echo(
"TUI dependencies not installed. Install with: pip install textual rich"
)
except Exception as e:
click.echo(f"Error launching TUI: {e}")
def main():
cli()

View File

@@ -35,20 +35,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
check_accelerate_default_config()
check_user_token()
if cli_args.iterable:
LOG.error(
"The --iterable CLI argument for 'axolotl preprocess' is no longer "
"supported. For training, set 'streaming: true' in your YAML config or "
"pass '--streaming' in your 'axolotl train' command for on-the-fly "
"preprocessing."
)
return
for key in ["skip_prepare_dataset", "pretraining_dataset"]:
if cfg.get(key):
LOG.error(
f"You have set `{key}:`. `preprocess` is not needed. Run the 'axolotl "
"train' CLI directly instead."
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead."
)
return

View File

@@ -55,11 +55,13 @@ def load_datasets(
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = getattr(cli_args, "iterable", False)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if (

View File

@@ -1,19 +1,18 @@
"""
Module containing dataset functionality.
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
concept of middlewares to wrap each dataset. We'll use the collators later on to pad the
datasets.
"""
from typing import Any
"""Module containing Dataset functionality"""
import torch
from datasets import Dataset, IterableDataset
from axolotl.utils.logging import get_logger
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = get_logger(__name__)
@@ -43,13 +42,10 @@ class TokenizedPromptDataset(Dataset):
**kwargs,
)
def process(self, dataset: Dataset | IterableDataset) -> Dataset | IterableDataset:
"""Apply filtering and tokenization."""
features = None
if not isinstance(dataset, IterableDataset):
features = dataset.features.keys()
def process(self, dataset):
features = dataset.features.keys()
map_kwargs: dict[str, Any] = {}
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 1_000
@@ -58,28 +54,18 @@ class TokenizedPromptDataset(Dataset):
hasattr(self.prompt_tokenizer, "filter_rows")
and self.prompt_tokenizer.filter_rows
):
filter_kwargs: dict[str, Any] = {"desc": "Strategy Filtering Rows"}
if not isinstance(dataset, IterableDataset):
filter_kwargs["num_proc"] = self.process_count
dataset = dataset.filter(
self.prompt_tokenizer.filter_rows,
**filter_kwargs,
num_proc=self.process_count,
desc="Strategy Filtering Rows",
)
map_kwargs = {
**map_kwargs,
"desc": "Tokenizing Prompts",
}
# Only add remove_columns for regular datasets
if not isinstance(dataset, IterableDataset):
map_kwargs["remove_columns"] = features
map_kwargs["num_proc"] = self.process_count
map_kwargs["keep_in_memory"] = self.keep_in_memory
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=self.process_count,
remove_columns=features,
keep_in_memory=self.keep_in_memory,
desc="Tokenizing Prompts",
**map_kwargs,
)
@@ -93,16 +79,140 @@ def wrap_dataset_for_tokenized_prompt(
map_kwargs = {}
if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
# Map the dataset and remove original columns
# For IterableDataset, features might be None until first iteration
remove_columns = None
if dataset.features is not None:
remove_columns = list(dataset.features.keys())
features = list(dataset.features.keys())
return dataset.map(
prompt_tokenizer.tokenize_prompt,
remove_columns=remove_columns,
remove_columns=features,
**map_kwargs,
)
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""Iterable dataset that returns constant length chunks of tokens from stream of
text files.
Args:
tokenizer: The processor used for processing the data.
dataset: Dataset with text files.
seq_length: Length of token sequences to return.
"""
def __init__( # pylint: disable=super-init-not-called
self,
tokenizer,
datasets,
seq_length=2048,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.datasets: list[IterableDataset] = datasets
self.seq_length = seq_length
vocab_size = len(tokenizer.get_vocab())
if vocab_size <= torch.iinfo(torch.int16).max:
self.tokens_dtype = torch.int16
elif vocab_size <= torch.iinfo(torch.int32).max:
self.tokens_dtype = torch.int32
else:
self.tokens_dtype = torch.int64
def __iter__(self):
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
for dataset in self.datasets:
idx = 0
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
example = next(iterator)
idx += 1
except StopIteration:
more_examples = False
example = None
add_concat_token = False
if example:
example_len = len(example["input_ids"])
add_concat_token = example["input_ids"][-1] != self.concat_token_id
else:
example_len = 0
if not example_len or (
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
: self.seq_length
]
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
LOG.warning(
"Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
idx = 1
if example:
# FIXME
# just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if add_concat_token:
input_ids.append(self.concat_token_id)
attention_mask.append(1)
labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
[idx * m for m in attention_mask], dtype=torch.int16
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)

View File

@@ -72,9 +72,10 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
builder_kwargs["message_field_training"] = message_field_training
chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml"))
format_message = (
lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment
)
def format_message(x):
return x
if chat_template == "chatml":
from axolotl.core.chat.format.chatml import format_message # noqa F811
if chat_template.startswith("llama3"):

216
src/axolotl/tui/README.md Normal file
View File

@@ -0,0 +1,216 @@
# Axolotl TUI (Terminal User Interface)
A comprehensive Terminal User Interface for Axolotl, providing an interactive way to manage configurations, training jobs, datasets, models, and system monitoring.
## Features
### 🏠 Main Dashboard
- **Welcome Screen**: Central hub with quick access to all features
- **Keyboard Navigation**: Efficient navigation with keyboard shortcuts
- **Screen Management**: Easy switching between different functional areas
### 📝 Configuration Management
- **YAML Editor**: Syntax-highlighted editor for Axolotl configurations
- **Real-time Validation**: Instant config validation with detailed error reporting
- **File Browser**: Navigate and select configuration files
- **Template Loading**: Load example configurations
- **Remote Config Support**: Load configurations from URLs
**Key Shortcuts:**
- `Ctrl+N`: New configuration
- `Ctrl+S`: Save configuration
- `Ctrl+V`: Validate configuration
- `Ctrl+E`: Toggle edit mode
### 🚀 Training Management
- **Job Launcher**: Start training with different launchers (accelerate, torchrun)
- **Real-time Monitoring**: Live training progress and metrics
- **Loss Visualization**: Sparkline charts for loss curves
- **Job Control**: Start, stop, resume, and manage multiple training jobs
- **Log Streaming**: Real-time log viewing and filtering
**Key Shortcuts:**
- `Ctrl+T`: New training job
- `Ctrl+R`: Resume training
- `Ctrl+X`: Stop training
- `R`: Refresh status
### 📊 Dataset Management
- **Dataset Browser**: Explore local and remote datasets
- **Preview & Statistics**: View dataset samples and metadata
- **Preprocessing**: Run dataset preprocessing with progress tracking
- **HuggingFace Integration**: Download and manage HF datasets
- **Format Detection**: Automatic dataset format recognition
**Key Shortcuts:**
- `Ctrl+P`: Preprocess dataset
- `Ctrl+V`: Preview dataset
- `Ctrl+I`: Dataset information
- `R`: Refresh dataset list
### 🤖 Model Management
- **Model Discovery**: Automatically find trained models
- **LoRA Operations**: Merge LoRA adapters with base models
- **Quantization**: Quantize models for deployment
- **Evaluation**: Run model evaluation benchmarks
- **Storage Info**: View model sizes and storage details
**Key Shortcuts:**
- `Ctrl+M`: Merge LoRA
- `Ctrl+Q`: Quantize model
- `Ctrl+E`: Evaluate model
- `R`: Refresh model list
### 💬 Inference & Testing
- **Interactive Chat**: Chat interface for model testing
- **Parameter Tuning**: Adjust inference parameters (temperature, top-p, max tokens)
- **Model Loading**: Load and switch between different models
- **Chat History**: Save and load conversation history
- **Gradio Integration**: Launch Gradio web interface
**Key Shortcuts:**
- `Ctrl+Enter`: Send message
- `Ctrl+C`: Clear chat
- `Ctrl+L`: Load model
- `Ctrl+S`: Save chat
### 📈 System Monitoring
- **Resource Monitoring**: Real-time CPU, GPU, and memory usage
- **Process Management**: View and manage running processes
- **Performance Graphs**: Historical usage charts with sparklines
- **GPU Information**: Detailed GPU status and memory usage
- **Temperature Monitoring**: System temperature tracking
**Key Shortcuts:**
- `R`: Refresh metrics
- `Ctrl+K`: Kill selected process
## Installation
### Dependencies
```bash
pip install textual==1.0.0 rich==14.1.0
```
### Launch TUI
```bash
# From command line
python -m axolotl.cli.main tui
# From Python code
from axolotl.tui.app import run
run()
```
## Architecture
### Screen Structure
```
AxolotlTUI (Main App)
├── WelcomeScreen (Dashboard)
├── ConfigScreen (Configuration Management)
├── TrainingScreen (Training Management)
├── DatasetScreen (Dataset Management)
├── ModelScreen (Model Management)
├── InferenceScreen (Inference & Testing)
└── MonitorScreen (System Monitoring)
```
### Key Components
- **BaseScreen**: Common functionality for all screens
- **Screen Navigation**: Stack-based screen management
- **Event Handling**: Reactive UI updates
- **Background Tasks**: Non-blocking operations
- **State Management**: Shared application state
### Integration Points
- **CLI Commands**: Seamless integration with existing axolotl CLI
- **Configuration System**: Uses axolotl's native config loading
- **Training Pipeline**: Integrates with axolotl training functions
- **Model Loading**: Compatible with axolotl model management
## Usage Examples
### 1. Creating a New Configuration
1. Launch TUI: `python -m axolotl.cli.main tui`
2. Select "Configuration Management" or press `C`
3. Press `Ctrl+N` for new configuration
4. Edit the template configuration
5. Press `Ctrl+V` to validate
6. Press `Ctrl+S` to save
### 2. Starting a Training Job
1. Navigate to "Training Management" or press `T`
2. Press `Ctrl+T` for new training job
3. Select configuration file and launcher
4. Monitor progress in real-time
5. View loss curves and logs
### 3. Interactive Model Testing
1. Go to "Inference & Testing" or press `I`
2. Load a trained model with `Ctrl+L`
3. Adjust inference parameters as needed
4. Start chatting with the model
5. Save conversation with `Ctrl+S`
## Navigation
### Global Shortcuts
- `Ctrl+Q`: Quit application
- `Escape`: Go back/close current screen
- `Tab`: Navigate between UI elements
- `Enter`: Select/activate element
- `Space`: Toggle switches/checkboxes
### Screen Shortcuts
Each screen has specific shortcuts displayed in the footer. Common patterns:
- `Ctrl+[Letter]`: Primary actions
- `R`: Refresh/reload
- `F1-F12`: Function keys for advanced features
## Customization
### Themes
The TUI uses Textual's theming system and can be customized by modifying the CSS in each screen class.
### Adding New Screens
1. Create a new screen class inheriting from `BaseScreen`
2. Implement the `compose()` method for UI layout
3. Add event handlers for user interactions
4. Register the screen in the main app navigation
### Extending Functionality
- Add new widgets to existing screens
- Implement custom data visualization
- Integrate with external tools and APIs
- Add new keyboard shortcuts
## Troubleshooting
### Common Issues
1. **Import Errors**: Ensure textual and rich are installed
2. **Permission Errors**: Check file system permissions for config directories
3. **GPU Monitoring**: Install pynvml for GPU monitoring features
4. **Config Validation**: Ensure axolotl dependencies are properly installed
### Debug Mode
Launch with debug logging:
```bash
TEXTUAL_LOG=DEBUG python -m axolotl.cli.main tui
```
### Performance
- Use `Ctrl+\` to open Textual's debug console
- Monitor memory usage with the system monitor
- Disable auto-refresh for better performance on slower systems
## Contributing
The TUI is designed to be extensible. Contributions are welcome for:
- New screen implementations
- Enhanced visualizations
- Better keyboard navigation
- Additional integrations
- Performance improvements
See the main Axolotl repository for contribution guidelines.

View File

@@ -0,0 +1 @@
"""Axolotl Terminal User Interface (TUI)."""

180
src/axolotl/tui/app.py Normal file
View File

@@ -0,0 +1,180 @@
"""Main TUI application for Axolotl."""
from textual import on
from textual.app import App, ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.screen import Screen
from textual.widgets import Button, Footer, Header, Static
from axolotl.tui.screens.config import ConfigScreen
from axolotl.tui.screens.datasets import DatasetScreen
from axolotl.tui.screens.inference import InferenceScreen
from axolotl.tui.screens.models import ModelScreen
from axolotl.tui.screens.monitor import MonitorScreen
from axolotl.tui.screens.training import TrainingScreen
class WelcomeScreen(Screen):
"""Welcome screen with main menu."""
BINDINGS = [
Binding("q", "quit", "Quit"),
Binding("c", "config", "Configuration"),
Binding("t", "training", "Training"),
Binding("d", "datasets", "Datasets"),
Binding("m", "models", "Models"),
Binding("i", "inference", "Inference"),
Binding("s", "monitor", "System Monitor"),
]
def compose(self) -> ComposeResult:
"""Compose the welcome screen."""
yield Header()
yield Container(
Static("🦾 Axolotl TUI", classes="title"),
Static(
"A Terminal User Interface for fine-tuning LLMs", classes="subtitle"
),
Container(
Button("Configuration Management [C]", id="config", variant="primary"),
Button("Training Management [T]", id="training", variant="primary"),
Button("Dataset Management [D]", id="datasets", variant="primary"),
Button("Model Management [M]", id="models", variant="primary"),
Button("Inference & Testing [I]", id="inference", variant="primary"),
Button("System Monitor [S]", id="monitor", variant="primary"),
classes="menu-container",
),
classes="welcome-container",
)
yield Footer()
def action_quit(self) -> None:
"""Quit the application."""
self.app.exit()
def action_config(self) -> None:
"""Navigate to config screen."""
self.app.push_screen(ConfigScreen())
def action_training(self) -> None:
"""Navigate to training screen."""
self.app.push_screen(TrainingScreen())
def action_datasets(self) -> None:
"""Navigate to datasets screen."""
self.app.push_screen(DatasetScreen())
def action_models(self) -> None:
"""Navigate to models screen."""
self.app.push_screen(ModelScreen())
def action_inference(self) -> None:
"""Navigate to inference screen."""
self.app.push_screen(InferenceScreen())
def action_monitor(self) -> None:
"""Navigate to monitor screen."""
self.app.push_screen(MonitorScreen())
@on(Button.Pressed, "#config")
def on_config_pressed(self) -> None:
"""Handle config button press."""
self.action_config()
@on(Button.Pressed, "#training")
def on_training_pressed(self) -> None:
"""Handle training button press."""
self.action_training()
@on(Button.Pressed, "#datasets")
def on_datasets_pressed(self) -> None:
"""Handle datasets button press."""
self.action_datasets()
@on(Button.Pressed, "#models")
def on_models_pressed(self) -> None:
"""Handle models button press."""
self.action_models()
@on(Button.Pressed, "#inference")
def on_inference_pressed(self) -> None:
"""Handle inference button press."""
self.action_inference()
@on(Button.Pressed, "#monitor")
def on_monitor_pressed(self) -> None:
"""Handle monitor button press."""
self.action_monitor()
class AxolotlTUI(App):
"""Main Axolotl TUI Application."""
CSS = """
.title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.subtitle {
text-align: center;
padding: 1;
color: $text-muted;
}
.welcome-container {
align: center middle;
height: 100%;
width: 100%;
}
.menu-container {
layout: vertical;
align: center middle;
padding: 2;
width: auto;
height: auto;
}
.menu-container Button {
width: 35;
margin: 1;
}
WelcomeScreen {
align: center middle;
}
"""
BINDINGS = [
Binding("ctrl+q", "quit", "Quit", priority=True),
Binding("escape", "back", "Back", priority=True),
]
def on_mount(self) -> None:
"""Called when the app is mounted."""
self.title = "Axolotl TUI"
self.sub_title = "Fine-tuning LLMs made easy"
self.push_screen(WelcomeScreen())
def action_quit(self) -> None:
"""Quit the application."""
self.exit()
def action_back(self) -> None:
"""Go back to previous screen."""
if len(self.screen_stack) > 1:
self.pop_screen()
def run():
"""Run the Axolotl TUI application."""
app = AxolotlTUI()
app.run()
if __name__ == "__main__":
run()

View File

@@ -0,0 +1 @@
"""TUI dialogs for Axolotl."""

View File

@@ -0,0 +1,112 @@
"""Training dialogs for Axolotl TUI."""
from pathlib import Path
from textual import on
from textual.app import ComposeResult
from textual.containers import Container
from textual.screen import ModalScreen
from textual.widgets import Button, Input, Label, Select, Static
class NewTrainingDialog(ModalScreen):
"""Dialog for starting a new training job."""
CSS = """
NewTrainingDialog {
align: center middle;
}
.dialog-container {
background: $surface;
border: thick $primary;
padding: 2;
width: 60;
height: auto;
}
.dialog-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.form-field {
margin: 1 0;
}
.form-label {
margin: 0 0 1 0;
color: $text-muted;
}
.button-container {
layout: horizontal;
align: center middle;
margin: 2 0 0 0;
}
.button-container Button {
margin: 0 1;
}
"""
def compose(self) -> ComposeResult:
"""Compose the dialog."""
yield Container(
Static("Start New Training Job", classes="dialog-title"),
Container(
Label("Configuration File:", classes="form-label"),
Input(
placeholder="Path to config YAML file",
id="config-path",
value="/workspace/configs/",
),
classes="form-field",
),
Container(
Label("Launcher:", classes="form-label"),
Select(
[
("accelerate", "Accelerate (Recommended)"),
("torchrun", "TorchRun"),
("deepspeed", "DeepSpeed"),
],
id="launcher",
value="accelerate",
),
classes="form-field",
),
Container(
Button("Start Training", variant="primary", id="start"),
Button("Cancel", variant="default", id="cancel"),
classes="button-container",
),
classes="dialog-container",
)
@on(Button.Pressed, "#start")
def handle_start(self) -> None:
"""Handle start button press."""
config_input = self.query_one("#config-path", Input)
launcher_select = self.query_one("#launcher", Select)
config_path = config_input.value.strip()
if not config_path:
return
if not Path(config_path).exists():
return
result = {
"config_path": config_path,
"launcher": launcher_select.value,
}
self.dismiss(result)
@on(Button.Pressed, "#cancel")
def handle_cancel(self) -> None:
"""Handle cancel button press."""
self.dismiss(None)

View File

@@ -0,0 +1 @@
"""TUI screens for Axolotl."""

View File

@@ -0,0 +1,50 @@
"""Base screen class for Axolotl TUI screens."""
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.screen import Screen
from textual.widgets import Footer, Header, Static
class BaseScreen(Screen):
"""Base class for all Axolotl TUI screens."""
BINDINGS = [
Binding("escape", "back", "Back"),
Binding("q", "quit", "Quit"),
]
def __init__(self, title: str = "Axolotl", subtitle: str = ""):
"""Initialize the base screen.
Args:
title: The screen title
subtitle: Optional subtitle for the screen
"""
super().__init__()
self.screen_title = title
self.screen_subtitle = subtitle
def compose(self) -> ComposeResult:
"""Compose the base screen layout."""
yield Header()
yield Container(
Static(f"🦾 {self.screen_title}", classes="screen-title"),
(
Static(self.screen_subtitle, classes="screen-subtitle")
if self.screen_subtitle
else Static("")
),
Container(id="content"),
id="main-container",
)
yield Footer()
def action_back(self) -> None:
"""Go back to previous screen."""
self.app.pop_screen()
def action_quit(self) -> None:
"""Quit the application."""
self.app.exit()

View File

@@ -0,0 +1,376 @@
"""Configuration management screen for Axolotl TUI."""
import os
from pathlib import Path
from typing import Optional
import yaml
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.reactive import reactive
from textual.widgets import (
Button,
DirectoryTree,
Footer,
Header,
Label,
Log,
Static,
TextArea,
)
from axolotl.tui.screens.base import BaseScreen
class ConfigScreen(BaseScreen):
"""Configuration management screen."""
BINDINGS = [
Binding("ctrl+n", "new_config", "New Config"),
Binding("ctrl+o", "open_config", "Open Config"),
Binding("ctrl+s", "save_config", "Save Config"),
Binding("ctrl+v", "validate_config", "Validate"),
Binding("ctrl+e", "edit_mode", "Toggle Edit Mode"),
]
CSS = """
.config-container {
layout: horizontal;
height: 100%;
}
.file-browser {
width: 30%;
border: solid $primary;
padding: 1;
margin: 1;
}
.config-editor {
width: 70%;
border: solid $secondary;
padding: 1;
margin: 1;
}
.config-form {
height: 80%;
}
.config-actions {
layout: horizontal;
height: 3;
align: center middle;
padding: 1;
}
.config-actions Button {
margin: 0 1;
}
TextArea {
height: 100%;
}
.validation-log {
height: 20%;
border: solid $warning;
padding: 1;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
"""
def __init__(self):
"""Initialize the config screen."""
super().__init__(
title="Configuration Management",
subtitle="Create, edit, and validate Axolotl configurations",
)
self.current_config_path: Optional[Path] = None
self.edit_mode = reactive(False)
self.config_data = {}
def compose(self) -> ComposeResult:
"""Compose the config screen layout."""
yield Header()
yield Container(
Static("🦾 Configuration Management", classes="screen-title"),
Static(
"Create, edit, and validate Axolotl configurations",
classes="screen-subtitle",
),
Container(
Container(
Label("Config Files"),
DirectoryTree(
(
Path("/workspace/configs")
if Path("/workspace/configs").exists()
else Path.cwd()
),
id="config-tree",
),
classes="file-browser",
),
Container(
Container(
TextArea(
"",
language="yaml",
theme="monokai",
id="config-editor",
read_only=True,
),
classes="config-form",
),
Container(
Button("New", id="new-config", variant="primary"),
Button("Open", id="open-config", variant="primary"),
Button("Save", id="save-config", variant="success"),
Button("Validate", id="validate-config", variant="warning"),
Button("Edit Mode", id="toggle-edit", variant="default"),
Button("Load Example", id="load-example", variant="default"),
classes="config-actions",
),
Container(
Log(id="validation-log"),
classes="validation-log",
),
classes="config-editor",
),
classes="config-container",
),
id="content",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
tree = self.query_one("#config-tree", DirectoryTree)
tree.show_root = False
tree.guide_depth = 3
log = self.query_one("#validation-log", Log)
log.write_line("Ready to load configuration files...")
@on(DirectoryTree.FileSelected)
def handle_file_selected(self, event: DirectoryTree.FileSelected) -> None:
"""Handle file selection from the directory tree."""
if event.path.suffix in [".yaml", ".yml"]:
self.load_config_file(event.path)
def load_config_file(self, path: Path) -> None:
"""Load a configuration file."""
self.current_config_path = path
try:
with open(path, "r") as f:
content = f.read()
self.config_data = yaml.safe_load(content)
editor = self.query_one("#config-editor", TextArea)
editor.load_text(content)
log = self.query_one("#validation-log", Log)
log.clear()
log.write_line(f"✅ Loaded: {path.name}")
except Exception as e:
log = self.query_one("#validation-log", Log)
log.write_line(f"❌ Error loading {path.name}: {str(e)}")
@on(Button.Pressed, "#new-config")
def handle_new_config(self) -> None:
"""Create a new configuration."""
template = """# Axolotl Configuration
base_model:
model_type:
tokenizer_type:
# Dataset Configuration
datasets:
- path:
type:
# Training Configuration
output_dir: ./outputs
num_epochs: 3
micro_batch_size: 1
gradient_accumulation_steps: 4
learning_rate: 0.00002
warmup_steps: 100
eval_steps: 100
save_steps: 500
# LoRA Configuration (optional)
adapter: lora
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
# Training optimizations
gradient_checkpointing: true
flash_attention: true
bf16: auto
tf32: true
# Logging
logging_steps: 10
wandb_project:
wandb_entity:
"""
editor = self.query_one("#config-editor", TextArea)
editor.load_text(template)
editor.read_only = False
self.edit_mode = True
self.update_edit_button()
log = self.query_one("#validation-log", Log)
log.clear()
log.write_line("📝 New configuration created. Edit and save when ready.")
@on(Button.Pressed, "#save-config")
def handle_save_config(self) -> None:
"""Save the current configuration."""
editor = self.query_one("#config-editor", TextArea)
content = editor.text
if not content.strip():
log = self.query_one("#validation-log", Log)
log.write_line("⚠️ Cannot save empty configuration")
return
if not self.current_config_path:
default_path = Path("/workspace/configs/new_config.yaml")
default_path.parent.mkdir(parents=True, exist_ok=True)
self.current_config_path = default_path
try:
with open(self.current_config_path, "w") as f:
f.write(content)
log = self.query_one("#validation-log", Log)
log.write_line(f"💾 Saved: {self.current_config_path.name}")
except Exception as e:
log = self.query_one("#validation-log", Log)
log.write_line(f"❌ Error saving: {str(e)}")
@on(Button.Pressed, "#validate-config")
@work(thread=True)
async def handle_validate_config(self) -> None:
"""Validate the current configuration."""
editor = self.query_one("#config-editor", TextArea)
content = editor.text
if not content.strip():
log = self.query_one("#validation-log", Log)
log.write_line("⚠️ No configuration to validate")
return
log = self.query_one("#validation-log", Log)
log.clear()
log.write_line("🔍 Validating configuration...")
try:
import tempfile
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
f.write(content)
temp_path = f.name
from argparse import Namespace
from axolotl.cli.config import check_user_config
args = Namespace(
config=temp_path,
debug=False,
debug_text_only=False,
debug_num_examples=5,
accelerate_config=None,
multi_gpu=False,
)
check_user_config(args)
log.write_line("✅ Configuration is valid!")
os.unlink(temp_path)
except Exception as e:
log.write_line(f"❌ Validation failed: {str(e)}")
if "temp_path" in locals():
os.unlink(temp_path)
@on(Button.Pressed, "#toggle-edit")
def handle_toggle_edit(self) -> None:
"""Toggle edit mode for the configuration."""
editor = self.query_one("#config-editor", TextArea)
self.edit_mode = not self.edit_mode
editor.read_only = not self.edit_mode
self.update_edit_button()
log = self.query_one("#validation-log", Log)
if self.edit_mode:
log.write_line("✏️ Edit mode enabled")
else:
log.write_line("👁️ View mode enabled")
@on(Button.Pressed, "#load-example")
async def handle_load_example(self) -> None:
"""Load an example configuration."""
examples_dir = Path("/workspace/axolotl/examples")
if not examples_dir.exists():
log = self.query_one("#validation-log", Log)
log.write_line("⚠️ Examples directory not found")
return
yaml_files = list(examples_dir.glob("**/*.yml")) + list(
examples_dir.glob("**/*.yaml")
)
if yaml_files:
self.load_config_file(yaml_files[0])
log = self.query_one("#validation-log", Log)
log.write_line(f"📚 Loaded example: {yaml_files[0].name}")
def update_edit_button(self) -> None:
"""Update the edit button appearance."""
button = self.query_one("#toggle-edit", Button)
if self.edit_mode:
button.variant = "warning"
button.label = "Edit Mode: ON"
else:
button.variant = "default"
button.label = "Edit Mode: OFF"
def action_new_config(self) -> None:
"""Create a new configuration."""
self.handle_new_config()
def action_save_config(self) -> None:
"""Save the current configuration."""
self.handle_save_config()
def action_validate_config(self) -> None:
"""Validate the current configuration."""
self.handle_validate_config()
def action_edit_mode(self) -> None:
"""Toggle edit mode."""
self.handle_toggle_edit()

View File

@@ -0,0 +1,440 @@
"""Dataset management screen for Axolotl TUI."""
import json
from pathlib import Path
from typing import Dict, Optional
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Label,
Log,
ProgressBar,
Static,
TextArea,
)
from axolotl.tui.screens.base import BaseScreen
class DatasetScreen(BaseScreen):
"""Dataset management screen."""
BINDINGS = [
Binding("ctrl+p", "preprocess", "Preprocess"),
Binding("ctrl+v", "preview", "Preview"),
Binding("ctrl+i", "info", "Info"),
Binding("r", "refresh", "Refresh"),
]
CSS = """
.dataset-container {
layout: horizontal;
height: 100%;
}
.dataset-list {
width: 40%;
border: solid $primary;
padding: 1;
margin: 1;
}
.dataset-details {
width: 60%;
border: solid $secondary;
padding: 1;
margin: 1;
}
.dataset-actions {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
}
.dataset-actions Button {
margin: 0 1;
}
DataTable {
height: 100%;
}
.preview-container {
height: 100%;
border: solid $primary;
padding: 1;
}
TextArea {
height: 100%;
}
.stats-container {
layout: vertical;
padding: 1;
}
.stat-row {
layout: horizontal;
padding: 0 0 1 0;
}
.stat-label {
width: 50%;
color: $text-muted;
}
.stat-value {
width: 50%;
text-align: right;
text-style: bold;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
.progress-container {
padding: 1;
border: solid $warning;
margin: 1;
}
"""
def __init__(self):
"""Initialize the dataset screen."""
super().__init__(
title="Dataset Management",
subtitle="Browse, preview, and preprocess datasets",
)
self.datasets: Dict[str, Dict] = {}
self.selected_dataset: Optional[str] = None
self.preprocessing_active = False
def compose(self) -> ComposeResult:
"""Compose the dataset screen layout."""
yield Header()
yield Container(
Static("🦾 Dataset Management", classes="screen-title"),
Static(
"Browse, preview, and preprocess datasets", classes="screen-subtitle"
),
Container(
Container(
Label("Available Datasets"),
DataTable(id="dataset-table"),
Container(
Button("Load Dataset", id="load-dataset", variant="primary"),
Button("Preprocess", id="preprocess", variant="success"),
Button("Download", id="download", variant="default"),
Button("Refresh", id="refresh", variant="default"),
classes="dataset-actions",
),
classes="dataset-list",
),
Container(
TextArea("", id="dataset-preview", read_only=True),
Container(
Static("Dataset Name:", classes="stat-label"),
Static("-", id="stat-name", classes="stat-value"),
Static("Type:", classes="stat-label"),
Static("-", id="stat-type", classes="stat-value"),
Static("Size:", classes="stat-label"),
Static("-", id="stat-size", classes="stat-value"),
Static("Samples:", classes="stat-label"),
Static("-", id="stat-samples", classes="stat-value"),
Static("Features:", classes="stat-label"),
Static("-", id="stat-features", classes="stat-value"),
Static("Format:", classes="stat-label"),
Static("-", id="stat-format", classes="stat-value"),
Static("Preprocessed:", classes="stat-label"),
Static("-", id="stat-preprocessed", classes="stat-value"),
),
Log(id="processing-log"),
ProgressBar(total=100, id="preprocessing-progress"),
classes="dataset-details",
),
classes="dataset-container",
),
id="content",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.setup_dataset_table()
self.load_datasets()
log = self.query_one("#processing-log", Log)
log.write_line("Dataset manager ready.")
def setup_dataset_table(self) -> None:
"""Setup the dataset table."""
table = self.query_one("#dataset-table", DataTable)
table.add_columns("Name", "Type", "Size", "Status")
table.cursor_type = "row"
table.zebra_stripes = True
@work(thread=True)
async def load_datasets(self) -> None:
"""Load available datasets."""
# Check for local datasets
datasets_dir = Path("/workspace/datasets")
if datasets_dir.exists():
for dataset_path in datasets_dir.glob("*"):
if dataset_path.is_dir():
self.datasets[dataset_path.name] = {
"name": dataset_path.name,
"path": str(dataset_path),
"type": "local",
"size": self.get_dir_size(dataset_path),
"status": "available",
}
# Check for HuggingFace datasets in configs
configs_dir = Path("/workspace/configs")
if configs_dir.exists():
for config_file in configs_dir.glob("*.yaml"):
try:
import yaml
with open(config_file) as f:
config = yaml.safe_load(f)
if "datasets" in config:
for ds in config.get("datasets", []):
if "path" in ds:
ds_name = ds["path"].split("/")[-1]
self.datasets[ds_name] = {
"name": ds_name,
"path": ds["path"],
"type": ds.get("type", "huggingface"),
"size": "Unknown",
"status": "remote",
}
except Exception:
pass
self.refresh_dataset_table()
def get_dir_size(self, path: Path) -> str:
"""Get human-readable directory size."""
total_size = sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
for unit in ["B", "KB", "MB", "GB"]:
if total_size < 1024.0:
return f"{total_size:.2f} {unit}"
total_size /= 1024.0
return f"{total_size:.2f} TB"
def refresh_dataset_table(self) -> None:
"""Refresh the dataset table."""
table = self.query_one("#dataset-table", DataTable)
table.clear()
for name, info in self.datasets.items():
table.add_row(
name[:30],
info["type"],
info["size"],
info["status"],
)
@on(DataTable.RowSelected)
def handle_dataset_selected(self, event: DataTable.RowSelected) -> None:
"""Handle dataset selection from table."""
if event.cursor_row >= 0:
dataset_names = list(self.datasets.keys())
if event.cursor_row < len(dataset_names):
self.selected_dataset = dataset_names[event.cursor_row]
self.load_dataset_preview()
self.update_dataset_stats()
@work(thread=True)
async def load_dataset_preview(self) -> None:
"""Load preview of selected dataset."""
if not self.selected_dataset:
return
dataset_info = self.datasets[self.selected_dataset]
preview_text = ""
try:
if dataset_info["type"] == "local" and Path(dataset_info["path"]).exists():
# Load first few samples from local dataset
sample_files = list(Path(dataset_info["path"]).glob("*.json"))[:3]
samples = []
for sample_file in sample_files:
with open(sample_file) as f:
samples.append(json.load(f))
preview_text = json.dumps(samples, indent=2)
else:
# Show dataset info for remote datasets
preview_text = json.dumps(dataset_info, indent=2)
except Exception as e:
preview_text = f"Error loading preview: {str(e)}"
preview = self.query_one("#dataset-preview", TextArea)
preview.load_text(preview_text)
def update_dataset_stats(self) -> None:
"""Update dataset statistics display."""
if not self.selected_dataset:
return
info = self.datasets[self.selected_dataset]
self.query_one("#stat-name", Static).update(info["name"])
self.query_one("#stat-type", Static).update(info["type"])
self.query_one("#stat-size", Static).update(info["size"])
self.query_one("#stat-samples", Static).update("N/A")
self.query_one("#stat-features", Static).update("N/A")
self.query_one("#stat-format", Static).update("JSON")
self.query_one("#stat-preprocessed", Static).update("No")
@on(Button.Pressed, "#preprocess")
@work(thread=True)
async def handle_preprocess(self) -> None:
"""Preprocess selected dataset."""
if not self.selected_dataset or self.preprocessing_active:
return
self.preprocessing_active = True
dataset_info = self.datasets[self.selected_dataset]
log = self.query_one("#processing-log", Log)
log.clear()
log.write_line(f"🔄 Starting preprocessing for {self.selected_dataset}...")
progress = self.query_one("#preprocessing-progress", ProgressBar)
progress.update(progress=0)
try:
import subprocess
import tempfile
# Create a temporary config for preprocessing
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
config = {
"datasets": [
{
"path": dataset_info["path"],
"type": dataset_info.get("type", "alpaca"),
}
],
"output_dir": f"/tmp/preprocessed_{self.selected_dataset}",
}
import yaml
yaml.dump(config, f)
temp_config = f.name
# Run preprocessing
cmd = ["python", "-m", "axolotl.cli.preprocess", temp_config]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
# Monitor progress
for line in process.stdout:
log.write_line(line.strip())
# Update progress bar based on output
if "Processing" in line:
progress.advance(10)
process.wait()
if process.returncode == 0:
log.write_line("✅ Preprocessing completed successfully!")
dataset_info["status"] = "preprocessed"
progress.update(progress=100)
else:
log.write_line(
f"❌ Preprocessing failed with code {process.returncode}"
)
import os
os.unlink(temp_config)
except Exception as e:
log.write_line(f"❌ Error during preprocessing: {str(e)}")
finally:
self.preprocessing_active = False
self.refresh_dataset_table()
@on(Button.Pressed, "#load-dataset")
async def handle_load_dataset(self) -> None:
"""Load a new dataset."""
log = self.query_one("#processing-log", Log)
log.write_line("📦 Load dataset functionality coming soon...")
@on(Button.Pressed, "#download")
@work(thread=True)
async def handle_download(self) -> None:
"""Download a remote dataset."""
if not self.selected_dataset:
return
dataset_info = self.datasets[self.selected_dataset]
if dataset_info["type"] != "huggingface":
return
log = self.query_one("#processing-log", Log)
log.clear()
log.write_line(f"📥 Downloading {self.selected_dataset} from HuggingFace...")
try:
from datasets import load_dataset
dataset = load_dataset(dataset_info["path"])
save_path = Path(f"/workspace/datasets/{self.selected_dataset}")
save_path.mkdir(parents=True, exist_ok=True)
dataset.save_to_disk(str(save_path))
log.write_line(f"✅ Downloaded to {save_path}")
dataset_info["type"] = "local"
dataset_info["status"] = "available"
dataset_info["path"] = str(save_path)
self.refresh_dataset_table()
except Exception as e:
log.write_line(f"❌ Download failed: {str(e)}")
@on(Button.Pressed, "#refresh")
def handle_refresh(self) -> None:
"""Refresh dataset list."""
self.load_datasets()
def action_preprocess(self) -> None:
"""Preprocess selected dataset."""
self.handle_preprocess()
def action_refresh(self) -> None:
"""Refresh dataset list."""
self.handle_refresh()

View File

@@ -0,0 +1,445 @@
"""Inference and testing screen for Axolotl TUI."""
from pathlib import Path
from typing import Dict, List, Optional
from textual import events, on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.widgets import (
Button,
Input,
Label,
Log,
Select,
Static,
TextArea,
)
from axolotl.tui.screens.base import BaseScreen
class InferenceScreen(BaseScreen):
"""Inference and testing screen."""
BINDINGS = [
Binding("ctrl+enter", "send_message", "Send"),
Binding("ctrl+c", "clear_chat", "Clear"),
Binding("ctrl+l", "load_model", "Load Model"),
Binding("ctrl+s", "save_chat", "Save Chat"),
]
CSS = """
.inference-container {
layout: horizontal;
height: 100%;
}
.model-selector {
width: 30%;
border: solid $primary;
padding: 1;
margin: 1;
}
.chat-interface {
width: 70%;
border: solid $secondary;
padding: 1;
margin: 1;
}
.chat-history {
height: 70%;
border: solid $primary;
padding: 1;
margin: 0 0 1 0;
}
.input-area {
height: 20%;
border: solid $warning;
padding: 1;
margin: 0 0 1 0;
}
.chat-controls {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
}
.chat-controls Button {
margin: 0 1;
}
.model-info {
padding: 1;
border: solid $surface;
margin: 1 0;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
TextArea {
height: 100%;
}
Log {
height: 100%;
}
"""
def __init__(self):
"""Initialize the inference screen."""
super().__init__(
title="Inference & Testing", subtitle="Interactive chat and model testing"
)
self.loaded_model: Optional[str] = None
self.chat_history: List[Dict[str, str]] = []
def compose(self) -> ComposeResult:
"""Compose the inference screen layout."""
yield Container(
Static("🦾 Inference & Testing", classes="screen-title"),
Static("Interactive chat and model testing", classes="screen-subtitle"),
Container(
Container(
Label("Model Selection"),
Select(
[("No model loaded", "none")],
id="model-select",
value="none",
),
Container(
Button("Load Model", id="load-model", variant="primary"),
Button("Unload", id="unload-model", variant="default"),
Button("Gradio UI", id="gradio-ui", variant="success"),
),
Container(
Static("No model loaded", id="model-status"),
classes="model-info",
),
Label("Inference Parameters"),
Container(
Label("Temperature:"),
Input(value="0.7", id="temperature"),
Label("Max Tokens:"),
Input(value="256", id="max-tokens"),
Label("Top P:"),
Input(value="0.9", id="top-p"),
),
classes="model-selector",
),
Container(
Container(
Log(id="chat-history"),
classes="chat-history",
),
Container(
TextArea(
id="message-input",
),
classes="input-area",
),
Container(
Button("Send [Ctrl+Enter]", id="send", variant="primary"),
Button("Clear Chat", id="clear", variant="warning"),
Button("Save Chat", id="save-chat", variant="default"),
Button("Load Examples", id="load-examples", variant="default"),
classes="chat-controls",
),
classes="chat-interface",
),
classes="inference-container",
),
id="content",
)
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.load_available_models()
chat = self.query_one("#chat-history", Log)
chat.write_line("💬 Welcome to Axolotl Inference!")
chat.write_line("Load a model to start chatting.")
@work(thread=True)
async def load_available_models(self) -> None:
"""Load list of available models."""
models = [("No model loaded", "none")]
chat = self.query_one("#chat-history", Log)
chat.write_line("🔍 Scanning for available models...")
# Check for trained models
outputs_dir = Path("./outputs")
chat.write_line(f"Checking outputs directory: {outputs_dir.absolute()}")
if outputs_dir.exists():
found_models = 0
for model_dir in outputs_dir.glob("*"):
if model_dir.is_dir():
# Look for various model file types
model_files = (
list(model_dir.glob("pytorch_model.bin"))
+ list(model_dir.glob("model.safetensors"))
+ list(model_dir.glob("*.bin"))
+ list(model_dir.glob("*.safetensors"))
)
if model_files:
models.append((model_dir.name, str(model_dir)))
found_models += 1
chat.write_line(f"Found {found_models} trained models in outputs/")
else:
chat.write_line("outputs/ directory not found")
# Add some example/demo models for testing
models.extend(
[
("Demo: GPT-2 Small", "gpt2"),
("Demo: TinyLlama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"),
("Demo: Phi-2", "microsoft/phi-2"),
]
)
select = self.query_one("#model-select", Select)
select.set_options(models)
chat.write_line(f"✅ Loaded {len(models)} models in dropdown")
@on(Button.Pressed, "#load-model")
@work(thread=True)
async def handle_load_model(self) -> None:
"""Load selected model for inference."""
select = self.query_one("#model-select", Select)
if select.value == "none":
return
chat = self.query_one("#chat-history", Log)
chat.write_line(f"🔄 Loading model: {select.value}")
status = self.query_one("#model-status", Static)
status.update("Loading...")
try:
# Simulate model loading (in real implementation, would load the actual model)
import time
time.sleep(2) # Simulate loading time
self.loaded_model = select.value
status.update(f"✅ Loaded: {Path(select.value).name}")
chat.write_line("✅ Model loaded successfully!")
chat.write_line("You can now start chatting.")
except Exception as e:
status.update("❌ Failed to load")
chat.write_line(f"❌ Failed to load model: {str(e)}")
@on(Button.Pressed, "#send")
async def handle_send_message(self) -> None:
"""Send message to model."""
if not self.loaded_model:
chat = self.query_one("#chat-history", Log)
chat.write_line("⚠️ Please load a model first")
return
message_input = self.query_one("#message-input", TextArea)
message = message_input.text.strip()
if not message:
return
# Add user message to chat
chat = self.query_one("#chat-history", Log)
chat.write_line(f"👤 User: {message}")
# Clear input
message_input.clear()
# Add to history
self.chat_history.append({"role": "user", "content": message})
# Generate response (placeholder)
self.generate_response(message)
@on(TextArea.Changed, "#message-input")
def on_message_input_changed(self, event: TextArea.Changed) -> None:
"""Handle changes to the message input."""
# This could be used for features like typing indicators
pass
def on_key(self, event: events.Key) -> None:
"""Handle key events globally."""
# Check if we're focused on the message input and Ctrl+Enter is pressed
focused = self.focused
if focused and focused.id == "message-input" and event.key == "ctrl+enter":
event.prevent_default()
self.handle_send_message()
@work(thread=True)
async def generate_response(self, message: str) -> None:
"""Generate model response."""
chat = self.query_one("#chat-history", Log)
chat.write_line("🤖 Assistant: Thinking...")
try:
# Get inference parameters
float(self.query_one("#temperature", Input).value)
int(self.query_one("#max-tokens", Input).value)
float(self.query_one("#top-p", Input).value)
if not self.loaded_model or self.loaded_model == "none":
response = "I don't have a model loaded yet. Please load a model first using the 'Load Model' button."
elif self.loaded_model.startswith("gpt2"):
# Simple response for GPT-2
responses = [
f"Thanks for your message: '{message}'. I'm a GPT-2 model running in demo mode.",
"I understand you're testing the interface. GPT-2 models are great for experimentation!",
"This is a simulated GPT-2 response. In a real setup, I'd generate text based on your input.",
f"GPT-2 here! You said: '{message}'. I'd normally continue this conversation creatively.",
]
import random
response = random.choice(responses)
elif "llama" in self.loaded_model.lower():
# Response for Llama models
response = f"🦙 LLaMA model here! You asked: '{message}'. I'm designed for helpful, harmless, and honest conversations. How can I assist you today?"
elif "phi" in self.loaded_model.lower():
# Response for Phi models
response = f"Phi model responding! Your message: '{message}'. I'm optimized for reasoning and code tasks. What would you like to explore?"
else:
# Generic response for other models
response = f"Model '{self.loaded_model}' responding to: '{message}'. I'm ready to help with your questions!"
# Simulate inference time
import time
time.sleep(0.5)
# Clear the "thinking" message and show response
chat.write_line(f"🤖 Assistant: {response}")
# Add to history
self.chat_history.append({"role": "assistant", "content": response})
except Exception as e:
chat.write_line(f"❌ Error generating response: {str(e)}")
@on(Button.Pressed, "#clear")
def handle_clear_chat(self) -> None:
"""Clear chat history."""
chat = self.query_one("#chat-history", Log)
chat.clear()
self.chat_history = []
chat.write_line("💬 Chat cleared. Start a new conversation!")
@on(Button.Pressed, "#save-chat")
def handle_save_chat(self) -> None:
"""Save chat history to file."""
if not self.chat_history:
chat = self.query_one("#chat-history", Log)
chat.write_line("⚠️ No chat history to save")
return
try:
import json
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"chat_history_{timestamp}.json"
with open(filename, "w") as f:
json.dump(self.chat_history, f, indent=2)
chat = self.query_one("#chat-history", Log)
chat.write_line(f"💾 Chat saved to {filename}")
except Exception as e:
chat = self.query_one("#chat-history", Log)
chat.write_line(f"❌ Error saving chat: {str(e)}")
@on(Button.Pressed, "#load-examples")
def handle_load_examples(self) -> None:
"""Load example prompts."""
examples = [
"Explain the concept of machine learning in simple terms.",
"Write a Python function to calculate fibonacci numbers.",
"What are the benefits of fine-tuning language models?",
"Describe the difference between supervised and unsupervised learning.",
]
chat = self.query_one("#chat-history", Log)
chat.write_line("📚 Example prompts:")
for i, example in enumerate(examples, 1):
chat.write_line(f"{i}. {example}")
chat.write_line("Copy and paste any example to try it out!")
@on(Button.Pressed, "#gradio-ui")
@work(thread=True)
async def handle_gradio_ui(self) -> None:
"""Launch Gradio web interface."""
chat = self.query_one("#chat-history", Log)
chat.write_line("🌐 Launching Gradio web interface...")
try:
import subprocess
if self.loaded_model:
cmd = [
"python",
"-m",
"axolotl.cli.inference",
self.loaded_model,
"--gradio",
]
else:
chat.write_line("⚠️ No model loaded. Loading default interface...")
cmd = ["python", "-m", "axolotl.cli.inference", "--gradio"]
subprocess.Popen(cmd)
chat.write_line("✅ Gradio interface launched! Check your browser.")
except Exception as e:
chat.write_line(f"❌ Error launching Gradio: {str(e)}")
@on(Button.Pressed, "#unload-model")
def handle_unload_model(self) -> None:
"""Unload current model."""
self.loaded_model = None
status = self.query_one("#model-status", Static)
status.update("No model loaded")
select = self.query_one("#model-select", Select)
select.value = "none"
chat = self.query_one("#chat-history", Log)
chat.write_line("🔄 Model unloaded")
def action_send_message(self) -> None:
"""Send message action."""
self.handle_send_message()
def action_clear_chat(self) -> None:
"""Clear chat action."""
self.handle_clear_chat()
def action_load_model(self) -> None:
"""Load model action."""
self.handle_load_model()
def action_save_chat(self) -> None:
"""Save chat action."""
self.handle_save_chat()

View File

@@ -0,0 +1,373 @@
"""Model management screen for Axolotl TUI."""
from pathlib import Path
from typing import Dict, Optional
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container, ScrollableContainer
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Label,
Log,
ProgressBar,
Static,
TabbedContent,
TabPane,
)
from axolotl.tui.screens.base import BaseScreen
class ModelScreen(BaseScreen):
"""Model management screen."""
BINDINGS = [
Binding("ctrl+m", "merge_lora", "Merge LoRA"),
Binding("ctrl+q", "quantize", "Quantize"),
Binding("ctrl+e", "evaluate", "Evaluate"),
Binding("r", "refresh", "Refresh"),
]
CSS = """
.model-container {
layout: horizontal;
height: 100%;
}
.model-list {
width: 50%;
border: solid $primary;
padding: 1;
margin: 1;
}
.model-operations {
width: 50%;
border: solid $secondary;
padding: 1;
margin: 1;
}
.model-actions {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
}
.model-actions Button {
margin: 0 1;
}
DataTable {
height: 80%;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
"""
def __init__(self):
"""Initialize the model screen."""
super().__init__(
title="Model Management",
subtitle="Manage trained models, merge LoRA adapters, and quantize models",
)
self.models: Dict[str, Dict] = {}
self.selected_model: Optional[str] = None
def compose(self) -> ComposeResult:
"""Compose the model screen layout."""
yield Header()
with Container(id="content"):
yield Static("🦾 Model Management", classes="screen-title")
yield Static(
"Manage trained models, merge LoRA adapters, and quantize models",
classes="screen-subtitle",
)
with Container(classes="model-container"):
with Container(classes="model-list"):
yield Label("Available Models")
yield DataTable(id="model-table")
with Container(classes="model-actions"):
yield Button("Merge LoRA", id="merge-lora", variant="primary")
yield Button("Quantize", id="quantize", variant="success")
yield Button("Evaluate", id="evaluate", variant="warning")
yield Button("Refresh", id="refresh", variant="default")
with Container(classes="model-operations"):
with TabbedContent():
with TabPane("Operations"):
with Container():
yield Log(id="operations-log")
with Container():
yield Label("Operation Progress:")
yield ProgressBar(
total=100,
id="operation-progress",
)
with TabPane("Model Info"):
with ScrollableContainer():
yield Static(
"Model information will appear here",
id="model-info",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.setup_model_table()
self.load_models()
log = self.query_one("#operations-log", Log)
log.write_line("Model manager ready.")
def setup_model_table(self) -> None:
"""Setup the model table."""
table = self.query_one("#model-table", DataTable)
table.add_columns("Name", "Type", "Size", "Status")
table.cursor_type = "row"
table.zebra_stripes = True
@work(thread=True)
async def load_models(self) -> None:
"""Load available models."""
# Check outputs directory for trained models
outputs_dir = Path("./outputs")
if outputs_dir.exists():
for model_dir in outputs_dir.glob("*"):
if model_dir.is_dir():
self.models[model_dir.name] = {
"name": model_dir.name,
"path": str(model_dir),
"type": "checkpoint",
"size": self.get_dir_size(model_dir),
"status": "available",
}
self.refresh_model_table()
def get_dir_size(self, path: Path) -> str:
"""Get human-readable directory size."""
try:
total_size = sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
for unit in ["B", "KB", "MB", "GB"]:
if total_size < 1024.0:
return f"{total_size:.2f} {unit}"
total_size /= 1024.0
return f"{total_size:.2f} TB"
except Exception:
return "Unknown"
def refresh_model_table(self) -> None:
"""Refresh the model table."""
table = self.query_one("#model-table", DataTable)
table.clear()
for name, info in self.models.items():
table.add_row(
name[:30],
info["type"],
info["size"],
info["status"],
)
@on(DataTable.RowSelected)
def handle_model_selected(self, event: DataTable.RowSelected) -> None:
"""Handle model selection from table."""
if event.cursor_row >= 0:
model_names = list(self.models.keys())
if event.cursor_row < len(model_names):
self.selected_model = model_names[event.cursor_row]
self.update_model_info()
def update_model_info(self) -> None:
"""Update model information display."""
if not self.selected_model:
return
info = self.models[self.selected_model]
info_text = f"""
Model Name: {info['name']}
Path: {info['path']}
Type: {info['type']}
Size: {info['size']}
Status: {info['status']}
"""
self.query_one("#model-info", Static).update(info_text)
@on(Button.Pressed, "#merge-lora")
@work(thread=True)
async def handle_merge_lora(self) -> None:
"""Merge LoRA adapters with base model."""
if not self.selected_model:
log = self.query_one("#operations-log", Log)
log.write_line("⚠️ No model selected")
return
model_info = self.models[self.selected_model]
log = self.query_one("#operations-log", Log)
log.clear()
log.write_line(f"🔄 Merging LoRA adapters for {self.selected_model}...")
progress = self.query_one("#operation-progress", ProgressBar)
progress.update(progress=0)
try:
import subprocess
cmd = ["python", "-m", "axolotl.cli.merge_lora", model_info["path"]]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
for line in process.stdout:
log.write_line(line.strip())
progress.advance(10)
process.wait()
if process.returncode == 0:
log.write_line("✅ LoRA merge completed successfully!")
progress.update(progress=100)
else:
log.write_line(f"❌ LoRA merge failed with code {process.returncode}")
except Exception as e:
log.write_line(f"❌ Error during LoRA merge: {str(e)}")
@on(Button.Pressed, "#quantize")
@work(thread=True)
async def handle_quantize(self) -> None:
"""Quantize selected model."""
if not self.selected_model:
log = self.query_one("#operations-log", Log)
log.write_line("⚠️ No model selected")
return
model_info = self.models[self.selected_model]
log = self.query_one("#operations-log", Log)
log.clear()
log.write_line(f"🔄 Quantizing {self.selected_model}...")
progress = self.query_one("#operation-progress", ProgressBar)
progress.update(progress=0)
try:
import subprocess
cmd = [
"python",
"-m",
"axolotl.cli.quantize",
model_info["path"],
"--output-dir",
f"{model_info['path']}_quantized",
]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
for line in process.stdout:
log.write_line(line.strip())
progress.advance(5)
process.wait()
if process.returncode == 0:
log.write_line("✅ Quantization completed successfully!")
progress.update(progress=100)
else:
log.write_line(f"❌ Quantization failed with code {process.returncode}")
except Exception as e:
log.write_line(f"❌ Error during quantization: {str(e)}")
@on(Button.Pressed, "#evaluate")
@work(thread=True)
async def handle_evaluate(self) -> None:
"""Evaluate selected model."""
if not self.selected_model:
log = self.query_one("#operations-log", Log)
log.write_line("⚠️ No model selected")
return
model_info = self.models[self.selected_model]
log = self.query_one("#operations-log", Log)
log.clear()
log.write_line(f"🔄 Evaluating {self.selected_model}...")
progress = self.query_one("#operation-progress", ProgressBar)
progress.update(progress=0)
try:
import subprocess
cmd = ["python", "-m", "axolotl.cli.evaluate", model_info["path"]]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
for line in process.stdout:
log.write_line(line.strip())
progress.advance(10)
process.wait()
if process.returncode == 0:
log.write_line("✅ Evaluation completed successfully!")
progress.update(progress=100)
else:
log.write_line(f"❌ Evaluation failed with code {process.returncode}")
except Exception as e:
log.write_line(f"❌ Error during evaluation: {str(e)}")
@on(Button.Pressed, "#refresh")
def handle_refresh(self) -> None:
"""Refresh model list."""
self.load_models()
def action_merge_lora(self) -> None:
"""Merge LoRA adapters."""
self.handle_merge_lora()
def action_quantize(self) -> None:
"""Quantize model."""
self.handle_quantize()
def action_evaluate(self) -> None:
"""Evaluate model."""
self.handle_evaluate()
def action_refresh(self) -> None:
"""Refresh model list."""
self.handle_refresh()

View File

@@ -0,0 +1,414 @@
"""System monitoring screen for Axolotl TUI."""
import psutil
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Label,
Log,
ProgressBar,
Sparkline,
Static,
)
from axolotl.tui.screens.base import BaseScreen
class MonitorScreen(BaseScreen):
"""System monitoring screen."""
BINDINGS = [
Binding("r", "refresh", "Refresh"),
Binding("ctrl+k", "kill_process", "Kill Process"),
]
CSS = """
.monitor-container {
layout: vertical;
height: 100%;
}
.metrics-grid {
layout: horizontal;
height: 20%;
padding: 1;
}
.metric-card {
width: 25%;
border: solid $surface;
padding: 1;
margin: 0 1;
}
.metric-label {
text-style: bold;
color: $text-muted;
text-align: center;
}
.metric-value {
text-style: bold;
text-align: center;
padding: 1;
}
.charts-container {
height: 40%;
layout: horizontal;
padding: 1;
}
.chart-panel {
width: 50%;
border: solid $primary;
padding: 1;
margin: 0 1;
}
.processes-container {
height: 40%;
border: solid $warning;
padding: 1;
margin: 1;
}
DataTable {
height: 90%;
}
.process-controls {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
}
.process-controls Button {
margin: 0 1;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
Sparkline {
height: 8;
}
ProgressBar {
margin: 1 0;
}
"""
def __init__(self):
"""Initialize the monitor screen."""
super().__init__(
title="System Monitor",
subtitle="Monitor system resources and running processes",
)
self.cpu_history = []
self.memory_history = []
self.gpu_history = []
def compose(self) -> ComposeResult:
"""Compose the monitor screen layout."""
yield Header()
yield Container(
Static("🦾 System Monitor", classes="screen-title"),
Static(
"Monitor system resources and running processes",
classes="screen-subtitle",
),
Container(
Container(
Container(
Static("CPU Usage", classes="metric-label"),
Static("0%", id="cpu-usage", classes="metric-value"),
ProgressBar(total=100, id="cpu-progress"),
classes="metric-card",
),
Container(
Static("Memory", classes="metric-label"),
Static("0%", id="memory-usage", classes="metric-value"),
ProgressBar(total=100, id="memory-progress"),
classes="metric-card",
),
Container(
Static("GPU Usage", classes="metric-label"),
Static("0%", id="gpu-usage", classes="metric-value"),
ProgressBar(total=100, id="gpu-progress"),
classes="metric-card",
),
Container(
Static("Temperature", classes="metric-label"),
Static("0°C", id="temperature", classes="metric-value"),
classes="metric-card",
),
classes="metrics-grid",
),
Container(
Container(
Label("CPU History"),
Sparkline([], id="cpu-sparkline"),
classes="chart-panel",
),
Container(
Label("Memory History"),
Sparkline([], id="memory-sparkline"),
classes="chart-panel",
),
classes="charts-container",
),
Container(
DataTable(id="process-table"),
Log(id="gpu-info"),
Log(id="system-logs"),
classes="processes-container",
),
classes="monitor-container",
),
id="content",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.setup_process_table()
self.start_monitoring()
# Initial system info
self.update_system_info()
self.update_gpu_info()
def setup_process_table(self) -> None:
"""Setup the process table."""
table = self.query_one("#process-table", DataTable)
table.add_columns("PID", "Name", "CPU%", "Memory%", "Status")
table.cursor_type = "row"
table.zebra_stripes = True
def start_monitoring(self) -> None:
"""Start the monitoring timer."""
self.set_interval(2.0, self.update_system_metrics)
@work(thread=True)
async def update_system_metrics(self) -> None:
"""Update system metrics."""
try:
# CPU usage
cpu_percent = psutil.cpu_percent(interval=None)
self.cpu_history.append(cpu_percent)
if len(self.cpu_history) > 50:
self.cpu_history.pop(0)
# Memory usage
memory = psutil.virtual_memory()
memory_percent = memory.percent
self.memory_history.append(memory_percent)
if len(self.memory_history) > 50:
self.memory_history.pop(0)
# GPU usage (if available)
gpu_percent = self.get_gpu_usage()
self.gpu_history.append(gpu_percent)
if len(self.gpu_history) > 50:
self.gpu_history.pop(0)
# Temperature
temperature = self.get_temperature()
# Update UI
self.update_metrics_display(
cpu_percent, memory_percent, gpu_percent, temperature
)
self.update_sparklines()
self.update_process_table()
except Exception as e:
log = self.query_one("#system-logs", Log)
log.write_line(f"Error updating metrics: {str(e)}")
def get_gpu_usage(self) -> float:
"""Get GPU usage percentage."""
try:
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
return util.gpu
except Exception:
return 0.0
def get_temperature(self) -> str:
"""Get system temperature."""
try:
temps = psutil.sensors_temperatures()
if temps:
for name, entries in temps.items():
if entries:
return f"{entries[0].current:.1f}°C"
return "N/A"
except Exception:
return "N/A"
def update_metrics_display(
self, cpu: float, memory: float, gpu: float, temp: str
) -> None:
"""Update metrics display."""
self.query_one("#cpu-usage", Static).update(f"{cpu:.1f}%")
self.query_one("#memory-usage", Static).update(f"{memory:.1f}%")
self.query_one("#gpu-usage", Static).update(f"{gpu:.1f}%")
self.query_one("#temperature", Static).update(temp)
self.query_one("#cpu-progress", ProgressBar).update(progress=cpu)
self.query_one("#memory-progress", ProgressBar).update(progress=memory)
self.query_one("#gpu-progress", ProgressBar).update(progress=gpu)
def update_sparklines(self) -> None:
"""Update sparkline charts."""
if self.cpu_history:
cpu_sparkline = self.query_one("#cpu-sparkline", Sparkline)
cpu_sparkline.data = self.cpu_history
if self.memory_history:
memory_sparkline = self.query_one("#memory-sparkline", Sparkline)
memory_sparkline.data = self.memory_history
def update_process_table(self) -> None:
"""Update the process table."""
table = self.query_one("#process-table", DataTable)
table.clear()
try:
# Get top processes by CPU usage
processes = []
for proc in psutil.process_iter(
["pid", "name", "cpu_percent", "memory_percent", "status"]
):
try:
pinfo = proc.info
if pinfo["cpu_percent"] > 0.1: # Only show processes using CPU
processes.append(pinfo)
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
# Sort by CPU usage
processes.sort(key=lambda x: x["cpu_percent"], reverse=True)
# Add top 20 processes
for proc in processes[:20]:
table.add_row(
str(proc["pid"]),
proc["name"][:20],
f"{proc['cpu_percent']:.1f}%",
f"{proc['memory_percent']:.1f}%",
proc["status"],
)
except Exception as e:
log = self.query_one("#system-logs", Log)
log.write_line(f"Error updating process table: {str(e)}")
def update_system_info(self) -> None:
"""Update system information."""
try:
# System info
psutil.boot_time()
cpu_count = psutil.cpu_count()
memory = psutil.virtual_memory()
log = self.query_one("#system-logs", Log)
log.write_line(f"System started. CPU cores: {cpu_count}")
log.write_line(f"Total memory: {memory.total / (1024**3):.1f} GB")
log.write_line(f"Available memory: {memory.available / (1024**3):.1f} GB")
except Exception as e:
log = self.query_one("#system-logs", Log)
log.write_line(f"Error getting system info: {str(e)}")
def update_gpu_info(self) -> None:
"""Update GPU information."""
try:
import pynvml
pynvml.nvmlInit()
device_count = pynvml.nvmlDeviceGetCount()
log = self.query_one("#gpu-info", Log)
log.clear()
log.write_line(f"Found {device_count} GPU(s)")
for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
name = pynvml.nvmlDeviceGetName(handle).decode()
memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
log.write_line(f"\nGPU {i}: {name}")
log.write_line(
f"Memory: {memory_info.used / (1024**3):.1f} / {memory_info.total / (1024**3):.1f} GB"
)
log.write_line(f"Free: {memory_info.free / (1024**3):.1f} GB")
except Exception as e:
log = self.query_one("#gpu-info", Log)
log.clear()
log.write_line(f"GPU info unavailable: {str(e)}")
@on(Button.Pressed, "#kill-process")
def handle_kill_process(self) -> None:
"""Kill selected process."""
table = self.query_one("#process-table", DataTable)
if table.cursor_row >= 0:
try:
row = table.get_row_at(table.cursor_row)
pid = int(row[0])
process = psutil.Process(pid)
process.terminate()
log = self.query_one("#system-logs", Log)
log.write_line(f"Terminated process {pid}")
except Exception as e:
log = self.query_one("#system-logs", Log)
log.write_line(f"Error killing process: {str(e)}")
@on(Button.Pressed, "#refresh")
def handle_refresh(self) -> None:
"""Refresh all metrics."""
self.update_system_info()
self.update_gpu_info()
log = self.query_one("#system-logs", Log)
log.write_line("Metrics refreshed")
@on(Button.Pressed, "#auto-refresh")
def handle_auto_refresh(self) -> None:
"""Toggle auto refresh."""
log = self.query_one("#system-logs", Log)
log.write_line("Auto refresh is always enabled (every 2 seconds)")
def action_refresh(self) -> None:
"""Refresh action."""
self.handle_refresh()
def action_kill_process(self) -> None:
"""Kill process action."""
self.handle_kill_process()

View File

@@ -0,0 +1,545 @@
"""Training management screen for Axolotl TUI."""
import subprocess
import threading
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Label,
Log,
Sparkline,
Static,
)
from axolotl.tui.screens.base import BaseScreen
@dataclass
class TrainingJob:
"""Represents a training job."""
id: str
config_path: str
status: str # pending, running, completed, failed
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
process: Optional[subprocess.Popen] = None
log_file: Optional[str] = None
current_epoch: int = 0
total_epochs: int = 0
current_loss: float = 0.0
losses: List[float] = None
def __post_init__(self):
if self.losses is None:
self.losses = []
class TrainingScreen(BaseScreen):
"""Training management screen."""
BINDINGS = [
Binding("ctrl+t", "new_training", "New Training"),
Binding("ctrl+r", "resume_training", "Resume"),
Binding("ctrl+x", "stop_training", "Stop"),
Binding("ctrl+l", "view_logs", "View Logs"),
Binding("r", "refresh", "Refresh"),
]
CSS = """
.training-container {
layout: vertical;
height: 100%;
}
.job-list-container {
height: 40%;
border: solid $primary;
padding: 1;
margin: 1;
}
.job-details-container {
height: 60%;
padding: 1;
}
.control-panel {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
border: solid $secondary;
margin: 1;
}
.control-panel Button {
margin: 0 1;
}
.metrics-panel {
layout: horizontal;
height: 10;
border: solid $primary;
padding: 1;
margin: 1;
}
.metric-card {
width: 25%;
border: tall $surface;
padding: 1;
margin: 0 1;
}
.metric-label {
text-style: bold;
color: $text-muted;
}
.metric-value {
text-style: bold;
text-align: center;
padding: 1;
}
.log-viewer {
border: solid $warning;
padding: 1;
margin: 1;
}
#training-logs {
height: 100%;
}
DataTable {
height: 100%;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
.sparkline-container {
height: 5;
border: solid $success;
padding: 1;
margin: 1;
}
"""
def __init__(self):
"""Initialize the training screen."""
super().__init__(
title="Training Management",
subtitle="Launch, monitor, and manage training jobs",
)
self.jobs: Dict[str, TrainingJob] = {}
self.selected_job_id: Optional[str] = None
self.update_timer = None
def compose(self) -> ComposeResult:
"""Compose the training screen layout."""
yield Header()
yield Container(
Static("🦾 Training Management", classes="screen-title"),
Static(
"Launch, monitor, and manage training jobs", classes="screen-subtitle"
),
Container(
Container(
Label("Active Training Jobs"),
DataTable(id="job-table"),
classes="job-list-container",
),
Container(
Button("New Training", id="new-training", variant="primary"),
Button("Resume", id="resume-training", variant="success"),
Button("Stop", id="stop-training", variant="error"),
Button("View Logs", id="view-logs", variant="default"),
Button("Clear Completed", id="clear-completed", variant="warning"),
Button("Refresh", id="refresh", variant="default"),
classes="control-panel",
),
Container(
Container(
Static("Current Epoch", classes="metric-label"),
Static("0 / 0", id="epoch-metric", classes="metric-value"),
classes="metric-card",
),
Container(
Static("Loss", classes="metric-label"),
Static("0.000", id="loss-metric", classes="metric-value"),
classes="metric-card",
),
Container(
Static("Status", classes="metric-label"),
Static("Idle", id="status-metric", classes="metric-value"),
classes="metric-card",
),
Container(
Static("Duration", classes="metric-label"),
Static(
"00:00:00", id="duration-metric", classes="metric-value"
),
classes="metric-card",
),
classes="metrics-panel",
),
Container(
Label("Loss History"),
Sparkline(
[],
id="loss-sparkline",
summary_function=min,
),
classes="sparkline-container",
),
Container(
Log(id="training-logs"),
classes="log-viewer",
),
classes="job-details-container",
),
classes="training-container",
id="content",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.setup_job_table()
self.start_update_timer()
log = self.query_one("#training-logs", Log)
log.write_line(
"Training manager ready. Select a configuration to start training."
)
def setup_job_table(self) -> None:
"""Setup the job table."""
table = self.query_one("#job-table", DataTable)
table.add_columns("ID", "Config", "Status", "Epoch", "Loss", "Duration")
table.cursor_type = "row"
table.zebra_stripes = True
def start_update_timer(self) -> None:
"""Start the periodic update timer."""
self.set_interval(2.0, self.update_job_status)
@work(thread=True)
async def update_job_status(self) -> None:
"""Update job status periodically."""
for job_id, job in self.jobs.items():
if job.status == "running" and job.process:
poll = job.process.poll()
if poll is not None:
if poll == 0:
job.status = "completed"
else:
job.status = "failed"
job.end_time = datetime.now()
self.refresh_job_table()
self.update_selected_job_metrics()
def refresh_job_table(self) -> None:
"""Refresh the job table."""
table = self.query_one("#job-table", DataTable)
table.clear()
for job_id, job in self.jobs.items():
duration = self.calculate_duration(job)
table.add_row(
job_id[:8],
Path(job.config_path).name,
job.status,
f"{job.current_epoch}/{job.total_epochs}",
f"{job.current_loss:.4f}" if job.current_loss else "N/A",
duration,
)
def calculate_duration(self, job: TrainingJob) -> str:
"""Calculate job duration."""
if not job.start_time:
return "00:00:00"
end_time = job.end_time or datetime.now()
duration = end_time - job.start_time
hours = int(duration.total_seconds() // 3600)
minutes = int((duration.total_seconds() % 3600) // 60)
seconds = int(duration.total_seconds() % 60)
return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
def update_selected_job_metrics(self) -> None:
"""Update metrics for selected job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
return
job = self.jobs[self.selected_job_id]
self.query_one("#epoch-metric", Static).update(
f"{job.current_epoch} / {job.total_epochs}"
)
self.query_one("#loss-metric", Static).update(
f"{job.current_loss:.4f}" if job.current_loss else "N/A"
)
self.query_one("#status-metric", Static).update(job.status.upper())
self.query_one("#duration-metric", Static).update(self.calculate_duration(job))
if job.losses:
sparkline = self.query_one("#loss-sparkline", Sparkline)
sparkline.data = job.losses[-50:] # Show last 50 loss values
@on(DataTable.RowSelected)
def handle_row_selected(self, event: DataTable.RowSelected) -> None:
"""Handle job selection from table."""
if event.cursor_row >= 0:
job_ids = list(self.jobs.keys())
if event.cursor_row < len(job_ids):
self.selected_job_id = job_ids[event.cursor_row]
self.update_selected_job_metrics()
self.load_job_logs()
def load_job_logs(self) -> None:
"""Load logs for selected job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
return
job = self.jobs[self.selected_job_id]
if job.log_file and Path(job.log_file).exists():
try:
with open(job.log_file, "r") as f:
content = f.read()
log = self.query_one("#training-logs", Log)
log.clear()
for line in content.split("\n")[-100:]: # Show last 100 lines
if line.strip():
log.write_line(line)
except Exception as e:
log = self.query_one("#training-logs", Log)
log.write_line(f"Error loading logs: {str(e)}")
@on(Button.Pressed, "#new-training")
async def handle_new_training(self) -> None:
"""Start a new training job."""
from axolotl.tui.dialogs.training import NewTrainingDialog
dialog = NewTrainingDialog()
result = await self.app.push_screen_wait(dialog)
if result and "config_path" in result:
await self.start_training_job(
result["config_path"], result.get("launcher", "accelerate")
)
@work(thread=True)
async def start_training_job(
self, config_path: str, launcher: str = "accelerate"
) -> None:
"""Start a training job."""
import uuid
from datetime import datetime
job_id = str(uuid.uuid4())
log_file = f"/tmp/axolotl_training_{job_id}.log"
job = TrainingJob(
id=job_id,
config_path=config_path,
status="pending",
start_time=datetime.now(),
log_file=log_file,
total_epochs=3, # Default, should parse from config
)
self.jobs[job_id] = job
self.selected_job_id = job_id
log = self.query_one("#training-logs", Log)
log.clear()
log.write_line(f"🚀 Starting training job {job_id[:8]}...")
log.write_line(f"Config: {config_path}")
log.write_line(f"Launcher: {launcher}")
try:
if launcher == "accelerate":
cmd = ["accelerate", "launch", "-m", "axolotl.cli.train", config_path]
else:
cmd = [
"torchrun",
"--nproc_per_node=1",
"-m",
"axolotl.cli.train",
config_path,
]
with open(log_file, "w") as f:
process = subprocess.Popen(
cmd,
stdout=f,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
)
job.process = process
job.status = "running"
log.write_line("✅ Training started successfully!")
self.refresh_job_table()
self.monitor_training_output(job_id)
except Exception as e:
job.status = "failed"
job.end_time = datetime.now()
log.write_line(f"❌ Failed to start training: {str(e)}")
self.refresh_job_table()
def monitor_training_output(self, job_id: str) -> None:
"""Monitor training output and extract metrics."""
if job_id not in self.jobs:
return
job = self.jobs[job_id]
if not job.log_file:
return
def tail_log():
import re
import time
with open(job.log_file, "r") as f:
f.seek(0, 2) # Go to end of file
while job.status == "running":
line = f.readline()
if line:
# Parse training metrics from log
epoch_match = re.search(r"Epoch (\d+)/(\d+)", line)
if epoch_match:
job.current_epoch = int(epoch_match.group(1))
job.total_epochs = int(epoch_match.group(2))
loss_match = re.search(
r"loss['\"]?\s*:\s*([\d.]+)", line, re.IGNORECASE
)
if loss_match:
job.current_loss = float(loss_match.group(1))
job.losses.append(job.current_loss)
# Update log viewer
self.call_from_thread(self.append_training_log, line.strip())
else:
time.sleep(0.5)
thread = threading.Thread(target=tail_log, daemon=True)
thread.start()
def append_training_log(self, line: str) -> None:
"""Append line to training log."""
log = self.query_one("#training-logs", Log)
log.write_line(line)
@on(Button.Pressed, "#stop-training")
def handle_stop_training(self) -> None:
"""Stop selected training job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
log = self.query_one("#training-logs", Log)
log.write_line("⚠️ No job selected")
return
job = self.jobs[self.selected_job_id]
if job.status == "running" and job.process:
job.process.terminate()
job.status = "stopped"
job.end_time = datetime.now()
log = self.query_one("#training-logs", Log)
log.write_line(f"🛑 Training job {job.id[:8]} stopped")
self.refresh_job_table()
@on(Button.Pressed, "#resume-training")
async def handle_resume_training(self) -> None:
"""Resume a stopped training job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
log = self.query_one("#training-logs", Log)
log.write_line("⚠️ No job selected")
return
job = self.jobs[self.selected_job_id]
if job.status in ["stopped", "failed"]:
await self.start_training_job(job.config_path)
@on(Button.Pressed, "#clear-completed")
def handle_clear_completed(self) -> None:
"""Clear completed jobs from the list."""
completed_jobs = [
job_id
for job_id, job in self.jobs.items()
if job.status in ["completed", "failed", "stopped"]
]
for job_id in completed_jobs:
del self.jobs[job_id]
self.refresh_job_table()
log = self.query_one("#training-logs", Log)
log.write_line(f"🧹 Cleared {len(completed_jobs)} completed jobs")
@on(Button.Pressed, "#refresh")
def handle_refresh(self) -> None:
"""Refresh the job list and metrics."""
self.refresh_job_table()
self.update_selected_job_metrics()
if self.selected_job_id:
self.load_job_logs()
@on(Button.Pressed, "#view-logs")
def handle_view_logs(self) -> None:
"""View full logs for selected job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
return
job = self.jobs[self.selected_job_id]
if job.log_file and Path(job.log_file).exists():
import subprocess
subprocess.run(["less", job.log_file])
def action_new_training(self) -> None:
"""Start a new training job."""
self.handle_new_training()
def action_stop_training(self) -> None:
"""Stop selected training job."""
self.handle_stop_training()
def action_resume_training(self) -> None:
"""Resume selected training job."""
self.handle_resume_training()
def action_refresh(self) -> None:
"""Refresh the display."""
self.handle_refresh()

View File

@@ -9,7 +9,6 @@ from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
load_dataset,
)
from transformers import PreTrainedTokenizer, ProcessorMixin
@@ -44,24 +43,12 @@ from axolotl.utils.trainer import (
LOG = get_logger(__name__)
def _is_streaming_enabled(cfg: DictDefault) -> bool:
"""Check if streaming is enabled for a specific split."""
streaming = cfg.get("streaming")
if streaming is True:
return True
# Check if pretraining dataset exists (defaults to streaming)
has_pretraining = cfg.get("pretraining_dataset") is not None
streaming = has_pretraining and streaming is None
return streaming
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_datasets(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[IterableDataset | Dataset, Dataset | None, int, list[Prompter | None]]:
"""Prepare training and evaluation datasets based on configuration.
@@ -69,19 +56,23 @@ def prepare_datasets(
cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: Tokenizer to use for processing text.
processor: Optional processor for multimodal datasets.
preprocess_iterable: Whether to use iterable preprocessing.
Returns:
Tuple of (train_dataset, eval_dataset, total_steps, prompters).
"""
if cfg.pretraining_dataset:
return _prepare_pretraining_dataset(cfg, tokenizer, processor)
return _prepare_standard_dataset(cfg, tokenizer, processor)
return _prepare_pretraining_dataset(
cfg, tokenizer, processor, preprocess_iterable
)
return _prepare_standard_dataset(cfg, tokenizer, processor, preprocess_iterable)
def _prepare_standard_dataset(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
preprocess_iterable: bool,
) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]:
"""Prepare standard (non-pretraining) datasets."""
@@ -92,6 +83,7 @@ def _prepare_standard_dataset(
cfg,
split="train",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
# Overwrite eval_dataset if test data exists
@@ -101,6 +93,7 @@ def _prepare_standard_dataset(
cfg,
split="test",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
return train_dataset, eval_dataset, prompters
@@ -116,12 +109,7 @@ def _prepare_standard_dataset(
return train_dataset, eval_dataset, -1, prompters
# Validate sample packing configuration for evaluation
if (
eval_dataset
and cfg.sample_packing
and cfg.eval_sample_packing is not False
and not isinstance(eval_dataset, IterableDataset)
):
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0:
raise ValueError(
@@ -129,17 +117,13 @@ def _prepare_standard_dataset(
"You should set `eval_sample_packing: False` in your config."
)
# Set total_num_steps for training
if isinstance(train_dataset, IterableDataset):
total_num_steps = cfg.max_steps
# Calculate total number of training steps
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
)
else:
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
)
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
return train_dataset, eval_dataset, total_num_steps, prompters
@@ -148,6 +132,7 @@ def _prepare_pretraining_dataset(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
preprocess_iterable: bool,
) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]:
"""
Prepare dataset for pretraining mode.
@@ -168,6 +153,7 @@ def _prepare_pretraining_dataset(
cfg,
split="test",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if cfg.dataset_exact_deduplication:
@@ -270,6 +256,7 @@ def _load_tokenized_prepared_datasets(
cfg: DictDefault,
split: Literal["train", "test"] = "train",
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[Dataset | DatasetDict, list[Prompter | None]]:
"""Load or create tokenized and prepared datasets for training or testing.
@@ -278,51 +265,39 @@ def _load_tokenized_prepared_datasets(
cfg: Configuration object.
split: Dataset split to load ('train' or 'test').
processor: Optional processor for multimodal datasets.
preprocess_iterable: Whether to use iterable preprocessing.
Returns:
Tuple of (dataset, prompters list).
"""
# Select correct dataset configuration based on split
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
# Generate dataset hash for caching
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
# Try loading from hub if push_dataset_to_hub is configured
dataset = None
if cfg.push_dataset_to_hub:
dataset = try_load_from_hub(cfg, dataset_hash, split)
# If not found on hub, try loading from disk
if dataset is None:
dataset = load_preprocessed_dataset(cfg, dataset_hash)
# If not found on disk or skipping prepared dataset, load and process raw datasets
prompters: list[Prompter | None] = []
use_streaming = False
if split == "train":
use_streaming = _is_streaming_enabled(cfg)
if use_streaming:
# For streaming datasets, skip caching and load raw datasets directly
if dataset is None:
dataset, prompters = _load_raw_datasets(
cfg,
datasets_configs,
tokenizer,
split,
processor,
preprocess_iterable,
)
else:
# Generate dataset hash for caching
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
# Try loading from hub if push_dataset_to_hub is configured
dataset = None
if cfg.push_dataset_to_hub:
dataset = try_load_from_hub(cfg, dataset_hash, split)
# If not found on hub, try loading from disk
if dataset is None:
dataset = load_preprocessed_dataset(cfg, dataset_hash)
# If not found on disk or skipping prepared dataset, load and process raw
# datasets
if dataset is None:
dataset, prompters = _load_raw_datasets(
cfg,
datasets_configs,
tokenizer,
split,
processor,
)
return dataset, prompters
@@ -331,8 +306,9 @@ def _load_raw_datasets(
cfg: DictDefault,
datasets_configs: list,
tokenizer: PreTrainedTokenizer,
split: Literal["train", "test"],
split: str,
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[Dataset, list[Prompter | None]]:
"""Load, process, merge, and save raw datasets."""
LOG.info("Loading raw datasets...", main_process_only=False)
@@ -353,6 +329,7 @@ def _load_raw_datasets(
split=split,
seed=cfg.seed,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
datasets.append(dataset_wrapper)
prompters.append(dataset_prompter)
@@ -368,12 +345,11 @@ def _load_raw_datasets(
if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
# Only save regular datasets to disk, not streaming datasets
if not isinstance(dataset, IterableDataset):
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
# Save the prepared dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
return dataset, prompters
@@ -382,22 +358,22 @@ def _load_and_process_single_dataset(
dataset_config: DictDefault,
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
split: Literal["train", "test"],
split: str,
seed: int,
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Load and process a single dataset based on the passed config."""
use_streaming = False
if split == "train":
use_streaming = _is_streaming_enabled(cfg)
# Load the dataset
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, use_streaming
dataset_config, cfg.hf_use_auth_token, streaming=preprocess_iterable
)
# Parse dataset type
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
# Select the appropriate split
if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
if isinstance(dataset, DatasetDict):
if dataset_config.split and dataset_config.split in dataset:
dataset = dataset[dataset_config.split]
elif split in dataset:
@@ -442,13 +418,11 @@ def _parse_dataset_type(d_type: str) -> tuple[str | None, str | None]:
def _handle_train_dataset_split(
dataset: Dataset | IterableDataset, cfg: DictDefault
) -> tuple[Dataset | IterableDataset, Dataset | IterableDataset | None]:
dataset: Dataset, cfg: DictDefault
) -> tuple[Dataset, Dataset | None]:
"""Handle processing for train split, including validation set creation."""
val_set_size = (
int(cfg.val_set_size)
if cfg.val_set_size and cfg.val_set_size > 1
else float(cfg.val_set_size or 0.0)
int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
)
if val_set_size:
@@ -459,33 +433,27 @@ def _handle_train_dataset_split(
return train_dataset, eval_dataset
# No validation split - apply deduplication if needed and return as train dataset
if cfg.dataset_exact_deduplication and not isinstance(dataset, IterableDataset):
if cfg.dataset_exact_deduplication:
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
else:
if cfg.dataset_exact_deduplication and isinstance(dataset, IterableDataset):
LOG.info("Deduplication skipped for streaming datasets (not compatible)")
train_dataset = dataset
return train_dataset, None
def _handle_test_dataset_split(
dataset: Dataset | IterableDataset, cfg: DictDefault
) -> tuple[None, Dataset | IterableDataset | None]:
dataset: Dataset, cfg: DictDefault
) -> tuple[None, Dataset | None]:
"""Handle processing for test split."""
if cfg.dataset_exact_deduplication and not isinstance(dataset, IterableDataset):
if cfg.dataset_exact_deduplication:
eval_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
else:
if cfg.dataset_exact_deduplication and isinstance(dataset, IterableDataset):
LOG.info("Deduplication skipped for streaming datasets (not compatible)")
eval_dataset = dataset
return None, eval_dataset
def _apply_dataset_sharding(
dataset: Dataset | IterableDataset, cfg: DictDefault
) -> Dataset | IterableDataset:
def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset:
"""Apply dataset sharding if configured.
Args:
@@ -511,6 +479,7 @@ def _load_and_prepare_datasets(
cfg: DictDefault,
split: Literal["train", "test"] = "train",
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[Dataset | None, Dataset | None, list[Prompter | None]]:
"""Load and prepare datasets with optional validation split and sharding.
@@ -519,6 +488,7 @@ def _load_and_prepare_datasets(
cfg: Configuration object.
split: Dataset split to load ('train' or 'test').
processor: Optional processor for multimodal datasets.
preprocess_iterable: Whether to use iterable preprocessing.
Returns:
Tuple of (train_dataset, eval_dataset, prompters).
@@ -529,6 +499,7 @@ def _load_and_prepare_datasets(
cfg,
split=split,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
# Apply dataset sharding if configured using shared function

View File

@@ -13,7 +13,6 @@ from datasets import (
IterableDataset,
IterableDatasetDict,
concatenate_datasets,
interleave_datasets,
load_dataset,
load_from_disk,
)
@@ -525,9 +524,7 @@ def generate_dataset_hash_from_config(
return str(md5(config_str))
def merge_datasets(
datasets: list[Dataset | IterableDataset], cfg: DictDefault
) -> Dataset | IterableDataset:
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
"""Merge multiple datasets into one with optional shuffling.
Args:
@@ -540,23 +537,23 @@ def merge_datasets(
if len(datasets) == 1:
ds = datasets[0]
if (
cfg.curriculum_sampling
or not cfg.shuffle_merged_datasets
or isinstance(ds, IterableDataset)
):
# Do not shuffle if curriculum sampling is enabled or
# shuffle_merged_datasets is disabled
if cfg.curriculum_sampling or not cfg.shuffle_merged_datasets:
return ds
return ds.shuffle(seed=cfg.seed)
if cfg.shuffle_before_merging_datasets and all(
isinstance(ds, Dataset) for ds in datasets
):
# If enabled, shuffle each dataset independently before merging.
# This allows curriculum learning strategies to be applied at the dataset level.
if cfg.shuffle_before_merging_datasets:
LOG.info("Shuffling each dataset individually before merging...")
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
merged_dataset = _merge_datasets_with_strategy(datasets, cfg)
LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets)
if cfg.shuffle_merged_datasets and not isinstance(merged_dataset, IterableDataset):
if cfg.shuffle_merged_datasets:
LOG.debug("Shuffling merged datasets...")
if cfg.curriculum_sampling:
LOG.warning(
@@ -565,45 +562,6 @@ def merge_datasets(
)
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
else:
if isinstance(merged_dataset, IterableDataset):
LOG.debug("Skipping shuffle for streaming datasets.")
else:
LOG.debug("Not shuffling merged datasets.")
LOG.debug("Not shuffling merged datasets.")
return merged_dataset
def _merge_datasets_with_strategy(
datasets: list[Dataset | IterableDataset], cfg: DictDefault
) -> Dataset | IterableDataset:
"""
Merge datasets using the configured mixing strategy. Works with streaming and non-
streaming datasets.
Args:
datasets: List of datasets to merge.
cfg: Configuration object containing mixing settings.
Returns:
Merged dataset (Dataset or IterableDataset depending on inputs).
"""
strategy = cfg.get("dataset_mixing_strategy", "concatenate")
weights = cfg.get("mixing_weights", None)
LOG.info(f"Merging datasets with mixing strategy: {strategy}...")
if strategy == "concatenate":
if not all(isinstance(ds, Dataset) for ds in datasets):
raise ValueError(
"Cannot concatenate streaming datasets. Use 'round_robin', 'weighted', "
"or 'random' instead."
)
return concatenate_datasets(datasets)
if strategy == "round_robin":
return interleave_datasets(datasets, seed=cfg.seed)
if strategy == "weighted":
return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed)
if strategy == "random":
equal_weights = [1.0 / len(datasets)] * len(datasets)
return interleave_datasets(datasets, probabilities=equal_weights, seed=cfg.seed)
raise ValueError(f"Unknown dataset mixing strategy: {strategy}")

View File

@@ -190,15 +190,11 @@ def handle_long_seq_in_dataset(
Returns:
Filtered dataset with long sequences removed.
"""
if hasattr(dataset, "column_names") and dataset.column_names:
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This "
"is expected for reward modeling."
)
return dataset
elif isinstance(dataset, IterableDataset):
LOG.info("Skipping drop_long_seq for streaming datasets (not compatible)")
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
return dataset
drop_long = functools.partial(

View File

@@ -932,27 +932,9 @@ class AxolotlInputConfig(
fix_untrained_tokens: int | list[int] | None = None
streaming: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use streaming datasets (IterableDataset) for training datasets. When True, data is loaded on-demand during training without upfront preprocessing. Requires max_steps to be set. Pre-training datasets default to streaming unless explicitly set to False."
},
)
dataset_mixing_strategy: str | None = Field(
default="round_robin",
json_schema_extra={
"description": "Strategy for mixing multiple datasets: 'concatenate', 'round_robin' (equal sampling), 'weighted' (use mixing_weights), or 'random' (random sampling with equal probability). Works for both streaming and non-streaming datasets."
},
)
mixing_weights: list[float] | None = Field(
default=None,
json_schema_extra={
"description": "Weights for weighted mixing strategy when using multiple datasets. Must sum to 1.0 and have same length as datasets list. Only used when dataset_mixing_strategy='weighted'."
},
)
# INTERNALS - document for now, generally not set externally
is_preprocess: bool | None = None
preprocess_iterable: bool | None = None
total_num_tokens: int | None = Field(
default=None,

View File

@@ -161,12 +161,7 @@ class HyperparametersConfig(BaseModel):
max_grad_norm: float | None = Field(
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
)
num_epochs: float = Field(
default=1.0,
json_schema_extra={
"description": "Number of iterations over dataset for training"
},
)
num_epochs: float = Field(default=1.0)
@field_validator("batch_size")
@classmethod

View File

@@ -3,7 +3,6 @@
# pylint: disable=too-many-boolean-expressions
import json
import os
import sys
import tempfile
from pathlib import Path
@@ -193,7 +192,6 @@ class AttentionValidationMixin:
return data
# pylint: disable=too-many-public-methods
class TrainingValidationMixin:
"""Validation methods related to training configuration."""
@@ -510,58 +508,11 @@ class TrainingValidationMixin:
# combining these would raise `TypeError: cannot pickle 'dict_keys' object`
# due to trying to count the number of tokens total in the dataset
raise ValueError(
"pretraining_dataset and include_tokens_per_second cannot be used "
"together."
"pretraining_dataset and include_tokens_per_second cannot be used together."
)
return data
@model_validator(mode="before")
@classmethod
def check_max_steps_num_epochs_conflict(cls, data):
"""Handle max_steps and num_epochs configuration and auto-set defaults."""
max_steps = data.get("max_steps")
num_epochs = data.get("num_epochs")
# Auto-set num_epochs to 1 if neither max_steps nor num_epochs are set
if max_steps is None and num_epochs is None:
data["num_epochs"] = 1.0
return data
@model_validator(mode="before")
@classmethod
def check_saves_per_epoch_conflicts(cls, data):
"""Ensure saves_per_epoch is compatible with training configuration."""
saves_per_epoch = data.get("saves_per_epoch")
num_epochs = data.get("num_epochs")
if saves_per_epoch is not None:
# Check if saves_per_epoch is set but num_epochs is unset
if num_epochs is None:
raise ValueError(
"saves_per_epoch requires num_epochs to be set to calculate save "
"intervals."
)
return data
@model_validator(mode="before")
@classmethod
def check_evals_per_epoch_conflicts(cls, data):
"""Ensure evals_per_epoch is compatible with training configuration."""
evals_per_epoch = data.get("evals_per_epoch")
num_epochs = data.get("num_epochs")
if evals_per_epoch is not None:
if num_epochs is None:
raise ValueError(
"evals_per_epoch requires num_epochs to be set to calculate "
"evaluation intervals."
)
return data
class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration."""
@@ -1127,27 +1078,6 @@ class PretrainingValidationMixin:
data["accelerator_config"]["dispatch_batches"] = False
return data
@model_validator(mode="before")
@classmethod
def check_streaming_split_batches_accelerate(cls, data):
# Check if streaming is enabled for training
streaming = data.get("streaming", False)
# If streaming is enabled, configure accelerator
if streaming:
accelerator_config = data.get("accelerator_config", {})
if not accelerator_config:
data["accelerator_config"] = {
"split_batches": False,
"dispatch_batches": False,
}
else:
if accelerator_config.get("split_batches") is None:
data["accelerator_config"]["split_batches"] = False
if accelerator_config.get("dispatch_batches") is None:
data["accelerator_config"]["dispatch_batches"] = False
return data
class ModelCompatibilityValidationMixin:
"""Validation methods for specific model compatibility."""
@@ -1406,128 +1336,6 @@ class GRPOVllmValidationMixin:
return self
class StreamingValidationMixin:
"""Validation methods related to streaming datasets."""
def _is_streaming_enabled(self) -> bool:
"""Check if streaming is enabled."""
# Fall back to main streaming setting
streaming = getattr(self, "streaming", None)
if streaming is True:
return True
# Check if pretraining dataset exists (defaults to streaming)
has_pretraining = getattr(self, "pretraining_dataset", None) is not None
streaming = has_pretraining and streaming is None
return streaming
@model_validator(mode="after")
def check_streaming_requires_max_steps(self):
"""Ensure max_steps is set when using streaming datasets."""
# Check if streaming is enabled for training datasets
if self._is_streaming_enabled():
max_steps = getattr(self, "max_steps", None)
if not max_steps:
raise ValueError("max_steps must be set when using streaming datasets")
return self
@model_validator(mode="after")
def check_streaming_validation_splits_conflict(self):
"""Ensure validation splits are not used with streaming datasets."""
# Check if streaming is enabled for training datasets
if self._is_streaming_enabled():
val_set_size = getattr(self, "val_set_size", 0.0)
if val_set_size and val_set_size > 0:
raise ValueError(
"Validation splits not supported for streaming datasets, please "
"use test_datasets: ... instead"
)
return self
@model_validator(mode="after")
def check_streaming_preprocessing_conflict(self):
"""Ensure preprocessing is not enabled with streaming datasets."""
# Check if streaming is enabled for training datasets
if self._is_streaming_enabled():
if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1":
raise ValueError("preprocess is not supported for streaming datasets")
return self
@model_validator(mode="after")
def check_dataset_mixing_weights(self):
"""Validate dataset mixing weights configuration."""
valid_strategies = ["concatenate", "round_robin", "weighted", "random"]
# Get datasets to validate length against
datasets = getattr(self, "datasets", None)
# Check main strategy and weights
strategy = getattr(self, "dataset_mixing_strategy", "concatenate")
weights = getattr(self, "mixing_weights", None)
dataset_count = len(datasets) if datasets else 0
self._validate_dataset_strategy_and_weights(
strategy,
weights,
"dataset_mixing_strategy",
"mixing_weights",
valid_strategies,
dataset_count,
)
return self
def _validate_dataset_strategy_and_weights(
self,
strategy,
weights,
strategy_field,
weights_field,
valid_strategies,
dataset_count,
):
"""Helper method to validate dataset mixing strategy and weights pair."""
if strategy not in valid_strategies:
raise ValueError(
f"{strategy_field} must be one of {valid_strategies}, "
f"got '{strategy}'"
)
if strategy == "weighted":
if weights is None:
raise ValueError(
f"{weights_field} must be provided when "
f"{strategy_field}='weighted'"
)
if not isinstance(weights, list) or not all(
isinstance(w, (int, float)) for w in weights
):
raise ValueError(f"{weights_field} must be a list of numbers")
if any(w < 0 for w in weights):
raise ValueError(f"{weights_field} must be non-negative")
if abs(sum(weights) - 1.0) > 1e-6:
raise ValueError(f"{weights_field} must sum to 1.0, got {sum(weights)}")
# Validate weights length against dataset count
if dataset_count > 0 and len(weights) != dataset_count:
raise ValueError(
f"{weights_field} length ({len(weights)}) must match number of datasets ({dataset_count})"
)
elif weights is not None and strategy != "weighted":
LOG.warning(
f"{weights_field} provided but {strategy_field} is '{strategy}'. "
"Weights will be ignored."
)
# pylint: disable=too-many-ancestors
class ValidationMixin(
DatasetValidationMixin,
@@ -1539,7 +1347,6 @@ class ValidationMixin(
SystemValidationMixin,
ChatTemplateValidationMixin,
PretrainingValidationMixin,
StreamingValidationMixin,
ModelCompatibilityValidationMixin,
ComplexValidationMixin,
GRPOVllmValidationMixin,

View File

@@ -10,6 +10,7 @@ from typing import List, Optional
import numpy as np
import torch
import torch.cuda
from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
@@ -22,65 +23,6 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
def _create_filtered_iterable_dataset(dataset, filter_fn, batched=False):
"""
Create a filtered IterableDataset that works around a HuggingFace datasets
limitation.
"""
def filtered_generator():
"""Generator that yields only samples that pass the filter function."""
if batched:
batch = []
batch_size = 1000 # Process in batches of 1000
for sample in dataset:
batch.append(sample)
if len(batch) >= batch_size:
# Create a batch dict from list of samples
batch_dict = {}
for key in batch[0].keys():
batch_dict[key] = [sample[key] for sample in batch]
# Apply filter function to batch
keep_mask = filter_fn(batch_dict)
# Yield samples that should be kept
for i, keep in enumerate(keep_mask):
if keep:
yield batch[i]
batch = []
# Process remaining samples in batch
if batch:
batch_dict = {}
for key in batch[0].keys():
batch_dict[key] = [sample[key] for sample in batch]
keep_mask = filter_fn(batch_dict)
for i, keep in enumerate(keep_mask):
if keep:
yield batch[i]
else:
# For non-batched filtering, apply filter to each sample individually
for sample in dataset:
if filter_fn(sample):
yield sample
# Create new IterableDataset from the filtered generator
filtered_dataset = IterableDataset.from_generator(filtered_generator)
# Preserve the original features if they exist
# pylint:disable=protected-access
if hasattr(dataset, "_info") and dataset._info.features is not None:
filtered_dataset._info.features = dataset._info.features
return filtered_dataset
@torch.jit.script
def weighted_cross_entropy(
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
@@ -340,21 +282,12 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
# For IterableDatasets, always use custom filtering to avoid features issues
if isinstance(train_dataset, IterableDataset):
# IterableDatasets often have None features after transformations,
# so we use our custom filter implementation that doesn't rely on features
train_dataset = _create_filtered_iterable_dataset(
train_dataset, drop_no_trainable_tokens, batched=True
)
else:
train_dataset = train_dataset.filter(
drop_no_trainable_tokens,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
train_dataset = train_dataset.filter(
drop_no_trainable_tokens,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(train_dataset)
if dropped:
@@ -539,7 +472,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
)
data_loader = DataLoader(
train_dataset,
train_dataset.remove_columns(["length"]),
batch_sampler=sampler,
)
data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size
@@ -614,7 +547,7 @@ def setup_deepspeed_env(cfg, stage=None):
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
# NOTE(djsaunde): The distributed state cannot be initialized prior to the
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
# to model load.
if (

View File

@@ -25,7 +25,7 @@ def min_cfg(temp_dir):
"liger_rms_norm": True,
"liger_glu_activation": True,
"torch_compile": True,
"chat_template": "qwen3",
"chat_template": "llama3",
"kd_trainer": True,
"kd_ce_alpha": 0.1,
"kd_alpha": 0.9,

View File

@@ -1,185 +0,0 @@
"""E2E tests for streaming dataset functionality"""
# pylint: disable=duplicate-code
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard
class TestStreamingDatasets:
"""Test case for streaming datasets with different mixing strategies"""
@pytest.mark.parametrize(
("dataset_mixing_strategy", "mixing_weights"),
[
("round_robin", None),
("weighted", [0.7, 0.3]),
("random", None),
],
)
def test_streaming_dataset_mixing_strategies(
self, temp_dir, dataset_mixing_strategy, mixing_weights
):
"""Test different mixing strategies with streaming datasets"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"sequence_len": 1024,
"sample_packing": False,
"dataset_processes": 1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
# Streaming config
"streaming": True,
"max_steps": 3, # Very small for smoke test
"dataset_mixing_strategy": dataset_mixing_strategy,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
}
)
# Add mixing weights if specified
if mixing_weights:
cfg["mixing_weights"] = mixing_weights
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
# Verify training actually happened by checking loss decrease
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
2.5, # Loss should be reasonable for a smoke test (higher threshold for streaming)
"Train Loss (%s) is too high",
)
def test_streaming_validation_error(self, temp_dir):
"""Test that pydantic validation catches invalid streaming configs"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"streaming": True,
"max_steps": 3,
# Invalid: wrong number of weights for datasets
"dataset_mixing_strategy": "weighted",
"mixing_weights": [1.0], # Should be [0.x, 0.y] for 2 datasets
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
}
)
# This should raise a validation error
with pytest.raises(Exception) as exc_info:
validate_config(cfg)
# Verify it's the right validation error
assert "mixing_weights length" in str(exc_info.value)
assert "must match number of datasets" in str(exc_info.value)
def test_streaming_three_datasets_weighted(self, temp_dir):
"""Test weighted mixing with three datasets"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"sequence_len": 512,
"sample_packing": False,
"dataset_processes": 1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
{
"path": "yahma/alpaca-cleaned",
"type": "alpaca",
},
],
# Streaming config
"streaming": True,
"max_steps": 3,
"dataset_mixing_strategy": "weighted",
"mixing_weights": [0.5, 0.3, 0.2],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
2.5,
"Train Loss (%s) is too high",
)

View File

@@ -7,13 +7,13 @@ from typing import Any, Generator
from unittest.mock import patch
import pytest
from datasets import Dataset, IterableDataset
from datasets import Dataset
from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizer
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets, prepare_datasets
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
from axolotl.utils.dict import DictDefault
from tests.constants import (
@@ -24,7 +24,6 @@ from tests.constants import (
from tests.hf_offline_utils import enable_hf_offline
# pylint: disable=too-many-public-methods
class TestDatasetPreparation:
"""Test a configured dataloader."""
@@ -47,24 +46,6 @@ class TestDatasetPreparation:
]
)
@pytest.fixture
def streaming_dataset_fixture(self):
"""Create a streaming dataset fixture for testing."""
def generator():
yield {
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
"input": "He finnished his meal and left the resturant",
"output": "He finished his meal and left the restaurant.",
}
yield {
"instruction": "What is the capital of France?",
"input": "",
"output": "The capital of France is Paris.",
}
return IterableDataset.from_generator(generator)
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_load_hub(self, tokenizer):
@@ -505,162 +486,3 @@ class TestDatasetPreparation:
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
def test_streaming_sft_dataset(self, tokenizer, streaming_dataset_fixture):
"""Test streaming SFT dataset preparation with IterableDataset."""
with patch("axolotl.utils.data.sft.load_dataset_with_config") as mock_load:
mock_load.return_value = streaming_dataset_fixture
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 256,
"streaming": True,
"max_steps": 100, # Required for streaming datasets
"datasets": [
{
"path": "dummy/path",
"type": "alpaca",
},
],
}
)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg, tokenizer
)
# Verify it returns an IterableDataset
assert isinstance(train_dataset, IterableDataset)
assert eval_dataset is None # No eval split for streaming
assert total_num_steps == 100 # Should use max_steps
assert len(prompters) == 1
# Test that we can iterate through the dataset
sample_count = 0
for sample in train_dataset:
assert "input_ids" in sample
assert "attention_mask" in sample
assert "labels" in sample
sample_count += 1
if sample_count >= 2: # Just test first few samples
break
assert sample_count == 2
def test_dataset_mixing_strategy_validation(self):
"""Test validation of dataset mixing strategy configuration."""
from axolotl.utils.data.shared import _merge_datasets_with_strategy
# Test valid strategies work
valid_strategies = ["round_robin", "weighted", "random"]
dataset1 = Dataset.from_dict({"text": ["a"], "source": ["ds1"]})
dataset2 = Dataset.from_dict({"text": ["b"], "source": ["ds2"]})
for strategy in valid_strategies:
cfg = DictDefault(
{
"dataset_mixing_strategy": strategy,
"mixing_weights": [0.5, 0.5] if strategy == "weighted" else None,
"seed": 42,
}
)
# Should not raise an error
merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg)
assert len(merged) >= 1
def test_regular_dataset_round_robin_mixing(self):
"""Test round-robin mixing for regular datasets."""
from axolotl.utils.data.shared import _merge_datasets_with_strategy
# Create test datasets
dataset1 = Dataset.from_dict(
{"text": ["ds1_item1", "ds1_item2"], "source": ["ds1", "ds1"]}
)
dataset2 = Dataset.from_dict(
{"text": ["ds2_item1", "ds2_item2"], "source": ["ds2", "ds2"]}
)
cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42})
merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg)
# Should have all samples from both datasets
assert len(merged) == 4
assert isinstance(merged, Dataset)
# Check that samples are interleaved (not just concatenated)
sources = [sample["source"] for sample in merged]
# Round-robin should alternate between datasets
assert sources != ["ds1", "ds1", "ds2", "ds2"] # Not concatenated
def test_regular_dataset_weighted_mixing(self):
"""Test weighted mixing for regular datasets."""
from axolotl.utils.data.shared import _merge_datasets_with_strategy
# Create test datasets
dataset1 = Dataset.from_dict(
{
"text": ["ds1_item1", "ds1_item2", "ds1_item3", "ds1_item4"],
"source": ["ds1"] * 4,
}
)
dataset2 = Dataset.from_dict(
{
"text": ["ds2_item1", "ds2_item2", "ds2_item3", "ds2_item4"],
"source": ["ds2"] * 4,
}
)
cfg = DictDefault(
{
"dataset_mixing_strategy": "weighted",
"mixing_weights": [0.75, 0.25], # 3:1 ratio
"seed": 42,
}
)
merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg)
# Should have samples proportional to weights
assert len(merged) > 0
assert isinstance(merged, Dataset)
# Count samples from each dataset
sources = [sample["source"] for sample in merged]
ds1_count = sources.count("ds1")
ds2_count = sources.count("ds2")
# Should have samples from both datasets
assert ds1_count > 0 and ds2_count > 0 # Both datasets should be represented
def test_streaming_dataset_mixing(self):
"""Test that streaming datasets use HuggingFace interleave_datasets."""
from axolotl.utils.data.shared import _merge_datasets_with_strategy
# Create test streaming datasets
def gen1():
yield {"text": "stream1_item1", "source": "stream1"}
yield {"text": "stream1_item2", "source": "stream1"}
def gen2():
yield {"text": "stream2_item1", "source": "stream2"}
yield {"text": "stream2_item2", "source": "stream2"}
stream1 = IterableDataset.from_generator(gen1)
stream2 = IterableDataset.from_generator(gen2)
cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42})
merged = _merge_datasets_with_strategy([stream1, stream2], cfg)
# Should return an IterableDataset
assert isinstance(merged, IterableDataset)
# Test that we can iterate and get samples
samples = list(merged.take(3))
assert len(samples) >= 2 # Should get at least 2 samples
# Should have samples from both datasets
sources = [sample["source"] for sample in samples]
assert len(set(sources)) >= 1 # At least one unique source

View File

@@ -1,11 +1,16 @@
"""Module for testing dataset sequence packing"""
import unittest
from pathlib import Path
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
from axolotl.train import setup_model_and_trainer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
@@ -31,6 +36,43 @@ class TestPacking(unittest.TestCase):
}
)
def test_increments_attention(self):
prompter = AlpacaPrompter("chat")
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
dateset = load_dataset(
"json",
data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
)["train"]
dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
constant_len_dataset = ConstantLengthDataset(
self.tokenizer,
[dataset],
seq_length=2048,
)
packed_dataset = Dataset.from_list(list(constant_len_dataset))
example = packed_dataset[0]
next_bos_index = (
example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
) # add one since we sliced
# first example doesn't have mask reset
assert example["input_ids"][0] == self.tokenizer.bos_token_id
assert example["attention_mask"][0] == 1
assert example["position_ids"][0] == 0
assert example["position_ids"][1] == 1
# but subsequent one does
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
assert example["attention_mask"][next_bos_index] == 2
assert example["position_ids"][next_bos_index] == 0
assert example["position_ids"][next_bos_index + 1] == 1
@with_temp_dir
def test_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code