Skip to content

Commit 810c291

Browse files
authored
[Flang] Generate inline reduction loops for elemental count intrinsics (#75774)
This adds a ReductionElementalConversion transform to OptimizedBufferizationPass, taking hlfir::count(hlfir::elemental) and generating the inline loop to perform the count of true elements. This lets us generate a single loop instead of ending up as two plus a temporary. Any and All should be able to share the same code with a different function/initial value.
1 parent 6eb372e commit 810c291

File tree

2 files changed

+434
-0
lines changed

2 files changed

+434
-0
lines changed

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,125 @@ mlir::LogicalResult VariableAssignBufferization::matchAndRewrite(
659659
return mlir::success();
660660
}
661661

662+
using GenBodyFn =
663+
std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location, mlir::Value,
664+
const llvm::SmallVectorImpl<mlir::Value> &)>;
665+
static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder,
666+
mlir::Location loc, mlir::Value init,
667+
mlir::Value shape, GenBodyFn genBody) {
668+
auto extents = hlfir::getIndexExtents(loc, builder, shape);
669+
mlir::Value reduction = init;
670+
mlir::IndexType idxTy = builder.getIndexType();
671+
mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
672+
673+
// Create a reduction loop nest. We use one-based indices so that they can be
674+
// passed to the elemental, and reverse the order so that they can be
675+
// generated in column-major order for better performance.
676+
llvm::SmallVector<mlir::Value> indices(extents.size(), mlir::Value{});
677+
for (unsigned i = 0; i < extents.size(); ++i) {
678+
auto loop = builder.create<fir::DoLoopOp>(
679+
loc, oneIdx, extents[extents.size() - i - 1], oneIdx, false,
680+
/*finalCountValue=*/false, reduction);
681+
reduction = loop.getRegionIterArgs()[0];
682+
indices[extents.size() - i - 1] = loop.getInductionVar();
683+
// Set insertion point to the loop body so that the next loop
684+
// is inserted inside the current one.
685+
builder.setInsertionPointToStart(loop.getBody());
686+
}
687+
688+
// Generate the body
689+
reduction = genBody(builder, loc, reduction, indices);
690+
691+
// Unwind the loop nest.
692+
for (unsigned i = 0; i < extents.size(); ++i) {
693+
auto result = builder.create<fir::ResultOp>(loc, reduction);
694+
auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
695+
reduction = loop.getResult(0);
696+
// Set insertion point after the loop operation that we have
697+
// just processed.
698+
builder.setInsertionPointAfter(loop.getOperation());
699+
}
700+
701+
return reduction;
702+
}
703+
704+
/// Given a reduction operation with an elemental mask, attempt to generate a
705+
/// do-loop to perform the operation inline.
706+
/// %e = hlfir.elemental %shape unordered
707+
/// %r = hlfir.count %e
708+
/// =>
709+
/// %r = for.do_loop %arg = 1 to bound(%shape) step 1 iter_args(%arg2 = init)
710+
/// %i = <inline elemental>
711+
/// %c = <reduce count> %i
712+
/// fir.result %c
713+
template <typename Op>
714+
class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
715+
public:
716+
using mlir::OpRewritePattern<Op>::OpRewritePattern;
717+
718+
mlir::LogicalResult
719+
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
720+
mlir::Location loc = op.getLoc();
721+
hlfir::ElementalOp elemental =
722+
op.getMask().template getDefiningOp<hlfir::ElementalOp>();
723+
if (!elemental || op.getDim())
724+
return rewriter.notifyMatchFailure(op, "Did not find valid elemental");
725+
726+
fir::KindMapping kindMap =
727+
fir::getKindMapping(op->template getParentOfType<mlir::ModuleOp>());
728+
fir::FirOpBuilder builder{op, kindMap};
729+
730+
mlir::Value init;
731+
GenBodyFn genBodyFn;
732+
if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
733+
init = builder.createIntegerConstant(loc, op.getType(), 0);
734+
genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
735+
mlir::Value reduction,
736+
const llvm::SmallVectorImpl<mlir::Value> &indices)
737+
-> mlir::Value {
738+
// Inline the elemental and get the condition from it.
739+
auto yield = inlineElementalOp(loc, builder, elemental, indices);
740+
mlir::Value cond = builder.create<fir::ConvertOp>(
741+
loc, builder.getI1Type(), yield.getElementValue());
742+
yield->erase();
743+
744+
// Conditionally add one to the current value
745+
mlir::Value one =
746+
builder.createIntegerConstant(loc, reduction.getType(), 1);
747+
mlir::Value add1 =
748+
builder.create<mlir::arith::AddIOp>(loc, reduction, one);
749+
return builder.create<mlir::arith::SelectOp>(loc, cond, add1,
750+
reduction);
751+
};
752+
} else {
753+
static_assert("Expected Op to be handled");
754+
return mlir::failure();
755+
}
756+
757+
mlir::Value res = generateReductionLoop(builder, loc, init,
758+
elemental.getOperand(0), genBodyFn);
759+
if (res.getType() != op.getType())
760+
res = builder.create<fir::ConvertOp>(loc, op.getType(), res);
761+
762+
// Check if the op was the only user of the elemental (apart from a
763+
// destroy), and remove it if so.
764+
mlir::Operation::user_range elemUsers = elemental->getUsers();
765+
hlfir::DestroyOp elemDestroy;
766+
if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) {
767+
elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin());
768+
if (!elemDestroy)
769+
elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin());
770+
}
771+
772+
rewriter.replaceOp(op, res);
773+
if (elemDestroy) {
774+
rewriter.eraseOp(elemDestroy);
775+
rewriter.eraseOp(elemental);
776+
}
777+
return mlir::success();
778+
}
779+
};
780+
662781
class OptimizedBufferizationPass
663782
: public hlfir::impl::OptimizedBufferizationBase<
664783
OptimizedBufferizationPass> {
@@ -681,6 +800,7 @@ class OptimizedBufferizationPass
681800
patterns.insert<ElementalAssignBufferization>(context);
682801
patterns.insert<BroadcastAssignBufferization>(context);
683802
patterns.insert<VariableAssignBufferization>(context);
803+
patterns.insert<ReductionElementalConversion<hlfir::CountOp>>(context);
684804

685805
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
686806
func, std::move(patterns), config))) {

0 commit comments

Comments
 (0)