Skip to content

Commit 6ed8434

Browse files
authored
[mlir][fold-memref-alias-ops] Add support for folding memref.expand_shape involving dynamic dims (#89093)
`fold-memref-alias-ops` bails out in presence of dynamic shapes in `memref.expand_shape` op. Handle this case.
1 parent 73bb8d9 commit 6ed8434

File tree

4 files changed

+180
-38
lines changed

4 files changed

+180
-38
lines changed

mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,35 @@ getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
6464
// it means both the allocations and associated stores can be removed.
6565
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp);
6666

67+
/// Given a set of sizes, return the suffix product.
68+
///
69+
/// When applied to slicing, this is the calculation needed to derive the
70+
/// strides (i.e. the number of linear indices to skip along the (k-1) most
71+
/// minor dimensions to get the next k-slice).
72+
///
73+
/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
74+
///
75+
/// Assuming `sizes` is `[s0, .. sn]`, return the vector<Value>
76+
/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
77+
///
78+
/// It is the caller's responsibility to provide valid OpFoldResult type values
79+
/// and construct valid IR in the end.
80+
///
81+
/// `sizes` elements are asserted to be non-negative.
82+
///
83+
/// Return an empty vector if `sizes` is empty.
84+
///
85+
/// The function emits an IR block which computes suffix product for provided
86+
/// sizes.
87+
SmallVector<OpFoldResult>
88+
computeSuffixProductIRBlock(Location loc, OpBuilder &builder,
89+
ArrayRef<OpFoldResult> sizes);
90+
inline SmallVector<OpFoldResult>
91+
computeStridesIRBlock(Location loc, OpBuilder &builder,
92+
ArrayRef<OpFoldResult> sizes) {
93+
return computeSuffixProductIRBlock(loc, builder, sizes);
94+
}
95+
6796
} // namespace memref
6897
} // namespace mlir
6998

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2020
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
2121
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
22+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
2223
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
2324
#include "mlir/Dialect/Utils/IndexingUtils.h"
2425
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -63,39 +64,85 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
6364
memref::ExpandShapeOp expandShapeOp,
6465
ValueRange indices,
6566
SmallVectorImpl<Value> &sourceIndices) {
66-
// The below implementation uses computeSuffixProduct method, which only
67-
// allows int64_t values (i.e., static shape). Bail out if it has dynamic
68-
// shapes.
69-
if (!expandShapeOp.getResultType().hasStaticShape())
67+
// Record the rewriter context for constructing ops later.
68+
MLIRContext *ctx = rewriter.getContext();
69+
70+
// Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
71+
// This is done for the purpose of inferring the output shape via
72+
// `inferExpandOutputShape` which will in turn be used for suffix product
73+
// calculation later.
74+
SmallVector<OpFoldResult> srcShape;
75+
MemRefType srcType = expandShapeOp.getSrcType();
76+
77+
for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
78+
if (srcType.isDynamicDim(i)) {
79+
srcShape.push_back(
80+
rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
81+
.getResult());
82+
} else {
83+
srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
84+
}
85+
}
86+
87+
auto outputShape = inferExpandShapeOutputShape(
88+
rewriter, loc, expandShapeOp.getResultType(),
89+
expandShapeOp.getReassociationIndices(), srcShape);
90+
if (!outputShape.has_value())
7091
return failure();
7192

72-
MLIRContext *ctx = rewriter.getContext();
93+
// Traverse all reassociation groups to determine the appropriate indices
94+
// corresponding to each one of them post op folding.
7395
for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
7496
assert(!groups.empty() && "association indices groups cannot be empty");
97+
// Flag to indicate the presence of dynamic dimensions in current
98+
// reassociation group.
7599
int64_t groupSize = groups.size();
76100

77-
// Construct the expression for the index value w.r.t to expand shape op
78-
// source corresponding the indices wrt to expand shape op result.
79-
SmallVector<int64_t> sizes(groupSize);
80-
for (int64_t i = 0; i < groupSize; ++i)
81-
sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]);
82-
SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
83-
SmallVector<AffineExpr> dims(groupSize);
84-
bindDimsList(ctx, MutableArrayRef{dims});
85-
AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
101+
// Group output dimensions utilized in this reassociation group for suffix
102+
// product calculation.
103+
SmallVector<OpFoldResult> sizesVal(groupSize);
104+
for (int64_t i = 0; i < groupSize; ++i) {
105+
sizesVal[i] = (*outputShape)[groups[i]];
106+
}
107+
108+
// Calculate suffix product of relevant output dimension sizes.
109+
SmallVector<OpFoldResult> suffixProduct =
110+
memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
111+
112+
// Create affine expression variables for dimensions and symbols in the
113+
// newly constructed affine map.
114+
SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
115+
bindDimsList<AffineExpr>(ctx, dims);
116+
bindSymbolsList<AffineExpr>(ctx, symbols);
86117

87-
/// Apply permutation and create AffineApplyOp.
118+
// Linearize binded dimensions and symbols to construct the resultant
119+
// affine expression for this indice.
120+
AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
121+
122+
// Record the load index corresponding to each dimension in the
123+
// reassociation group. These are later supplied as operands to the affine
124+
// map used for calulating relevant index post op folding.
88125
SmallVector<OpFoldResult> dynamicIndices(groupSize);
89126
for (int64_t i = 0; i < groupSize; i++)
90127
dynamicIndices[i] = indices[groups[i]];
91128

92-
// Creating maximally folded and composd affine.apply composes better with
93-
// other transformations without interleaving canonicalization passes.
129+
// Supply suffix product results followed by load op indices as operands
130+
// to the map.
131+
SmallVector<OpFoldResult> mapOperands;
132+
llvm::append_range(mapOperands, suffixProduct);
133+
llvm::append_range(mapOperands, dynamicIndices);
134+
135+
// Creating maximally folded and composed affine.apply composes better
136+
// with other transformations without interleaving canonicalization
137+
// passes.
94138
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
95139
rewriter, loc,
96140
AffineMap::get(/*numDims=*/groupSize,
97-
/*numSymbols=*/0, srcIndexExpr),
98-
dynamicIndices);
141+
/*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
142+
mapOperands);
143+
144+
// Push index value in the op post folding corresponding to this
145+
// reassociation group.
99146
sourceIndices.push_back(
100147
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
101148
}

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Arith/Utils/Utils.h"
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1717
#include "mlir/Dialect/Vector/IR/VectorOps.h"
18+
#include "llvm/ADT/STLExtras.h"
1819

1920
namespace mlir {
2021
namespace memref {
@@ -155,5 +156,27 @@ void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp) {
155156
rewriter.eraseOp(op);
156157
}
157158

159+
static SmallVector<OpFoldResult>
160+
computeSuffixProductIRBlockImpl(Location loc, OpBuilder &builder,
161+
ArrayRef<OpFoldResult> sizes,
162+
OpFoldResult unit) {
163+
SmallVector<OpFoldResult> strides(sizes.size(), unit);
164+
AffineExpr s0, s1;
165+
bindSymbols(builder.getContext(), s0, s1);
166+
167+
for (int64_t r = strides.size() - 1; r > 0; --r) {
168+
strides[r - 1] = affine::makeComposedFoldedAffineApply(
169+
builder, loc, s0 * s1, {strides[r], sizes[r]});
170+
}
171+
return strides;
172+
}
173+
174+
SmallVector<OpFoldResult>
175+
computeSuffixProductIRBlock(Location loc, OpBuilder &builder,
176+
ArrayRef<OpFoldResult> sizes) {
177+
OpFoldResult unit = builder.getIndexAttr(1);
178+
return computeSuffixProductIRBlockImpl(loc, builder, sizes, unit);
179+
}
180+
158181
} // namespace memref
159182
} // namespace mlir

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -468,23 +468,66 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
468468

469469
// -----
470470

471-
// CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape
472-
// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[SZ0:.*]]: index)
473-
func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 {
471+
// CHECK-LABEL: fold_dynamic_subview_with_memref_load_expand_shape
472+
// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32
473+
func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 {
474474
%c0 = arith.constant 0 : index
475475
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
476476
%0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
477477
return %0 : f32
478478
}
479-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
480-
// CHECK: %[[EXPAND_SHAPE:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [1, 16, %[[SZ0]], 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
481-
// CHECK: %[[VAL_0:.*]] = memref.load %[[EXPAND_SHAPE]][%[[C0]], %[[ARG1]], %[[ARG2]], %[[C0]]] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
482-
// CHECK: return %[[VAL_0]] : f32
479+
// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
480+
// CHECK-NEXT: return %[[VAL1]] : f32
483481

484482
// -----
485483

486-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
487-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
484+
// CHECK-LABEL: fold_dynamic_subview_with_memref_store_expand_shape
485+
// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
486+
func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0 : index) {
487+
%c0 = arith.constant 0 : index
488+
%c1f32 = arith.constant 1.0 : f32
489+
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
490+
memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
491+
return
492+
}
493+
// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
494+
// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>>
495+
// CHECK-NEXT: return
496+
497+
// -----
498+
499+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
500+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)>
501+
// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
502+
// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
503+
func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
504+
%subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref<?x16xf32, strided<[16, 1], offset: ?>>
505+
%expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] output_shape [1, 16, %sz0, 1] : memref<?x16xf32, strided<[16, 1], offset: ?>> into memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
506+
%dim = memref.dim %expand_shape, %c0 : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
507+
508+
affine.for %arg6 = 0 to %dim step 64 {
509+
affine.for %arg7 = 0 to 16 step 16 {
510+
%dummy_load = affine.load %expand_shape[%arg6, 0, %arg7, %arg7] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
511+
affine.store %dummy_load, %subview[%arg6, %arg7] : memref<?x16xf32, strided<[16, 1], offset: ?>>
512+
}
513+
}
514+
return
515+
}
516+
// CHECK-NEXT: memref.subview
517+
// CHECK-NEXT: %[[EXPAND_SHAPE:.*]] = memref.expand_shape
518+
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
519+
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
520+
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
521+
// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
522+
// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]]]
523+
// CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
524+
// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
525+
// CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
526+
527+
// -----
528+
529+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 * 1024 + s1)>
530+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
488531
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
489532
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
490533
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -506,14 +549,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0:
506549
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
507550
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
508551
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
509-
// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
510-
// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
552+
// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]], %[[ARG4]]]
553+
// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
511554
// CHECK-NEXT: affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32>
512555

513556
// -----
514557

515-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1025 + d1)>
516-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
558+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 + s0 * 1024)>
559+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
517560
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
518561
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
519562
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -535,14 +578,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_a
535578
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
536579
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
537580
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
538-
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
539-
// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
581+
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])[%[[ARG3]]]
582+
// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
540583
// CHECK-NEXT: affine.load %[[ARG0]][%[[TMP1]], %[[TMP3]]] : memref<1024x1024xf32>
541584

542585
// -----
543586

544-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 * 1024)>
545-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
587+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 1024)>
588+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
546589
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index
547590
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
548591
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
@@ -565,8 +608,8 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_c
565608
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
566609
// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
567610
// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
568-
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
569-
// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
611+
// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]]]
612+
// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]]
570613
// CHECK-NEXT: memref.load %[[ARG0]][%[[TMP1]], %[[TMP2]]] : memref<1024x1024xf32>
571614

572615
// -----

0 commit comments

Comments
 (0)