20
20
#include " flang/Optimizer/HLFIR/HLFIRDialect.h"
21
21
#include " flang/Optimizer/HLFIR/HLFIROps.h"
22
22
#include " flang/Optimizer/HLFIR/Passes.h"
23
+ #include " flang/Optimizer/Support/Utils.h"
23
24
#include " mlir/Dialect/Func/IR/FuncOps.h"
24
25
#include " mlir/IR/Dominance.h"
25
26
#include " mlir/IR/PatternMatch.h"
@@ -807,6 +808,203 @@ class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
807
808
}
808
809
};
809
810
811
+ // Look for minloc(mask=elemental) and generate the minloc loop with
812
+ // inlined elemental.
813
+ // %e = hlfir.elemental %shape ({ ... })
814
+ // %m = hlfir.minloc %array mask %e
815
+ class MinMaxlocElementalConversion
816
+ : public mlir::OpRewritePattern<hlfir::MinlocOp> {
817
+ public:
818
+ using mlir::OpRewritePattern<hlfir::MinlocOp>::OpRewritePattern;
819
+
820
+ mlir::LogicalResult
821
+ matchAndRewrite (hlfir::MinlocOp minloc,
822
+ mlir::PatternRewriter &rewriter) const override {
823
+ if (!minloc.getMask () || minloc.getDim () || minloc.getBack ())
824
+ return rewriter.notifyMatchFailure (minloc, " Did not find valid minloc" );
825
+
826
+ auto elemental = minloc.getMask ().getDefiningOp <hlfir::ElementalOp>();
827
+ if (!elemental || hlfir::elementalOpMustProduceTemp (elemental))
828
+ return rewriter.notifyMatchFailure (minloc, " Did not find elemental" );
829
+
830
+ mlir::Value array = minloc.getArray ();
831
+
832
+ unsigned rank = mlir::cast<hlfir::ExprType>(minloc.getType ()).getShape ()[0 ];
833
+ mlir::Type arrayType = array.getType ();
834
+ if (!arrayType.isa <fir::BoxType>())
835
+ return rewriter.notifyMatchFailure (
836
+ minloc, " Currently requires a boxed type input" );
837
+ mlir::Type elementType = hlfir::getFortranElementType (arrayType);
838
+ if (!fir::isa_trivial (elementType))
839
+ return rewriter.notifyMatchFailure (
840
+ minloc, " Character arrays are currently not handled" );
841
+
842
+ mlir::Location loc = minloc.getLoc ();
843
+ fir::FirOpBuilder builder{rewriter, minloc.getOperation ()};
844
+ mlir::Value resultArr = builder.createTemporary (
845
+ loc, fir::SequenceType::get (
846
+ rank, hlfir::getFortranElementType (minloc.getType ())));
847
+
848
+ auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
849
+ mlir::Type elementType) {
850
+ if (auto ty = elementType.dyn_cast <mlir::FloatType>()) {
851
+ const llvm::fltSemantics &sem = ty.getFloatSemantics ();
852
+ return builder.createRealConstant (
853
+ loc, elementType,
854
+ llvm::APFloat::getLargest (sem, /* Negative=*/ false ));
855
+ }
856
+ unsigned bits = elementType.getIntOrFloatBitWidth ();
857
+ int64_t maxInt = llvm::APInt::getSignedMaxValue (bits).getSExtValue ();
858
+ return builder.createIntegerConstant (loc, elementType, maxInt);
859
+ };
860
+
861
+ auto genBodyOp =
862
+ [&rank, &resultArr, &elemental](
863
+ fir::FirOpBuilder builder, mlir::Location loc,
864
+ mlir::Type elementType, mlir::Value array, mlir::Value flagRef,
865
+ mlir::Value reduction,
866
+ const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
867
+ // We are in the innermost loop: generate the elemental inline
868
+ mlir::Value oneIdx =
869
+ builder.createIntegerConstant (loc, builder.getIndexType (), 1 );
870
+ llvm::SmallVector<mlir::Value> oneBasedIndices;
871
+ llvm::transform (
872
+ indices, std::back_inserter (oneBasedIndices), [&](mlir::Value V) {
873
+ return builder.create <mlir::arith::AddIOp>(loc, V, oneIdx);
874
+ });
875
+ hlfir::YieldElementOp yield =
876
+ hlfir::inlineElementalOp (loc, builder, elemental, oneBasedIndices);
877
+ mlir::Value maskElem = yield.getElementValue ();
878
+ yield->erase ();
879
+
880
+ mlir::Type ifCompatType = builder.getI1Type ();
881
+ mlir::Value ifCompatElem =
882
+ builder.create <fir::ConvertOp>(loc, ifCompatType, maskElem);
883
+
884
+ llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
885
+ fir::IfOp maskIfOp =
886
+ builder.create <fir::IfOp>(loc, elementType, ifCompatElem,
887
+ /* withElseRegion=*/ true );
888
+ builder.setInsertionPointToStart (&maskIfOp.getThenRegion ().front ());
889
+
890
+ // Set flag that mask was true at some point
891
+ mlir::Value flagSet = builder.createIntegerConstant (
892
+ loc, mlir::cast<fir::ReferenceType>(flagRef.getType ()).getEleTy (), 1 );
893
+ builder.create <fir::StoreOp>(loc, flagSet, flagRef);
894
+ mlir::Value addr = hlfir::getElementAt (loc, builder, hlfir::Entity{array},
895
+ oneBasedIndices);
896
+ mlir::Value elem = builder.create <fir::LoadOp>(loc, addr);
897
+
898
+ // Compare with the max reduction value
899
+ mlir::Value cmp;
900
+ if (elementType.isa <mlir::FloatType>()) {
901
+ cmp = builder.create <mlir::arith::CmpFOp>(
902
+ loc, mlir::arith::CmpFPredicate::OLT, elem, reduction);
903
+ } else if (elementType.isa <mlir::IntegerType>()) {
904
+ cmp = builder.create <mlir::arith::CmpIOp>(
905
+ loc, mlir::arith::CmpIPredicate::slt, elem, reduction);
906
+ } else {
907
+ llvm_unreachable (" unsupported type" );
908
+ }
909
+
910
+ // Set the new coordinate to the result
911
+ fir::IfOp ifOp = builder.create <fir::IfOp>(loc, elementType, cmp,
912
+ /* withElseRegion*/ true );
913
+
914
+ builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
915
+ mlir::Type resultElemTy =
916
+ hlfir::getFortranElementType (resultArr.getType ());
917
+ mlir::Type returnRefTy = builder.getRefType (resultElemTy);
918
+ mlir::IndexType idxTy = builder.getIndexType ();
919
+
920
+ for (unsigned int i = 0 ; i < rank; ++i) {
921
+ mlir::Value index = builder.createIntegerConstant (loc, idxTy, i + 1 );
922
+ mlir::Value resultElemAddr = builder.create <hlfir::DesignateOp>(
923
+ loc, returnRefTy, resultArr, index);
924
+ mlir::Value fortranIndex = builder.create <fir::ConvertOp>(
925
+ loc, resultElemTy, oneBasedIndices[i]);
926
+ builder.create <fir::StoreOp>(loc, fortranIndex, resultElemAddr);
927
+ }
928
+ builder.create <fir::ResultOp>(loc, elem);
929
+ builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
930
+ builder.create <fir::ResultOp>(loc, reduction);
931
+ builder.setInsertionPointAfter (ifOp);
932
+
933
+ // Close the mask if
934
+ builder.create <fir::ResultOp>(loc, ifOp.getResult (0 ));
935
+ builder.setInsertionPointToStart (&maskIfOp.getElseRegion ().front ());
936
+ builder.create <fir::ResultOp>(loc, reduction);
937
+ builder.setInsertionPointAfter (maskIfOp);
938
+
939
+ return maskIfOp.getResult (0 );
940
+ };
941
+ auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
942
+ const mlir::Type &resultElemType, mlir::Value resultArr,
943
+ mlir::Value index) {
944
+ mlir::Type resultRefTy = builder.getRefType (resultElemType);
945
+ mlir::Value oneIdx =
946
+ builder.createIntegerConstant (loc, builder.getIndexType (), 1 );
947
+ index = builder.create <mlir::arith::AddIOp>(loc, index, oneIdx);
948
+ return builder.create <hlfir::DesignateOp>(loc, resultRefTy, resultArr,
949
+ index);
950
+ };
951
+
952
+ // Initialize the result
953
+ mlir::Type resultElemTy = hlfir::getFortranElementType (resultArr.getType ());
954
+ mlir::Type resultRefTy = builder.getRefType (resultElemTy);
955
+ mlir::Value returnValue =
956
+ builder.createIntegerConstant (loc, resultElemTy, 0 );
957
+ for (unsigned int i = 0 ; i < rank; ++i) {
958
+ mlir::Value index =
959
+ builder.createIntegerConstant (loc, builder.getIndexType (), i + 1 );
960
+ mlir::Value resultElemAddr = builder.create <hlfir::DesignateOp>(
961
+ loc, resultRefTy, resultArr, index);
962
+ builder.create <fir::StoreOp>(loc, returnValue, resultElemAddr);
963
+ }
964
+
965
+ fir::genMinMaxlocReductionLoop (builder, array, init, genBodyOp, getAddrFn,
966
+ rank, elementType, loc, builder.getI1Type (),
967
+ resultArr, false );
968
+
969
+ mlir::Value asExpr = builder.create <hlfir::AsExprOp>(
970
+ loc, resultArr, builder.createBool (loc, false ));
971
+
972
+ // Check all the users - the destroy is no longer required, and any assign
973
+ // can use resultArr directly so that VariableAssignBufferization in this
974
+ // pass can optimize the results. Other operations are replaces with an
975
+ // AsExpr for the temporary resultArr.
976
+ llvm::SmallVector<hlfir::DestroyOp> destroys;
977
+ llvm::SmallVector<hlfir::AssignOp> assigns;
978
+ for (auto user : minloc->getUsers ()) {
979
+ if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(user))
980
+ destroys.push_back (destroy);
981
+ else if (auto assign = mlir::dyn_cast<hlfir::AssignOp>(user))
982
+ assigns.push_back (assign);
983
+ }
984
+
985
+ // Check if the minloc was the only user of the elemental (apart from a
986
+ // destroy), and remove it if so.
987
+ mlir::Operation::user_range elemUsers = elemental->getUsers ();
988
+ hlfir::DestroyOp elemDestroy;
989
+ if (std::distance (elemUsers.begin (), elemUsers.end ()) == 2 ) {
990
+ elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin ());
991
+ if (!elemDestroy)
992
+ elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin ());
993
+ }
994
+
995
+ for (auto d : destroys)
996
+ rewriter.eraseOp (d);
997
+ for (auto a : assigns)
998
+ a.setOperand (0 , resultArr);
999
+ rewriter.replaceOp (minloc, asExpr);
1000
+ if (elemDestroy) {
1001
+ rewriter.eraseOp (elemDestroy);
1002
+ rewriter.eraseOp (elemental);
1003
+ }
1004
+ return mlir::success ();
1005
+ }
1006
+ };
1007
+
810
1008
class OptimizedBufferizationPass
811
1009
: public hlfir::impl::OptimizedBufferizationBase<
812
1010
OptimizedBufferizationPass> {
@@ -832,6 +1030,7 @@ class OptimizedBufferizationPass
832
1030
patterns.insert <ReductionElementalConversion<hlfir::CountOp>>(context);
833
1031
patterns.insert <ReductionElementalConversion<hlfir::AnyOp>>(context);
834
1032
patterns.insert <ReductionElementalConversion<hlfir::AllOp>>(context);
1033
+ patterns.insert <MinMaxlocElementalConversion>(context);
835
1034
836
1035
if (mlir::failed (mlir::applyPatternsAndFoldGreedily (
837
1036
func, std::move (patterns), config))) {
0 commit comments