Skip to content

[NVPTX] Make nvptx mma instructions convergent. #96521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

weiweichen
Copy link
Contributor

We are running into NVPTX backend generating wrong code for an input:

%0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...)
if laneid == 0:
  ret
else:
  store %0

The backend reorder the instruction (as an effect of MachineSink pass) to

if laneid == 0:
  ret
else:
  %0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...)
  store %0

This is incorrect because mma is a warp instruction which needs all threads to sync before performing the operation instead of being guarded by a specific thread id. It should be similar as the shuffle instruction shfl in terms of warp level sync, and shfl is marked as isConvergent = true.

Apply isConvergent = true to mma instructions.

@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2024

@llvm/pr-subscribers-backend-nvptx

Author: weiwei chen (weiweichen)

Changes

We are running into NVPTX backend generating wrong code for an input:

%0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...)
if laneid == 0:
  ret
else:
  store %0

The backend reorder the instruction (as an effect of MachineSink pass) to

if laneid == 0:
  ret
else:
  %0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...)
  store %0

This is incorrect because mma is a warp instruction which needs all threads to sync before performing the operation instead of being guarded by a specific thread id. It should be similar as the shuffle instruction shfl in terms of warp level sync, and shfl is marked as isConvergent = true.

Apply isConvergent = true to mma instructions.


Full diff: https://github.com/llvm/llvm-project/pull/96521.diff

1 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+4)
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index a65170e56aa24..a19ec21826b82 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -6724,6 +6724,7 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                   # FragC.regstring # ";";
 }
 
+let isConvergent = true in {
 defset list<WMMA_INSTR> WMMAs  = {
   foreach layout_a = ["row", "col"] in {
     foreach layout_b = ["row", "col"] in {
@@ -6745,6 +6746,7 @@ defset list<WMMA_INSTR> WMMAs  = {
     } // layout_b
   } // layout_a
 } // defset
+}
 
 // MMA
 class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
@@ -6774,6 +6776,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                   # FragC.regstring # ";";
 }
 
+let isConvergent = true in {
 defset list<WMMA_INSTR> MMAs  = {
   foreach layout_a = ["row", "col"] in {
     foreach layout_b = ["row", "col"] in {
@@ -6793,6 +6796,7 @@ defset list<WMMA_INSTR> MMAs  = {
     } // layout_b
   } // layout_a
 } // defset
+}
 
 //
 // ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16

Copy link
Contributor

@Mogball Mogball left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me. These instructions can't be sunk across conditional boundaries. Please make sure to get a review from someone who normally touches the NVPTX backend!

Copy link
Collaborator

@qcolombet qcolombet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

isConvergent is indeed missing on these instructions.

Nice finding!


; COM: llvm.nvvm.mma should not sink to the next block and gets reordered to be after laneid check.
; CHECK-LABEL: no_reorder_mma_and_laneid_check
define dso_local void @no_reorder_mma_and_laneid_check(ptr %0, ptr %1, i64 %2) #0 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please get rid of the implicit variables (run opt -passes=instnamer on your input IR and update the file)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, update! Thank you for the opt tip!

@weiweichen weiweichen merged commit b0e9b00 into llvm:main Jun 25, 2024
7 checks passed
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
We are running into NVPTX backend generating wrong code for an input:
```
%0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...)
if laneid == 0:
  ret
else:
  store %0
```

The backend reorder the instruction (as an effect of `MachineSink` pass)
to
```
if laneid == 0:
  ret
else:
  %0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...)
  store %0
```

This is incorrect because `mma` is a warp instruction which needs all
threads to sync before performing the operation instead of being guarded
by a specific thread id. It should be similar as the shuffle instruction
`shfl` in terms of warp level sync, and `shfl` is marked as
`isConvergent = true`.

Apply `isConvergent = true` to `mma` instructions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants