Summary
mx.conv_general with 5D inputs (3D convolution) is significantly slower than decomposing the same operation into per-frame 2D convolutions with a Python loop. This affects video generation workloads (VAE decoders) where 3D convolutions over temporal/spatial dimensions are common.
Environment
- Hardware: Apple M4 Ultra (T6050), 128GB
- macOS: 26.5 (Tahoe)
- MLX: latest (via pip, mlx-video dependency)
- Python: 3.12
Reproduction
import os
os.environ["MACGEN_ALLOW_PARENT_MLX"] = "1"
import time
import mlx.core as mx
# Test: 3x3x3 conv on [1, 41, 120, 208, 512] input
# Equivalent to a CausalConv3d in video VAE decoder
B, T, H, W, Cin, Cout = 1, 41, 120, 208, 512, 512
kd, kh, kw = 3, 3, 3
# Random input and weight
x = mx.random.normal((B, T + 2, H + 2, W + 2, Cin), dtype=mx.bfloat16) # pre-padded
w = mx.random.normal((Cout, kd, kh, kw, Cin), dtype=mx.bfloat16)
bias = mx.zeros((Cout,), dtype=mx.bfloat16)
mx.eval(x, w, bias)
# Method 1: Native 3D conv (5D conv_general)
mx.synchronize()
t0 = time.perf_counter()
y_3d = mx.conv_general(x, w, stride=(1, 1, 1)) + bias
mx.eval(y_3d)
mx.synchronize()
native_time = time.perf_counter() - t0
# Method 2: Per-frame 2D conv (Python loop, 3 depth positions)
mx.synchronize()
t0 = time.perf_counter()
outputs_d = []
for d in range(kd):
frames = x[:, d:d+T].reshape(B * T, H + 2, W + 2, Cin)
w2d = w[:, d, :, :, :] # [Cout, kh, kw, Cin]
conv_out = mx.conv_general(frames, w2d)
outputs_d.append(conv_out.reshape(B, T, conv_out.shape[1], conv_out.shape[2], Cout))
y_2d = outputs_d[0] + outputs_d[1] + outputs_d[2] + bias
mx.eval(y_2d)
mx.synchronize()
loop_time = time.perf_counter() - t0
print(f"Native 3D conv: {native_time*1000:.0f}ms")
print(f"Per-frame 2D loop: {loop_time*1000:.0f}ms")
print(f"Slowdown: {native_time/loop_time:.1f}x")
# Verify correctness
diff = mx.max(mx.abs(y_3d[:, :T] - y_2d)).item()
print(f"Max diff: {diff:.6f}")
Results
Across multiple spatial resolutions (relevant to video VAE decoder stages):
Input shape Native 3D Per-frame 2D Ratio
──────────────────────────────────────────────────────────────────────
[1, 21, 60, 104, 1024] 488ms 231ms 2.1x slower
[1, 41, 120, 208, 512] 966ms 240ms 4.0x slower
[1, 41, 240, 416, 512] 3888ms 815ms 4.8x slower
The gap grows with spatial resolution, suggesting the 3D conv Metal kernel has suboptimal memory access patterns or lacks tiling optimizations that the 2D conv kernel has.
Additional Findings
A "depth-chunked" approach (batching temporal frames into groups of 2-8, then using 2D conv per depth position) achieves 1.03-1.72x speedup over the naive per-frame loop:
Input shape Chunk size Optimized Speedup vs loop
────────────────────────────────────────────────────────────────────────────
[1, 11, 30, 52, 1024] cs=8 63ms 1.40x
[1, 21, 60, 104, 1024] cs=4 199ms 1.72x
[1, 41, 120, 208, 1024→512] cs=3 731ms 1.13x
[1, 41, 240, 416, 512→256] cs=2 1132ms 1.03x
This suggests the 2D conv Metal kernel benefits significantly from batched inputs, and a well-optimized 3D conv kernel could match or exceed this.
Impact
This affects all video generation models using 3D convolutions in MLX:
- Wan2.2 VAE decoder: 3D conv takes ~70% of VAE decode time
- LTX-2.3 VAE: similar 3D conv usage
- Any video model with CausalConv3d or 3D convolution layers
An optimized 3D conv kernel would reduce VAE decode time by 15-40%, translating to meaningful end-to-end pipeline improvements for video generation on Apple Silicon.
Related
Summary
mx.conv_generalwith 5D inputs (3D convolution) is significantly slower than decomposing the same operation into per-frame 2D convolutions with a Python loop. This affects video generation workloads (VAE decoders) where 3D convolutions over temporal/spatial dimensions are common.Environment
Reproduction
Results
Across multiple spatial resolutions (relevant to video VAE decoder stages):
The gap grows with spatial resolution, suggesting the 3D conv Metal kernel has suboptimal memory access patterns or lacks tiling optimizations that the 2D conv kernel has.
Additional Findings
A "depth-chunked" approach (batching temporal frames into groups of 2-8, then using 2D conv per depth position) achieves 1.03-1.72x speedup over the naive per-frame loop:
This suggests the 2D conv Metal kernel benefits significantly from batched inputs, and a well-optimized 3D conv kernel could match or exceed this.
Impact
This affects all video generation models using 3D convolutions in MLX:
An optimized 3D conv kernel would reduce VAE decode time by 15-40%, translating to meaningful end-to-end pipeline improvements for video generation on Apple Silicon.
Related