Skip to content

Commit 6160a67

Browse files
authored
[flang] Inline hlfir.reshape as hlfir.elemental. (#124683)
This patch inlines hlfir.reshape for simple cases, such as when there is no ORDER argument; and when PAD is present, only the trivial types are handled.
1 parent 81f5098 commit 6160a67

File tree

2 files changed

+441
-0
lines changed

2 files changed

+441
-0
lines changed

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,218 @@ class DotProductConversion
951951
}
952952
};
953953

954+
class ReshapeAsElementalConversion
955+
: public mlir::OpRewritePattern<hlfir::ReshapeOp> {
956+
public:
957+
using mlir::OpRewritePattern<hlfir::ReshapeOp>::OpRewritePattern;
958+
959+
llvm::LogicalResult
960+
matchAndRewrite(hlfir::ReshapeOp reshape,
961+
mlir::PatternRewriter &rewriter) const override {
962+
// Do not inline RESHAPE with ORDER yet. The runtime implementation
963+
// may be good enough, unless the temporary creation overhead
964+
// is high.
965+
// TODO: If ORDER is constant, then we can still easily inline.
966+
// TODO: If the result's rank is 1, then we can assume ORDER == (/1/).
967+
if (reshape.getOrder())
968+
return rewriter.notifyMatchFailure(reshape,
969+
"RESHAPE with ORDER argument");
970+
971+
// Verify that the element types of ARRAY, PAD and the result
972+
// match before doing any transformations. For example,
973+
// the character types of different lengths may appear in the dead
974+
// code, and it just does not make sense to inline hlfir.reshape
975+
// in this case (a runtime call might have less code size footprint).
976+
hlfir::Entity result = hlfir::Entity{reshape};
977+
hlfir::Entity array = hlfir::Entity{reshape.getArray()};
978+
mlir::Type elementType = array.getFortranElementType();
979+
if (result.getFortranElementType() != elementType)
980+
return rewriter.notifyMatchFailure(
981+
reshape, "ARRAY and result have different types");
982+
mlir::Value pad = reshape.getPad();
983+
if (pad && hlfir::getFortranElementType(pad.getType()) != elementType)
984+
return rewriter.notifyMatchFailure(reshape,
985+
"ARRAY and PAD have different types");
986+
987+
// TODO: selecting between ARRAY and PAD of non-trivial element types
988+
// requires more work. We have to select between two references
989+
// to elements in ARRAY and PAD. This requires conditional
990+
// bufferization of the element, if ARRAY/PAD is an expression.
991+
if (pad && !fir::isa_trivial(elementType))
992+
return rewriter.notifyMatchFailure(reshape,
993+
"PAD present with non-trivial type");
994+
995+
mlir::Location loc = reshape.getLoc();
996+
fir::FirOpBuilder builder{rewriter, reshape.getOperation()};
997+
// Assume that all the indices arithmetic does not overflow
998+
// the IndexType.
999+
builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nuw);
1000+
1001+
llvm::SmallVector<mlir::Value, 1> typeParams;
1002+
hlfir::genLengthParameters(loc, builder, array, typeParams);
1003+
1004+
// Fetch the extents of ARRAY, PAD and result beforehand.
1005+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayExtents =
1006+
hlfir::genExtentsVector(loc, builder, array);
1007+
1008+
// If PAD is present, we have to use array size to start taking
1009+
// elements from the PAD array.
1010+
mlir::Value arraySize =
1011+
pad ? computeArraySize(loc, builder, arrayExtents) : nullptr;
1012+
hlfir::Entity shape = hlfir::Entity{reshape.getShape()};
1013+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> resultExtents;
1014+
mlir::Type indexType = builder.getIndexType();
1015+
for (int idx = 0; idx < result.getRank(); ++idx)
1016+
resultExtents.push_back(hlfir::loadElementAt(
1017+
loc, builder, shape,
1018+
builder.createIntegerConstant(loc, indexType, idx + 1)));
1019+
auto resultShape = builder.create<fir::ShapeOp>(loc, resultExtents);
1020+
1021+
auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1022+
mlir::ValueRange inputIndices) -> hlfir::Entity {
1023+
mlir::Value linearIndex =
1024+
computeLinearIndex(loc, builder, resultExtents, inputIndices);
1025+
fir::IfOp ifOp;
1026+
if (pad) {
1027+
// PAD is present. Check if this element comes from the PAD array.
1028+
mlir::Value isInsideArray = builder.create<mlir::arith::CmpIOp>(
1029+
loc, mlir::arith::CmpIPredicate::ult, linearIndex, arraySize);
1030+
ifOp = builder.create<fir::IfOp>(loc, elementType, isInsideArray,
1031+
/*withElseRegion=*/true);
1032+
1033+
// In the 'else' block, return an element from the PAD.
1034+
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1035+
// PAD is dynamically optional, but we can unconditionally access it
1036+
// in the 'else' block. If we have to start taking elements from it,
1037+
// then it must be present in a valid program.
1038+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padExtents =
1039+
hlfir::genExtentsVector(loc, builder, hlfir::Entity{pad});
1040+
// Subtract the ARRAY size from the zero-based linear index
1041+
// to get the zero-based linear index into PAD.
1042+
mlir::Value padLinearIndex =
1043+
builder.create<mlir::arith::SubIOp>(loc, linearIndex, arraySize);
1044+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padIndices =
1045+
delinearizeIndex(loc, builder, padExtents, padLinearIndex,
1046+
/*wrapAround=*/true);
1047+
mlir::Value padElement =
1048+
hlfir::loadElementAt(loc, builder, hlfir::Entity{pad}, padIndices);
1049+
builder.create<fir::ResultOp>(loc, padElement);
1050+
1051+
// In the 'then' block, return an element from the ARRAY.
1052+
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1053+
}
1054+
1055+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayIndices =
1056+
delinearizeIndex(loc, builder, arrayExtents, linearIndex,
1057+
/*wrapAround=*/false);
1058+
mlir::Value arrayElement =
1059+
hlfir::loadElementAt(loc, builder, array, arrayIndices);
1060+
1061+
if (ifOp) {
1062+
builder.create<fir::ResultOp>(loc, arrayElement);
1063+
builder.setInsertionPointAfter(ifOp);
1064+
arrayElement = ifOp.getResult(0);
1065+
}
1066+
1067+
return hlfir::Entity{arrayElement};
1068+
};
1069+
hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
1070+
loc, builder, elementType, resultShape, typeParams, genKernel,
1071+
/*isUnordered=*/true,
1072+
/*polymorphicMold=*/result.isPolymorphic() ? array : mlir::Value{},
1073+
reshape.getResult().getType());
1074+
assert(elementalOp.getResult().getType() == reshape.getResult().getType());
1075+
rewriter.replaceOp(reshape, elementalOp);
1076+
return mlir::success();
1077+
}
1078+
1079+
private:
1080+
/// Compute zero-based linear index given an array extents
1081+
/// and one-based indices:
1082+
/// \p extents: [e0, e1, ..., en]
1083+
/// \p indices: [i0, i1, ..., in]
1084+
///
1085+
/// linear-index :=
1086+
/// (...((in-1)*e(n-1)+(i(n-1)-1))*e(n-2)+...)*e0+(i0-1)
1087+
static mlir::Value computeLinearIndex(mlir::Location loc,
1088+
fir::FirOpBuilder &builder,
1089+
mlir::ValueRange extents,
1090+
mlir::ValueRange indices) {
1091+
std::size_t rank = extents.size();
1092+
assert(rank = indices.size());
1093+
mlir::Type indexType = builder.getIndexType();
1094+
mlir::Value zero = builder.createIntegerConstant(loc, indexType, 0);
1095+
mlir::Value one = builder.createIntegerConstant(loc, indexType, 1);
1096+
mlir::Value linearIndex = zero;
1097+
for (auto idx : llvm::enumerate(llvm::reverse(indices))) {
1098+
mlir::Value tmp = builder.create<mlir::arith::SubIOp>(
1099+
loc, builder.createConvert(loc, indexType, idx.value()), one);
1100+
tmp = builder.create<mlir::arith::AddIOp>(loc, linearIndex, tmp);
1101+
if (idx.index() + 1 < rank)
1102+
tmp = builder.create<mlir::arith::MulIOp>(
1103+
loc, tmp,
1104+
builder.createConvert(loc, indexType,
1105+
extents[rank - idx.index() - 2]));
1106+
1107+
linearIndex = tmp;
1108+
}
1109+
return linearIndex;
1110+
}
1111+
1112+
/// Compute one-based array indices from the given zero-based \p linearIndex
1113+
/// and the array \p extents [e0, e1, ..., en].
1114+
/// i0 := linearIndex % e0 + 1
1115+
/// linearIndex := linearIndex / e0
1116+
/// i1 := linearIndex % e1 + 1
1117+
/// linearIndex := linearIndex / e1
1118+
/// ...
1119+
/// i(n-1) := linearIndex % e(n-1) + 1
1120+
/// linearIndex := linearIndex / e(n-1)
1121+
/// if (wrapAround) {
1122+
/// // If the index is allowed to wrap around, then
1123+
/// // we need to modulo it by the last dimension's extent.
1124+
/// in := linearIndex % en + 1
1125+
/// } else {
1126+
/// in := linearIndex + 1
1127+
/// }
1128+
static llvm::SmallVector<mlir::Value, Fortran::common::maxRank>
1129+
delinearizeIndex(mlir::Location loc, fir::FirOpBuilder &builder,
1130+
mlir::ValueRange extents, mlir::Value linearIndex,
1131+
bool wrapAround) {
1132+
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
1133+
mlir::Type indexType = builder.getIndexType();
1134+
mlir::Value one = builder.createIntegerConstant(loc, indexType, 1);
1135+
linearIndex = builder.createConvert(loc, indexType, linearIndex);
1136+
1137+
for (std::size_t dim = 0; dim < extents.size(); ++dim) {
1138+
mlir::Value extent = builder.createConvert(loc, indexType, extents[dim]);
1139+
// Avoid the modulo for the last index, unless wrap around is allowed.
1140+
mlir::Value currentIndex = linearIndex;
1141+
if (dim != extents.size() - 1 || wrapAround)
1142+
currentIndex =
1143+
builder.create<mlir::arith::RemUIOp>(loc, linearIndex, extent);
1144+
// The result of the last division is unused, so it will be DCEd.
1145+
linearIndex =
1146+
builder.create<mlir::arith::DivUIOp>(loc, linearIndex, extent);
1147+
indices.push_back(
1148+
builder.create<mlir::arith::AddIOp>(loc, currentIndex, one));
1149+
}
1150+
return indices;
1151+
}
1152+
1153+
/// Return size of an array given its extents.
1154+
static mlir::Value computeArraySize(mlir::Location loc,
1155+
fir::FirOpBuilder &builder,
1156+
mlir::ValueRange extents) {
1157+
mlir::Type indexType = builder.getIndexType();
1158+
mlir::Value size = builder.createIntegerConstant(loc, indexType, 1);
1159+
for (auto extent : extents)
1160+
size = builder.create<mlir::arith::MulIOp>(
1161+
loc, size, builder.createConvert(loc, indexType, extent));
1162+
return size;
1163+
}
1164+
};
1165+
9541166
class SimplifyHLFIRIntrinsics
9551167
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
9561168
public:
@@ -987,6 +1199,7 @@ class SimplifyHLFIRIntrinsics
9871199
patterns.insert<MatmulConversion<hlfir::MatmulOp>>(context);
9881200

9891201
patterns.insert<DotProductConversion>(context);
1202+
patterns.insert<ReshapeAsElementalConversion>(context);
9901203

9911204
if (mlir::failed(mlir::applyPatternsGreedily(
9921205
getOperation(), std::move(patterns), config))) {

0 commit comments

Comments
 (0)