This commit is contained in:
Dan Saunders
2025-08-22 02:43:16 -04:00
parent 889b27ecf1
commit c3e1882de5
11 changed files with 10087 additions and 10181 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -76,3 +76,4 @@ mistral-common==1.8.3
# TUI dependencies
textual==1.0.0
rich==14.1.0
tree_sitter_ruby==0.23.1

View File

@@ -72,9 +72,10 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
builder_kwargs["message_field_training"] = message_field_training
chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml"))
format_message = (
lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment
)
def format_message(x):
return x
if chat_template == "chatml":
from axolotl.core.chat.format.chatml import format_message # noqa F811
if chat_template.startswith("llama3"):

View File

@@ -3,9 +3,9 @@
from textual import on
from textual.app import App, ComposeResult
from textual.binding import Binding
from textual.containers import Container, Horizontal, ScrollableContainer
from textual.containers import Container
from textual.screen import Screen
from textual.widgets import Button, Footer, Header, Label, Static
from textual.widgets import Button, Footer, Header, Static
from axolotl.tui.screens.config import ConfigScreen
from axolotl.tui.screens.datasets import DatasetScreen
@@ -117,7 +117,6 @@ class AxolotlTUI(App):
text-style: bold;
padding: 1;
color: $primary;
content-align: center middle;
}
.subtitle {
@@ -129,6 +128,7 @@ class AxolotlTUI(App):
.welcome-container {
align: center middle;
height: 100%;
width: 100%;
}
.menu-container {
@@ -144,7 +144,7 @@ class AxolotlTUI(App):
margin: 1;
}
Screen {
WelcomeScreen {
align: center middle;
}
"""

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from textual import on
from textual.app import ComposeResult
from textual.containers import Container, Horizontal, Vertical
from textual.containers import Container
from textual.screen import ModalScreen
from textual.widgets import Button, Input, Label, Select, Static

View File

@@ -8,26 +8,19 @@ import yaml
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container, Horizontal, ScrollableContainer, Vertical
from textual.message import Message
from textual.containers import Container
from textual.reactive import reactive
from textual.widgets import (
Button,
DataTable,
DirectoryTree,
Footer,
Header,
Input,
Label,
LoadingIndicator,
Log,
Select,
Static,
Switch,
TextArea,
)
from axolotl.cli.config import load_cfg
from axolotl.tui.screens.base import BaseScreen
@@ -154,7 +147,7 @@ class ConfigScreen(BaseScreen):
classes="config-actions",
),
Container(
Log(id="validation-log", wrap=True, highlight=True),
Log(id="validation-log"),
classes="validation-log",
),
classes="config-editor",

View File

@@ -2,27 +2,21 @@
import json
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, Optional
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container, Horizontal, ScrollableContainer, Vertical
from textual.containers import Container
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Input,
Label,
LoadingIndicator,
Log,
Pretty,
ProgressBar,
Select,
Static,
TabbedContent,
TabPane,
TextArea,
)
@@ -76,7 +70,7 @@ class DatasetScreen(BaseScreen):
.preview-container {
height: 100%;
border: solid $info;
border: solid $primary;
padding: 1;
}
@@ -157,84 +151,25 @@ class DatasetScreen(BaseScreen):
classes="dataset-list",
),
Container(
TabbedContent(
TabPane(
"Preview",
Container(
TextArea(
"",
language="json",
theme="monokai",
id="dataset-preview",
read_only=True,
),
classes="preview-container",
),
),
TabPane(
"Statistics",
Container(
Container(
Static("Dataset Name:", classes="stat-label"),
Static("-", id="stat-name", classes="stat-value"),
classes="stat-row",
),
Container(
Static("Type:", classes="stat-label"),
Static("-", id="stat-type", classes="stat-value"),
classes="stat-row",
),
Container(
Static("Size:", classes="stat-label"),
Static("-", id="stat-size", classes="stat-value"),
classes="stat-row",
),
Container(
Static("Samples:", classes="stat-label"),
Static(
"-", id="stat-samples", classes="stat-value"
),
classes="stat-row",
),
Container(
Static("Features:", classes="stat-label"),
Static(
"-", id="stat-features", classes="stat-value"
),
classes="stat-row",
),
Container(
Static("Format:", classes="stat-label"),
Static("-", id="stat-format", classes="stat-value"),
classes="stat-row",
),
Container(
Static("Preprocessed:", classes="stat-label"),
Static(
"-",
id="stat-preprocessed",
classes="stat-value",
),
classes="stat-row",
),
classes="stats-container",
),
),
TabPane(
"Processing",
Container(
Log(id="processing-log", wrap=True, highlight=True),
Container(
Label("Preprocessing Progress:"),
ProgressBar(
total=100,
id="preprocessing-progress",
),
classes="progress-container",
),
),
),
TextArea("", id="dataset-preview", read_only=True),
Container(
Static("Dataset Name:", classes="stat-label"),
Static("-", id="stat-name", classes="stat-value"),
Static("Type:", classes="stat-label"),
Static("-", id="stat-type", classes="stat-value"),
Static("Size:", classes="stat-label"),
Static("-", id="stat-size", classes="stat-value"),
Static("Samples:", classes="stat-label"),
Static("-", id="stat-samples", classes="stat-value"),
Static("Features:", classes="stat-label"),
Static("-", id="stat-features", classes="stat-value"),
Static("Format:", classes="stat-label"),
Static("-", id="stat-format", classes="stat-value"),
Static("Preprocessed:", classes="stat-label"),
Static("-", id="stat-preprocessed", classes="stat-value"),
),
Log(id="processing-log"),
ProgressBar(total=100, id="preprocessing-progress"),
classes="dataset-details",
),
classes="dataset-container",
@@ -325,10 +260,10 @@ class DatasetScreen(BaseScreen):
@on(DataTable.RowSelected)
def handle_dataset_selected(self, event: DataTable.RowSelected) -> None:
"""Handle dataset selection from table."""
if event.row_index >= 0:
if event.cursor_row >= 0:
dataset_names = list(self.datasets.keys())
if event.row_index < len(dataset_names):
self.selected_dataset = dataset_names[event.row_index]
if event.cursor_row < len(dataset_names):
self.selected_dataset = dataset_names[event.cursor_row]
self.load_dataset_preview()
self.update_dataset_stats()

View File

@@ -3,16 +3,15 @@
from pathlib import Path
from typing import Dict, List, Optional
from textual import on, work
from textual import events, on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container, Horizontal, ScrollableContainer, Vertical
from textual.containers import Container
from textual.widgets import (
Button,
Input,
Label,
Log,
Pretty,
Select,
Static,
TextArea,
@@ -53,7 +52,7 @@ class InferenceScreen(BaseScreen):
.chat-history {
height: 70%;
border: solid $info;
border: solid $primary;
padding: 1;
margin: 0 0 1 0;
}
@@ -121,7 +120,7 @@ class InferenceScreen(BaseScreen):
Container(
Label("Model Selection"),
Select(
[("none", "No model loaded")],
[("No model loaded", "none")],
id="model-select",
value="none",
),
@@ -147,12 +146,11 @@ class InferenceScreen(BaseScreen):
),
Container(
Container(
Log(id="chat-history", wrap=True, highlight=True),
Log(id="chat-history"),
classes="chat-history",
),
Container(
TextArea(
placeholder="Type your message here...",
id="message-input",
),
classes="input-area",
@@ -182,27 +180,44 @@ class InferenceScreen(BaseScreen):
@work(thread=True)
async def load_available_models(self) -> None:
"""Load list of available models."""
models = [("none", "No model loaded")]
models = [("No model loaded", "none")]
chat = self.query_one("#chat-history", Log)
chat.write_line("🔍 Scanning for available models...")
# Check for trained models
outputs_dir = Path("./outputs")
chat.write_line(f"Checking outputs directory: {outputs_dir.absolute()}")
if outputs_dir.exists():
found_models = 0
for model_dir in outputs_dir.glob("*"):
if model_dir.is_dir() and (model_dir / "pytorch_model.bin").exists():
models.append((str(model_dir), model_dir.name))
# Check for HuggingFace models in cache
hf_cache = Path.home() / ".cache" / "huggingface" / "transformers"
if hf_cache.exists():
for model_dir in hf_cache.glob("models--*"):
if model_dir.is_dir():
model_name = model_dir.name.replace("models--", "").replace(
"--", "/"
# Look for various model file types
model_files = (
list(model_dir.glob("pytorch_model.bin"))
+ list(model_dir.glob("model.safetensors"))
+ list(model_dir.glob("*.bin"))
+ list(model_dir.glob("*.safetensors"))
)
models.append((str(model_dir), f"HF: {model_name}"))
if model_files:
models.append((model_dir.name, str(model_dir)))
found_models += 1
chat.write_line(f"Found {found_models} trained models in outputs/")
else:
chat.write_line("outputs/ directory not found")
# Add some example/demo models for testing
models.extend(
[
("Demo: GPT-2 Small", "gpt2"),
("Demo: TinyLlama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"),
("Demo: Phi-2", "microsoft/phi-2"),
]
)
select = self.query_one("#model-select", Select)
select.set_options(models)
chat.write_line(f"✅ Loaded {len(models)} models in dropdown")
@on(Button.Pressed, "#load-model")
@work(thread=True)
@@ -258,28 +273,63 @@ class InferenceScreen(BaseScreen):
self.chat_history.append({"role": "user", "content": message})
# Generate response (placeholder)
await self.generate_response(message)
self.generate_response(message)
@on(TextArea.Changed, "#message-input")
def on_message_input_changed(self, event: TextArea.Changed) -> None:
"""Handle changes to the message input."""
# This could be used for features like typing indicators
pass
def on_key(self, event: events.Key) -> None:
"""Handle key events globally."""
# Check if we're focused on the message input and Ctrl+Enter is pressed
focused = self.focused
if focused and focused.id == "message-input" and event.key == "ctrl+enter":
event.prevent_default()
self.handle_send_message()
@work(thread=True)
async def generate_response(self, message: str) -> None:
"""Generate model response (placeholder implementation)."""
"""Generate model response."""
chat = self.query_one("#chat-history", Log)
chat.write_line("🤖 Assistant: Thinking...")
try:
# Get inference parameters
temperature = float(self.query_one("#temperature", Input).value)
max_tokens = int(self.query_one("#max-tokens", Input).value)
top_p = float(self.query_one("#top-p", Input).value)
float(self.query_one("#temperature", Input).value)
int(self.query_one("#max-tokens", Input).value)
float(self.query_one("#top-p", Input).value)
# Placeholder response (in real implementation, would call the model)
if not self.loaded_model or self.loaded_model == "none":
response = "I don't have a model loaded yet. Please load a model first using the 'Load Model' button."
elif self.loaded_model.startswith("gpt2"):
# Simple response for GPT-2
responses = [
f"Thanks for your message: '{message}'. I'm a GPT-2 model running in demo mode.",
"I understand you're testing the interface. GPT-2 models are great for experimentation!",
"This is a simulated GPT-2 response. In a real setup, I'd generate text based on your input.",
f"GPT-2 here! You said: '{message}'. I'd normally continue this conversation creatively.",
]
import random
response = random.choice(responses)
elif "llama" in self.loaded_model.lower():
# Response for Llama models
response = f"🦙 LLaMA model here! You asked: '{message}'. I'm designed for helpful, harmless, and honest conversations. How can I assist you today?"
elif "phi" in self.loaded_model.lower():
# Response for Phi models
response = f"Phi model responding! Your message: '{message}'. I'm optimized for reasoning and code tasks. What would you like to explore?"
else:
# Generic response for other models
response = f"Model '{self.loaded_model}' responding to: '{message}'. I'm ready to help with your questions!"
# Simulate inference time
import time
time.sleep(1) # Simulate inference time
time.sleep(0.5)
response = f"This is a placeholder response to: '{message}'. In a real implementation, this would be generated by the loaded model using the parameters: temperature={temperature}, max_tokens={max_tokens}, top_p={top_p}."
# Update chat with response
# Clear the "thinking" message and show response
chat.write_line(f"🤖 Assistant: {response}")
# Add to history

View File

@@ -1,22 +1,20 @@
"""Model management screen for Axolotl TUI."""
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, Optional
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container, Horizontal, ScrollableContainer
from textual.containers import Container, ScrollableContainer
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Input,
Label,
Log,
ProgressBar,
Select,
Static,
TabbedContent,
TabPane,
@@ -96,56 +94,38 @@ class ModelScreen(BaseScreen):
def compose(self) -> ComposeResult:
"""Compose the model screen layout."""
yield Header()
yield Container(
Static("🦾 Model Management", classes="screen-title"),
Static(
with Container(id="content"):
yield Static("🦾 Model Management", classes="screen-title")
yield Static(
"Manage trained models, merge LoRA adapters, and quantize models",
classes="screen-subtitle",
),
Container(
Container(
Label("Available Models"),
DataTable(id="model-table"),
Container(
Button("Merge LoRA", id="merge-lora", variant="primary"),
Button("Quantize", id="quantize", variant="success"),
Button("Evaluate", id="evaluate", variant="warning"),
Button("Refresh", id="refresh", variant="default"),
classes="model-actions",
),
classes="model-list",
),
Container(
TabbedContent(
TabPane(
"Operations",
Container(
Log(id="operations-log", wrap=True, highlight=True),
Container(
Label("Operation Progress:"),
ProgressBar(
)
with Container(classes="model-container"):
with Container(classes="model-list"):
yield Label("Available Models")
yield DataTable(id="model-table")
with Container(classes="model-actions"):
yield Button("Merge LoRA", id="merge-lora", variant="primary")
yield Button("Quantize", id="quantize", variant="success")
yield Button("Evaluate", id="evaluate", variant="warning")
yield Button("Refresh", id="refresh", variant="default")
with Container(classes="model-operations"):
with TabbedContent():
with TabPane("Operations"):
with Container():
yield Log(id="operations-log")
with Container():
yield Label("Operation Progress:")
yield ProgressBar(
total=100,
id="operation-progress",
),
),
),
),
TabPane(
"Model Info",
ScrollableContainer(
Static(
)
with TabPane("Model Info"):
with ScrollableContainer():
yield Static(
"Model information will appear here",
id="model-info",
),
),
),
),
classes="model-operations",
),
classes="model-container",
),
id="content",
)
)
yield Footer()
def on_mount(self) -> None:
@@ -210,10 +190,10 @@ class ModelScreen(BaseScreen):
@on(DataTable.RowSelected)
def handle_model_selected(self, event: DataTable.RowSelected) -> None:
"""Handle model selection from table."""
if event.row_index >= 0:
if event.cursor_row >= 0:
model_names = list(self.models.keys())
if event.row_index < len(model_names):
self.selected_model = model_names[event.row_index]
if event.cursor_row < len(model_names):
self.selected_model = model_names[event.cursor_row]
self.update_model_info()
def update_model_info(self) -> None:

View File

@@ -4,7 +4,7 @@ import psutil
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container, Horizontal, Vertical
from textual.containers import Container
from textual.widgets import (
Button,
DataTable,
@@ -15,8 +15,6 @@ from textual.widgets import (
ProgressBar,
Sparkline,
Static,
TabbedContent,
TabPane,
)
from axolotl.tui.screens.base import BaseScreen
@@ -59,7 +57,6 @@ class MonitorScreen(BaseScreen):
text-style: bold;
text-align: center;
padding: 1;
font-size: 2;
}
.charts-container {
@@ -70,7 +67,7 @@ class MonitorScreen(BaseScreen):
.chart-panel {
width: 50%;
border: solid $info;
border: solid $primary;
padding: 1;
margin: 0 1;
}
@@ -179,36 +176,9 @@ class MonitorScreen(BaseScreen):
classes="charts-container",
),
Container(
TabbedContent(
TabPane(
"Processes",
Container(
DataTable(id="process-table"),
Container(
Button(
"Kill Process",
id="kill-process",
variant="error",
),
Button("Refresh", id="refresh", variant="default"),
Button(
"Auto Refresh",
id="auto-refresh",
variant="primary",
),
classes="process-controls",
),
),
),
TabPane(
"GPU Info",
Log(id="gpu-info", wrap=True, highlight=True),
),
TabPane(
"System Logs",
Log(id="system-logs", wrap=True, highlight=True),
),
),
DataTable(id="process-table"),
Log(id="gpu-info"),
Log(id="system-logs"),
classes="processes-container",
),
classes="monitor-container",
@@ -360,7 +330,7 @@ class MonitorScreen(BaseScreen):
"""Update system information."""
try:
# System info
boot_time = psutil.boot_time()
psutil.boot_time()
cpu_count = psutil.cpu_count()
memory = psutil.virtual_memory()

View File

@@ -1,7 +1,5 @@
"""Training management screen for Axolotl TUI."""
import asyncio
import os
import subprocess
import threading
from dataclasses import dataclass
@@ -12,24 +10,16 @@ from typing import Dict, List, Optional
from textual import on, work
from textual.app import ComposeResult
from textual.binding import Binding
from textual.containers import Container, Horizontal, ScrollableContainer, Vertical
from textual.reactive import reactive
from textual.containers import Container
from textual.widgets import (
Button,
DataTable,
Footer,
Header,
Input,
Label,
LoadingIndicator,
Log,
ProgressBar,
Select,
Sparkline,
Static,
Switch,
TabbedContent,
TabPane,
)
from axolotl.tui.screens.base import BaseScreen
@@ -101,7 +91,7 @@ class TrainingScreen(BaseScreen):
.metrics-panel {
layout: horizontal;
height: 10;
border: solid $info;
border: solid $primary;
padding: 1;
margin: 1;
}
@@ -227,20 +217,7 @@ class TrainingScreen(BaseScreen):
classes="sparkline-container",
),
Container(
TabbedContent(
TabPane(
"Training Logs",
Log(id="training-logs", wrap=True, highlight=True),
),
TabPane(
"System Logs",
Log(id="system-logs", wrap=True, highlight=True),
),
TabPane(
"Validation",
Log(id="validation-logs", wrap=True, highlight=True),
),
),
Log(id="training-logs"),
classes="log-viewer",
),
classes="job-details-container",
@@ -338,10 +315,10 @@ class TrainingScreen(BaseScreen):
@on(DataTable.RowSelected)
def handle_row_selected(self, event: DataTable.RowSelected) -> None:
"""Handle job selection from table."""
if event.row_index >= 0:
if event.cursor_row >= 0:
job_ids = list(self.jobs.keys())
if event.row_index < len(job_ids):
self.selected_job_id = job_ids[event.row_index]
if event.cursor_row < len(job_ids):
self.selected_job_id = job_ids[event.cursor_row]
self.update_selected_job_metrics()
self.load_job_logs()