Skip to content

Commit 3e59ff2

Browse files
authored
[flang][cuda] Fix pred type for vote functions (#134166)
1 parent cfee056 commit 3e59ff2

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6508,12 +6508,13 @@ IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
65086508
}
65096509

65106510
static mlir::Value genVoteSync(fir::FirOpBuilder &builder, mlir::Location loc,
6511-
llvm::StringRef funcName,
6511+
llvm::StringRef funcName, mlir::Type resTy,
65126512
llvm::ArrayRef<mlir::Value> args) {
65136513
mlir::MLIRContext *context = builder.getContext();
65146514
mlir::Type i32Ty = builder.getI32Type();
6515+
mlir::Type i1Ty = builder.getI1Type();
65156516
mlir::FunctionType ftype =
6516-
mlir::FunctionType::get(context, {i32Ty, i32Ty}, {i32Ty});
6517+
mlir::FunctionType::get(context, {i32Ty, i1Ty}, {resTy});
65176518
auto funcOp = builder.createFunction(loc, funcName, ftype);
65186519
llvm::SmallVector<mlir::Value> filteredArgs;
65196520
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
@@ -6523,22 +6524,25 @@ static mlir::Value genVoteSync(fir::FirOpBuilder &builder, mlir::Location loc,
65236524
mlir::Value IntrinsicLibrary::genVoteAllSync(mlir::Type resultType,
65246525
llvm::ArrayRef<mlir::Value> args) {
65256526
assert(args.size() == 2);
6526-
return genVoteSync(builder, loc, "llvm.nvvm.vote.all.sync", args);
6527+
return genVoteSync(builder, loc, "llvm.nvvm.vote.all.sync",
6528+
builder.getI1Type(), args);
65276529
}
65286530

65296531
// ANY_SYNC
65306532
mlir::Value IntrinsicLibrary::genVoteAnySync(mlir::Type resultType,
65316533
llvm::ArrayRef<mlir::Value> args) {
65326534
assert(args.size() == 2);
6533-
return genVoteSync(builder, loc, "llvm.nvvm.vote.any.sync", args);
6535+
return genVoteSync(builder, loc, "llvm.nvvm.vote.any.sync",
6536+
builder.getI1Type(), args);
65346537
}
65356538

65366539
// BALLOT_SYNC
65376540
mlir::Value
65386541
IntrinsicLibrary::genVoteBallotSync(mlir::Type resultType,
65396542
llvm::ArrayRef<mlir::Value> args) {
65406543
assert(args.size() == 2);
6541-
return genVoteSync(builder, loc, "llvm.nvvm.vote.ballot.sync", args);
6544+
return genVoteSync(builder, loc, "llvm.nvvm.vote.ballot.sync",
6545+
builder.getI32Type(), args);
65426546
}
65436547

65446548
// MATCH_ANY_SYNC

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,10 +297,11 @@ end
297297
! CHECK: fir.call @__ldcv_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
298298

299299
attributes(device) subroutine testVote()
300-
integer :: a, ipred, mask, v32
301-
a = all_sync(mask, v32)
302-
a = any_sync(mask, v32)
303-
a = ballot_sync(mask, v32)
300+
integer :: a, ipred, mask
301+
logical(4) :: pred
302+
a = all_sync(mask, pred)
303+
a = any_sync(mask, pred)
304+
a = ballot_sync(mask, pred)
304305
end subroutine
305306

306307
! CHECK-LABEL: func.func @_QPtestvote()

0 commit comments

Comments
 (0)