Skip to content

Commit 223d3da

Browse files
authored
[Flang] Minloc elemental intrinsic lowering (#74828)
Currently the lowering of a minloc intrinsic with a mask will look something like: %e = hlfir.elemental %shape ({ ... }) %m = hlfir.minloc %array mask %e hlfir.assign %m to %result hlfir.destroy %m The elemental will be expanded into a temporary+loop, the minloc into a FortranAMinloc call (which hopefully gets simplified to a specialized call that can be inlined at the call site), and the assign might get expanded to a FortranAAssign. It would be better to generate the entire construct as single loop if we can - one that performs the minloc calculation with the mask elemental computed inline. This patch attempt to do that, adding a hlfir version of the expansion code from SimplifyIntrinsics that turns an minloc+elemental into a single combined loop nest. It attempts to reuse the methods in genMinlocReductionLoop for constructing the loop with a modified loop body. The declaration for the function is currently in Optimizer/Support/Utils.h, but there might be a better place for it. It is added as part of the OptimizedBufferizationPass, like the similar count/any/all that have been added recently.
1 parent 4b8e55c commit 223d3da

File tree

5 files changed

+751
-103
lines changed

5 files changed

+751
-103
lines changed

flang/include/flang/Optimizer/Support/Utils.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,25 @@ inline void intrinsicTypeTODO(fir::FirOpBuilder &builder, mlir::Type type,
133133
fir::numericMlirTypeToFortran(builder, type, loc, intrinsicName) +
134134
" in " + intrinsicName);
135135
}
136+
137+
using MinlocBodyOpGeneratorTy = llvm::function_ref<mlir::Value(
138+
fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
139+
mlir::Value, mlir::Value, const llvm::SmallVectorImpl<mlir::Value> &)>;
140+
using InitValGeneratorTy = llvm::function_ref<mlir::Value(
141+
fir::FirOpBuilder &, mlir::Location, const mlir::Type &)>;
142+
using AddrGeneratorTy = llvm::function_ref<mlir::Value(
143+
fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
144+
mlir::Value)>;
145+
146+
// Produces a loop nest for a Minloc intrinsic.
147+
void genMinMaxlocReductionLoop(fir::FirOpBuilder &builder, mlir::Value array,
148+
InitValGeneratorTy initVal,
149+
MinlocBodyOpGeneratorTy genBody,
150+
fir::AddrGeneratorTy getAddrFn, unsigned rank,
151+
mlir::Type elementType, mlir::Location loc,
152+
mlir::Type maskElemType, mlir::Value resultArr,
153+
bool maskMayBeLogicalScalar);
154+
136155
} // namespace fir
137156

138157
#endif // FORTRAN_OPTIMIZER_SUPPORT_UTILS_H

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
2121
#include "flang/Optimizer/HLFIR/HLFIROps.h"
2222
#include "flang/Optimizer/HLFIR/Passes.h"
23+
#include "flang/Optimizer/Support/Utils.h"
2324
#include "mlir/Dialect/Func/IR/FuncOps.h"
2425
#include "mlir/IR/Dominance.h"
2526
#include "mlir/IR/PatternMatch.h"
@@ -807,6 +808,203 @@ class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
807808
}
808809
};
809810

811+
// Look for minloc(mask=elemental) and generate the minloc loop with
812+
// inlined elemental.
813+
// %e = hlfir.elemental %shape ({ ... })
814+
// %m = hlfir.minloc %array mask %e
815+
class MinMaxlocElementalConversion
816+
: public mlir::OpRewritePattern<hlfir::MinlocOp> {
817+
public:
818+
using mlir::OpRewritePattern<hlfir::MinlocOp>::OpRewritePattern;
819+
820+
mlir::LogicalResult
821+
matchAndRewrite(hlfir::MinlocOp minloc,
822+
mlir::PatternRewriter &rewriter) const override {
823+
if (!minloc.getMask() || minloc.getDim() || minloc.getBack())
824+
return rewriter.notifyMatchFailure(minloc, "Did not find valid minloc");
825+
826+
auto elemental = minloc.getMask().getDefiningOp<hlfir::ElementalOp>();
827+
if (!elemental || hlfir::elementalOpMustProduceTemp(elemental))
828+
return rewriter.notifyMatchFailure(minloc, "Did not find elemental");
829+
830+
mlir::Value array = minloc.getArray();
831+
832+
unsigned rank = mlir::cast<hlfir::ExprType>(minloc.getType()).getShape()[0];
833+
mlir::Type arrayType = array.getType();
834+
if (!arrayType.isa<fir::BoxType>())
835+
return rewriter.notifyMatchFailure(
836+
minloc, "Currently requires a boxed type input");
837+
mlir::Type elementType = hlfir::getFortranElementType(arrayType);
838+
if (!fir::isa_trivial(elementType))
839+
return rewriter.notifyMatchFailure(
840+
minloc, "Character arrays are currently not handled");
841+
842+
mlir::Location loc = minloc.getLoc();
843+
fir::FirOpBuilder builder{rewriter, minloc.getOperation()};
844+
mlir::Value resultArr = builder.createTemporary(
845+
loc, fir::SequenceType::get(
846+
rank, hlfir::getFortranElementType(minloc.getType())));
847+
848+
auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
849+
mlir::Type elementType) {
850+
if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
851+
const llvm::fltSemantics &sem = ty.getFloatSemantics();
852+
return builder.createRealConstant(
853+
loc, elementType,
854+
llvm::APFloat::getLargest(sem, /*Negative=*/false));
855+
}
856+
unsigned bits = elementType.getIntOrFloatBitWidth();
857+
int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
858+
return builder.createIntegerConstant(loc, elementType, maxInt);
859+
};
860+
861+
auto genBodyOp =
862+
[&rank, &resultArr, &elemental](
863+
fir::FirOpBuilder builder, mlir::Location loc,
864+
mlir::Type elementType, mlir::Value array, mlir::Value flagRef,
865+
mlir::Value reduction,
866+
const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
867+
// We are in the innermost loop: generate the elemental inline
868+
mlir::Value oneIdx =
869+
builder.createIntegerConstant(loc, builder.getIndexType(), 1);
870+
llvm::SmallVector<mlir::Value> oneBasedIndices;
871+
llvm::transform(
872+
indices, std::back_inserter(oneBasedIndices), [&](mlir::Value V) {
873+
return builder.create<mlir::arith::AddIOp>(loc, V, oneIdx);
874+
});
875+
hlfir::YieldElementOp yield =
876+
hlfir::inlineElementalOp(loc, builder, elemental, oneBasedIndices);
877+
mlir::Value maskElem = yield.getElementValue();
878+
yield->erase();
879+
880+
mlir::Type ifCompatType = builder.getI1Type();
881+
mlir::Value ifCompatElem =
882+
builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem);
883+
884+
llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
885+
fir::IfOp maskIfOp =
886+
builder.create<fir::IfOp>(loc, elementType, ifCompatElem,
887+
/*withElseRegion=*/true);
888+
builder.setInsertionPointToStart(&maskIfOp.getThenRegion().front());
889+
890+
// Set flag that mask was true at some point
891+
mlir::Value flagSet = builder.createIntegerConstant(
892+
loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
893+
builder.create<fir::StoreOp>(loc, flagSet, flagRef);
894+
mlir::Value addr = hlfir::getElementAt(loc, builder, hlfir::Entity{array},
895+
oneBasedIndices);
896+
mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
897+
898+
// Compare with the max reduction value
899+
mlir::Value cmp;
900+
if (elementType.isa<mlir::FloatType>()) {
901+
cmp = builder.create<mlir::arith::CmpFOp>(
902+
loc, mlir::arith::CmpFPredicate::OLT, elem, reduction);
903+
} else if (elementType.isa<mlir::IntegerType>()) {
904+
cmp = builder.create<mlir::arith::CmpIOp>(
905+
loc, mlir::arith::CmpIPredicate::slt, elem, reduction);
906+
} else {
907+
llvm_unreachable("unsupported type");
908+
}
909+
910+
// Set the new coordinate to the result
911+
fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
912+
/*withElseRegion*/ true);
913+
914+
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
915+
mlir::Type resultElemTy =
916+
hlfir::getFortranElementType(resultArr.getType());
917+
mlir::Type returnRefTy = builder.getRefType(resultElemTy);
918+
mlir::IndexType idxTy = builder.getIndexType();
919+
920+
for (unsigned int i = 0; i < rank; ++i) {
921+
mlir::Value index = builder.createIntegerConstant(loc, idxTy, i + 1);
922+
mlir::Value resultElemAddr = builder.create<hlfir::DesignateOp>(
923+
loc, returnRefTy, resultArr, index);
924+
mlir::Value fortranIndex = builder.create<fir::ConvertOp>(
925+
loc, resultElemTy, oneBasedIndices[i]);
926+
builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr);
927+
}
928+
builder.create<fir::ResultOp>(loc, elem);
929+
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
930+
builder.create<fir::ResultOp>(loc, reduction);
931+
builder.setInsertionPointAfter(ifOp);
932+
933+
// Close the mask if
934+
builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
935+
builder.setInsertionPointToStart(&maskIfOp.getElseRegion().front());
936+
builder.create<fir::ResultOp>(loc, reduction);
937+
builder.setInsertionPointAfter(maskIfOp);
938+
939+
return maskIfOp.getResult(0);
940+
};
941+
auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
942+
const mlir::Type &resultElemType, mlir::Value resultArr,
943+
mlir::Value index) {
944+
mlir::Type resultRefTy = builder.getRefType(resultElemType);
945+
mlir::Value oneIdx =
946+
builder.createIntegerConstant(loc, builder.getIndexType(), 1);
947+
index = builder.create<mlir::arith::AddIOp>(loc, index, oneIdx);
948+
return builder.create<hlfir::DesignateOp>(loc, resultRefTy, resultArr,
949+
index);
950+
};
951+
952+
// Initialize the result
953+
mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
954+
mlir::Type resultRefTy = builder.getRefType(resultElemTy);
955+
mlir::Value returnValue =
956+
builder.createIntegerConstant(loc, resultElemTy, 0);
957+
for (unsigned int i = 0; i < rank; ++i) {
958+
mlir::Value index =
959+
builder.createIntegerConstant(loc, builder.getIndexType(), i + 1);
960+
mlir::Value resultElemAddr = builder.create<hlfir::DesignateOp>(
961+
loc, resultRefTy, resultArr, index);
962+
builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr);
963+
}
964+
965+
fir::genMinMaxlocReductionLoop(builder, array, init, genBodyOp, getAddrFn,
966+
rank, elementType, loc, builder.getI1Type(),
967+
resultArr, false);
968+
969+
mlir::Value asExpr = builder.create<hlfir::AsExprOp>(
970+
loc, resultArr, builder.createBool(loc, false));
971+
972+
// Check all the users - the destroy is no longer required, and any assign
973+
// can use resultArr directly so that VariableAssignBufferization in this
974+
// pass can optimize the results. Other operations are replaces with an
975+
// AsExpr for the temporary resultArr.
976+
llvm::SmallVector<hlfir::DestroyOp> destroys;
977+
llvm::SmallVector<hlfir::AssignOp> assigns;
978+
for (auto user : minloc->getUsers()) {
979+
if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(user))
980+
destroys.push_back(destroy);
981+
else if (auto assign = mlir::dyn_cast<hlfir::AssignOp>(user))
982+
assigns.push_back(assign);
983+
}
984+
985+
// Check if the minloc was the only user of the elemental (apart from a
986+
// destroy), and remove it if so.
987+
mlir::Operation::user_range elemUsers = elemental->getUsers();
988+
hlfir::DestroyOp elemDestroy;
989+
if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
990+
elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin());
991+
if (!elemDestroy)
992+
elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin());
993+
}
994+
995+
for (auto d : destroys)
996+
rewriter.eraseOp(d);
997+
for (auto a : assigns)
998+
a.setOperand(0, resultArr);
999+
rewriter.replaceOp(minloc, asExpr);
1000+
if (elemDestroy) {
1001+
rewriter.eraseOp(elemDestroy);
1002+
rewriter.eraseOp(elemental);
1003+
}
1004+
return mlir::success();
1005+
}
1006+
};
1007+
8101008
class OptimizedBufferizationPass
8111009
: public hlfir::impl::OptimizedBufferizationBase<
8121010
OptimizedBufferizationPass> {
@@ -832,6 +1030,7 @@ class OptimizedBufferizationPass
8321030
patterns.insert<ReductionElementalConversion<hlfir::CountOp>>(context);
8331031
patterns.insert<ReductionElementalConversion<hlfir::AnyOp>>(context);
8341032
patterns.insert<ReductionElementalConversion<hlfir::AllOp>>(context);
1033+
patterns.insert<MinMaxlocElementalConversion>(context);
8351034

8361035
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
8371036
func, std::move(patterns), config))) {

0 commit comments

Comments
 (0)