Files
llm-trainer/backend/main.py
tocmo0nlord e5649148f7 Default to localhost for workstation install
- ConnectionPanel defaults host to localhost, blank username
- Backend OLLAMA_URL defaults to localhost:11434
- Systemd service reads /etc/llm-trainer/env for overrides

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-12 16:18:10 -04:00

452 lines
18 KiB
Python

import asyncio
import json
import os
import tempfile
import threading
from pathlib import Path
from typing import Optional
import httpx
import yaml
from fastapi import FastAPI, File, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from gpu import get_gpu_stats
from pipeline import STAGE_DIRS, CONFIG_PATH, ingest_cmd, create_cmd, curate_cmd, save_as_cmd, train_cmd
from ssh_client import ssh_manager
# ──────────────────────────────────────────────────────────────────────────────
app = FastAPI(title="LLM Trainer API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434")
# ──────────────────────────────────────────────────────────────────────────────
# Pydantic models
# ──────────────────────────────────────────────────────────────────────────────
class ConnectRequest(BaseModel):
host: str = "192.168.2.47"
username: str = "tocmo0nlord"
password: Optional[str] = None
key_path: Optional[str] = None
port: int = 22
class TrainRequest(BaseModel):
model_name: str = "llama3.1:8b"
dataset_path: str
output_dir: str = "/opt/synthetic/output"
num_epochs: int = 3
batch_size: int = 2
learning_rate: float = 2e-4
# ──────────────────────────────────────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────────────────────────────────────
def _require_ssh():
if not ssh_manager.is_connected():
raise HTTPException(status_code=503, detail="Not connected to SSH server")
async def _stream_ws(websocket: WebSocket, command: str, use_conda: bool = True):
"""Run a remote command and stream output lines over WebSocket."""
await websocket.accept()
loop = asyncio.get_event_loop()
queue: asyncio.Queue = asyncio.Queue()
def _worker():
try:
for line in ssh_manager.execute_stream(command, use_conda=use_conda):
asyncio.run_coroutine_threadsafe(
queue.put({"type": "log", "data": line}), loop
)
asyncio.run_coroutine_threadsafe(
queue.put({"type": "done", "data": "Command completed."}), loop
)
except Exception as exc:
asyncio.run_coroutine_threadsafe(
queue.put({"type": "error", "data": str(exc)}), loop
)
threading.Thread(target=_worker, daemon=True).start()
try:
while True:
msg = await queue.get()
await websocket.send_json(msg)
if msg["type"] in ("done", "error"):
break
except WebSocketDisconnect:
pass
# ──────────────────────────────────────────────────────────────────────────────
# Connection
# ──────────────────────────────────────────────────────────────────────────────
@app.post("/api/connect")
async def connect(req: ConnectRequest):
try:
ssh_manager.connect(
host=req.host,
username=req.username,
password=req.password,
key_path=req.key_path,
port=req.port,
)
return {"status": "connected", "host": req.host, "username": req.username}
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
@app.post("/api/disconnect")
async def disconnect():
ssh_manager.disconnect()
return {"status": "disconnected"}
@app.get("/api/status")
async def status():
connected = ssh_manager.is_connected()
gpu = get_gpu_stats() if connected else {"gpus": [], "error": "Not connected"}
return {
"connected": connected,
"host": ssh_manager.host if connected else None,
"username": ssh_manager.username if connected else None,
"gpu": gpu,
}
# ──────────────────────────────────────────────────────────────────────────────
# GPU
# ──────────────────────────────────────────────────────────────────────────────
@app.get("/api/gpu")
async def gpu():
_require_ssh()
return get_gpu_stats()
# ──────────────────────────────────────────────────────────────────────────────
# File management
# ──────────────────────────────────────────────────────────────────────────────
@app.get("/api/files/{stage}")
async def list_files(stage: str):
if stage not in STAGE_DIRS:
raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}")
_require_ssh()
out, _, code = ssh_manager.execute(
f"ls -la '{STAGE_DIRS[stage]}' 2>/dev/null | tail -n +2", use_conda=False
)
files = []
for line in out.strip().split("\n"):
if not line.strip() or line.startswith("total"):
continue
parts = line.split()
if len(parts) >= 9 and not parts[0].startswith("d"):
files.append({
"name": " ".join(parts[8:]),
"size": int(parts[4]),
"modified": f"{parts[5]} {parts[6]} {parts[7]}",
})
return {"stage": stage, "directory": STAGE_DIRS[stage], "files": files}
@app.delete("/api/files/{stage}/{filename}")
async def delete_file(stage: str, filename: str):
if stage not in STAGE_DIRS:
raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}")
_require_ssh()
path = f"{STAGE_DIRS[stage]}/{filename}"
_, err, code = ssh_manager.execute(f"rm -f '{path}'", use_conda=False)
if code != 0:
raise HTTPException(status_code=500, detail=err)
return {"deleted": filename}
@app.get("/api/files/{stage}/{filename}/preview")
async def preview_file(stage: str, filename: str, lines: int = 120):
if stage not in STAGE_DIRS:
raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}")
_require_ssh()
path = f"{STAGE_DIRS[stage]}/{filename}"
out, err, code = ssh_manager.execute(f"head -n {lines} '{path}'", use_conda=False)
if code != 0:
raise HTTPException(status_code=500, detail=err)
return {"filename": filename, "content": out}
@app.post("/api/upload")
async def upload_file(file: UploadFile = File(...)):
_require_ssh()
suffix = Path(file.filename).suffix
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(await file.read())
tmp_path = tmp.name
try:
remote_path = f"{STAGE_DIRS['input']}/{file.filename}"
ssh_manager.upload_file(tmp_path, remote_path)
return {"uploaded": file.filename, "remote_path": remote_path}
finally:
os.unlink(tmp_path)
# ──────────────────────────────────────────────────────────────────────────────
# Pipeline (WebSocket streaming)
# ──────────────────────────────────────────────────────────────────────────────
@app.websocket("/api/pipeline/ingest")
async def ws_ingest(websocket: WebSocket, filename: str):
if not ssh_manager.is_connected():
await websocket.accept()
await websocket.send_json({"type": "error", "data": "Not connected"})
return
cmd = ingest_cmd(f"{STAGE_DIRS['input']}/{filename}")
await _stream_ws(websocket, cmd)
@app.websocket("/api/pipeline/create")
async def ws_create(websocket: WebSocket, filename: str,
num_pairs: int = 50, pair_type: str = "qa"):
if not ssh_manager.is_connected():
await websocket.accept()
await websocket.send_json({"type": "error", "data": "Not connected"})
return
cmd = create_cmd(f"{STAGE_DIRS['parsed']}/{filename}", num_pairs, pair_type)
await _stream_ws(websocket, cmd)
@app.websocket("/api/pipeline/curate")
async def ws_curate(websocket: WebSocket, filename: str,
output_filename: str, threshold: float = 7.0):
if not ssh_manager.is_connected():
await websocket.accept()
await websocket.send_json({"type": "error", "data": "Not connected"})
return
cmd = curate_cmd(
f"{STAGE_DIRS['generated']}/{filename}",
f"{STAGE_DIRS['curated']}/{output_filename}",
threshold,
)
await _stream_ws(websocket, cmd)
@app.websocket("/api/pipeline/save")
async def ws_save(websocket: WebSocket, filename: str,
output_filename: str, fmt: str = "jsonl"):
if not ssh_manager.is_connected():
await websocket.accept()
await websocket.send_json({"type": "error", "data": "Not connected"})
return
cmd = save_as_cmd(
f"{STAGE_DIRS['curated']}/{filename}",
f"{STAGE_DIRS['final']}/{output_filename}",
fmt,
)
await _stream_ws(websocket, cmd)
# ──────────────────────────────────────────────────────────────────────────────
# QA Pairs viewer
# ──────────────────────────────────────────────────────────────────────────────
@app.get("/api/pairs/{filename}")
async def get_pairs(filename: str, stage: str = "generated"):
_require_ssh()
path = f"{STAGE_DIRS.get(stage, STAGE_DIRS['generated'])}/{filename}"
out, err, code = ssh_manager.execute(f"cat '{path}'", use_conda=False)
if code != 0:
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
pairs = []
for line in out.strip().split("\n"):
if not line.strip():
continue
try:
pairs.append(json.loads(line))
except json.JSONDecodeError:
pass
return {"filename": filename, "count": len(pairs), "pairs": pairs}
# ──────────────────────────────────────────────────────────────────────────────
# Config editor
# ──────────────────────────────────────────────────────────────────────────────
@app.get("/api/config")
async def get_config():
_require_ssh()
try:
raw = ssh_manager.read_remote_file(CONFIG_PATH)
return {"config": yaml.safe_load(raw), "raw": raw}
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
@app.put("/api/config")
async def update_config(payload: dict):
_require_ssh()
try:
ssh_manager.write_remote_file(CONFIG_PATH, yaml.dump(payload, default_flow_style=False))
return {"status": "updated"}
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
# ──────────────────────────────────────────────────────────────────────────────
# Training (WebSocket streaming)
# ──────────────────────────────────────────────────────────────────────────────
@app.websocket("/api/train")
async def ws_train(
websocket: WebSocket,
model_name: str = "llama3.1:8b",
dataset_path: str = "",
output_dir: str = "/opt/synthetic/output",
num_epochs: int = 3,
batch_size: int = 2,
learning_rate: float = 2e-4,
):
if not ssh_manager.is_connected():
await websocket.accept()
await websocket.send_json({"type": "error", "data": "Not connected"})
return
cmd = train_cmd(model_name, dataset_path, output_dir, num_epochs, batch_size, learning_rate)
await _stream_ws(websocket, cmd)
# ──────────────────────────────────────────────────────────────────────────────
# Interactive terminal (xterm.js ↔ SSH shell)
# ──────────────────────────────────────────────────────────────────────────────
@app.websocket("/api/terminal")
async def ws_terminal(websocket: WebSocket):
await websocket.accept()
if not ssh_manager.is_connected():
await websocket.send_text("\r\nNot connected to SSH server.\r\n")
return
channel = None
try:
channel = ssh_manager.open_shell_channel()
async def ssh_to_ws():
while True:
if channel.recv_ready():
data = channel.recv(4096)
if not data:
break
await websocket.send_bytes(data)
elif channel.exit_status_ready():
break
else:
await asyncio.sleep(0.02)
async def ws_to_ssh():
try:
while True:
data = await websocket.receive_bytes()
channel.send(data)
except WebSocketDisconnect:
pass
await asyncio.gather(ssh_to_ws(), ws_to_ssh())
except WebSocketDisconnect:
pass
except Exception as exc:
try:
await websocket.send_text(f"\r\nError: {exc}\r\n")
except Exception:
pass
finally:
if channel:
try:
channel.close()
except Exception:
pass
# ──────────────────────────────────────────────────────────────────────────────
# Model manager (Ollama)
# ──────────────────────────────────────────────────────────────────────────────
@app.get("/api/models")
async def list_models():
try:
async with httpx.AsyncClient(timeout=10) as client:
resp = await client.get(f"{OLLAMA_URL}/api/tags")
resp.raise_for_status()
return {"models": resp.json().get("models", []), "error": None}
except Exception as exc:
# Return empty list instead of crashing — Ollama may not be reachable yet
return {"models": [], "error": str(exc)}
@app.websocket("/api/models/pull")
async def ws_pull_model(websocket: WebSocket, model_name: str):
await websocket.accept()
try:
async with httpx.AsyncClient(timeout=600) as client:
async with client.stream(
"POST", f"{OLLAMA_URL}/api/pull",
json={"name": model_name, "stream": True}
) as resp:
async for line in resp.aiter_lines():
if line.strip():
try:
await websocket.send_json(json.loads(line))
except json.JSONDecodeError:
pass
await websocket.send_json({"status": "success"})
except WebSocketDisconnect:
pass
except Exception as exc:
try:
await websocket.send_json({"status": "error", "error": str(exc)})
except Exception:
pass
@app.delete("/api/models/{model_name:path}")
async def delete_model(model_name: str):
try:
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.request(
"DELETE", f"{OLLAMA_URL}/api/delete",
json={"name": model_name}
)
resp.raise_for_status()
return {"deleted": model_name}
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
# ──────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8080, reload=True)