@@ -659,6 +659,124 @@ mlir::LogicalResult VariableAssignBufferization::matchAndRewrite(
659
659
return mlir::success ();
660
660
}
661
661
662
+ using GenBodyFn =
663
+ std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location, mlir::Value,
664
+ const llvm::SmallVectorImpl<mlir::Value> &)>;
665
+ static mlir::Value generateReductionLoop (fir::FirOpBuilder &builder,
666
+ mlir::Location loc, mlir::Value init,
667
+ mlir::Value shape, GenBodyFn genBody) {
668
+ auto extents = hlfir::getIndexExtents (loc, builder, shape);
669
+ mlir::Value reduction = init;
670
+ mlir::IndexType idxTy = builder.getIndexType ();
671
+ mlir::Value oneIdx = builder.createIntegerConstant (loc, idxTy, 1 );
672
+
673
+ // Create a reduction loop nest. We use one-based indices so that they can be
674
+ // passed to the elemental.
675
+ llvm::SmallVector<mlir::Value> indices;
676
+ for (unsigned i = 0 ; i < extents.size (); ++i) {
677
+ auto loop =
678
+ builder.create <fir::DoLoopOp>(loc, oneIdx, extents[i], oneIdx, false ,
679
+ /* finalCountValue=*/ false , reduction);
680
+ reduction = loop.getRegionIterArgs ()[0 ];
681
+ indices.push_back (loop.getInductionVar ());
682
+ // Set insertion point to the loop body so that the next loop
683
+ // is inserted inside the current one.
684
+ builder.setInsertionPointToStart (loop.getBody ());
685
+ }
686
+
687
+ // Generate the body
688
+ reduction = genBody (builder, loc, reduction, indices);
689
+
690
+ // Unwind the loop nest.
691
+ for (unsigned i = 0 ; i < extents.size (); ++i) {
692
+ auto result = builder.create <fir::ResultOp>(loc, reduction);
693
+ auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp ());
694
+ reduction = loop.getResult (0 );
695
+ // Set insertion point after the loop operation that we have
696
+ // just processed.
697
+ builder.setInsertionPointAfter (loop.getOperation ());
698
+ }
699
+
700
+ return reduction;
701
+ }
702
+
703
+ // / Given a reduction operation with an elemental mask, attempt to generate a
704
+ // / do-loop to perform the operation inline.
705
+ // / %e = hlfir.elemental %shape unordered
706
+ // / %r = hlfir.count %e
707
+ // / =>
708
+ // / %r = for.do_loop %arg = 1 to bound(%shape) step 1 iter_args(%arg2 = init)
709
+ // / %i = <inline elemental>
710
+ // / %c = <reduce count> %i
711
+ // / fir.result %c
712
+ template <typename Op>
713
+ class ReductionElementalConversion : public mlir ::OpRewritePattern<Op> {
714
+ public:
715
+ using mlir::OpRewritePattern<Op>::OpRewritePattern;
716
+
717
+ mlir::LogicalResult
718
+ matchAndRewrite (Op op, mlir::PatternRewriter &rewriter) const override {
719
+ mlir::Location loc = op.getLoc ();
720
+ hlfir::ElementalOp elemental =
721
+ op.getMask ().template getDefiningOp <hlfir::ElementalOp>();
722
+ if (!elemental || op.getDim ())
723
+ return rewriter.notifyMatchFailure (op, " Did not find valid elemental" );
724
+
725
+ fir::KindMapping kindMap =
726
+ fir::getKindMapping (op->template getParentOfType <mlir::ModuleOp>());
727
+ fir::FirOpBuilder builder{op, kindMap};
728
+
729
+ mlir::Value init;
730
+ GenBodyFn genBodyFn;
731
+ if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
732
+ init = builder.createIntegerConstant (loc, op.getType (), 0 );
733
+ genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
734
+ mlir::Value reduction,
735
+ const llvm::SmallVectorImpl<mlir::Value> &indices)
736
+ -> mlir::Value {
737
+ // Inline the elemental and get the condition from it.
738
+ auto yield = inlineElementalOp (loc, builder, elemental, indices);
739
+ mlir::Value cond = builder.create <fir::ConvertOp>(
740
+ loc, builder.getI1Type (), yield.getElementValue ());
741
+ yield->erase ();
742
+
743
+ // Conditionally add one to the current value
744
+ mlir::Value one =
745
+ builder.createIntegerConstant (loc, reduction.getType (), 1 );
746
+ mlir::Value add1 =
747
+ builder.create <mlir::arith::AddIOp>(loc, reduction, one);
748
+ return builder.create <mlir::arith::SelectOp>(loc, cond, add1,
749
+ reduction);
750
+ };
751
+ } else {
752
+ static_assert (" Expected Op to be handled" );
753
+ return mlir::failure ();
754
+ }
755
+
756
+ mlir::Value res = generateReductionLoop (builder, loc, init,
757
+ elemental.getOperand (0 ), genBodyFn);
758
+ if (res.getType () != op.getType ())
759
+ res = builder.create <fir::ConvertOp>(loc, op.getType (), res);
760
+
761
+ // Check if the op was the only user of the elemental (apart from a
762
+ // destroy), and remove it if so.
763
+ mlir::Operation::user_range elemUsers = elemental->getUsers ();
764
+ hlfir::DestroyOp elemDestroy;
765
+ if (std::distance (elemUsers.begin (), elemUsers.end ()) == 2 ) {
766
+ elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin ());
767
+ if (!elemDestroy)
768
+ elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin ());
769
+ }
770
+
771
+ rewriter.replaceOp (op, res);
772
+ if (elemDestroy) {
773
+ rewriter.eraseOp (elemDestroy);
774
+ rewriter.eraseOp (elemental);
775
+ }
776
+ return mlir::success ();
777
+ }
778
+ };
779
+
662
780
class OptimizedBufferizationPass
663
781
: public hlfir::impl::OptimizedBufferizationBase<
664
782
OptimizedBufferizationPass> {
@@ -681,6 +799,7 @@ class OptimizedBufferizationPass
681
799
patterns.insert <ElementalAssignBufferization>(context);
682
800
patterns.insert <BroadcastAssignBufferization>(context);
683
801
patterns.insert <VariableAssignBufferization>(context);
802
+ patterns.insert <ReductionElementalConversion<hlfir::CountOp>>(context);
684
803
685
804
if (mlir::failed (mlir::applyPatternsAndFoldGreedily (
686
805
func, std::move (patterns), config))) {
0 commit comments