diff --git a/src/cache.py b/src/cache.py index 653d196..6352b73 100644 --- a/src/cache.py +++ b/src/cache.py @@ -280,8 +280,15 @@ def compute_lloyd_max_codebook( lo = max(-1.0, -6.0 * sigma) hi = min(1.0, 6.0 * sigma) + # Lloyd-Max optimization requires float64 precision for stable convergence. + # Some accelerator backends (e.g., MPS) don't support float64, so the + # computation runs on CPU; the resulting Codebook tensors are moved to + # the caller's target device before return. Codebook.quantize/dequantize + # handle further device migration at usage time. + cpu_dev = torch.device("cpu") + grid_size = 16385 - grid = torch.linspace(lo, hi, grid_size, device=device, dtype=torch.float64) + grid = torch.linspace(lo, hi, grid_size, device=cpu_dev, dtype=torch.float64) if d >= 64: pdf = torch.exp(-0.5 * d * grid ** 2) * math.sqrt(d / (2.0 * math.pi)) @@ -295,7 +302,11 @@ def compute_lloyd_max_codebook( pdf = pdf / mass centroids, boundaries = _solve_lloyd_max(pdf, grid, K, max_iter, tol) - return Codebook(centroids=centroids, boundaries=boundaries, d=d, b=b, K=K) + return Codebook( + centroids=centroids.to(device), + boundaries=boundaries.to(device), + d=d, b=b, K=K, + ) def compute_online_codebook( @@ -316,24 +327,36 @@ def compute_online_codebook( max_iter: k-means iterations """ K = 2 ** b + + # 1D k-means requires float64 precision for stable convergence. Some + # accelerator backends (e.g., MPS) don't support float64, so the + # computation runs on CPU; the resulting Codebook tensors are moved to + # the caller's target device before return. Codebook.quantize/dequantize + # handle further device migration at usage time. + cpu_dev = torch.device("cpu") + # Flatten all coordinates into a single 1D distribution - flat = data.reshape(-1).float().to(device) + flat = data.reshape(-1).float().to(cpu_dev) # Build empirical PDF via histogram n_bins = 16385 lo = flat.min().item() - 1e-6 hi = flat.max().item() + 1e-6 - grid = torch.linspace(lo, hi, n_bins, device=device, dtype=torch.float64) - hist = torch.histogram(flat.cpu().double(), bins=n_bins, range=(lo, hi)) - pdf = hist.hist.to(device).double() + grid = torch.linspace(lo, hi, n_bins, device=cpu_dev, dtype=torch.float64) + hist = torch.histogram(flat.double(), bins=n_bins, range=(lo, hi)) + pdf = hist.hist.double() pdf = pdf / (pdf.sum() * (hi - lo) / n_bins) # Grid for PDF (bin centers) - bin_edges = hist.bin_edges.to(device).double() + bin_edges = hist.bin_edges.double() grid = 0.5 * (bin_edges[:-1] + bin_edges[1:]) centroids, boundaries = _solve_lloyd_max(pdf, grid, K, max_iter) - return Codebook(centroids=centroids, boundaries=boundaries, d=data.shape[-1], b=b, K=K) + return Codebook( + centroids=centroids.to(device), + boundaries=boundaries.to(device), + d=data.shape[-1], b=b, K=K, + ) # ---------------------------------------------------------------------------