Initial scaffold: LLM Trainer Dashboard
Full-stack app with FastAPI backend (SSH/paramiko, pipeline streaming, GPU stats, xterm.js terminal, Ollama model manager) and React + Tailwind frontend (8 panels: Connection, Documents, Pipeline, QA Pairs, Training, Terminal, Models, Config). Docker Compose included. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
12
backend/Dockerfile
Normal file
12
backend/Dockerfile
Normal file
@@ -0,0 +1,12 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"]
|
||||
39
backend/gpu.py
Normal file
39
backend/gpu.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from ssh_client import ssh_manager
|
||||
|
||||
|
||||
def get_gpu_stats() -> dict:
|
||||
"""Query nvidia-smi on the remote host and return parsed GPU info."""
|
||||
try:
|
||||
if not ssh_manager.is_connected():
|
||||
return {"gpus": [], "error": "Not connected"}
|
||||
|
||||
out, err, code = ssh_manager.execute(
|
||||
"nvidia-smi --query-gpu=name,utilization.gpu,memory.used,memory.total,"
|
||||
"temperature.gpu,power.draw --format=csv,noheader,nounits",
|
||||
use_conda=False
|
||||
)
|
||||
|
||||
if code != 0:
|
||||
return {"gpus": [], "error": err.strip() or "nvidia-smi failed"}
|
||||
|
||||
gpus = []
|
||||
for line in out.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
parts = [p.strip() for p in line.split(",")]
|
||||
if len(parts) >= 5:
|
||||
try:
|
||||
gpus.append({
|
||||
"name": parts[0],
|
||||
"utilization": int(parts[1]),
|
||||
"memory_used": int(parts[2]),
|
||||
"memory_total": int(parts[3]),
|
||||
"temperature": int(parts[4]),
|
||||
"power_draw": float(parts[5]) if len(parts) > 5 else None,
|
||||
})
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
return {"gpus": gpus, "error": None}
|
||||
except Exception as e:
|
||||
return {"gpus": [], "error": str(e)}
|
||||
450
backend/main.py
Normal file
450
backend/main.py
Normal file
@@ -0,0 +1,450 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import FastAPI, File, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
from gpu import get_gpu_stats
|
||||
from pipeline import STAGE_DIRS, CONFIG_PATH, ingest_cmd, create_cmd, curate_cmd, save_as_cmd, train_cmd
|
||||
from ssh_client import ssh_manager
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
app = FastAPI(title="LLM Trainer API", version="1.0.0")
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://192.168.2.47:11434")
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Pydantic models
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class ConnectRequest(BaseModel):
|
||||
host: str = "192.168.2.47"
|
||||
username: str = "tocmo0nlord"
|
||||
password: Optional[str] = None
|
||||
key_path: Optional[str] = None
|
||||
port: int = 22
|
||||
|
||||
|
||||
class TrainRequest(BaseModel):
|
||||
model_name: str = "llama3.1:8b"
|
||||
dataset_path: str
|
||||
output_dir: str = "/opt/synthetic/output"
|
||||
num_epochs: int = 3
|
||||
batch_size: int = 2
|
||||
learning_rate: float = 2e-4
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Helpers
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _require_ssh():
|
||||
if not ssh_manager.is_connected():
|
||||
raise HTTPException(status_code=503, detail="Not connected to SSH server")
|
||||
|
||||
|
||||
async def _stream_ws(websocket: WebSocket, command: str, use_conda: bool = True):
|
||||
"""Run a remote command and stream output lines over WebSocket."""
|
||||
await websocket.accept()
|
||||
loop = asyncio.get_event_loop()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
def _worker():
|
||||
try:
|
||||
for line in ssh_manager.execute_stream(command, use_conda=use_conda):
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
queue.put({"type": "log", "data": line}), loop
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
queue.put({"type": "done", "data": "Command completed."}), loop
|
||||
)
|
||||
except Exception as exc:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
queue.put({"type": "error", "data": str(exc)}), loop
|
||||
)
|
||||
|
||||
threading.Thread(target=_worker, daemon=True).start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
msg = await queue.get()
|
||||
await websocket.send_json(msg)
|
||||
if msg["type"] in ("done", "error"):
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Connection
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.post("/api/connect")
|
||||
async def connect(req: ConnectRequest):
|
||||
try:
|
||||
ssh_manager.connect(
|
||||
host=req.host,
|
||||
username=req.username,
|
||||
password=req.password,
|
||||
key_path=req.key_path,
|
||||
port=req.port,
|
||||
)
|
||||
return {"status": "connected", "host": req.host, "username": req.username}
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@app.post("/api/disconnect")
|
||||
async def disconnect():
|
||||
ssh_manager.disconnect()
|
||||
return {"status": "disconnected"}
|
||||
|
||||
|
||||
@app.get("/api/status")
|
||||
async def status():
|
||||
connected = ssh_manager.is_connected()
|
||||
gpu = get_gpu_stats() if connected else {"gpus": [], "error": "Not connected"}
|
||||
return {
|
||||
"connected": connected,
|
||||
"host": ssh_manager.host if connected else None,
|
||||
"username": ssh_manager.username if connected else None,
|
||||
"gpu": gpu,
|
||||
}
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# GPU
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/api/gpu")
|
||||
async def gpu():
|
||||
_require_ssh()
|
||||
return get_gpu_stats()
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# File management
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/api/files/{stage}")
|
||||
async def list_files(stage: str):
|
||||
if stage not in STAGE_DIRS:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}")
|
||||
_require_ssh()
|
||||
|
||||
out, _, code = ssh_manager.execute(
|
||||
f"ls -la '{STAGE_DIRS[stage]}' 2>/dev/null | tail -n +2", use_conda=False
|
||||
)
|
||||
|
||||
files = []
|
||||
for line in out.strip().split("\n"):
|
||||
if not line.strip() or line.startswith("total"):
|
||||
continue
|
||||
parts = line.split()
|
||||
if len(parts) >= 9 and not parts[0].startswith("d"):
|
||||
files.append({
|
||||
"name": " ".join(parts[8:]),
|
||||
"size": int(parts[4]),
|
||||
"modified": f"{parts[5]} {parts[6]} {parts[7]}",
|
||||
})
|
||||
|
||||
return {"stage": stage, "directory": STAGE_DIRS[stage], "files": files}
|
||||
|
||||
|
||||
@app.delete("/api/files/{stage}/{filename}")
|
||||
async def delete_file(stage: str, filename: str):
|
||||
if stage not in STAGE_DIRS:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}")
|
||||
_require_ssh()
|
||||
|
||||
path = f"{STAGE_DIRS[stage]}/{filename}"
|
||||
_, err, code = ssh_manager.execute(f"rm -f '{path}'", use_conda=False)
|
||||
if code != 0:
|
||||
raise HTTPException(status_code=500, detail=err)
|
||||
return {"deleted": filename}
|
||||
|
||||
|
||||
@app.get("/api/files/{stage}/{filename}/preview")
|
||||
async def preview_file(stage: str, filename: str, lines: int = 120):
|
||||
if stage not in STAGE_DIRS:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown stage: {stage}")
|
||||
_require_ssh()
|
||||
|
||||
path = f"{STAGE_DIRS[stage]}/{filename}"
|
||||
out, err, code = ssh_manager.execute(f"head -n {lines} '{path}'", use_conda=False)
|
||||
if code != 0:
|
||||
raise HTTPException(status_code=500, detail=err)
|
||||
return {"filename": filename, "content": out}
|
||||
|
||||
|
||||
@app.post("/api/upload")
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
_require_ssh()
|
||||
|
||||
suffix = Path(file.filename).suffix
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
tmp.write(await file.read())
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
remote_path = f"{STAGE_DIRS['input']}/{file.filename}"
|
||||
ssh_manager.upload_file(tmp_path, remote_path)
|
||||
return {"uploaded": file.filename, "remote_path": remote_path}
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Pipeline (WebSocket streaming)
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.websocket("/api/pipeline/ingest")
|
||||
async def ws_ingest(websocket: WebSocket, filename: str):
|
||||
if not ssh_manager.is_connected():
|
||||
await websocket.accept()
|
||||
await websocket.send_json({"type": "error", "data": "Not connected"})
|
||||
return
|
||||
cmd = ingest_cmd(f"{STAGE_DIRS['input']}/{filename}")
|
||||
await _stream_ws(websocket, cmd)
|
||||
|
||||
|
||||
@app.websocket("/api/pipeline/create")
|
||||
async def ws_create(websocket: WebSocket, filename: str,
|
||||
num_pairs: int = 50, pair_type: str = "qa"):
|
||||
if not ssh_manager.is_connected():
|
||||
await websocket.accept()
|
||||
await websocket.send_json({"type": "error", "data": "Not connected"})
|
||||
return
|
||||
cmd = create_cmd(f"{STAGE_DIRS['parsed']}/{filename}", num_pairs, pair_type)
|
||||
await _stream_ws(websocket, cmd)
|
||||
|
||||
|
||||
@app.websocket("/api/pipeline/curate")
|
||||
async def ws_curate(websocket: WebSocket, filename: str,
|
||||
output_filename: str, threshold: float = 7.0):
|
||||
if not ssh_manager.is_connected():
|
||||
await websocket.accept()
|
||||
await websocket.send_json({"type": "error", "data": "Not connected"})
|
||||
return
|
||||
cmd = curate_cmd(
|
||||
f"{STAGE_DIRS['generated']}/{filename}",
|
||||
f"{STAGE_DIRS['curated']}/{output_filename}",
|
||||
threshold,
|
||||
)
|
||||
await _stream_ws(websocket, cmd)
|
||||
|
||||
|
||||
@app.websocket("/api/pipeline/save")
|
||||
async def ws_save(websocket: WebSocket, filename: str,
|
||||
output_filename: str, fmt: str = "jsonl"):
|
||||
if not ssh_manager.is_connected():
|
||||
await websocket.accept()
|
||||
await websocket.send_json({"type": "error", "data": "Not connected"})
|
||||
return
|
||||
cmd = save_as_cmd(
|
||||
f"{STAGE_DIRS['curated']}/{filename}",
|
||||
f"{STAGE_DIRS['final']}/{output_filename}",
|
||||
fmt,
|
||||
)
|
||||
await _stream_ws(websocket, cmd)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# QA Pairs viewer
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/api/pairs/{filename}")
|
||||
async def get_pairs(filename: str, stage: str = "generated"):
|
||||
_require_ssh()
|
||||
path = f"{STAGE_DIRS.get(stage, STAGE_DIRS['generated'])}/{filename}"
|
||||
out, err, code = ssh_manager.execute(f"cat '{path}'", use_conda=False)
|
||||
if code != 0:
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
|
||||
|
||||
pairs = []
|
||||
for line in out.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
pairs.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return {"filename": filename, "count": len(pairs), "pairs": pairs}
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Config editor
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/api/config")
|
||||
async def get_config():
|
||||
_require_ssh()
|
||||
try:
|
||||
raw = ssh_manager.read_remote_file(CONFIG_PATH)
|
||||
return {"config": yaml.safe_load(raw), "raw": raw}
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@app.put("/api/config")
|
||||
async def update_config(payload: dict):
|
||||
_require_ssh()
|
||||
try:
|
||||
ssh_manager.write_remote_file(CONFIG_PATH, yaml.dump(payload, default_flow_style=False))
|
||||
return {"status": "updated"}
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Training (WebSocket streaming)
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.websocket("/api/train")
|
||||
async def ws_train(
|
||||
websocket: WebSocket,
|
||||
model_name: str = "llama3.1:8b",
|
||||
dataset_path: str = "",
|
||||
output_dir: str = "/opt/synthetic/output",
|
||||
num_epochs: int = 3,
|
||||
batch_size: int = 2,
|
||||
learning_rate: float = 2e-4,
|
||||
):
|
||||
if not ssh_manager.is_connected():
|
||||
await websocket.accept()
|
||||
await websocket.send_json({"type": "error", "data": "Not connected"})
|
||||
return
|
||||
|
||||
cmd = train_cmd(model_name, dataset_path, output_dir, num_epochs, batch_size, learning_rate)
|
||||
await _stream_ws(websocket, cmd)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Interactive terminal (xterm.js ↔ SSH shell)
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.websocket("/api/terminal")
|
||||
async def ws_terminal(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
|
||||
if not ssh_manager.is_connected():
|
||||
await websocket.send_text("\r\nNot connected to SSH server.\r\n")
|
||||
return
|
||||
|
||||
channel = None
|
||||
try:
|
||||
channel = ssh_manager.open_shell_channel()
|
||||
|
||||
async def ssh_to_ws():
|
||||
while True:
|
||||
if channel.recv_ready():
|
||||
data = channel.recv(4096)
|
||||
if not data:
|
||||
break
|
||||
await websocket.send_bytes(data)
|
||||
elif channel.exit_status_ready():
|
||||
break
|
||||
else:
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
async def ws_to_ssh():
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_bytes()
|
||||
channel.send(data)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
await asyncio.gather(ssh_to_ws(), ws_to_ssh())
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as exc:
|
||||
try:
|
||||
await websocket.send_text(f"\r\nError: {exc}\r\n")
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if channel:
|
||||
try:
|
||||
channel.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Model manager (Ollama)
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/api/models")
|
||||
async def list_models():
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.get(f"{OLLAMA_URL}/api/tags")
|
||||
resp.raise_for_status()
|
||||
return {"models": resp.json().get("models", [])}
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@app.websocket("/api/models/pull")
|
||||
async def ws_pull_model(websocket: WebSocket, model_name: str):
|
||||
await websocket.accept()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=600) as client:
|
||||
async with client.stream(
|
||||
"POST", f"{OLLAMA_URL}/api/pull",
|
||||
json={"name": model_name, "stream": True}
|
||||
) as resp:
|
||||
async for line in resp.aiter_lines():
|
||||
if line.strip():
|
||||
try:
|
||||
await websocket.send_json(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
await websocket.send_json({"status": "success"})
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as exc:
|
||||
try:
|
||||
await websocket.send_json({"status": "error", "error": str(exc)})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@app.delete("/api/models/{model_name:path}")
|
||||
async def delete_model(model_name: str):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.request(
|
||||
"DELETE", f"{OLLAMA_URL}/api/delete",
|
||||
json={"name": model_name}
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return {"deleted": model_name}
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8080, reload=True)
|
||||
73
backend/pipeline.py
Normal file
73
backend/pipeline.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Pipeline paths & command builders
|
||||
# These match the remote Ubuntu server layout from LLM_TRAINER_APP_SCOPE.md
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
SDK_BIN = (
|
||||
"/home/tocmo0nlord/miniconda3/envs/synthetic-data/bin/synthetic-data-kit"
|
||||
)
|
||||
CONFIG_PATH = "/opt/synthetic/synthetic-data-kit/config.yaml"
|
||||
DATA_BASE = "/opt/synthetic/synthetic-data-kit/data"
|
||||
|
||||
STAGE_DIRS = {
|
||||
"input": f"{DATA_BASE}/input",
|
||||
"parsed": f"{DATA_BASE}/parsed",
|
||||
"generated": f"{DATA_BASE}/generated",
|
||||
"curated": f"{DATA_BASE}/curated",
|
||||
"final": f"{DATA_BASE}/final",
|
||||
}
|
||||
|
||||
TRAIN_SCRIPT = "/opt/synthetic/train.py"
|
||||
OUTPUT_BASE = "/opt/synthetic/output"
|
||||
|
||||
|
||||
def _sdk(subcommand: str, *args) -> str:
|
||||
return f"{SDK_BIN} --config {CONFIG_PATH} {subcommand} {' '.join(args)}"
|
||||
|
||||
|
||||
def ingest_cmd(input_file: str) -> str:
|
||||
return _sdk("ingest", f"'{input_file}'", "-o", STAGE_DIRS["parsed"])
|
||||
|
||||
|
||||
def create_cmd(parsed_file: str, num_pairs: int = 50, pair_type: str = "qa") -> str:
|
||||
return _sdk(
|
||||
"create", f"'{parsed_file}'",
|
||||
"-o", STAGE_DIRS["generated"],
|
||||
"--type", pair_type,
|
||||
"--num-pairs", str(num_pairs),
|
||||
)
|
||||
|
||||
|
||||
def curate_cmd(generated_file: str, output_file: str, threshold: float = 7.0) -> str:
|
||||
return _sdk(
|
||||
"curate", f"'{generated_file}'",
|
||||
"-o", f"'{output_file}'",
|
||||
"--threshold", str(threshold),
|
||||
)
|
||||
|
||||
|
||||
def save_as_cmd(curated_file: str, output_file: str, fmt: str = "jsonl") -> str:
|
||||
return _sdk(
|
||||
"save-as", f"'{curated_file}'",
|
||||
"-f", fmt,
|
||||
"-o", f"'{output_file}'",
|
||||
)
|
||||
|
||||
|
||||
def train_cmd(
|
||||
model_name: str,
|
||||
dataset_path: str,
|
||||
output_dir: str = OUTPUT_BASE,
|
||||
num_epochs: int = 3,
|
||||
batch_size: int = 2,
|
||||
learning_rate: float = 2e-4,
|
||||
) -> str:
|
||||
return (
|
||||
f"python3 {TRAIN_SCRIPT} "
|
||||
f"--model '{model_name}' "
|
||||
f"--dataset '{dataset_path}' "
|
||||
f"--output '{output_dir}' "
|
||||
f"--epochs {num_epochs} "
|
||||
f"--batch-size {batch_size} "
|
||||
f"--lr {learning_rate}"
|
||||
)
|
||||
7
backend/requirements.txt
Normal file
7
backend/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
fastapi==0.111.0
|
||||
uvicorn[standard]==0.29.0
|
||||
paramiko==3.4.0
|
||||
httpx==0.27.0
|
||||
pyyaml==6.0.1
|
||||
python-multipart==0.0.9
|
||||
websockets==12.0
|
||||
177
backend/ssh_client.py
Normal file
177
backend/ssh_client.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import base64
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import paramiko
|
||||
|
||||
|
||||
class SSHClient:
|
||||
def __init__(self):
|
||||
self.client: Optional[paramiko.SSHClient] = None
|
||||
self.connected = False
|
||||
self.host = ""
|
||||
self.username = ""
|
||||
self.port = 22
|
||||
self._keepalive_thread: Optional[threading.Thread] = None
|
||||
self._stop_keepalive = threading.Event()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def connect(self, host: str, username: str, password: str = None,
|
||||
key_path: str = None, port: int = 22) -> bool:
|
||||
with self._lock:
|
||||
try:
|
||||
if self.client:
|
||||
self.client.close()
|
||||
|
||||
self.client = paramiko.SSHClient()
|
||||
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
|
||||
kwargs = {"hostname": host, "port": port, "username": username, "timeout": 10}
|
||||
if key_path:
|
||||
kwargs["key_filename"] = key_path
|
||||
if password:
|
||||
kwargs["password"] = password
|
||||
|
||||
self.client.connect(**kwargs)
|
||||
self.connected = True
|
||||
self.host = host
|
||||
self.username = username
|
||||
self.port = port
|
||||
|
||||
self._stop_keepalive.clear()
|
||||
self._keepalive_thread = threading.Thread(target=self._keepalive_loop, daemon=True)
|
||||
self._keepalive_thread.start()
|
||||
return True
|
||||
except Exception as e:
|
||||
self.connected = False
|
||||
raise e
|
||||
|
||||
def disconnect(self):
|
||||
self._stop_keepalive.set()
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self.connected = False
|
||||
|
||||
def _keepalive_loop(self):
|
||||
while not self._stop_keepalive.wait(30):
|
||||
try:
|
||||
transport = self.client.get_transport()
|
||||
if transport and transport.is_active():
|
||||
transport.send_ignore()
|
||||
else:
|
||||
self.connected = False
|
||||
break
|
||||
except Exception:
|
||||
self.connected = False
|
||||
break
|
||||
|
||||
def execute(self, command: str, use_conda: bool = True) -> tuple:
|
||||
if not self.is_connected():
|
||||
raise Exception("Not connected to SSH server")
|
||||
|
||||
if use_conda:
|
||||
full_cmd = (
|
||||
f"source /home/{self.username}/miniconda3/etc/profile.d/conda.sh && "
|
||||
f"conda activate synthetic-data && {command}"
|
||||
)
|
||||
else:
|
||||
full_cmd = command
|
||||
|
||||
_, stdout, stderr = self.client.exec_command(full_cmd)
|
||||
out = stdout.read().decode("utf-8", errors="replace")
|
||||
err = stderr.read().decode("utf-8", errors="replace")
|
||||
exit_code = stdout.channel.recv_exit_status()
|
||||
return out, err, exit_code
|
||||
|
||||
def execute_stream(self, command: str, use_conda: bool = True):
|
||||
"""Generator that yields output lines from a command."""
|
||||
if not self.is_connected():
|
||||
raise Exception("Not connected to SSH server")
|
||||
|
||||
if use_conda:
|
||||
full_cmd = (
|
||||
f"source /home/{self.username}/miniconda3/etc/profile.d/conda.sh && "
|
||||
f"conda activate synthetic-data && {command}"
|
||||
)
|
||||
else:
|
||||
full_cmd = command
|
||||
|
||||
transport = self.client.get_transport()
|
||||
channel = transport.open_session()
|
||||
channel.get_pty()
|
||||
channel.exec_command(full_cmd)
|
||||
|
||||
buffer = b""
|
||||
while True:
|
||||
if channel.recv_ready():
|
||||
data = channel.recv(4096)
|
||||
if not data:
|
||||
break
|
||||
buffer += data
|
||||
while b"\n" in buffer:
|
||||
line, buffer = buffer.split(b"\n", 1)
|
||||
yield line.decode("utf-8", errors="replace") + "\n"
|
||||
elif channel.exit_status_ready():
|
||||
if buffer:
|
||||
yield buffer.decode("utf-8", errors="replace")
|
||||
break
|
||||
else:
|
||||
time.sleep(0.05)
|
||||
|
||||
channel.close()
|
||||
|
||||
def open_shell_channel(self, term: str = "xterm-256color", width: int = 220, height: int = 50):
|
||||
"""Open an interactive shell channel for the terminal panel."""
|
||||
if not self.is_connected():
|
||||
raise Exception("Not connected to SSH server")
|
||||
|
||||
transport = self.client.get_transport()
|
||||
channel = transport.open_session()
|
||||
channel.get_pty(term=term, width=width, height=height)
|
||||
channel.invoke_shell()
|
||||
|
||||
# Auto-activate conda env
|
||||
activate = (
|
||||
f"source /home/{self.username}/miniconda3/etc/profile.d/conda.sh && "
|
||||
f"conda activate synthetic-data\n"
|
||||
)
|
||||
channel.send(activate)
|
||||
return channel
|
||||
|
||||
def upload_file(self, local_path: str, remote_path: str):
|
||||
if not self.is_connected():
|
||||
raise Exception("Not connected to SSH server")
|
||||
sftp = self.client.open_sftp()
|
||||
try:
|
||||
sftp.put(local_path, remote_path)
|
||||
finally:
|
||||
sftp.close()
|
||||
|
||||
def read_remote_file(self, remote_path: str) -> str:
|
||||
out, err, code = self.execute(f"cat '{remote_path}'", use_conda=False)
|
||||
if code != 0:
|
||||
raise Exception(f"Failed to read file: {err}")
|
||||
return out
|
||||
|
||||
def write_remote_file(self, remote_path: str, content: str):
|
||||
encoded = base64.b64encode(content.encode()).decode()
|
||||
cmd = f"echo '{encoded}' | base64 -d > '{remote_path}'"
|
||||
out, err, code = self.execute(cmd, use_conda=False)
|
||||
if code != 0:
|
||||
raise Exception(f"Failed to write file: {err}")
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
try:
|
||||
if self.client:
|
||||
transport = self.client.get_transport()
|
||||
if transport and transport.is_active():
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
self.connected = False
|
||||
return False
|
||||
|
||||
|
||||
# Singleton shared across all routes
|
||||
ssh_manager = SSHClient()
|
||||
Reference in New Issue
Block a user