Skip to content

Commit c42952a

Browse files
authored
[MLIR][NVVM] Add support for match.sync Op (#130718)
This change adds the `match.sync` Op to the MLIR NVVM dialect to generate the `match.sync` PTX instruction. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-match-sync
1 parent 49b8d84 commit c42952a

File tree

6 files changed

+122
-0
lines changed

6 files changed

+122
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
1919
include "mlir/Interfaces/SideEffectInterfaces.td"
2020
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
2121
include "mlir/Interfaces/InferIntRangeInterface.td"
22+
include "mlir/Dialect/LLVMIR/LLVMTypes.td"
2223

2324
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
2425
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
@@ -2583,6 +2584,52 @@ def NVVM_MapaOp: NVVM_Op<"mapa",
25832584
let assemblyFormat = "$a`,` $b attr-dict `:` type($a) `->` type($res)";
25842585
}
25852586

2587+
//===----------------------------------------------------------------------===//
2588+
// NVVM match.sync Op
2589+
//===----------------------------------------------------------------------===//
2590+
2591+
def MatchSyncKindAny : I32EnumAttrCase<"any", 0>;
2592+
def MatchSyncKindAll : I32EnumAttrCase<"all", 1>;
2593+
2594+
def MatchSyncKind : I32EnumAttr<"MatchSyncKind", "NVVM match sync kind",
2595+
[MatchSyncKindAny, MatchSyncKindAll]> {
2596+
let genSpecializedAttr = 0;
2597+
let cppNamespace = "::mlir::NVVM";
2598+
}
2599+
2600+
def MatchSyncKindAttr : EnumAttr<NVVM_Dialect, MatchSyncKind, "match_sync_kind">;
2601+
2602+
def NVVM_MatchSyncOp : NVVM_Op<"match.sync">,
2603+
Results<(outs AnyTypeOf<[I32, LLVMStructType]>:$res)>,
2604+
Arguments<(ins I32:$thread_mask,
2605+
AnyTypeOf<[I32, I64]>:$val,
2606+
MatchSyncKindAttr:$kind)> {
2607+
let summary = "Broadcast and compare a value across threads in warp";
2608+
let description = [{
2609+
The `match.sync` op performs broadcast and compare of operand `val` across
2610+
all non-exited threads in `thread_mask` and returns a mask depending on the
2611+
kind and an optional predicate.
2612+
2613+
The matching operation kinds are:
2614+
- `any`: Returns a mask corresponding to the non-exited threads in the
2615+
`thread_mask` that have the same value of operand `val`.
2616+
- `all`: Returns a mask and a predicate. If all non-exited threads in the
2617+
`thread_mask` have the same value of operand `val`, the predicate is set to
2618+
true and the mask corresponds to the non-exited threads in the
2619+
`thread_mask`. Otherwise, the predicate is set to false and the mask is 0.
2620+
2621+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-match-sync)
2622+
}];
2623+
string llvmBuilder = [{
2624+
auto intId = getMatchSyncIntrinsicId(
2625+
op.getVal().getType(), $kind);
2626+
$res = createIntrinsicCall(builder,
2627+
intId, {$thread_mask, $val});
2628+
}];
2629+
let assemblyFormat = "$kind $thread_mask `,` $val attr-dict `:` type($val) `->` type($res)";
2630+
let hasVerifier = 1;
2631+
}
2632+
25862633
def NVVM_Exit : NVVM_Op<"exit"> {
25872634
let summary = "Exit Op";
25882635
let description = [{

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,22 @@ LogicalResult NVVM::Tcgen05CpOp::verify() {
11381138
return success();
11391139
}
11401140

1141+
LogicalResult NVVM::MatchSyncOp::verify() {
1142+
if (getKind() == NVVM::MatchSyncKind::all) {
1143+
auto Type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
1144+
if (!Type || Type.getBody().size() != 2 ||
1145+
!Type.getBody()[0].isInteger(32) || !Type.getBody()[1].isInteger(1)) {
1146+
return emitOpError("match.sync 'all' returns a two element struct with "
1147+
"first element as i32 and second element as i1");
1148+
}
1149+
} else {
1150+
if (!getType().isInteger(32)) {
1151+
return emitOpError("match.sync 'any' returns an i32");
1152+
}
1153+
}
1154+
return success();
1155+
}
1156+
11411157
//===----------------------------------------------------------------------===//
11421158
// getIntrinsicID/getIntrinsicIDAndArgs methods
11431159
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,23 @@ static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
106106
llvm_unreachable("unknown shuffle kind");
107107
}
108108

109+
static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType,
110+
NVVM::MatchSyncKind kind) {
111+
switch (kind) {
112+
case NVVM::MatchSyncKind::any:
113+
return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32
114+
: llvm::Intrinsic::nvvm_match_any_sync_i64;
115+
case NVVM::MatchSyncKind::all:
116+
// match.all instruction has two variants -- one returns a single value,
117+
// another returns a pair {value, predicate}. We currently only implement
118+
// the latter as that's the variant exposed by CUDA API.
119+
return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p
120+
: llvm::Intrinsic::nvvm_match_all_sync_i64p;
121+
default:
122+
llvm_unreachable("unknown match sync kind");
123+
}
124+
}
125+
109126
/// Return the intrinsic ID associated with ldmatrix for the given paramters.
110127
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
111128
int32_t num) {

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,19 @@ func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
550550
return
551551
}
552552

553+
// CHECK-LABEL: @match_sync
554+
func.func @match_sync(%val32: i32, %val64: i64, %thread_mask: i32) {
555+
// CHECK: nvvm.match.sync any %{{.*}}, %{{.*}} : i32 -> i32
556+
%0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> i32
557+
// CHECK: nvvm.match.sync all %{{.*}}, %{{.*}} : i32 -> !llvm.struct<(i32, i1)>
558+
%1 = nvvm.match.sync all %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
559+
// CHECK: nvvm.match.sync any %{{.*}}, %{{.*}} : i64 -> i32
560+
%2 = nvvm.match.sync any %thread_mask, %val64 : i64 -> i32
561+
// CHECK: nvvm.match.sync all %{{.*}}, %{{.*}} : i64 -> !llvm.struct<(i32, i1)>
562+
%3 = nvvm.match.sync all %thread_mask, %val64 : i64 -> !llvm.struct<(i32, i1)>
563+
return
564+
}
565+
553566
// -----
554567

555568
// Just check these don't emit errors.

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,19 @@ llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
152152
}
153153
llvm.return
154154
}
155+
156+
// -----
157+
158+
llvm.func @nvvm_match_sync_all(%val32: i32, %thread_mask: i32) {
159+
// expected-error @below {{match.sync 'all' returns a two element struct with first element as i32 and second element as i1}}
160+
%0 = nvvm.match.sync all %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i8)>
161+
llvm.return
162+
}
163+
164+
// -----
165+
166+
llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
167+
// expected-error @below {{match.sync 'any' returns an i32}}
168+
%0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
169+
llvm.return
170+
}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,3 +810,16 @@ llvm.func @nvvm_redux_sync_f32(%value: f32, %offset: i32) {
810810
%7 = nvvm.redux.sync fmax %value, %offset {abs = true, nan = true}: f32 -> f32
811811
llvm.return
812812
}
813+
814+
// CHECK-LABEL: @nvvm_match_sync
815+
llvm.func @nvvm_match_sync(%mask: i32, %val32: i32, %val64: i64) {
816+
// CHECK: call i32 @llvm.nvvm.match.any.sync.i32(i32 %{{.*}}, i32 %{{.*}})
817+
%0 = nvvm.match.sync any %mask, %val32 : i32 -> i32
818+
// CHECK: call { i32, i1 } @llvm.nvvm.match.all.sync.i32p(i32 %{{.*}}, i32 %{{.*}})
819+
%1 = nvvm.match.sync all %mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
820+
// CHECK: call i32 @llvm.nvvm.match.any.sync.i64(i32 %{{.*}}, i64 %{{.*}})
821+
%2 = nvvm.match.sync any %mask, %val64 : i64 -> i32
822+
// CHECK: call { i32, i1 } @llvm.nvvm.match.all.sync.i64p(i32 %{{.*}}, i64 %{{.*}})
823+
%3 = nvvm.match.sync all %mask, %val64 : i64 -> !llvm.struct<(i32, i1)>
824+
llvm.return
825+
}

0 commit comments

Comments
 (0)