feat(api): /api/setup/check + /api/setup/bootstrap WS
This commit is contained in:
103
backend/main.py
103
backend/main.py
@@ -233,6 +233,109 @@ async def upload_files(files: List[UploadFile] = File(...)):
|
|||||||
return {"results": results}
|
return {"results": results}
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Setup / Bootstrap
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Where bootstrap assets live in the .deb install layout
|
||||||
|
ASSETS_DIR = Path(os.getenv("LLM_TRAINER_ASSETS", "/opt/llm-trainer/remote"))
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/setup/check")
|
||||||
|
async def setup_check():
|
||||||
|
"""Detect what's already set up on the remote GPU host."""
|
||||||
|
_require_ssh()
|
||||||
|
user = ssh_manager.username
|
||||||
|
|
||||||
|
checks = {}
|
||||||
|
|
||||||
|
# Probe each component with a single shell command, parse the result
|
||||||
|
probe = (
|
||||||
|
f"if [ -x $HOME/miniconda3/bin/conda ]; then echo 'conda=ok'; else echo 'conda=missing'; fi; "
|
||||||
|
f"if $HOME/miniconda3/bin/conda env list 2>/dev/null | awk '{{print $1}}' | grep -qx synthetic-data; "
|
||||||
|
f"then echo 'env=ok'; else echo 'env=missing'; fi; "
|
||||||
|
f"if [ -x $HOME/miniconda3/envs/synthetic-data/bin/synthetic-data-kit ]; "
|
||||||
|
f"then echo 'sdk=ok'; else echo 'sdk=missing'; fi; "
|
||||||
|
f"if [ -f /opt/synthetic/train.py ] || [ -f $HOME/synthetic/train.py ]; "
|
||||||
|
f"then echo 'train_py=ok'; else echo 'train_py=missing'; fi; "
|
||||||
|
f"if [ -d /opt/synthetic/synthetic-data-kit/data/input ] || [ -d $HOME/synthetic/synthetic-data-kit/data/input ]; "
|
||||||
|
f"then echo 'data_dirs=ok'; else echo 'data_dirs=missing'; fi; "
|
||||||
|
f"if command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi -L >/dev/null 2>&1; "
|
||||||
|
f"then echo 'gpu=ok'; else echo 'gpu=missing'; fi; "
|
||||||
|
f"if $HOME/miniconda3/envs/synthetic-data/bin/python -c 'import torch,transformers,peft,trl' 2>/dev/null; "
|
||||||
|
f"then echo 'training_deps=ok'; else echo 'training_deps=missing'; fi"
|
||||||
|
)
|
||||||
|
out, err, _ = ssh_manager.execute(probe, use_conda=False)
|
||||||
|
for line in out.splitlines():
|
||||||
|
if "=" in line:
|
||||||
|
k, v = line.strip().split("=", 1)
|
||||||
|
checks[k] = v == "ok"
|
||||||
|
|
||||||
|
ready = all(checks.get(k, False) for k in
|
||||||
|
("conda", "env", "sdk", "train_py", "data_dirs", "training_deps"))
|
||||||
|
return {"checks": checks, "ready": ready, "user": user}
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/api/setup/bootstrap")
|
||||||
|
async def ws_bootstrap(websocket: WebSocket):
|
||||||
|
"""Upload bootstrap.sh + train.py + config.yaml to the remote and run it."""
|
||||||
|
await websocket.accept()
|
||||||
|
if not ssh_manager.is_connected():
|
||||||
|
await websocket.send_json({"type": "error", "data": "Not connected"})
|
||||||
|
return
|
||||||
|
|
||||||
|
if not ASSETS_DIR.is_dir():
|
||||||
|
await websocket.send_json({"type": "error",
|
||||||
|
"data": f"Bootstrap assets not found at {ASSETS_DIR}"})
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
remote_dir = "/tmp/llm-trainer-bootstrap"
|
||||||
|
ssh_manager.execute(f"mkdir -p {remote_dir}", use_conda=False)
|
||||||
|
for name in ("bootstrap.sh", "train.py", "config.yaml"):
|
||||||
|
src = ASSETS_DIR / name
|
||||||
|
if not src.is_file():
|
||||||
|
await websocket.send_json({"type": "error",
|
||||||
|
"data": f"Missing asset: {name}"})
|
||||||
|
return
|
||||||
|
ssh_manager.upload_file(str(src), f"{remote_dir}/{name}")
|
||||||
|
ssh_manager.execute(f"chmod +x {remote_dir}/bootstrap.sh", use_conda=False)
|
||||||
|
except Exception as exc:
|
||||||
|
await websocket.send_json({"type": "error",
|
||||||
|
"data": f"Asset upload failed: {exc}"})
|
||||||
|
return
|
||||||
|
|
||||||
|
# Stream bootstrap output
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
|
||||||
|
def _worker():
|
||||||
|
try:
|
||||||
|
for line in ssh_manager.execute_stream(
|
||||||
|
f"bash {remote_dir}/bootstrap.sh", use_conda=False
|
||||||
|
):
|
||||||
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
queue.put({"type": "log", "data": line}), loop
|
||||||
|
)
|
||||||
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
queue.put({"type": "done", "data": "Bootstrap complete."}), loop
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
queue.put({"type": "error", "data": str(exc)}), loop
|
||||||
|
)
|
||||||
|
|
||||||
|
threading.Thread(target=_worker, daemon=True).start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
msg = await queue.get()
|
||||||
|
await websocket.send_json(msg)
|
||||||
|
if msg["type"] in ("done", "error"):
|
||||||
|
break
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────────
|
||||||
# Pipeline (WebSocket streaming)
|
# Pipeline (WebSocket streaming)
|
||||||
# ──────────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────────
|
||||||
|
|||||||
Reference in New Issue
Block a user