@@ -659,6 +659,125 @@ mlir::LogicalResult VariableAssignBufferization::matchAndRewrite(
659
659
return mlir::success ();
660
660
}
661
661
662
+ using GenBodyFn =
663
+ std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location, mlir::Value,
664
+ const llvm::SmallVectorImpl<mlir::Value> &)>;
665
+ static mlir::Value generateReductionLoop (fir::FirOpBuilder &builder,
666
+ mlir::Location loc, mlir::Value init,
667
+ mlir::Value shape, GenBodyFn genBody) {
668
+ auto extents = hlfir::getIndexExtents (loc, builder, shape);
669
+ mlir::Value reduction = init;
670
+ mlir::IndexType idxTy = builder.getIndexType ();
671
+ mlir::Value oneIdx = builder.createIntegerConstant (loc, idxTy, 1 );
672
+
673
+ // Create a reduction loop nest. We use one-based indices so that they can be
674
+ // passed to the elemental, and reverse the order so that they can be
675
+ // generated in column-major order for better performance.
676
+ llvm::SmallVector<mlir::Value> indices (extents.size (), mlir::Value{});
677
+ for (unsigned i = 0 ; i < extents.size (); ++i) {
678
+ auto loop = builder.create <fir::DoLoopOp>(
679
+ loc, oneIdx, extents[extents.size () - i - 1 ], oneIdx, false ,
680
+ /* finalCountValue=*/ false , reduction);
681
+ reduction = loop.getRegionIterArgs ()[0 ];
682
+ indices[extents.size () - i - 1 ] = loop.getInductionVar ();
683
+ // Set insertion point to the loop body so that the next loop
684
+ // is inserted inside the current one.
685
+ builder.setInsertionPointToStart (loop.getBody ());
686
+ }
687
+
688
+ // Generate the body
689
+ reduction = genBody (builder, loc, reduction, indices);
690
+
691
+ // Unwind the loop nest.
692
+ for (unsigned i = 0 ; i < extents.size (); ++i) {
693
+ auto result = builder.create <fir::ResultOp>(loc, reduction);
694
+ auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp ());
695
+ reduction = loop.getResult (0 );
696
+ // Set insertion point after the loop operation that we have
697
+ // just processed.
698
+ builder.setInsertionPointAfter (loop.getOperation ());
699
+ }
700
+
701
+ return reduction;
702
+ }
703
+
704
+ // / Given a reduction operation with an elemental mask, attempt to generate a
705
+ // / do-loop to perform the operation inline.
706
+ // / %e = hlfir.elemental %shape unordered
707
+ // / %r = hlfir.count %e
708
+ // / =>
709
+ // / %r = for.do_loop %arg = 1 to bound(%shape) step 1 iter_args(%arg2 = init)
710
+ // / %i = <inline elemental>
711
+ // / %c = <reduce count> %i
712
+ // / fir.result %c
713
+ template <typename Op>
714
+ class ReductionElementalConversion : public mlir ::OpRewritePattern<Op> {
715
+ public:
716
+ using mlir::OpRewritePattern<Op>::OpRewritePattern;
717
+
718
+ mlir::LogicalResult
719
+ matchAndRewrite (Op op, mlir::PatternRewriter &rewriter) const override {
720
+ mlir::Location loc = op.getLoc ();
721
+ hlfir::ElementalOp elemental =
722
+ op.getMask ().template getDefiningOp <hlfir::ElementalOp>();
723
+ if (!elemental || op.getDim ())
724
+ return rewriter.notifyMatchFailure (op, " Did not find valid elemental" );
725
+
726
+ fir::KindMapping kindMap =
727
+ fir::getKindMapping (op->template getParentOfType <mlir::ModuleOp>());
728
+ fir::FirOpBuilder builder{op, kindMap};
729
+
730
+ mlir::Value init;
731
+ GenBodyFn genBodyFn;
732
+ if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
733
+ init = builder.createIntegerConstant (loc, op.getType (), 0 );
734
+ genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
735
+ mlir::Value reduction,
736
+ const llvm::SmallVectorImpl<mlir::Value> &indices)
737
+ -> mlir::Value {
738
+ // Inline the elemental and get the condition from it.
739
+ auto yield = inlineElementalOp (loc, builder, elemental, indices);
740
+ mlir::Value cond = builder.create <fir::ConvertOp>(
741
+ loc, builder.getI1Type (), yield.getElementValue ());
742
+ yield->erase ();
743
+
744
+ // Conditionally add one to the current value
745
+ mlir::Value one =
746
+ builder.createIntegerConstant (loc, reduction.getType (), 1 );
747
+ mlir::Value add1 =
748
+ builder.create <mlir::arith::AddIOp>(loc, reduction, one);
749
+ return builder.create <mlir::arith::SelectOp>(loc, cond, add1,
750
+ reduction);
751
+ };
752
+ } else {
753
+ static_assert (" Expected Op to be handled" );
754
+ return mlir::failure ();
755
+ }
756
+
757
+ mlir::Value res = generateReductionLoop (builder, loc, init,
758
+ elemental.getOperand (0 ), genBodyFn);
759
+ if (res.getType () != op.getType ())
760
+ res = builder.create <fir::ConvertOp>(loc, op.getType (), res);
761
+
762
+ // Check if the op was the only user of the elemental (apart from a
763
+ // destroy), and remove it if so.
764
+ mlir::Operation::user_range elemUsers = elemental->getUsers ();
765
+ hlfir::DestroyOp elemDestroy;
766
+ if (std::distance (elemUsers.begin (), elemUsers.end ()) == 2 ) {
767
+ elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin ());
768
+ if (!elemDestroy)
769
+ elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin ());
770
+ }
771
+
772
+ rewriter.replaceOp (op, res);
773
+ if (elemDestroy) {
774
+ rewriter.eraseOp (elemDestroy);
775
+ rewriter.eraseOp (elemental);
776
+ }
777
+ return mlir::success ();
778
+ }
779
+ };
780
+
662
781
class OptimizedBufferizationPass
663
782
: public hlfir::impl::OptimizedBufferizationBase<
664
783
OptimizedBufferizationPass> {
@@ -681,6 +800,7 @@ class OptimizedBufferizationPass
681
800
patterns.insert <ElementalAssignBufferization>(context);
682
801
patterns.insert <BroadcastAssignBufferization>(context);
683
802
patterns.insert <VariableAssignBufferization>(context);
803
+ patterns.insert <ReductionElementalConversion<hlfir::CountOp>>(context);
684
804
685
805
if (mlir::failed (mlir::applyPatternsAndFoldGreedily (
686
806
func, std::move (patterns), config))) {
0 commit comments