Skip to content

[Performance] conv_general 5D (3D convolution) is 2-5x slower than equivalent per-frame 2D conv #3625

@kimjj81

Description

@kimjj81

Summary

mx.conv_general with 5D inputs (3D convolution) is significantly slower than decomposing the same operation into per-frame 2D convolutions with a Python loop. This affects video generation workloads (VAE decoders) where 3D convolutions over temporal/spatial dimensions are common.

Environment

  • Hardware: Apple M4 Ultra (T6050), 128GB
  • macOS: 26.5 (Tahoe)
  • MLX: latest (via pip, mlx-video dependency)
  • Python: 3.12

Reproduction

import os
os.environ["MACGEN_ALLOW_PARENT_MLX"] = "1"
import time
import mlx.core as mx

# Test: 3x3x3 conv on [1, 41, 120, 208, 512] input
# Equivalent to a CausalConv3d in video VAE decoder

B, T, H, W, Cin, Cout = 1, 41, 120, 208, 512, 512
kd, kh, kw = 3, 3, 3

# Random input and weight
x = mx.random.normal((B, T + 2, H + 2, W + 2, Cin), dtype=mx.bfloat16)  # pre-padded
w = mx.random.normal((Cout, kd, kh, kw, Cin), dtype=mx.bfloat16)
bias = mx.zeros((Cout,), dtype=mx.bfloat16)
mx.eval(x, w, bias)

# Method 1: Native 3D conv (5D conv_general)
mx.synchronize()
t0 = time.perf_counter()
y_3d = mx.conv_general(x, w, stride=(1, 1, 1)) + bias
mx.eval(y_3d)
mx.synchronize()
native_time = time.perf_counter() - t0

# Method 2: Per-frame 2D conv (Python loop, 3 depth positions)
mx.synchronize()
t0 = time.perf_counter()
outputs_d = []
for d in range(kd):
    frames = x[:, d:d+T].reshape(B * T, H + 2, W + 2, Cin)
    w2d = w[:, d, :, :, :]  # [Cout, kh, kw, Cin]
    conv_out = mx.conv_general(frames, w2d)
    outputs_d.append(conv_out.reshape(B, T, conv_out.shape[1], conv_out.shape[2], Cout))
y_2d = outputs_d[0] + outputs_d[1] + outputs_d[2] + bias
mx.eval(y_2d)
mx.synchronize()
loop_time = time.perf_counter() - t0

print(f"Native 3D conv: {native_time*1000:.0f}ms")
print(f"Per-frame 2D loop: {loop_time*1000:.0f}ms")
print(f"Slowdown: {native_time/loop_time:.1f}x")

# Verify correctness
diff = mx.max(mx.abs(y_3d[:, :T] - y_2d)).item()
print(f"Max diff: {diff:.6f}")

Results

Across multiple spatial resolutions (relevant to video VAE decoder stages):

Input shape                      Native 3D    Per-frame 2D    Ratio
──────────────────────────────────────────────────────────────────────
[1, 21, 60, 104, 1024]           488ms        231ms           2.1x slower
[1, 41, 120, 208, 512]           966ms        240ms           4.0x slower
[1, 41, 240, 416, 512]          3888ms        815ms           4.8x slower

The gap grows with spatial resolution, suggesting the 3D conv Metal kernel has suboptimal memory access patterns or lacks tiling optimizations that the 2D conv kernel has.

Additional Findings

A "depth-chunked" approach (batching temporal frames into groups of 2-8, then using 2D conv per depth position) achieves 1.03-1.72x speedup over the naive per-frame loop:

Input shape                      Chunk size   Optimized    Speedup vs loop
────────────────────────────────────────────────────────────────────────────
[1, 11, 30, 52, 1024]              cs=8        63ms         1.40x
[1, 21, 60, 104, 1024]             cs=4       199ms         1.72x
[1, 41, 120, 208, 1024→512]        cs=3       731ms         1.13x
[1, 41, 240, 416, 512→256]         cs=2      1132ms         1.03x

This suggests the 2D conv Metal kernel benefits significantly from batched inputs, and a well-optimized 3D conv kernel could match or exceed this.

Impact

This affects all video generation models using 3D convolutions in MLX:

  • Wan2.2 VAE decoder: 3D conv takes ~70% of VAE decode time
  • LTX-2.3 VAE: similar 3D conv usage
  • Any video model with CausalConv3d or 3D convolution layers

An optimized 3D conv kernel would reduce VAE decode time by 15-40%, translating to meaningful end-to-end pipeline improvements for video generation on Apple Silicon.

Related

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions