From 59168eb623d609e75461392e6c71b64df8f048ae Mon Sep 17 00:00:00 2001 From: Synapticode Agent Date: Sun, 17 May 2026 15:33:37 +1000 Subject: [PATCH] fix(cache): compute codebooks on CPU at fp64 for MPS compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MPS framework doesn't support float64 dtype. compute_lloyd_max_codebook (line 284) and compute_online_codebook (line 326) both hardcode dtype=torch.float64 for their optimization grids, failing with TypeError when callers pass device=torch.device('mps'). Fix: force internal computation onto CPU at fp64 in both functions (preserving the algorithms' literature-standard precision for codebook centroid optimization), then move the final centroids/boundaries to the caller's target device when constructing the returned Codebook dataclass. This fits the existing Codebook architecture: the dataclass's quantize/ dequantize methods (lines 214-223) already handle device migration at usage time via .to(device=x.device, dtype=x.dtype). The fix sits at the natural device-firewall: build on CPU, store on caller's device, use on operand's device. _beta_pdf and _solve_lloyd_max inherit CPU automatically through tensor argument propagation; no edits needed there. Discovered during gamma-seeds tern-core R-track MPS validation (2026-05-16). Both make_b_mse_hook and make_b_mse_hook_uniform factories invoke TurboQuantConfig with device='mps' for KV-cache compression hooks; both fail at TurboQuantConfig.__init__ before any hook is applied. Verified end-to-end with tern-core's R7-B v1.2 harness on TinyLlama-1.1B FP16 MPS: β1a hook now produces finite PPL output. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/cache.py | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/src/cache.py b/src/cache.py index 653d196..6352b73 100644 --- a/src/cache.py +++ b/src/cache.py @@ -280,8 +280,15 @@ def compute_lloyd_max_codebook( lo = max(-1.0, -6.0 * sigma) hi = min(1.0, 6.0 * sigma) + # Lloyd-Max optimization requires float64 precision for stable convergence. + # Some accelerator backends (e.g., MPS) don't support float64, so the + # computation runs on CPU; the resulting Codebook tensors are moved to + # the caller's target device before return. Codebook.quantize/dequantize + # handle further device migration at usage time. + cpu_dev = torch.device("cpu") + grid_size = 16385 - grid = torch.linspace(lo, hi, grid_size, device=device, dtype=torch.float64) + grid = torch.linspace(lo, hi, grid_size, device=cpu_dev, dtype=torch.float64) if d >= 64: pdf = torch.exp(-0.5 * d * grid ** 2) * math.sqrt(d / (2.0 * math.pi)) @@ -295,7 +302,11 @@ def compute_lloyd_max_codebook( pdf = pdf / mass centroids, boundaries = _solve_lloyd_max(pdf, grid, K, max_iter, tol) - return Codebook(centroids=centroids, boundaries=boundaries, d=d, b=b, K=K) + return Codebook( + centroids=centroids.to(device), + boundaries=boundaries.to(device), + d=d, b=b, K=K, + ) def compute_online_codebook( @@ -316,24 +327,36 @@ def compute_online_codebook( max_iter: k-means iterations """ K = 2 ** b + + # 1D k-means requires float64 precision for stable convergence. Some + # accelerator backends (e.g., MPS) don't support float64, so the + # computation runs on CPU; the resulting Codebook tensors are moved to + # the caller's target device before return. Codebook.quantize/dequantize + # handle further device migration at usage time. + cpu_dev = torch.device("cpu") + # Flatten all coordinates into a single 1D distribution - flat = data.reshape(-1).float().to(device) + flat = data.reshape(-1).float().to(cpu_dev) # Build empirical PDF via histogram n_bins = 16385 lo = flat.min().item() - 1e-6 hi = flat.max().item() + 1e-6 - grid = torch.linspace(lo, hi, n_bins, device=device, dtype=torch.float64) - hist = torch.histogram(flat.cpu().double(), bins=n_bins, range=(lo, hi)) - pdf = hist.hist.to(device).double() + grid = torch.linspace(lo, hi, n_bins, device=cpu_dev, dtype=torch.float64) + hist = torch.histogram(flat.double(), bins=n_bins, range=(lo, hi)) + pdf = hist.hist.double() pdf = pdf / (pdf.sum() * (hi - lo) / n_bins) # Grid for PDF (bin centers) - bin_edges = hist.bin_edges.to(device).double() + bin_edges = hist.bin_edges.double() grid = 0.5 * (bin_edges[:-1] + bin_edges[1:]) centroids, boundaries = _solve_lloyd_max(pdf, grid, K, max_iter) - return Codebook(centroids=centroids, boundaries=boundaries, d=data.shape[-1], b=b, K=K) + return Codebook( + centroids=centroids.to(device), + boundaries=boundaries.to(device), + d=data.shape[-1], b=b, K=K, + ) # ---------------------------------------------------------------------------