import asyncio import zipfile import json import os import tempfile import threading from pathlib import Path from typing import List, Optional import httpx import yaml from fastapi import FastAPI, File, HTTPException, UploadFile, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from gpu import get_gpu_stats from pipeline import STAGE_DIRS, CONFIG_PATH, ingest_cmd, create_cmd, curate_cmd, save_as_cmd, train_cmd from ssh_client import ssh_manager # ────────────────────────────────────────────────────────────────────────────── app = FastAPI(title="LLM Trainer API", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434") # ────────────────────────────────────────────────────────────────────────────── # Pydantic models # ────────────────────────────────────────────────────────────────────────────── class ConnectRequest(BaseModel): host: str = "192.168.2.47" username: str = "tocmo0nlord" password: Optional[str] = None key_path: Optional[str] = None port: int = 22 class TrainRequest(BaseModel): model_name: str = "llama3.1:8b" dataset_path: str output_dir: str = "/opt/synthetic/output" num_epochs: int = 3 batch_size: int = 2 learning_rate: float = 2e-4 # ────────────────────────────────────────────────────────────────────────────── # Helpers # ────────────────────────────────────────────────────────────────────────────── def _require_ssh(): if not ssh_manager.is_connected(): raise HTTPException(status_code=503, detail="Not connected to SSH server") async def _stream_ws(websocket: WebSocket, command: str, use_conda: bool = True): """Run a remote command and stream output lines over WebSocket.""" await websocket.accept() loop = asyncio.get_event_loop() queue: asyncio.Queue = asyncio.Queue() def _worker(): try: for line in ssh_manager.execute_stream(command, use_conda=use_conda): asyncio.run_coroutine_threadsafe( queue.put({"type": "log", "data": line}), loop ) asyncio.run_coroutine_threadsafe( queue.put({"type": "done", "data": "Command completed."}), loop ) except Exception as exc: asyncio.run_coroutine_threadsafe( queue.put({"type": "error", "data": str(exc)}), loop ) threading.Thread(target=_worker, daemon=True).start() try: while True: msg = await queue.get() await websocket.send_json(msg) if msg["type"] in ("done", "error"): break except WebSocketDisconnect: pass # ────────────────────────────────────────────────────────────────────────────── # Connection # ────────────────────────────────────────────────────────────────────────────── @app.post("/api/connect") async def connect(req: ConnectRequest): try: ssh_manager.connect( host=req.host, username=req.username, password=req.password, key_path=req.key_path, port=req.port, ) return {"status": "connected", "host": req.host, "username": req.username} except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) @app.post("/api/disconnect") async def disconnect(): ssh_manager.disconnect() return {"status": "disconnected"} @app.get("/api/status") async def status(): connected = ssh_manager.is_connected() gpu = get_gpu_stats() if connected else {"gpus": [], "error": "Not connected"} return { "connected": connected, "host": ssh_manager.host if connected else None, "username": ssh_manager.username if connected else None, "gpu": gpu, } # ────────────────────────────────────────────────────────────────────────────── # GPU # ────────────────────────────────────────────────────────────────────────────── @app.get("/api/gpu") async def gpu(): _require_ssh() return get_gpu_stats() # ────────────────────────────────────────────────────────────────────────────── # File management # ────────────────────────────────────────────────────────────────────────────── @app.get("/api/files/{stage}") async def list_files(stage: str): if stage not in STAGE_DIRS: raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}") _require_ssh() out, _, code = ssh_manager.execute( f"ls -la '{STAGE_DIRS[stage]}' 2>/dev/null | tail -n +2", use_conda=False ) files = [] for line in out.strip().split("\n"): if not line.strip() or line.startswith("total"): continue parts = line.split() if len(parts) >= 9 and not parts[0].startswith("d"): files.append({ "name": " ".join(parts[8:]), "size": int(parts[4]), "modified": f"{parts[5]} {parts[6]} {parts[7]}", }) return {"stage": stage, "directory": STAGE_DIRS[stage], "files": files} @app.delete("/api/files/{stage}/{filename}") async def delete_file(stage: str, filename: str): if stage not in STAGE_DIRS: raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}") _require_ssh() path = f"{STAGE_DIRS[stage]}/{filename}" _, err, code = ssh_manager.execute(f"rm -f '{path}'", use_conda=False) if code != 0: raise HTTPException(status_code=500, detail=err) return {"deleted": filename} @app.get("/api/files/{stage}/{filename}/preview") async def preview_file(stage: str, filename: str, lines: int = 120): if stage not in STAGE_DIRS: raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}") _require_ssh() path = f"{STAGE_DIRS[stage]}/{filename}" out, err, code = ssh_manager.execute(f"head -n {lines} '{path}'", use_conda=False) if code != 0: raise HTTPException(status_code=500, detail=err) return {"filename": filename, "content": out} @app.post("/api/upload") async def upload_files(files: List[UploadFile] = File(...)): _require_ssh() results = [] for file in files: suffix = Path(file.filename).suffix.lower() with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: tmp.write(await file.read()) tmp_path = tmp.name try: if suffix == ".zip": extracted = [] with zipfile.ZipFile(tmp_path) as zf: for member in zf.infolist(): if member.is_dir(): continue name = Path(member.filename).name if not name or name.startswith("."): continue with zf.open(member) as src, tempfile.NamedTemporaryFile(delete=False, suffix=Path(name).suffix) as out_tmp: out_tmp.write(src.read()) out_tmp_path = out_tmp.name try: remote_path = f"{STAGE_DIRS['input']}/{name}" ssh_manager.upload_file(out_tmp_path, remote_path) extracted.append(name) finally: os.unlink(out_tmp_path) results.append({"file": file.filename, "action": "extracted", "files": extracted}) else: remote_path = f"{STAGE_DIRS['input']}/{file.filename}" ssh_manager.upload_file(tmp_path, remote_path) results.append({"file": file.filename, "action": "uploaded", "remote_path": remote_path}) finally: os.unlink(tmp_path) return {"results": results} # ────────────────────────────────────────────────────────────────────────────── # Pipeline (WebSocket streaming) # ────────────────────────────────────────────────────────────────────────────── @app.websocket("/api/pipeline/ingest") async def ws_ingest(websocket: WebSocket, filename: str): if not ssh_manager.is_connected(): await websocket.accept() await websocket.send_json({"type": "error", "data": "Not connected"}) return cmd = ingest_cmd(f"{STAGE_DIRS['input']}/{filename}") await _stream_ws(websocket, cmd) @app.websocket("/api/pipeline/create") async def ws_create(websocket: WebSocket, filename: str, num_pairs: int = 50, pair_type: str = "qa"): if not ssh_manager.is_connected(): await websocket.accept() await websocket.send_json({"type": "error", "data": "Not connected"}) return cmd = create_cmd(f"{STAGE_DIRS['parsed']}/{filename}", num_pairs, pair_type) await _stream_ws(websocket, cmd) @app.websocket("/api/pipeline/curate") async def ws_curate(websocket: WebSocket, filename: str, output_filename: str, threshold: float = 7.0): if not ssh_manager.is_connected(): await websocket.accept() await websocket.send_json({"type": "error", "data": "Not connected"}) return cmd = curate_cmd( f"{STAGE_DIRS['generated']}/{filename}", f"{STAGE_DIRS['curated']}/{output_filename}", threshold, ) await _stream_ws(websocket, cmd) @app.websocket("/api/pipeline/save") async def ws_save(websocket: WebSocket, filename: str, output_filename: str, fmt: str = "jsonl"): if not ssh_manager.is_connected(): await websocket.accept() await websocket.send_json({"type": "error", "data": "Not connected"}) return cmd = save_as_cmd( f"{STAGE_DIRS['curated']}/{filename}", f"{STAGE_DIRS['final']}/{output_filename}", fmt, ) await _stream_ws(websocket, cmd) # ────────────────────────────────────────────────────────────────────────────── # QA Pairs viewer # ────────────────────────────────────────────────────────────────────────────── @app.get("/api/pairs/{filename}") async def get_pairs(filename: str, stage: str = "generated"): _require_ssh() path = f"{STAGE_DIRS.get(stage, STAGE_DIRS['generated'])}/{filename}" out, err, code = ssh_manager.execute(f"cat '{path}'", use_conda=False) if code != 0: raise HTTPException(status_code=404, detail=f"File not found: {filename}") pairs = [] for line in out.strip().split("\n"): if not line.strip(): continue try: pairs.append(json.loads(line)) except json.JSONDecodeError: pass return {"filename": filename, "count": len(pairs), "pairs": pairs} # ────────────────────────────────────────────────────────────────────────────── # Config editor # ────────────────────────────────────────────────────────────────────────────── @app.get("/api/config") async def get_config(): _require_ssh() try: raw = ssh_manager.read_remote_file(CONFIG_PATH) return {"config": yaml.safe_load(raw), "raw": raw} except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) @app.put("/api/config") async def update_config(payload: dict): _require_ssh() try: ssh_manager.write_remote_file(CONFIG_PATH, yaml.dump(payload, default_flow_style=False)) return {"status": "updated"} except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) # ────────────────────────────────────────────────────────────────────────────── # Training (WebSocket streaming) # ────────────────────────────────────────────────────────────────────────────── @app.websocket("/api/train") async def ws_train( websocket: WebSocket, model_name: str = "llama3.1:8b", dataset_path: str = "", output_dir: str = "/opt/synthetic/output", num_epochs: int = 3, batch_size: int = 2, learning_rate: float = 2e-4, ): if not ssh_manager.is_connected(): await websocket.accept() await websocket.send_json({"type": "error", "data": "Not connected"}) return cmd = train_cmd(model_name, dataset_path, output_dir, num_epochs, batch_size, learning_rate) await _stream_ws(websocket, cmd) # ────────────────────────────────────────────────────────────────────────────── # Interactive terminal (xterm.js ↔ SSH shell) # ────────────────────────────────────────────────────────────────────────────── @app.websocket("/api/terminal") async def ws_terminal(websocket: WebSocket): await websocket.accept() if not ssh_manager.is_connected(): await websocket.send_text("\r\nNot connected to SSH server.\r\n") return channel = None try: channel = ssh_manager.open_shell_channel() async def ssh_to_ws(): while True: if channel.recv_ready(): data = channel.recv(4096) if not data: break await websocket.send_bytes(data) elif channel.exit_status_ready(): break else: await asyncio.sleep(0.02) async def ws_to_ssh(): try: while True: msg = await websocket.receive() if "bytes" in msg and msg["bytes"]: channel.send(msg["bytes"]) elif "text" in msg and msg["text"]: channel.send(msg["text"].encode()) except WebSocketDisconnect: pass await asyncio.gather(ssh_to_ws(), ws_to_ssh()) except WebSocketDisconnect: pass except Exception as exc: try: await websocket.send_text(f"\r\nError: {exc}\r\n") except Exception: pass finally: if channel: try: channel.close() except Exception: pass # ────────────────────────────────────────────────────────────────────────────── # Model manager (Ollama) # ────────────────────────────────────────────────────────────────────────────── @app.get("/api/models") async def list_models(): try: async with httpx.AsyncClient(timeout=10) as client: resp = await client.get(f"{OLLAMA_URL}/api/tags") resp.raise_for_status() return {"models": resp.json().get("models", []), "error": None} except Exception as exc: # Return empty list instead of crashing — Ollama may not be reachable yet return {"models": [], "error": str(exc)} @app.websocket("/api/models/pull") async def ws_pull_model(websocket: WebSocket, model_name: str): await websocket.accept() try: async with httpx.AsyncClient(timeout=600) as client: async with client.stream( "POST", f"{OLLAMA_URL}/api/pull", json={"name": model_name, "stream": True} ) as resp: async for line in resp.aiter_lines(): if line.strip(): try: await websocket.send_json(json.loads(line)) except json.JSONDecodeError: pass await websocket.send_json({"status": "success"}) except WebSocketDisconnect: pass except Exception as exc: try: await websocket.send_json({"status": "error", "error": str(exc)}) except Exception: pass @app.delete("/api/models/{model_name:path}") async def delete_model(model_name: str): try: async with httpx.AsyncClient(timeout=30) as client: resp = await client.request( "DELETE", f"{OLLAMA_URL}/api/delete", json={"name": model_name} ) resp.raise_for_status() return {"deleted": model_name} except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) # ────────────────────────────────────────────────────────────────────────────── if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=8080, reload=True)