Skip to content

Commit 84c8848

Browse files
authored
[flang][cuda] Lower match_any_sync functions to nvvm intrinsics (#127942)
1 parent d1dde17 commit 84c8848

File tree

4 files changed

+88
-0
lines changed

4 files changed

+88
-0
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ struct IntrinsicLibrary {
336336
template <typename Shift>
337337
mlir::Value genMask(mlir::Type, llvm::ArrayRef<mlir::Value>);
338338
mlir::Value genMatchAllSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
339+
mlir::Value genMatchAnySync(mlir::Type, llvm::ArrayRef<mlir::Value>);
339340
fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
340341
fir::ExtendedValue genMatmulTranspose(mlir::Type,
341342
llvm::ArrayRef<fir::ExtendedValue>);

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,22 @@ static constexpr IntrinsicHandler handlers[]{
485485
&I::genMatchAllSync,
486486
{{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
487487
/*isElemental=*/false},
488+
{"match_any_syncjd",
489+
&I::genMatchAnySync,
490+
{{{"mask", asValue}, {"value", asValue}}},
491+
/*isElemental=*/false},
492+
{"match_any_syncjf",
493+
&I::genMatchAnySync,
494+
{{{"mask", asValue}, {"value", asValue}}},
495+
/*isElemental=*/false},
496+
{"match_any_syncjj",
497+
&I::genMatchAnySync,
498+
{{{"mask", asValue}, {"value", asValue}}},
499+
/*isElemental=*/false},
500+
{"match_any_syncjx",
501+
&I::genMatchAnySync,
502+
{{{"mask", asValue}, {"value", asValue}}},
503+
/*isElemental=*/false},
488504
{"matmul",
489505
&I::genMatmul,
490506
{{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
@@ -6060,6 +6076,7 @@ mlir::Value IntrinsicLibrary::genMask(mlir::Type resultType,
60606076
return result;
60616077
}
60626078

6079+
// MATCH_ALL_SYNC
60636080
mlir::Value
60646081
IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
60656082
llvm::ArrayRef<mlir::Value> args) {
@@ -6096,6 +6113,32 @@ IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
60966113
return value;
60976114
}
60986115

6116+
// MATCH_ANY_SYNC
6117+
mlir::Value
6118+
IntrinsicLibrary::genMatchAnySync(mlir::Type resultType,
6119+
llvm::ArrayRef<mlir::Value> args) {
6120+
assert(args.size() == 2);
6121+
bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
6122+
6123+
llvm::StringRef funcName =
6124+
is32 ? "llvm.nvvm.match.any.sync.i32p" : "llvm.nvvm.match.any.sync.i64p";
6125+
mlir::MLIRContext *context = builder.getContext();
6126+
mlir::Type i32Ty = builder.getI32Type();
6127+
mlir::Type i64Ty = builder.getI64Type();
6128+
mlir::Type valTy = is32 ? i32Ty : i64Ty;
6129+
6130+
mlir::FunctionType ftype =
6131+
mlir::FunctionType::get(context, {i32Ty, valTy}, {i32Ty});
6132+
auto funcOp = builder.createFunction(loc, funcName, ftype);
6133+
llvm::SmallVector<mlir::Value> filteredArgs;
6134+
filteredArgs.push_back(args[0]);
6135+
if (args[1].getType().isF32() || args[1].getType().isF64())
6136+
filteredArgs.push_back(builder.create<fir::ConvertOp>(loc, valTy, args[1]));
6137+
else
6138+
filteredArgs.push_back(args[1]);
6139+
return builder.create<fir::CallOp>(loc, funcOp, filteredArgs).getResult(0);
6140+
}
6141+
60996142
// MATMUL
61006143
fir::ExtendedValue
61016144
IntrinsicLibrary::genMatmul(mlir::Type resultType,

flang/module/cudadevice.f90

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,4 +589,27 @@ attributes(device) integer function match_all_syncjd(mask, val, pred)
589589
end function
590590
end interface
591591

592+
interface match_any_sync
593+
attributes(device) integer function match_any_syncjj(mask, val)
594+
!dir$ ignore_tkr(d) mask, (d) val
595+
integer(4), value :: mask
596+
integer(4), value :: val
597+
end function
598+
attributes(device) integer function match_any_syncjx(mask, val)
599+
!dir$ ignore_tkr(d) mask, (d) val
600+
integer(4), value :: mask
601+
integer(8), value :: val
602+
end function
603+
attributes(device) integer function match_any_syncjf(mask, val)
604+
!dir$ ignore_tkr(d) mask, (d) val
605+
integer(4), value :: mask
606+
real(4), value :: val
607+
end function
608+
attributes(device) integer function match_any_syncjd(mask, val)
609+
!dir$ ignore_tkr(d) mask, (d) val
610+
integer(4), value :: mask
611+
real(8), value :: val
612+
end function
613+
end interface
614+
592615
end module

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,25 @@ end subroutine
131131
! CHECK: fir.convert %{{.*}} : (f64) -> i64
132132
! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p
133133

134+
attributes(device) subroutine testMatchAny()
135+
integer :: a, mask, v32
136+
integer(8) :: v64
137+
real(4) :: r4
138+
real(8) :: r8
139+
a = match_any_sync(mask, v32)
140+
a = match_any_sync(mask, v64)
141+
a = match_any_sync(mask, r4)
142+
a = match_any_sync(mask, r8)
143+
end subroutine
144+
145+
! CHECK-LABEL: func.func @_QPtestmatchany()
146+
! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
147+
! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
148+
! CHECK: fir.convert %{{.*}} : (f32) -> i32
149+
! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
150+
! CHECK: fir.convert %{{.*}} : (f64) -> i64
151+
! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
152+
134153
! CHECK: func.func private @llvm.nvvm.barrier0()
135154
! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
136155
! CHECK: func.func private @llvm.nvvm.membar.gl()
@@ -141,3 +160,5 @@ end subroutine
141160
! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
142161
! CHECK: func.func private @llvm.nvvm.match.all.sync.i32p(i32, i32) -> tuple<i32, i1>
143162
! CHECK: func.func private @llvm.nvvm.match.all.sync.i64p(i32, i64) -> tuple<i32, i1>
163+
! CHECK: func.func private @llvm.nvvm.match.any.sync.i32p(i32, i32) -> i32
164+
! CHECK: func.func private @llvm.nvvm.match.any.sync.i64p(i32, i64) -> i32

0 commit comments

Comments
 (0)