Skip to content

Commit 5316d19

Browse files
authored
[mlir][nvvm] Introduce nvvm.stmatrix Op (#69467)
This PR adds `nvvm.stmatrix` Op to NVVM dialect. The Op collectively store one or more matrices across all threads in a warp to the given address location in shared memory.
1 parent 5341d54 commit 5316d19

File tree

3 files changed

+66
-0
lines changed

3 files changed

+66
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,6 +1186,35 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
11861186
let hasVerifier = 1;
11871187
}
11881188

1189+
def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
1190+
Arguments<(ins LLVM_i8Ptr_shared:$ptr,
1191+
Variadic<I32>:$sources,
1192+
MMALayoutAttr:$layout)> {
1193+
let summary = "cooperative matrix store";
1194+
let description = [{
1195+
Collectively store one or more matrices across all threads in a warp to the
1196+
location indicated by the address operand $ptr in shared memory.
1197+
[For more information, see PTX ISA]
1198+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix)
1199+
}];
1200+
1201+
let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
1202+
let extraClassDefinition = [{
1203+
std::string $cppClass::getPtx() {
1204+
int d = getSources().size();
1205+
std::string ptx = "stmatrix.sync.aligned";
1206+
ptx += ".x" + std::to_string(d);
1207+
if (getLayout() == NVVM::MMALayout::col)
1208+
ptx += ".trans";
1209+
if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1}";
1210+
if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2}";
1211+
if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};";
1212+
return ptx;
1213+
}
1214+
}];
1215+
let hasVerifier = 1;
1216+
}
1217+
11891218
def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
11901219
Results<(outs AnyType:$res)>,
11911220
Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,19 @@ LogicalResult NVVM::LdMatrixOp::verify() {
717717
return success();
718718
}
719719

720+
LogicalResult NVVM::StMatrixOp::verify() {
721+
unsigned addressSpace =
722+
llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
723+
if (addressSpace != NVVM::kSharedMemorySpace)
724+
return emitOpError("expected source pointer in memory space 3");
725+
726+
int numMatrix = getSources().size();
727+
if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
728+
return emitOpError("expected num attribute to be 1, 2 or 4");
729+
730+
return success();
731+
}
732+
720733
FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
721734
if (typeA == NVVM::WGMMATypes::tf32)
722735
return 8;

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,30 @@ func.func @elect_one_leader_sync() {
507507

508508
// -----
509509

510+
// CHECK-LABEL: @stmatrix(
511+
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>,
512+
// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32,
513+
// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32,
514+
// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32,
515+
// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32)
516+
llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) {
517+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1}", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
518+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2}", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
519+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
520+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1}", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
521+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2}", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
522+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
523+
nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32
524+
nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32
525+
nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32
526+
nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32
527+
nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32
528+
nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32
529+
llvm.return
530+
}
531+
532+
// -----
533+
510534
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
511535
llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
512536
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"

0 commit comments

Comments
 (0)