Skip to content

Commit db21ae7

Browse files
authored
[flang][cuda] Support any_sync and ballot_sync (#134135)
1 parent 066787b commit db21ae7

File tree

4 files changed

+53
-8
lines changed

4 files changed

+53
-8
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,8 @@ struct IntrinsicLibrary {
442442
fir::ExtendedValue genUnpack(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
443443
fir::ExtendedValue genVerify(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
444444
mlir::Value genVoteAllSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
445+
mlir::Value genVoteAnySync(mlir::Type, llvm::ArrayRef<mlir::Value>);
446+
mlir::Value genVoteBallotSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
445447

446448
/// Implement all conversion functions like DBLE, the first argument is
447449
/// the value to convert. There may be an additional KIND arguments that

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,10 @@ static constexpr IntrinsicHandler handlers[]{
273273
&I::genAny,
274274
{{{"mask", asAddr}, {"dim", asValue}}},
275275
/*isElemental=*/false},
276+
{"any_sync",
277+
&I::genVoteAnySync,
278+
{{{"mask", asValue}, {"pred", asValue}}},
279+
/*isElemental=*/false},
276280
{"asind", &I::genAsind},
277281
{"associated",
278282
&I::genAssociated,
@@ -335,6 +339,10 @@ static constexpr IntrinsicHandler handlers[]{
335339
{"atomicsubi", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
336340
{"atomicsubl", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
337341
{"atomicxori", &I::genAtomicXor, {{{"a", asAddr}, {"v", asValue}}}, false},
342+
{"ballot_sync",
343+
&I::genVoteBallotSync,
344+
{{{"mask", asValue}, {"pred", asValue}}},
345+
/*isElemental=*/false},
338346
{"bessel_jn",
339347
&I::genBesselJn,
340348
{{{"n1", asValue}, {"n2", asValue}, {"x", asValue}}},
@@ -6499,12 +6507,9 @@ IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
64996507
return value;
65006508
}
65016509

6502-
// ALL_SYNC
6503-
mlir::Value IntrinsicLibrary::genVoteAllSync(mlir::Type resultType,
6504-
llvm::ArrayRef<mlir::Value> args) {
6505-
assert(args.size() == 2);
6506-
6507-
llvm::StringRef funcName = "llvm.nvvm.vote.all.sync";
6510+
static mlir::Value genVoteSync(fir::FirOpBuilder &builder, mlir::Location loc,
6511+
llvm::StringRef funcName,
6512+
llvm::ArrayRef<mlir::Value> args) {
65086513
mlir::MLIRContext *context = builder.getContext();
65096514
mlir::Type i32Ty = builder.getI32Type();
65106515
mlir::FunctionType ftype =
@@ -6514,6 +6519,28 @@ mlir::Value IntrinsicLibrary::genVoteAllSync(mlir::Type resultType,
65146519
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
65156520
}
65166521

6522+
// ALL_SYNC
6523+
mlir::Value IntrinsicLibrary::genVoteAllSync(mlir::Type resultType,
6524+
llvm::ArrayRef<mlir::Value> args) {
6525+
assert(args.size() == 2);
6526+
return genVoteSync(builder, loc, "llvm.nvvm.vote.all.sync", args);
6527+
}
6528+
6529+
// ANY_SYNC
6530+
mlir::Value IntrinsicLibrary::genVoteAnySync(mlir::Type resultType,
6531+
llvm::ArrayRef<mlir::Value> args) {
6532+
assert(args.size() == 2);
6533+
return genVoteSync(builder, loc, "llvm.nvvm.vote.any.sync", args);
6534+
}
6535+
6536+
// BALLOT_SYNC
6537+
mlir::Value
6538+
IntrinsicLibrary::genVoteBallotSync(mlir::Type resultType,
6539+
llvm::ArrayRef<mlir::Value> args) {
6540+
assert(args.size() == 2);
6541+
return genVoteSync(builder, loc, "llvm.nvvm.vote.ballot.sync", args);
6542+
}
6543+
65176544
// MATCH_ANY_SYNC
65186545
mlir::Value
65196546
IntrinsicLibrary::genMatchAnySync(mlir::Type resultType,

flang/module/cudadevice.f90

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,20 @@ attributes(device) integer function all_sync(mask, pred)
10221022
end function
10231023
end interface
10241024

1025+
interface any_sync
1026+
attributes(device) integer function any_sync(mask, pred)
1027+
!dir$ ignore_tkr(d) mask, (td) pred
1028+
integer, value :: mask, pred
1029+
end function
1030+
end interface
1031+
1032+
interface ballot_sync
1033+
attributes(device) integer function ballot_sync(mask, pred)
1034+
!dir$ ignore_tkr(d) mask, (td) pred
1035+
integer, value :: mask, pred
1036+
end function
1037+
end interface
1038+
10251039
! LDCG
10261040
interface __ldcg
10271041
attributes(device) pure integer(4) function __ldcg_i4(x) bind(c)

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,14 @@ end
299299
attributes(device) subroutine testVote()
300300
integer :: a, ipred, mask, v32
301301
a = all_sync(mask, v32)
302-
302+
a = any_sync(mask, v32)
303+
a = ballot_sync(mask, v32)
303304
end subroutine
304305

305306
! CHECK-LABEL: func.func @_QPtestvote()
306307
! CHECK: fir.call @llvm.nvvm.vote.all.sync
307-
308+
! CHECK: fir.call @llvm.nvvm.vote.any.sync
309+
! CHECK: fir.call @llvm.nvvm.vote.ballot.sync
308310

309311
! CHECK-DAG: func.func private @__ldca_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
310312
! CHECK-DAG: func.func private @__ldcg_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)

0 commit comments

Comments
 (0)