Skip to content

Commit ed74ee7

Browse files
committed
[Flang] Generate inline reduction loops for elemental count intrinsics
This adds a ReductionElementalConversion transform to OptimizedBufferizationPass, taking hlfir::count(hlfir::elemental) and generating the inline loop to perform the count of true elements. This lets us generate a single loop instead of ending up as two plus a temporary. This is currently part of OptimizedBufferization, similar to llvm#74828. I attempted to move it to LowerHLFIRIntrinsics to make it part of the existing lowering, but it hit problems with inlining elementals that contain operations that are being legalized by the same pass. Any and All should be able to share the same code with a different function/initial value.
1 parent 3d688d4 commit ed74ee7

File tree

2 files changed

+433
-0
lines changed

2 files changed

+433
-0
lines changed

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,124 @@ mlir::LogicalResult VariableAssignBufferization::matchAndRewrite(
659659
return mlir::success();
660660
}
661661

662+
using GenBodyFn =
663+
std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location, mlir::Value,
664+
const llvm::SmallVectorImpl<mlir::Value> &)>;
665+
static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder,
666+
mlir::Location loc, mlir::Value init,
667+
mlir::Value shape, GenBodyFn genBody) {
668+
auto extents = hlfir::getIndexExtents(loc, builder, shape);
669+
mlir::Value reduction = init;
670+
mlir::IndexType idxTy = builder.getIndexType();
671+
mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
672+
673+
// Create a reduction loop nest. We use one-based indices so that they can be
674+
// passed to the elemental.
675+
llvm::SmallVector<mlir::Value> indices;
676+
for (unsigned i = 0; i < extents.size(); ++i) {
677+
auto loop =
678+
builder.create<fir::DoLoopOp>(loc, oneIdx, extents[i], oneIdx, false,
679+
/*finalCountValue=*/false, reduction);
680+
reduction = loop.getRegionIterArgs()[0];
681+
indices.push_back(loop.getInductionVar());
682+
// Set insertion point to the loop body so that the next loop
683+
// is inserted inside the current one.
684+
builder.setInsertionPointToStart(loop.getBody());
685+
}
686+
687+
// Generate the body
688+
reduction = genBody(builder, loc, reduction, indices);
689+
690+
// Unwind the loop nest.
691+
for (unsigned i = 0; i < extents.size(); ++i) {
692+
auto result = builder.create<fir::ResultOp>(loc, reduction);
693+
auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
694+
reduction = loop.getResult(0);
695+
// Set insertion point after the loop operation that we have
696+
// just processed.
697+
builder.setInsertionPointAfter(loop.getOperation());
698+
}
699+
700+
return reduction;
701+
}
702+
703+
/// Given a reduction operation with an elemental mask, attempt to generate a
704+
/// do-loop to perform the operation inline.
705+
/// %e = hlfir.elemental %shape unordered
706+
/// %r = hlfir.count %e
707+
/// =>
708+
/// %r = for.do_loop %arg = 1 to bound(%shape) step 1 iter_args(%arg2 = init)
709+
/// %i = <inline elemental>
710+
/// %c = <reduce count> %i
711+
/// fir.result %c
712+
template <typename Op>
713+
class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
714+
public:
715+
using mlir::OpRewritePattern<Op>::OpRewritePattern;
716+
717+
mlir::LogicalResult
718+
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
719+
mlir::Location loc = op.getLoc();
720+
hlfir::ElementalOp elemental =
721+
op.getMask().template getDefiningOp<hlfir::ElementalOp>();
722+
if (!elemental || op.getDim())
723+
return rewriter.notifyMatchFailure(op, "Did not find valid elemental");
724+
725+
fir::KindMapping kindMap =
726+
fir::getKindMapping(op->template getParentOfType<mlir::ModuleOp>());
727+
fir::FirOpBuilder builder{op, kindMap};
728+
729+
mlir::Value init;
730+
GenBodyFn genBodyFn;
731+
if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
732+
init = builder.createIntegerConstant(loc, op.getType(), 0);
733+
genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
734+
mlir::Value reduction,
735+
const llvm::SmallVectorImpl<mlir::Value> &indices)
736+
-> mlir::Value {
737+
// Inline the elemental and get the condition from it.
738+
auto yield = inlineElementalOp(loc, builder, elemental, indices);
739+
mlir::Value cond = builder.create<fir::ConvertOp>(
740+
loc, builder.getI1Type(), yield.getElementValue());
741+
yield->erase();
742+
743+
// Conditionally add one to the current value
744+
mlir::Value one =
745+
builder.createIntegerConstant(loc, reduction.getType(), 1);
746+
mlir::Value add1 =
747+
builder.create<mlir::arith::AddIOp>(loc, reduction, one);
748+
return builder.create<mlir::arith::SelectOp>(loc, cond, add1,
749+
reduction);
750+
};
751+
} else {
752+
static_assert("Expected Op to be handled");
753+
return mlir::failure();
754+
}
755+
756+
mlir::Value res = generateReductionLoop(builder, loc, init,
757+
elemental.getOperand(0), genBodyFn);
758+
if (res.getType() != op.getType())
759+
res = builder.create<fir::ConvertOp>(loc, op.getType(), res);
760+
761+
// Check if the op was the only user of the elemental (apart from a
762+
// destroy), and remove it if so.
763+
mlir::Operation::user_range elemUsers = elemental->getUsers();
764+
hlfir::DestroyOp elemDestroy;
765+
if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
766+
elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin());
767+
if (!elemDestroy)
768+
elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin());
769+
}
770+
771+
rewriter.replaceOp(op, res);
772+
if (elemDestroy) {
773+
rewriter.eraseOp(elemDestroy);
774+
rewriter.eraseOp(elemental);
775+
}
776+
return mlir::success();
777+
}
778+
};
779+
662780
class OptimizedBufferizationPass
663781
: public hlfir::impl::OptimizedBufferizationBase<
664782
OptimizedBufferizationPass> {
@@ -681,6 +799,7 @@ class OptimizedBufferizationPass
681799
patterns.insert<ElementalAssignBufferization>(context);
682800
patterns.insert<BroadcastAssignBufferization>(context);
683801
patterns.insert<VariableAssignBufferization>(context);
802+
patterns.insert<ReductionElementalConversion<hlfir::CountOp>>(context);
684803

685804
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
686805
func, std::move(patterns), config))) {

0 commit comments

Comments
 (0)