Skip to content

Commit 5f98dd5

Browse files
authored
[MLIR][NVVM] Update Wgmma.fence Ops to use intrinsics (#120956)
This PR updates the WgmmaFenceAlignedOp, WgmmaGroupSyncAlignedOp, and WgmmaWaitGroupSyncOp Ops in the NVVM Dialect to lower to the corresponding intrinsics instead of inline-ptx. The existing test under Conversion/NVVMToLLVM is updated to check for the new patterns and separate tests are added under Target/LLVMIR to verify the lowered intrinsics.
1 parent a35640f commit 5f98dd5

File tree

4 files changed

+43
-19
lines changed

4 files changed

+43
-19
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2130,7 +2130,7 @@ def NVVM_CpAsyncBulkTensorReduceOp :
21302130
// NVVM Wgmma Ops
21312131
//===----------------------------------------------------------------------===//
21322132

2133-
def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
2133+
def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
21342134
let arguments = (ins);
21352135
let description = [{
21362136
Enforce an ordering of register accesses between warpgroup level matrix
@@ -2139,34 +2139,34 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
21392139
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence)
21402140
}];
21412141
let assemblyFormat = "attr-dict";
2142-
let extraClassDefinition = [{
2143-
std::string $cppClass::getPtx() { return std::string("wgmma.fence.sync.aligned;"); }
2142+
string llvmBuilder = [{
2143+
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_fence_sync_aligned);
21442144
}];
21452145
}
21462146

2147-
def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.aligned">,
2147+
def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">,
21482148
Arguments<(ins )> {
21492149
let assemblyFormat = "attr-dict";
21502150
let description = [{
21512151
Commits all prior uncommitted warpgroup level matrix multiplication operations.
21522152

21532153
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group)
21542154
}];
2155-
let extraClassDefinition = [{
2156-
std::string $cppClass::getPtx() { return std::string("wgmma.commit_group.sync.aligned;"); }
2155+
string llvmBuilder = [{
2156+
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_commit_group_sync_aligned);
21572157
}];
21582158
}
21592159

2160-
def NVVM_WgmmaWaitGroupSyncOp : NVVM_PTXBuilder_Op<"wgmma.wait.group.sync.aligned">{
2161-
let arguments = (ins I32Attr:$group);
2160+
def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned">{
2161+
let arguments = (ins I64Attr:$group);
21622162
let assemblyFormat = "attr-dict $group";
21632163
let description = [{
21642164
Signal the completion of a preceding warpgroup operation.
21652165

21662166
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-wait-group)
21672167
}];
2168-
let extraClassDefinition = [{
2169-
std::string $cppClass::getPtx() { return std::string("wgmma.wait_group.sync.aligned %0;"); }
2168+
string llvmBuilder = [{
2169+
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_wait_group_sync_aligned, builder.getInt64($group));
21702170
}];
21712171
}
21722172

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
771771

772772
let arguments = (ins NVGPU_WarpgroupMatrixDescriptor:$descriptorA,
773773
NVGPU_WarpgroupMatrixDescriptor:$descriptorB,
774-
DefaultValuedOptionalAttr<I32Attr, "1">:$waitGroup,
774+
DefaultValuedOptionalAttr<I64Attr, "1">:$waitGroup,
775775
OptionalAttr<UnitAttr>:$transposeA,
776776
OptionalAttr<UnitAttr>:$transposeB,
777777
NVGPU_WarpgroupAccumulator:$matrixC);

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -266,19 +266,17 @@ func.func @wgmma_execute() {
266266
nvvm.wgmma.fence.aligned
267267
nvvm.wgmma.commit.group.sync.aligned
268268
nvvm.wgmma.wait.group.sync.aligned 0
269-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
270-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
271-
// CHECK: %[[S0:.+]] = llvm.mlir.constant(0 : i32) : i32
272-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S0]] : (i32)
269+
// CHECK: nvvm.wgmma.fence.aligned
270+
// CHECK: nvvm.wgmma.commit.group.sync.aligned
271+
// CHECK: nvvm.wgmma.wait.group.sync.aligned 0
273272

274273

275274
nvvm.wgmma.fence.aligned
276275
nvvm.wgmma.commit.group.sync.aligned
277276
nvvm.wgmma.wait.group.sync.aligned 5
278-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
279-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
280-
// CHECK: %[[S1:.+]] = llvm.mlir.constant(5 : i32) : i32
281-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S1]] : (i32)
277+
// CHECK: nvvm.wgmma.fence.aligned
278+
// CHECK: nvvm.wgmma.commit.group.sync.aligned
279+
// CHECK: nvvm.wgmma.wait.group.sync.aligned 5
282280
return
283281
}
284282

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,3 +714,29 @@ llvm.func @nvvm_breakpoint() {
714714
nvvm.breakpoint
715715
llvm.return
716716
}
717+
718+
// -----
719+
// CHECK-LABEL: @nvvm_wgmma_fence_aligned
720+
llvm.func @nvvm_wgmma_fence_aligned() {
721+
// CHECK: call void @llvm.nvvm.wgmma.fence.sync.aligned()
722+
nvvm.wgmma.fence.aligned
723+
llvm.return
724+
}
725+
726+
// -----
727+
// CHECK-LABEL: @nvvm_wgmma_commit_group_aligned
728+
llvm.func @nvvm_wgmma_commit_group_aligned() {
729+
// CHECK: call void @llvm.nvvm.wgmma.commit_group.sync.aligned()
730+
nvvm.wgmma.commit.group.sync.aligned
731+
llvm.return
732+
}
733+
734+
// -----
735+
// CHECK-LABEL: @nvvm_wgmma_wait_group_aligned
736+
llvm.func @nvvm_wgmma_wait_group_aligned() {
737+
// CHECK: call void @llvm.nvvm.wgmma.wait_group.sync.aligned(i64 0)
738+
nvvm.wgmma.wait.group.sync.aligned 0
739+
// CHECK: call void @llvm.nvvm.wgmma.wait_group.sync.aligned(i64 20)
740+
nvvm.wgmma.wait.group.sync.aligned 20
741+
llvm.return
742+
}

0 commit comments

Comments
 (0)