Skip to content

Commit 224262a

Browse files
committed
[mlir][fold-memref-alias-ops] Add support for folding memref.expand_shape involving dynamic dims
1 parent d34a2c2 commit 224262a

File tree

4 files changed

+187
-24
lines changed

4 files changed

+187
-24
lines changed

mlir/include/mlir/Dialect/Utils/IndexingUtils.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,31 @@ inline SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes) {
4848
return computeSuffixProduct(sizes);
4949
}
5050

51+
/// Given a set of sizes, return the suffix product.
52+
///
53+
/// When applied to slicing, this is the calculation needed to derive the
54+
/// strides (i.e. the number of linear indices to skip along the (k-1) most
55+
/// minor dimensions to get the next k-slice).
56+
///
57+
/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
58+
///
59+
/// Assuming `sizes` is `[s0, .. sn]`, return the vector<Value>
60+
/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
61+
///
62+
/// It is the caller's responsibility to provide valid values which are expected
63+
/// to be constants with index type or results of dimension extraction ops
64+
/// (for ex. memref.dim op).
65+
///
66+
/// `sizes` elements are asserted to be non-negative.
67+
///
68+
/// Return an empty vector if `sizes` is empty.
69+
SmallVector<Value> computeSuffixProduct(Location loc, OpBuilder &builder,
70+
ArrayRef<Value> sizes);
71+
inline SmallVector<Value> computeStrides(Location loc, OpBuilder &builder,
72+
ArrayRef<Value> sizes) {
73+
return computeSuffixProduct(loc, builder, sizes);
74+
}
75+
5176
/// Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
5277
///
5378
/// Return an empty vector if `v1` and `v2` are empty.

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 79 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -63,39 +63,99 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
6363
memref::ExpandShapeOp expandShapeOp,
6464
ValueRange indices,
6565
SmallVectorImpl<Value> &sourceIndices) {
66-
// The below implementation uses computeSuffixProduct method, which only
67-
// allows int64_t values (i.e., static shape). Bail out if it has dynamic
68-
// shapes.
69-
if (!expandShapeOp.getResultType().hasStaticShape())
70-
return failure();
71-
66+
// Record the rewriter context for constructing ops later.
7267
MLIRContext *ctx = rewriter.getContext();
68+
69+
// Record result type to get result dimensions for calulating suffix product
70+
// later.
71+
ShapedType resultType = expandShapeOp.getResultType();
72+
73+
// Traverse all reassociation groups to determine the appropriate indice
74+
// corresponding to each one of them post op folding.
7375
for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
7476
assert(!groups.empty() && "association indices groups cannot be empty");
77+
// Flag to indicate the presence of dynamic dimensions in current
78+
// reassociation group.
79+
bool hasDynamicDims = false;
7580
int64_t groupSize = groups.size();
7681

77-
// Construct the expression for the index value w.r.t to expand shape op
78-
// source corresponding the indices wrt to expand shape op result.
82+
// Capture expand_shape's resultant memref dimensions which are to be used
83+
// in suffix product calculation later.
7984
SmallVector<int64_t> sizes(groupSize);
80-
for (int64_t i = 0; i < groupSize; ++i)
85+
for (int64_t i = 0; i < groupSize; ++i) {
8186
sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]);
82-
SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
87+
if (resultType.isDynamicDim(groups[i]))
88+
hasDynamicDims = true;
89+
}
90+
91+
// Declare resultant affine apply result and affine expression variables to
92+
// represent dimensions in the newly constructed affine map.
93+
OpFoldResult ofr;
8394
SmallVector<AffineExpr> dims(groupSize);
8495
bindDimsList(ctx, MutableArrayRef{dims});
85-
AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
8696

87-
/// Apply permutation and create AffineApplyOp.
97+
// Record the load index corresponding to each dimension in the
98+
// reassociation group. These are later supplied as operands to the affine
99+
// map used for calulating relevant index post op folding.
88100
SmallVector<OpFoldResult> dynamicIndices(groupSize);
89101
for (int64_t i = 0; i < groupSize; i++)
90102
dynamicIndices[i] = indices[groups[i]];
91103

92-
// Creating maximally folded and composd affine.apply composes better with
93-
// other transformations without interleaving canonicalization passes.
94-
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
95-
rewriter, loc,
96-
AffineMap::get(/*numDims=*/groupSize,
97-
/*numSymbols=*/0, srcIndexExpr),
98-
dynamicIndices);
104+
if (hasDynamicDims) {
105+
// Record relevant dimension sizes for each result dimension in the
106+
// reassociation group.
107+
SmallVector<Value> sizesVal(groupSize);
108+
for (int64_t i = 0; i < groupSize; ++i) {
109+
if (sizes[i] <= 0)
110+
sizesVal[i] = rewriter.create<memref::DimOp>(
111+
loc, expandShapeOp.getResult(), groups[i]);
112+
else
113+
sizesVal[i] = rewriter.create<arith::ConstantIndexOp>(loc, sizes[i]);
114+
}
115+
116+
// Calculate suffix product of previously obtained dimension sizes.
117+
auto suffixProduct = computeSuffixProduct(loc, rewriter, sizesVal);
118+
119+
// Create affine expression variables for symbols in the newly constructed
120+
// affine map.
121+
SmallVector<AffineExpr> symbols(groupSize);
122+
bindSymbolsList(ctx, MutableArrayRef{symbols});
123+
124+
// Linearize binded dimensions and symbols to construct the resultant
125+
// affine expression for this indice.
126+
AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
127+
128+
// Supply suffix product results followed by load op indices as operands
129+
// to the map.
130+
SmallVector<OpFoldResult> mapOperands;
131+
llvm::append_range(mapOperands, suffixProduct);
132+
llvm::append_range(mapOperands, dynamicIndices);
133+
134+
// Creating maximally folded and composed affine.apply composes better
135+
// with other transformations without interleaving canonicalization
136+
// passes.
137+
ofr = affine::makeComposedFoldedAffineApply(
138+
rewriter, loc,
139+
AffineMap::get(/*numDims=*/groupSize,
140+
/*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
141+
mapOperands);
142+
} else {
143+
// Calculate suffix product of static dimension sizes and linearize those
144+
// values with dimension affine variables defined previously.
145+
SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
146+
AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
147+
148+
// Creating maximally folded and composed affine.apply composes better
149+
// with other transformations without interleaving canonicalization
150+
// passes.
151+
ofr = affine::makeComposedFoldedAffineApply(
152+
rewriter, loc,
153+
AffineMap::get(/*numDims=*/groupSize,
154+
/*numSymbols=*/0, /*expression=*/srcIndexExpr),
155+
dynamicIndices);
156+
}
157+
// Push index value in the op post folding corresponding to this
158+
// reassociation group.
99159
sourceIndices.push_back(
100160
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
101161
}

mlir/lib/Dialect/Utils/IndexingUtils.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Utils/IndexingUtils.h"
10+
11+
#include "mlir/Dialect/Arith/IR/Arith.h"
1012
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1113
#include "mlir/IR/AffineExpr.h"
1214
#include "mlir/IR/Builders.h"
@@ -29,6 +31,19 @@ SmallVector<ExprType> computeSuffixProductImpl(ArrayRef<ExprType> sizes,
2931
return strides;
3032
}
3133

34+
static SmallVector<Value> computeSuffixProductImpl(Location loc,
35+
OpBuilder &builder,
36+
ArrayRef<Value> sizes,
37+
Value unit) {
38+
if (sizes.empty())
39+
return {};
40+
SmallVector<Value> strides(sizes.size(), unit);
41+
for (int64_t r = strides.size() - 2; r >= 0; --r)
42+
strides[r] =
43+
builder.create<arith::MulIOp>(loc, strides[r + 1], sizes[r + 1]);
44+
return strides;
45+
}
46+
3247
template <typename ExprType>
3348
SmallVector<ExprType> computeElementwiseMulImpl(ArrayRef<ExprType> v1,
3449
ArrayRef<ExprType> v2) {
@@ -197,6 +212,18 @@ SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
197212
return delinearize(linearIndex, getAffineConstantExprs(strides, ctx));
198213
}
199214

215+
//===----------------------------------------------------------------------===//
216+
// Utils that operate on compile time unknown values.
217+
//===----------------------------------------------------------------------===//
218+
219+
SmallVector<Value> mlir::computeSuffixProduct(Location loc, OpBuilder &builder,
220+
ArrayRef<Value> sizes) {
221+
if (sizes.empty())
222+
return {};
223+
Value unit = builder.create<arith::ConstantIndexOp>(loc, 1);
224+
return ::computeSuffixProductImpl(loc, builder, sizes, unit);
225+
}
226+
200227
//===----------------------------------------------------------------------===//
201228
// Permutation utils.
202229
//===----------------------------------------------------------------------===//

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -468,16 +468,67 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
468468

469469
// -----
470470

471-
// CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape
472-
func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) -> f32 {
471+
// CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
472+
// CHECK-LABEL: fold_dynamic_subview_with_memref_load_expand_shape
473+
// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> f32
474+
func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) -> f32 {
473475
%c0 = arith.constant 0 : index
474476
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
475477
%0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
476478
return %0 : f32
477479
}
478-
// CHECK: %[[EXPAND_SHAPE:.+]] = memref.expand_shape {{.+}} : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
479-
// CHECK: %[[LOAD:.+]] = memref.load %[[EXPAND_SHAPE]]
480-
// CHECK: return %[[LOAD]]
480+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
481+
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
482+
// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
483+
// CHECK-NEXT: return %[[VAL1]] : f32
484+
485+
// -----
486+
487+
// CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
488+
// CHECK-LABEL: fold_dynamic_subview_with_memref_store_expand_shape
489+
// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
490+
func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) {
491+
%c0 = arith.constant 0 : index
492+
%c1f32 = arith.constant 1.0 : f32
493+
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
494+
memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
495+
return
496+
}
497+
// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
498+
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
499+
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
500+
// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
501+
// CHECK-NEXT: return
502+
503+
// -----
504+
505+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
506+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
507+
// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
508+
// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
509+
func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index) {
510+
%subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref<?x16xf32, strided<[16, 1], offset: ?>>
511+
%expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] : memref<?x16xf32, strided<[16, 1], offset: ?>> into memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
512+
%dim = memref.dim %expand_shape, %c0 : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
513+
514+
affine.for %arg6 = 0 to %dim step 64 {
515+
affine.for %arg7 = 0 to 16 step 16 {
516+
%dummy_load = affine.load %expand_shape[%arg6, 0, %arg7, %arg7] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
517+
affine.store %dummy_load, %subview[%arg6, %arg7] : memref<?x16xf32, strided<[16, 1], offset: ?>>
518+
}
519+
}
520+
return
521+
}
522+
// CHECK-NEXT: memref.subview
523+
// CHECK-NEXT: %[[EXPAND_SHAPE:.*]] = memref.expand_shape
524+
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
525+
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
526+
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
527+
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
528+
// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]](%[[ARG5]])
529+
// CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
530+
// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
531+
// CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
481532

482533
// -----
483534

0 commit comments

Comments
 (0)