Skip to content

Commit e33aec8

Browse files
authored
[MLIR][NVVM] Update the elect.sync Op to use intrinsics (#113757)
Recently, we added an intrinsic for the elect.sync PTX instruction (PR 104780). This patch updates the corresponding Op in NVVM Dialect to lower to the intrinsic instead of inline-ptx. The existing test under Conversion/ is migrated to check for the new pattern. A separate test is added to verify the lowered intrinsic under the Target/ directory. Signed-off-by: Durgadoss R <[email protected]>
1 parent 7fe149c commit e33aec8

File tree

3 files changed

+28
-22
lines changed

3 files changed

+28
-22
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -783,24 +783,27 @@ def NVVM_SyncWarpOp :
783783
let assemblyFormat = "$mask attr-dict `:` type($mask)";
784784
}
785785

786-
787-
def NVVM_ElectSyncOp : NVVM_Op<"elect.sync",
788-
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>
786+
def NVVM_ElectSyncOp : NVVM_Op<"elect.sync">
789787
{
788+
let summary = "Elect one leader thread";
789+
let description = [{
790+
The `elect.sync` instruction elects one predicated active leader
791+
thread from among a set of threads specified in membermask.
792+
The membermask is set to `0xFFFFFFFF` for the current version
793+
of this Op. The predicate result is set to `True` for the
794+
leader thread, and `False` for all other threads.
795+
796+
[For more information, see PTX ISA]
797+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-elect-sync)
798+
}];
799+
790800
let results = (outs I1:$pred);
791801
let assemblyFormat = "attr-dict `->` type(results)";
792-
let extraClassDefinition = [{
793-
std::string $cppClass::getPtx() {
794-
return std::string(
795-
"{ \n"
796-
".reg .u32 rx; \n"
797-
".reg .pred px; \n"
798-
" mov.pred %0, 0; \n"
799-
" elect.sync rx | px, 0xFFFFFFFF;\n"
800-
"@px mov.pred %0, 1; \n"
801-
"}\n"
802-
);
803-
}
802+
string llvmBuilder = [{
803+
auto *resultTuple = createIntrinsicCall(builder,
804+
llvm::Intrinsic::nvvm_elect_sync, {builder.getInt32(0xFFFFFFFF)});
805+
// Extract the second value into $pred
806+
$pred = builder.CreateExtractValue(resultTuple, 1);
804807
}];
805808
}
806809

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -579,13 +579,7 @@ func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
579579
// -----
580580

581581
func.func @elect_one_leader_sync() {
582-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{
583-
// CHECK-SAME: .reg .u32 rx;
584-
// CHECK-SAME: .reg .pred px;
585-
// CHECK-SAME: mov.pred $0, 0;
586-
// CHECK-SAME: elect.sync rx | px, 0xFFFFFFFF;
587-
// CHECK-SAME: @px mov.pred $0, 1;
588-
// CHECK-SAME: "=b" : () -> i1
582+
// CHECK: %[[RES:.*]] = nvvm.elect.sync -> i1
589583
%cnd = nvvm.elect.sync -> i1
590584
return
591585
}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,15 @@ llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 {
259259
llvm.return %3 : i32
260260
}
261261

262+
// CHECK-LABEL: @nvvm_elect_sync
263+
llvm.func @nvvm_elect_sync() -> i1 {
264+
// CHECK: %[[RES:.*]] = call { i32, i1 } @llvm.nvvm.elect.sync(i32 -1)
265+
// CHECK-NEXT: %[[PRED:.*]] = extractvalue { i32, i1 } %[[RES]], 1
266+
// CHECK-NEXT: ret i1 %[[PRED]]
267+
%0 = nvvm.elect.sync -> i1
268+
llvm.return %0 : i1
269+
}
270+
262271
// CHECK-LABEL: @nvvm_mma_mn8n8k4_row_col_f32_f32
263272
llvm.func @nvvm_mma_mn8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
264273
%b0 : vector<2xf16>, %b1 : vector<2xf16>,

0 commit comments

Comments
 (0)