feat(api): /api/setup/check + /api/setup/bootstrap WS

This commit is contained in:
2026-04-26 02:01:34 +00:00
parent bf5519ebb2
commit 25f6a07fd6

View File

@@ -233,6 +233,109 @@ async def upload_files(files: List[UploadFile] = File(...)):
return {"results": results}
# ──────────────────────────────────────────────────────────────────────────────
# Setup / Bootstrap
# ──────────────────────────────────────────────────────────────────────────────
# Where bootstrap assets live in the .deb install layout
ASSETS_DIR = Path(os.getenv("LLM_TRAINER_ASSETS", "/opt/llm-trainer/remote"))
@app.get("/api/setup/check")
async def setup_check():
"""Detect what's already set up on the remote GPU host."""
_require_ssh()
user = ssh_manager.username
checks = {}
# Probe each component with a single shell command, parse the result
probe = (
f"if [ -x $HOME/miniconda3/bin/conda ]; then echo 'conda=ok'; else echo 'conda=missing'; fi; "
f"if $HOME/miniconda3/bin/conda env list 2>/dev/null | awk '{{print $1}}' | grep -qx synthetic-data; "
f"then echo 'env=ok'; else echo 'env=missing'; fi; "
f"if [ -x $HOME/miniconda3/envs/synthetic-data/bin/synthetic-data-kit ]; "
f"then echo 'sdk=ok'; else echo 'sdk=missing'; fi; "
f"if [ -f /opt/synthetic/train.py ] || [ -f $HOME/synthetic/train.py ]; "
f"then echo 'train_py=ok'; else echo 'train_py=missing'; fi; "
f"if [ -d /opt/synthetic/synthetic-data-kit/data/input ] || [ -d $HOME/synthetic/synthetic-data-kit/data/input ]; "
f"then echo 'data_dirs=ok'; else echo 'data_dirs=missing'; fi; "
f"if command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi -L >/dev/null 2>&1; "
f"then echo 'gpu=ok'; else echo 'gpu=missing'; fi; "
f"if $HOME/miniconda3/envs/synthetic-data/bin/python -c 'import torch,transformers,peft,trl' 2>/dev/null; "
f"then echo 'training_deps=ok'; else echo 'training_deps=missing'; fi"
)
out, err, _ = ssh_manager.execute(probe, use_conda=False)
for line in out.splitlines():
if "=" in line:
k, v = line.strip().split("=", 1)
checks[k] = v == "ok"
ready = all(checks.get(k, False) for k in
("conda", "env", "sdk", "train_py", "data_dirs", "training_deps"))
return {"checks": checks, "ready": ready, "user": user}
@app.websocket("/api/setup/bootstrap")
async def ws_bootstrap(websocket: WebSocket):
"""Upload bootstrap.sh + train.py + config.yaml to the remote and run it."""
await websocket.accept()
if not ssh_manager.is_connected():
await websocket.send_json({"type": "error", "data": "Not connected"})
return
if not ASSETS_DIR.is_dir():
await websocket.send_json({"type": "error",
"data": f"Bootstrap assets not found at {ASSETS_DIR}"})
return
try:
remote_dir = "/tmp/llm-trainer-bootstrap"
ssh_manager.execute(f"mkdir -p {remote_dir}", use_conda=False)
for name in ("bootstrap.sh", "train.py", "config.yaml"):
src = ASSETS_DIR / name
if not src.is_file():
await websocket.send_json({"type": "error",
"data": f"Missing asset: {name}"})
return
ssh_manager.upload_file(str(src), f"{remote_dir}/{name}")
ssh_manager.execute(f"chmod +x {remote_dir}/bootstrap.sh", use_conda=False)
except Exception as exc:
await websocket.send_json({"type": "error",
"data": f"Asset upload failed: {exc}"})
return
# Stream bootstrap output
loop = asyncio.get_event_loop()
queue: asyncio.Queue = asyncio.Queue()
def _worker():
try:
for line in ssh_manager.execute_stream(
f"bash {remote_dir}/bootstrap.sh", use_conda=False
):
asyncio.run_coroutine_threadsafe(
queue.put({"type": "log", "data": line}), loop
)
asyncio.run_coroutine_threadsafe(
queue.put({"type": "done", "data": "Bootstrap complete."}), loop
)
except Exception as exc:
asyncio.run_coroutine_threadsafe(
queue.put({"type": "error", "data": str(exc)}), loop
)
threading.Thread(target=_worker, daemon=True).start()
try:
while True:
msg = await queue.get()
await websocket.send_json(msg)
if msg["type"] in ("done", "error"):
break
except WebSocketDisconnect:
pass
# ──────────────────────────────────────────────────────────────────────────────
# Pipeline (WebSocket streaming)
# ──────────────────────────────────────────────────────────────────────────────