Skip to content

Commit 726c4b9

Browse files
authored
[flang][cuda] Lower match_all_sync functions to nvvm intrinsics (#127940)
1 parent 02e8fd7 commit 726c4b9

File tree

7 files changed

+111
-1
lines changed

7 files changed

+111
-1
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ struct IntrinsicLibrary {
335335
mlir::Value genMalloc(mlir::Type, llvm::ArrayRef<mlir::Value>);
336336
template <typename Shift>
337337
mlir::Value genMask(mlir::Type, llvm::ArrayRef<mlir::Value>);
338+
mlir::Value genMatchAllSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
338339
fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
339340
fir::ExtendedValue genMatmulTranspose(mlir::Type,
340341
llvm::ArrayRef<fir::ExtendedValue>);

flang/include/flang/Semantics/tools.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ inline bool NeedCUDAAlloc(const Symbol &sym) {
231231
(*details->cudaDataAttr() == common::CUDADataAttr::Device ||
232232
*details->cudaDataAttr() == common::CUDADataAttr::Managed ||
233233
*details->cudaDataAttr() == common::CUDADataAttr::Unified ||
234+
*details->cudaDataAttr() == common::CUDADataAttr::Shared ||
234235
*details->cudaDataAttr() == common::CUDADataAttr::Pinned)) {
235236
return true;
236237
}

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,22 @@ static constexpr IntrinsicHandler handlers[]{
469469
{"malloc", &I::genMalloc},
470470
{"maskl", &I::genMask<mlir::arith::ShLIOp>},
471471
{"maskr", &I::genMask<mlir::arith::ShRUIOp>},
472+
{"match_all_syncjd",
473+
&I::genMatchAllSync,
474+
{{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
475+
/*isElemental=*/false},
476+
{"match_all_syncjf",
477+
&I::genMatchAllSync,
478+
{{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
479+
/*isElemental=*/false},
480+
{"match_all_syncjj",
481+
&I::genMatchAllSync,
482+
{{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
483+
/*isElemental=*/false},
484+
{"match_all_syncjx",
485+
&I::genMatchAllSync,
486+
{{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
487+
/*isElemental=*/false},
472488
{"matmul",
473489
&I::genMatmul,
474490
{{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
@@ -6044,6 +6060,42 @@ mlir::Value IntrinsicLibrary::genMask(mlir::Type resultType,
60446060
return result;
60456061
}
60466062

6063+
mlir::Value
6064+
IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
6065+
llvm::ArrayRef<mlir::Value> args) {
6066+
assert(args.size() == 3);
6067+
bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
6068+
6069+
llvm::StringRef funcName =
6070+
is32 ? "llvm.nvvm.match.all.sync.i32p" : "llvm.nvvm.match.all.sync.i64p";
6071+
mlir::MLIRContext *context = builder.getContext();
6072+
mlir::Type i32Ty = builder.getI32Type();
6073+
mlir::Type i64Ty = builder.getI64Type();
6074+
mlir::Type i1Ty = builder.getI1Type();
6075+
mlir::Type retTy = mlir::TupleType::get(context, {resultType, i1Ty});
6076+
mlir::Type valTy = is32 ? i32Ty : i64Ty;
6077+
6078+
mlir::FunctionType ftype =
6079+
mlir::FunctionType::get(context, {i32Ty, valTy}, {retTy});
6080+
auto funcOp = builder.createFunction(loc, funcName, ftype);
6081+
llvm::SmallVector<mlir::Value> filteredArgs;
6082+
filteredArgs.push_back(args[0]);
6083+
if (args[1].getType().isF32() || args[1].getType().isF64())
6084+
filteredArgs.push_back(builder.create<fir::ConvertOp>(loc, valTy, args[1]));
6085+
else
6086+
filteredArgs.push_back(args[1]);
6087+
auto call = builder.create<fir::CallOp>(loc, funcOp, filteredArgs);
6088+
auto zero = builder.getIntegerAttr(builder.getIndexType(), 0);
6089+
auto value = builder.create<fir::ExtractValueOp>(
6090+
loc, resultType, call.getResult(0), builder.getArrayAttr(zero));
6091+
auto one = builder.getIntegerAttr(builder.getIndexType(), 1);
6092+
auto pred = builder.create<fir::ExtractValueOp>(loc, i1Ty, call.getResult(0),
6093+
builder.getArrayAttr(one));
6094+
auto conv = builder.create<mlir::LLVM::ZExtOp>(loc, resultType, pred);
6095+
builder.create<fir::StoreOp>(loc, conv, args[2]);
6096+
return value;
6097+
}
6098+
60476099
// MATMUL
60486100
fir::ExtendedValue
60496101
IntrinsicLibrary::genMatmul(mlir::Type resultType,

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,12 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
292292
rewriter.setInsertionPointAfter(size.getDefiningOp());
293293
}
294294

295+
if (auto dataAttr = alloc->getAttrOfType<cuf::DataAttributeAttr>(
296+
cuf::getDataAttrName())) {
297+
if (dataAttr.getValue() == cuf::DataAttribute::Shared)
298+
allocaAs = 3;
299+
}
300+
295301
// NOTE: we used to pass alloc->getAttrs() in the builder for non opaque
296302
// pointers! Only propagate pinned and bindc_name to help debugging, but
297303
// this should have no functional purpose (and passing the operand segment
@@ -316,6 +322,7 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
316322
rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>(
317323
alloc, ::getLlvmPtrType(alloc.getContext(), programAs), llvmAlloc);
318324
}
325+
319326
return mlir::success();
320327
}
321328
};

flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ static llvm::LogicalResult checkCudaAttr(Op op) {
5757
if (op.getDataAttr() == cuf::DataAttribute::Device ||
5858
op.getDataAttr() == cuf::DataAttribute::Managed ||
5959
op.getDataAttr() == cuf::DataAttribute::Unified ||
60-
op.getDataAttr() == cuf::DataAttribute::Pinned)
60+
op.getDataAttr() == cuf::DataAttribute::Pinned ||
61+
op.getDataAttr() == cuf::DataAttribute::Shared)
6162
return mlir::success();
6263
return op.emitOpError()
6364
<< "expect device, managed, pinned or unified cuda attribute";

flang/module/cudadevice.f90

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,4 +562,31 @@ attributes(device) integer(8) function clock64()
562562
end function
563563
end interface
564564

565+
interface match_all_sync
566+
attributes(device) integer function match_all_syncjj(mask, val, pred)
567+
!dir$ ignore_tkr(d) mask, (d) val, (d) pred
568+
integer(4), value :: mask
569+
integer(4), value :: val
570+
integer(4) :: pred
571+
end function
572+
attributes(device) integer function match_all_syncjx(mask, val, pred)
573+
!dir$ ignore_tkr(d) mask, (d) val, (d) pred
574+
integer(4), value :: mask
575+
integer(8), value :: val
576+
integer(4) :: pred
577+
end function
578+
attributes(device) integer function match_all_syncjf(mask, val, pred)
579+
!dir$ ignore_tkr(d) mask, (d) val, (d) pred
580+
integer(4), value :: mask
581+
real(4), value :: val
582+
integer(4) :: pred
583+
end function
584+
attributes(device) integer function match_all_syncjd(mask, val, pred)
585+
!dir$ ignore_tkr(d) mask, (d) val, (d) pred
586+
integer(4), value :: mask
587+
real(8), value :: val
588+
integer(4) :: pred
589+
end function
590+
end interface
591+
565592
end module

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,25 @@ end
112112
! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
113113
! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
114114

115+
attributes(device) subroutine testMatch()
116+
integer :: a, ipred, mask, v32
117+
integer(8) :: v64
118+
real(4) :: r4
119+
real(8) :: r8
120+
a = match_all_sync(mask, v32, ipred)
121+
a = match_all_sync(mask, v64, ipred)
122+
a = match_all_sync(mask, r4, ipred)
123+
a = match_all_sync(mask, r8, ipred)
124+
end subroutine
125+
126+
! CHECK-LABEL: func.func @_QPtestmatch()
127+
! CHECK: fir.call @llvm.nvvm.match.all.sync.i32p
128+
! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p
129+
! CHECK: fir.convert %{{.*}} : (f32) -> i32
130+
! CHECK: fir.call @llvm.nvvm.match.all.sync.i32p
131+
! CHECK: fir.convert %{{.*}} : (f64) -> i64
132+
! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p
133+
115134
! CHECK: func.func private @llvm.nvvm.barrier0()
116135
! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
117136
! CHECK: func.func private @llvm.nvvm.membar.gl()
@@ -120,3 +139,5 @@ end
120139
! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
121140
! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
122141
! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
142+
! CHECK: func.func private @llvm.nvvm.match.all.sync.i32p(i32, i32) -> tuple<i32, i1>
143+
! CHECK: func.func private @llvm.nvvm.match.all.sync.i64p(i32, i64) -> tuple<i32, i1>

0 commit comments

Comments
 (0)