Skip to content

Commit aef1e99

Browse files
committed
[mlir][fold-memref-alias-ops] Add support for folding memref.expand_shape involving dynamic dims
fold-memref-alias-ops pass bails out in presence of dynamic shapes which leads to unwanted propagation of alias types during other transformations. This can percolate down further and can lead to errors which should not have been created in the first place.
1 parent 3684a38 commit aef1e99

File tree

4 files changed

+178
-38
lines changed

4 files changed

+178
-38
lines changed

mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,31 @@ getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
6464
// it means both the allocations and associated stores can be removed.
6565
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp);
6666

67+
/// Given a set of sizes, return the suffix product.
68+
///
69+
/// When applied to slicing, this is the calculation needed to derive the
70+
/// strides (i.e. the number of linear indices to skip along the (k-1) most
71+
/// minor dimensions to get the next k-slice).
72+
///
73+
/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
74+
///
75+
/// Assuming `sizes` is `[s0, .. sn]`, return the vector<Value>
76+
/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
77+
///
78+
/// It is the caller's responsibility to provide valid OpFoldResult type values
79+
/// and construct valid IR in the end.
80+
///
81+
/// `sizes` elements are asserted to be non-negative.
82+
///
83+
/// Return an empty vector if `sizes` is empty.
84+
///
85+
SmallVector<OpFoldResult> computeSuffixProduct(Location loc, OpBuilder &builder,
86+
ArrayRef<OpFoldResult> sizes);
87+
inline SmallVector<OpFoldResult>
88+
computeStrides(Location loc, OpBuilder &builder, ArrayRef<OpFoldResult> sizes) {
89+
return computeSuffixProduct(loc, builder, sizes);
90+
}
91+
6792
} // namespace memref
6893
} // namespace mlir
6994

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2020
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
2121
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
22+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
2223
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
2324
#include "mlir/Dialect/Utils/IndexingUtils.h"
2425
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -63,39 +64,85 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
6364
memref::ExpandShapeOp expandShapeOp,
6465
ValueRange indices,
6566
SmallVectorImpl<Value> &sourceIndices) {
66-
// The below implementation uses computeSuffixProduct method, which only
67-
// allows int64_t values (i.e., static shape). Bail out if it has dynamic
68-
// shapes.
69-
if (!expandShapeOp.getResultType().hasStaticShape())
67+
// Record the rewriter context for constructing ops later.
68+
MLIRContext *ctx = rewriter.getContext();
69+
70+
// Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
71+
// This is done for the purpose of inferring the output shape via
72+
// `inferExpandOutputShape` which will in turn be used for suffix product
73+
// calculation later.
74+
SmallVector<OpFoldResult> srcShape;
75+
MemRefType srcType = expandShapeOp.getSrcType();
76+
77+
for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
78+
if (srcType.isDynamicDim(i)) {
79+
srcShape.push_back(
80+
rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
81+
.getResult());
82+
} else {
83+
srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
84+
}
85+
}
86+
87+
auto outputShape = inferExpandShapeOutputShape(
88+
rewriter, loc, expandShapeOp.getResultType(),
89+
expandShapeOp.getReassociationIndices(), srcShape);
90+
if (!outputShape.has_value())
7091
return failure();
7192

72-
MLIRContext *ctx = rewriter.getContext();
93+
// Traverse all reassociation groups to determine the appropriate indices
94+
// corresponding to each one of them post op folding.
7395
for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
7496
assert(!groups.empty() && "association indices groups cannot be empty");
97+
// Flag to indicate the presence of dynamic dimensions in current
98+
// reassociation group.
7599
int64_t groupSize = groups.size();
76100

77-
// Construct the expression for the index value w.r.t to expand shape op
78-
// source corresponding the indices wrt to expand shape op result.
79-
SmallVector<int64_t> sizes(groupSize);
80-
for (int64_t i = 0; i < groupSize; ++i)
81-
sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]);
82-
SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
83-
SmallVector<AffineExpr> dims(groupSize);
84-
bindDimsList(ctx, MutableArrayRef{dims});
85-
AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
101+
// Group output dimensions utilized in this reassociation group for suffix
102+
// product calculation.
103+
SmallVector<OpFoldResult> sizesVal(groupSize);
104+
for (int64_t i = 0; i < groupSize; ++i) {
105+
sizesVal[i] = (*outputShape)[groups[i]];
106+
}
107+
108+
// Calculate suffix product of relevant output dimension sizes.
109+
SmallVector<OpFoldResult> suffixProduct =
110+
memref::computeSuffixProduct(loc, rewriter, sizesVal);
111+
112+
// Create affine expression variables for symbols in the newly constructed
113+
// affine map.
114+
SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
115+
bindDimsList<AffineExpr>(ctx, dims);
116+
bindSymbolsList<AffineExpr>(ctx, symbols);
86117

87-
/// Apply permutation and create AffineApplyOp.
118+
// Linearize binded dimensions and symbols to construct the resultant
119+
// affine expression for this indice.
120+
AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
121+
122+
// Record the load index corresponding to each dimension in the
123+
// reassociation group. These are later supplied as operands to the affine
124+
// map used for calulating relevant index post op folding.
88125
SmallVector<OpFoldResult> dynamicIndices(groupSize);
89126
for (int64_t i = 0; i < groupSize; i++)
90127
dynamicIndices[i] = indices[groups[i]];
91128

92-
// Creating maximally folded and composd affine.apply composes better with
93-
// other transformations without interleaving canonicalization passes.
129+
// Supply suffix product results followed by load op indices as operands
130+
// to the map.
131+
SmallVector<OpFoldResult> mapOperands;
132+
llvm::append_range(mapOperands, suffixProduct);
133+
llvm::append_range(mapOperands, dynamicIndices);
134+
135+
// Creating maximally folded and composed affine.apply composes better
136+
// with other transformations without interleaving canonicalization
137+
// passes.
94138
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
95139
rewriter, loc,
96140
AffineMap::get(/*numDims=*/groupSize,
97-
/*numSymbols=*/0, srcIndexExpr),
98-
dynamicIndices);
141+
/*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
142+
mapOperands);
143+
144+
// Push index value in the op post folding corresponding to this
145+
// reassociation group.
99146
sourceIndices.push_back(
100147
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
101148
}

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Arith/Utils/Utils.h"
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1717
#include "mlir/Dialect/Vector/IR/VectorOps.h"
18+
#include "llvm/ADT/STLExtras.h"
1819

1920
namespace mlir {
2021
namespace memref {
@@ -155,5 +156,29 @@ void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp) {
155156
rewriter.eraseOp(op);
156157
}
157158

159+
static SmallVector<OpFoldResult>
160+
computeSuffixProductImpl(Location loc, OpBuilder &builder,
161+
ArrayRef<OpFoldResult> sizes, OpFoldResult unit) {
162+
if (sizes.empty())
163+
return {};
164+
SmallVector<OpFoldResult> strides(sizes.size(), unit);
165+
AffineExpr s0, s1;
166+
bindSymbols(builder.getContext(), s0, s1);
167+
168+
for (int64_t r = strides.size() - 2; r >= 0; --r) {
169+
strides[r] = affine::makeComposedFoldedAffineApply(
170+
builder, loc, s0 * s1, {strides[r + 1], sizes[r + 1]});
171+
}
172+
return strides;
173+
}
174+
175+
SmallVector<OpFoldResult> computeSuffixProduct(Location loc, OpBuilder &builder,
176+
ArrayRef<OpFoldResult> sizes) {
177+
if (sizes.empty())
178+
return {};
179+
OpFoldResult unit = builder.getIndexAttr(1);
180+
return computeSuffixProductImpl(loc, builder, sizes, unit);
181+
}
182+
158183
} // namespace memref
159184
} // namespace mlir

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -468,23 +468,66 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
468468

469469
// -----
470470

471-
// CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape
472-
// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[SZ0:.*]]: index)
473-
func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 {
471+
// CHECK-LABEL: fold_dynamic_subview_with_memref_load_expand_shape
472+
// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32
473+
func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 {
474474
%c0 = arith.constant 0 : index
475475
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
476476
%0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
477477
return %0 : f32
478478
}
479-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
480-
// CHECK: %[[EXPAND_SHAPE:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [1, 16, %[[SZ0]], 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
481-
// CHECK: %[[VAL_0:.*]] = memref.load %[[EXPAND_SHAPE]][%[[C0]], %[[ARG1]], %[[ARG2]], %[[C0]]] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
482-
// CHECK: return %[[VAL_0]] : f32
479+
// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
480+
// CHECK-NEXT: return %[[VAL1]] : f32
483481

484482
// -----
485483

486-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
487-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
484+
// CHECK-LABEL: fold_dynamic_subview_with_memref_store_expand_shape
485+
// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
486+
func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0 : index) {
487+
%c0 = arith.constant 0 : index
488+
%c1f32 = arith.constant 1.0 : f32
489+
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
490+
memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
491+
return
492+
}
493+
// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
494+
// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
495+
// CHECK-NEXT: return
496+
497+
// -----
498+
499+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
500+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)>
501+
// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
502+
// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
503+
func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
504+
%subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref<?x16xf32, strided<[16, 1], offset: ?>>
505+
%expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] output_shape [1, 16, %sz0, 1] : memref<?x16xf32, strided<[16, 1], offset: ?>> into memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
506+
%dim = memref.dim %expand_shape, %c0 : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
507+
508+
affine.for %arg6 = 0 to %dim step 64 {
509+
affine.for %arg7 = 0 to 16 step 16 {
510+
%dummy_load = affine.load %expand_shape[%arg6, 0, %arg7, %arg7] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
511+
affine.store %dummy_load, %subview[%arg6, %arg7] : memref<?x16xf32, strided<[16, 1], offset: ?>>
512+
}
513+
}
514+
return
515+
}
516+
// CHECK-NEXT: memref.subview
517+
// CHECK-NEXT: %[[EXPAND_SHAPE:.*]] = memref.expand_shape
518+
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
519+
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
520+
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
521+
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
522+
// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]]]
523+
// CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
524+
// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
525+
// CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
526+
527+
// -----
528+
529+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 * 1024 + s1)>
530+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
488531
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
489532
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
490533
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -506,14 +549,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0:
506549
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
507550
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
508551
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
509-
// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
510-
// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
552+
// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]], %[[ARG4]]]
553+
// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
511554
// CHECK-NEXT: affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32>
512555

513556
// -----
514557

515-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1025 + d1)>
516-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
558+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 + s0 * 1024)>
559+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
517560
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
518561
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
519562
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -535,14 +578,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_a
535578
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
536579
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
537580
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
538-
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
539-
// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
581+
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])[%[[ARG3]]]
582+
// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
540583
// CHECK-NEXT: affine.load %[[ARG0]][%[[TMP1]], %[[TMP3]]] : memref<1024x1024xf32>
541584

542585
// -----
543586

544-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 * 1024)>
545-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
587+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 1024)>
588+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
546589
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index
547590
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
548591
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -565,8 +608,8 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_c
565608
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
566609
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
567610
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
568-
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
569-
// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
611+
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]]]
612+
// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
570613
// CHECK-NEXT: memref.load %[[ARG0]][%[[TMP1]], %[[TMP2]]] : memref<1024x1024xf32>
571614

572615
// -----

0 commit comments

Comments
 (0)