Skip to content

Commit a891163

Browse files
authored
[mlir][MemRef] Use specialized index ops to fold expand/collapse_shape (#138930)
This PR updates the FoldMemRefAliasOps to use `affine.linearize_index` and `affine.delinearize_index` to perform the index computations needed to fold a `memref.expand_shape` or `memref.collapse_shape` into its consumers, respectively. This also loosens some limitations of the pass: 1. The existing `output_shape` argument to `memref.expand_shape` is now used, eliminating the need to re-infer this shape or call `memref.dim`. 2. Because we're using `affine.delinearize_index`, the restriction that each group in a `memref.collapse_shape` can only have one dynamic dimension is removed.
1 parent 810148c commit a891163

File tree

3 files changed

+123
-189
lines changed

3 files changed

+123
-189
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,14 +1342,14 @@ def MemRef_ReinterpretCastOp
13421342
according to specified offsets, sizes, and strides.
13431343

13441344
```mlir
1345-
%result1 = memref.reinterpret_cast %arg0 to
1345+
%result1 = memref.reinterpret_cast %arg0 to
13461346
offset: [9],
13471347
sizes: [4, 4],
13481348
strides: [16, 2]
13491349
: memref<8x8xf32, strided<[8, 1], offset: 0>> to
13501350
memref<4x4xf32, strided<[16, 2], offset: 9>>
13511351

1352-
%result2 = memref.reinterpret_cast %result1 to
1352+
%result2 = memref.reinterpret_cast %result1 to
13531353
offset: [0],
13541354
sizes: [2, 2],
13551355
strides: [4, 2]
@@ -1755,6 +1755,12 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
17551755
OpBuilder &b, Location loc, MemRefType expandedType,
17561756
ArrayRef<ReassociationIndices> reassociation,
17571757
ArrayRef<OpFoldResult> inputShape);
1758+
1759+
// Return a vector with all the static and dynamic values in the output shape.
1760+
SmallVector<OpFoldResult> getMixedOutputShape() {
1761+
OpBuilder builder(getContext());
1762+
return ::mlir::getMixedValues(getStaticOutputShape(), getOutputShape(), builder);
1763+
}
17581764
}];
17591765

17601766
let hasVerifier = 1;
@@ -1873,7 +1879,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
18731879
let summary = "store operation";
18741880
let description = [{
18751881
The `store` op stores an element into a memref at the specified indices.
1876-
1882+
18771883
The number of indices must match the rank of the memref. The indices must
18781884
be in-bounds: `0 <= idx < dim_size`
18791885

@@ -2025,7 +2031,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
20252031
Unlike the `reinterpret_cast`, the values are relative to the strided
20262032
memref of the input (`%result1` in this case) and not its
20272033
underlying memory.
2028-
2034+
20292035
Example 2:
20302036

20312037
```mlir

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 49 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -59,92 +59,28 @@ using namespace mlir;
5959
///
6060
/// %2 = load %0[6 * i1 + i2, %i3] :
6161
/// memref<12x42xf32>
62-
static LogicalResult
63-
resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
64-
memref::ExpandShapeOp expandShapeOp,
65-
ValueRange indices,
66-
SmallVectorImpl<Value> &sourceIndices) {
67-
// Record the rewriter context for constructing ops later.
68-
MLIRContext *ctx = rewriter.getContext();
69-
70-
// Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
71-
// This is done for the purpose of inferring the output shape via
72-
// `inferExpandOutputShape` which will in turn be used for suffix product
73-
// calculation later.
74-
SmallVector<OpFoldResult> srcShape;
75-
MemRefType srcType = expandShapeOp.getSrcType();
76-
77-
for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
78-
if (srcType.isDynamicDim(i)) {
79-
srcShape.push_back(
80-
rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
81-
.getResult());
82-
} else {
83-
srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
84-
}
85-
}
86-
87-
auto outputShape = inferExpandShapeOutputShape(
88-
rewriter, loc, expandShapeOp.getResultType(),
89-
expandShapeOp.getReassociationIndices(), srcShape);
90-
if (!outputShape.has_value())
91-
return failure();
62+
static LogicalResult resolveSourceIndicesExpandShape(
63+
Location loc, PatternRewriter &rewriter,
64+
memref::ExpandShapeOp expandShapeOp, ValueRange indices,
65+
SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
66+
SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
9267

9368
// Traverse all reassociation groups to determine the appropriate indices
9469
// corresponding to each one of them post op folding.
95-
for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
96-
assert(!groups.empty() && "association indices groups cannot be empty");
97-
// Flag to indicate the presence of dynamic dimensions in current
98-
// reassociation group.
99-
int64_t groupSize = groups.size();
100-
101-
// Group output dimensions utilized in this reassociation group for suffix
102-
// product calculation.
103-
SmallVector<OpFoldResult> sizesVal(groupSize);
104-
for (int64_t i = 0; i < groupSize; ++i) {
105-
sizesVal[i] = (*outputShape)[groups[i]];
70+
for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
71+
assert(!group.empty() && "association indices groups cannot be empty");
72+
int64_t groupSize = group.size();
73+
if (groupSize == 1) {
74+
sourceIndices.push_back(indices[group[0]]);
75+
continue;
10676
}
107-
108-
// Calculate suffix product of relevant output dimension sizes.
109-
SmallVector<OpFoldResult> suffixProduct =
110-
memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
111-
112-
// Create affine expression variables for dimensions and symbols in the
113-
// newly constructed affine map.
114-
SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
115-
bindDimsList<AffineExpr>(ctx, dims);
116-
bindSymbolsList<AffineExpr>(ctx, symbols);
117-
118-
// Linearize binded dimensions and symbols to construct the resultant
119-
// affine expression for this indice.
120-
AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
121-
122-
// Record the load index corresponding to each dimension in the
123-
// reassociation group. These are later supplied as operands to the affine
124-
// map used for calulating relevant index post op folding.
125-
SmallVector<OpFoldResult> dynamicIndices(groupSize);
126-
for (int64_t i = 0; i < groupSize; i++)
127-
dynamicIndices[i] = indices[groups[i]];
128-
129-
// Supply suffix product results followed by load op indices as operands
130-
// to the map.
131-
SmallVector<OpFoldResult> mapOperands;
132-
llvm::append_range(mapOperands, suffixProduct);
133-
llvm::append_range(mapOperands, dynamicIndices);
134-
135-
// Creating maximally folded and composed affine.apply composes better
136-
// with other transformations without interleaving canonicalization
137-
// passes.
138-
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
139-
rewriter, loc,
140-
AffineMap::get(/*numDims=*/groupSize,
141-
/*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
142-
mapOperands);
143-
144-
// Push index value in the op post folding corresponding to this
145-
// reassociation group.
146-
sourceIndices.push_back(
147-
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
77+
SmallVector<OpFoldResult> groupBasis =
78+
llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
79+
SmallVector<Value> groupIndices =
80+
llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
81+
Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
82+
loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
83+
sourceIndices.push_back(collapsedIndex);
14884
}
14985
return success();
15086
}
@@ -167,49 +103,34 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
167103
memref::CollapseShapeOp collapseShapeOp,
168104
ValueRange indices,
169105
SmallVectorImpl<Value> &sourceIndices) {
170-
int64_t cnt = 0;
171-
SmallVector<OpFoldResult> dynamicIndices;
172-
for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
173-
assert(!groups.empty() && "association indices groups cannot be empty");
174-
dynamicIndices.push_back(indices[cnt++]);
175-
int64_t groupSize = groups.size();
176-
177-
// Calculate suffix product for all collapse op source dimension sizes
178-
// except the most major one of each group.
179-
// We allow the most major source dimension to be dynamic but enforce all
180-
// others to be known statically.
181-
SmallVector<int64_t> sizes(groupSize, 1);
182-
for (int64_t i = 1; i < groupSize; ++i) {
183-
sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
184-
if (sizes[i] == ShapedType::kDynamic)
185-
return failure();
186-
}
187-
SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
188-
189-
// Derive the index values along all dimensions of the source corresponding
190-
// to the index wrt to collapsed shape op output.
191-
auto d0 = rewriter.getAffineDimExpr(0);
192-
SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
193-
194-
// Construct the AffineApplyOp for each delinearizingExpr.
195-
for (int64_t i = 0; i < groupSize; i++) {
196-
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
197-
rewriter, loc,
198-
AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
199-
delinearizingExprs[i]),
200-
dynamicIndices);
201-
sourceIndices.push_back(
202-
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
106+
MemRefType sourceType = collapseShapeOp.getSrcType();
107+
// Note: collapse_shape requires a strided memref, we can do this.
108+
auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
109+
loc, collapseShapeOp.getSrc());
110+
SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
111+
for (auto [index, group] :
112+
llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
113+
assert(!group.empty() && "association indices groups cannot be empty");
114+
int64_t groupSize = group.size();
115+
116+
if (groupSize == 1) {
117+
sourceIndices.push_back(index);
118+
continue;
203119
}
204-
dynamicIndices.clear();
120+
121+
SmallVector<OpFoldResult> basis =
122+
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
123+
auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
124+
loc, index, basis, /*hasOuterBound=*/true);
125+
llvm::append_range(sourceIndices, delinearize.getResults());
205126
}
206127
if (collapseShapeOp.getReassociationIndices().empty()) {
207128
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
208129
int64_t srcRank =
209130
cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
131+
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
132+
rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
210133
for (int64_t i = 0; i < srcRank; i++) {
211-
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
212-
rewriter, loc, zeroAffineMap, dynamicIndices);
213134
sourceIndices.push_back(
214135
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
215136
}
@@ -513,8 +434,12 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
513434
indices.assign(expandedIndices.begin(), expandedIndices.end());
514435
}
515436
SmallVector<Value> sourceIndices;
437+
// memref.load and affine.load guarantee that indexes start inbounds
438+
// while the vector operations don't. This impacts if our linearization
439+
// is `disjoint`
516440
if (failed(resolveSourceIndicesExpandShape(
517-
loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
441+
loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
442+
isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
518443
return failure();
519444
llvm::TypeSwitch<Operation *, void>(loadOp)
520445
.Case([&](affine::AffineLoadOp op) {
@@ -676,8 +601,12 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
676601
indices.assign(expandedIndices.begin(), expandedIndices.end());
677602
}
678603
SmallVector<Value> sourceIndices;
604+
// memref.store and affine.store guarantee that indexes start inbounds
605+
// while the vector operations don't. This impacts if our linearization
606+
// is `disjoint`
679607
if (failed(resolveSourceIndicesExpandShape(
680-
storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
608+
storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
609+
isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
681610
return failure();
682611
llvm::TypeSwitch<Operation *, void>(storeOp)
683612
.Case([&](affine::AffineStoreOp op) {

0 commit comments

Comments
 (0)