Initial scaffold: LLM Trainer Dashboard
Full-stack app with FastAPI backend (SSH/paramiko, pipeline streaming, GPU stats, xterm.js terminal, Ollama model manager) and React + Tailwind frontend (8 panels: Connection, Documents, Pipeline, QA Pairs, Training, Terminal, Models, Config). Docker Compose included. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
450
backend/main.py
Normal file
450
backend/main.py
Normal file
@@ -0,0 +1,450 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import FastAPI, File, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
from gpu import get_gpu_stats
|
||||
from pipeline import STAGE_DIRS, CONFIG_PATH, ingest_cmd, create_cmd, curate_cmd, save_as_cmd, train_cmd
|
||||
from ssh_client import ssh_manager
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
app = FastAPI(title="LLM Trainer API", version="1.0.0")
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://192.168.2.47:11434")
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Pydantic models
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class ConnectRequest(BaseModel):
|
||||
host: str = "192.168.2.47"
|
||||
username: str = "tocmo0nlord"
|
||||
password: Optional[str] = None
|
||||
key_path: Optional[str] = None
|
||||
port: int = 22
|
||||
|
||||
|
||||
class TrainRequest(BaseModel):
|
||||
model_name: str = "llama3.1:8b"
|
||||
dataset_path: str
|
||||
output_dir: str = "/opt/synthetic/output"
|
||||
num_epochs: int = 3
|
||||
batch_size: int = 2
|
||||
learning_rate: float = 2e-4
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Helpers
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _require_ssh():
|
||||
if not ssh_manager.is_connected():
|
||||
raise HTTPException(status_code=503, detail="Not connected to SSH server")
|
||||
|
||||
|
||||
async def _stream_ws(websocket: WebSocket, command: str, use_conda: bool = True):
|
||||
"""Run a remote command and stream output lines over WebSocket."""
|
||||
await websocket.accept()
|
||||
loop = asyncio.get_event_loop()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
def _worker():
|
||||
try:
|
||||
for line in ssh_manager.execute_stream(command, use_conda=use_conda):
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
queue.put({"type": "log", "data": line}), loop
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
queue.put({"type": "done", "data": "Command completed."}), loop
|
||||
)
|
||||
except Exception as exc:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
queue.put({"type": "error", "data": str(exc)}), loop
|
||||
)
|
||||
|
||||
threading.Thread(target=_worker, daemon=True).start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
msg = await queue.get()
|
||||
await websocket.send_json(msg)
|
||||
if msg["type"] in ("done", "error"):
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Connection
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.post("/api/connect")
|
||||
async def connect(req: ConnectRequest):
|
||||
try:
|
||||
ssh_manager.connect(
|
||||
host=req.host,
|
||||
username=req.username,
|
||||
password=req.password,
|
||||
key_path=req.key_path,
|
||||
port=req.port,
|
||||
)
|
||||
return {"status": "connected", "host": req.host, "username": req.username}
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@app.post("/api/disconnect")
|
||||
async def disconnect():
|
||||
ssh_manager.disconnect()
|
||||
return {"status": "disconnected"}
|
||||
|
||||
|
||||
@app.get("/api/status")
|
||||
async def status():
|
||||
connected = ssh_manager.is_connected()
|
||||
gpu = get_gpu_stats() if connected else {"gpus": [], "error": "Not connected"}
|
||||
return {
|
||||
"connected": connected,
|
||||
"host": ssh_manager.host if connected else None,
|
||||
"username": ssh_manager.username if connected else None,
|
||||
"gpu": gpu,
|
||||
}
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# GPU
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/api/gpu")
|
||||
async def gpu():
|
||||
_require_ssh()
|
||||
return get_gpu_stats()
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# File management
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/api/files/{stage}")
|
||||
async def list_files(stage: str):
|
||||
if stage not in STAGE_DIRS:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}")
|
||||
_require_ssh()
|
||||
|
||||
out, _, code = ssh_manager.execute(
|
||||
f"ls -la '{STAGE_DIRS[stage]}' 2>/dev/null | tail -n +2", use_conda=False
|
||||
)
|
||||
|
||||
files = []
|
||||
for line in out.strip().split("\n"):
|
||||
if not line.strip() or line.startswith("total"):
|
||||
continue
|
||||
parts = line.split()
|
||||
if len(parts) >= 9 and not parts[0].startswith("d"):
|
||||
files.append({
|
||||
"name": " ".join(parts[8:]),
|
||||
"size": int(parts[4]),
|
||||
"modified": f"{parts[5]} {parts[6]} {parts[7]}",
|
||||
})
|
||||
|
||||
return {"stage": stage, "directory": STAGE_DIRS[stage], "files": files}
|
||||
|
||||
|
||||
@app.delete("/api/files/{stage}/{filename}")
|
||||
async def delete_file(stage: str, filename: str):
|
||||
if stage not in STAGE_DIRS:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}")
|
||||
_require_ssh()
|
||||
|
||||
path = f"{STAGE_DIRS[stage]}/{filename}"
|
||||
_, err, code = ssh_manager.execute(f"rm -f '{path}'", use_conda=False)
|
||||
if code != 0:
|
||||
raise HTTPException(status_code=500, detail=err)
|
||||
return {"deleted": filename}
|
||||
|
||||
|
||||
@app.get("/api/files/{stage}/{filename}/preview")
|
||||
async def preview_file(stage: str, filename: str, lines: int = 120):
|
||||
if stage not in STAGE_DIRS:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}")
|
||||
_require_ssh()
|
||||
|
||||
path = f"{STAGE_DIRS[stage]}/{filename}"
|
||||
out, err, code = ssh_manager.execute(f"head -n {lines} '{path}'", use_conda=False)
|
||||
if code != 0:
|
||||
raise HTTPException(status_code=500, detail=err)
|
||||
return {"filename": filename, "content": out}
|
||||
|
||||
|
||||
@app.post("/api/upload")
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
_require_ssh()
|
||||
|
||||
suffix = Path(file.filename).suffix
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
tmp.write(await file.read())
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
remote_path = f"{STAGE_DIRS['input']}/{file.filename}"
|
||||
ssh_manager.upload_file(tmp_path, remote_path)
|
||||
return {"uploaded": file.filename, "remote_path": remote_path}
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Pipeline (WebSocket streaming)
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.websocket("/api/pipeline/ingest")
|
||||
async def ws_ingest(websocket: WebSocket, filename: str):
|
||||
if not ssh_manager.is_connected():
|
||||
await websocket.accept()
|
||||
await websocket.send_json({"type": "error", "data": "Not connected"})
|
||||
return
|
||||
cmd = ingest_cmd(f"{STAGE_DIRS['input']}/{filename}")
|
||||
await _stream_ws(websocket, cmd)
|
||||
|
||||
|
||||
@app.websocket("/api/pipeline/create")
|
||||
async def ws_create(websocket: WebSocket, filename: str,
|
||||
num_pairs: int = 50, pair_type: str = "qa"):
|
||||
if not ssh_manager.is_connected():
|
||||
await websocket.accept()
|
||||
await websocket.send_json({"type": "error", "data": "Not connected"})
|
||||
return
|
||||
cmd = create_cmd(f"{STAGE_DIRS['parsed']}/{filename}", num_pairs, pair_type)
|
||||
await _stream_ws(websocket, cmd)
|
||||
|
||||
|
||||
@app.websocket("/api/pipeline/curate")
|
||||
async def ws_curate(websocket: WebSocket, filename: str,
|
||||
output_filename: str, threshold: float = 7.0):
|
||||
if not ssh_manager.is_connected():
|
||||
await websocket.accept()
|
||||
await websocket.send_json({"type": "error", "data": "Not connected"})
|
||||
return
|
||||
cmd = curate_cmd(
|
||||
f"{STAGE_DIRS['generated']}/{filename}",
|
||||
f"{STAGE_DIRS['curated']}/{output_filename}",
|
||||
threshold,
|
||||
)
|
||||
await _stream_ws(websocket, cmd)
|
||||
|
||||
|
||||
@app.websocket("/api/pipeline/save")
|
||||
async def ws_save(websocket: WebSocket, filename: str,
|
||||
output_filename: str, fmt: str = "jsonl"):
|
||||
if not ssh_manager.is_connected():
|
||||
await websocket.accept()
|
||||
await websocket.send_json({"type": "error", "data": "Not connected"})
|
||||
return
|
||||
cmd = save_as_cmd(
|
||||
f"{STAGE_DIRS['curated']}/{filename}",
|
||||
f"{STAGE_DIRS['final']}/{output_filename}",
|
||||
fmt,
|
||||
)
|
||||
await _stream_ws(websocket, cmd)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# QA Pairs viewer
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/api/pairs/{filename}")
|
||||
async def get_pairs(filename: str, stage: str = "generated"):
|
||||
_require_ssh()
|
||||
path = f"{STAGE_DIRS.get(stage, STAGE_DIRS['generated'])}/{filename}"
|
||||
out, err, code = ssh_manager.execute(f"cat '{path}'", use_conda=False)
|
||||
if code != 0:
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
|
||||
|
||||
pairs = []
|
||||
for line in out.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
pairs.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return {"filename": filename, "count": len(pairs), "pairs": pairs}
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Config editor
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/api/config")
|
||||
async def get_config():
|
||||
_require_ssh()
|
||||
try:
|
||||
raw = ssh_manager.read_remote_file(CONFIG_PATH)
|
||||
return {"config": yaml.safe_load(raw), "raw": raw}
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@app.put("/api/config")
|
||||
async def update_config(payload: dict):
|
||||
_require_ssh()
|
||||
try:
|
||||
ssh_manager.write_remote_file(CONFIG_PATH, yaml.dump(payload, default_flow_style=False))
|
||||
return {"status": "updated"}
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Training (WebSocket streaming)
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.websocket("/api/train")
|
||||
async def ws_train(
|
||||
websocket: WebSocket,
|
||||
model_name: str = "llama3.1:8b",
|
||||
dataset_path: str = "",
|
||||
output_dir: str = "/opt/synthetic/output",
|
||||
num_epochs: int = 3,
|
||||
batch_size: int = 2,
|
||||
learning_rate: float = 2e-4,
|
||||
):
|
||||
if not ssh_manager.is_connected():
|
||||
await websocket.accept()
|
||||
await websocket.send_json({"type": "error", "data": "Not connected"})
|
||||
return
|
||||
|
||||
cmd = train_cmd(model_name, dataset_path, output_dir, num_epochs, batch_size, learning_rate)
|
||||
await _stream_ws(websocket, cmd)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Interactive terminal (xterm.js ↔ SSH shell)
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.websocket("/api/terminal")
|
||||
async def ws_terminal(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
|
||||
if not ssh_manager.is_connected():
|
||||
await websocket.send_text("\r\nNot connected to SSH server.\r\n")
|
||||
return
|
||||
|
||||
channel = None
|
||||
try:
|
||||
channel = ssh_manager.open_shell_channel()
|
||||
|
||||
async def ssh_to_ws():
|
||||
while True:
|
||||
if channel.recv_ready():
|
||||
data = channel.recv(4096)
|
||||
if not data:
|
||||
break
|
||||
await websocket.send_bytes(data)
|
||||
elif channel.exit_status_ready():
|
||||
break
|
||||
else:
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
async def ws_to_ssh():
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_bytes()
|
||||
channel.send(data)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
await asyncio.gather(ssh_to_ws(), ws_to_ssh())
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as exc:
|
||||
try:
|
||||
await websocket.send_text(f"\r\nError: {exc}\r\n")
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if channel:
|
||||
try:
|
||||
channel.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Model manager (Ollama)
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/api/models")
|
||||
async def list_models():
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.get(f"{OLLAMA_URL}/api/tags")
|
||||
resp.raise_for_status()
|
||||
return {"models": resp.json().get("models", [])}
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@app.websocket("/api/models/pull")
|
||||
async def ws_pull_model(websocket: WebSocket, model_name: str):
|
||||
await websocket.accept()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=600) as client:
|
||||
async with client.stream(
|
||||
"POST", f"{OLLAMA_URL}/api/pull",
|
||||
json={"name": model_name, "stream": True}
|
||||
) as resp:
|
||||
async for line in resp.aiter_lines():
|
||||
if line.strip():
|
||||
try:
|
||||
await websocket.send_json(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
await websocket.send_json({"status": "success"})
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as exc:
|
||||
try:
|
||||
await websocket.send_json({"status": "error", "error": str(exc)})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@app.delete("/api/models/{model_name:path}")
|
||||
async def delete_model(model_name: str):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.request(
|
||||
"DELETE", f"{OLLAMA_URL}/api/delete",
|
||||
json={"name": model_name}
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return {"deleted": model_name}
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8080, reload=True)
|
||||
Reference in New Issue
Block a user