Skip to content

Commit 197f3ec

Browse files
authored
[flang][OpenMP] lower simple array reductions (#84958)
This has been tested with arrays with compile-time constant bounds. Allocatable arrays and arrays with non-constant bounds are not yet supported. User-defined reduction functions are also not yet supported. The design is intended to work for arrays with non-constant bounds too without a lot of extra work (mostly there are bugs in OpenMPIRBuilder I haven't fixed yet). We need some way to get these runtime bounds into the reduction init and combiner regions. To keep things simple for now I opted to always box the array arguments so the box can be passed as one argument and the lower bounds and extents read from the box. This has the disadvantage of resulting in fir.box_dim operations inside of the critical section. If these prove to be a performance issue, we could follow OpenACC reading box lower bounds and extents before the reduction and passing them as block arguments to the reduction init and combiner regions. I would prefer to keep things simple for now. Note: this implementation only works when the HLFIR lowering is used. I don't think it is worth supporting FIR-only lowering because the plan is for that to be removed soon. OpenMP array reductions 6/6 Previous PR: #84957
1 parent 22f2056 commit 197f3ec

File tree

9 files changed

+586
-54
lines changed

9 files changed

+586
-54
lines changed

flang/include/flang/Optimizer/Builder/HLFIRTools.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,11 @@ std::pair<hlfir::Entity, mlir::Value>
434434
createTempFromMold(mlir::Location loc, fir::FirOpBuilder &builder,
435435
hlfir::Entity mold);
436436

437+
// TODO: this does not support polymorphic molds
438+
hlfir::Entity createStackTempFromMold(mlir::Location loc,
439+
fir::FirOpBuilder &builder,
440+
hlfir::Entity mold);
441+
437442
hlfir::EntityWithAttributes convertCharacterKind(mlir::Location loc,
438443
fir::FirOpBuilder &builder,
439444
hlfir::Entity scalarChar,

flang/lib/Lower/OpenMP/ReductionProcessor.cpp

Lines changed: 212 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "ReductionProcessor.h"
1414

1515
#include "flang/Lower/AbstractConverter.h"
16+
#include "flang/Optimizer/Builder/HLFIRTools.h"
1617
#include "flang/Optimizer/Builder/Todo.h"
1718
#include "flang/Optimizer/Dialect/FIRType.h"
1819
#include "flang/Optimizer/HLFIR/HLFIROps.h"
@@ -90,10 +91,42 @@ std::string ReductionProcessor::getReductionName(llvm::StringRef name,
9091
if (isByRef)
9192
byrefAddition = "_byref";
9293

93-
return (llvm::Twine(name) +
94-
(ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
95-
llvm::Twine(ty.getIntOrFloatBitWidth()) + byrefAddition)
96-
.str();
94+
if (fir::isa_trivial(ty))
95+
return (llvm::Twine(name) +
96+
(ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
97+
llvm::Twine(ty.getIntOrFloatBitWidth()) + byrefAddition)
98+
.str();
99+
100+
// creates a name like reduction_i_64_box_ux4x3
101+
if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
102+
// TODO: support for allocatable boxes:
103+
// !fir.box<!fir.heap<!fir.array<...>>>
104+
fir::SequenceType seqTy = fir::unwrapRefType(boxTy.getEleTy())
105+
.dyn_cast_or_null<fir::SequenceType>();
106+
if (!seqTy)
107+
return {};
108+
109+
std::string prefix = getReductionName(
110+
name, fir::unwrapSeqOrBoxedSeqType(ty), /*isByRef=*/false);
111+
if (prefix.empty())
112+
return {};
113+
std::stringstream tyStr;
114+
tyStr << prefix << "_box_";
115+
bool first = true;
116+
for (std::int64_t extent : seqTy.getShape()) {
117+
if (first)
118+
first = false;
119+
else
120+
tyStr << "x";
121+
if (extent == seqTy.getUnknownExtent())
122+
tyStr << 'u'; // I'm not sure that '?' is safe in symbol names
123+
else
124+
tyStr << extent;
125+
}
126+
return (tyStr.str() + byrefAddition).str();
127+
}
128+
129+
return {};
97130
}
98131

99132
std::string ReductionProcessor::getReductionName(
@@ -281,13 +314,158 @@ mlir::Value ReductionProcessor::createScalarCombiner(
281314
return reductionOp;
282315
}
283316

317+
/// Create reduction combiner region for reduction variables which are boxed
318+
/// arrays
319+
static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
320+
ReductionProcessor::ReductionIdentifier redId,
321+
fir::BaseBoxType boxTy, mlir::Value lhs,
322+
mlir::Value rhs) {
323+
fir::SequenceType seqTy =
324+
mlir::dyn_cast_or_null<fir::SequenceType>(boxTy.getEleTy());
325+
// TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
326+
if (!seqTy || seqTy.hasUnknownShape())
327+
TODO(loc, "Unsupported boxed type in OpenMP reduction");
328+
329+
// load fir.ref<fir.box<...>>
330+
mlir::Value lhsAddr = lhs;
331+
lhs = builder.create<fir::LoadOp>(loc, lhs);
332+
rhs = builder.create<fir::LoadOp>(loc, rhs);
333+
334+
const unsigned rank = seqTy.getDimension();
335+
llvm::SmallVector<mlir::Value> extents;
336+
extents.reserve(rank);
337+
llvm::SmallVector<mlir::Value> lbAndExtents;
338+
lbAndExtents.reserve(rank * 2);
339+
340+
// Get box lowerbounds and extents:
341+
mlir::Type idxTy = builder.getIndexType();
342+
for (unsigned i = 0; i < rank; ++i) {
343+
// TODO: ideally we want to hoist box reads out of the critical section.
344+
// We could do this by having box dimensions in block arguments like
345+
// OpenACC does
346+
mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i);
347+
auto dimInfo =
348+
builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, lhs, dim);
349+
extents.push_back(dimInfo.getExtent());
350+
lbAndExtents.push_back(dimInfo.getLowerBound());
351+
lbAndExtents.push_back(dimInfo.getExtent());
352+
}
353+
354+
auto shapeShiftTy = fir::ShapeShiftType::get(builder.getContext(), rank);
355+
auto shapeShift =
356+
builder.create<fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents);
357+
358+
// Iterate over array elements, applying the equivalent scalar reduction:
359+
360+
// A hlfir::elemental here gets inlined with a temporary so create the
361+
// loop nest directly.
362+
// This function already controls all of the code in this region so we
363+
// know this won't miss any opportuinties for clever elemental inlining
364+
hlfir::LoopNest nest =
365+
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
366+
builder.setInsertionPointToStart(nest.innerLoop.getBody());
367+
mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
368+
auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
369+
loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
370+
nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
371+
auto rhsEleAddr = builder.create<fir::ArrayCoorOp>(
372+
loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{},
373+
nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
374+
auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr);
375+
auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr);
376+
mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner(
377+
builder, loc, redId, refTy, lhsEle, rhsEle);
378+
builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
379+
380+
builder.setInsertionPointAfter(nest.outerLoop);
381+
builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
382+
}
383+
384+
// generate combiner region for reduction operations
385+
static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
386+
ReductionProcessor::ReductionIdentifier redId,
387+
mlir::Type ty, mlir::Value lhs, mlir::Value rhs,
388+
bool isByRef) {
389+
ty = fir::unwrapRefType(ty);
390+
391+
if (fir::isa_trivial(ty)) {
392+
mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs);
393+
mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs);
394+
395+
mlir::Value result = ReductionProcessor::createScalarCombiner(
396+
builder, loc, redId, ty, lhsLoaded, rhsLoaded);
397+
if (isByRef) {
398+
builder.create<fir::StoreOp>(loc, result, lhs);
399+
builder.create<mlir::omp::YieldOp>(loc, lhs);
400+
} else {
401+
builder.create<mlir::omp::YieldOp>(loc, result);
402+
}
403+
return;
404+
}
405+
// all arrays should have been boxed
406+
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
407+
genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs);
408+
return;
409+
}
410+
411+
TODO(loc, "OpenMP genCombiner for unsupported reduction variable type");
412+
}
413+
414+
static mlir::Value
415+
createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc,
416+
const ReductionProcessor::ReductionIdentifier redId,
417+
mlir::Type type, bool isByRef) {
418+
mlir::Type ty = fir::unwrapRefType(type);
419+
mlir::Value initValue = ReductionProcessor::getReductionInitValue(
420+
loc, fir::unwrapSeqOrBoxedSeqType(ty), redId, builder);
421+
422+
if (fir::isa_trivial(ty)) {
423+
if (isByRef) {
424+
mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty);
425+
builder.createStoreWithConvert(loc, initValue, alloca);
426+
return alloca;
427+
}
428+
// by val
429+
return initValue;
430+
}
431+
432+
// all arrays are boxed
433+
if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
434+
assert(isByRef && "passing arrays by value is unsupported");
435+
// TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>>
436+
mlir::Type innerTy = fir::extractSequenceType(boxTy);
437+
if (!mlir::isa<fir::SequenceType>(innerTy))
438+
TODO(loc, "Unsupported boxed type for reduction");
439+
// Create the private copy from the initial fir.box:
440+
hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)};
441+
442+
// TODO: if the whole reduction is nested inside of a loop, this alloca
443+
// could lead to a stack overflow (the memory is only freed at the end of
444+
// the stack frame). The reduction declare operation needs a deallocation
445+
// region to undo the init region.
446+
hlfir::Entity temp = createStackTempFromMold(loc, builder, source);
447+
448+
// Put the temporary inside of a box:
449+
hlfir::Entity box = hlfir::genVariableBox(loc, builder, temp);
450+
builder.create<hlfir::AssignOp>(loc, initValue, box);
451+
mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty);
452+
builder.create<fir::StoreOp>(loc, box, boxAlloca);
453+
return boxAlloca;
454+
}
455+
456+
TODO(loc, "createReductionInitRegion for unsupported type");
457+
}
458+
284459
mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
285460
fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
286461
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
287462
bool isByRef) {
288463
mlir::OpBuilder::InsertionGuard guard(builder);
289464
mlir::ModuleOp module = builder.getModule();
290465

466+
if (reductionOpName.empty())
467+
TODO(loc, "Reduction of some types is not supported");
468+
291469
auto decl =
292470
module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
293471
if (decl)
@@ -304,14 +482,9 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
304482
decl.getInitializerRegion().end(), {type}, {loc});
305483
builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
306484

307-
mlir::Value init = getReductionInitValue(loc, type, redId, builder);
308-
if (isByRef) {
309-
mlir::Value alloca = builder.create<fir::AllocaOp>(loc, valTy);
310-
builder.createStoreWithConvert(loc, init, alloca);
311-
builder.create<mlir::omp::YieldOp>(loc, alloca);
312-
} else {
313-
builder.create<mlir::omp::YieldOp>(loc, init);
314-
}
485+
mlir::Value init =
486+
createReductionInitRegion(builder, loc, redId, type, isByRef);
487+
builder.create<mlir::omp::YieldOp>(loc, init);
315488

316489
builder.createBlock(&decl.getReductionRegion(),
317490
decl.getReductionRegion().end(), {type, type},
@@ -320,19 +493,7 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
320493
builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
321494
mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
322495
mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
323-
mlir::Value outAddr = op1;
324-
325-
op1 = builder.loadIfRef(loc, op1);
326-
op2 = builder.loadIfRef(loc, op2);
327-
328-
mlir::Value reductionOp =
329-
createScalarCombiner(builder, loc, redId, type, op1, op2);
330-
if (isByRef) {
331-
builder.create<fir::StoreOp>(loc, reductionOp, outAddr);
332-
builder.create<mlir::omp::YieldOp>(loc, outAddr);
333-
} else {
334-
builder.create<mlir::omp::YieldOp>(loc, reductionOp);
335-
}
496+
genCombiner(builder, loc, redId, type, op1, op2, isByRef);
336497

337498
return decl;
338499
}
@@ -387,13 +548,33 @@ void ReductionProcessor::addReductionDecl(
387548

388549
// initial pass to collect all reduction vars so we can figure out if this
389550
// should happen byref
551+
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
390552
for (const Object &object : objectList) {
391553
const Fortran::semantics::Symbol *symbol = object.id();
392554
if (reductionSymbols)
393555
reductionSymbols->push_back(symbol);
394556
mlir::Value symVal = converter.getSymbolAddress(*symbol);
395-
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
557+
auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
558+
559+
// all arrays must be boxed so that we have convenient access to all the
560+
// information needed to iterate over the array
561+
if (mlir::isa<fir::SequenceType>(redType.getEleTy())) {
562+
hlfir::Entity entity{symVal};
563+
entity = genVariableBox(currentLocation, builder, entity);
564+
mlir::Value box = entity.getBase();
565+
566+
// Always pass the box by reference so that the OpenMP dialect
567+
// verifiers don't need to know anything about fir.box
568+
auto alloca =
569+
builder.create<fir::AllocaOp>(currentLocation, box.getType());
570+
builder.create<fir::StoreOp>(currentLocation, box, alloca);
571+
572+
symVal = alloca;
573+
redType = mlir::cast<fir::ReferenceType>(symVal.getType());
574+
} else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
396575
symVal = declOp.getBase();
576+
}
577+
397578
reductionVars.push_back(symVal);
398579
}
399580
const bool isByRef = doReductionByRef(reductionVars);
@@ -418,24 +599,17 @@ void ReductionProcessor::addReductionDecl(
418599
break;
419600
}
420601

421-
for (const Object &object : objectList) {
422-
const Fortran::semantics::Symbol *symbol = object.id();
423-
mlir::Value symVal = converter.getSymbolAddress(*symbol);
424-
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
425-
symVal = declOp.getBase();
426-
auto redType = symVal.getType().cast<fir::ReferenceType>();
602+
for (mlir::Value symVal : reductionVars) {
603+
auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
427604
if (redType.getEleTy().isa<fir::LogicalType>())
428605
decl = createReductionDecl(
429606
firOpBuilder,
430607
getReductionName(intrinsicOp, firOpBuilder.getI1Type(), isByRef),
431608
redId, redType, currentLocation, isByRef);
432-
else if (redType.getEleTy().isIntOrIndexOrFloat()) {
609+
else
433610
decl = createReductionDecl(
434611
firOpBuilder, getReductionName(intrinsicOp, redType, isByRef),
435612
redId, redType, currentLocation, isByRef);
436-
} else {
437-
TODO(currentLocation, "Reduction of some types is not supported");
438-
}
439613
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
440614
firOpBuilder.getContext(), decl.getSymName()));
441615
}
@@ -452,8 +626,8 @@ void ReductionProcessor::addReductionDecl(
452626
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
453627
symVal = declOp.getBase();
454628
auto redType = symVal.getType().cast<fir::ReferenceType>();
455-
assert(redType.getEleTy().isIntOrIndexOrFloat() &&
456-
"Unsupported reduction type");
629+
if (!redType.getEleTy().isIntOrIndexOrFloat())
630+
TODO(currentLocation, "User Defined Reduction on non-trivial type");
457631
decl = createReductionDecl(
458632
firOpBuilder,
459633
getReductionName(getRealName(*reductionIntrinsic).ToString(),

flang/lib/Lower/OpenMP/ReductionProcessor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class ReductionProcessor {
108108
/// Creates an OpenMP reduction declaration and inserts it into the provided
109109
/// symbol table. The declaration has a constant initializer with the neutral
110110
/// value `initValue`, and the reduction combiner carried over from `reduce`.
111-
/// TODO: Generalize this for non-integer types, add atomic region.
111+
/// TODO: add atomic region.
112112
static mlir::omp::ReductionDeclareOp
113113
createReductionDecl(fir::FirOpBuilder &builder,
114114
llvm::StringRef reductionOpName,

flang/lib/Optimizer/Builder/HLFIRTools.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,6 +1111,35 @@ hlfir::createTempFromMold(mlir::Location loc, fir::FirOpBuilder &builder,
11111111
return {hlfir::Entity{declareOp.getBase()}, isHeapAlloc};
11121112
}
11131113

1114+
hlfir::Entity hlfir::createStackTempFromMold(mlir::Location loc,
1115+
fir::FirOpBuilder &builder,
1116+
hlfir::Entity mold) {
1117+
llvm::SmallVector<mlir::Value> lenParams;
1118+
hlfir::genLengthParameters(loc, builder, mold, lenParams);
1119+
llvm::StringRef tmpName{".tmp"};
1120+
mlir::Value alloc;
1121+
mlir::Value shape{};
1122+
fir::FortranVariableFlagsAttr declAttrs;
1123+
1124+
if (mold.isPolymorphic()) {
1125+
// genAllocatableApplyMold does heap allocation
1126+
TODO(loc, "createStackTempFromMold for polymorphic type");
1127+
} else if (mold.isArray()) {
1128+
mlir::Type sequenceType =
1129+
hlfir::getFortranElementOrSequenceType(mold.getType());
1130+
shape = hlfir::genShape(loc, builder, mold);
1131+
auto extents = hlfir::getIndexExtents(loc, builder, shape);
1132+
alloc =
1133+
builder.createTemporary(loc, sequenceType, tmpName, extents, lenParams);
1134+
} else {
1135+
alloc = builder.createTemporary(loc, mold.getFortranElementType(), tmpName,
1136+
/*shape=*/std::nullopt, lenParams);
1137+
}
1138+
auto declareOp = builder.create<hlfir::DeclareOp>(loc, alloc, tmpName, shape,
1139+
lenParams, declAttrs);
1140+
return hlfir::Entity{declareOp.getBase()};
1141+
}
1142+
11141143
hlfir::EntityWithAttributes
11151144
hlfir::convertCharacterKind(mlir::Location loc, fir::FirOpBuilder &builder,
11161145
hlfir::Entity scalarChar, int toKind) {

flang/test/Lower/OpenMP/Todo/reduction-arrays.f90

Lines changed: 0 additions & 15 deletions
This file was deleted.

0 commit comments

Comments
 (0)