Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions src/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,15 @@ def compute_lloyd_max_codebook(
lo = max(-1.0, -6.0 * sigma)
hi = min(1.0, 6.0 * sigma)

# Lloyd-Max optimization requires float64 precision for stable convergence.
# Some accelerator backends (e.g., MPS) don't support float64, so the
# computation runs on CPU; the resulting Codebook tensors are moved to
# the caller's target device before return. Codebook.quantize/dequantize
# handle further device migration at usage time.
cpu_dev = torch.device("cpu")

grid_size = 16385
grid = torch.linspace(lo, hi, grid_size, device=device, dtype=torch.float64)
grid = torch.linspace(lo, hi, grid_size, device=cpu_dev, dtype=torch.float64)

if d >= 64:
pdf = torch.exp(-0.5 * d * grid ** 2) * math.sqrt(d / (2.0 * math.pi))
Expand All @@ -295,7 +302,11 @@ def compute_lloyd_max_codebook(
pdf = pdf / mass

centroids, boundaries = _solve_lloyd_max(pdf, grid, K, max_iter, tol)
return Codebook(centroids=centroids, boundaries=boundaries, d=d, b=b, K=K)
return Codebook(
centroids=centroids.to(device),
boundaries=boundaries.to(device),
d=d, b=b, K=K,
)


def compute_online_codebook(
Expand All @@ -316,24 +327,36 @@ def compute_online_codebook(
max_iter: k-means iterations
"""
K = 2 ** b

# 1D k-means requires float64 precision for stable convergence. Some
# accelerator backends (e.g., MPS) don't support float64, so the
# computation runs on CPU; the resulting Codebook tensors are moved to
# the caller's target device before return. Codebook.quantize/dequantize
# handle further device migration at usage time.
cpu_dev = torch.device("cpu")

# Flatten all coordinates into a single 1D distribution
flat = data.reshape(-1).float().to(device)
flat = data.reshape(-1).float().to(cpu_dev)

# Build empirical PDF via histogram
n_bins = 16385
lo = flat.min().item() - 1e-6
hi = flat.max().item() + 1e-6
grid = torch.linspace(lo, hi, n_bins, device=device, dtype=torch.float64)
hist = torch.histogram(flat.cpu().double(), bins=n_bins, range=(lo, hi))
pdf = hist.hist.to(device).double()
grid = torch.linspace(lo, hi, n_bins, device=cpu_dev, dtype=torch.float64)
hist = torch.histogram(flat.double(), bins=n_bins, range=(lo, hi))
pdf = hist.hist.double()
pdf = pdf / (pdf.sum() * (hi - lo) / n_bins)

# Grid for PDF (bin centers)
bin_edges = hist.bin_edges.to(device).double()
bin_edges = hist.bin_edges.double()
grid = 0.5 * (bin_edges[:-1] + bin_edges[1:])

centroids, boundaries = _solve_lloyd_max(pdf, grid, K, max_iter)
return Codebook(centroids=centroids, boundaries=boundaries, d=data.shape[-1], b=b, K=K)
return Codebook(
centroids=centroids.to(device),
boundaries=boundaries.to(device),
d=data.shape[-1], b=b, K=K,
)


# ---------------------------------------------------------------------------
Expand Down