Skip to content

Commit 1ade5b5

Browse files
committed
Merge ops
1 parent b8caf55 commit 1ade5b5

File tree

7 files changed

+65
-79
lines changed

7 files changed

+65
-79
lines changed

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6542,7 +6542,8 @@ IntrinsicLibrary::genVoteBallotSync(mlir::Type resultType,
65426542
mlir::Value arg1 =
65436543
builder.create<fir::ConvertOp>(loc, builder.getI1Type(), args[1]);
65446544
return builder
6545-
.create<mlir::NVVM::VoteBallotOp>(loc, resultType, args[0], arg1)
6545+
.create<mlir::NVVM::VoteSyncOp>(loc, resultType, args[0], arg1,
6546+
mlir::NVVM::VoteSyncKind::ballot)
65466547
.getResult();
65476548
}
65486549

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ end subroutine
303303
! CHECK-LABEL: func.func @_QPtestvote()
304304
! CHECK: fir.call @llvm.nvvm.vote.all.sync
305305
! CHECK: fir.call @llvm.nvvm.vote.any.sync
306-
! CHECK: %{{.*}} = nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
306+
! CHECK: %{{.*}} = nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
307307

308308
! CHECK-DAG: func.func private @__ldca_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
309309
! CHECK-DAG: func.func private @__ldcg_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -808,35 +808,30 @@ def NVVM_ShflOp :
808808
let hasVerifier = 1;
809809
}
810810

811-
def NVVM_VoteBallotOp :
812-
NVVM_Op<"vote.ballot.sync">,
813-
Results<(outs LLVM_Type:$res)>,
814-
Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
815-
string llvmBuilder = [{
816-
$res = createIntrinsicCall(builder,
817-
llvm::Intrinsic::nvvm_vote_ballot_sync, {$mask, $pred});
818-
}];
819-
let hasCustomAssemblyFormat = 1;
811+
def VoteSyncKindAny : I32EnumAttrCase<"any", 0>;
812+
def VoteSyncKindAll : I32EnumAttrCase<"all", 1>;
813+
def VoteSyncKindBallot : I32EnumAttrCase<"ballot", 2>;
814+
def VoteSyncKindUni : I32EnumAttrCase<"uni", 3>;
815+
816+
def VoteSyncKind : I32EnumAttr<"VoteSyncKind", "NVVM vote sync kind",
817+
[VoteSyncKindAny, VoteSyncKindAll,
818+
VoteSyncKindBallot, VoteSyncKindUni]> {
819+
let genSpecializedAttr = 0;
820+
let cppNamespace = "::mlir::NVVM";
820821
}
821822

822-
def NVVM_VoteAllSyncOp : NVVM_Op<"vote.all.sync">,
823-
Results<(outs LLVM_Type:$res)>,
824-
Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
825-
string llvmBuilder = [{
826-
$res = createIntrinsicCall(builder,
827-
llvm::Intrinsic::nvvm_vote_all_sync, {$mask, $pred});
828-
}];
829-
let hasCustomAssemblyFormat = 1;
830-
}
823+
def VoteSyncKindAttr : EnumAttr<NVVM_Dialect, VoteSyncKind, "vote_sync_kind">;
831824

832-
def NVVM_VoteAnySyncOp : NVVM_Op<"vote.any.sync">,
833-
Results<(outs LLVM_Type:$res)>,
834-
Arguments<(ins LLVM_Type:$mask, LLVM_Type:$pred)> {
825+
def NVVM_VoteSyncOp
826+
: NVVM_Op<"vote.sync">,
827+
Results<(outs AnyTypeOf<[I32, I1]>:$res)>,
828+
Arguments<(ins I32:$mask, I1:$pred, VoteSyncKindAttr:$kind)> {
835829
string llvmBuilder = [{
836-
$res = createIntrinsicCall(builder,
837-
llvm::Intrinsic::nvvm_vote_all_sync, {$mask, $pred});
830+
auto intId = getVoteSyncIntrinsicId($kind);
831+
$res = createIntrinsicCall(builder, intId, {$mask, $pred});
838832
}];
839-
let hasCustomAssemblyFormat = 1;
833+
let assemblyFormat = "$kind $mask `,` $pred attr-dict `->` type($res)";
834+
let hasVerifier = 1;
840835
}
841836

842837
def NVVM_SyncWarpOp :

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 13 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -48,52 +48,6 @@ using namespace NVVM;
4848
#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
4949
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
5050

51-
//===----------------------------------------------------------------------===//
52-
// Printing/parsing for NVVM ops
53-
//===----------------------------------------------------------------------===//
54-
55-
static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
56-
p << " " << op->getOperands();
57-
if (op->getNumResults() > 0)
58-
p << " : " << op->getResultTypes();
59-
}
60-
61-
static ParseResult parseVoteOps(OpAsmParser &parser, OperationState &result) {
62-
MLIRContext *context = parser.getContext();
63-
auto int32Ty = IntegerType::get(context, 32);
64-
auto int1Ty = IntegerType::get(context, 1);
65-
66-
SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
67-
Type type;
68-
return failure(parser.parseOperandList(ops) ||
69-
parser.parseOptionalAttrDict(result.attributes) ||
70-
parser.parseColonType(type) ||
71-
parser.addTypeToList(type, result.types) ||
72-
parser.resolveOperands(ops, {int32Ty, int1Ty},
73-
parser.getNameLoc(), result.operands));
74-
}
75-
76-
// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
77-
ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
78-
return parseVoteOps(parser, result);
79-
}
80-
81-
void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
82-
83-
// <operation> ::= `llvm.nvvm.vote.all.sync %mask, %pred` : result_type
84-
ParseResult VoteAllSyncOp::parse(OpAsmParser &parser, OperationState &result) {
85-
return parseVoteOps(parser, result);
86-
}
87-
88-
void VoteAllSyncOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
89-
90-
// <operation> ::= `llvm.nvvm.vote.any.sync %mask, %pred` : result_type
91-
ParseResult VoteAnySyncOp::parse(OpAsmParser &parser, OperationState &result) {
92-
return parseVoteOps(parser, result);
93-
}
94-
95-
void VoteAnySyncOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
96-
9751
//===----------------------------------------------------------------------===//
9852
// Verifier methods
9953
//===----------------------------------------------------------------------===//
@@ -1178,6 +1132,19 @@ LogicalResult NVVM::MatchSyncOp::verify() {
11781132
return success();
11791133
}
11801134

1135+
LogicalResult NVVM::VoteSyncOp::verify() {
1136+
if (getKind() == NVVM::VoteSyncKind::ballot) {
1137+
if (!getType().isInteger(32)) {
1138+
return emitOpError("vote.sync 'ballot' returns an i32");
1139+
}
1140+
} else {
1141+
if (!getType().isInteger(1)) {
1142+
return emitOpError("match.sync 'any', 'all' and 'uni' returns an i1");
1143+
}
1144+
}
1145+
return success();
1146+
}
1147+
11811148
//===----------------------------------------------------------------------===//
11821149
// getIntrinsicID/getIntrinsicIDAndArgs methods
11831150
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,21 @@ static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType,
121121
}
122122
}
123123

124+
static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
125+
switch (kind) {
126+
case NVVM::VoteSyncKind::any:
127+
return llvm::Intrinsic::nvvm_vote_any_sync;
128+
case NVVM::VoteSyncKind::all:
129+
return llvm::Intrinsic::nvvm_vote_all_sync;
130+
case NVVM::VoteSyncKind::ballot:
131+
return llvm::Intrinsic::nvvm_vote_ballot_sync;
132+
case NVVM::VoteSyncKind::uni:
133+
return llvm::Intrinsic::nvvm_vote_uni_sync;
134+
default:
135+
llvm_unreachable("unsupported vote kind");
136+
}
137+
}
138+
124139
/// Return the intrinsic ID associated with ldmatrix for the given paramters.
125140
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
126141
int32_t num) {

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,14 @@ func.func @nvvm_shfl_pred(
129129

130130
// CHECK-LABEL: @nvvm_vote(
131131
func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
132-
// CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
133-
%0 = nvvm.vote.ballot.sync %arg0, %arg1 : i32
134-
// CHECK: nvvm.vote.all.sync %{{.*}}, %{{.*}} : i32
135-
%1 = nvvm.vote.all.sync %arg0, %arg1 : i32
136-
// CHECK: nvvm.vote.any.sync %{{.*}}, %{{.*}} : i32
137-
%2 = nvvm.vote.any.sync %arg0, %arg1 : i32
132+
// CHECK: nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
133+
%0 = nvvm.vote.sync ballot %arg0, %arg1 -> i32
134+
// CHECK: nvvm.vote.sync all %{{.*}}, %{{.*}} -> i1
135+
%1 = nvvm.vote.sync all %arg0, %arg1 -> i1
136+
// CHECK: nvvm.vote.sync any %{{.*}}, %{{.*}} -> i1
137+
%2 = nvvm.vote.sync any %arg0, %arg1 -> i1
138+
// CHECK: nvvm.vote.sync uni %{{.*}}, %{{.*}} -> i1
139+
%3 = nvvm.vote.sync uni %arg0, %arg1 -> i1
138140
llvm.return %0 : i32
139141
}
140142

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,13 @@ llvm.func @nvvm_shfl_pred(
255255
// CHECK-LABEL: @nvvm_vote
256256
llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 {
257257
// CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}})
258-
%3 = nvvm.vote.ballot.sync %0, %1 : i32
258+
%3 = nvvm.vote.sync ballot %0, %1 -> i32
259+
// CHECK: call i1 @llvm.nvvm.vote.all.sync(i32 %{{.*}}, i1 %{{.*}})
260+
%4 = nvvm.vote.sync all %0, %1 -> i1
261+
// CHECK: call i1 @llvm.nvvm.vote.any.sync(i32 %{{.*}}, i1 %{{.*}})
262+
%5 = nvvm.vote.sync any %0, %1 -> i1
263+
// CHECK: call i1 @llvm.nvvm.vote.uni.sync(i32 %{{.*}}, i1 %{{.*}})
264+
%6 = nvvm.vote.sync uni %0, %1 -> i1
259265
llvm.return %3 : i32
260266
}
261267

0 commit comments

Comments
 (0)