@@ -951,6 +951,218 @@ class DotProductConversion
951
951
}
952
952
};
953
953
954
+ class ReshapeAsElementalConversion
955
+ : public mlir::OpRewritePattern<hlfir::ReshapeOp> {
956
+ public:
957
+ using mlir::OpRewritePattern<hlfir::ReshapeOp>::OpRewritePattern;
958
+
959
+ llvm::LogicalResult
960
+ matchAndRewrite (hlfir::ReshapeOp reshape,
961
+ mlir::PatternRewriter &rewriter) const override {
962
+ // Do not inline RESHAPE with ORDER yet. The runtime implementation
963
+ // may be good enough, unless the temporary creation overhead
964
+ // is high.
965
+ // TODO: If ORDER is constant, then we can still easily inline.
966
+ // TODO: If the result's rank is 1, then we can assume ORDER == (/1/).
967
+ if (reshape.getOrder ())
968
+ return rewriter.notifyMatchFailure (reshape,
969
+ " RESHAPE with ORDER argument" );
970
+
971
+ // Verify that the element types of ARRAY, PAD and the result
972
+ // match before doing any transformations. For example,
973
+ // the character types of different lengths may appear in the dead
974
+ // code, and it just does not make sense to inline hlfir.reshape
975
+ // in this case (a runtime call might have less code size footprint).
976
+ hlfir::Entity result = hlfir::Entity{reshape};
977
+ hlfir::Entity array = hlfir::Entity{reshape.getArray ()};
978
+ mlir::Type elementType = array.getFortranElementType ();
979
+ if (result.getFortranElementType () != elementType)
980
+ return rewriter.notifyMatchFailure (
981
+ reshape, " ARRAY and result have different types" );
982
+ mlir::Value pad = reshape.getPad ();
983
+ if (pad && hlfir::getFortranElementType (pad.getType ()) != elementType)
984
+ return rewriter.notifyMatchFailure (reshape,
985
+ " ARRAY and PAD have different types" );
986
+
987
+ // TODO: selecting between ARRAY and PAD of non-trivial element types
988
+ // requires more work. We have to select between two references
989
+ // to elements in ARRAY and PAD. This requires conditional
990
+ // bufferization of the element, if ARRAY/PAD is an expression.
991
+ if (pad && !fir::isa_trivial (elementType))
992
+ return rewriter.notifyMatchFailure (reshape,
993
+ " PAD present with non-trivial type" );
994
+
995
+ mlir::Location loc = reshape.getLoc ();
996
+ fir::FirOpBuilder builder{rewriter, reshape.getOperation ()};
997
+ // Assume that all the indices arithmetic does not overflow
998
+ // the IndexType.
999
+ builder.setIntegerOverflowFlags (mlir::arith::IntegerOverflowFlags::nuw);
1000
+
1001
+ llvm::SmallVector<mlir::Value, 1 > typeParams;
1002
+ hlfir::genLengthParameters (loc, builder, array, typeParams);
1003
+
1004
+ // Fetch the extents of ARRAY, PAD and result beforehand.
1005
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayExtents =
1006
+ hlfir::genExtentsVector (loc, builder, array);
1007
+
1008
+ // If PAD is present, we have to use array size to start taking
1009
+ // elements from the PAD array.
1010
+ mlir::Value arraySize =
1011
+ pad ? computeArraySize (loc, builder, arrayExtents) : nullptr ;
1012
+ hlfir::Entity shape = hlfir::Entity{reshape.getShape ()};
1013
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> resultExtents;
1014
+ mlir::Type indexType = builder.getIndexType ();
1015
+ for (int idx = 0 ; idx < result.getRank (); ++idx)
1016
+ resultExtents.push_back (hlfir::loadElementAt (
1017
+ loc, builder, shape,
1018
+ builder.createIntegerConstant (loc, indexType, idx + 1 )));
1019
+ auto resultShape = builder.create <fir::ShapeOp>(loc, resultExtents);
1020
+
1021
+ auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1022
+ mlir::ValueRange inputIndices) -> hlfir::Entity {
1023
+ mlir::Value linearIndex =
1024
+ computeLinearIndex (loc, builder, resultExtents, inputIndices);
1025
+ fir::IfOp ifOp;
1026
+ if (pad) {
1027
+ // PAD is present. Check if this element comes from the PAD array.
1028
+ mlir::Value isInsideArray = builder.create <mlir::arith::CmpIOp>(
1029
+ loc, mlir::arith::CmpIPredicate::ult, linearIndex, arraySize);
1030
+ ifOp = builder.create <fir::IfOp>(loc, elementType, isInsideArray,
1031
+ /* withElseRegion=*/ true );
1032
+
1033
+ // In the 'else' block, return an element from the PAD.
1034
+ builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
1035
+ // PAD is dynamically optional, but we can unconditionally access it
1036
+ // in the 'else' block. If we have to start taking elements from it,
1037
+ // then it must be present in a valid program.
1038
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padExtents =
1039
+ hlfir::genExtentsVector (loc, builder, hlfir::Entity{pad});
1040
+ // Subtract the ARRAY size from the zero-based linear index
1041
+ // to get the zero-based linear index into PAD.
1042
+ mlir::Value padLinearIndex =
1043
+ builder.create <mlir::arith::SubIOp>(loc, linearIndex, arraySize);
1044
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padIndices =
1045
+ delinearizeIndex (loc, builder, padExtents, padLinearIndex,
1046
+ /* wrapAround=*/ true );
1047
+ mlir::Value padElement =
1048
+ hlfir::loadElementAt (loc, builder, hlfir::Entity{pad}, padIndices);
1049
+ builder.create <fir::ResultOp>(loc, padElement);
1050
+
1051
+ // In the 'then' block, return an element from the ARRAY.
1052
+ builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
1053
+ }
1054
+
1055
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayIndices =
1056
+ delinearizeIndex (loc, builder, arrayExtents, linearIndex,
1057
+ /* wrapAround=*/ false );
1058
+ mlir::Value arrayElement =
1059
+ hlfir::loadElementAt (loc, builder, array, arrayIndices);
1060
+
1061
+ if (ifOp) {
1062
+ builder.create <fir::ResultOp>(loc, arrayElement);
1063
+ builder.setInsertionPointAfter (ifOp);
1064
+ arrayElement = ifOp.getResult (0 );
1065
+ }
1066
+
1067
+ return hlfir::Entity{arrayElement};
1068
+ };
1069
+ hlfir::ElementalOp elementalOp = hlfir::genElementalOp (
1070
+ loc, builder, elementType, resultShape, typeParams, genKernel,
1071
+ /* isUnordered=*/ true ,
1072
+ /* polymorphicMold=*/ result.isPolymorphic () ? array : mlir::Value{},
1073
+ reshape.getResult ().getType ());
1074
+ assert (elementalOp.getResult ().getType () == reshape.getResult ().getType ());
1075
+ rewriter.replaceOp (reshape, elementalOp);
1076
+ return mlir::success ();
1077
+ }
1078
+
1079
+ private:
1080
+ // / Compute zero-based linear index given an array extents
1081
+ // / and one-based indices:
1082
+ // / \p extents: [e0, e1, ..., en]
1083
+ // / \p indices: [i0, i1, ..., in]
1084
+ // /
1085
+ // / linear-index :=
1086
+ // / (...((in-1)*e(n-1)+(i(n-1)-1))*e(n-2)+...)*e0+(i0-1)
1087
+ static mlir::Value computeLinearIndex (mlir::Location loc,
1088
+ fir::FirOpBuilder &builder,
1089
+ mlir::ValueRange extents,
1090
+ mlir::ValueRange indices) {
1091
+ std::size_t rank = extents.size ();
1092
+ assert (rank = indices.size ());
1093
+ mlir::Type indexType = builder.getIndexType ();
1094
+ mlir::Value zero = builder.createIntegerConstant (loc, indexType, 0 );
1095
+ mlir::Value one = builder.createIntegerConstant (loc, indexType, 1 );
1096
+ mlir::Value linearIndex = zero;
1097
+ for (auto idx : llvm::enumerate (llvm::reverse (indices))) {
1098
+ mlir::Value tmp = builder.create <mlir::arith::SubIOp>(
1099
+ loc, builder.createConvert (loc, indexType, idx.value ()), one);
1100
+ tmp = builder.create <mlir::arith::AddIOp>(loc, linearIndex, tmp);
1101
+ if (idx.index () + 1 < rank)
1102
+ tmp = builder.create <mlir::arith::MulIOp>(
1103
+ loc, tmp,
1104
+ builder.createConvert (loc, indexType,
1105
+ extents[rank - idx.index () - 2 ]));
1106
+
1107
+ linearIndex = tmp;
1108
+ }
1109
+ return linearIndex;
1110
+ }
1111
+
1112
+ // / Compute one-based array indices from the given zero-based \p linearIndex
1113
+ // / and the array \p extents [e0, e1, ..., en].
1114
+ // / i0 := linearIndex % e0 + 1
1115
+ // / linearIndex := linearIndex / e0
1116
+ // / i1 := linearIndex % e1 + 1
1117
+ // / linearIndex := linearIndex / e1
1118
+ // / ...
1119
+ // / i(n-1) := linearIndex % e(n-1) + 1
1120
+ // / linearIndex := linearIndex / e(n-1)
1121
+ // / if (wrapAround) {
1122
+ // / // If the index is allowed to wrap around, then
1123
+ // / // we need to modulo it by the last dimension's extent.
1124
+ // / in := linearIndex % en + 1
1125
+ // / } else {
1126
+ // / in := linearIndex + 1
1127
+ // / }
1128
+ static llvm::SmallVector<mlir::Value, Fortran::common::maxRank>
1129
+ delinearizeIndex (mlir::Location loc, fir::FirOpBuilder &builder,
1130
+ mlir::ValueRange extents, mlir::Value linearIndex,
1131
+ bool wrapAround) {
1132
+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
1133
+ mlir::Type indexType = builder.getIndexType ();
1134
+ mlir::Value one = builder.createIntegerConstant (loc, indexType, 1 );
1135
+ linearIndex = builder.createConvert (loc, indexType, linearIndex);
1136
+
1137
+ for (std::size_t dim = 0 ; dim < extents.size (); ++dim) {
1138
+ mlir::Value extent = builder.createConvert (loc, indexType, extents[dim]);
1139
+ // Avoid the modulo for the last index, unless wrap around is allowed.
1140
+ mlir::Value currentIndex = linearIndex;
1141
+ if (dim != extents.size () - 1 || wrapAround)
1142
+ currentIndex =
1143
+ builder.create <mlir::arith::RemUIOp>(loc, linearIndex, extent);
1144
+ // The result of the last division is unused, so it will be DCEd.
1145
+ linearIndex =
1146
+ builder.create <mlir::arith::DivUIOp>(loc, linearIndex, extent);
1147
+ indices.push_back (
1148
+ builder.create <mlir::arith::AddIOp>(loc, currentIndex, one));
1149
+ }
1150
+ return indices;
1151
+ }
1152
+
1153
+ // / Return size of an array given its extents.
1154
+ static mlir::Value computeArraySize (mlir::Location loc,
1155
+ fir::FirOpBuilder &builder,
1156
+ mlir::ValueRange extents) {
1157
+ mlir::Type indexType = builder.getIndexType ();
1158
+ mlir::Value size = builder.createIntegerConstant (loc, indexType, 1 );
1159
+ for (auto extent : extents)
1160
+ size = builder.create <mlir::arith::MulIOp>(
1161
+ loc, size, builder.createConvert (loc, indexType, extent));
1162
+ return size;
1163
+ }
1164
+ };
1165
+
954
1166
class SimplifyHLFIRIntrinsics
955
1167
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
956
1168
public:
@@ -987,6 +1199,7 @@ class SimplifyHLFIRIntrinsics
987
1199
patterns.insert <MatmulConversion<hlfir::MatmulOp>>(context);
988
1200
989
1201
patterns.insert <DotProductConversion>(context);
1202
+ patterns.insert <ReshapeAsElementalConversion>(context);
990
1203
991
1204
if (mlir::failed (mlir::applyPatternsGreedily (
992
1205
getOperation (), std::move (patterns), config))) {
0 commit comments