Skip to content

[mlir][nvvm] Introduce nvvm.stmatrix Op #69467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 19, 2023
Merged

[mlir][nvvm] Introduce nvvm.stmatrix Op #69467

merged 2 commits into from
Oct 19, 2023

Conversation

grypp
Copy link
Member

@grypp grypp commented Oct 18, 2023

This PR adds nvvm.stmatrix Op to NVVM dialect. The Op collectively store one or more matrices across all threads in a warp to the given address location in shared memory.

This PR adds `nvvm.stmatrix` Op to NVVM dialect. The Op collectively store one or more matrices across all threads in a warp to the given address location in shared memory.
@llvmbot
Copy link
Member

llvmbot commented Oct 18, 2023

@llvm/pr-subscribers-mlir-nvgpu
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Guray Ozen (grypp)

Changes

This PR adds nvvm.stmatrix Op to NVVM dialect. The Op collectively store one or more matrices across all threads in a warp to the given address location in shared memory.


Full diff: https://github.com/llvm/llvm-project/pull/69467.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+29)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+13)
  • (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+24)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index cefdd7cc4033a11..9cda7862ccb0fe3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1186,6 +1186,35 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
   let hasVerifier = 1;
 }
 
+def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, 
+  Arguments<(ins LLVM_i8Ptr_shared:$ptr, 
+                 Variadic<I32>:$sources, 
+                 MMALayoutAttr:$layout)> {
+  let summary = "cooperative matrix store";
+  let description = [{
+    Collectively store one or more matrices across all threads in a warp to the
+    location indicated by the address operand $ptr in shared memory.
+    [For more information, see PTX ISA]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix)
+  }];
+  
+  let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() {
+      int d = getSources().size();
+      std::string ptx = "stmatrix.sync.aligned";
+      ptx += ".x" + std::to_string(d);
+      if (getLayout() == NVVM::MMALayout::col)
+        ptx += ".trans";
+      if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1}";
+      if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2}";
+      if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};";
+      return ptx;
+    }
+  }];
+  let hasVerifier = 1;
+}
+
 def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
   Results<(outs AnyType:$res)>,
   Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 92df023c797b1bc..3736978505707e3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -717,6 +717,19 @@ LogicalResult NVVM::LdMatrixOp::verify() {
   return success();
 }
 
+LogicalResult NVVM::StMatrixOp::verify() {
+  unsigned addressSpace =
+      llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
+  if (addressSpace != NVVM::kSharedMemorySpace)
+    return emitOpError("expected source pointer in memory space 3");
+
+  int numMatrix = getSources().size();
+  if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
+    return emitOpError("expected num attribute to be 1, 2 or 4");
+
+  return success();
+}
+
 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
   if (typeA == NVVM::WGMMATypes::tf32)
     return 8;
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 0d0ac9637438a95..3bb0ab90775edf5 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -507,6 +507,30 @@ func.func @elect_one_leader_sync() {
 
 // -----
 
+// CHECK-LABEL: @stmatrix(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>, 
+// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32)
+llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) {
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1}", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2}", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1}", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2}", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
+  nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32
+  nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32
+  nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32
+  nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32
+  nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32
+  nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32
+  llvm.return 
+}
+
+// -----
+
 // CHECK-LABEL: @init_mbarrier_arrive_expect_tx
 llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
   //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"

ptx += ".trans";
if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1}";
if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2}";
if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a switch here with a default case to llvm_unreachable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The verifier catches that actually

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants