Skip to content

Commit ea9d0b0

Browse files
committed
Revert "[mlir][MemRef] Use specialized index ops to fold expand/collapse_shape (llvm#138930)"
This reverts commit a891163.
1 parent fc01680 commit ea9d0b0

File tree

3 files changed

+189
-122
lines changed

3 files changed

+189
-122
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,14 +1362,14 @@ def MemRef_ReinterpretCastOp
13621362
according to specified offsets, sizes, and strides.
13631363

13641364
```mlir
1365-
%result1 = memref.reinterpret_cast %arg0 to
1365+
%result1 = memref.reinterpret_cast %arg0 to
13661366
offset: [9],
13671367
sizes: [4, 4],
13681368
strides: [16, 2]
13691369
: memref<8x8xf32, strided<[8, 1], offset: 0>> to
13701370
memref<4x4xf32, strided<[16, 2], offset: 9>>
13711371

1372-
%result2 = memref.reinterpret_cast %result1 to
1372+
%result2 = memref.reinterpret_cast %result1 to
13731373
offset: [0],
13741374
sizes: [2, 2],
13751375
strides: [4, 2]
@@ -1775,12 +1775,6 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
17751775
OpBuilder &b, Location loc, MemRefType expandedType,
17761776
ArrayRef<ReassociationIndices> reassociation,
17771777
ArrayRef<OpFoldResult> inputShape);
1778-
1779-
// Return a vector with all the static and dynamic values in the output shape.
1780-
SmallVector<OpFoldResult> getMixedOutputShape() {
1781-
OpBuilder builder(getContext());
1782-
return ::mlir::getMixedValues(getStaticOutputShape(), getOutputShape(), builder);
1783-
}
17841778
}];
17851779

17861780
let hasVerifier = 1;
@@ -1899,7 +1893,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
18991893
let summary = "store operation";
19001894
let description = [{
19011895
The `store` op stores an element into a memref at the specified indices.
1902-
1896+
19031897
The number of indices must match the rank of the memref. The indices must
19041898
be in-bounds: `0 <= idx < dim_size`.
19051899

@@ -2056,7 +2050,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
20562050
Unlike the `reinterpret_cast`, the values are relative to the strided
20572051
memref of the input (`%result1` in this case) and not its
20582052
underlying memory.
2059-
2053+
20602054
Example 2:
20612055

20622056
```mlir

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 120 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -59,28 +59,92 @@ using namespace mlir;
5959
///
6060
/// %2 = load %0[6 * i1 + i2, %i3] :
6161
/// memref<12x42xf32>
62-
static LogicalResult resolveSourceIndicesExpandShape(
63-
Location loc, PatternRewriter &rewriter,
64-
memref::ExpandShapeOp expandShapeOp, ValueRange indices,
65-
SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
66-
SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
62+
static LogicalResult
63+
resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
64+
memref::ExpandShapeOp expandShapeOp,
65+
ValueRange indices,
66+
SmallVectorImpl<Value> &sourceIndices) {
67+
// Record the rewriter context for constructing ops later.
68+
MLIRContext *ctx = rewriter.getContext();
69+
70+
// Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
71+
// This is done for the purpose of inferring the output shape via
72+
// `inferExpandOutputShape` which will in turn be used for suffix product
73+
// calculation later.
74+
SmallVector<OpFoldResult> srcShape;
75+
MemRefType srcType = expandShapeOp.getSrcType();
76+
77+
for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
78+
if (srcType.isDynamicDim(i)) {
79+
srcShape.push_back(
80+
rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
81+
.getResult());
82+
} else {
83+
srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
84+
}
85+
}
86+
87+
auto outputShape = inferExpandShapeOutputShape(
88+
rewriter, loc, expandShapeOp.getResultType(),
89+
expandShapeOp.getReassociationIndices(), srcShape);
90+
if (!outputShape.has_value())
91+
return failure();
6792

6893
// Traverse all reassociation groups to determine the appropriate indices
6994
// corresponding to each one of them post op folding.
70-
for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
71-
assert(!group.empty() && "association indices groups cannot be empty");
72-
int64_t groupSize = group.size();
73-
if (groupSize == 1) {
74-
sourceIndices.push_back(indices[group[0]]);
75-
continue;
95+
for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
96+
assert(!groups.empty() && "association indices groups cannot be empty");
97+
// Flag to indicate the presence of dynamic dimensions in current
98+
// reassociation group.
99+
int64_t groupSize = groups.size();
100+
101+
// Group output dimensions utilized in this reassociation group for suffix
102+
// product calculation.
103+
SmallVector<OpFoldResult> sizesVal(groupSize);
104+
for (int64_t i = 0; i < groupSize; ++i) {
105+
sizesVal[i] = (*outputShape)[groups[i]];
76106
}
77-
SmallVector<OpFoldResult> groupBasis =
78-
llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
79-
SmallVector<Value> groupIndices =
80-
llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
81-
Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
82-
loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
83-
sourceIndices.push_back(collapsedIndex);
107+
108+
// Calculate suffix product of relevant output dimension sizes.
109+
SmallVector<OpFoldResult> suffixProduct =
110+
memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
111+
112+
// Create affine expression variables for dimensions and symbols in the
113+
// newly constructed affine map.
114+
SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
115+
bindDimsList<AffineExpr>(ctx, dims);
116+
bindSymbolsList<AffineExpr>(ctx, symbols);
117+
118+
// Linearize binded dimensions and symbols to construct the resultant
119+
// affine expression for this indice.
120+
AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
121+
122+
// Record the load index corresponding to each dimension in the
123+
// reassociation group. These are later supplied as operands to the affine
124+
// map used for calulating relevant index post op folding.
125+
SmallVector<OpFoldResult> dynamicIndices(groupSize);
126+
for (int64_t i = 0; i < groupSize; i++)
127+
dynamicIndices[i] = indices[groups[i]];
128+
129+
// Supply suffix product results followed by load op indices as operands
130+
// to the map.
131+
SmallVector<OpFoldResult> mapOperands;
132+
llvm::append_range(mapOperands, suffixProduct);
133+
llvm::append_range(mapOperands, dynamicIndices);
134+
135+
// Creating maximally folded and composed affine.apply composes better
136+
// with other transformations without interleaving canonicalization
137+
// passes.
138+
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
139+
rewriter, loc,
140+
AffineMap::get(/*numDims=*/groupSize,
141+
/*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
142+
mapOperands);
143+
144+
// Push index value in the op post folding corresponding to this
145+
// reassociation group.
146+
sourceIndices.push_back(
147+
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
84148
}
85149
return success();
86150
}
@@ -103,33 +167,49 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
103167
memref::CollapseShapeOp collapseShapeOp,
104168
ValueRange indices,
105169
SmallVectorImpl<Value> &sourceIndices) {
106-
// Note: collapse_shape requires a strided memref, we can do this.
107-
auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
108-
loc, collapseShapeOp.getSrc());
109-
SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
110-
for (auto [index, group] :
111-
llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
112-
assert(!group.empty() && "association indices groups cannot be empty");
113-
int64_t groupSize = group.size();
114-
115-
if (groupSize == 1) {
116-
sourceIndices.push_back(index);
117-
continue;
170+
int64_t cnt = 0;
171+
SmallVector<OpFoldResult> dynamicIndices;
172+
for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
173+
assert(!groups.empty() && "association indices groups cannot be empty");
174+
dynamicIndices.push_back(indices[cnt++]);
175+
int64_t groupSize = groups.size();
176+
177+
// Calculate suffix product for all collapse op source dimension sizes
178+
// except the most major one of each group.
179+
// We allow the most major source dimension to be dynamic but enforce all
180+
// others to be known statically.
181+
SmallVector<int64_t> sizes(groupSize, 1);
182+
for (int64_t i = 1; i < groupSize; ++i) {
183+
sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
184+
if (sizes[i] == ShapedType::kDynamic)
185+
return failure();
118186
}
119-
120-
SmallVector<OpFoldResult> basis =
121-
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
122-
auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
123-
loc, index, basis, /*hasOuterBound=*/true);
124-
llvm::append_range(sourceIndices, delinearize.getResults());
187+
SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
188+
189+
// Derive the index values along all dimensions of the source corresponding
190+
// to the index wrt to collapsed shape op output.
191+
auto d0 = rewriter.getAffineDimExpr(0);
192+
SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
193+
194+
// Construct the AffineApplyOp for each delinearizingExpr.
195+
for (int64_t i = 0; i < groupSize; i++) {
196+
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
197+
rewriter, loc,
198+
AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
199+
delinearizingExprs[i]),
200+
dynamicIndices);
201+
sourceIndices.push_back(
202+
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
203+
}
204+
dynamicIndices.clear();
125205
}
126206
if (collapseShapeOp.getReassociationIndices().empty()) {
127207
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
128208
int64_t srcRank =
129209
cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
130-
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
131-
rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
132210
for (int64_t i = 0; i < srcRank; i++) {
211+
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
212+
rewriter, loc, zeroAffineMap, dynamicIndices);
133213
sourceIndices.push_back(
134214
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
135215
}
@@ -433,12 +513,8 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
433513
indices.assign(expandedIndices.begin(), expandedIndices.end());
434514
}
435515
SmallVector<Value> sourceIndices;
436-
// memref.load and affine.load guarantee that indexes start inbounds
437-
// while the vector operations don't. This impacts if our linearization
438-
// is `disjoint`
439516
if (failed(resolveSourceIndicesExpandShape(
440-
loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
441-
isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
517+
loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
442518
return failure();
443519
llvm::TypeSwitch<Operation *, void>(loadOp)
444520
.Case([&](affine::AffineLoadOp op) {
@@ -600,12 +676,8 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
600676
indices.assign(expandedIndices.begin(), expandedIndices.end());
601677
}
602678
SmallVector<Value> sourceIndices;
603-
// memref.store and affine.store guarantee that indexes start inbounds
604-
// while the vector operations don't. This impacts if our linearization
605-
// is `disjoint`
606679
if (failed(resolveSourceIndicesExpandShape(
607-
storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
608-
isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
680+
storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
609681
return failure();
610682
llvm::TypeSwitch<Operation *, void>(storeOp)
611683
.Case([&](affine::AffineStoreOp op) {

0 commit comments

Comments
 (0)