@@ -59,28 +59,92 @@ using namespace mlir;
59
59
// /
60
60
// / %2 = load %0[6 * i1 + i2, %i3] :
61
61
// / memref<12x42xf32>
62
- static LogicalResult resolveSourceIndicesExpandShape (
63
- Location loc, PatternRewriter &rewriter,
64
- memref::ExpandShapeOp expandShapeOp, ValueRange indices,
65
- SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
66
- SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape ();
62
+ static LogicalResult
63
+ resolveSourceIndicesExpandShape (Location loc, PatternRewriter &rewriter,
64
+ memref::ExpandShapeOp expandShapeOp,
65
+ ValueRange indices,
66
+ SmallVectorImpl<Value> &sourceIndices) {
67
+ // Record the rewriter context for constructing ops later.
68
+ MLIRContext *ctx = rewriter.getContext ();
69
+
70
+ // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
71
+ // This is done for the purpose of inferring the output shape via
72
+ // `inferExpandOutputShape` which will in turn be used for suffix product
73
+ // calculation later.
74
+ SmallVector<OpFoldResult> srcShape;
75
+ MemRefType srcType = expandShapeOp.getSrcType ();
76
+
77
+ for (int64_t i = 0 , e = srcType.getRank (); i < e; ++i) {
78
+ if (srcType.isDynamicDim (i)) {
79
+ srcShape.push_back (
80
+ rewriter.create <memref::DimOp>(loc, expandShapeOp.getSrc (), i)
81
+ .getResult ());
82
+ } else {
83
+ srcShape.push_back (rewriter.getIndexAttr (srcType.getShape ()[i]));
84
+ }
85
+ }
86
+
87
+ auto outputShape = inferExpandShapeOutputShape (
88
+ rewriter, loc, expandShapeOp.getResultType (),
89
+ expandShapeOp.getReassociationIndices (), srcShape);
90
+ if (!outputShape.has_value ())
91
+ return failure ();
67
92
68
93
// Traverse all reassociation groups to determine the appropriate indices
69
94
// corresponding to each one of them post op folding.
70
- for (ArrayRef<int64_t > group : expandShapeOp.getReassociationIndices ()) {
71
- assert (!group.empty () && " association indices groups cannot be empty" );
72
- int64_t groupSize = group.size ();
73
- if (groupSize == 1 ) {
74
- sourceIndices.push_back (indices[group[0 ]]);
75
- continue ;
95
+ for (ArrayRef<int64_t > groups : expandShapeOp.getReassociationIndices ()) {
96
+ assert (!groups.empty () && " association indices groups cannot be empty" );
97
+ // Flag to indicate the presence of dynamic dimensions in current
98
+ // reassociation group.
99
+ int64_t groupSize = groups.size ();
100
+
101
+ // Group output dimensions utilized in this reassociation group for suffix
102
+ // product calculation.
103
+ SmallVector<OpFoldResult> sizesVal (groupSize);
104
+ for (int64_t i = 0 ; i < groupSize; ++i) {
105
+ sizesVal[i] = (*outputShape)[groups[i]];
76
106
}
77
- SmallVector<OpFoldResult> groupBasis =
78
- llvm::map_to_vector (group, [&](int64_t d) { return destShape[d]; });
79
- SmallVector<Value> groupIndices =
80
- llvm::map_to_vector (group, [&](int64_t d) { return indices[d]; });
81
- Value collapsedIndex = rewriter.create <affine::AffineLinearizeIndexOp>(
82
- loc, groupIndices, groupBasis, /* disjoint=*/ startsInbounds);
83
- sourceIndices.push_back (collapsedIndex);
107
+
108
+ // Calculate suffix product of relevant output dimension sizes.
109
+ SmallVector<OpFoldResult> suffixProduct =
110
+ memref::computeSuffixProductIRBlock (loc, rewriter, sizesVal);
111
+
112
+ // Create affine expression variables for dimensions and symbols in the
113
+ // newly constructed affine map.
114
+ SmallVector<AffineExpr> dims (groupSize), symbols (groupSize);
115
+ bindDimsList<AffineExpr>(ctx, dims);
116
+ bindSymbolsList<AffineExpr>(ctx, symbols);
117
+
118
+ // Linearize binded dimensions and symbols to construct the resultant
119
+ // affine expression for this indice.
120
+ AffineExpr srcIndexExpr = linearize (ctx, dims, symbols);
121
+
122
+ // Record the load index corresponding to each dimension in the
123
+ // reassociation group. These are later supplied as operands to the affine
124
+ // map used for calulating relevant index post op folding.
125
+ SmallVector<OpFoldResult> dynamicIndices (groupSize);
126
+ for (int64_t i = 0 ; i < groupSize; i++)
127
+ dynamicIndices[i] = indices[groups[i]];
128
+
129
+ // Supply suffix product results followed by load op indices as operands
130
+ // to the map.
131
+ SmallVector<OpFoldResult> mapOperands;
132
+ llvm::append_range (mapOperands, suffixProduct);
133
+ llvm::append_range (mapOperands, dynamicIndices);
134
+
135
+ // Creating maximally folded and composed affine.apply composes better
136
+ // with other transformations without interleaving canonicalization
137
+ // passes.
138
+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
139
+ rewriter, loc,
140
+ AffineMap::get (/* numDims=*/ groupSize,
141
+ /* numSymbols=*/ groupSize, /* expression=*/ srcIndexExpr),
142
+ mapOperands);
143
+
144
+ // Push index value in the op post folding corresponding to this
145
+ // reassociation group.
146
+ sourceIndices.push_back (
147
+ getValueOrCreateConstantIndexOp (rewriter, loc, ofr));
84
148
}
85
149
return success ();
86
150
}
@@ -103,33 +167,49 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
103
167
memref::CollapseShapeOp collapseShapeOp,
104
168
ValueRange indices,
105
169
SmallVectorImpl<Value> &sourceIndices) {
106
- // Note: collapse_shape requires a strided memref, we can do this.
107
- auto metadata = rewriter.create <memref::ExtractStridedMetadataOp>(
108
- loc, collapseShapeOp.getSrc ());
109
- SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes ();
110
- for (auto [index, group] :
111
- llvm::zip (indices, collapseShapeOp.getReassociationIndices ())) {
112
- assert (!group.empty () && " association indices groups cannot be empty" );
113
- int64_t groupSize = group.size ();
114
-
115
- if (groupSize == 1 ) {
116
- sourceIndices.push_back (index);
117
- continue ;
170
+ int64_t cnt = 0 ;
171
+ SmallVector<OpFoldResult> dynamicIndices;
172
+ for (ArrayRef<int64_t > groups : collapseShapeOp.getReassociationIndices ()) {
173
+ assert (!groups.empty () && " association indices groups cannot be empty" );
174
+ dynamicIndices.push_back (indices[cnt++]);
175
+ int64_t groupSize = groups.size ();
176
+
177
+ // Calculate suffix product for all collapse op source dimension sizes
178
+ // except the most major one of each group.
179
+ // We allow the most major source dimension to be dynamic but enforce all
180
+ // others to be known statically.
181
+ SmallVector<int64_t > sizes (groupSize, 1 );
182
+ for (int64_t i = 1 ; i < groupSize; ++i) {
183
+ sizes[i] = collapseShapeOp.getSrcType ().getDimSize (groups[i]);
184
+ if (sizes[i] == ShapedType::kDynamic )
185
+ return failure ();
118
186
}
119
-
120
- SmallVector<OpFoldResult> basis =
121
- llvm::map_to_vector (group, [&](int64_t d) { return sourceSizes[d]; });
122
- auto delinearize = rewriter.create <affine::AffineDelinearizeIndexOp>(
123
- loc, index, basis, /* hasOuterBound=*/ true );
124
- llvm::append_range (sourceIndices, delinearize.getResults ());
187
+ SmallVector<int64_t > suffixProduct = computeSuffixProduct (sizes);
188
+
189
+ // Derive the index values along all dimensions of the source corresponding
190
+ // to the index wrt to collapsed shape op output.
191
+ auto d0 = rewriter.getAffineDimExpr (0 );
192
+ SmallVector<AffineExpr> delinearizingExprs = delinearize (d0, suffixProduct);
193
+
194
+ // Construct the AffineApplyOp for each delinearizingExpr.
195
+ for (int64_t i = 0 ; i < groupSize; i++) {
196
+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
197
+ rewriter, loc,
198
+ AffineMap::get (/* numDims=*/ 1 , /* numSymbols=*/ 0 ,
199
+ delinearizingExprs[i]),
200
+ dynamicIndices);
201
+ sourceIndices.push_back (
202
+ getValueOrCreateConstantIndexOp (rewriter, loc, ofr));
203
+ }
204
+ dynamicIndices.clear ();
125
205
}
126
206
if (collapseShapeOp.getReassociationIndices ().empty ()) {
127
207
auto zeroAffineMap = rewriter.getConstantAffineMap (0 );
128
208
int64_t srcRank =
129
209
cast<MemRefType>(collapseShapeOp.getViewSource ().getType ()).getRank ();
130
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
131
- rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
132
210
for (int64_t i = 0 ; i < srcRank; i++) {
211
+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
212
+ rewriter, loc, zeroAffineMap, dynamicIndices);
133
213
sourceIndices.push_back (
134
214
getValueOrCreateConstantIndexOp (rewriter, loc, ofr));
135
215
}
@@ -433,12 +513,8 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
433
513
indices.assign (expandedIndices.begin (), expandedIndices.end ());
434
514
}
435
515
SmallVector<Value> sourceIndices;
436
- // memref.load and affine.load guarantee that indexes start inbounds
437
- // while the vector operations don't. This impacts if our linearization
438
- // is `disjoint`
439
516
if (failed (resolveSourceIndicesExpandShape (
440
- loadOp.getLoc (), rewriter, expandShapeOp, indices, sourceIndices,
441
- isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation ()))))
517
+ loadOp.getLoc (), rewriter, expandShapeOp, indices, sourceIndices)))
442
518
return failure ();
443
519
llvm::TypeSwitch<Operation *, void >(loadOp)
444
520
.Case ([&](affine::AffineLoadOp op) {
@@ -600,12 +676,8 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
600
676
indices.assign (expandedIndices.begin (), expandedIndices.end ());
601
677
}
602
678
SmallVector<Value> sourceIndices;
603
- // memref.store and affine.store guarantee that indexes start inbounds
604
- // while the vector operations don't. This impacts if our linearization
605
- // is `disjoint`
606
679
if (failed (resolveSourceIndicesExpandShape (
607
- storeOp.getLoc (), rewriter, expandShapeOp, indices, sourceIndices,
608
- isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation ()))))
680
+ storeOp.getLoc (), rewriter, expandShapeOp, indices, sourceIndices)))
609
681
return failure ();
610
682
llvm::TypeSwitch<Operation *, void >(storeOp)
611
683
.Case ([&](affine::AffineStoreOp op) {
0 commit comments