[BugFix][Engram] Support SM counts not divisible by four#17
Open
JayceSu98 wants to merge 1 commit into
Open
Conversation
The README declares support for NVIDIA SM90 or SM100 architecture GPUs. The engram grad_w_reduce kernel nevertheless fixed num_batches to four and asserted that num_persistent_blocks was divisible by four, where num_persistent_blocks is derived from the device SM count through the grad_w_partial leading dimension. That assert is an implementation assumption about a particular SM count, not an architectural requirement. H100 PCIe is an SM90 GPU with 114 SMs, so 114 % 4 == 2 and all grad_w_reduce correctness and benchmark variants fail before launch even though the device satisfies the documented requirement. Devices with 132 SMs happen to pass only because 132 is divisible by four. Choose the largest batch count up to four that evenly partitions num_persistent_blocks. This preserves the original four-batch schedule when the SM count allows it, uses three batches on 114-SM devices, and keeps each pipeline batch rectangular without adding tail handling inside the TileLang kernel. JayceSu98 <jayce.su@enflame-tech.com> authored and validated this patch. Co-author GitHub: https://github.com/dingsg Verified on NVIDIA H100 PCIe (sm_90, 114 SMs) with Python 3.12.3 and local TileLang 0.1.10+cuda.git23d91c58: tests/engram/test_engram_grad_w_reduce.py correctness reported 6 passed, and the benchmark variants reported 6 passed. The benchmark command returned non-zero only because tests/benchmark_baselines.jsonl was absent, so regression comparison marked the six benchmark records as missing baselines. Co-authored-by: dingsg <shengge.ding@enflame-tech.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The README declares support for NVIDIA SM90 or SM100 architecture GPUs. The engram grad_w_reduce kernel nevertheless fixed num_batches to four and asserted that num_persistent_blocks was divisible by four, where num_persistent_blocks is derived from the device SM count through the grad_w_partial leading dimension.
That assert is an implementation assumption about a particular SM count, not an architectural requirement. H100 PCIe is an SM90 GPU with 114 SMs, so 114 % 4 == 2 and all grad_w_reduce correctness and benchmark variants fail before launch even though the device satisfies the documented requirement. Devices with 132 SMs happen to pass only because 132 is divisible by four.
Choose the largest batch count up to four that evenly partitions num_persistent_blocks. This preserves the original four-batch schedule when the SM count allows it, uses three batches on 114-SM devices, and keeps each pipeline batch rectangular without adding tail handling inside the TileLang kernel.
Verified on NVIDIA H100 PCIe (sm_90, 114 SMs) with Python 3.12.3 and local TileLang 0.1.10+cuda.git23d91c58: tests/engram/test_engram_grad_w_reduce.py correctness reported 6 passed, and the benchmark variants reported 6 passed. The benchmark command returned non-zero only because tests/benchmark_baselines.jsonl was absent, so regression comparison marked the six benchmark records as missing baselines.
JayceSu98 jayce.su@enflame-tech.com authored and validated this patch.
Co-author GitHub: https://github.com/dingsg