-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[NVPTX] Make nvptx mma instructions convergent. #96521
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVPTX] Make nvptx mma instructions convergent. #96521
Conversation
@llvm/pr-subscribers-backend-nvptx Author: weiwei chen (weiweichen) ChangesWe are running into NVPTX backend generating wrong code for an input:
The backend reorder the instruction (as an effect of
This is incorrect because Apply Full diff: https://github.com/llvm/llvm-project/pull/96521.diff 1 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index a65170e56aa24..a19ec21826b82 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -6724,6 +6724,7 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# FragC.regstring # ";";
}
+let isConvergent = true in {
defset list<WMMA_INSTR> WMMAs = {
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
@@ -6745,6 +6746,7 @@ defset list<WMMA_INSTR> WMMAs = {
} // layout_b
} // layout_a
} // defset
+}
// MMA
class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
@@ -6774,6 +6776,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# FragC.regstring # ";";
}
+let isConvergent = true in {
defset list<WMMA_INSTR> MMAs = {
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
@@ -6793,6 +6796,7 @@ defset list<WMMA_INSTR> MMAs = {
} // layout_b
} // layout_a
} // defset
+}
//
// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense to me. These instructions can't be sunk across conditional boundaries. Please make sure to get a review from someone who normally touches the NVPTX backend!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
isConvergent
is indeed missing on these instructions.
Nice finding!
|
||
; COM: llvm.nvvm.mma should not sink to the next block and gets reordered to be after laneid check. | ||
; CHECK-LABEL: no_reorder_mma_and_laneid_check | ||
define dso_local void @no_reorder_mma_and_laneid_check(ptr %0, ptr %1, i64 %2) #0 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please get rid of the implicit variables (run opt -passes=instnamer
on your input IR and update the file)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, update! Thank you for the opt
tip!
…eic/mark-nvvm-mma-with-side-effect
We are running into NVPTX backend generating wrong code for an input: ``` %0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...) if laneid == 0: ret else: store %0 ``` The backend reorder the instruction (as an effect of `MachineSink` pass) to ``` if laneid == 0: ret else: %0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...) store %0 ``` This is incorrect because `mma` is a warp instruction which needs all threads to sync before performing the operation instead of being guarded by a specific thread id. It should be similar as the shuffle instruction `shfl` in terms of warp level sync, and `shfl` is marked as `isConvergent = true`. Apply `isConvergent = true` to `mma` instructions.
We are running into NVPTX backend generating wrong code for an input:
The backend reorder the instruction (as an effect of
MachineSink
pass) toThis is incorrect because
mma
is a warp instruction which needs all threads to sync before performing the operation instead of being guarded by a specific thread id. It should be similar as the shuffle instructionshfl
in terms of warp level sync, andshfl
is marked asisConvergent = true
.Apply
isConvergent = true
tomma
instructions.