Replace deprecated torch.jit.script with torch.compile#129
Conversation
|
@saitcakmak: since you recently worked on #128, would you be willing to quickly review/approve this PR? Has only two lines of code changed 🙃 |
|
Hi @AdrianSosic. Thanks for chasing down these deprecation warnings! I am not very familiar with |
Seems your intuition is right. Ran some benchmarks with Claude directly on Details#!/usr/bin/env python3 """ Benchmark script that generates graphs comparing @torch.compile vs no decorator for linear_cg helper functions.Tests a broad range of configurations:
import importlib import torch warnings.filterwarnings("ignore") Ensure the repo root is on the path so linear_operator can be importedsys.path.insert(0, os.path.dirname(os.path.abspath(file))) import linear_operator OUTPUT_DIR = os.path.dirname(os.path.abspath(file)) def make_psd_matrix(n, batch_size=None, dtype=torch.float64, device="cpu"): def get_cg_module(): def run_linear_cg(matrix_size, rhs_cols, max_iter, batch_size=None): def run_linear_cg_no_decorator(matrix_size, rhs_cols, max_iter, batch_size=None): def benchmark_steady_state(matrix_size, rhs_cols, max_iter, batch_size=None, ============================================================print("=" * 70) ============================================================Graph 1: First-call compilation overhead (measured fresh)============================================================print("Graph 1: Measuring first-call compilation overhead...") torch._dynamo.reset() first_call_configs = [ compile_first = [] for matrix_size, rhs_cols, max_iter, batch_size, label in first_call_configs: fig, ax = plt.subplots(figsize=(9, 5)) bars1 = ax.bar(x - width/2, plain_first, width, label='No decorator', color='#2196F3') ax.set_ylabel('Time (ms)', fontsize=12) for bar in bars1: plt.tight_layout() ============================================================Graph 2: Steady-state performance (broad configurations)============================================================print("Graph 2: Steady-state performance (broad configurations)...") steady_configs = [ compile_steady = [] for matrix_size, rhs_cols, max_iter, batch_size, label in steady_configs: fig, ax = plt.subplots(figsize=(12, 6)) bars1 = ax.bar(x - width/2, plain_steady, width, label='No decorator', color='#2196F3') ax.set_ylabel('Time (ms)', fontsize=12) Add ratio labelsfor i, (p, c) in enumerate(zip(plain_steady, compile_steady)): plt.tight_layout() ============================================================Graph 3: Shape variation (recompilation) overhead============================================================print("Graph 3: Shape variation overhead...") sizes = [200, 400, 600, 800, 1000, 1500, 2000, 3000, 5000] compile_shape_times = [] for sz in sizes: fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5)) Left: absolute timesax1.plot(sizes, plain_shape_times, 'o-', color='#2196F3', linewidth=2, markersize=7, label='No decorator') Right: overhead ratioratios = [c / p if p > 0 else 0 for c, p in zip(compile_shape_times, plain_shape_times)] for i, r in enumerate(ratios): plt.tight_layout() ============================================================Graph 4: CG iteration scaling (does more work per call help torch.compile?)============================================================print("Graph 4: CG iteration scaling...") iter_counts = [50, 100, 200, 500, 1000] compile_iter_times = [] for max_iter in iter_counts: fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5)) ax1.plot(iter_counts, plain_iter_times, 'o-', color='#2196F3', linewidth=2, markersize=7, label='No decorator') ratios_iter = [c / p if p > 0 else 0 for c, p in zip(compile_iter_times, plain_iter_times)] for i, r in enumerate(ratios_iter): plt.tight_layout() ============================================================Graph 5: Per-call timeline (showing compilation spike)============================================================print("Graph 5: Per-call timeline...") matrix_size, rhs_cols, max_iter = 500, 10, 100 No decorator: all callsplain_timeline = [] torch.compile: reset dynamo to force fresh compilationtorch._dynamo.reset() fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(calls, plain_timeline, 'o-', color='#2196F3', linewidth=2, markersize=7, label='No decorator') if compile_timeline[0] > 100: plt.tight_layout() ============================================================Summary============================================================print("=" * 70) all_ratios = [c / p for c, p in zip(compile_steady, plain_steady) if p > 0]
|
|
Thanks for running these. It seems like we're better off with no decorator. If anyone has a use case that would benefit from torch.compile, I think they can manually apply it. It doesn't seem worth using for generic use cases. |





Summary
Follow-up to #128, replacing the two
@torch.jit.scriptdecorators inlinear_operator/utils/linear_cg.pywith@torch.compileMotivation
torch.jit.scripthas been deprecated in recent PyTorch versions, raising aDeprecationWarningat import time: