Compare commits

..

8 Commits

Author SHA1 Message Date
Dan Saunders
c3e1882de5 progress 2025-08-22 02:43:16 -04:00
Dan Saunders
889b27ecf1 tui 2025-08-22 05:08:02 +00:00
Wing Lian
0fa752e58b upgrade flash-attn to 2.8.3 for gpt-oss attn sink support (#3082) 2025-08-21 15:04:10 -04:00
Dan Saunders
08e517ea48 Update .coderabbit.yaml (#3091) [skip ci] 2025-08-20 22:14:13 -04:00
Wing Lian
07fd22f39b better handling of lora w bias with fsdp2 and handling of files when saving model checkpoint (#3090) 2025-08-20 15:17:48 -04:00
Wing Lian
06eaf6c448 misc fixes (#3085) 2025-08-20 08:52:26 -04:00
goggle
050210e637 fix: Sweep runs overwrite each other because output_dir from base config is reused (#3080)
* refactor: improve output_dir handling in generate_config_files

* fix typo

* cli: harden sweep output_dir handling with base fallback

- Ensure sweep permutations always resolve a valid output_dir
- Default to ./model-out if neither permutation nor base config sets output_dir
- Append sweepXXXX suffix consistently for each permutation
- Prevent Path(None) TypeError and improve robustness of sweep config generation

* fix typo

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-08-19 20:25:20 -04:00
Wing Lian
05cedbfb1e add baseten info for gpt-oss recipe (#3078)
* add bsaeten info for gpt-oss recipe

* incorporate PR review
2025-08-19 13:30:37 -04:00
40 changed files with 13186 additions and 10343 deletions

View File

@@ -12,5 +12,6 @@ reviews:
auto_review:
enabled: true
drafts: false
auto_incremental_review: true
chat:
auto_reply: true

File diff suppressed because it is too large Load Diff

View File

@@ -41,6 +41,12 @@ model, and final model output, you may need at least 3TB of free disk space to k
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
```
To simplify fine-tuning across 2 nodes × 8x H100 (80GB) GPUs, we've partnered with [Baseten](https://baseten.co) to showcase multi-node
training of the 120B model using Baseten Truss. You can read more about this recipe on
[Baseten's blog](https://www.baseten.co/blog/how-to-fine-tune-gpt-oss-120b-with-baseten-and-axolotl/). The recipe can
be found on their
[GitHub](https://github.com/basetenlabs/ml-cookbook/tree/main/examples/oss-gpt-120b-axolotl/training).
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
@@ -61,9 +67,23 @@ mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
### Inferencing your fine-tuned model
#### vLLM
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
for more information about using a special vllm-openai docker image for inferencing with vLLM.
Optionally, vLLM can be installed from nightly:
```bash
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
```
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
```bash
vllm serve ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-20b --host 0.0.0.0 --port 8888 --tensor-parallel-size 8
```
#### SGLang
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:

View File

@@ -44,7 +44,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -40,7 +40,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
dataset_prepared_path: ./outputs/last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
@@ -41,7 +41,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
dataset_prepared_path: ./outputs/last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
@@ -40,7 +40,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -53,7 +53,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -72,3 +72,8 @@ axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.5
mistral-common==1.8.3
# TUI dependencies
textual==1.0.0
rich==14.1.0
tree_sitter_ruby==0.23.1

View File

@@ -118,9 +118,9 @@ def get_package_version():
extras_require = {
"flash-attn": ["flash-attn==2.8.2"],
"flash-attn": ["flash-attn==2.8.3"],
"ring-flash-attn": [
"flash-attn==2.8.2",
"flash-attn==2.8.3",
"ring-flash-attn>=0.1.7",
"yunchang==0.6.0",
],

View File

@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
return res
def get_image(self):
docker_tag = "main-py3.11-cu124-2.6.0"
docker_tag = "main-py3.11-cu126-2.7.1"
if self.config.docker_tag:
docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}"
@@ -200,7 +200,7 @@ class ModalCloud(Cloud):
if family in ["a10", "a10g"]:
return modal.gpu.A10G(count=count)
if family == "h100":
return modal.gpu.H100(count=count)
return f"H100:{count}"
if family == "t4":
return modal.gpu.T4(count=count)
if family == "l4":

View File

@@ -64,7 +64,7 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer

View File

@@ -344,6 +344,26 @@ def delinearize_llama4(model: str, output: str):
cli.add_command(lm_eval)
@cli.command()
def tui():
"""
Launch the Axolotl Terminal User Interface (TUI).
Provides an interactive interface for configuration management,
training monitoring, dataset handling, and model operations.
"""
try:
from axolotl.tui.app import run
run()
except ImportError:
click.echo(
"TUI dependencies not installed. Install with: pip install textual rich"
)
except Exception as e:
click.echo(f"Error launching TUI: {e}")
def main():
cli()

View File

@@ -97,7 +97,8 @@ def do_cli(
"""
# pylint: disable=duplicate-code
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
parsed_cfg = load_cfg(config, **kwargs)
is_preprocess = kwargs.pop("is_preprocess", True)
parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs)
parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -3,11 +3,12 @@
import random
from copy import deepcopy
from itertools import product
from typing import Any
def generate_sweep_configs(
base_config: dict[str, list], sweeps_config: dict[str, list]
) -> list[dict[str, list]]:
) -> list[dict[str, Any]]:
"""
Recursively generates all possible configurations by applying sweeps to the base config.

View File

@@ -4,6 +4,7 @@ import os
import subprocess # nosec
import sys
import tempfile
from pathlib import Path
from typing import Any, Iterator, Literal
import yaml
@@ -88,7 +89,12 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str,
# Generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
is_group = len(permutations) > 1
for permutation in permutations:
base_output_dir = base_config.get("output_dir", "./model-out")
for idx, permutation in enumerate(permutations, start=1):
permutation_dir = Path(permutation.get("output_dir", base_output_dir))
permutation_id = f"sweep{idx:04d}"
permutation["output_dir"] = str(permutation_dir / permutation_id)
# pylint: disable=consider-using-with
temp_file = tempfile.NamedTemporaryFile(
mode="w",

View File

@@ -40,7 +40,6 @@ from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
MambaDataCollator,
StreamingDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
@@ -423,17 +422,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
is_eval=False,
**kwargs,
):
from datasets import IterableDataset
if isinstance(self.train_dataset, IterableDataset) and not is_eval:
LOG.info("Using StreamingDataCollator")
return StreamingDataCollator(
tokenizer=self.tokenizer,
cfg=self.cfg,
prompter=None,
**kwargs,
)
if training_args.pretraining:
if (
self.cfg.pretraining_sample_concatenation is False

View File

@@ -43,11 +43,7 @@ class TokenizedPromptDataset(Dataset):
)
def process(self, dataset):
# For IterableDataset, we can't access features upfront
# We'll need to infer from the first batch
features = None
if hasattr(dataset, "features") and dataset.features:
features = dataset.features.keys()
features = dataset.features.keys()
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
@@ -58,29 +54,18 @@ class TokenizedPromptDataset(Dataset):
hasattr(self.prompt_tokenizer, "filter_rows")
and self.prompt_tokenizer.filter_rows
):
filter_kwargs = {"desc": "Strategy Filtering Rows"}
# Only add num_proc for regular datasets
if features is not None:
filter_kwargs["num_proc"] = self.process_count
dataset = dataset.filter(
self.prompt_tokenizer.filter_rows,
**filter_kwargs,
num_proc=self.process_count,
desc="Strategy Filtering Rows",
)
map_kwargs = {
**map_kwargs,
"desc": "Tokenizing Prompts",
}
# Only add remove_columns for regular datasets
if features is not None:
map_kwargs["remove_columns"] = features
map_kwargs["num_proc"] = self.process_count
map_kwargs["keep_in_memory"] = self.keep_in_memory
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=self.process_count,
remove_columns=features,
keep_in_memory=self.keep_in_memory,
desc="Tokenizing Prompts",
**map_kwargs,
)

View File

@@ -187,7 +187,7 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
if module.base_layer.bias is not None:
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to(

View File

@@ -72,9 +72,10 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
builder_kwargs["message_field_training"] = message_field_training
chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml"))
format_message = (
lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment
)
def format_message(x):
return x
if chat_template == "chatml":
from axolotl.core.chat.format.chatml import format_message # noqa F811
if chat_template.startswith("llama3"):

View File

@@ -253,7 +253,9 @@ def save_trained_model(
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return
if trainer.is_fsdp_enabled or cfg.fsdp_config:
if ( # pylint: disable=too-many-nested-blocks
trainer.is_fsdp_enabled or cfg.fsdp_config
):
if cfg.fsdp_config or cfg.fsdp:
if cfg.fsdp_config.final_state_dict_type:
state_dict_type = cfg.fsdp_config.final_state_dict_type
@@ -285,6 +287,8 @@ def save_trained_model(
if trainer.accelerator.is_main_process:
# move all files in merged_path to cfg.output_dir
for merged_file in Path(merged_path).iterdir():
if (Path(cfg.output_dir) / merged_file.name).exists():
(Path(cfg.output_dir) / merged_file.name).unlink()
shutil.move(str(merged_file), cfg.output_dir)
shutil.rmtree(merged_path) # remove what should be an empty dir
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207

216
src/axolotl/tui/README.md Normal file
View File

@@ -0,0 +1,216 @@
# Axolotl TUI (Terminal User Interface)
A comprehensive Terminal User Interface for Axolotl, providing an interactive way to manage configurations, training jobs, datasets, models, and system monitoring.
## Features
### 🏠 Main Dashboard
- **Welcome Screen**: Central hub with quick access to all features
- **Keyboard Navigation**: Efficient navigation with keyboard shortcuts
- **Screen Management**: Easy switching between different functional areas
### 📝 Configuration Management
- **YAML Editor**: Syntax-highlighted editor for Axolotl configurations
- **Real-time Validation**: Instant config validation with detailed error reporting
- **File Browser**: Navigate and select configuration files
- **Template Loading**: Load example configurations
- **Remote Config Support**: Load configurations from URLs
**Key Shortcuts:**
- `Ctrl+N`: New configuration
- `Ctrl+S`: Save configuration
- `Ctrl+V`: Validate configuration
- `Ctrl+E`: Toggle edit mode
### 🚀 Training Management
- **Job Launcher**: Start training with different launchers (accelerate, torchrun)
- **Real-time Monitoring**: Live training progress and metrics
- **Loss Visualization**: Sparkline charts for loss curves
- **Job Control**: Start, stop, resume, and manage multiple training jobs
- **Log Streaming**: Real-time log viewing and filtering
**Key Shortcuts:**
- `Ctrl+T`: New training job
- `Ctrl+R`: Resume training
- `Ctrl+X`: Stop training
- `R`: Refresh status
### 📊 Dataset Management
- **Dataset Browser**: Explore local and remote datasets
- **Preview & Statistics**: View dataset samples and metadata
- **Preprocessing**: Run dataset preprocessing with progress tracking
- **HuggingFace Integration**: Download and manage HF datasets
- **Format Detection**: Automatic dataset format recognition
**Key Shortcuts:**
- `Ctrl+P`: Preprocess dataset
- `Ctrl+V`: Preview dataset
- `Ctrl+I`: Dataset information
- `R`: Refresh dataset list
### 🤖 Model Management
- **Model Discovery**: Automatically find trained models
- **LoRA Operations**: Merge LoRA adapters with base models
- **Quantization**: Quantize models for deployment
- **Evaluation**: Run model evaluation benchmarks
- **Storage Info**: View model sizes and storage details
**Key Shortcuts:**
- `Ctrl+M`: Merge LoRA
- `Ctrl+Q`: Quantize model
- `Ctrl+E`: Evaluate model
- `R`: Refresh model list
### 💬 Inference & Testing
- **Interactive Chat**: Chat interface for model testing
- **Parameter Tuning**: Adjust inference parameters (temperature, top-p, max tokens)
- **Model Loading**: Load and switch between different models
- **Chat History**: Save and load conversation history
- **Gradio Integration**: Launch Gradio web interface
**Key Shortcuts:**
- `Ctrl+Enter`: Send message
- `Ctrl+C`: Clear chat
- `Ctrl+L`: Load model
- `Ctrl+S`: Save chat
### 📈 System Monitoring
- **Resource Monitoring**: Real-time CPU, GPU, and memory usage
- **Process Management**: View and manage running processes
- **Performance Graphs**: Historical usage charts with sparklines
- **GPU Information**: Detailed GPU status and memory usage
- **Temperature Monitoring**: System temperature tracking
**Key Shortcuts:**
- `R`: Refresh metrics
- `Ctrl+K`: Kill selected process
## Installation
### Dependencies
```bash
pip install textual==1.0.0 rich==14.1.0
```
### Launch TUI
```bash
# From command line
python -m axolotl.cli.main tui
# From Python code
from axolotl.tui.app import run
run()
```
## Architecture
### Screen Structure
```
AxolotlTUI (Main App)
├── WelcomeScreen (Dashboard)
├── ConfigScreen (Configuration Management)
├── TrainingScreen (Training Management)
├── DatasetScreen (Dataset Management)
├── ModelScreen (Model Management)
├── InferenceScreen (Inference & Testing)
└── MonitorScreen (System Monitoring)
```
### Key Components
- **BaseScreen**: Common functionality for all screens
- **Screen Navigation**: Stack-based screen management
- **Event Handling**: Reactive UI updates
- **Background Tasks**: Non-blocking operations
- **State Management**: Shared application state
### Integration Points
- **CLI Commands**: Seamless integration with existing axolotl CLI
- **Configuration System**: Uses axolotl's native config loading
- **Training Pipeline**: Integrates with axolotl training functions
- **Model Loading**: Compatible with axolotl model management
## Usage Examples
### 1. Creating a New Configuration
1. Launch TUI: `python -m axolotl.cli.main tui`
2. Select "Configuration Management" or press `C`
3. Press `Ctrl+N` for new configuration
4. Edit the template configuration
5. Press `Ctrl+V` to validate
6. Press `Ctrl+S` to save
### 2. Starting a Training Job
1. Navigate to "Training Management" or press `T`
2. Press `Ctrl+T` for new training job
3. Select configuration file and launcher
4. Monitor progress in real-time
5. View loss curves and logs
### 3. Interactive Model Testing
1. Go to "Inference & Testing" or press `I`
2. Load a trained model with `Ctrl+L`
3. Adjust inference parameters as needed
4. Start chatting with the model
5. Save conversation with `Ctrl+S`
## Navigation
### Global Shortcuts
- `Ctrl+Q`: Quit application
- `Escape`: Go back/close current screen
- `Tab`: Navigate between UI elements
- `Enter`: Select/activate element
- `Space`: Toggle switches/checkboxes
### Screen Shortcuts
Each screen has specific shortcuts displayed in the footer. Common patterns:
- `Ctrl+[Letter]`: Primary actions
- `R`: Refresh/reload
- `F1-F12`: Function keys for advanced features
## Customization
### Themes
The TUI uses Textual's theming system and can be customized by modifying the CSS in each screen class.
### Adding New Screens
1. Create a new screen class inheriting from `BaseScreen`
2. Implement the `compose()` method for UI layout
3. Add event handlers for user interactions
4. Register the screen in the main app navigation
### Extending Functionality
- Add new widgets to existing screens
- Implement custom data visualization
- Integrate with external tools and APIs
- Add new keyboard shortcuts
## Troubleshooting
### Common Issues
1. **Import Errors**: Ensure textual and rich are installed
2. **Permission Errors**: Check file system permissions for config directories
3. **GPU Monitoring**: Install pynvml for GPU monitoring features
4. **Config Validation**: Ensure axolotl dependencies are properly installed
### Debug Mode
Launch with debug logging:
```bash
TEXTUAL_LOG=DEBUG python -m axolotl.cli.main tui
```
### Performance
- Use `Ctrl+\` to open Textual's debug console
- Monitor memory usage with the system monitor
- Disable auto-refresh for better performance on slower systems
## Contributing
The TUI is designed to be extensible. Contributions are welcome for:
- New screen implementations
- Enhanced visualizations
- Better keyboard navigation
- Additional integrations
- Performance improvements
See the main Axolotl repository for contribution guidelines.

View File

@@ -0,0 +1 @@
"""Axolotl Terminal User Interface (TUI)."""

180
src/axolotl/tui/app.py Normal file
View File

@@ -0,0 +1,180 @@
"""Main TUI application for Axolotl."""
from textual import on
from textual.app import App, ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.screen import Screen
from textual.widgets import Button, Footer, Header, Static
from axolotl.tui.screens.config import ConfigScreen
from axolotl.tui.screens.datasets import DatasetScreen
from axolotl.tui.screens.inference import InferenceScreen
from axolotl.tui.screens.models import ModelScreen
from axolotl.tui.screens.monitor import MonitorScreen
from axolotl.tui.screens.training import TrainingScreen
class WelcomeScreen(Screen):
"""Welcome screen with main menu."""
BINDINGS = [
Binding("q", "quit", "Quit"),
Binding("c", "config", "Configuration"),
Binding("t", "training", "Training"),
Binding("d", "datasets", "Datasets"),
Binding("m", "models", "Models"),
Binding("i", "inference", "Inference"),
Binding("s", "monitor", "System Monitor"),
]
def compose(self) -> ComposeResult:
"""Compose the welcome screen."""
yield Header()
yield Container(
Static("🦾 Axolotl TUI", classes="title"),
Static(
"A Terminal User Interface for fine-tuning LLMs", classes="subtitle"
),
Container(
Button("Configuration Management [C]", id="config", variant="primary"),
Button("Training Management [T]", id="training", variant="primary"),
Button("Dataset Management [D]", id="datasets", variant="primary"),
Button("Model Management [M]", id="models", variant="primary"),
Button("Inference & Testing [I]", id="inference", variant="primary"),
Button("System Monitor [S]", id="monitor", variant="primary"),
classes="menu-container",
),
classes="welcome-container",
)
yield Footer()
def action_quit(self) -> None:
"""Quit the application."""
self.app.exit()
def action_config(self) -> None:
"""Navigate to config screen."""
self.app.push_screen(ConfigScreen())
def action_training(self) -> None:
"""Navigate to training screen."""
self.app.push_screen(TrainingScreen())
def action_datasets(self) -> None:
"""Navigate to datasets screen."""
self.app.push_screen(DatasetScreen())
def action_models(self) -> None:
"""Navigate to models screen."""
self.app.push_screen(ModelScreen())
def action_inference(self) -> None:
"""Navigate to inference screen."""
self.app.push_screen(InferenceScreen())
def action_monitor(self) -> None:
"""Navigate to monitor screen."""
self.app.push_screen(MonitorScreen())
@on(Button.Pressed, "#config")
def on_config_pressed(self) -> None:
"""Handle config button press."""
self.action_config()
@on(Button.Pressed, "#training")
def on_training_pressed(self) -> None:
"""Handle training button press."""
self.action_training()
@on(Button.Pressed, "#datasets")
def on_datasets_pressed(self) -> None:
"""Handle datasets button press."""
self.action_datasets()
@on(Button.Pressed, "#models")
def on_models_pressed(self) -> None:
"""Handle models button press."""
self.action_models()
@on(Button.Pressed, "#inference")
def on_inference_pressed(self) -> None:
"""Handle inference button press."""
self.action_inference()
@on(Button.Pressed, "#monitor")
def on_monitor_pressed(self) -> None:
"""Handle monitor button press."""
self.action_monitor()
class AxolotlTUI(App):
"""Main Axolotl TUI Application."""
CSS = """
.title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.subtitle {
text-align: center;
padding: 1;
color: $text-muted;
}
.welcome-container {
align: center middle;
height: 100%;
width: 100%;
}
.menu-container {
layout: vertical;
align: center middle;
padding: 2;
width: auto;
height: auto;
}
.menu-container Button {
width: 35;
margin: 1;
}
WelcomeScreen {
align: center middle;
}
"""
BINDINGS = [
Binding("ctrl+q", "quit", "Quit", priority=True),
Binding("escape", "back", "Back", priority=True),
]
def on_mount(self) -> None:
"""Called when the app is mounted."""
self.title = "Axolotl TUI"
self.sub_title = "Fine-tuning LLMs made easy"
self.push_screen(WelcomeScreen())
def action_quit(self) -> None:
"""Quit the application."""
self.exit()
def action_back(self) -> None:
"""Go back to previous screen."""
if len(self.screen_stack) > 1:
self.pop_screen()
def run():
"""Run the Axolotl TUI application."""
app = AxolotlTUI()
app.run()
if __name__ == "__main__":
run()

View File

@@ -0,0 +1 @@
"""TUI dialogs for Axolotl."""

View File

@@ -0,0 +1,112 @@
"""Training dialogs for Axolotl TUI."""
from pathlib import Path
from textual import on
from textual.app import ComposeResult
from textual.containers import Container
from textual.screen import ModalScreen
from textual.widgets import Button, Input, Label, Select, Static
class NewTrainingDialog(ModalScreen):
"""Dialog for starting a new training job."""
CSS = """
NewTrainingDialog {
align: center middle;
}
.dialog-container {
background: $surface;
border: thick $primary;
padding: 2;
width: 60;
height: auto;
}
.dialog-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.form-field {
margin: 1 0;
}
.form-label {
margin: 0 0 1 0;
color: $text-muted;
}
.button-container {
layout: horizontal;
align: center middle;
margin: 2 0 0 0;
}
.button-container Button {
margin: 0 1;
}
"""
def compose(self) -> ComposeResult:
"""Compose the dialog."""
yield Container(
Static("Start New Training Job", classes="dialog-title"),
Container(
Label("Configuration File:", classes="form-label"),
Input(
placeholder="Path to config YAML file",
id="config-path",
value="/workspace/configs/",
),
classes="form-field",
),
Container(
Label("Launcher:", classes="form-label"),
Select(
[
("accelerate", "Accelerate (Recommended)"),
("torchrun", "TorchRun"),
("deepspeed", "DeepSpeed"),
],
id="launcher",
value="accelerate",
),
classes="form-field",
),
Container(
Button("Start Training", variant="primary", id="start"),
Button("Cancel", variant="default", id="cancel"),
classes="button-container",
),
classes="dialog-container",
)
@on(Button.Pressed, "#start")
def handle_start(self) -> None:
"""Handle start button press."""
config_input = self.query_one("#config-path", Input)
launcher_select = self.query_one("#launcher", Select)
config_path = config_input.value.strip()
if not config_path:
return
if not Path(config_path).exists():
return
result = {
"config_path": config_path,
"launcher": launcher_select.value,
}
self.dismiss(result)
@on(Button.Pressed, "#cancel")
def handle_cancel(self) -> None:
"""Handle cancel button press."""
self.dismiss(None)

View File

@@ -0,0 +1 @@
"""TUI screens for Axolotl."""

View File

@@ -0,0 +1,50 @@
"""Base screen class for Axolotl TUI screens."""
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.screen import Screen
from textual.widgets import Footer, Header, Static
class BaseScreen(Screen):
"""Base class for all Axolotl TUI screens."""
BINDINGS = [
Binding("escape", "back", "Back"),
Binding("q", "quit", "Quit"),
]
def __init__(self, title: str = "Axolotl", subtitle: str = ""):
"""Initialize the base screen.
Args:
title: The screen title
subtitle: Optional subtitle for the screen
"""
super().__init__()
self.screen_title = title
self.screen_subtitle = subtitle
def compose(self) -> ComposeResult:
"""Compose the base screen layout."""
yield Header()
yield Container(
Static(f"🦾 {self.screen_title}", classes="screen-title"),
(
Static(self.screen_subtitle, classes="screen-subtitle")
if self.screen_subtitle
else Static("")
),
Container(id="content"),
id="main-container",
)
yield Footer()
def action_back(self) -> None:
"""Go back to previous screen."""
self.app.pop_screen()
def action_quit(self) -> None:
"""Quit the application."""
self.app.exit()

View File

@@ -0,0 +1,376 @@
"""Configuration management screen for Axolotl TUI."""
import os
from pathlib import Path
from typing import Optional
import yaml
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.reactive import reactive
from textual.widgets import (
Button,
DirectoryTree,
Footer,
Header,
Label,
Log,
Static,
TextArea,
)
from axolotl.tui.screens.base import BaseScreen
class ConfigScreen(BaseScreen):
"""Configuration management screen."""
BINDINGS = [
Binding("ctrl+n", "new_config", "New Config"),
Binding("ctrl+o", "open_config", "Open Config"),
Binding("ctrl+s", "save_config", "Save Config"),
Binding("ctrl+v", "validate_config", "Validate"),
Binding("ctrl+e", "edit_mode", "Toggle Edit Mode"),
]
CSS = """
.config-container {
layout: horizontal;
height: 100%;
}
.file-browser {
width: 30%;
border: solid $primary;
padding: 1;
margin: 1;
}
.config-editor {
width: 70%;
border: solid $secondary;
padding: 1;
margin: 1;
}
.config-form {
height: 80%;
}
.config-actions {
layout: horizontal;
height: 3;
align: center middle;
padding: 1;
}
.config-actions Button {
margin: 0 1;
}
TextArea {
height: 100%;
}
.validation-log {
height: 20%;
border: solid $warning;
padding: 1;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
"""
def __init__(self):
"""Initialize the config screen."""
super().__init__(
title="Configuration Management",
subtitle="Create, edit, and validate Axolotl configurations",
)
self.current_config_path: Optional[Path] = None
self.edit_mode = reactive(False)
self.config_data = {}
def compose(self) -> ComposeResult:
"""Compose the config screen layout."""
yield Header()
yield Container(
Static("🦾 Configuration Management", classes="screen-title"),
Static(
"Create, edit, and validate Axolotl configurations",
classes="screen-subtitle",
),
Container(
Container(
Label("Config Files"),
DirectoryTree(
(
Path("/workspace/configs")
if Path("/workspace/configs").exists()
else Path.cwd()
),
id="config-tree",
),
classes="file-browser",
),
Container(
Container(
TextArea(
"",
language="yaml",
theme="monokai",
id="config-editor",
read_only=True,
),
classes="config-form",
),
Container(
Button("New", id="new-config", variant="primary"),
Button("Open", id="open-config", variant="primary"),
Button("Save", id="save-config", variant="success"),
Button("Validate", id="validate-config", variant="warning"),
Button("Edit Mode", id="toggle-edit", variant="default"),
Button("Load Example", id="load-example", variant="default"),
classes="config-actions",
),
Container(
Log(id="validation-log"),
classes="validation-log",
),
classes="config-editor",
),
classes="config-container",
),
id="content",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
tree = self.query_one("#config-tree", DirectoryTree)
tree.show_root = False
tree.guide_depth = 3
log = self.query_one("#validation-log", Log)
log.write_line("Ready to load configuration files...")
@on(DirectoryTree.FileSelected)
def handle_file_selected(self, event: DirectoryTree.FileSelected) -> None:
"""Handle file selection from the directory tree."""
if event.path.suffix in [".yaml", ".yml"]:
self.load_config_file(event.path)
def load_config_file(self, path: Path) -> None:
"""Load a configuration file."""
self.current_config_path = path
try:
with open(path, "r") as f:
content = f.read()
self.config_data = yaml.safe_load(content)
editor = self.query_one("#config-editor", TextArea)
editor.load_text(content)
log = self.query_one("#validation-log", Log)
log.clear()
log.write_line(f"✅ Loaded: {path.name}")
except Exception as e:
log = self.query_one("#validation-log", Log)
log.write_line(f"❌ Error loading {path.name}: {str(e)}")
@on(Button.Pressed, "#new-config")
def handle_new_config(self) -> None:
"""Create a new configuration."""
template = """# Axolotl Configuration
base_model:
model_type:
tokenizer_type:
# Dataset Configuration
datasets:
- path:
type:
# Training Configuration
output_dir: ./outputs
num_epochs: 3
micro_batch_size: 1
gradient_accumulation_steps: 4
learning_rate: 0.00002
warmup_steps: 100
eval_steps: 100
save_steps: 500
# LoRA Configuration (optional)
adapter: lora
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
# Training optimizations
gradient_checkpointing: true
flash_attention: true
bf16: auto
tf32: true
# Logging
logging_steps: 10
wandb_project:
wandb_entity:
"""
editor = self.query_one("#config-editor", TextArea)
editor.load_text(template)
editor.read_only = False
self.edit_mode = True
self.update_edit_button()
log = self.query_one("#validation-log", Log)
log.clear()
log.write_line("📝 New configuration created. Edit and save when ready.")
@on(Button.Pressed, "#save-config")
def handle_save_config(self) -> None:
"""Save the current configuration."""
editor = self.query_one("#config-editor", TextArea)
content = editor.text
if not content.strip():
log = self.query_one("#validation-log", Log)
log.write_line("⚠️ Cannot save empty configuration")
return
if not self.current_config_path:
default_path = Path("/workspace/configs/new_config.yaml")
default_path.parent.mkdir(parents=True, exist_ok=True)
self.current_config_path = default_path
try:
with open(self.current_config_path, "w") as f:
f.write(content)
log = self.query_one("#validation-log", Log)
log.write_line(f"💾 Saved: {self.current_config_path.name}")
except Exception as e:
log = self.query_one("#validation-log", Log)
log.write_line(f"❌ Error saving: {str(e)}")
@on(Button.Pressed, "#validate-config")
@work(thread=True)
async def handle_validate_config(self) -> None:
"""Validate the current configuration."""
editor = self.query_one("#config-editor", TextArea)
content = editor.text
if not content.strip():
log = self.query_one("#validation-log", Log)
log.write_line("⚠️ No configuration to validate")
return
log = self.query_one("#validation-log", Log)
log.clear()
log.write_line("🔍 Validating configuration...")
try:
import tempfile
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
f.write(content)
temp_path = f.name
from argparse import Namespace
from axolotl.cli.config import check_user_config
args = Namespace(
config=temp_path,
debug=False,
debug_text_only=False,
debug_num_examples=5,
accelerate_config=None,
multi_gpu=False,
)
check_user_config(args)
log.write_line("✅ Configuration is valid!")
os.unlink(temp_path)
except Exception as e:
log.write_line(f"❌ Validation failed: {str(e)}")
if "temp_path" in locals():
os.unlink(temp_path)
@on(Button.Pressed, "#toggle-edit")
def handle_toggle_edit(self) -> None:
"""Toggle edit mode for the configuration."""
editor = self.query_one("#config-editor", TextArea)
self.edit_mode = not self.edit_mode
editor.read_only = not self.edit_mode
self.update_edit_button()
log = self.query_one("#validation-log", Log)
if self.edit_mode:
log.write_line("✏️ Edit mode enabled")
else:
log.write_line("👁️ View mode enabled")
@on(Button.Pressed, "#load-example")
async def handle_load_example(self) -> None:
"""Load an example configuration."""
examples_dir = Path("/workspace/axolotl/examples")
if not examples_dir.exists():
log = self.query_one("#validation-log", Log)
log.write_line("⚠️ Examples directory not found")
return
yaml_files = list(examples_dir.glob("**/*.yml")) + list(
examples_dir.glob("**/*.yaml")
)
if yaml_files:
self.load_config_file(yaml_files[0])
log = self.query_one("#validation-log", Log)
log.write_line(f"📚 Loaded example: {yaml_files[0].name}")
def update_edit_button(self) -> None:
"""Update the edit button appearance."""
button = self.query_one("#toggle-edit", Button)
if self.edit_mode:
button.variant = "warning"
button.label = "Edit Mode: ON"
else:
button.variant = "default"
button.label = "Edit Mode: OFF"
def action_new_config(self) -> None:
"""Create a new configuration."""
self.handle_new_config()
def action_save_config(self) -> None:
"""Save the current configuration."""
self.handle_save_config()
def action_validate_config(self) -> None:
"""Validate the current configuration."""
self.handle_validate_config()
def action_edit_mode(self) -> None:
"""Toggle edit mode."""
self.handle_toggle_edit()

View File

@@ -0,0 +1,440 @@
"""Dataset management screen for Axolotl TUI."""
import json
from pathlib import Path
from typing import Dict, Optional
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Label,
Log,
ProgressBar,
Static,
TextArea,
)
from axolotl.tui.screens.base import BaseScreen
class DatasetScreen(BaseScreen):
"""Dataset management screen."""
BINDINGS = [
Binding("ctrl+p", "preprocess", "Preprocess"),
Binding("ctrl+v", "preview", "Preview"),
Binding("ctrl+i", "info", "Info"),
Binding("r", "refresh", "Refresh"),
]
CSS = """
.dataset-container {
layout: horizontal;
height: 100%;
}
.dataset-list {
width: 40%;
border: solid $primary;
padding: 1;
margin: 1;
}
.dataset-details {
width: 60%;
border: solid $secondary;
padding: 1;
margin: 1;
}
.dataset-actions {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
}
.dataset-actions Button {
margin: 0 1;
}
DataTable {
height: 100%;
}
.preview-container {
height: 100%;
border: solid $primary;
padding: 1;
}
TextArea {
height: 100%;
}
.stats-container {
layout: vertical;
padding: 1;
}
.stat-row {
layout: horizontal;
padding: 0 0 1 0;
}
.stat-label {
width: 50%;
color: $text-muted;
}
.stat-value {
width: 50%;
text-align: right;
text-style: bold;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
.progress-container {
padding: 1;
border: solid $warning;
margin: 1;
}
"""
def __init__(self):
"""Initialize the dataset screen."""
super().__init__(
title="Dataset Management",
subtitle="Browse, preview, and preprocess datasets",
)
self.datasets: Dict[str, Dict] = {}
self.selected_dataset: Optional[str] = None
self.preprocessing_active = False
def compose(self) -> ComposeResult:
"""Compose the dataset screen layout."""
yield Header()
yield Container(
Static("🦾 Dataset Management", classes="screen-title"),
Static(
"Browse, preview, and preprocess datasets", classes="screen-subtitle"
),
Container(
Container(
Label("Available Datasets"),
DataTable(id="dataset-table"),
Container(
Button("Load Dataset", id="load-dataset", variant="primary"),
Button("Preprocess", id="preprocess", variant="success"),
Button("Download", id="download", variant="default"),
Button("Refresh", id="refresh", variant="default"),
classes="dataset-actions",
),
classes="dataset-list",
),
Container(
TextArea("", id="dataset-preview", read_only=True),
Container(
Static("Dataset Name:", classes="stat-label"),
Static("-", id="stat-name", classes="stat-value"),
Static("Type:", classes="stat-label"),
Static("-", id="stat-type", classes="stat-value"),
Static("Size:", classes="stat-label"),
Static("-", id="stat-size", classes="stat-value"),
Static("Samples:", classes="stat-label"),
Static("-", id="stat-samples", classes="stat-value"),
Static("Features:", classes="stat-label"),
Static("-", id="stat-features", classes="stat-value"),
Static("Format:", classes="stat-label"),
Static("-", id="stat-format", classes="stat-value"),
Static("Preprocessed:", classes="stat-label"),
Static("-", id="stat-preprocessed", classes="stat-value"),
),
Log(id="processing-log"),
ProgressBar(total=100, id="preprocessing-progress"),
classes="dataset-details",
),
classes="dataset-container",
),
id="content",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.setup_dataset_table()
self.load_datasets()
log = self.query_one("#processing-log", Log)
log.write_line("Dataset manager ready.")
def setup_dataset_table(self) -> None:
"""Setup the dataset table."""
table = self.query_one("#dataset-table", DataTable)
table.add_columns("Name", "Type", "Size", "Status")
table.cursor_type = "row"
table.zebra_stripes = True
@work(thread=True)
async def load_datasets(self) -> None:
"""Load available datasets."""
# Check for local datasets
datasets_dir = Path("/workspace/datasets")
if datasets_dir.exists():
for dataset_path in datasets_dir.glob("*"):
if dataset_path.is_dir():
self.datasets[dataset_path.name] = {
"name": dataset_path.name,
"path": str(dataset_path),
"type": "local",
"size": self.get_dir_size(dataset_path),
"status": "available",
}
# Check for HuggingFace datasets in configs
configs_dir = Path("/workspace/configs")
if configs_dir.exists():
for config_file in configs_dir.glob("*.yaml"):
try:
import yaml
with open(config_file) as f:
config = yaml.safe_load(f)
if "datasets" in config:
for ds in config.get("datasets", []):
if "path" in ds:
ds_name = ds["path"].split("/")[-1]
self.datasets[ds_name] = {
"name": ds_name,
"path": ds["path"],
"type": ds.get("type", "huggingface"),
"size": "Unknown",
"status": "remote",
}
except Exception:
pass
self.refresh_dataset_table()
def get_dir_size(self, path: Path) -> str:
"""Get human-readable directory size."""
total_size = sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
for unit in ["B", "KB", "MB", "GB"]:
if total_size < 1024.0:
return f"{total_size:.2f} {unit}"
total_size /= 1024.0
return f"{total_size:.2f} TB"
def refresh_dataset_table(self) -> None:
"""Refresh the dataset table."""
table = self.query_one("#dataset-table", DataTable)
table.clear()
for name, info in self.datasets.items():
table.add_row(
name[:30],
info["type"],
info["size"],
info["status"],
)
@on(DataTable.RowSelected)
def handle_dataset_selected(self, event: DataTable.RowSelected) -> None:
"""Handle dataset selection from table."""
if event.cursor_row >= 0:
dataset_names = list(self.datasets.keys())
if event.cursor_row < len(dataset_names):
self.selected_dataset = dataset_names[event.cursor_row]
self.load_dataset_preview()
self.update_dataset_stats()
@work(thread=True)
async def load_dataset_preview(self) -> None:
"""Load preview of selected dataset."""
if not self.selected_dataset:
return
dataset_info = self.datasets[self.selected_dataset]
preview_text = ""
try:
if dataset_info["type"] == "local" and Path(dataset_info["path"]).exists():
# Load first few samples from local dataset
sample_files = list(Path(dataset_info["path"]).glob("*.json"))[:3]
samples = []
for sample_file in sample_files:
with open(sample_file) as f:
samples.append(json.load(f))
preview_text = json.dumps(samples, indent=2)
else:
# Show dataset info for remote datasets
preview_text = json.dumps(dataset_info, indent=2)
except Exception as e:
preview_text = f"Error loading preview: {str(e)}"
preview = self.query_one("#dataset-preview", TextArea)
preview.load_text(preview_text)
def update_dataset_stats(self) -> None:
"""Update dataset statistics display."""
if not self.selected_dataset:
return
info = self.datasets[self.selected_dataset]
self.query_one("#stat-name", Static).update(info["name"])
self.query_one("#stat-type", Static).update(info["type"])
self.query_one("#stat-size", Static).update(info["size"])
self.query_one("#stat-samples", Static).update("N/A")
self.query_one("#stat-features", Static).update("N/A")
self.query_one("#stat-format", Static).update("JSON")
self.query_one("#stat-preprocessed", Static).update("No")
@on(Button.Pressed, "#preprocess")
@work(thread=True)
async def handle_preprocess(self) -> None:
"""Preprocess selected dataset."""
if not self.selected_dataset or self.preprocessing_active:
return
self.preprocessing_active = True
dataset_info = self.datasets[self.selected_dataset]
log = self.query_one("#processing-log", Log)
log.clear()
log.write_line(f"🔄 Starting preprocessing for {self.selected_dataset}...")
progress = self.query_one("#preprocessing-progress", ProgressBar)
progress.update(progress=0)
try:
import subprocess
import tempfile
# Create a temporary config for preprocessing
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
config = {
"datasets": [
{
"path": dataset_info["path"],
"type": dataset_info.get("type", "alpaca"),
}
],
"output_dir": f"/tmp/preprocessed_{self.selected_dataset}",
}
import yaml
yaml.dump(config, f)
temp_config = f.name
# Run preprocessing
cmd = ["python", "-m", "axolotl.cli.preprocess", temp_config]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
# Monitor progress
for line in process.stdout:
log.write_line(line.strip())
# Update progress bar based on output
if "Processing" in line:
progress.advance(10)
process.wait()
if process.returncode == 0:
log.write_line("✅ Preprocessing completed successfully!")
dataset_info["status"] = "preprocessed"
progress.update(progress=100)
else:
log.write_line(
f"❌ Preprocessing failed with code {process.returncode}"
)
import os
os.unlink(temp_config)
except Exception as e:
log.write_line(f"❌ Error during preprocessing: {str(e)}")
finally:
self.preprocessing_active = False
self.refresh_dataset_table()
@on(Button.Pressed, "#load-dataset")
async def handle_load_dataset(self) -> None:
"""Load a new dataset."""
log = self.query_one("#processing-log", Log)
log.write_line("📦 Load dataset functionality coming soon...")
@on(Button.Pressed, "#download")
@work(thread=True)
async def handle_download(self) -> None:
"""Download a remote dataset."""
if not self.selected_dataset:
return
dataset_info = self.datasets[self.selected_dataset]
if dataset_info["type"] != "huggingface":
return
log = self.query_one("#processing-log", Log)
log.clear()
log.write_line(f"📥 Downloading {self.selected_dataset} from HuggingFace...")
try:
from datasets import load_dataset
dataset = load_dataset(dataset_info["path"])
save_path = Path(f"/workspace/datasets/{self.selected_dataset}")
save_path.mkdir(parents=True, exist_ok=True)
dataset.save_to_disk(str(save_path))
log.write_line(f"✅ Downloaded to {save_path}")
dataset_info["type"] = "local"
dataset_info["status"] = "available"
dataset_info["path"] = str(save_path)
self.refresh_dataset_table()
except Exception as e:
log.write_line(f"❌ Download failed: {str(e)}")
@on(Button.Pressed, "#refresh")
def handle_refresh(self) -> None:
"""Refresh dataset list."""
self.load_datasets()
def action_preprocess(self) -> None:
"""Preprocess selected dataset."""
self.handle_preprocess()
def action_refresh(self) -> None:
"""Refresh dataset list."""
self.handle_refresh()

View File

@@ -0,0 +1,445 @@
"""Inference and testing screen for Axolotl TUI."""
from pathlib import Path
from typing import Dict, List, Optional
from textual import events, on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.widgets import (
Button,
Input,
Label,
Log,
Select,
Static,
TextArea,
)
from axolotl.tui.screens.base import BaseScreen
class InferenceScreen(BaseScreen):
"""Inference and testing screen."""
BINDINGS = [
Binding("ctrl+enter", "send_message", "Send"),
Binding("ctrl+c", "clear_chat", "Clear"),
Binding("ctrl+l", "load_model", "Load Model"),
Binding("ctrl+s", "save_chat", "Save Chat"),
]
CSS = """
.inference-container {
layout: horizontal;
height: 100%;
}
.model-selector {
width: 30%;
border: solid $primary;
padding: 1;
margin: 1;
}
.chat-interface {
width: 70%;
border: solid $secondary;
padding: 1;
margin: 1;
}
.chat-history {
height: 70%;
border: solid $primary;
padding: 1;
margin: 0 0 1 0;
}
.input-area {
height: 20%;
border: solid $warning;
padding: 1;
margin: 0 0 1 0;
}
.chat-controls {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
}
.chat-controls Button {
margin: 0 1;
}
.model-info {
padding: 1;
border: solid $surface;
margin: 1 0;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
TextArea {
height: 100%;
}
Log {
height: 100%;
}
"""
def __init__(self):
"""Initialize the inference screen."""
super().__init__(
title="Inference & Testing", subtitle="Interactive chat and model testing"
)
self.loaded_model: Optional[str] = None
self.chat_history: List[Dict[str, str]] = []
def compose(self) -> ComposeResult:
"""Compose the inference screen layout."""
yield Container(
Static("🦾 Inference & Testing", classes="screen-title"),
Static("Interactive chat and model testing", classes="screen-subtitle"),
Container(
Container(
Label("Model Selection"),
Select(
[("No model loaded", "none")],
id="model-select",
value="none",
),
Container(
Button("Load Model", id="load-model", variant="primary"),
Button("Unload", id="unload-model", variant="default"),
Button("Gradio UI", id="gradio-ui", variant="success"),
),
Container(
Static("No model loaded", id="model-status"),
classes="model-info",
),
Label("Inference Parameters"),
Container(
Label("Temperature:"),
Input(value="0.7", id="temperature"),
Label("Max Tokens:"),
Input(value="256", id="max-tokens"),
Label("Top P:"),
Input(value="0.9", id="top-p"),
),
classes="model-selector",
),
Container(
Container(
Log(id="chat-history"),
classes="chat-history",
),
Container(
TextArea(
id="message-input",
),
classes="input-area",
),
Container(
Button("Send [Ctrl+Enter]", id="send", variant="primary"),
Button("Clear Chat", id="clear", variant="warning"),
Button("Save Chat", id="save-chat", variant="default"),
Button("Load Examples", id="load-examples", variant="default"),
classes="chat-controls",
),
classes="chat-interface",
),
classes="inference-container",
),
id="content",
)
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.load_available_models()
chat = self.query_one("#chat-history", Log)
chat.write_line("💬 Welcome to Axolotl Inference!")
chat.write_line("Load a model to start chatting.")
@work(thread=True)
async def load_available_models(self) -> None:
"""Load list of available models."""
models = [("No model loaded", "none")]
chat = self.query_one("#chat-history", Log)
chat.write_line("🔍 Scanning for available models...")
# Check for trained models
outputs_dir = Path("./outputs")
chat.write_line(f"Checking outputs directory: {outputs_dir.absolute()}")
if outputs_dir.exists():
found_models = 0
for model_dir in outputs_dir.glob("*"):
if model_dir.is_dir():
# Look for various model file types
model_files = (
list(model_dir.glob("pytorch_model.bin"))
+ list(model_dir.glob("model.safetensors"))
+ list(model_dir.glob("*.bin"))
+ list(model_dir.glob("*.safetensors"))
)
if model_files:
models.append((model_dir.name, str(model_dir)))
found_models += 1
chat.write_line(f"Found {found_models} trained models in outputs/")
else:
chat.write_line("outputs/ directory not found")
# Add some example/demo models for testing
models.extend(
[
("Demo: GPT-2 Small", "gpt2"),
("Demo: TinyLlama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"),
("Demo: Phi-2", "microsoft/phi-2"),
]
)
select = self.query_one("#model-select", Select)
select.set_options(models)
chat.write_line(f"✅ Loaded {len(models)} models in dropdown")
@on(Button.Pressed, "#load-model")
@work(thread=True)
async def handle_load_model(self) -> None:
"""Load selected model for inference."""
select = self.query_one("#model-select", Select)
if select.value == "none":
return
chat = self.query_one("#chat-history", Log)
chat.write_line(f"🔄 Loading model: {select.value}")
status = self.query_one("#model-status", Static)
status.update("Loading...")
try:
# Simulate model loading (in real implementation, would load the actual model)
import time
time.sleep(2) # Simulate loading time
self.loaded_model = select.value
status.update(f"✅ Loaded: {Path(select.value).name}")
chat.write_line("✅ Model loaded successfully!")
chat.write_line("You can now start chatting.")
except Exception as e:
status.update("❌ Failed to load")
chat.write_line(f"❌ Failed to load model: {str(e)}")
@on(Button.Pressed, "#send")
async def handle_send_message(self) -> None:
"""Send message to model."""
if not self.loaded_model:
chat = self.query_one("#chat-history", Log)
chat.write_line("⚠️ Please load a model first")
return
message_input = self.query_one("#message-input", TextArea)
message = message_input.text.strip()
if not message:
return
# Add user message to chat
chat = self.query_one("#chat-history", Log)
chat.write_line(f"👤 User: {message}")
# Clear input
message_input.clear()
# Add to history
self.chat_history.append({"role": "user", "content": message})
# Generate response (placeholder)
self.generate_response(message)
@on(TextArea.Changed, "#message-input")
def on_message_input_changed(self, event: TextArea.Changed) -> None:
"""Handle changes to the message input."""
# This could be used for features like typing indicators
pass
def on_key(self, event: events.Key) -> None:
"""Handle key events globally."""
# Check if we're focused on the message input and Ctrl+Enter is pressed
focused = self.focused
if focused and focused.id == "message-input" and event.key == "ctrl+enter":
event.prevent_default()
self.handle_send_message()
@work(thread=True)
async def generate_response(self, message: str) -> None:
"""Generate model response."""
chat = self.query_one("#chat-history", Log)
chat.write_line("🤖 Assistant: Thinking...")
try:
# Get inference parameters
float(self.query_one("#temperature", Input).value)
int(self.query_one("#max-tokens", Input).value)
float(self.query_one("#top-p", Input).value)
if not self.loaded_model or self.loaded_model == "none":
response = "I don't have a model loaded yet. Please load a model first using the 'Load Model' button."
elif self.loaded_model.startswith("gpt2"):
# Simple response for GPT-2
responses = [
f"Thanks for your message: '{message}'. I'm a GPT-2 model running in demo mode.",
"I understand you're testing the interface. GPT-2 models are great for experimentation!",
"This is a simulated GPT-2 response. In a real setup, I'd generate text based on your input.",
f"GPT-2 here! You said: '{message}'. I'd normally continue this conversation creatively.",
]
import random
response = random.choice(responses)
elif "llama" in self.loaded_model.lower():
# Response for Llama models
response = f"🦙 LLaMA model here! You asked: '{message}'. I'm designed for helpful, harmless, and honest conversations. How can I assist you today?"
elif "phi" in self.loaded_model.lower():
# Response for Phi models
response = f"Phi model responding! Your message: '{message}'. I'm optimized for reasoning and code tasks. What would you like to explore?"
else:
# Generic response for other models
response = f"Model '{self.loaded_model}' responding to: '{message}'. I'm ready to help with your questions!"
# Simulate inference time
import time
time.sleep(0.5)
# Clear the "thinking" message and show response
chat.write_line(f"🤖 Assistant: {response}")
# Add to history
self.chat_history.append({"role": "assistant", "content": response})
except Exception as e:
chat.write_line(f"❌ Error generating response: {str(e)}")
@on(Button.Pressed, "#clear")
def handle_clear_chat(self) -> None:
"""Clear chat history."""
chat = self.query_one("#chat-history", Log)
chat.clear()
self.chat_history = []
chat.write_line("💬 Chat cleared. Start a new conversation!")
@on(Button.Pressed, "#save-chat")
def handle_save_chat(self) -> None:
"""Save chat history to file."""
if not self.chat_history:
chat = self.query_one("#chat-history", Log)
chat.write_line("⚠️ No chat history to save")
return
try:
import json
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"chat_history_{timestamp}.json"
with open(filename, "w") as f:
json.dump(self.chat_history, f, indent=2)
chat = self.query_one("#chat-history", Log)
chat.write_line(f"💾 Chat saved to {filename}")
except Exception as e:
chat = self.query_one("#chat-history", Log)
chat.write_line(f"❌ Error saving chat: {str(e)}")
@on(Button.Pressed, "#load-examples")
def handle_load_examples(self) -> None:
"""Load example prompts."""
examples = [
"Explain the concept of machine learning in simple terms.",
"Write a Python function to calculate fibonacci numbers.",
"What are the benefits of fine-tuning language models?",
"Describe the difference between supervised and unsupervised learning.",
]
chat = self.query_one("#chat-history", Log)
chat.write_line("📚 Example prompts:")
for i, example in enumerate(examples, 1):
chat.write_line(f"{i}. {example}")
chat.write_line("Copy and paste any example to try it out!")
@on(Button.Pressed, "#gradio-ui")
@work(thread=True)
async def handle_gradio_ui(self) -> None:
"""Launch Gradio web interface."""
chat = self.query_one("#chat-history", Log)
chat.write_line("🌐 Launching Gradio web interface...")
try:
import subprocess
if self.loaded_model:
cmd = [
"python",
"-m",
"axolotl.cli.inference",
self.loaded_model,
"--gradio",
]
else:
chat.write_line("⚠️ No model loaded. Loading default interface...")
cmd = ["python", "-m", "axolotl.cli.inference", "--gradio"]
subprocess.Popen(cmd)
chat.write_line("✅ Gradio interface launched! Check your browser.")
except Exception as e:
chat.write_line(f"❌ Error launching Gradio: {str(e)}")
@on(Button.Pressed, "#unload-model")
def handle_unload_model(self) -> None:
"""Unload current model."""
self.loaded_model = None
status = self.query_one("#model-status", Static)
status.update("No model loaded")
select = self.query_one("#model-select", Select)
select.value = "none"
chat = self.query_one("#chat-history", Log)
chat.write_line("🔄 Model unloaded")
def action_send_message(self) -> None:
"""Send message action."""
self.handle_send_message()
def action_clear_chat(self) -> None:
"""Clear chat action."""
self.handle_clear_chat()
def action_load_model(self) -> None:
"""Load model action."""
self.handle_load_model()
def action_save_chat(self) -> None:
"""Save chat action."""
self.handle_save_chat()

View File

@@ -0,0 +1,373 @@
"""Model management screen for Axolotl TUI."""
from pathlib import Path
from typing import Dict, Optional
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container, ScrollableContainer
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Label,
Log,
ProgressBar,
Static,
TabbedContent,
TabPane,
)
from axolotl.tui.screens.base import BaseScreen
class ModelScreen(BaseScreen):
"""Model management screen."""
BINDINGS = [
Binding("ctrl+m", "merge_lora", "Merge LoRA"),
Binding("ctrl+q", "quantize", "Quantize"),
Binding("ctrl+e", "evaluate", "Evaluate"),
Binding("r", "refresh", "Refresh"),
]
CSS = """
.model-container {
layout: horizontal;
height: 100%;
}
.model-list {
width: 50%;
border: solid $primary;
padding: 1;
margin: 1;
}
.model-operations {
width: 50%;
border: solid $secondary;
padding: 1;
margin: 1;
}
.model-actions {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
}
.model-actions Button {
margin: 0 1;
}
DataTable {
height: 80%;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
"""
def __init__(self):
"""Initialize the model screen."""
super().__init__(
title="Model Management",
subtitle="Manage trained models, merge LoRA adapters, and quantize models",
)
self.models: Dict[str, Dict] = {}
self.selected_model: Optional[str] = None
def compose(self) -> ComposeResult:
"""Compose the model screen layout."""
yield Header()
with Container(id="content"):
yield Static("🦾 Model Management", classes="screen-title")
yield Static(
"Manage trained models, merge LoRA adapters, and quantize models",
classes="screen-subtitle",
)
with Container(classes="model-container"):
with Container(classes="model-list"):
yield Label("Available Models")
yield DataTable(id="model-table")
with Container(classes="model-actions"):
yield Button("Merge LoRA", id="merge-lora", variant="primary")
yield Button("Quantize", id="quantize", variant="success")
yield Button("Evaluate", id="evaluate", variant="warning")
yield Button("Refresh", id="refresh", variant="default")
with Container(classes="model-operations"):
with TabbedContent():
with TabPane("Operations"):
with Container():
yield Log(id="operations-log")
with Container():
yield Label("Operation Progress:")
yield ProgressBar(
total=100,
id="operation-progress",
)
with TabPane("Model Info"):
with ScrollableContainer():
yield Static(
"Model information will appear here",
id="model-info",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.setup_model_table()
self.load_models()
log = self.query_one("#operations-log", Log)
log.write_line("Model manager ready.")
def setup_model_table(self) -> None:
"""Setup the model table."""
table = self.query_one("#model-table", DataTable)
table.add_columns("Name", "Type", "Size", "Status")
table.cursor_type = "row"
table.zebra_stripes = True
@work(thread=True)
async def load_models(self) -> None:
"""Load available models."""
# Check outputs directory for trained models
outputs_dir = Path("./outputs")
if outputs_dir.exists():
for model_dir in outputs_dir.glob("*"):
if model_dir.is_dir():
self.models[model_dir.name] = {
"name": model_dir.name,
"path": str(model_dir),
"type": "checkpoint",
"size": self.get_dir_size(model_dir),
"status": "available",
}
self.refresh_model_table()
def get_dir_size(self, path: Path) -> str:
"""Get human-readable directory size."""
try:
total_size = sum(f.stat().st_size for f in path.rglob("*") if f.is_file())
for unit in ["B", "KB", "MB", "GB"]:
if total_size < 1024.0:
return f"{total_size:.2f} {unit}"
total_size /= 1024.0
return f"{total_size:.2f} TB"
except Exception:
return "Unknown"
def refresh_model_table(self) -> None:
"""Refresh the model table."""
table = self.query_one("#model-table", DataTable)
table.clear()
for name, info in self.models.items():
table.add_row(
name[:30],
info["type"],
info["size"],
info["status"],
)
@on(DataTable.RowSelected)
def handle_model_selected(self, event: DataTable.RowSelected) -> None:
"""Handle model selection from table."""
if event.cursor_row >= 0:
model_names = list(self.models.keys())
if event.cursor_row < len(model_names):
self.selected_model = model_names[event.cursor_row]
self.update_model_info()
def update_model_info(self) -> None:
"""Update model information display."""
if not self.selected_model:
return
info = self.models[self.selected_model]
info_text = f"""
Model Name: {info['name']}
Path: {info['path']}
Type: {info['type']}
Size: {info['size']}
Status: {info['status']}
"""
self.query_one("#model-info", Static).update(info_text)
@on(Button.Pressed, "#merge-lora")
@work(thread=True)
async def handle_merge_lora(self) -> None:
"""Merge LoRA adapters with base model."""
if not self.selected_model:
log = self.query_one("#operations-log", Log)
log.write_line("⚠️ No model selected")
return
model_info = self.models[self.selected_model]
log = self.query_one("#operations-log", Log)
log.clear()
log.write_line(f"🔄 Merging LoRA adapters for {self.selected_model}...")
progress = self.query_one("#operation-progress", ProgressBar)
progress.update(progress=0)
try:
import subprocess
cmd = ["python", "-m", "axolotl.cli.merge_lora", model_info["path"]]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
for line in process.stdout:
log.write_line(line.strip())
progress.advance(10)
process.wait()
if process.returncode == 0:
log.write_line("✅ LoRA merge completed successfully!")
progress.update(progress=100)
else:
log.write_line(f"❌ LoRA merge failed with code {process.returncode}")
except Exception as e:
log.write_line(f"❌ Error during LoRA merge: {str(e)}")
@on(Button.Pressed, "#quantize")
@work(thread=True)
async def handle_quantize(self) -> None:
"""Quantize selected model."""
if not self.selected_model:
log = self.query_one("#operations-log", Log)
log.write_line("⚠️ No model selected")
return
model_info = self.models[self.selected_model]
log = self.query_one("#operations-log", Log)
log.clear()
log.write_line(f"🔄 Quantizing {self.selected_model}...")
progress = self.query_one("#operation-progress", ProgressBar)
progress.update(progress=0)
try:
import subprocess
cmd = [
"python",
"-m",
"axolotl.cli.quantize",
model_info["path"],
"--output-dir",
f"{model_info['path']}_quantized",
]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
for line in process.stdout:
log.write_line(line.strip())
progress.advance(5)
process.wait()
if process.returncode == 0:
log.write_line("✅ Quantization completed successfully!")
progress.update(progress=100)
else:
log.write_line(f"❌ Quantization failed with code {process.returncode}")
except Exception as e:
log.write_line(f"❌ Error during quantization: {str(e)}")
@on(Button.Pressed, "#evaluate")
@work(thread=True)
async def handle_evaluate(self) -> None:
"""Evaluate selected model."""
if not self.selected_model:
log = self.query_one("#operations-log", Log)
log.write_line("⚠️ No model selected")
return
model_info = self.models[self.selected_model]
log = self.query_one("#operations-log", Log)
log.clear()
log.write_line(f"🔄 Evaluating {self.selected_model}...")
progress = self.query_one("#operation-progress", ProgressBar)
progress.update(progress=0)
try:
import subprocess
cmd = ["python", "-m", "axolotl.cli.evaluate", model_info["path"]]
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
for line in process.stdout:
log.write_line(line.strip())
progress.advance(10)
process.wait()
if process.returncode == 0:
log.write_line("✅ Evaluation completed successfully!")
progress.update(progress=100)
else:
log.write_line(f"❌ Evaluation failed with code {process.returncode}")
except Exception as e:
log.write_line(f"❌ Error during evaluation: {str(e)}")
@on(Button.Pressed, "#refresh")
def handle_refresh(self) -> None:
"""Refresh model list."""
self.load_models()
def action_merge_lora(self) -> None:
"""Merge LoRA adapters."""
self.handle_merge_lora()
def action_quantize(self) -> None:
"""Quantize model."""
self.handle_quantize()
def action_evaluate(self) -> None:
"""Evaluate model."""
self.handle_evaluate()
def action_refresh(self) -> None:
"""Refresh model list."""
self.handle_refresh()

View File

@@ -0,0 +1,414 @@
"""System monitoring screen for Axolotl TUI."""
import psutil
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Label,
Log,
ProgressBar,
Sparkline,
Static,
)
from axolotl.tui.screens.base import BaseScreen
class MonitorScreen(BaseScreen):
"""System monitoring screen."""
BINDINGS = [
Binding("r", "refresh", "Refresh"),
Binding("ctrl+k", "kill_process", "Kill Process"),
]
CSS = """
.monitor-container {
layout: vertical;
height: 100%;
}
.metrics-grid {
layout: horizontal;
height: 20%;
padding: 1;
}
.metric-card {
width: 25%;
border: solid $surface;
padding: 1;
margin: 0 1;
}
.metric-label {
text-style: bold;
color: $text-muted;
text-align: center;
}
.metric-value {
text-style: bold;
text-align: center;
padding: 1;
}
.charts-container {
height: 40%;
layout: horizontal;
padding: 1;
}
.chart-panel {
width: 50%;
border: solid $primary;
padding: 1;
margin: 0 1;
}
.processes-container {
height: 40%;
border: solid $warning;
padding: 1;
margin: 1;
}
DataTable {
height: 90%;
}
.process-controls {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
}
.process-controls Button {
margin: 0 1;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
Sparkline {
height: 8;
}
ProgressBar {
margin: 1 0;
}
"""
def __init__(self):
"""Initialize the monitor screen."""
super().__init__(
title="System Monitor",
subtitle="Monitor system resources and running processes",
)
self.cpu_history = []
self.memory_history = []
self.gpu_history = []
def compose(self) -> ComposeResult:
"""Compose the monitor screen layout."""
yield Header()
yield Container(
Static("🦾 System Monitor", classes="screen-title"),
Static(
"Monitor system resources and running processes",
classes="screen-subtitle",
),
Container(
Container(
Container(
Static("CPU Usage", classes="metric-label"),
Static("0%", id="cpu-usage", classes="metric-value"),
ProgressBar(total=100, id="cpu-progress"),
classes="metric-card",
),
Container(
Static("Memory", classes="metric-label"),
Static("0%", id="memory-usage", classes="metric-value"),
ProgressBar(total=100, id="memory-progress"),
classes="metric-card",
),
Container(
Static("GPU Usage", classes="metric-label"),
Static("0%", id="gpu-usage", classes="metric-value"),
ProgressBar(total=100, id="gpu-progress"),
classes="metric-card",
),
Container(
Static("Temperature", classes="metric-label"),
Static("0°C", id="temperature", classes="metric-value"),
classes="metric-card",
),
classes="metrics-grid",
),
Container(
Container(
Label("CPU History"),
Sparkline([], id="cpu-sparkline"),
classes="chart-panel",
),
Container(
Label("Memory History"),
Sparkline([], id="memory-sparkline"),
classes="chart-panel",
),
classes="charts-container",
),
Container(
DataTable(id="process-table"),
Log(id="gpu-info"),
Log(id="system-logs"),
classes="processes-container",
),
classes="monitor-container",
),
id="content",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.setup_process_table()
self.start_monitoring()
# Initial system info
self.update_system_info()
self.update_gpu_info()
def setup_process_table(self) -> None:
"""Setup the process table."""
table = self.query_one("#process-table", DataTable)
table.add_columns("PID", "Name", "CPU%", "Memory%", "Status")
table.cursor_type = "row"
table.zebra_stripes = True
def start_monitoring(self) -> None:
"""Start the monitoring timer."""
self.set_interval(2.0, self.update_system_metrics)
@work(thread=True)
async def update_system_metrics(self) -> None:
"""Update system metrics."""
try:
# CPU usage
cpu_percent = psutil.cpu_percent(interval=None)
self.cpu_history.append(cpu_percent)
if len(self.cpu_history) > 50:
self.cpu_history.pop(0)
# Memory usage
memory = psutil.virtual_memory()
memory_percent = memory.percent
self.memory_history.append(memory_percent)
if len(self.memory_history) > 50:
self.memory_history.pop(0)
# GPU usage (if available)
gpu_percent = self.get_gpu_usage()
self.gpu_history.append(gpu_percent)
if len(self.gpu_history) > 50:
self.gpu_history.pop(0)
# Temperature
temperature = self.get_temperature()
# Update UI
self.update_metrics_display(
cpu_percent, memory_percent, gpu_percent, temperature
)
self.update_sparklines()
self.update_process_table()
except Exception as e:
log = self.query_one("#system-logs", Log)
log.write_line(f"Error updating metrics: {str(e)}")
def get_gpu_usage(self) -> float:
"""Get GPU usage percentage."""
try:
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
return util.gpu
except Exception:
return 0.0
def get_temperature(self) -> str:
"""Get system temperature."""
try:
temps = psutil.sensors_temperatures()
if temps:
for name, entries in temps.items():
if entries:
return f"{entries[0].current:.1f}°C"
return "N/A"
except Exception:
return "N/A"
def update_metrics_display(
self, cpu: float, memory: float, gpu: float, temp: str
) -> None:
"""Update metrics display."""
self.query_one("#cpu-usage", Static).update(f"{cpu:.1f}%")
self.query_one("#memory-usage", Static).update(f"{memory:.1f}%")
self.query_one("#gpu-usage", Static).update(f"{gpu:.1f}%")
self.query_one("#temperature", Static).update(temp)
self.query_one("#cpu-progress", ProgressBar).update(progress=cpu)
self.query_one("#memory-progress", ProgressBar).update(progress=memory)
self.query_one("#gpu-progress", ProgressBar).update(progress=gpu)
def update_sparklines(self) -> None:
"""Update sparkline charts."""
if self.cpu_history:
cpu_sparkline = self.query_one("#cpu-sparkline", Sparkline)
cpu_sparkline.data = self.cpu_history
if self.memory_history:
memory_sparkline = self.query_one("#memory-sparkline", Sparkline)
memory_sparkline.data = self.memory_history
def update_process_table(self) -> None:
"""Update the process table."""
table = self.query_one("#process-table", DataTable)
table.clear()
try:
# Get top processes by CPU usage
processes = []
for proc in psutil.process_iter(
["pid", "name", "cpu_percent", "memory_percent", "status"]
):
try:
pinfo = proc.info
if pinfo["cpu_percent"] > 0.1: # Only show processes using CPU
processes.append(pinfo)
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
# Sort by CPU usage
processes.sort(key=lambda x: x["cpu_percent"], reverse=True)
# Add top 20 processes
for proc in processes[:20]:
table.add_row(
str(proc["pid"]),
proc["name"][:20],
f"{proc['cpu_percent']:.1f}%",
f"{proc['memory_percent']:.1f}%",
proc["status"],
)
except Exception as e:
log = self.query_one("#system-logs", Log)
log.write_line(f"Error updating process table: {str(e)}")
def update_system_info(self) -> None:
"""Update system information."""
try:
# System info
psutil.boot_time()
cpu_count = psutil.cpu_count()
memory = psutil.virtual_memory()
log = self.query_one("#system-logs", Log)
log.write_line(f"System started. CPU cores: {cpu_count}")
log.write_line(f"Total memory: {memory.total / (1024**3):.1f} GB")
log.write_line(f"Available memory: {memory.available / (1024**3):.1f} GB")
except Exception as e:
log = self.query_one("#system-logs", Log)
log.write_line(f"Error getting system info: {str(e)}")
def update_gpu_info(self) -> None:
"""Update GPU information."""
try:
import pynvml
pynvml.nvmlInit()
device_count = pynvml.nvmlDeviceGetCount()
log = self.query_one("#gpu-info", Log)
log.clear()
log.write_line(f"Found {device_count} GPU(s)")
for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
name = pynvml.nvmlDeviceGetName(handle).decode()
memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
log.write_line(f"\nGPU {i}: {name}")
log.write_line(
f"Memory: {memory_info.used / (1024**3):.1f} / {memory_info.total / (1024**3):.1f} GB"
)
log.write_line(f"Free: {memory_info.free / (1024**3):.1f} GB")
except Exception as e:
log = self.query_one("#gpu-info", Log)
log.clear()
log.write_line(f"GPU info unavailable: {str(e)}")
@on(Button.Pressed, "#kill-process")
def handle_kill_process(self) -> None:
"""Kill selected process."""
table = self.query_one("#process-table", DataTable)
if table.cursor_row >= 0:
try:
row = table.get_row_at(table.cursor_row)
pid = int(row[0])
process = psutil.Process(pid)
process.terminate()
log = self.query_one("#system-logs", Log)
log.write_line(f"Terminated process {pid}")
except Exception as e:
log = self.query_one("#system-logs", Log)
log.write_line(f"Error killing process: {str(e)}")
@on(Button.Pressed, "#refresh")
def handle_refresh(self) -> None:
"""Refresh all metrics."""
self.update_system_info()
self.update_gpu_info()
log = self.query_one("#system-logs", Log)
log.write_line("Metrics refreshed")
@on(Button.Pressed, "#auto-refresh")
def handle_auto_refresh(self) -> None:
"""Toggle auto refresh."""
log = self.query_one("#system-logs", Log)
log.write_line("Auto refresh is always enabled (every 2 seconds)")
def action_refresh(self) -> None:
"""Refresh action."""
self.handle_refresh()
def action_kill_process(self) -> None:
"""Kill process action."""
self.handle_kill_process()

View File

@@ -0,0 +1,545 @@
"""Training management screen for Axolotl TUI."""
import subprocess
import threading
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Label,
Log,
Sparkline,
Static,
)
from axolotl.tui.screens.base import BaseScreen
@dataclass
class TrainingJob:
"""Represents a training job."""
id: str
config_path: str
status: str # pending, running, completed, failed
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
process: Optional[subprocess.Popen] = None
log_file: Optional[str] = None
current_epoch: int = 0
total_epochs: int = 0
current_loss: float = 0.0
losses: List[float] = None
def __post_init__(self):
if self.losses is None:
self.losses = []
class TrainingScreen(BaseScreen):
"""Training management screen."""
BINDINGS = [
Binding("ctrl+t", "new_training", "New Training"),
Binding("ctrl+r", "resume_training", "Resume"),
Binding("ctrl+x", "stop_training", "Stop"),
Binding("ctrl+l", "view_logs", "View Logs"),
Binding("r", "refresh", "Refresh"),
]
CSS = """
.training-container {
layout: vertical;
height: 100%;
}
.job-list-container {
height: 40%;
border: solid $primary;
padding: 1;
margin: 1;
}
.job-details-container {
height: 60%;
padding: 1;
}
.control-panel {
layout: horizontal;
height: 4;
align: center middle;
padding: 1;
border: solid $secondary;
margin: 1;
}
.control-panel Button {
margin: 0 1;
}
.metrics-panel {
layout: horizontal;
height: 10;
border: solid $primary;
padding: 1;
margin: 1;
}
.metric-card {
width: 25%;
border: tall $surface;
padding: 1;
margin: 0 1;
}
.metric-label {
text-style: bold;
color: $text-muted;
}
.metric-value {
text-style: bold;
text-align: center;
padding: 1;
}
.log-viewer {
border: solid $warning;
padding: 1;
margin: 1;
}
#training-logs {
height: 100%;
}
DataTable {
height: 100%;
}
.screen-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.screen-subtitle {
text-align: center;
padding: 0 0 1 0;
color: $text-muted;
}
.sparkline-container {
height: 5;
border: solid $success;
padding: 1;
margin: 1;
}
"""
def __init__(self):
"""Initialize the training screen."""
super().__init__(
title="Training Management",
subtitle="Launch, monitor, and manage training jobs",
)
self.jobs: Dict[str, TrainingJob] = {}
self.selected_job_id: Optional[str] = None
self.update_timer = None
def compose(self) -> ComposeResult:
"""Compose the training screen layout."""
yield Header()
yield Container(
Static("🦾 Training Management", classes="screen-title"),
Static(
"Launch, monitor, and manage training jobs", classes="screen-subtitle"
),
Container(
Container(
Label("Active Training Jobs"),
DataTable(id="job-table"),
classes="job-list-container",
),
Container(
Button("New Training", id="new-training", variant="primary"),
Button("Resume", id="resume-training", variant="success"),
Button("Stop", id="stop-training", variant="error"),
Button("View Logs", id="view-logs", variant="default"),
Button("Clear Completed", id="clear-completed", variant="warning"),
Button("Refresh", id="refresh", variant="default"),
classes="control-panel",
),
Container(
Container(
Static("Current Epoch", classes="metric-label"),
Static("0 / 0", id="epoch-metric", classes="metric-value"),
classes="metric-card",
),
Container(
Static("Loss", classes="metric-label"),
Static("0.000", id="loss-metric", classes="metric-value"),
classes="metric-card",
),
Container(
Static("Status", classes="metric-label"),
Static("Idle", id="status-metric", classes="metric-value"),
classes="metric-card",
),
Container(
Static("Duration", classes="metric-label"),
Static(
"00:00:00", id="duration-metric", classes="metric-value"
),
classes="metric-card",
),
classes="metrics-panel",
),
Container(
Label("Loss History"),
Sparkline(
[],
id="loss-sparkline",
summary_function=min,
),
classes="sparkline-container",
),
Container(
Log(id="training-logs"),
classes="log-viewer",
),
classes="job-details-container",
),
classes="training-container",
id="content",
)
yield Footer()
def on_mount(self) -> None:
"""Called when the screen is mounted."""
self.setup_job_table()
self.start_update_timer()
log = self.query_one("#training-logs", Log)
log.write_line(
"Training manager ready. Select a configuration to start training."
)
def setup_job_table(self) -> None:
"""Setup the job table."""
table = self.query_one("#job-table", DataTable)
table.add_columns("ID", "Config", "Status", "Epoch", "Loss", "Duration")
table.cursor_type = "row"
table.zebra_stripes = True
def start_update_timer(self) -> None:
"""Start the periodic update timer."""
self.set_interval(2.0, self.update_job_status)
@work(thread=True)
async def update_job_status(self) -> None:
"""Update job status periodically."""
for job_id, job in self.jobs.items():
if job.status == "running" and job.process:
poll = job.process.poll()
if poll is not None:
if poll == 0:
job.status = "completed"
else:
job.status = "failed"
job.end_time = datetime.now()
self.refresh_job_table()
self.update_selected_job_metrics()
def refresh_job_table(self) -> None:
"""Refresh the job table."""
table = self.query_one("#job-table", DataTable)
table.clear()
for job_id, job in self.jobs.items():
duration = self.calculate_duration(job)
table.add_row(
job_id[:8],
Path(job.config_path).name,
job.status,
f"{job.current_epoch}/{job.total_epochs}",
f"{job.current_loss:.4f}" if job.current_loss else "N/A",
duration,
)
def calculate_duration(self, job: TrainingJob) -> str:
"""Calculate job duration."""
if not job.start_time:
return "00:00:00"
end_time = job.end_time or datetime.now()
duration = end_time - job.start_time
hours = int(duration.total_seconds() // 3600)
minutes = int((duration.total_seconds() % 3600) // 60)
seconds = int(duration.total_seconds() % 60)
return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
def update_selected_job_metrics(self) -> None:
"""Update metrics for selected job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
return
job = self.jobs[self.selected_job_id]
self.query_one("#epoch-metric", Static).update(
f"{job.current_epoch} / {job.total_epochs}"
)
self.query_one("#loss-metric", Static).update(
f"{job.current_loss:.4f}" if job.current_loss else "N/A"
)
self.query_one("#status-metric", Static).update(job.status.upper())
self.query_one("#duration-metric", Static).update(self.calculate_duration(job))
if job.losses:
sparkline = self.query_one("#loss-sparkline", Sparkline)
sparkline.data = job.losses[-50:] # Show last 50 loss values
@on(DataTable.RowSelected)
def handle_row_selected(self, event: DataTable.RowSelected) -> None:
"""Handle job selection from table."""
if event.cursor_row >= 0:
job_ids = list(self.jobs.keys())
if event.cursor_row < len(job_ids):
self.selected_job_id = job_ids[event.cursor_row]
self.update_selected_job_metrics()
self.load_job_logs()
def load_job_logs(self) -> None:
"""Load logs for selected job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
return
job = self.jobs[self.selected_job_id]
if job.log_file and Path(job.log_file).exists():
try:
with open(job.log_file, "r") as f:
content = f.read()
log = self.query_one("#training-logs", Log)
log.clear()
for line in content.split("\n")[-100:]: # Show last 100 lines
if line.strip():
log.write_line(line)
except Exception as e:
log = self.query_one("#training-logs", Log)
log.write_line(f"Error loading logs: {str(e)}")
@on(Button.Pressed, "#new-training")
async def handle_new_training(self) -> None:
"""Start a new training job."""
from axolotl.tui.dialogs.training import NewTrainingDialog
dialog = NewTrainingDialog()
result = await self.app.push_screen_wait(dialog)
if result and "config_path" in result:
await self.start_training_job(
result["config_path"], result.get("launcher", "accelerate")
)
@work(thread=True)
async def start_training_job(
self, config_path: str, launcher: str = "accelerate"
) -> None:
"""Start a training job."""
import uuid
from datetime import datetime
job_id = str(uuid.uuid4())
log_file = f"/tmp/axolotl_training_{job_id}.log"
job = TrainingJob(
id=job_id,
config_path=config_path,
status="pending",
start_time=datetime.now(),
log_file=log_file,
total_epochs=3, # Default, should parse from config
)
self.jobs[job_id] = job
self.selected_job_id = job_id
log = self.query_one("#training-logs", Log)
log.clear()
log.write_line(f"🚀 Starting training job {job_id[:8]}...")
log.write_line(f"Config: {config_path}")
log.write_line(f"Launcher: {launcher}")
try:
if launcher == "accelerate":
cmd = ["accelerate", "launch", "-m", "axolotl.cli.train", config_path]
else:
cmd = [
"torchrun",
"--nproc_per_node=1",
"-m",
"axolotl.cli.train",
config_path,
]
with open(log_file, "w") as f:
process = subprocess.Popen(
cmd,
stdout=f,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
)
job.process = process
job.status = "running"
log.write_line("✅ Training started successfully!")
self.refresh_job_table()
self.monitor_training_output(job_id)
except Exception as e:
job.status = "failed"
job.end_time = datetime.now()
log.write_line(f"❌ Failed to start training: {str(e)}")
self.refresh_job_table()
def monitor_training_output(self, job_id: str) -> None:
"""Monitor training output and extract metrics."""
if job_id not in self.jobs:
return
job = self.jobs[job_id]
if not job.log_file:
return
def tail_log():
import re
import time
with open(job.log_file, "r") as f:
f.seek(0, 2) # Go to end of file
while job.status == "running":
line = f.readline()
if line:
# Parse training metrics from log
epoch_match = re.search(r"Epoch (\d+)/(\d+)", line)
if epoch_match:
job.current_epoch = int(epoch_match.group(1))
job.total_epochs = int(epoch_match.group(2))
loss_match = re.search(
r"loss['\"]?\s*:\s*([\d.]+)", line, re.IGNORECASE
)
if loss_match:
job.current_loss = float(loss_match.group(1))
job.losses.append(job.current_loss)
# Update log viewer
self.call_from_thread(self.append_training_log, line.strip())
else:
time.sleep(0.5)
thread = threading.Thread(target=tail_log, daemon=True)
thread.start()
def append_training_log(self, line: str) -> None:
"""Append line to training log."""
log = self.query_one("#training-logs", Log)
log.write_line(line)
@on(Button.Pressed, "#stop-training")
def handle_stop_training(self) -> None:
"""Stop selected training job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
log = self.query_one("#training-logs", Log)
log.write_line("⚠️ No job selected")
return
job = self.jobs[self.selected_job_id]
if job.status == "running" and job.process:
job.process.terminate()
job.status = "stopped"
job.end_time = datetime.now()
log = self.query_one("#training-logs", Log)
log.write_line(f"🛑 Training job {job.id[:8]} stopped")
self.refresh_job_table()
@on(Button.Pressed, "#resume-training")
async def handle_resume_training(self) -> None:
"""Resume a stopped training job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
log = self.query_one("#training-logs", Log)
log.write_line("⚠️ No job selected")
return
job = self.jobs[self.selected_job_id]
if job.status in ["stopped", "failed"]:
await self.start_training_job(job.config_path)
@on(Button.Pressed, "#clear-completed")
def handle_clear_completed(self) -> None:
"""Clear completed jobs from the list."""
completed_jobs = [
job_id
for job_id, job in self.jobs.items()
if job.status in ["completed", "failed", "stopped"]
]
for job_id in completed_jobs:
del self.jobs[job_id]
self.refresh_job_table()
log = self.query_one("#training-logs", Log)
log.write_line(f"🧹 Cleared {len(completed_jobs)} completed jobs")
@on(Button.Pressed, "#refresh")
def handle_refresh(self) -> None:
"""Refresh the job list and metrics."""
self.refresh_job_table()
self.update_selected_job_metrics()
if self.selected_job_id:
self.load_job_logs()
@on(Button.Pressed, "#view-logs")
def handle_view_logs(self) -> None:
"""View full logs for selected job."""
if not self.selected_job_id or self.selected_job_id not in self.jobs:
return
job = self.jobs[self.selected_job_id]
if job.log_file and Path(job.log_file).exists():
import subprocess
subprocess.run(["less", job.log_file])
def action_new_training(self) -> None:
"""Start a new training job."""
self.handle_new_training()
def action_stop_training(self) -> None:
"""Stop selected training job."""
self.handle_stop_training()
def action_resume_training(self) -> None:
"""Resume selected training job."""
self.handle_resume_training()
def action_refresh(self) -> None:
"""Refresh the display."""
self.handle_refresh()

View File

@@ -1,19 +1,11 @@
"""Shared axolotl collators for multipack, mamba, multimodal, etc."""
"""
shared axolotl collators for multipack, mamba, multimodal
"""
from .batching import (
from .batching import ( # noqa: F401
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
PretrainingBatchSamplerDataCollatorForSeq2Seq,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from .mamba import MambaDataCollator
from .streaming import StreamingDataCollator
__all__ = [
"BatchSamplerDataCollatorForSeq2Seq",
"DataCollatorForSeq2Seq",
"PretrainingBatchSamplerDataCollatorForSeq2Seq",
"V2BatchSamplerDataCollatorForSeq2Seq",
"MambaDataCollator",
"StreamingDataCollator",
]
from .mamba import MambaDataCollator # noqa: F401

View File

@@ -1,146 +0,0 @@
from dataclasses import dataclass
from typing import Any, List
import torch
from transformers import PreTrainedTokenizerBase, default_data_collator
from transformers.utils import PaddingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.dict import DictDefault
@dataclass
class StreamingDataCollator:
tokenizer: PreTrainedTokenizerBase
cfg: DictDefault
prompter: Prompter | None = None
padding: bool | str | PaddingStrategy = True
max_length: int | None = None
pad_to_multiple_of: int | None = None
label_pad_token_id: int = -100
return_tensors: str = "pt"
def __post_init__(self):
if self.max_length is None:
self.max_length = self.cfg.sequence_len
def __call__(self, raw_batch: List[dict]) -> dict[str, Any]:
processed_samples = []
for raw_sample in raw_batch:
formatted_sample = raw_sample
if self.prompter:
formatted_sample = self._apply_prompt_formatting(raw_sample)
tokenized_sample = self._tokenize_sample(formatted_sample)
if len(tokenized_sample["input_ids"]) > self.max_length:
tokenized_sample = self._truncate_sample(tokenized_sample)
if tokenized_sample.get("input_ids"):
processed_samples.append(tokenized_sample)
return self._pad_and_batch(processed_samples)
def _apply_prompt_formatting(self, raw_sample: dict) -> dict:
formatted_text = self.prompter.build_prompt(
instruction=raw_sample.get("instruction", ""),
input=raw_sample.get("input", ""),
output=raw_sample.get("output", ""),
)
return {"text": formatted_text}
def _tokenize_sample(self, sample: dict) -> dict:
text = sample.get("text", sample.get("content", ""))
if not text:
instruction = sample.get("instruction", "")
input_text = sample.get("input", "")
output_text = sample.get("output", "")
parts = []
if instruction:
parts.append(f"Instruction: {instruction}")
if input_text:
parts.append(f"Input: {input_text}")
if output_text:
parts.append(f"Output: {output_text}")
text = "\n".join(parts)
if not text:
return {"input_ids": [], "attention_mask": [], "labels": []}
tokenized = self.tokenizer(
text,
truncation=False,
padding=False,
return_tensors=None,
)
tokenized["labels"] = tokenized["input_ids"].copy()
return tokenized
def _truncate_sample(self, tokenized_sample: dict) -> dict:
max_len = self.max_length
for key in ["input_ids", "attention_mask", "labels"]:
if key in tokenized_sample:
tokenized_sample[key] = tokenized_sample[key][:max_len]
return tokenized_sample
def _pad_and_batch(self, processed_samples: List[dict]) -> dict[str, Any]:
if not processed_samples:
processed_samples = [
{
"input_ids": [self.tokenizer.eos_token_id],
"attention_mask": [1],
"labels": [self.tokenizer.eos_token_id],
}
]
batch_samples = []
for sample in processed_samples:
batch_sample = {}
for key, value in sample.items():
if key in ["input_ids", "attention_mask", "labels"]:
batch_sample[key] = torch.tensor(value, dtype=torch.long)
batch_samples.append(batch_sample)
if self.padding:
max_len_in_batch = max(len(sample["input_ids"]) for sample in batch_samples)
for sample in batch_samples:
current_len = len(sample["input_ids"])
pad_len = max_len_in_batch - current_len
if pad_len > 0:
pad_token_id = (
self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
)
sample["input_ids"] = torch.cat(
[
sample["input_ids"],
torch.full((pad_len,), pad_token_id, dtype=torch.long),
]
)
sample["attention_mask"] = torch.cat(
[
sample["attention_mask"],
torch.zeros(pad_len, dtype=torch.long),
]
)
sample["labels"] = torch.cat(
[
sample["labels"],
torch.full(
(pad_len,), self.label_pad_token_id, dtype=torch.long
),
]
)
batch = {}
for key in ["input_ids", "attention_mask", "labels"]:
if key in batch_samples[0]:
batch[key] = torch.stack([sample[key] for sample in batch_samples])
return batch

View File

@@ -9,7 +9,6 @@ from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
load_dataset,
)
from transformers import PreTrainedTokenizer, ProcessorMixin
@@ -44,18 +43,6 @@ from axolotl.utils.trainer import (
LOG = get_logger(__name__)
def _determine_streaming_mode(cfg: DictDefault) -> bool:
"""Determine if we should use streaming mode based on config."""
if cfg.streaming is not None:
return cfg.streaming
# Default to streaming for pretraining datasets
if cfg.pretraining_dataset:
return True
return False
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_datasets(
cfg: DictDefault,
@@ -74,52 +61,11 @@ def prepare_datasets(
Returns:
Tuple of (train_dataset, eval_dataset, total_steps, prompters).
"""
streaming_mode = _determine_streaming_mode(cfg)
if streaming_mode:
if cfg.pretraining_dataset:
return _prepare_streaming_pretraining_dataset(cfg, tokenizer, processor)
else:
return _prepare_streaming_sft_dataset(cfg, tokenizer, processor)
else:
if cfg.pretraining_dataset:
return _prepare_pretraining_dataset(
cfg, tokenizer, processor, preprocess_iterable=False
)
else:
return _prepare_standard_dataset(
cfg, tokenizer, processor, preprocess_iterable=False
)
def _prepare_streaming_sft_dataset(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]:
LOG.info("Loading streaming datasets")
raw_datasets = _load_raw_datasets_for_streaming(cfg, split="train")
eval_dataset = None
if cfg.test_datasets:
eval_raw_datasets = _load_raw_datasets_for_streaming(
cfg, split="test", dataset_configs=cfg.test_datasets
if cfg.pretraining_dataset:
return _prepare_pretraining_dataset(
cfg, tokenizer, processor, preprocess_iterable
)
eval_dataset = _process_eval_dataset_minimal(
eval_raw_datasets, cfg, tokenizer, processor
)
elif cfg.val_set_size:
LOG.info("Validation splits not supported for streaming datasets")
if not cfg.max_steps:
raise ValueError("max_steps must be set when using streaming datasets")
total_num_steps = cfg.max_steps
LOG.info(f"Maximum steps: {total_num_steps}")
prompters = [None] * len(cfg.datasets) if cfg.datasets else []
return raw_datasets, eval_dataset, total_num_steps, prompters
return _prepare_standard_dataset(cfg, tokenizer, processor, preprocess_iterable)
def _prepare_standard_dataset(
@@ -427,7 +373,7 @@ def _load_and_process_single_dataset(
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
# Select the appropriate split
if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
if isinstance(dataset, DatasetDict):
if dataset_config.split and dataset_config.split in dataset:
dataset = dataset[dataset_config.split]
elif split in dataset:
@@ -566,78 +512,3 @@ def _load_and_prepare_datasets(
train_dataset, eval_dataset = _handle_test_dataset_split(dataset, cfg)
return train_dataset, eval_dataset, prompters
def _load_raw_datasets_for_streaming(
cfg: DictDefault, split: str = "train", dataset_configs: list | None = None
) -> IterableDataset:
configs = (
dataset_configs
if dataset_configs is not None
else (cfg.datasets if split == "train" else cfg.test_datasets)
)
if not configs:
raise ValueError(f"No dataset configurations found for split '{split}'")
datasets = []
for dataset_config in datasets_with_name_generator(configs):
raw_dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=True
)
if isinstance(raw_dataset, (DatasetDict, IterableDatasetDict)):
if dataset_config.split and dataset_config.split in raw_dataset:
raw_dataset = raw_dataset[dataset_config.split]
elif split in raw_dataset:
raw_dataset = raw_dataset[split]
else:
raise ValueError(
f"no {split} split found for dataset {dataset_config.path}, "
"you may specify a split with 'split: ...'"
)
datasets.append(raw_dataset)
if len(datasets) == 1:
return datasets[0]
else:
return merge_datasets(datasets, cfg)
def _process_eval_dataset_minimal(
raw_dataset: IterableDataset,
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
) -> Dataset | None:
LOG.info("Eval dataset processing skipped for streaming")
return None
def _prepare_streaming_pretraining_dataset(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]:
pretraining_config = _extract_pretraining_config(cfg)
train_dataset = load_dataset_with_config(
pretraining_config, cfg.hf_use_auth_token, streaming=True
)
if isinstance(train_dataset, (DatasetDict, IterableDatasetDict)):
if pretraining_config.split and pretraining_config.split in train_dataset:
train_dataset = train_dataset[pretraining_config.split]
elif "train" in train_dataset:
train_dataset = train_dataset["train"]
else:
raise ValueError("no train split found for pretraining dataset")
if not cfg.max_steps:
raise ValueError("max_steps must be set when using streaming datasets")
total_num_steps = cfg.max_steps
LOG.info(f"Maximum steps: {total_num_steps}")
return train_dataset, None, total_num_steps, []

View File

@@ -190,18 +190,12 @@ def handle_long_seq_in_dataset(
Returns:
Filtered dataset with long sequences removed.
"""
if hasattr(dataset, "column_names") and dataset.column_names:
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
return dataset
else:
# For IterableDataset, we can't check columns upfront, so skip for streaming
if isinstance(dataset, IterableDataset):
LOG.info("Skipping drop_long_seq for streaming datasets (not compatible)")
return dataset
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
return dataset
drop_long = functools.partial(
drop_long_seq,

View File

@@ -932,34 +932,6 @@ class AxolotlInputConfig(
fix_untrained_tokens: int | list[int] | None = None
streaming: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use streaming datasets (IterableDataset) for processing large datasets that don't fit in memory. When True, data is loaded on-demand during training without upfront preprocessing. Requires max_steps to be set. Pre-training datasets default to streaming unless explicitly set to False."
},
)
streaming_dataset_mixing_strategy: str | None = Field(
default="round_robin",
json_schema_extra={
"description": "Strategy for mixing multiple streaming datasets: 'round_robin' (equal sampling), 'weighted' (use streaming_mixing_weights), or 'random' (random sampling with equal probability)."
},
)
streaming_mixing_weights: list[float] | None = Field(
default=None,
json_schema_extra={
"description": "Weights for weighted mixing strategy when using multiple streaming datasets. Must sum to 1.0 and have same length as datasets list. Only used when streaming_dataset_mixing_strategy='weighted'."
},
)
streaming_buffer_per_dataset: int | None = Field(
default=1000,
json_schema_extra={
"description": "Buffer size per dataset when mixing multiple streaming datasets. Higher values may improve mixing quality but use more memory."
},
)
# INTERNALS - document for now, generally not set externally
is_preprocess: bool | None = None
preprocess_iterable: bool | None = None

View File

@@ -1337,30 +1337,6 @@ class GRPOVllmValidationMixin:
# pylint: disable=too-many-ancestors
class StreamingValidationMixin:
"""Validation methods related to streaming datasets."""
@model_validator(mode="after")
def check_streaming_requires_max_steps(self):
"""Ensure max_steps is set when using streaming datasets."""
# Check if streaming is explicitly enabled
streaming_enabled = getattr(self, "streaming", None) is True
# Check if pretraining dataset exists (defaults to streaming)
has_pretraining = getattr(self, "pretraining_dataset", None) is not None
streaming_default_for_pretraining = (
has_pretraining and getattr(self, "streaming", None) is None
)
# If streaming is enabled (explicitly or by default for pretraining)
if streaming_enabled or streaming_default_for_pretraining:
max_steps = getattr(self, "max_steps", None)
if not max_steps:
raise ValueError("max_steps must be set when using streaming datasets")
return self
class ValidationMixin(
DatasetValidationMixin,
AttentionValidationMixin,
@@ -1371,7 +1347,6 @@ class ValidationMixin(
SystemValidationMixin,
ChatTemplateValidationMixin,
PretrainingValidationMixin,
StreamingValidationMixin,
ModelCompatibilityValidationMixin,
ComplexValidationMixin,
GRPOVllmValidationMixin,