GPU-accelerated phash + fix discovery/takeout hang
GPU: - Switch Dockerfile base to pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime - Add gpu_hasher.py: batched 2D DCT on GPU via PyTorch matrix multiply, 256 images/batch, produces imagehash-compatible 64-bit hex hashes, auto-falls back to CPU when CUDA unavailable - Replace per-image phash loop in scanner.py with phasher.hash_files() - docker-compose.yml: add nvidia GPU device reservation Hang fix: - takeout.is_takeout_folder() now caps at 50 directories (was walking entire tree — blocked for minutes on 65k+ file libraries) - Add "Not a Takeout folder" status message so takeout phase is never silent Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,7 +1,9 @@
|
|||||||
FROM python:3.12-slim
|
# PyTorch + CUDA 12.1 base — matches Ubuntu 22.04 with NVIDIA driver 525+
|
||||||
|
FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
libheif-dev libjpeg-dev libpng-dev libtiff-dev libwebp-dev exiftool \
|
libheif-dev libjpeg-dev libpng-dev libtiff-dev libwebp-dev \
|
||||||
|
libgl1 libglib2.0-0 exiftool ffmpeg \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|||||||
162
app/gpu_hasher.py
Normal file
162
app/gpu_hasher.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
"""
|
||||||
|
GPU-accelerated perceptual hashing via PyTorch + CUDA.
|
||||||
|
|
||||||
|
Implements the same pHash algorithm as the `imagehash` library (DCT-II,
|
||||||
|
8×8 low-frequency block, 64-bit hash) so hashes produced here are
|
||||||
|
directly comparable with any existing imagehash-generated hashes in the DB.
|
||||||
|
|
||||||
|
Falls back to CPU if CUDA is not available — no code changes needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image, UnidentifiedImageError
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pillow_heif import register_heif_opener
|
||||||
|
register_heif_opener()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Must match imagehash defaults: hash_size=8, highfreq_factor=4
|
||||||
|
HASH_SIZE = 8
|
||||||
|
IMG_SIZE = HASH_SIZE * 4 # 32
|
||||||
|
BATCH_SIZE = 256 # images per GPU batch; lower if VRAM is tight
|
||||||
|
|
||||||
|
|
||||||
|
class GpuPhasher:
|
||||||
|
"""
|
||||||
|
Batched perceptual hasher. Uses CUDA when available, CPU otherwise.
|
||||||
|
|
||||||
|
The DCT is implemented as two matrix multiplications:
|
||||||
|
DCT2D(X) = D @ X @ Dᵀ
|
||||||
|
where D is the precomputed orthonormal DCT-II matrix of size IMG_SIZE.
|
||||||
|
This runs entirely on-GPU for the full batch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, batch_size: int = BATCH_SIZE):
|
||||||
|
self.batch_size = batch_size
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.device = torch.device("cuda")
|
||||||
|
dev_name = torch.cuda.get_device_name(0)
|
||||||
|
log.info("GpuPhasher: using CUDA device — %s", dev_name)
|
||||||
|
else:
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
log.info("GpuPhasher: CUDA not available, using CPU")
|
||||||
|
|
||||||
|
# Precompute orthonormal DCT-II matrix (IMG_SIZE × IMG_SIZE)
|
||||||
|
self._dct = self._build_dct_matrix(IMG_SIZE).to(self.device)
|
||||||
|
|
||||||
|
# ── DCT matrix ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_dct_matrix(n: int) -> torch.Tensor:
|
||||||
|
"""Orthonormal DCT-II matrix of size n×n."""
|
||||||
|
k = torch.arange(n, dtype=torch.float32).unsqueeze(1) # (n, 1)
|
||||||
|
i = torch.arange(n, dtype=torch.float32).unsqueeze(0) # (1, n)
|
||||||
|
mat = torch.cos(math.pi * k * (2.0 * i + 1.0) / (2.0 * n)) # (n, n)
|
||||||
|
mat[0] *= 1.0 / math.sqrt(n)
|
||||||
|
mat[1:] *= math.sqrt(2.0 / n)
|
||||||
|
return mat # (n, n)
|
||||||
|
|
||||||
|
# ── Image loading ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_image(path: str) -> np.ndarray | None:
|
||||||
|
"""Load image → greyscale float32 numpy array of shape (IMG_SIZE, IMG_SIZE)."""
|
||||||
|
try:
|
||||||
|
img = (
|
||||||
|
Image.open(path)
|
||||||
|
.convert("L")
|
||||||
|
.resize((IMG_SIZE, IMG_SIZE), Image.Resampling.LANCZOS)
|
||||||
|
)
|
||||||
|
return np.asarray(img, dtype=np.float32)
|
||||||
|
except (UnidentifiedImageError, OSError, Exception):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ── Core GPU batch ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _phash_batch(self, arrays: list[np.ndarray]) -> list[str]:
|
||||||
|
"""
|
||||||
|
Compute pHash for a list of (IMG_SIZE, IMG_SIZE) float32 numpy arrays.
|
||||||
|
Returns a list of 16-char hex strings (64-bit hashes).
|
||||||
|
"""
|
||||||
|
# Stack into GPU tensor (B, H, W)
|
||||||
|
batch = torch.from_numpy(np.stack(arrays)).to(self.device) # (B, 32, 32)
|
||||||
|
|
||||||
|
# 2D DCT: D @ X @ Dᵀ
|
||||||
|
dct2d = self._dct @ batch @ self._dct.T # (B, 32, 32)
|
||||||
|
|
||||||
|
# Keep only top-left HASH_SIZE × HASH_SIZE block
|
||||||
|
low = dct2d[:, :HASH_SIZE, :HASH_SIZE] # (B, 8, 8)
|
||||||
|
flat = low.reshape(low.shape[0], -1) # (B, 64)
|
||||||
|
|
||||||
|
# Each bit: is value > row mean?
|
||||||
|
means = flat.mean(dim=1, keepdim=True)
|
||||||
|
bits = (flat > means).cpu().numpy() # (B, 64) bool
|
||||||
|
|
||||||
|
# Pack bits → bytes → hex (matches imagehash's __str__ format)
|
||||||
|
return [np.packbits(b).tobytes().hex() for b in bits]
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def hash_files(
|
||||||
|
self,
|
||||||
|
paths: list[str],
|
||||||
|
progress_cb=None,
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Compute pHash for every path in `paths`.
|
||||||
|
|
||||||
|
Returns {path: hex_hash_string}. Paths that fail to open are omitted.
|
||||||
|
progress_cb(n_done: int) is called after each batch.
|
||||||
|
"""
|
||||||
|
results: dict[str, str] = {}
|
||||||
|
done = 0
|
||||||
|
|
||||||
|
for i in range(0, len(paths), self.batch_size):
|
||||||
|
chunk = paths[i : i + self.batch_size]
|
||||||
|
|
||||||
|
arrays: list[np.ndarray] = []
|
||||||
|
valid: list[str] = []
|
||||||
|
|
||||||
|
for p in chunk:
|
||||||
|
arr = self._load_image(p)
|
||||||
|
if arr is not None:
|
||||||
|
arrays.append(arr)
|
||||||
|
valid.append(p)
|
||||||
|
|
||||||
|
if arrays:
|
||||||
|
try:
|
||||||
|
hashes = self._phash_batch(arrays)
|
||||||
|
results.update(zip(valid, hashes))
|
||||||
|
except Exception as exc:
|
||||||
|
log.warning("GPU batch failed (%s); skipping batch", exc)
|
||||||
|
|
||||||
|
done += len(chunk)
|
||||||
|
if progress_cb:
|
||||||
|
progress_cb(done)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@property
|
||||||
|
def using_gpu(self) -> bool:
|
||||||
|
return self.device.type == "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Module-level singleton (created once, reused across scan phases) ──────────
|
||||||
|
|
||||||
|
_phasher: GpuPhasher | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_phasher() -> GpuPhasher:
|
||||||
|
global _phasher
|
||||||
|
if _phasher is None:
|
||||||
|
_phasher = GpuPhasher()
|
||||||
|
return _phasher
|
||||||
@@ -20,6 +20,7 @@ except ImportError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
from takeout import is_takeout_folder, process_takeout
|
from takeout import is_takeout_folder, process_takeout
|
||||||
|
from gpu_hasher import get_phasher
|
||||||
|
|
||||||
|
|
||||||
PHOTO_EXT = {
|
PHOTO_EXT = {
|
||||||
@@ -516,10 +517,14 @@ def run_scan(folder_path: str, scan_id: int, mode: str = "incremental"):
|
|||||||
con.commit()
|
con.commit()
|
||||||
|
|
||||||
# ── Phase: takeout pre-processing ─────────────────────────────────
|
# ── Phase: takeout pre-processing ─────────────────────────────────
|
||||||
scan_state.update(phase="takeout", message="Checking for Google Takeout structure...")
|
# Detection samples ≤50 dirs so it never blocks on large libraries
|
||||||
|
scan_state.update(phase="takeout",
|
||||||
|
message="Checking for Google Takeout structure (sampling)...")
|
||||||
if is_takeout_folder(folder_path):
|
if is_takeout_folder(folder_path):
|
||||||
scan_state["message"] = "Processing Google Takeout sidecars..."
|
scan_state["message"] = "Processing Google Takeout sidecars..."
|
||||||
process_takeout(folder_path, DB_PATH)
|
process_takeout(folder_path, DB_PATH)
|
||||||
|
else:
|
||||||
|
scan_state["message"] = "Not a Takeout folder — skipping"
|
||||||
|
|
||||||
if scan_state["cancel_requested"]:
|
if scan_state["cancel_requested"]:
|
||||||
_mark_scan(cur, scan_id, "cancelled")
|
_mark_scan(cur, scan_id, "cancelled")
|
||||||
@@ -607,8 +612,10 @@ def run_scan(folder_path: str, scan_id: int, mode: str = "incremental"):
|
|||||||
con.commit()
|
con.commit()
|
||||||
|
|
||||||
# ── Phase: phash ──────────────────────────────────────────────────
|
# ── Phase: phash ──────────────────────────────────────────────────
|
||||||
|
phasher = get_phasher()
|
||||||
|
hw_label = "GPU" if phasher.using_gpu else "CPU"
|
||||||
scan_state.update(phase="phash", progress=0,
|
scan_state.update(phase="phash", progress=0,
|
||||||
message="Computing perceptual hashes...")
|
message=f"Computing perceptual hashes ({hw_label})...")
|
||||||
|
|
||||||
cur.execute("""
|
cur.execute("""
|
||||||
SELECT id, path FROM files
|
SELECT id, path FROM files
|
||||||
@@ -621,19 +628,35 @@ def run_scan(folder_path: str, scan_id: int, mode: str = "incremental"):
|
|||||||
photo_rows = cur.fetchall()
|
photo_rows = cur.fetchall()
|
||||||
scan_state["total"] = len(photo_rows)
|
scan_state["total"] = len(photo_rows)
|
||||||
|
|
||||||
for i, row in enumerate(photo_rows):
|
if photo_rows:
|
||||||
if scan_state["cancel_requested"]:
|
# Build id lookup so we can write results back efficiently
|
||||||
_mark_scan(cur, scan_id, "cancelled")
|
path_to_id = {row["path"]: row["id"] for row in photo_rows}
|
||||||
con.commit()
|
all_paths = list(path_to_id.keys())
|
||||||
scan_state["status"] = "cancelled"
|
|
||||||
return
|
|
||||||
|
|
||||||
scan_state["progress"] = i + 1
|
def _phash_progress(n_done: int):
|
||||||
scan_state["message"] = f"Phash: {Path(row['path']).name}"
|
if scan_state["cancel_requested"]:
|
||||||
ph = _phash(row["path"])
|
return
|
||||||
if ph:
|
scan_state["progress"] = n_done
|
||||||
cur.execute("UPDATE files SET phash=? WHERE id=?", (ph, row["id"]))
|
scan_state["message"] = (
|
||||||
if (i + 1) % 200 == 0:
|
f"Phash ({hw_label}): {n_done:,} / {len(all_paths):,}"
|
||||||
|
)
|
||||||
|
|
||||||
|
results = phasher.hash_files(all_paths, progress_cb=_phash_progress)
|
||||||
|
|
||||||
|
# Bulk write to DB in chunks of 500
|
||||||
|
items = list(results.items())
|
||||||
|
for chunk_start in range(0, len(items), 500):
|
||||||
|
if scan_state["cancel_requested"]:
|
||||||
|
_mark_scan(cur, scan_id, "cancelled")
|
||||||
|
con.commit()
|
||||||
|
scan_state["status"] = "cancelled"
|
||||||
|
return
|
||||||
|
for path, ph in items[chunk_start : chunk_start + 500]:
|
||||||
|
fid = path_to_id.get(path)
|
||||||
|
if fid and ph:
|
||||||
|
cur.execute(
|
||||||
|
"UPDATE files SET phash=? WHERE id=?", (ph, fid)
|
||||||
|
)
|
||||||
con.commit()
|
con.commit()
|
||||||
|
|
||||||
con.commit()
|
con.commit()
|
||||||
|
|||||||
@@ -50,14 +50,19 @@ def is_takeout_folder(folder_path: str) -> bool:
|
|||||||
adjacent media files. If we find at least 5 such pairs, call it Takeout.
|
adjacent media files. If we find at least 5 such pairs, call it Takeout.
|
||||||
"""
|
"""
|
||||||
count = 0
|
count = 0
|
||||||
|
dirs_checked = 0
|
||||||
|
MAX_DIRS = 50 # sample at most 50 directories — fast on any library size
|
||||||
|
|
||||||
for root, dirs, files in os.walk(folder_path):
|
for root, dirs, files in os.walk(folder_path):
|
||||||
# Skip hidden dirs
|
|
||||||
dirs[:] = [d for d in dirs if not d.startswith(".")]
|
dirs[:] = [d for d in dirs if not d.startswith(".")]
|
||||||
|
dirs_checked += 1
|
||||||
|
if dirs_checked > MAX_DIRS:
|
||||||
|
break
|
||||||
|
|
||||||
file_set = set(files)
|
file_set = set(files)
|
||||||
for f in files:
|
for f in files:
|
||||||
if not f.endswith(".json"):
|
if not f.endswith(".json"):
|
||||||
continue
|
continue
|
||||||
# Check if a media file exists that this could be a sidecar for
|
|
||||||
base = f[:-5] # strip .json
|
base = f[:-5] # strip .json
|
||||||
if base in file_set:
|
if base in file_set:
|
||||||
count += 1
|
count += 1
|
||||||
|
|||||||
@@ -13,5 +13,10 @@ services:
|
|||||||
deploy:
|
deploy:
|
||||||
resources:
|
resources:
|
||||||
limits:
|
limits:
|
||||||
cpus: "2.0"
|
cpus: "4.0"
|
||||||
memory: 2G
|
memory: 4G
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
count: 1
|
||||||
|
capabilities: [gpu]
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
|
# torch + torchvision come pre-installed in the pytorch/pytorch base image
|
||||||
|
# (torchvision needed for image transforms)
|
||||||
|
torchvision==0.18.1
|
||||||
|
|
||||||
fastapi==0.115.6
|
fastapi==0.115.6
|
||||||
uvicorn==0.32.1
|
uvicorn==0.32.1
|
||||||
Pillow==11.0.0
|
Pillow==11.0.0
|
||||||
@@ -5,3 +9,4 @@ imagehash==4.3.1
|
|||||||
pillow-heif==0.21.0
|
pillow-heif==0.21.0
|
||||||
jinja2==3.1.4
|
jinja2==3.1.4
|
||||||
aiofiles==24.1.0
|
aiofiles==24.1.0
|
||||||
|
numpy==1.26.4
|
||||||
|
|||||||
Reference in New Issue
Block a user