Skip to content

[Flang] Minloc elemental intrinsic lowering #74828

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions flang/include/flang/Optimizer/Support/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,25 @@ inline void intrinsicTypeTODO(fir::FirOpBuilder &builder, mlir::Type type,
fir::numericMlirTypeToFortran(builder, type, loc, intrinsicName) +
" in " + intrinsicName);
}

using MinlocBodyOpGeneratorTy = llvm::function_ref<mlir::Value(
fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
mlir::Value, mlir::Value, const llvm::SmallVectorImpl<mlir::Value> &)>;
using InitValGeneratorTy = llvm::function_ref<mlir::Value(
fir::FirOpBuilder &, mlir::Location, const mlir::Type &)>;
using AddrGeneratorTy = llvm::function_ref<mlir::Value(
fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
mlir::Value)>;

// Produces a loop nest for a Minloc intrinsic.
void genMinMaxlocReductionLoop(fir::FirOpBuilder &builder, mlir::Value array,
InitValGeneratorTy initVal,
MinlocBodyOpGeneratorTy genBody,
fir::AddrGeneratorTy getAddrFn, unsigned rank,
mlir::Type elementType, mlir::Location loc,
mlir::Type maskElemType, mlir::Value resultArr,
bool maskMayBeLogicalScalar);

} // namespace fir

#endif // FORTRAN_OPTIMIZER_SUPPORT_UTILS_H
199 changes: 199 additions & 0 deletions flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "flang/Optimizer/Support/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -807,6 +808,203 @@ class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
}
};

// Look for minloc(mask=elemental) and generate the minloc loop with
// inlined elemental.
// %e = hlfir.elemental %shape ({ ... })
// %m = hlfir.minloc %array mask %e
class MinMaxlocElementalConversion
: public mlir::OpRewritePattern<hlfir::MinlocOp> {
public:
using mlir::OpRewritePattern<hlfir::MinlocOp>::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(hlfir::MinlocOp minloc,
mlir::PatternRewriter &rewriter) const override {
if (!minloc.getMask() || minloc.getDim() || minloc.getBack())
return rewriter.notifyMatchFailure(minloc, "Did not find valid minloc");

auto elemental = minloc.getMask().getDefiningOp<hlfir::ElementalOp>();
if (!elemental || hlfir::elementalOpMustProduceTemp(elemental))
return rewriter.notifyMatchFailure(minloc, "Did not find elemental");

mlir::Value array = minloc.getArray();

unsigned rank = mlir::cast<hlfir::ExprType>(minloc.getType()).getShape()[0];
mlir::Type arrayType = array.getType();
if (!arrayType.isa<fir::BoxType>())
return rewriter.notifyMatchFailure(
minloc, "Currently requires a boxed type input");
mlir::Type elementType = hlfir::getFortranElementType(arrayType);
if (!fir::isa_trivial(elementType))
return rewriter.notifyMatchFailure(
minloc, "Character arrays are currently not handled");

mlir::Location loc = minloc.getLoc();
fir::FirOpBuilder builder{rewriter, minloc.getOperation()};
mlir::Value resultArr = builder.createTemporary(
loc, fir::SequenceType::get(
rank, hlfir::getFortranElementType(minloc.getType())));

auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType) {
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
const llvm::fltSemantics &sem = ty.getFloatSemantics();
return builder.createRealConstant(
loc, elementType,
llvm::APFloat::getLargest(sem, /*Negative=*/false));
}
unsigned bits = elementType.getIntOrFloatBitWidth();
int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
return builder.createIntegerConstant(loc, elementType, maxInt);
};

auto genBodyOp =
[&rank, &resultArr, &elemental](
fir::FirOpBuilder builder, mlir::Location loc,
mlir::Type elementType, mlir::Value array, mlir::Value flagRef,
mlir::Value reduction,
const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
// We are in the innermost loop: generate the elemental inline
mlir::Value oneIdx =
builder.createIntegerConstant(loc, builder.getIndexType(), 1);
llvm::SmallVector<mlir::Value> oneBasedIndices;
llvm::transform(
indices, std::back_inserter(oneBasedIndices), [&](mlir::Value V) {
return builder.create<mlir::arith::AddIOp>(loc, V, oneIdx);
});
hlfir::YieldElementOp yield =
hlfir::inlineElementalOp(loc, builder, elemental, oneBasedIndices);
mlir::Value maskElem = yield.getElementValue();
yield->erase();

mlir::Type ifCompatType = builder.getI1Type();
mlir::Value ifCompatElem =
builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem);

llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
fir::IfOp maskIfOp =
builder.create<fir::IfOp>(loc, elementType, ifCompatElem,
/*withElseRegion=*/true);
builder.setInsertionPointToStart(&maskIfOp.getThenRegion().front());

// Set flag that mask was true at some point
mlir::Value flagSet = builder.createIntegerConstant(
loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
builder.create<fir::StoreOp>(loc, flagSet, flagRef);
mlir::Value addr = hlfir::getElementAt(loc, builder, hlfir::Entity{array},
oneBasedIndices);
mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);

// Compare with the max reduction value
mlir::Value cmp;
if (elementType.isa<mlir::FloatType>()) {
cmp = builder.create<mlir::arith::CmpFOp>(
loc, mlir::arith::CmpFPredicate::OLT, elem, reduction);
} else if (elementType.isa<mlir::IntegerType>()) {
cmp = builder.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::slt, elem, reduction);
} else {
llvm_unreachable("unsupported type");
}

// Set the new coordinate to the result
fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
/*withElseRegion*/ true);

builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
mlir::Type resultElemTy =
hlfir::getFortranElementType(resultArr.getType());
mlir::Type returnRefTy = builder.getRefType(resultElemTy);
mlir::IndexType idxTy = builder.getIndexType();

for (unsigned int i = 0; i < rank; ++i) {
mlir::Value index = builder.createIntegerConstant(loc, idxTy, i + 1);
mlir::Value resultElemAddr = builder.create<hlfir::DesignateOp>(
loc, returnRefTy, resultArr, index);
mlir::Value fortranIndex = builder.create<fir::ConvertOp>(
loc, resultElemTy, oneBasedIndices[i]);
builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr);
}
builder.create<fir::ResultOp>(loc, elem);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<fir::ResultOp>(loc, reduction);
builder.setInsertionPointAfter(ifOp);

// Close the mask if
builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
builder.setInsertionPointToStart(&maskIfOp.getElseRegion().front());
builder.create<fir::ResultOp>(loc, reduction);
builder.setInsertionPointAfter(maskIfOp);

return maskIfOp.getResult(0);
};
auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
const mlir::Type &resultElemType, mlir::Value resultArr,
mlir::Value index) {
mlir::Type resultRefTy = builder.getRefType(resultElemType);
mlir::Value oneIdx =
builder.createIntegerConstant(loc, builder.getIndexType(), 1);
index = builder.create<mlir::arith::AddIOp>(loc, index, oneIdx);
return builder.create<hlfir::DesignateOp>(loc, resultRefTy, resultArr,
index);
};

// Initialize the result
mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
mlir::Type resultRefTy = builder.getRefType(resultElemTy);
mlir::Value returnValue =
builder.createIntegerConstant(loc, resultElemTy, 0);
for (unsigned int i = 0; i < rank; ++i) {
mlir::Value index =
builder.createIntegerConstant(loc, builder.getIndexType(), i + 1);
mlir::Value resultElemAddr = builder.create<hlfir::DesignateOp>(
loc, resultRefTy, resultArr, index);
builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr);
}

fir::genMinMaxlocReductionLoop(builder, array, init, genBodyOp, getAddrFn,
rank, elementType, loc, builder.getI1Type(),
resultArr, false);

mlir::Value asExpr = builder.create<hlfir::AsExprOp>(
loc, resultArr, builder.createBool(loc, false));

// Check all the users - the destroy is no longer required, and any assign
// can use resultArr directly so that VariableAssignBufferization in this
// pass can optimize the results. Other operations are replaces with an
// AsExpr for the temporary resultArr.
llvm::SmallVector<hlfir::DestroyOp> destroys;
llvm::SmallVector<hlfir::AssignOp> assigns;
for (auto user : minloc->getUsers()) {
if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(user))
destroys.push_back(destroy);
else if (auto assign = mlir::dyn_cast<hlfir::AssignOp>(user))
assigns.push_back(assign);
}

// Check if the minloc was the only user of the elemental (apart from a
// destroy), and remove it if so.
mlir::Operation::user_range elemUsers = elemental->getUsers();
hlfir::DestroyOp elemDestroy;
if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin());
if (!elemDestroy)
elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin());
}

for (auto d : destroys)
rewriter.eraseOp(d);
for (auto a : assigns)
a.setOperand(0, resultArr);
rewriter.replaceOp(minloc, asExpr);
if (elemDestroy) {
rewriter.eraseOp(elemDestroy);
rewriter.eraseOp(elemental);
}
return mlir::success();
}
};

class OptimizedBufferizationPass
: public hlfir::impl::OptimizedBufferizationBase<
OptimizedBufferizationPass> {
Expand All @@ -832,6 +1030,7 @@ class OptimizedBufferizationPass
patterns.insert<ReductionElementalConversion<hlfir::CountOp>>(context);
patterns.insert<ReductionElementalConversion<hlfir::AnyOp>>(context);
patterns.insert<ReductionElementalConversion<hlfir::AllOp>>(context);
patterns.insert<MinMaxlocElementalConversion>(context);

if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
func, std::move(patterns), config))) {
Expand Down
Loading