Skip to content

Commit d578694

Browse files
committed
[Flang] Maxloc elemental intrinsic lowering.
This is an extension to #74828 to handle maxloc too, to keep the minloc and maxloc symmetric.
1 parent b0b7be2 commit d578694

File tree

4 files changed

+306
-161
lines changed

4 files changed

+306
-161
lines changed

flang/include/flang/Optimizer/Support/Utils.h

Lines changed: 128 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "flang/Optimizer/Builder/Todo.h"
1919
#include "flang/Optimizer/Dialect/FIROps.h"
2020
#include "flang/Optimizer/Dialect/FIRType.h"
21+
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
2122
#include "flang/Optimizer/Support/FatalError.h"
2223
#include "mlir/Dialect/Arith/IR/Arith.h"
2324
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -144,13 +145,133 @@ using AddrGeneratorTy = llvm::function_ref<mlir::Value(
144145
mlir::Value)>;
145146

146147
// Produces a loop nest for a Minloc intrinsic.
147-
void genMinMaxlocReductionLoop(fir::FirOpBuilder &builder, mlir::Value array,
148-
InitValGeneratorTy initVal,
149-
MinlocBodyOpGeneratorTy genBody,
150-
fir::AddrGeneratorTy getAddrFn, unsigned rank,
151-
mlir::Type elementType, mlir::Location loc,
152-
mlir::Type maskElemType, mlir::Value resultArr,
153-
bool maskMayBeLogicalScalar);
148+
inline void genMinMaxlocReductionLoop(
149+
fir::FirOpBuilder &builder, mlir::Value array,
150+
fir::InitValGeneratorTy initVal, fir::MinlocBodyOpGeneratorTy genBody,
151+
fir::AddrGeneratorTy getAddrFn, unsigned rank, mlir::Type elementType,
152+
mlir::Location loc, mlir::Type maskElemType, mlir::Value resultArr,
153+
bool maskMayBeLogicalScalar) {
154+
mlir::IndexType idxTy = builder.getIndexType();
155+
156+
mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
157+
158+
fir::SequenceType::Shape flatShape(rank,
159+
fir::SequenceType::getUnknownExtent());
160+
mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
161+
mlir::Type boxArrTy = fir::BoxType::get(arrTy);
162+
array = builder.create<fir::ConvertOp>(loc, boxArrTy, array);
163+
164+
mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType());
165+
mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
166+
mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0);
167+
mlir::Value flagRef = builder.createTemporary(loc, resultElemType);
168+
builder.create<fir::StoreOp>(loc, zero, flagRef);
169+
170+
mlir::Value init = initVal(builder, loc, elementType);
171+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
172+
173+
assert(rank > 0 && "rank cannot be zero");
174+
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
175+
176+
// Compute all the upper bounds before the loop nest.
177+
// It is not strictly necessary for performance, since the loop nest
178+
// does not have any store operations and any LICM optimization
179+
// should be able to optimize the redundancy.
180+
for (unsigned i = 0; i < rank; ++i) {
181+
mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
182+
auto dims =
183+
builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, array, dimIdx);
184+
mlir::Value len = dims.getResult(1);
185+
// We use C indexing here, so len-1 as loopcount
186+
mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
187+
bounds.push_back(loopCount);
188+
}
189+
// Create a loop nest consisting of OP operations.
190+
// Collect the loops' induction variables into indices array,
191+
// which will be used in the innermost loop to load the input
192+
// array's element.
193+
// The loops are generated such that the innermost loop processes
194+
// the 0 dimension.
195+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
196+
for (unsigned i = rank; 0 < i; --i) {
197+
mlir::Value step = one;
198+
mlir::Value loopCount = bounds[i - 1];
199+
auto loop =
200+
builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false,
201+
/*finalCountValue=*/false, init);
202+
init = loop.getRegionIterArgs()[0];
203+
indices.push_back(loop.getInductionVar());
204+
// Set insertion point to the loop body so that the next loop
205+
// is inserted inside the current one.
206+
builder.setInsertionPointToStart(loop.getBody());
207+
}
208+
209+
// Reverse the indices such that they are ordered as:
210+
// <dim-0-idx, dim-1-idx, ...>
211+
std::reverse(indices.begin(), indices.end());
212+
mlir::Value reductionVal =
213+
genBody(builder, loc, elementType, array, flagRef, init, indices);
214+
215+
// Unwind the loop nest and insert ResultOp on each level
216+
// to return the updated value of the reduction to the enclosing
217+
// loops.
218+
for (unsigned i = 0; i < rank; ++i) {
219+
auto result = builder.create<fir::ResultOp>(loc, reductionVal);
220+
// Proceed to the outer loop.
221+
auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
222+
reductionVal = loop.getResult(0);
223+
// Set insertion point after the loop operation that we have
224+
// just processed.
225+
builder.setInsertionPointAfter(loop.getOperation());
226+
}
227+
// End of loop nest. The insertion point is after the outermost loop.
228+
if (maskMayBeLogicalScalar) {
229+
if (fir::IfOp ifOp =
230+
mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) {
231+
builder.create<fir::ResultOp>(loc, reductionVal);
232+
builder.setInsertionPointAfter(ifOp);
233+
// Redefine flagSet to escape scope of ifOp
234+
flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
235+
reductionVal = ifOp.getResult(0);
236+
}
237+
}
238+
239+
// Check for case where array was full of max values.
240+
// flag will be 0 if mask was never true, 1 if mask was true as some point,
241+
// this is needed to avoid catching cases where we didn't access any elements
242+
// e.g. mask=.FALSE.
243+
mlir::Value flagValue =
244+
builder.create<fir::LoadOp>(loc, resultElemType, flagRef);
245+
mlir::Value flagCmp = builder.create<mlir::arith::CmpIOp>(
246+
loc, mlir::arith::CmpIPredicate::eq, flagValue, flagSet);
247+
fir::IfOp ifMaskTrueOp =
248+
builder.create<fir::IfOp>(loc, flagCmp, /*withElseRegion=*/false);
249+
builder.setInsertionPointToStart(&ifMaskTrueOp.getThenRegion().front());
250+
251+
mlir::Value testInit = initVal(builder, loc, elementType);
252+
fir::IfOp ifMinSetOp;
253+
if (elementType.isa<mlir::FloatType>()) {
254+
mlir::Value cmp = builder.create<mlir::arith::CmpFOp>(
255+
loc, mlir::arith::CmpFPredicate::OEQ, testInit, reductionVal);
256+
ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
257+
/*withElseRegion*/ false);
258+
} else {
259+
mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
260+
loc, mlir::arith::CmpIPredicate::eq, testInit, reductionVal);
261+
ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
262+
/*withElseRegion*/ false);
263+
}
264+
builder.setInsertionPointToStart(&ifMinSetOp.getThenRegion().front());
265+
266+
// Load output array with 1s instead of 0s
267+
for (unsigned int i = 0; i < rank; ++i) {
268+
mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
269+
mlir::Value resultElemAddr =
270+
getAddrFn(builder, loc, resultElemType, resultArr, index);
271+
builder.create<fir::StoreOp>(loc, flagSet, resultElemAddr);
272+
}
273+
builder.setInsertionPointAfter(ifMaskTrueOp);
274+
}
154275

155276
} // namespace fir
156277

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -812,50 +812,55 @@ class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
812812
// inlined elemental.
813813
// %e = hlfir.elemental %shape ({ ... })
814814
// %m = hlfir.minloc %array mask %e
815-
class MinMaxlocElementalConversion
816-
: public mlir::OpRewritePattern<hlfir::MinlocOp> {
815+
template <typename Op>
816+
class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
817817
public:
818-
using mlir::OpRewritePattern<hlfir::MinlocOp>::OpRewritePattern;
818+
using mlir::OpRewritePattern<Op>::OpRewritePattern;
819819

820820
mlir::LogicalResult
821-
matchAndRewrite(hlfir::MinlocOp minloc,
822-
mlir::PatternRewriter &rewriter) const override {
823-
if (!minloc.getMask() || minloc.getDim() || minloc.getBack())
824-
return rewriter.notifyMatchFailure(minloc, "Did not find valid minloc");
821+
matchAndRewrite(Op mloc, mlir::PatternRewriter &rewriter) const override {
822+
if (!mloc.getMask() || mloc.getDim() || mloc.getBack())
823+
return rewriter.notifyMatchFailure(mloc,
824+
"Did not find valid minloc/maxloc");
825825

826-
auto elemental = minloc.getMask().getDefiningOp<hlfir::ElementalOp>();
826+
constexpr bool isMax = std::is_same_v<Op, hlfir::MaxlocOp>;
827+
828+
auto elemental =
829+
mloc.getMask().template getDefiningOp<hlfir::ElementalOp>();
827830
if (!elemental || hlfir::elementalOpMustProduceTemp(elemental))
828-
return rewriter.notifyMatchFailure(minloc, "Did not find elemental");
831+
return rewriter.notifyMatchFailure(mloc, "Did not find elemental");
829832

830-
mlir::Value array = minloc.getArray();
833+
mlir::Value array = mloc.getArray();
831834

832-
unsigned rank = mlir::cast<hlfir::ExprType>(minloc.getType()).getShape()[0];
835+
unsigned rank = mlir::cast<hlfir::ExprType>(mloc.getType()).getShape()[0];
833836
mlir::Type arrayType = array.getType();
834837
if (!arrayType.isa<fir::BoxType>())
835838
return rewriter.notifyMatchFailure(
836-
minloc, "Currently requires a boxed type input");
839+
mloc, "Currently requires a boxed type input");
837840
mlir::Type elementType = hlfir::getFortranElementType(arrayType);
838841
if (!fir::isa_trivial(elementType))
839842
return rewriter.notifyMatchFailure(
840-
minloc, "Character arrays are currently not handled");
843+
mloc, "Character arrays are currently not handled");
841844

842-
mlir::Location loc = minloc.getLoc();
843-
fir::FirOpBuilder builder{rewriter, minloc.getOperation()};
845+
mlir::Location loc = mloc.getLoc();
846+
fir::FirOpBuilder builder{rewriter, mloc.getOperation()};
844847
mlir::Value resultArr = builder.createTemporary(
845848
loc, fir::SequenceType::get(
846-
rank, hlfir::getFortranElementType(minloc.getType())));
849+
rank, hlfir::getFortranElementType(mloc.getType())));
847850

848851
auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
849852
mlir::Type elementType) {
850853
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
851854
const llvm::fltSemantics &sem = ty.getFloatSemantics();
852855
return builder.createRealConstant(
853856
loc, elementType,
854-
llvm::APFloat::getLargest(sem, /*Negative=*/false));
857+
llvm::APFloat::getLargest(sem, /*Negative=*/!isMax));
855858
}
856859
unsigned bits = elementType.getIntOrFloatBitWidth();
857-
int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
858-
return builder.createIntegerConstant(loc, elementType, maxInt);
860+
int64_t limitInt =
861+
isMax ? llvm::APInt::getSignedMinValue(bits).getSExtValue()
862+
: llvm::APInt::getSignedMaxValue(bits).getSExtValue();
863+
return builder.createIntegerConstant(loc, elementType, limitInt);
859864
};
860865

861866
auto genBodyOp =
@@ -899,10 +904,16 @@ class MinMaxlocElementalConversion
899904
mlir::Value cmp;
900905
if (elementType.isa<mlir::FloatType>()) {
901906
cmp = builder.create<mlir::arith::CmpFOp>(
902-
loc, mlir::arith::CmpFPredicate::OLT, elem, reduction);
907+
loc,
908+
isMax ? mlir::arith::CmpFPredicate::OGT
909+
: mlir::arith::CmpFPredicate::OLT,
910+
elem, reduction);
903911
} else if (elementType.isa<mlir::IntegerType>()) {
904912
cmp = builder.create<mlir::arith::CmpIOp>(
905-
loc, mlir::arith::CmpIPredicate::slt, elem, reduction);
913+
loc,
914+
isMax ? mlir::arith::CmpIPredicate::sgt
915+
: mlir::arith::CmpIPredicate::slt,
916+
elem, reduction);
906917
} else {
907918
llvm_unreachable("unsupported type");
908919
}
@@ -975,15 +986,15 @@ class MinMaxlocElementalConversion
975986
// AsExpr for the temporary resultArr.
976987
llvm::SmallVector<hlfir::DestroyOp> destroys;
977988
llvm::SmallVector<hlfir::AssignOp> assigns;
978-
for (auto user : minloc->getUsers()) {
989+
for (auto user : mloc->getUsers()) {
979990
if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(user))
980991
destroys.push_back(destroy);
981992
else if (auto assign = mlir::dyn_cast<hlfir::AssignOp>(user))
982993
assigns.push_back(assign);
983994
}
984995

985-
// Check if the minloc was the only user of the elemental (apart from a
986-
// destroy), and remove it if so.
996+
// Check if the minloc/maxloc was the only user of the elemental (apart from
997+
// a destroy), and remove it if so.
987998
mlir::Operation::user_range elemUsers = elemental->getUsers();
988999
hlfir::DestroyOp elemDestroy;
9891000
if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
@@ -996,7 +1007,7 @@ class MinMaxlocElementalConversion
9961007
rewriter.eraseOp(d);
9971008
for (auto a : assigns)
9981009
a.setOperand(0, resultArr);
999-
rewriter.replaceOp(minloc, asExpr);
1010+
rewriter.replaceOp(mloc, asExpr);
10001011
if (elemDestroy) {
10011012
rewriter.eraseOp(elemDestroy);
10021013
rewriter.eraseOp(elemental);
@@ -1030,7 +1041,8 @@ class OptimizedBufferizationPass
10301041
patterns.insert<ReductionElementalConversion<hlfir::CountOp>>(context);
10311042
patterns.insert<ReductionElementalConversion<hlfir::AnyOp>>(context);
10321043
patterns.insert<ReductionElementalConversion<hlfir::AllOp>>(context);
1033-
patterns.insert<MinMaxlocElementalConversion>(context);
1044+
patterns.insert<MinMaxlocElementalConversion<hlfir::MinlocOp>>(context);
1045+
patterns.insert<MinMaxlocElementalConversion<hlfir::MaxlocOp>>(context);
10341046

10351047
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
10361048
func, std::move(patterns), config))) {

0 commit comments

Comments
 (0)