-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][nvvm] Introduce nvvm.stmatrix
Op
#69467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR adds `nvvm.stmatrix` Op to NVVM dialect. The Op collectively store one or more matrices across all threads in a warp to the given address location in shared memory.
@llvm/pr-subscribers-mlir-nvgpu @llvm/pr-subscribers-mlir Author: Guray Ozen (grypp) ChangesThis PR adds Full diff: https://github.com/llvm/llvm-project/pull/69467.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index cefdd7cc4033a11..9cda7862ccb0fe3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1186,6 +1186,35 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
let hasVerifier = 1;
}
+def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
+ Arguments<(ins LLVM_i8Ptr_shared:$ptr,
+ Variadic<I32>:$sources,
+ MMALayoutAttr:$layout)> {
+ let summary = "cooperative matrix store";
+ let description = [{
+ Collectively store one or more matrices across all threads in a warp to the
+ location indicated by the address operand $ptr in shared memory.
+ [For more information, see PTX ISA]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix)
+ }];
+
+ let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() {
+ int d = getSources().size();
+ std::string ptx = "stmatrix.sync.aligned";
+ ptx += ".x" + std::to_string(d);
+ if (getLayout() == NVVM::MMALayout::col)
+ ptx += ".trans";
+ if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1}";
+ if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2}";
+ if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};";
+ return ptx;
+ }
+ }];
+ let hasVerifier = 1;
+}
+
def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
Results<(outs AnyType:$res)>,
Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 92df023c797b1bc..3736978505707e3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -717,6 +717,19 @@ LogicalResult NVVM::LdMatrixOp::verify() {
return success();
}
+LogicalResult NVVM::StMatrixOp::verify() {
+ unsigned addressSpace =
+ llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
+ if (addressSpace != NVVM::kSharedMemorySpace)
+ return emitOpError("expected source pointer in memory space 3");
+
+ int numMatrix = getSources().size();
+ if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
+ return emitOpError("expected num attribute to be 1, 2 or 4");
+
+ return success();
+}
+
FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
if (typeA == NVVM::WGMMATypes::tf32)
return 8;
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 0d0ac9637438a95..3bb0ab90775edf5 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -507,6 +507,30 @@ func.func @elect_one_leader_sync() {
// -----
+// CHECK-LABEL: @stmatrix(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>,
+// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32,
+// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32)
+llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) {
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1}", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2}", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1}", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2}", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
+// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
+ nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32
+ nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32
+ nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32
+ nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32
+ nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32
+ nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32
+ llvm.return
+}
+
+// -----
+
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
|
ptx += ".trans"; | ||
if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1}"; | ||
if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2}"; | ||
if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use a switch
here with a default case to llvm_unreachable
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The verifier catches that actually
This PR adds
nvvm.stmatrix
Op to NVVM dialect. The Op collectively store one or more matrices across all threads in a warp to the given address location in shared memory.