Files
llm-trainer/backend/main.py

580 lines
24 KiB
Python

import asyncio
import zipfile
import json
import os
import tempfile
import threading
from pathlib import Path
from typing import List, Optional
import httpx
import yaml
from fastapi import FastAPI, File, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from gpu import get_gpu_stats
from pipeline import STAGE_DIRS, CONFIG_PATH, ingest_cmd, create_cmd, curate_cmd, save_as_cmd, train_cmd
from ssh_client import ssh_manager
# ──────────────────────────────────────────────────────────────────────────────
app = FastAPI(title="LLM Trainer API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434")
# ──────────────────────────────────────────────────────────────────────────────
# Pydantic models
# ──────────────────────────────────────────────────────────────────────────────
class ConnectRequest(BaseModel):
host: str = "192.168.2.47"
username: str = "tocmo0nlord"
password: Optional[str] = None
key_path: Optional[str] = None
port: int = 22
class TrainRequest(BaseModel):
model_name: str = "llama3.1:8b"
dataset_path: str
output_dir: str = "/opt/synthetic/output"
num_epochs: int = 3
batch_size: int = 2
learning_rate: float = 2e-4
# ──────────────────────────────────────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────────────────────────────────────
def _require_ssh():
if not ssh_manager.is_connected():
raise HTTPException(status_code=503, detail="Not connected to SSH server")
async def _stream_ws(websocket: WebSocket, command: str, use_conda: bool = True):
"""Run a remote command and stream output lines over WebSocket."""
await websocket.accept()
loop = asyncio.get_event_loop()
queue: asyncio.Queue = asyncio.Queue()
def _worker():
try:
for line in ssh_manager.execute_stream(command, use_conda=use_conda):
asyncio.run_coroutine_threadsafe(
queue.put({"type": "log", "data": line}), loop
)
asyncio.run_coroutine_threadsafe(
queue.put({"type": "done", "data": "Command completed."}), loop
)
except Exception as exc:
asyncio.run_coroutine_threadsafe(
queue.put({"type": "error", "data": str(exc)}), loop
)
threading.Thread(target=_worker, daemon=True).start()
try:
while True:
msg = await queue.get()
await websocket.send_json(msg)
if msg["type"] in ("done", "error"):
break
except WebSocketDisconnect:
pass
# ──────────────────────────────────────────────────────────────────────────────
# Connection
# ──────────────────────────────────────────────────────────────────────────────
@app.post("/api/connect")
async def connect(req: ConnectRequest):
try:
ssh_manager.connect(
host=req.host,
username=req.username,
password=req.password,
key_path=req.key_path,
port=req.port,
)
return {"status": "connected", "host": req.host, "username": req.username}
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
@app.post("/api/disconnect")
async def disconnect():
ssh_manager.disconnect()
return {"status": "disconnected"}
@app.get("/api/status")
async def status():
connected = ssh_manager.is_connected()
gpu = get_gpu_stats() if connected else {"gpus": [], "error": "Not connected"}
return {
"connected": connected,
"host": ssh_manager.host if connected else None,
"username": ssh_manager.username if connected else None,
"gpu": gpu,
}
# ──────────────────────────────────────────────────────────────────────────────
# GPU
# ──────────────────────────────────────────────────────────────────────────────
@app.get("/api/gpu")
async def gpu():
_require_ssh()
return get_gpu_stats()
# ──────────────────────────────────────────────────────────────────────────────
# File management
# ──────────────────────────────────────────────────────────────────────────────
@app.get("/api/files/{stage}")
async def list_files(stage: str):
if stage not in STAGE_DIRS:
raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}")
_require_ssh()
out, _, code = ssh_manager.execute(
f"ls -la '{STAGE_DIRS[stage]}' 2>/dev/null | tail -n +2", use_conda=False
)
files = []
for line in out.strip().split("\n"):
if not line.strip() or line.startswith("total"):
continue
parts = line.split()
if len(parts) >= 9 and not parts[0].startswith("d"):
files.append({
"name": " ".join(parts[8:]),
"size": int(parts[4]),
"modified": f"{parts[5]} {parts[6]} {parts[7]}",
})
return {"stage": stage, "directory": STAGE_DIRS[stage], "files": files}
@app.delete("/api/files/{stage}/{filename}")
async def delete_file(stage: str, filename: str):
if stage not in STAGE_DIRS:
raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}")
_require_ssh()
path = f"{STAGE_DIRS[stage]}/{filename}"
_, err, code = ssh_manager.execute(f"rm -f '{path}'", use_conda=False)
if code != 0:
raise HTTPException(status_code=500, detail=err)
return {"deleted": filename}
@app.get("/api/files/{stage}/{filename}/preview")
async def preview_file(stage: str, filename: str, lines: int = 120):
if stage not in STAGE_DIRS:
raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}")
_require_ssh()
path = f"{STAGE_DIRS[stage]}/{filename}"
out, err, code = ssh_manager.execute(f"head -n {lines} '{path}'", use_conda=False)
if code != 0:
raise HTTPException(status_code=500, detail=err)
return {"filename": filename, "content": out}
@app.post("/api/upload")
async def upload_files(files: List[UploadFile] = File(...)):
_require_ssh()
results = []
for file in files:
suffix = Path(file.filename).suffix.lower()
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(await file.read())
tmp_path = tmp.name
try:
if suffix == ".zip":
extracted = []
with zipfile.ZipFile(tmp_path) as zf:
for member in zf.infolist():
if member.is_dir():
continue
name = Path(member.filename).name
if not name or name.startswith("."):
continue
with zf.open(member) as src, tempfile.NamedTemporaryFile(delete=False, suffix=Path(name).suffix) as out_tmp:
out_tmp.write(src.read())
out_tmp_path = out_tmp.name
try:
remote_path = f"{STAGE_DIRS['input']}/{name}"
ssh_manager.upload_file(out_tmp_path, remote_path)
extracted.append(name)
finally:
os.unlink(out_tmp_path)
results.append({"file": file.filename, "action": "extracted", "files": extracted})
else:
remote_path = f"{STAGE_DIRS['input']}/{file.filename}"
ssh_manager.upload_file(tmp_path, remote_path)
results.append({"file": file.filename, "action": "uploaded", "remote_path": remote_path})
finally:
os.unlink(tmp_path)
return {"results": results}
# ──────────────────────────────────────────────────────────────────────────────
# Setup / Bootstrap
# ──────────────────────────────────────────────────────────────────────────────
# Where bootstrap assets live in the .deb install layout
ASSETS_DIR = Path(os.getenv("LLM_TRAINER_ASSETS", "/opt/llm-trainer/remote"))
@app.get("/api/setup/check")
async def setup_check():
"""Detect what's already set up on the remote GPU host."""
_require_ssh()
user = ssh_manager.username
checks = {}
# Probe each component with a single shell command, parse the result
probe = (
f"if [ -x $HOME/miniconda3/bin/conda ]; then echo 'conda=ok'; else echo 'conda=missing'; fi; "
f"if $HOME/miniconda3/bin/conda env list 2>/dev/null | awk '{{print $1}}' | grep -qx synthetic-data; "
f"then echo 'env=ok'; else echo 'env=missing'; fi; "
f"if [ -x $HOME/miniconda3/envs/synthetic-data/bin/synthetic-data-kit ]; "
f"then echo 'sdk=ok'; else echo 'sdk=missing'; fi; "
f"if [ -f /opt/synthetic/train.py ] || [ -f $HOME/synthetic/train.py ]; "
f"then echo 'train_py=ok'; else echo 'train_py=missing'; fi; "
f"if [ -d /opt/synthetic/synthetic-data-kit/data/input ] || [ -d $HOME/synthetic/synthetic-data-kit/data/input ]; "
f"then echo 'data_dirs=ok'; else echo 'data_dirs=missing'; fi; "
f"if command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi -L >/dev/null 2>&1; "
f"then echo 'gpu=ok'; else echo 'gpu=missing'; fi; "
f"if $HOME/miniconda3/envs/synthetic-data/bin/python -c 'import torch,transformers,peft,trl' 2>/dev/null; "
f"then echo 'training_deps=ok'; else echo 'training_deps=missing'; fi"
)
out, err, _ = ssh_manager.execute(probe, use_conda=False)
for line in out.splitlines():
if "=" in line:
k, v = line.strip().split("=", 1)
checks[k] = v == "ok"
ready = all(checks.get(k, False) for k in
("conda", "env", "sdk", "train_py", "data_dirs", "training_deps"))
return {"checks": checks, "ready": ready, "user": user}
@app.websocket("/api/setup/bootstrap")
async def ws_bootstrap(websocket: WebSocket):
"""Upload bootstrap.sh + train.py + config.yaml to the remote and run it."""
await websocket.accept()
if not ssh_manager.is_connected():
await websocket.send_json({"type": "error", "data": "Not connected"})
return
if not ASSETS_DIR.is_dir():
await websocket.send_json({"type": "error",
"data": f"Bootstrap assets not found at {ASSETS_DIR}"})
return
try:
remote_dir = "/tmp/llm-trainer-bootstrap"
ssh_manager.execute(f"mkdir -p {remote_dir}", use_conda=False)
for name in ("bootstrap.sh", "train.py", "config.yaml"):
src = ASSETS_DIR / name
if not src.is_file():
await websocket.send_json({"type": "error",
"data": f"Missing asset: {name}"})
return
ssh_manager.upload_file(str(src), f"{remote_dir}/{name}")
ssh_manager.execute(f"chmod +x {remote_dir}/bootstrap.sh", use_conda=False)
except Exception as exc:
await websocket.send_json({"type": "error",
"data": f"Asset upload failed: {exc}"})
return
# Stream bootstrap output
loop = asyncio.get_event_loop()
queue: asyncio.Queue = asyncio.Queue()
def _worker():
try:
for line in ssh_manager.execute_stream(
f"bash {remote_dir}/bootstrap.sh", use_conda=False
):
asyncio.run_coroutine_threadsafe(
queue.put({"type": "log", "data": line}), loop
)
asyncio.run_coroutine_threadsafe(
queue.put({"type": "done", "data": "Bootstrap complete."}), loop
)
except Exception as exc:
asyncio.run_coroutine_threadsafe(
queue.put({"type": "error", "data": str(exc)}), loop
)
threading.Thread(target=_worker, daemon=True).start()
try:
while True:
msg = await queue.get()
await websocket.send_json(msg)
if msg["type"] in ("done", "error"):
break
except WebSocketDisconnect:
pass
# ──────────────────────────────────────────────────────────────────────────────
# Pipeline (WebSocket streaming)
# ──────────────────────────────────────────────────────────────────────────────
@app.websocket("/api/pipeline/ingest")
async def ws_ingest(websocket: WebSocket, filename: str):
if not ssh_manager.is_connected():
await websocket.accept()
await websocket.send_json({"type": "error", "data": "Not connected"})
return
cmd = ingest_cmd(f"{STAGE_DIRS['input']}/{filename}")
await _stream_ws(websocket, cmd)
@app.websocket("/api/pipeline/create")
async def ws_create(websocket: WebSocket, filename: str,
num_pairs: int = 50, pair_type: str = "qa"):
if not ssh_manager.is_connected():
await websocket.accept()
await websocket.send_json({"type": "error", "data": "Not connected"})
return
cmd = create_cmd(f"{STAGE_DIRS['parsed']}/{filename}", num_pairs, pair_type)
await _stream_ws(websocket, cmd)
@app.websocket("/api/pipeline/curate")
async def ws_curate(websocket: WebSocket, filename: str,
output_filename: str, threshold: float = 7.0):
if not ssh_manager.is_connected():
await websocket.accept()
await websocket.send_json({"type": "error", "data": "Not connected"})
return
cmd = curate_cmd(
f"{STAGE_DIRS['generated']}/{filename}",
f"{STAGE_DIRS['curated']}/{output_filename}",
threshold,
)
await _stream_ws(websocket, cmd)
@app.websocket("/api/pipeline/save")
async def ws_save(websocket: WebSocket, filename: str,
output_filename: str, fmt: str = "jsonl"):
if not ssh_manager.is_connected():
await websocket.accept()
await websocket.send_json({"type": "error", "data": "Not connected"})
return
cmd = save_as_cmd(
f"{STAGE_DIRS['curated']}/{filename}",
f"{STAGE_DIRS['final']}/{output_filename}",
fmt,
)
await _stream_ws(websocket, cmd)
# ──────────────────────────────────────────────────────────────────────────────
# QA Pairs viewer
# ──────────────────────────────────────────────────────────────────────────────
@app.get("/api/pairs/{filename}")
async def get_pairs(filename: str, stage: str = "generated"):
_require_ssh()
path = f"{STAGE_DIRS.get(stage, STAGE_DIRS['generated'])}/{filename}"
out, err, code = ssh_manager.execute(f"cat '{path}'", use_conda=False)
if code != 0:
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
pairs = []
for line in out.strip().split("\n"):
if not line.strip():
continue
try:
pairs.append(json.loads(line))
except json.JSONDecodeError:
pass
return {"filename": filename, "count": len(pairs), "pairs": pairs}
# ──────────────────────────────────────────────────────────────────────────────
# Config editor
# ──────────────────────────────────────────────────────────────────────────────
@app.get("/api/config")
async def get_config():
_require_ssh()
try:
raw = ssh_manager.read_remote_file(CONFIG_PATH)
return {"config": yaml.safe_load(raw), "raw": raw}
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
@app.put("/api/config")
async def update_config(payload: dict):
_require_ssh()
try:
ssh_manager.write_remote_file(CONFIG_PATH, yaml.dump(payload, default_flow_style=False))
return {"status": "updated"}
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
# ──────────────────────────────────────────────────────────────────────────────
# Training (WebSocket streaming)
# ──────────────────────────────────────────────────────────────────────────────
@app.websocket("/api/train")
async def ws_train(
websocket: WebSocket,
model_name: str = "llama3.1:8b",
dataset_path: str = "",
output_dir: str = "/opt/synthetic/output",
num_epochs: int = 3,
batch_size: int = 2,
learning_rate: float = 2e-4,
):
if not ssh_manager.is_connected():
await websocket.accept()
await websocket.send_json({"type": "error", "data": "Not connected"})
return
cmd = train_cmd(model_name, dataset_path, output_dir, num_epochs, batch_size, learning_rate)
await _stream_ws(websocket, cmd)
# ──────────────────────────────────────────────────────────────────────────────
# Interactive terminal (xterm.js ↔ SSH shell)
# ──────────────────────────────────────────────────────────────────────────────
@app.websocket("/api/terminal")
async def ws_terminal(websocket: WebSocket):
await websocket.accept()
if not ssh_manager.is_connected():
await websocket.send_text("\r\nNot connected to SSH server.\r\n")
return
channel = None
try:
channel = ssh_manager.open_shell_channel()
async def ssh_to_ws():
while True:
if channel.recv_ready():
data = channel.recv(4096)
if not data:
break
await websocket.send_bytes(data)
elif channel.exit_status_ready():
break
else:
await asyncio.sleep(0.02)
async def ws_to_ssh():
try:
while True:
msg = await websocket.receive()
if "bytes" in msg and msg["bytes"]:
channel.send(msg["bytes"])
elif "text" in msg and msg["text"]:
channel.send(msg["text"].encode())
except WebSocketDisconnect:
pass
await asyncio.gather(ssh_to_ws(), ws_to_ssh())
except WebSocketDisconnect:
pass
except Exception as exc:
try:
await websocket.send_text(f"\r\nError: {exc}\r\n")
except Exception:
pass
finally:
if channel:
try:
channel.close()
except Exception:
pass
# ──────────────────────────────────────────────────────────────────────────────
# Model manager (Ollama)
# ──────────────────────────────────────────────────────────────────────────────
@app.get("/api/models")
async def list_models():
try:
async with httpx.AsyncClient(timeout=10) as client:
resp = await client.get(f"{OLLAMA_URL}/api/tags")
resp.raise_for_status()
return {"models": resp.json().get("models", []), "error": None}
except Exception as exc:
# Return empty list instead of crashing — Ollama may not be reachable yet
return {"models": [], "error": str(exc)}
@app.websocket("/api/models/pull")
async def ws_pull_model(websocket: WebSocket, model_name: str):
await websocket.accept()
try:
async with httpx.AsyncClient(timeout=600) as client:
async with client.stream(
"POST", f"{OLLAMA_URL}/api/pull",
json={"name": model_name, "stream": True}
) as resp:
async for line in resp.aiter_lines():
if line.strip():
try:
await websocket.send_json(json.loads(line))
except json.JSONDecodeError:
pass
await websocket.send_json({"status": "success"})
except WebSocketDisconnect:
pass
except Exception as exc:
try:
await websocket.send_json({"status": "error", "error": str(exc)})
except Exception:
pass
@app.delete("/api/models/{model_name:path}")
async def delete_model(model_name: str):
try:
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.request(
"DELETE", f"{OLLAMA_URL}/api/delete",
json={"name": model_name}
)
resp.raise_for_status()
return {"deleted": model_name}
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
# ──────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8080, reload=True)