[CK_TILE][FMHA] Fix uninitialized sink_size in mask_info::decode() and filter redundant no-mask+sink instances#3504
[CK_TILE][FMHA] Fix uninitialized sink_size in mask_info::decode() and filter redundant no-mask+sink instances#3504poyenc wants to merge 20 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
Fixes an FMHA runtime dispatch hazard caused by uninitialized sink_size for no_mask, adds a compile-time guard against invalid sink+no-mask template combinations, and reduces redundant kernel instantiations in codegen.
Changes:
- Initialize
left/right/sinkwhen decodingno_maskinmask_info::decode(). - Add
static_assert(FmhaMask::IsMasking || !kHasSink)to prevent invalid pipeline instantiations. - Filter out
no_mask + sink=truecombinations in FMHA fwd-related codegen scripts.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp | Adds compile-time validation to prevent kHasSink=true when masking is disabled. |
| example/ck_tile/01_fmha/mask.hpp | Fixes uninitialized fields for no_mask decoding (prevents bogus runtime has_sink). |
| example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py | Skips generating redundant/invalid no_mask + sink kernel variants. |
| example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py | Skips generating redundant/invalid no_mask + sink kernel variants. |
| example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | Adds compatibility filtering to avoid no_mask + sink kernels in fwd generation. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| static constexpr auto QScaleEnum = Traits::QScaleEnum; | ||
| static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; | ||
| static constexpr bool kHasSink = Traits::kHasSink; | ||
| static_assert(FmhaMask::IsMasking || !kHasSink); |
There was a problem hiding this comment.
The new static_assert has no diagnostic message, while other static_asserts in this file provide one (e.g., lines 108–123). Adding a short message (e.g., that sink requires masking) would make template instantiation failures much easier to understand.
| static constexpr bool kIsPagedKV = Traits::kIsPagedKV; | ||
| static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; | ||
| static constexpr bool kHasSink = Traits::kHasSink; | ||
| static_assert(FmhaMask::IsMasking || !kHasSink); |
There was a problem hiding this comment.
The new static_assert has no diagnostic message, while other static_asserts in this file provide one (e.g., lines 108–123). Adding a short message (e.g., that sink requires masking) would make template instantiation failures much easier to understand.
| static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ; | ||
| static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; | ||
| static constexpr bool kHasSink = Traits::kHasSink; | ||
| static_assert(FmhaMask::IsMasking || !kHasSink); |
There was a problem hiding this comment.
The new static_assert has no diagnostic message, while other static_asserts in this file provide one (e.g., lines 108–123). Adding a short message (e.g., that sink requires masking) would make template instantiation failures much easier to understand.
| # sink_size is only meaningful when mask is applied | ||
| if ( | ||
| kernel_ctx.pipeline.F_mask in no_mask_keys | ||
| and kernel_ctx.pipeline.F_sink == "t" | ||
| ): | ||
| return False |
There was a problem hiding this comment.
This comment says sink_size is only meaningful when no masking is applied, but the condition directly below filters out the no-mask + sink=true combination. Please update the comment to match the logic (i.e., sink is only meaningful when masking is enabled).
| or pipeline.F_logits == "f" | ||
| ): | ||
| continue | ||
| # sink_size is only meaningful when mask is applied |
There was a problem hiding this comment.
This comment says sink_size is only meaningful when no masking is applied, but the condition directly below filters out the no-mask + sink=true combination. Please update the comment to match the logic (i.e., sink is only meaningful when masking is enabled).
| # sink_size is only meaningful when mask is applied | |
| # sink_size is only meaningful when masking is enabled, so disallow sink when no mask is applied |
| or pipeline.F_logits == "f" | ||
| ): | ||
| continue | ||
| # sink_size is only meaningful when mask is applied |
There was a problem hiding this comment.
This comment says sink_size is only meaningful when no masking is applied, but the condition directly below filters out the no-mask + sink=true combination. Please update the comment to match the logic (i.e., sink is only meaningful when masking is enabled).
| # sink_size is only meaningful when mask is applied | |
| # sink_size is only meaningful when masking is enabled; disallow sink when no mask is used |
|
LGTM @asleepzzz Please approve it. |
|
Imported to ROCm/rocm-libraries |
Problem
When
mask_info::decode()parses"0"(no_mask), it only set thetypefield but leftleft,right, andsinkuninitialized. This caused:sinkcould be arbitrary garbage valuetraits.has_sink = (mask.sink > 0)in fmha_fwd_runner.hpp:882 might evaluate to truekHasSink=trueinstantiationsSolution
left=-1,right=-1,sink=0when decoding no_mask in mask.hppstatic_assert(FmhaMask::IsMasking || !kHasSink)to pipeline problemsF_mask=no_mask + F_sink=truecombinations in codegen scripts:Impact
Testing
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered