From 25f6a07fd644fbe7da20f190ffb82ad6075146c2 Mon Sep 17 00:00:00 2001 From: tocmo0nlord Date: Sun, 26 Apr 2026 02:01:34 +0000 Subject: [PATCH] feat(api): /api/setup/check + /api/setup/bootstrap WS --- backend/main.py | 103 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/backend/main.py b/backend/main.py index 0e83661..d925ea4 100644 --- a/backend/main.py +++ b/backend/main.py @@ -233,6 +233,109 @@ async def upload_files(files: List[UploadFile] = File(...)): return {"results": results} +# ────────────────────────────────────────────────────────────────────────────── +# Setup / Bootstrap +# ────────────────────────────────────────────────────────────────────────────── + +# Where bootstrap assets live in the .deb install layout +ASSETS_DIR = Path(os.getenv("LLM_TRAINER_ASSETS", "/opt/llm-trainer/remote")) + + +@app.get("/api/setup/check") +async def setup_check(): + """Detect what's already set up on the remote GPU host.""" + _require_ssh() + user = ssh_manager.username + + checks = {} + + # Probe each component with a single shell command, parse the result + probe = ( + f"if [ -x $HOME/miniconda3/bin/conda ]; then echo 'conda=ok'; else echo 'conda=missing'; fi; " + f"if $HOME/miniconda3/bin/conda env list 2>/dev/null | awk '{{print $1}}' | grep -qx synthetic-data; " + f"then echo 'env=ok'; else echo 'env=missing'; fi; " + f"if [ -x $HOME/miniconda3/envs/synthetic-data/bin/synthetic-data-kit ]; " + f"then echo 'sdk=ok'; else echo 'sdk=missing'; fi; " + f"if [ -f /opt/synthetic/train.py ] || [ -f $HOME/synthetic/train.py ]; " + f"then echo 'train_py=ok'; else echo 'train_py=missing'; fi; " + f"if [ -d /opt/synthetic/synthetic-data-kit/data/input ] || [ -d $HOME/synthetic/synthetic-data-kit/data/input ]; " + f"then echo 'data_dirs=ok'; else echo 'data_dirs=missing'; fi; " + f"if command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi -L >/dev/null 2>&1; " + f"then echo 'gpu=ok'; else echo 'gpu=missing'; fi; " + f"if $HOME/miniconda3/envs/synthetic-data/bin/python -c 'import torch,transformers,peft,trl' 2>/dev/null; " + f"then echo 'training_deps=ok'; else echo 'training_deps=missing'; fi" + ) + out, err, _ = ssh_manager.execute(probe, use_conda=False) + for line in out.splitlines(): + if "=" in line: + k, v = line.strip().split("=", 1) + checks[k] = v == "ok" + + ready = all(checks.get(k, False) for k in + ("conda", "env", "sdk", "train_py", "data_dirs", "training_deps")) + return {"checks": checks, "ready": ready, "user": user} + + +@app.websocket("/api/setup/bootstrap") +async def ws_bootstrap(websocket: WebSocket): + """Upload bootstrap.sh + train.py + config.yaml to the remote and run it.""" + await websocket.accept() + if not ssh_manager.is_connected(): + await websocket.send_json({"type": "error", "data": "Not connected"}) + return + + if not ASSETS_DIR.is_dir(): + await websocket.send_json({"type": "error", + "data": f"Bootstrap assets not found at {ASSETS_DIR}"}) + return + + try: + remote_dir = "/tmp/llm-trainer-bootstrap" + ssh_manager.execute(f"mkdir -p {remote_dir}", use_conda=False) + for name in ("bootstrap.sh", "train.py", "config.yaml"): + src = ASSETS_DIR / name + if not src.is_file(): + await websocket.send_json({"type": "error", + "data": f"Missing asset: {name}"}) + return + ssh_manager.upload_file(str(src), f"{remote_dir}/{name}") + ssh_manager.execute(f"chmod +x {remote_dir}/bootstrap.sh", use_conda=False) + except Exception as exc: + await websocket.send_json({"type": "error", + "data": f"Asset upload failed: {exc}"}) + return + + # Stream bootstrap output + loop = asyncio.get_event_loop() + queue: asyncio.Queue = asyncio.Queue() + + def _worker(): + try: + for line in ssh_manager.execute_stream( + f"bash {remote_dir}/bootstrap.sh", use_conda=False + ): + asyncio.run_coroutine_threadsafe( + queue.put({"type": "log", "data": line}), loop + ) + asyncio.run_coroutine_threadsafe( + queue.put({"type": "done", "data": "Bootstrap complete."}), loop + ) + except Exception as exc: + asyncio.run_coroutine_threadsafe( + queue.put({"type": "error", "data": str(exc)}), loop + ) + + threading.Thread(target=_worker, daemon=True).start() + + try: + while True: + msg = await queue.get() + await websocket.send_json(msg) + if msg["type"] in ("done", "error"): + break + except WebSocketDisconnect: + pass + # ────────────────────────────────────────────────────────────────────────────── # Pipeline (WebSocket streaming) # ──────────────────────────────────────────────────────────────────────────────