Skip to content

Commit 063406c

Browse files
committed
[mlir][fold-memref-alias-ops] Add support for folding memref.expand_shape involving dynamic dims
fold-memref-alias-ops pass bails out in presence of dynamic shapes which leads to unwanted propagation of alias types during other transformations. This can percolate down further and can lead to errors which should not have been created in the first place.
1 parent d34a2c2 commit 063406c

File tree

4 files changed

+187
-24
lines changed

4 files changed

+187
-24
lines changed

mlir/include/mlir/Dialect/Utils/IndexingUtils.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,31 @@ inline SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes) {
4848
return computeSuffixProduct(sizes);
4949
}
5050

51+
/// Given a set of sizes, return the suffix product.
52+
///
53+
/// When applied to slicing, this is the calculation needed to derive the
54+
/// strides (i.e. the number of linear indices to skip along the (k-1) most
55+
/// minor dimensions to get the next k-slice).
56+
///
57+
/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
58+
///
59+
/// Assuming `sizes` is `[s0, .. sn]`, return the vector<Value>
60+
/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
61+
///
62+
/// It is the caller's responsibility to provide valid values which are expected
63+
/// to be constants with index type or results of dimension extraction ops
64+
/// (for ex. memref.dim op).
65+
///
66+
/// `sizes` elements are asserted to be non-negative.
67+
///
68+
/// Return an empty vector if `sizes` is empty.
69+
SmallVector<Value> computeSuffixProduct(Location loc, OpBuilder &builder,
70+
ArrayRef<Value> sizes);
71+
inline SmallVector<Value> computeStrides(Location loc, OpBuilder &builder,
72+
ArrayRef<Value> sizes) {
73+
return computeSuffixProduct(loc, builder, sizes);
74+
}
75+
5176
/// Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
5277
///
5378
/// Return an empty vector if `v1` and `v2` are empty.

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 79 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -63,39 +63,99 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
6363
memref::ExpandShapeOp expandShapeOp,
6464
ValueRange indices,
6565
SmallVectorImpl<Value> &sourceIndices) {
66-
// The below implementation uses computeSuffixProduct method, which only
67-
// allows int64_t values (i.e., static shape). Bail out if it has dynamic
68-
// shapes.
69-
if (!expandShapeOp.getResultType().hasStaticShape())
70-
return failure();
71-
66+
// Record the rewriter context for constructing ops later.
7267
MLIRContext *ctx = rewriter.getContext();
68+
69+
// Record result type to get result dimensions for calulating suffix product
70+
// later.
71+
ShapedType resultType = expandShapeOp.getResultType();
72+
73+
// Traverse all reassociation groups to determine the appropriate indice
74+
// corresponding to each one of them post op folding.
7375
for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
7476
assert(!groups.empty() && "association indices groups cannot be empty");
77+
// Flag to indicate the presence of dynamic dimensions in current
78+
// reassociation group.
79+
bool hasDynamicDims = false;
7580
int64_t groupSize = groups.size();
7681

77-
// Construct the expression for the index value w.r.t to expand shape op
78-
// source corresponding the indices wrt to expand shape op result.
82+
// Capture expand_shape's resultant memref dimensions which are to be used
83+
// in suffix product calculation later.
7984
SmallVector<int64_t> sizes(groupSize);
80-
for (int64_t i = 0; i < groupSize; ++i)
85+
for (int64_t i = 0; i < groupSize; ++i) {
8186
sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]);
82-
SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
87+
if (resultType.isDynamicDim(groups[i]))
88+
hasDynamicDims = true;
89+
}
90+
91+
// Declare resultant affine apply result and affine expression variables to
92+
// represent dimensions in the newly constructed affine map.
93+
OpFoldResult ofr;
8394
SmallVector<AffineExpr> dims(groupSize);
8495
bindDimsList(ctx, MutableArrayRef{dims});
85-
AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
8696

87-
/// Apply permutation and create AffineApplyOp.
97+
// Record the load index corresponding to each dimension in the
98+
// reassociation group. These are later supplied as operands to the affine
99+
// map used for calulating relevant index post op folding.
88100
SmallVector<OpFoldResult> dynamicIndices(groupSize);
89101
for (int64_t i = 0; i < groupSize; i++)
90102
dynamicIndices[i] = indices[groups[i]];
91103

92-
// Creating maximally folded and composd affine.apply composes better with
93-
// other transformations without interleaving canonicalization passes.
94-
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
95-
rewriter, loc,
96-
AffineMap::get(/*numDims=*/groupSize,
97-
/*numSymbols=*/0, srcIndexExpr),
98-
dynamicIndices);
104+
if (hasDynamicDims) {
105+
// Record relevant dimension sizes for each result dimension in the
106+
// reassociation group.
107+
SmallVector<Value> sizesVal(groupSize);
108+
for (int64_t i = 0; i < groupSize; ++i) {
109+
if (sizes[i] <= 0)
110+
sizesVal[i] = rewriter.create<memref::DimOp>(
111+
loc, expandShapeOp.getResult(), groups[i]);
112+
else
113+
sizesVal[i] = rewriter.create<arith::ConstantIndexOp>(loc, sizes[i]);
114+
}
115+
116+
// Calculate suffix product of previously obtained dimension sizes.
117+
auto suffixProduct = computeSuffixProduct(loc, rewriter, sizesVal);
118+
119+
// Create affine expression variables for symbols in the newly constructed
120+
// affine map.
121+
SmallVector<AffineExpr> symbols(groupSize);
122+
bindSymbolsList(ctx, MutableArrayRef{symbols});
123+
124+
// Linearize binded dimensions and symbols to construct the resultant
125+
// affine expression for this indice.
126+
AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
127+
128+
// Supply suffix product results followed by load op indices as operands
129+
// to the map.
130+
SmallVector<OpFoldResult> mapOperands;
131+
llvm::append_range(mapOperands, suffixProduct);
132+
llvm::append_range(mapOperands, dynamicIndices);
133+
134+
// Creating maximally folded and composed affine.apply composes better
135+
// with other transformations without interleaving canonicalization
136+
// passes.
137+
ofr = affine::makeComposedFoldedAffineApply(
138+
rewriter, loc,
139+
AffineMap::get(/*numDims=*/groupSize,
140+
/*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
141+
mapOperands);
142+
} else {
143+
// Calculate suffix product of static dimension sizes and linearize those
144+
// values with dimension affine variables defined previously.
145+
SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
146+
AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
147+
148+
// Creating maximally folded and composed affine.apply composes better
149+
// with other transformations without interleaving canonicalization
150+
// passes.
151+
ofr = affine::makeComposedFoldedAffineApply(
152+
rewriter, loc,
153+
AffineMap::get(/*numDims=*/groupSize,
154+
/*numSymbols=*/0, /*expression=*/srcIndexExpr),
155+
dynamicIndices);
156+
}
157+
// Push index value in the op post folding corresponding to this
158+
// reassociation group.
99159
sourceIndices.push_back(
100160
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
101161
}

mlir/lib/Dialect/Utils/IndexingUtils.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Utils/IndexingUtils.h"
10+
11+
#include "mlir/Dialect/Arith/IR/Arith.h"
1012
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1113
#include "mlir/IR/AffineExpr.h"
1214
#include "mlir/IR/Builders.h"
@@ -29,6 +31,19 @@ SmallVector<ExprType> computeSuffixProductImpl(ArrayRef<ExprType> sizes,
2931
return strides;
3032
}
3133

34+
static SmallVector<Value> computeSuffixProductImpl(Location loc,
35+
OpBuilder &builder,
36+
ArrayRef<Value> sizes,
37+
Value unit) {
38+
if (sizes.empty())
39+
return {};
40+
SmallVector<Value> strides(sizes.size(), unit);
41+
for (int64_t r = strides.size() - 2; r >= 0; --r)
42+
strides[r] =
43+
builder.create<arith::MulIOp>(loc, strides[r + 1], sizes[r + 1]);
44+
return strides;
45+
}
46+
3247
template <typename ExprType>
3348
SmallVector<ExprType> computeElementwiseMulImpl(ArrayRef<ExprType> v1,
3449
ArrayRef<ExprType> v2) {
@@ -197,6 +212,18 @@ SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
197212
return delinearize(linearIndex, getAffineConstantExprs(strides, ctx));
198213
}
199214

215+
//===----------------------------------------------------------------------===//
216+
// Utils that operate on compile time unknown values.
217+
//===----------------------------------------------------------------------===//
218+
219+
SmallVector<Value> mlir::computeSuffixProduct(Location loc, OpBuilder &builder,
220+
ArrayRef<Value> sizes) {
221+
if (sizes.empty())
222+
return {};
223+
Value unit = builder.create<arith::ConstantIndexOp>(loc, 1);
224+
return ::computeSuffixProductImpl(loc, builder, sizes, unit);
225+
}
226+
200227
//===----------------------------------------------------------------------===//
201228
// Permutation utils.
202229
//===----------------------------------------------------------------------===//

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -468,16 +468,67 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
468468

469469
// -----
470470

471-
// CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape
472-
func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) -> f32 {
471+
// CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
472+
// CHECK-LABEL: fold_dynamic_subview_with_memref_load_expand_shape
473+
// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> f32
474+
func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) -> f32 {
473475
%c0 = arith.constant 0 : index
474476
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
475477
%0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
476478
return %0 : f32
477479
}
478-
// CHECK: %[[EXPAND_SHAPE:.+]] = memref.expand_shape {{.+}} : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
479-
// CHECK: %[[LOAD:.+]] = memref.load %[[EXPAND_SHAPE]]
480-
// CHECK: return %[[LOAD]]
480+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
481+
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
482+
// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
483+
// CHECK-NEXT: return %[[VAL1]] : f32
484+
485+
// -----
486+
487+
// CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
488+
// CHECK-LABEL: fold_dynamic_subview_with_memref_store_expand_shape
489+
// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
490+
func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) {
491+
%c0 = arith.constant 0 : index
492+
%c1f32 = arith.constant 1.0 : f32
493+
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
494+
memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
495+
return
496+
}
497+
// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
498+
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
499+
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
500+
// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
501+
// CHECK-NEXT: return
502+
503+
// -----
504+
505+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
506+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
507+
// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
508+
// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
509+
func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index) {
510+
%subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref<?x16xf32, strided<[16, 1], offset: ?>>
511+
%expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] : memref<?x16xf32, strided<[16, 1], offset: ?>> into memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
512+
%dim = memref.dim %expand_shape, %c0 : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
513+
514+
affine.for %arg6 = 0 to %dim step 64 {
515+
affine.for %arg7 = 0 to 16 step 16 {
516+
%dummy_load = affine.load %expand_shape[%arg6, 0, %arg7, %arg7] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
517+
affine.store %dummy_load, %subview[%arg6, %arg7] : memref<?x16xf32, strided<[16, 1], offset: ?>>
518+
}
519+
}
520+
return
521+
}
522+
// CHECK-NEXT: memref.subview
523+
// CHECK-NEXT: %[[EXPAND_SHAPE:.*]] = memref.expand_shape
524+
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
525+
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
526+
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
527+
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
528+
// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]](%[[ARG5]])
529+
// CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
530+
// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
531+
// CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
481532

482533
// -----
483534

0 commit comments

Comments
 (0)