580 lines
24 KiB
Python
580 lines
24 KiB
Python
import asyncio
|
|
import zipfile
|
|
import json
|
|
import os
|
|
import tempfile
|
|
import threading
|
|
from pathlib import Path
|
|
from typing import List, Optional
|
|
|
|
import httpx
|
|
import yaml
|
|
from fastapi import FastAPI, File, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
|
|
from gpu import get_gpu_stats
|
|
from pipeline import STAGE_DIRS, CONFIG_PATH, ingest_cmd, create_cmd, curate_cmd, save_as_cmd, train_cmd
|
|
from ssh_client import ssh_manager
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
app = FastAPI(title="LLM Trainer API", version="1.0.0")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434")
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
# Pydantic models
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
class ConnectRequest(BaseModel):
|
|
host: str = "192.168.2.47"
|
|
username: str = "tocmo0nlord"
|
|
password: Optional[str] = None
|
|
key_path: Optional[str] = None
|
|
port: int = 22
|
|
|
|
|
|
class TrainRequest(BaseModel):
|
|
model_name: str = "llama3.1:8b"
|
|
dataset_path: str
|
|
output_dir: str = "/opt/synthetic/output"
|
|
num_epochs: int = 3
|
|
batch_size: int = 2
|
|
learning_rate: float = 2e-4
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
# Helpers
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
def _require_ssh():
|
|
if not ssh_manager.is_connected():
|
|
raise HTTPException(status_code=503, detail="Not connected to SSH server")
|
|
|
|
|
|
async def _stream_ws(websocket: WebSocket, command: str, use_conda: bool = True):
|
|
"""Run a remote command and stream output lines over WebSocket."""
|
|
await websocket.accept()
|
|
loop = asyncio.get_event_loop()
|
|
queue: asyncio.Queue = asyncio.Queue()
|
|
|
|
def _worker():
|
|
try:
|
|
for line in ssh_manager.execute_stream(command, use_conda=use_conda):
|
|
asyncio.run_coroutine_threadsafe(
|
|
queue.put({"type": "log", "data": line}), loop
|
|
)
|
|
asyncio.run_coroutine_threadsafe(
|
|
queue.put({"type": "done", "data": "Command completed."}), loop
|
|
)
|
|
except Exception as exc:
|
|
asyncio.run_coroutine_threadsafe(
|
|
queue.put({"type": "error", "data": str(exc)}), loop
|
|
)
|
|
|
|
threading.Thread(target=_worker, daemon=True).start()
|
|
|
|
try:
|
|
while True:
|
|
msg = await queue.get()
|
|
await websocket.send_json(msg)
|
|
if msg["type"] in ("done", "error"):
|
|
break
|
|
except WebSocketDisconnect:
|
|
pass
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
# Connection
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
@app.post("/api/connect")
|
|
async def connect(req: ConnectRequest):
|
|
try:
|
|
ssh_manager.connect(
|
|
host=req.host,
|
|
username=req.username,
|
|
password=req.password,
|
|
key_path=req.key_path,
|
|
port=req.port,
|
|
)
|
|
return {"status": "connected", "host": req.host, "username": req.username}
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|
|
@app.post("/api/disconnect")
|
|
async def disconnect():
|
|
ssh_manager.disconnect()
|
|
return {"status": "disconnected"}
|
|
|
|
|
|
@app.get("/api/status")
|
|
async def status():
|
|
connected = ssh_manager.is_connected()
|
|
gpu = get_gpu_stats() if connected else {"gpus": [], "error": "Not connected"}
|
|
return {
|
|
"connected": connected,
|
|
"host": ssh_manager.host if connected else None,
|
|
"username": ssh_manager.username if connected else None,
|
|
"gpu": gpu,
|
|
}
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
# GPU
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
@app.get("/api/gpu")
|
|
async def gpu():
|
|
_require_ssh()
|
|
return get_gpu_stats()
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
# File management
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
@app.get("/api/files/{stage}")
|
|
async def list_files(stage: str):
|
|
if stage not in STAGE_DIRS:
|
|
raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}")
|
|
_require_ssh()
|
|
|
|
out, _, code = ssh_manager.execute(
|
|
f"ls -la '{STAGE_DIRS[stage]}' 2>/dev/null | tail -n +2", use_conda=False
|
|
)
|
|
|
|
files = []
|
|
for line in out.strip().split("\n"):
|
|
if not line.strip() or line.startswith("total"):
|
|
continue
|
|
parts = line.split()
|
|
if len(parts) >= 9 and not parts[0].startswith("d"):
|
|
files.append({
|
|
"name": " ".join(parts[8:]),
|
|
"size": int(parts[4]),
|
|
"modified": f"{parts[5]} {parts[6]} {parts[7]}",
|
|
})
|
|
|
|
return {"stage": stage, "directory": STAGE_DIRS[stage], "files": files}
|
|
|
|
|
|
@app.delete("/api/files/{stage}/{filename}")
|
|
async def delete_file(stage: str, filename: str):
|
|
if stage not in STAGE_DIRS:
|
|
raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}")
|
|
_require_ssh()
|
|
|
|
path = f"{STAGE_DIRS[stage]}/{filename}"
|
|
_, err, code = ssh_manager.execute(f"rm -f '{path}'", use_conda=False)
|
|
if code != 0:
|
|
raise HTTPException(status_code=500, detail=err)
|
|
return {"deleted": filename}
|
|
|
|
|
|
@app.get("/api/files/{stage}/{filename}/preview")
|
|
async def preview_file(stage: str, filename: str, lines: int = 120):
|
|
if stage not in STAGE_DIRS:
|
|
raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}")
|
|
_require_ssh()
|
|
|
|
path = f"{STAGE_DIRS[stage]}/{filename}"
|
|
out, err, code = ssh_manager.execute(f"head -n {lines} '{path}'", use_conda=False)
|
|
if code != 0:
|
|
raise HTTPException(status_code=500, detail=err)
|
|
return {"filename": filename, "content": out}
|
|
|
|
|
|
@app.post("/api/upload")
|
|
async def upload_files(files: List[UploadFile] = File(...)):
|
|
_require_ssh()
|
|
results = []
|
|
for file in files:
|
|
suffix = Path(file.filename).suffix.lower()
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
|
tmp.write(await file.read())
|
|
tmp_path = tmp.name
|
|
try:
|
|
if suffix == ".zip":
|
|
extracted = []
|
|
with zipfile.ZipFile(tmp_path) as zf:
|
|
for member in zf.infolist():
|
|
if member.is_dir():
|
|
continue
|
|
name = Path(member.filename).name
|
|
if not name or name.startswith("."):
|
|
continue
|
|
with zf.open(member) as src, tempfile.NamedTemporaryFile(delete=False, suffix=Path(name).suffix) as out_tmp:
|
|
out_tmp.write(src.read())
|
|
out_tmp_path = out_tmp.name
|
|
try:
|
|
remote_path = f"{STAGE_DIRS['input']}/{name}"
|
|
ssh_manager.upload_file(out_tmp_path, remote_path)
|
|
extracted.append(name)
|
|
finally:
|
|
os.unlink(out_tmp_path)
|
|
results.append({"file": file.filename, "action": "extracted", "files": extracted})
|
|
else:
|
|
remote_path = f"{STAGE_DIRS['input']}/{file.filename}"
|
|
ssh_manager.upload_file(tmp_path, remote_path)
|
|
results.append({"file": file.filename, "action": "uploaded", "remote_path": remote_path})
|
|
finally:
|
|
os.unlink(tmp_path)
|
|
return {"results": results}
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
# Setup / Bootstrap
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
# Where bootstrap assets live in the .deb install layout
|
|
ASSETS_DIR = Path(os.getenv("LLM_TRAINER_ASSETS", "/opt/llm-trainer/remote"))
|
|
|
|
|
|
@app.get("/api/setup/check")
|
|
async def setup_check():
|
|
"""Detect what's already set up on the remote GPU host."""
|
|
_require_ssh()
|
|
user = ssh_manager.username
|
|
|
|
checks = {}
|
|
|
|
# Probe each component with a single shell command, parse the result
|
|
probe = (
|
|
f"if [ -x $HOME/miniconda3/bin/conda ]; then echo 'conda=ok'; else echo 'conda=missing'; fi; "
|
|
f"if $HOME/miniconda3/bin/conda env list 2>/dev/null | awk '{{print $1}}' | grep -qx synthetic-data; "
|
|
f"then echo 'env=ok'; else echo 'env=missing'; fi; "
|
|
f"if [ -x $HOME/miniconda3/envs/synthetic-data/bin/synthetic-data-kit ]; "
|
|
f"then echo 'sdk=ok'; else echo 'sdk=missing'; fi; "
|
|
f"if [ -f /opt/synthetic/train.py ] || [ -f $HOME/synthetic/train.py ]; "
|
|
f"then echo 'train_py=ok'; else echo 'train_py=missing'; fi; "
|
|
f"if [ -d /opt/synthetic/synthetic-data-kit/data/input ] || [ -d $HOME/synthetic/synthetic-data-kit/data/input ]; "
|
|
f"then echo 'data_dirs=ok'; else echo 'data_dirs=missing'; fi; "
|
|
f"if command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi -L >/dev/null 2>&1; "
|
|
f"then echo 'gpu=ok'; else echo 'gpu=missing'; fi; "
|
|
f"if $HOME/miniconda3/envs/synthetic-data/bin/python -c 'import torch,transformers,peft,trl' 2>/dev/null; "
|
|
f"then echo 'training_deps=ok'; else echo 'training_deps=missing'; fi"
|
|
)
|
|
out, err, _ = ssh_manager.execute(probe, use_conda=False)
|
|
for line in out.splitlines():
|
|
if "=" in line:
|
|
k, v = line.strip().split("=", 1)
|
|
checks[k] = v == "ok"
|
|
|
|
ready = all(checks.get(k, False) for k in
|
|
("conda", "env", "sdk", "train_py", "data_dirs", "training_deps"))
|
|
return {"checks": checks, "ready": ready, "user": user}
|
|
|
|
|
|
@app.websocket("/api/setup/bootstrap")
|
|
async def ws_bootstrap(websocket: WebSocket):
|
|
"""Upload bootstrap.sh + train.py + config.yaml to the remote and run it."""
|
|
await websocket.accept()
|
|
if not ssh_manager.is_connected():
|
|
await websocket.send_json({"type": "error", "data": "Not connected"})
|
|
return
|
|
|
|
if not ASSETS_DIR.is_dir():
|
|
await websocket.send_json({"type": "error",
|
|
"data": f"Bootstrap assets not found at {ASSETS_DIR}"})
|
|
return
|
|
|
|
try:
|
|
remote_dir = "/tmp/llm-trainer-bootstrap"
|
|
ssh_manager.execute(f"mkdir -p {remote_dir}", use_conda=False)
|
|
for name in ("bootstrap.sh", "train.py", "config.yaml"):
|
|
src = ASSETS_DIR / name
|
|
if not src.is_file():
|
|
await websocket.send_json({"type": "error",
|
|
"data": f"Missing asset: {name}"})
|
|
return
|
|
ssh_manager.upload_file(str(src), f"{remote_dir}/{name}")
|
|
ssh_manager.execute(f"chmod +x {remote_dir}/bootstrap.sh", use_conda=False)
|
|
except Exception as exc:
|
|
await websocket.send_json({"type": "error",
|
|
"data": f"Asset upload failed: {exc}"})
|
|
return
|
|
|
|
# Stream bootstrap output
|
|
loop = asyncio.get_event_loop()
|
|
queue: asyncio.Queue = asyncio.Queue()
|
|
|
|
def _worker():
|
|
try:
|
|
for line in ssh_manager.execute_stream(
|
|
f"bash {remote_dir}/bootstrap.sh", use_conda=False
|
|
):
|
|
asyncio.run_coroutine_threadsafe(
|
|
queue.put({"type": "log", "data": line}), loop
|
|
)
|
|
asyncio.run_coroutine_threadsafe(
|
|
queue.put({"type": "done", "data": "Bootstrap complete."}), loop
|
|
)
|
|
except Exception as exc:
|
|
asyncio.run_coroutine_threadsafe(
|
|
queue.put({"type": "error", "data": str(exc)}), loop
|
|
)
|
|
|
|
threading.Thread(target=_worker, daemon=True).start()
|
|
|
|
try:
|
|
while True:
|
|
msg = await queue.get()
|
|
await websocket.send_json(msg)
|
|
if msg["type"] in ("done", "error"):
|
|
break
|
|
except WebSocketDisconnect:
|
|
pass
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
# Pipeline (WebSocket streaming)
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
@app.websocket("/api/pipeline/ingest")
|
|
async def ws_ingest(websocket: WebSocket, filename: str):
|
|
if not ssh_manager.is_connected():
|
|
await websocket.accept()
|
|
await websocket.send_json({"type": "error", "data": "Not connected"})
|
|
return
|
|
cmd = ingest_cmd(f"{STAGE_DIRS['input']}/{filename}")
|
|
await _stream_ws(websocket, cmd)
|
|
|
|
|
|
@app.websocket("/api/pipeline/create")
|
|
async def ws_create(websocket: WebSocket, filename: str,
|
|
num_pairs: int = 50, pair_type: str = "qa"):
|
|
if not ssh_manager.is_connected():
|
|
await websocket.accept()
|
|
await websocket.send_json({"type": "error", "data": "Not connected"})
|
|
return
|
|
cmd = create_cmd(f"{STAGE_DIRS['parsed']}/{filename}", num_pairs, pair_type)
|
|
await _stream_ws(websocket, cmd)
|
|
|
|
|
|
@app.websocket("/api/pipeline/curate")
|
|
async def ws_curate(websocket: WebSocket, filename: str,
|
|
output_filename: str, threshold: float = 7.0):
|
|
if not ssh_manager.is_connected():
|
|
await websocket.accept()
|
|
await websocket.send_json({"type": "error", "data": "Not connected"})
|
|
return
|
|
cmd = curate_cmd(
|
|
f"{STAGE_DIRS['generated']}/{filename}",
|
|
f"{STAGE_DIRS['curated']}/{output_filename}",
|
|
threshold,
|
|
)
|
|
await _stream_ws(websocket, cmd)
|
|
|
|
|
|
@app.websocket("/api/pipeline/save")
|
|
async def ws_save(websocket: WebSocket, filename: str,
|
|
output_filename: str, fmt: str = "jsonl"):
|
|
if not ssh_manager.is_connected():
|
|
await websocket.accept()
|
|
await websocket.send_json({"type": "error", "data": "Not connected"})
|
|
return
|
|
cmd = save_as_cmd(
|
|
f"{STAGE_DIRS['curated']}/{filename}",
|
|
f"{STAGE_DIRS['final']}/{output_filename}",
|
|
fmt,
|
|
)
|
|
await _stream_ws(websocket, cmd)
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
# QA Pairs viewer
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
@app.get("/api/pairs/{filename}")
|
|
async def get_pairs(filename: str, stage: str = "generated"):
|
|
_require_ssh()
|
|
path = f"{STAGE_DIRS.get(stage, STAGE_DIRS['generated'])}/{filename}"
|
|
out, err, code = ssh_manager.execute(f"cat '{path}'", use_conda=False)
|
|
if code != 0:
|
|
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
|
|
|
|
pairs = []
|
|
for line in out.strip().split("\n"):
|
|
if not line.strip():
|
|
continue
|
|
try:
|
|
pairs.append(json.loads(line))
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
return {"filename": filename, "count": len(pairs), "pairs": pairs}
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
# Config editor
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
@app.get("/api/config")
|
|
async def get_config():
|
|
_require_ssh()
|
|
try:
|
|
raw = ssh_manager.read_remote_file(CONFIG_PATH)
|
|
return {"config": yaml.safe_load(raw), "raw": raw}
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|
|
@app.put("/api/config")
|
|
async def update_config(payload: dict):
|
|
_require_ssh()
|
|
try:
|
|
ssh_manager.write_remote_file(CONFIG_PATH, yaml.dump(payload, default_flow_style=False))
|
|
return {"status": "updated"}
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
# Training (WebSocket streaming)
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
@app.websocket("/api/train")
|
|
async def ws_train(
|
|
websocket: WebSocket,
|
|
model_name: str = "llama3.1:8b",
|
|
dataset_path: str = "",
|
|
output_dir: str = "/opt/synthetic/output",
|
|
num_epochs: int = 3,
|
|
batch_size: int = 2,
|
|
learning_rate: float = 2e-4,
|
|
):
|
|
if not ssh_manager.is_connected():
|
|
await websocket.accept()
|
|
await websocket.send_json({"type": "error", "data": "Not connected"})
|
|
return
|
|
|
|
cmd = train_cmd(model_name, dataset_path, output_dir, num_epochs, batch_size, learning_rate)
|
|
await _stream_ws(websocket, cmd)
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
# Interactive terminal (xterm.js ↔ SSH shell)
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
@app.websocket("/api/terminal")
|
|
async def ws_terminal(websocket: WebSocket):
|
|
await websocket.accept()
|
|
|
|
if not ssh_manager.is_connected():
|
|
await websocket.send_text("\r\nNot connected to SSH server.\r\n")
|
|
return
|
|
|
|
channel = None
|
|
try:
|
|
channel = ssh_manager.open_shell_channel()
|
|
|
|
async def ssh_to_ws():
|
|
while True:
|
|
if channel.recv_ready():
|
|
data = channel.recv(4096)
|
|
if not data:
|
|
break
|
|
await websocket.send_bytes(data)
|
|
elif channel.exit_status_ready():
|
|
break
|
|
else:
|
|
await asyncio.sleep(0.02)
|
|
|
|
async def ws_to_ssh():
|
|
try:
|
|
while True:
|
|
msg = await websocket.receive()
|
|
if "bytes" in msg and msg["bytes"]:
|
|
channel.send(msg["bytes"])
|
|
elif "text" in msg and msg["text"]:
|
|
channel.send(msg["text"].encode())
|
|
except WebSocketDisconnect:
|
|
pass
|
|
|
|
await asyncio.gather(ssh_to_ws(), ws_to_ssh())
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except Exception as exc:
|
|
try:
|
|
await websocket.send_text(f"\r\nError: {exc}\r\n")
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
if channel:
|
|
try:
|
|
channel.close()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
# Model manager (Ollama)
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
@app.get("/api/models")
|
|
async def list_models():
|
|
try:
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
resp = await client.get(f"{OLLAMA_URL}/api/tags")
|
|
resp.raise_for_status()
|
|
return {"models": resp.json().get("models", []), "error": None}
|
|
except Exception as exc:
|
|
# Return empty list instead of crashing — Ollama may not be reachable yet
|
|
return {"models": [], "error": str(exc)}
|
|
|
|
|
|
@app.websocket("/api/models/pull")
|
|
async def ws_pull_model(websocket: WebSocket, model_name: str):
|
|
await websocket.accept()
|
|
try:
|
|
async with httpx.AsyncClient(timeout=600) as client:
|
|
async with client.stream(
|
|
"POST", f"{OLLAMA_URL}/api/pull",
|
|
json={"name": model_name, "stream": True}
|
|
) as resp:
|
|
async for line in resp.aiter_lines():
|
|
if line.strip():
|
|
try:
|
|
await websocket.send_json(json.loads(line))
|
|
except json.JSONDecodeError:
|
|
pass
|
|
await websocket.send_json({"status": "success"})
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except Exception as exc:
|
|
try:
|
|
await websocket.send_json({"status": "error", "error": str(exc)})
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
@app.delete("/api/models/{model_name:path}")
|
|
async def delete_model(model_name: str):
|
|
try:
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
resp = await client.request(
|
|
"DELETE", f"{OLLAMA_URL}/api/delete",
|
|
json={"name": model_name}
|
|
)
|
|
resp.raise_for_status()
|
|
return {"deleted": model_name}
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────────────
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8080, reload=True)
|