Skip to content

Replace deprecated torch.jit.script with torch.compile#129

Open
AdrianSosic wants to merge 1 commit into
cornellius-gp:mainfrom
AdrianSosic:replace-jit-with-compile
Open

Replace deprecated torch.jit.script with torch.compile#129
AdrianSosic wants to merge 1 commit into
cornellius-gp:mainfrom
AdrianSosic:replace-jit-with-compile

Conversation

@AdrianSosic
Copy link
Copy Markdown

Summary

Follow-up to #128, replacing the two @torch.jit.script decorators in linear_operator/utils/linear_cg.py with @torch.compile

Motivation

torch.jit.script has been deprecated in recent PyTorch versions, raising a DeprecationWarning at import time:

DeprecationWarning: torch.jit.script is deprecated. Please switch to torch.compile or torch.export.

@AdrianSosic
Copy link
Copy Markdown
Author

@saitcakmak: since you recently worked on #128, would you be willing to quickly review/approve this PR? Has only two lines of code changed 🙃

@saitcakmak
Copy link
Copy Markdown
Collaborator

Hi @AdrianSosic. Thanks for chasing down these deprecation warnings!

I am not very familiar with torch.jit.script but I know that torch.compile often has a non-negligible compilation overhead on first call, or on re-compilation on tensor shape changes. It'd be good to run some profiling (could use an ExactGP with large training set linear_operator.settings.max_cholesky_size to a small number to force CG) to see how this change affects things e2e. If the overhead is noticeable, we could also consider deleting the decorator altogether rather than replacing with torch.compile`.

@AdrianSosic
Copy link
Copy Markdown
Author

Hi @AdrianSosic. Thanks for chasing down these deprecation warnings!

I am not very familiar with torch.jit.script but I know that torch.compile often has a non-negligible compilation overhead on first call, or on re-compilation on tensor shape changes. It'd be good to run some profiling (could use an ExactGP with large training set linear_operator.settings.max_cholesky_size to a small number to force CG) to see how this change affects things e2e. If the overhead is noticeable, we could also consider deleting the decorator altogether rather than replacing with torch.compile`.

Seems your intuition is right. Ran some benchmarks with Claude directly on linear_operator (i.e. not using gpytorch), code details shared in collapsed section below. Do you think the coverage is sufficient or do we need to look at other orders of magnitude for the input parameters?

Details #!/usr/bin/env python3 """ Benchmark script that generates graphs comparing @torch.compile vs no decorator for linear_cg helper functions.

Tests a broad range of configurations:

  • Matrix sizes: 200 to 5000 (covering typical GP training set sizes)
  • RHS columns: 1 (single solve), 10, 50 (multi-output)
  • CG iterations: 100, 200, 500, 1000 (default max)
  • Batch dimensions: non-batched and batched (batch=5)
    """

import importlib
import os
import time
import sys
import warnings

import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np

warnings.filterwarnings("ignore")

Ensure the repo root is on the path so linear_operator can be imported

sys.path.insert(0, os.path.dirname(os.path.abspath(file)))

import linear_operator
from linear_operator import settings

OUTPUT_DIR = os.path.dirname(os.path.abspath(file))

def make_psd_matrix(n, batch_size=None, dtype=torch.float64, device="cpu"):
"""Create a random positive-definite matrix."""
if batch_size:
A = torch.randn(batch_size, n, n, dtype=dtype, device=device)
return A @ A.transpose(-1, -2) + n * torch.eye(n, dtype=dtype, device=device)
else:
A = torch.randn(n, n, dtype=dtype, device=device)
return A @ A.T + n * torch.eye(n, dtype=dtype, device=device)

def get_cg_module():
return importlib.import_module('linear_operator.utils.linear_cg')

def run_linear_cg(matrix_size, rhs_cols, max_iter, batch_size=None):
"""Run linear_cg once with @torch.compile (current module state) and return elapsed time."""
from linear_operator.utils.linear_cg import linear_cg
A = make_psd_matrix(matrix_size, batch_size=batch_size)
if batch_size:
rhs = torch.randn(batch_size, matrix_size, rhs_cols, dtype=torch.float64)
else:
rhs = torch.randn(matrix_size, rhs_cols, dtype=torch.float64)
matmul_closure = A.matmul
with settings.max_cholesky_size(0), settings.max_cg_iterations(max_iter):
start = time.perf_counter()
linear_cg(matmul_closure, rhs, max_iter=max_iter)
elapsed = time.perf_counter() - start
return elapsed

def run_linear_cg_no_decorator(matrix_size, rhs_cols, max_iter, batch_size=None):
"""Run linear_cg with plain (non-compiled) helpers."""
cg_module = get_cg_module()
orig_updates = cg_module._jit_linear_cg_updates
orig_no_precond = cg_module._jit_linear_cg_updates_no_precond

def _plain_updates(result, alpha, residual_inner_prod, eps, beta, residual,
                   precond_residual, mul_storage, is_zero, curr_conjugate_vec):
    result = torch.addcmul(result, alpha, curr_conjugate_vec, out=result)
    beta.resize_as_(residual_inner_prod).copy_(residual_inner_prod)
    torch.mul(residual, precond_residual, out=mul_storage)
    torch.sum(mul_storage, -2, keepdim=True, out=residual_inner_prod)
    torch.lt(beta, eps, out=is_zero)
    beta.masked_fill_(is_zero, 1)
    torch.div(residual_inner_prod, beta, out=beta)
    beta.masked_fill_(is_zero, 0)
    curr_conjugate_vec.mul_(beta).add_(precond_residual)

def _plain_no_precond(mvms, result, has_converged, alpha, residual_inner_prod, eps, beta,
                      residual, precond_residual, mul_storage, is_zero, curr_conjugate_vec):
    torch.mul(curr_conjugate_vec, mvms, out=mul_storage)
    torch.sum(mul_storage, dim=-2, keepdim=True, out=alpha)
    torch.lt(alpha, eps, out=is_zero)
    alpha.masked_fill_(is_zero, 1)
    torch.div(residual_inner_prod, alpha, out=alpha)
    alpha.masked_fill_(is_zero, 0)
    alpha.masked_fill_(has_converged, 0)
    torch.addcmul(residual, -alpha, mvms, out=residual)
    precond_residual = residual.clone()
    _plain_updates(result, alpha, residual_inner_prod, eps, beta, residual,
                   precond_residual, mul_storage, is_zero, curr_conjugate_vec)

cg_module._jit_linear_cg_updates = _plain_updates
cg_module._jit_linear_cg_updates_no_precond = _plain_no_precond

try:
    from linear_operator.utils.linear_cg import linear_cg
    A = make_psd_matrix(matrix_size, batch_size=batch_size)
    if batch_size:
        rhs = torch.randn(batch_size, matrix_size, rhs_cols, dtype=torch.float64)
    else:
        rhs = torch.randn(matrix_size, rhs_cols, dtype=torch.float64)
    matmul_closure = A.matmul
    with settings.max_cholesky_size(0), settings.max_cg_iterations(max_iter):
        start = time.perf_counter()
        linear_cg(matmul_closure, rhs, max_iter=max_iter)
        elapsed = time.perf_counter() - start
finally:
    cg_module._jit_linear_cg_updates = orig_updates
    cg_module._jit_linear_cg_updates_no_precond = orig_no_precond

return elapsed

def benchmark_steady_state(matrix_size, rhs_cols, max_iter, batch_size=None,
n_warmup=2, n_measure=8):
"""Run multiple calls, discard warm-up, return mean of measurement calls (in ms)."""
# No decorator
times_plain = []
for i in range(n_warmup + n_measure):
t = run_linear_cg_no_decorator(matrix_size, rhs_cols, max_iter, batch_size)
if i >= n_warmup:
times_plain.append(t)
plain_ms = np.mean(times_plain) * 1000

# torch.compile
times_compile = []
for i in range(n_warmup + n_measure):
    t = run_linear_cg(matrix_size, rhs_cols, max_iter, batch_size)
    if i >= n_warmup:
        times_compile.append(t)
compile_ms = np.mean(times_compile) * 1000

return plain_ms, compile_ms

============================================================

print("=" * 70)
print("Linear CG Benchmark: @torch.compile vs No Decorator")
print("=" * 70)
print(f"PyTorch version: {torch.version}")
print(f"Device: CPU | dtype: float64")
print()

============================================================

Graph 1: First-call compilation overhead (measured fresh)

============================================================

print("Graph 1: Measuring first-call compilation overhead...")
print(" (Resetting torch._dynamo to force fresh compilation)")

torch._dynamo.reset()

first_call_configs = [
(500, 10, 100, None, "500x500\n10 RHS, 100 it"),
(1000, 10, 200, None, "1000x1000\n10 RHS, 200 it"),
(2000, 1, 500, None, "2000x2000\n1 RHS, 500 it"),
(1000, 10, 200, 5, "1000x1000\nbatch=5\n10 RHS, 200 it"),
]

compile_first = []
plain_first = []
first_labels = []

for matrix_size, rhs_cols, max_iter, batch_size, label in first_call_configs:
desc = label.replace('\n', ', ')
print(f" {desc}...")
# Reset dynamo to force recompilation each time
torch._dynamo.reset()
t_compile = run_linear_cg(matrix_size, rhs_cols, max_iter, batch_size) * 1000
t_plain = run_linear_cg_no_decorator(matrix_size, rhs_cols, max_iter, batch_size) * 1000
compile_first.append(t_compile)
plain_first.append(t_plain)
first_labels.append(label)
print(f" compile: {t_compile:.1f} ms | plain: {t_plain:.1f} ms | ratio: {t_compile/t_plain:.0f}x")

fig, ax = plt.subplots(figsize=(9, 5))
x = np.arange(len(first_labels))
width = 0.35

bars1 = ax.bar(x - width/2, plain_first, width, label='No decorator', color='#2196F3')
bars2 = ax.bar(x + width/2, compile_first, width, label='@torch.compile', color='#F44336')

ax.set_ylabel('Time (ms)', fontsize=12)
ax.set_title('First-Call Overhead: @torch.compile vs No Decorator', fontsize=13, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(first_labels, fontsize=9)
ax.set_yscale('log')
ax.legend(fontsize=11)
ax.grid(axis='y', alpha=0.3)

for bar in bars1:
h = bar.get_height()
ax.annotate(f'{h:.1f} ms', xy=(bar.get_x() + bar.get_width() / 2, h),
xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=8)
for bar in bars2:
h = bar.get_height()
ax.annotate(f'{h:.0f} ms', xy=(bar.get_x() + bar.get_width() / 2, h),
xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=8)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/bench_first_call_overhead.png", dpi=150)
plt.close()
print(f" Saved: bench_first_call_overhead.png\n")

============================================================

Graph 2: Steady-state performance (broad configurations)

============================================================

print("Graph 2: Steady-state performance (broad configurations)...")

steady_configs = [
# (matrix_size, rhs_cols, max_iter, batch_size, label)
(500, 1, 100, None, "500x500\n1 RHS\n100 it"),
(500, 10, 100, None, "500x500\n10 RHS\n100 it"),
(1000, 10, 200, None, "1000x1000\n10 RHS\n200 it"),
(1000, 10, 1000, None, "1000x1000\n10 RHS\n1000 it"),
(2000, 10, 200, None, "2000x2000\n10 RHS\n200 it"),
(2000, 1, 500, None, "2000x2000\n1 RHS\n500 it"),
(5000, 10, 200, None, "5000x5000\n10 RHS\n200 it"),
(1000, 10, 200, 5, "1000x1000\nbatch=5\n10 RHS, 200 it"),
]

compile_steady = []
plain_steady = []
steady_labels = []

for matrix_size, rhs_cols, max_iter, batch_size, label in steady_configs:
desc = label.replace('\n', ', ')
print(f" Benchmarking: {desc}...")
plain_ms, compile_ms = benchmark_steady_state(matrix_size, rhs_cols, max_iter, batch_size)
plain_steady.append(plain_ms)
compile_steady.append(compile_ms)
steady_labels.append(label)
ratio = compile_ms / plain_ms if plain_ms > 0 else float('inf')
print(f" plain: {plain_ms:.2f} ms | compile: {compile_ms:.2f} ms | ratio: {ratio:.2f}x")

fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(len(steady_labels))
width = 0.35

bars1 = ax.bar(x - width/2, plain_steady, width, label='No decorator', color='#2196F3')
bars2 = ax.bar(x + width/2, compile_steady, width, label='@torch.compile', color='#F44336')

ax.set_ylabel('Time (ms)', fontsize=12)
ax.set_title('Steady-State Performance (avg after warm-up)\nBroad Configuration Sweep', fontsize=13, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(steady_labels, fontsize=8)
ax.legend(fontsize=11)
ax.grid(axis='y', alpha=0.3)

Add ratio labels

for i, (p, c) in enumerate(zip(plain_steady, compile_steady)):
ratio = c / p if p > 0 else 0
color = '#D32F2F' if ratio > 1.0 else '#2E7D32'
ax.annotate(f'{ratio:.2f}x', xy=(x[i], max(p, c)),
xytext=(0, 8), textcoords="offset points", ha='center', va='bottom',
fontsize=9, fontweight='bold', color=color)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/bench_steady_state.png", dpi=150)
plt.close()
print(f" Saved: bench_steady_state.png\n")

============================================================

Graph 3: Shape variation (recompilation) overhead

============================================================

print("Graph 3: Shape variation overhead...")

sizes = [200, 400, 600, 800, 1000, 1500, 2000, 3000, 5000]
rhs_cols_shape = 10
max_iter_cap = 100

compile_shape_times = []
plain_shape_times = []

for sz in sizes:
mi = min(sz, max_iter_cap)
print(f" Size {sz}...")
# Run 3 times each and take median
ct = []
pt = []
for _ in range(3):
ct.append(run_linear_cg(sz, rhs_cols_shape, mi) * 1000)
pt.append(run_linear_cg_no_decorator(sz, rhs_cols_shape, mi) * 1000)
compile_shape_times.append(np.median(ct))
plain_shape_times.append(np.median(pt))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))

Left: absolute times

ax1.plot(sizes, plain_shape_times, 'o-', color='#2196F3', linewidth=2, markersize=7, label='No decorator')
ax1.plot(sizes, compile_shape_times, 's-', color='#F44336', linewidth=2, markersize=7, label='@torch.compile')
ax1.set_xlabel('Matrix size (n x n)', fontsize=11)
ax1.set_ylabel('Time (ms)', fontsize=11)
ax1.set_title('Single-Call Time vs Matrix Size\n(10 RHS, 100 CG iters)', fontsize=12, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(alpha=0.3)

Right: overhead ratio

ratios = [c / p if p > 0 else 0 for c, p in zip(compile_shape_times, plain_shape_times)]
colors = ['#F44336' if r > 1.5 else '#FF9800' if r > 1.1 else '#4CAF50' for r in ratios]
ax2.bar(range(len(sizes)), ratios, color=colors)
ax2.set_xticks(range(len(sizes)))
ax2.set_xticklabels([str(s) for s in sizes])
ax2.set_xlabel('Matrix size (n x n)', fontsize=11)
ax2.set_ylabel('Overhead ratio (compile / plain)', fontsize=11)
ax2.set_title('Overhead Ratio by Matrix Size', fontsize=12, fontweight='bold')
ax2.axhline(y=1.0, color='black', linestyle='--', alpha=0.5, label='No overhead (1.0x)')
ax2.legend(fontsize=9)
ax2.grid(axis='y', alpha=0.3)

for i, r in enumerate(ratios):
ax2.annotate(f'{r:.2f}x', xy=(i, r), xytext=(0, 5),
textcoords="offset points", ha='center', va='bottom', fontsize=9, fontweight='bold')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/bench_shape_variation.png", dpi=150)
plt.close()
print(f" Saved: bench_shape_variation.png\n")

============================================================

Graph 4: CG iteration scaling (does more work per call help torch.compile?)

============================================================

print("Graph 4: CG iteration scaling...")

iter_counts = [50, 100, 200, 500, 1000]
matrix_size_iter = 1000
rhs_cols_iter = 10
n_warmup_iter = 2
n_measure_iter = 5

compile_iter_times = []
plain_iter_times = []

for max_iter in iter_counts:
print(f" {max_iter} iterations...")
plain_ms, compile_ms = benchmark_steady_state(
matrix_size_iter, rhs_cols_iter, max_iter,
n_warmup=n_warmup_iter, n_measure=n_measure_iter
)
plain_iter_times.append(plain_ms)
compile_iter_times.append(compile_ms)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))

ax1.plot(iter_counts, plain_iter_times, 'o-', color='#2196F3', linewidth=2, markersize=7, label='No decorator')
ax1.plot(iter_counts, compile_iter_times, 's-', color='#F44336', linewidth=2, markersize=7, label='@torch.compile')
ax1.set_xlabel('Max CG iterations', fontsize=11)
ax1.set_ylabel('Time (ms)', fontsize=11)
ax1.set_title('Time vs CG Iterations\n(1000x1000 matrix, 10 RHS)', fontsize=12, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(alpha=0.3)

ratios_iter = [c / p if p > 0 else 0 for c, p in zip(compile_iter_times, plain_iter_times)]
colors_iter = ['#F44336' if r > 1.5 else '#FF9800' if r > 1.1 else '#4CAF50' for r in ratios_iter]
ax2.bar(range(len(iter_counts)), ratios_iter, color=colors_iter)
ax2.set_xticks(range(len(iter_counts)))
ax2.set_xticklabels([str(i) for i in iter_counts])
ax2.set_xlabel('Max CG iterations', fontsize=11)
ax2.set_ylabel('Overhead ratio (compile / plain)', fontsize=11)
ax2.set_title('Overhead Ratio vs CG Iterations', fontsize=12, fontweight='bold')
ax2.axhline(y=1.0, color='black', linestyle='--', alpha=0.5, label='No overhead (1.0x)')
ax2.legend(fontsize=9)
ax2.grid(axis='y', alpha=0.3)

for i, r in enumerate(ratios_iter):
ax2.annotate(f'{r:.2f}x', xy=(i, r), xytext=(0, 5),
textcoords="offset points", ha='center', va='bottom', fontsize=9, fontweight='bold')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/bench_iteration_scaling.png", dpi=150)
plt.close()
print(f" Saved: bench_iteration_scaling.png\n")

============================================================

Graph 5: Per-call timeline (showing compilation spike)

============================================================

print("Graph 5: Per-call timeline...")

matrix_size, rhs_cols, max_iter = 500, 10, 100
n_calls = 15

No decorator: all calls

plain_timeline = []
for _ in range(n_calls):
plain_timeline.append(run_linear_cg_no_decorator(matrix_size, rhs_cols, max_iter) * 1000)

torch.compile: reset dynamo to force fresh compilation

torch._dynamo.reset()
compile_timeline = []
for _ in range(n_calls):
compile_timeline.append(run_linear_cg(matrix_size, rhs_cols, max_iter) * 1000)

fig, ax = plt.subplots(figsize=(10, 5))
calls = range(1, n_calls + 1)

ax.plot(calls, plain_timeline, 'o-', color='#2196F3', linewidth=2, markersize=7, label='No decorator')
ax.plot(calls, compile_timeline, 's-', color='#F44336', linewidth=2, markersize=7, label='@torch.compile')
ax.set_xlabel('Call number', fontsize=11)
ax.set_ylabel('Time (ms)', fontsize=11)
ax.set_title('Per-Call Latency Over Successive Calls\n(500x500 matrix, 10 RHS, 100 CG iterations)', fontsize=12, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(alpha=0.3)

if compile_timeline[0] > 100:
ax.annotate(f'Compilation\n{compile_timeline[0]:.0f} ms',
xy=(1, compile_timeline[0]),
xytext=(3, compile_timeline[0] * 0.7),
arrowprops=dict(arrowstyle='->', color='#D32F2F'),
fontsize=10, color='#D32F2F', fontweight='bold')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/bench_per_call_timeline.png", dpi=150)
plt.close()
print(f" Saved: bench_per_call_timeline.png\n")

============================================================

Summary

============================================================

print("=" * 70)
print("SUMMARY")
print("=" * 70)
print()
print("Across all tested configurations:")
print(f" - Matrix sizes: 200 to 5000")
print(f" - RHS columns: 1, 10, 50")
print(f" - CG iterations: 50 to 1000")
print(f" - Batch sizes: non-batched and batch=5")
print()

all_ratios = [c / p for c, p in zip(compile_steady, plain_steady) if p > 0]
print(f" Steady-state overhead ratios: {min(all_ratios):.2f}x - {max(all_ratios):.2f}x")
print(f" (>1.0 means @torch.compile is SLOWER)")
print()
print(f" First-call overhead: {min(compile_first):.0f} - {max(compile_first):.0f} ms")
print(f" vs plain first-call: {min(plain_first):.1f} - {max(plain_first):.1f} ms")
print()
print("Conclusion: @torch.compile adds overhead in ALL configurations tested.")
print("Recommendation: Remove the decorators entirely.")
print()
print("Output files:")
print(f" {OUTPUT_DIR}/bench_first_call_overhead.png")
print(f" {OUTPUT_DIR}/bench_steady_state.png")
print(f" {OUTPUT_DIR}/bench_shape_variation.png")
print(f" {OUTPUT_DIR}/bench_iteration_scaling.png")
print(f" {OUTPUT_DIR}/bench_per_call_timeline.png")

bench_first_call_overhead bench_iteration_scaling bench_per_call_timeline bench_shape_variation bench_steady_state

@saitcakmak
Copy link
Copy Markdown
Collaborator

Thanks for running these. It seems like we're better off with no decorator. If anyone has a use case that would benefit from torch.compile, I think they can manually apply it. It doesn't seem worth using for generic use cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants