Skip to content

perf: Optimize quantized matmul for small decode batches#122

Merged
Connor1996 merged 1 commit into
skyzh:mainfrom
Li0k:optimize-quantized-matmul-small-m
May 21, 2026
Merged

perf: Optimize quantized matmul for small decode batches#122
Connor1996 merged 1 commit into
skyzh:mainfrom
Li0k:optimize-quantized-matmul-small-m

Conversation

@Li0k
Copy link
Copy Markdown
Contributor

@Li0k Li0k commented May 20, 2026

Summary

This PR optimizes the reference 4-bit quantized matmul Metal launch shape for small-M decode workloads.

The existing kernel always uses a 32-row tile in the M dimension:

threads_per_group = (32, maxTotalThreadsPerThreadgroup / 32, 1)

During decode, M is often small (M = 1 for single-request decode, or small batch sizes for continuous batching). In those cases, most M-dimension lanes are idle. This PR keeps the total threadgroup size unchanged, but shifts lanes from the M dimension to the K/output-column dimension when M is small:

x_size = 1  if M <= 1
x_size = 2  if M <= 2
x_size = 4  if M <= 4
x_size = 8  if M <= 8
x_size = 16 if M <= 16
x_size = 32 otherwise

y_size = maxTotalThreadsPerThreadgroup / x_size

For larger M, this keeps the original 32-row tile.

Benchmark support

I also added a synthetic-token batch decode benchmark path to bench.py:

  • --batch-decode runs the Week 2 continuous-batching decode path.
  • It uses BatchingKvCache and batched offsets, matching the Week 2 batching task shape.
  • It excludes tokenizer/detokenizer overhead, matching the existing synthetic-token benchmark style.
  • It validates --num-seqs >= --batch-size so larger batch-size measurements are not accidentally under-filled.

The original non-batch benchmark path remains the default.

Benchmark setup

Model and command shape:

PYTHONPATH=src pdm run bench \
  --solution tiny_llm_ref \
  --loader week2 \
  --model qwen3-0.6b \
  --enable-flash-attn \
  --batch-decode \
  --batch-size <B> \
  --num-seqs 32 \
  --min-input-len 32 \
  --max-input-len 64 \
  --min-output-len 16 \
  --max-output-len 32 \
  --prefill-step 128 \
  --warmup 0 \
  --seed 0

Fixed workload:

  • prompt tokens: 1631
  • generated tokens: 783
  • local hardware: MacBook Pro, Apple M4 Pro, 12-core CPU, 48GB memory, MLX GPU backend
  • Qwen/Qwen3-0.6B-MLX-4bit
  • single run with fixed seed

Results

batch_size baseline decode tok/s this PR decode tok/s decode change baseline output tok/s this PR output tok/s
1 34.88 45.45 +30.3% 31.91 40.12
2 57.39 70.49 +22.8% 48.51 56.80
4 86.42 105.15 +21.7% 66.60 76.94
8 112.58 133.74 +18.8% 79.73 91.45
16 121.17 125.58 +3.6% 84.62 86.06

The improvement is concentrated in small-M decode cases. Local B=16 results were only a small improvement on my machine, while reviewer testing on M5 Pro showed a larger B=16 decode gain, so this PR includes the M <= 16 tile as well.

Validation

Correctness tests:

PYTHONPATH=src pdm run test-refsol --week 2 --day 2 -- -k quantized_matmul -q
PYTHONPATH=src pdm run test-refsol --week 2 --day 4 -- -k flash_attention -q
PYTHONPATH=src pdm run test-refsol --week 2 --day 6 -- -k "batching_kv_cache or qwen3_0_6b" -q

Smoke checks:

# Original non-batch benchmark path
PYTHONPATH=src pdm run bench --solution tiny_llm_ref --loader week2 --model qwen3-0.6b --enable-flash-attn --num-seqs 2 --min-input-len 8 --max-input-len 8 --min-output-len 4 --max-output-len 4 --warmup 0

# New batch-decode benchmark path
PYTHONPATH=src pdm run bench --solution tiny_llm_ref --loader week2 --model qwen3-0.6b --enable-flash-attn --batch-decode --batch-size 2 --num-seqs 2 --min-input-len 8 --max-input-len 8 --min-output-len 4 --max-output-len 4 --prefill-step 8 --warmup 0

# Guard against under-filled batch measurements
PYTHONPATH=src pdm run bench --solution tiny_llm_ref --loader week2 --model qwen3-0.6b --batch-decode --batch-size 16 --num-seqs 8 --min-input-len 8 --max-input-len 8 --min-output-len 4 --max-output-len 4 --warmup 0

@Li0k Li0k changed the title Optimize quantized matmul for small decode batches perf: Optimize quantized matmul for small decode batches May 20, 2026
@Li0k Li0k force-pushed the optimize-quantized-matmul-small-m branch from 2db7941 to 79c9c41 Compare May 20, 2026 06:07
@Li0k
Copy link
Copy Markdown
Contributor Author

Li0k commented May 20, 2026

@Connor1996 Could you take a look? Looking forward to your comments.

Copy link
Copy Markdown
Collaborator

@Connor1996 Connor1996 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. As x_size=16 still gains a little benefit. Let's add the M<=16 condition too.

I check locally on M5 pro, for M=16 decode +14.6% vs base

@Li0k
Copy link
Copy Markdown
Contributor Author

Li0k commented May 21, 2026

LGTM. As x_size=16 still gains a little benefit. Let's add the M<=16 condition too.

I check locally on M5 pro, for M=16 decode +14.6% vs base

Sure, it just didn’t show much difference when I tested it on my M4 Pro earlier.

@Li0k Li0k force-pushed the optimize-quantized-matmul-small-m branch from 79c9c41 to d8e3b4e Compare May 21, 2026 05:24
@Li0k
Copy link
Copy Markdown
Contributor Author

Li0k commented May 21, 2026

@Connor1996 Addressed in the latest push: added the M <= 16 tile and updated the B=16 benchmark note.

@Connor1996 Connor1996 merged commit 8f71501 into skyzh:main May 21, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants