Skip to content

Commit fea75c9

Browse files
committed
[mlir][fold-memref-alias-ops] Add support for folding memref.expand_shape involving dynamic dims
fold-memref-alias-ops pass bails out in presence of dynamic shapes which leads to unwanted propagation of alias types during other transformations. This can percolate down further and can lead to errors which should not have been created in the first place.
1 parent 3684a38 commit fea75c9

File tree

4 files changed

+186
-26
lines changed

4 files changed

+186
-26
lines changed

mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,33 @@ getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
6464
// it means both the allocations and associated stores can be removed.
6565
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp);
6666

67+
/// Given a set of sizes, return the suffix product.
68+
///
69+
/// When applied to slicing, this is the calculation needed to derive the
70+
/// strides (i.e. the number of linear indices to skip along the (k-1) most
71+
/// minor dimensions to get the next k-slice).
72+
///
73+
/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
74+
///
75+
/// Assuming `sizes` is `[s0, .. sn]`, return the vector<Value>
76+
/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
77+
///
78+
/// It is the caller's responsibility to provide valid values which are expected
79+
/// to be of index type and construct valid IR in the end.
80+
///
81+
/// `sizes` elements are asserted to be non-negative.
82+
///
83+
/// Return an empty vector if `sizes` is empty.
84+
///
85+
/// The function emits an IR block which computes suffix product for provided
86+
/// sizes.
87+
SmallVector<Value> computeSuffixProductIRBlock(Location loc, OpBuilder &builder,
88+
ArrayRef<Value> sizes);
89+
inline SmallVector<Value>
90+
computeStridesIRBlock(Location loc, OpBuilder &builder, ArrayRef<Value> sizes) {
91+
return computeSuffixProductIRBlock(loc, builder, sizes);
92+
}
93+
6794
} // namespace memref
6895
} // namespace mlir
6996

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 81 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2020
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
2121
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
22+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
2223
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
2324
#include "mlir/Dialect/Utils/IndexingUtils.h"
2425
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -63,39 +64,100 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
6364
memref::ExpandShapeOp expandShapeOp,
6465
ValueRange indices,
6566
SmallVectorImpl<Value> &sourceIndices) {
66-
// The below implementation uses computeSuffixProduct method, which only
67-
// allows int64_t values (i.e., static shape). Bail out if it has dynamic
68-
// shapes.
69-
if (!expandShapeOp.getResultType().hasStaticShape())
70-
return failure();
71-
67+
// Record the rewriter context for constructing ops later.
7268
MLIRContext *ctx = rewriter.getContext();
69+
70+
// Record result type to get result dimensions for calculating suffix product
71+
// later.
72+
ShapedType resultType = expandShapeOp.getResultType();
73+
74+
// Traverse all reassociation groups to determine the appropriate indices
75+
// corresponding to each one of them post op folding.
7376
for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
7477
assert(!groups.empty() && "association indices groups cannot be empty");
78+
// Flag to indicate the presence of dynamic dimensions in current
79+
// reassociation group.
80+
bool hasDynamicDims = false;
7581
int64_t groupSize = groups.size();
7682

77-
// Construct the expression for the index value w.r.t to expand shape op
78-
// source corresponding the indices wrt to expand shape op result.
83+
// Capture expand_shape's resultant memref dimensions which are to be used
84+
// in suffix product calculation later.
7985
SmallVector<int64_t> sizes(groupSize);
80-
for (int64_t i = 0; i < groupSize; ++i)
86+
for (int64_t i = 0; i < groupSize; ++i) {
8187
sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]);
82-
SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
88+
if (resultType.isDynamicDim(groups[i]))
89+
hasDynamicDims = true;
90+
}
91+
92+
// Declare resultant affine apply result and affine expression variables to
93+
// represent dimensions in the newly constructed affine map.
94+
OpFoldResult ofr;
8395
SmallVector<AffineExpr> dims(groupSize);
8496
bindDimsList(ctx, MutableArrayRef{dims});
85-
AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
8697

87-
/// Apply permutation and create AffineApplyOp.
98+
// Record the load index corresponding to each dimension in the
99+
// reassociation group. These are later supplied as operands to the affine
100+
// map used for calulating relevant index post op folding.
88101
SmallVector<OpFoldResult> dynamicIndices(groupSize);
89102
for (int64_t i = 0; i < groupSize; i++)
90103
dynamicIndices[i] = indices[groups[i]];
91104

92-
// Creating maximally folded and composd affine.apply composes better with
93-
// other transformations without interleaving canonicalization passes.
94-
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
95-
rewriter, loc,
96-
AffineMap::get(/*numDims=*/groupSize,
97-
/*numSymbols=*/0, srcIndexExpr),
98-
dynamicIndices);
105+
if (hasDynamicDims) {
106+
// Record relevant dimension sizes for each result dimension in the
107+
// reassociation group.
108+
SmallVector<Value> sizesVal(groupSize);
109+
for (int64_t i = 0; i < groupSize; ++i) {
110+
if (ShapedType::isDynamic(sizes[i]))
111+
sizesVal[i] = rewriter.create<memref::DimOp>(
112+
loc, expandShapeOp.getResult(), groups[i]);
113+
else
114+
sizesVal[i] = rewriter.create<arith::ConstantIndexOp>(loc, sizes[i]);
115+
}
116+
117+
// Calculate suffix product of previously obtained dimension sizes.
118+
SmallVector<Value> suffixProduct =
119+
memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
120+
121+
// Create affine expression variables for symbols in the newly constructed
122+
// affine map.
123+
SmallVector<AffineExpr> symbols(groupSize);
124+
bindSymbolsList<AffineExpr>(ctx, symbols);
125+
126+
// Linearize binded dimensions and symbols to construct the resultant
127+
// affine expression for this indice.
128+
AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
129+
130+
// Supply suffix product results followed by load op indices as operands
131+
// to the map.
132+
SmallVector<OpFoldResult> mapOperands;
133+
llvm::append_range(mapOperands, suffixProduct);
134+
llvm::append_range(mapOperands, dynamicIndices);
135+
136+
// Creating maximally folded and composed affine.apply composes better
137+
// with other transformations without interleaving canonicalization
138+
// passes.
139+
ofr = affine::makeComposedFoldedAffineApply(
140+
rewriter, loc,
141+
AffineMap::get(/*numDims=*/groupSize,
142+
/*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
143+
mapOperands);
144+
} else {
145+
// Calculate suffix product of static dimension sizes and linearize those
146+
// values with dimension affine variables defined previously.
147+
SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
148+
AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
149+
150+
// Creating maximally folded and composed affine.apply composes better
151+
// with other transformations without interleaving canonicalization
152+
// passes.
153+
ofr = affine::makeComposedFoldedAffineApply(
154+
rewriter, loc,
155+
AffineMap::get(/*numDims=*/groupSize,
156+
/*numSymbols=*/0, /*expression=*/srcIndexExpr),
157+
dynamicIndices);
158+
}
159+
// Push index value in the op post folding corresponding to this
160+
// reassociation group.
99161
sourceIndices.push_back(
100162
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
101163
}

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Arith/Utils/Utils.h"
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1717
#include "mlir/Dialect/Vector/IR/VectorOps.h"
18+
#include "llvm/ADT/STLExtras.h"
1819

1920
namespace mlir {
2021
namespace memref {
@@ -155,5 +156,26 @@ void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp) {
155156
rewriter.eraseOp(op);
156157
}
157158

159+
static SmallVector<Value> computeSuffixProductIRBlockImpl(Location loc,
160+
OpBuilder &builder,
161+
ArrayRef<Value> sizes,
162+
Value unit) {
163+
if (sizes.empty())
164+
return {};
165+
SmallVector<Value> strides(sizes.size(), unit);
166+
for (int64_t r = strides.size() - 2; r >= 0; --r)
167+
strides[r] =
168+
builder.create<arith::MulIOp>(loc, strides[r + 1], sizes[r + 1]);
169+
return strides;
170+
}
171+
172+
SmallVector<Value> computeSuffixProductIRBlock(Location loc, OpBuilder &builder,
173+
ArrayRef<Value> sizes) {
174+
if (sizes.empty())
175+
return {};
176+
Value unit = builder.create<arith::ConstantIndexOp>(loc, 1);
177+
return computeSuffixProductIRBlockImpl(loc, builder, sizes, unit);
178+
}
179+
158180
} // namespace memref
159181
} // namespace mlir

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -468,18 +468,67 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
468468

469469
// -----
470470

471-
// CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape
472-
// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[SZ0:.*]]: index)
473-
func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 {
471+
// CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
472+
// CHECK-LABEL: fold_dynamic_subview_with_memref_load_expand_shape
473+
// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32
474+
func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 {
474475
%c0 = arith.constant 0 : index
475476
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
476477
%0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
477478
return %0 : f32
478479
}
479-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
480-
// CHECK: %[[EXPAND_SHAPE:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [1, 16, %[[SZ0]], 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
481-
// CHECK: %[[VAL_0:.*]] = memref.load %[[EXPAND_SHAPE]][%[[C0]], %[[ARG1]], %[[ARG2]], %[[C0]]] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
482-
// CHECK: return %[[VAL_0]] : f32
480+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
481+
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
482+
// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
483+
// CHECK-NEXT: return %[[VAL1]] : f32
484+
485+
// -----
486+
487+
// CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
488+
// CHECK-LABEL: fold_dynamic_subview_with_memref_store_expand_shape
489+
// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
490+
func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0 : index) {
491+
%c0 = arith.constant 0 : index
492+
%c1f32 = arith.constant 1.0 : f32
493+
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
494+
memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
495+
return
496+
}
497+
// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
498+
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
499+
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
500+
// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
501+
// CHECK-NEXT: return
502+
503+
// -----
504+
505+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
506+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
507+
// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
508+
// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
509+
func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
510+
%subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref<?x16xf32, strided<[16, 1], offset: ?>>
511+
%expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] output_shape [1, 16, %sz0, 1] : memref<?x16xf32, strided<[16, 1], offset: ?>> into memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
512+
%dim = memref.dim %expand_shape, %c0 : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
513+
514+
affine.for %arg6 = 0 to %dim step 64 {
515+
affine.for %arg7 = 0 to 16 step 16 {
516+
%dummy_load = affine.load %expand_shape[%arg6, 0, %arg7, %arg7] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
517+
affine.store %dummy_load, %subview[%arg6, %arg7] : memref<?x16xf32, strided<[16, 1], offset: ?>>
518+
}
519+
}
520+
return
521+
}
522+
// CHECK-NEXT: memref.subview
523+
// CHECK-NEXT: %[[EXPAND_SHAPE:.*]] = memref.expand_shape
524+
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
525+
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
526+
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
527+
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
528+
// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]](%[[ARG5]])
529+
// CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
530+
// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
531+
// CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
483532

484533
// -----
485534

0 commit comments

Comments
 (0)