feat(api): /api/setup/check + /api/setup/bootstrap WS
This commit is contained in:
103
backend/main.py
103
backend/main.py
@@ -233,6 +233,109 @@ async def upload_files(files: List[UploadFile] = File(...)):
|
||||
return {"results": results}
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Setup / Bootstrap
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
# Where bootstrap assets live in the .deb install layout
|
||||
ASSETS_DIR = Path(os.getenv("LLM_TRAINER_ASSETS", "/opt/llm-trainer/remote"))
|
||||
|
||||
|
||||
@app.get("/api/setup/check")
|
||||
async def setup_check():
|
||||
"""Detect what's already set up on the remote GPU host."""
|
||||
_require_ssh()
|
||||
user = ssh_manager.username
|
||||
|
||||
checks = {}
|
||||
|
||||
# Probe each component with a single shell command, parse the result
|
||||
probe = (
|
||||
f"if [ -x $HOME/miniconda3/bin/conda ]; then echo 'conda=ok'; else echo 'conda=missing'; fi; "
|
||||
f"if $HOME/miniconda3/bin/conda env list 2>/dev/null | awk '{{print $1}}' | grep -qx synthetic-data; "
|
||||
f"then echo 'env=ok'; else echo 'env=missing'; fi; "
|
||||
f"if [ -x $HOME/miniconda3/envs/synthetic-data/bin/synthetic-data-kit ]; "
|
||||
f"then echo 'sdk=ok'; else echo 'sdk=missing'; fi; "
|
||||
f"if [ -f /opt/synthetic/train.py ] || [ -f $HOME/synthetic/train.py ]; "
|
||||
f"then echo 'train_py=ok'; else echo 'train_py=missing'; fi; "
|
||||
f"if [ -d /opt/synthetic/synthetic-data-kit/data/input ] || [ -d $HOME/synthetic/synthetic-data-kit/data/input ]; "
|
||||
f"then echo 'data_dirs=ok'; else echo 'data_dirs=missing'; fi; "
|
||||
f"if command -v nvidia-smi >/dev/null 2>&1 && nvidia-smi -L >/dev/null 2>&1; "
|
||||
f"then echo 'gpu=ok'; else echo 'gpu=missing'; fi; "
|
||||
f"if $HOME/miniconda3/envs/synthetic-data/bin/python -c 'import torch,transformers,peft,trl' 2>/dev/null; "
|
||||
f"then echo 'training_deps=ok'; else echo 'training_deps=missing'; fi"
|
||||
)
|
||||
out, err, _ = ssh_manager.execute(probe, use_conda=False)
|
||||
for line in out.splitlines():
|
||||
if "=" in line:
|
||||
k, v = line.strip().split("=", 1)
|
||||
checks[k] = v == "ok"
|
||||
|
||||
ready = all(checks.get(k, False) for k in
|
||||
("conda", "env", "sdk", "train_py", "data_dirs", "training_deps"))
|
||||
return {"checks": checks, "ready": ready, "user": user}
|
||||
|
||||
|
||||
@app.websocket("/api/setup/bootstrap")
|
||||
async def ws_bootstrap(websocket: WebSocket):
|
||||
"""Upload bootstrap.sh + train.py + config.yaml to the remote and run it."""
|
||||
await websocket.accept()
|
||||
if not ssh_manager.is_connected():
|
||||
await websocket.send_json({"type": "error", "data": "Not connected"})
|
||||
return
|
||||
|
||||
if not ASSETS_DIR.is_dir():
|
||||
await websocket.send_json({"type": "error",
|
||||
"data": f"Bootstrap assets not found at {ASSETS_DIR}"})
|
||||
return
|
||||
|
||||
try:
|
||||
remote_dir = "/tmp/llm-trainer-bootstrap"
|
||||
ssh_manager.execute(f"mkdir -p {remote_dir}", use_conda=False)
|
||||
for name in ("bootstrap.sh", "train.py", "config.yaml"):
|
||||
src = ASSETS_DIR / name
|
||||
if not src.is_file():
|
||||
await websocket.send_json({"type": "error",
|
||||
"data": f"Missing asset: {name}"})
|
||||
return
|
||||
ssh_manager.upload_file(str(src), f"{remote_dir}/{name}")
|
||||
ssh_manager.execute(f"chmod +x {remote_dir}/bootstrap.sh", use_conda=False)
|
||||
except Exception as exc:
|
||||
await websocket.send_json({"type": "error",
|
||||
"data": f"Asset upload failed: {exc}"})
|
||||
return
|
||||
|
||||
# Stream bootstrap output
|
||||
loop = asyncio.get_event_loop()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
def _worker():
|
||||
try:
|
||||
for line in ssh_manager.execute_stream(
|
||||
f"bash {remote_dir}/bootstrap.sh", use_conda=False
|
||||
):
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
queue.put({"type": "log", "data": line}), loop
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
queue.put({"type": "done", "data": "Bootstrap complete."}), loop
|
||||
)
|
||||
except Exception as exc:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
queue.put({"type": "error", "data": str(exc)}), loop
|
||||
)
|
||||
|
||||
threading.Thread(target=_worker, daemon=True).start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
msg = await queue.get()
|
||||
await websocket.send_json(msg)
|
||||
if msg["type"] in ("done", "error"):
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Pipeline (WebSocket streaming)
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
Reference in New Issue
Block a user