[PERFORMANCE] Templated kernels for grouped Conv1x1/Conv1D#271
Open
rhaist wants to merge 1 commit into
Open
Conversation
Compile-time-specialized GEMM kernels for the (out_channels, in_channels, groups) shapes used by WaveNet models. Generalizes the depthwise-only fast path from sdatkinson#217 to all grouped (and small dense) cases, addressing sdatkinson#215. Both the default Eigen path and NAM_USE_INLINE_GEMM build benefit; unknown shapes fall through to existing behavior. Render output is bit-identical to main on 33 production models including the v4 baseline a1-{pico,nano,feather,lite,standard} set.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Generalizes the depthwise-only fast path from #217 to all grouped (and small dense)
Conv1x1/Conv1Dshapes using compile-time-specialized kernels. Targets the "compile-time optimizations" path hinted at by #215.Closes #215.
Approach
templated_conv1x1_kernel<OutCh, InCh, Groups>andtemplated_conv1d_tap_kernel<OutCh, InCh, Groups>carry all shape information as template parameters. With constexpr loop bounds, the compiler unrolls every loop, folds index arithmetic, and never visits off-block-diagonal zeros.Dispatch is a function pointer (
_kernel/_tap_kernel) set at construction bypick_*_kernel(out, in, groups). Unknown shapes returnnullptrand fall through to the existing inline-GEMM / Eigen path — no regression risk. Both default Eigen andNAM_USE_INLINE_GEMMbuilds benefit (no#ifdefgate around dispatch).Depthwise (
groups == channels) is intentionally not registered — already handled by the existing_is_depthwisefast path from #217.Registered square shapes:
(4,4),(6,6),(8,8),(12,12),(16,16)atgroups in {1, 2, 3, 4, 6, 8}.Microbenchmark (Conv1x1, 64-frame buffer, best of 3 x 2M iters)
End-to-end (
benchmodel, best of 5 runs, STD build, Apple M-series Release)v4/1x1_groups/*.nam(varieslayer1x1.groups, 8-channel WaveNet):v4/input_groups/*.nam(varies Conv1Dgroups_input):v4/channels/*.nam(varies channel width,groups=1— shows dense gains from bypassing Eigen overhead on small matrices):Wins are independent of Eigen version — verified against both this repo's pinned Eigen 3.4-dev (
87300c93) and a separately-tracked Eigen 5.0.0 bump.Correctness
run_testsx {STD, NAM_USE_INLINE_GEMM}baseline a1-{pico,nano,feather,lite,standard},channels/{1..16},bottleneck_sizes,1x1_groups,input_groups,head1x1_groups)git-clang-format --diff HEADDiff
NAM/dsp.h+10,NAM/dsp.cpp+120/-2NAM/conv1d.h+11,NAM/conv1d.cpp+132tools/CMakeLists.txt+36tools/bench_conv1x1_groups.cpp(microbench + correctness gate)tools/check_conv1d_grouped.cpp(Conv1D correctness gate across 198 shapes)Notes for reviewers
pick_*_kerneltables are intentionally narrow — only square shapes that appear in the v4 model sweep. Trivial to extend later. Anything unregistered keeps the existing behavior exactly.run_testsbecause they need-O2/-O3(run_testsis-O0for allocation tracking). They run cleanly as standalone CI steps if you want them gated.Test plan
run_testspasses on STD +NAM_USE_INLINE_GEMMtools/bench_conv1x1_groupscorrectness gate passes (22 shapes)tools/check_conv1d_groupedpasses (198 shapes)git-clang-format --diff HEADclean