@@ -59,92 +59,28 @@ using namespace mlir;
59
59
// /
60
60
// / %2 = load %0[6 * i1 + i2, %i3] :
61
61
// / memref<12x42xf32>
62
- static LogicalResult
63
- resolveSourceIndicesExpandShape (Location loc, PatternRewriter &rewriter,
64
- memref::ExpandShapeOp expandShapeOp,
65
- ValueRange indices,
66
- SmallVectorImpl<Value> &sourceIndices) {
67
- // Record the rewriter context for constructing ops later.
68
- MLIRContext *ctx = rewriter.getContext ();
69
-
70
- // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
71
- // This is done for the purpose of inferring the output shape via
72
- // `inferExpandOutputShape` which will in turn be used for suffix product
73
- // calculation later.
74
- SmallVector<OpFoldResult> srcShape;
75
- MemRefType srcType = expandShapeOp.getSrcType ();
76
-
77
- for (int64_t i = 0 , e = srcType.getRank (); i < e; ++i) {
78
- if (srcType.isDynamicDim (i)) {
79
- srcShape.push_back (
80
- rewriter.create <memref::DimOp>(loc, expandShapeOp.getSrc (), i)
81
- .getResult ());
82
- } else {
83
- srcShape.push_back (rewriter.getIndexAttr (srcType.getShape ()[i]));
84
- }
85
- }
86
-
87
- auto outputShape = inferExpandShapeOutputShape (
88
- rewriter, loc, expandShapeOp.getResultType (),
89
- expandShapeOp.getReassociationIndices (), srcShape);
90
- if (!outputShape.has_value ())
91
- return failure ();
62
+ static LogicalResult resolveSourceIndicesExpandShape (
63
+ Location loc, PatternRewriter &rewriter,
64
+ memref::ExpandShapeOp expandShapeOp, ValueRange indices,
65
+ SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
66
+ SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape ();
92
67
93
68
// Traverse all reassociation groups to determine the appropriate indices
94
69
// corresponding to each one of them post op folding.
95
- for (ArrayRef<int64_t > groups : expandShapeOp.getReassociationIndices ()) {
96
- assert (!groups.empty () && " association indices groups cannot be empty" );
97
- // Flag to indicate the presence of dynamic dimensions in current
98
- // reassociation group.
99
- int64_t groupSize = groups.size ();
100
-
101
- // Group output dimensions utilized in this reassociation group for suffix
102
- // product calculation.
103
- SmallVector<OpFoldResult> sizesVal (groupSize);
104
- for (int64_t i = 0 ; i < groupSize; ++i) {
105
- sizesVal[i] = (*outputShape)[groups[i]];
70
+ for (ArrayRef<int64_t > group : expandShapeOp.getReassociationIndices ()) {
71
+ assert (!group.empty () && " association indices groups cannot be empty" );
72
+ int64_t groupSize = group.size ();
73
+ if (groupSize == 1 ) {
74
+ sourceIndices.push_back (indices[group[0 ]]);
75
+ continue ;
106
76
}
107
-
108
- // Calculate suffix product of relevant output dimension sizes.
109
- SmallVector<OpFoldResult> suffixProduct =
110
- memref::computeSuffixProductIRBlock (loc, rewriter, sizesVal);
111
-
112
- // Create affine expression variables for dimensions and symbols in the
113
- // newly constructed affine map.
114
- SmallVector<AffineExpr> dims (groupSize), symbols (groupSize);
115
- bindDimsList<AffineExpr>(ctx, dims);
116
- bindSymbolsList<AffineExpr>(ctx, symbols);
117
-
118
- // Linearize binded dimensions and symbols to construct the resultant
119
- // affine expression for this indice.
120
- AffineExpr srcIndexExpr = linearize (ctx, dims, symbols);
121
-
122
- // Record the load index corresponding to each dimension in the
123
- // reassociation group. These are later supplied as operands to the affine
124
- // map used for calulating relevant index post op folding.
125
- SmallVector<OpFoldResult> dynamicIndices (groupSize);
126
- for (int64_t i = 0 ; i < groupSize; i++)
127
- dynamicIndices[i] = indices[groups[i]];
128
-
129
- // Supply suffix product results followed by load op indices as operands
130
- // to the map.
131
- SmallVector<OpFoldResult> mapOperands;
132
- llvm::append_range (mapOperands, suffixProduct);
133
- llvm::append_range (mapOperands, dynamicIndices);
134
-
135
- // Creating maximally folded and composed affine.apply composes better
136
- // with other transformations without interleaving canonicalization
137
- // passes.
138
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
139
- rewriter, loc,
140
- AffineMap::get (/* numDims=*/ groupSize,
141
- /* numSymbols=*/ groupSize, /* expression=*/ srcIndexExpr),
142
- mapOperands);
143
-
144
- // Push index value in the op post folding corresponding to this
145
- // reassociation group.
146
- sourceIndices.push_back (
147
- getValueOrCreateConstantIndexOp (rewriter, loc, ofr));
77
+ SmallVector<OpFoldResult> groupBasis =
78
+ llvm::map_to_vector (group, [&](int64_t d) { return destShape[d]; });
79
+ SmallVector<Value> groupIndices =
80
+ llvm::map_to_vector (group, [&](int64_t d) { return indices[d]; });
81
+ Value collapsedIndex = rewriter.create <affine::AffineLinearizeIndexOp>(
82
+ loc, groupIndices, groupBasis, /* disjoint=*/ startsInbounds);
83
+ sourceIndices.push_back (collapsedIndex);
148
84
}
149
85
return success ();
150
86
}
@@ -167,49 +103,34 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
167
103
memref::CollapseShapeOp collapseShapeOp,
168
104
ValueRange indices,
169
105
SmallVectorImpl<Value> &sourceIndices) {
170
- int64_t cnt = 0 ;
171
- SmallVector<OpFoldResult> dynamicIndices;
172
- for (ArrayRef<int64_t > groups : collapseShapeOp.getReassociationIndices ()) {
173
- assert (!groups.empty () && " association indices groups cannot be empty" );
174
- dynamicIndices.push_back (indices[cnt++]);
175
- int64_t groupSize = groups.size ();
176
-
177
- // Calculate suffix product for all collapse op source dimension sizes
178
- // except the most major one of each group.
179
- // We allow the most major source dimension to be dynamic but enforce all
180
- // others to be known statically.
181
- SmallVector<int64_t > sizes (groupSize, 1 );
182
- for (int64_t i = 1 ; i < groupSize; ++i) {
183
- sizes[i] = collapseShapeOp.getSrcType ().getDimSize (groups[i]);
184
- if (sizes[i] == ShapedType::kDynamic )
185
- return failure ();
186
- }
187
- SmallVector<int64_t > suffixProduct = computeSuffixProduct (sizes);
188
-
189
- // Derive the index values along all dimensions of the source corresponding
190
- // to the index wrt to collapsed shape op output.
191
- auto d0 = rewriter.getAffineDimExpr (0 );
192
- SmallVector<AffineExpr> delinearizingExprs = delinearize (d0, suffixProduct);
193
-
194
- // Construct the AffineApplyOp for each delinearizingExpr.
195
- for (int64_t i = 0 ; i < groupSize; i++) {
196
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
197
- rewriter, loc,
198
- AffineMap::get (/* numDims=*/ 1 , /* numSymbols=*/ 0 ,
199
- delinearizingExprs[i]),
200
- dynamicIndices);
201
- sourceIndices.push_back (
202
- getValueOrCreateConstantIndexOp (rewriter, loc, ofr));
106
+ MemRefType sourceType = collapseShapeOp.getSrcType ();
107
+ // Note: collapse_shape requires a strided memref, we can do this.
108
+ auto metadata = rewriter.create <memref::ExtractStridedMetadataOp>(
109
+ loc, collapseShapeOp.getSrc ());
110
+ SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes ();
111
+ for (auto [index, group] :
112
+ llvm::zip (indices, collapseShapeOp.getReassociationIndices ())) {
113
+ assert (!group.empty () && " association indices groups cannot be empty" );
114
+ int64_t groupSize = group.size ();
115
+
116
+ if (groupSize == 1 ) {
117
+ sourceIndices.push_back (index);
118
+ continue ;
203
119
}
204
- dynamicIndices.clear ();
120
+
121
+ SmallVector<OpFoldResult> basis =
122
+ llvm::map_to_vector (group, [&](int64_t d) { return sourceSizes[d]; });
123
+ auto delinearize = rewriter.create <affine::AffineDelinearizeIndexOp>(
124
+ loc, index, basis, /* hasOuterBound=*/ true );
125
+ llvm::append_range (sourceIndices, delinearize.getResults ());
205
126
}
206
127
if (collapseShapeOp.getReassociationIndices ().empty ()) {
207
128
auto zeroAffineMap = rewriter.getConstantAffineMap (0 );
208
129
int64_t srcRank =
209
130
cast<MemRefType>(collapseShapeOp.getViewSource ().getType ()).getRank ();
131
+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
132
+ rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
210
133
for (int64_t i = 0 ; i < srcRank; i++) {
211
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
212
- rewriter, loc, zeroAffineMap, dynamicIndices);
213
134
sourceIndices.push_back (
214
135
getValueOrCreateConstantIndexOp (rewriter, loc, ofr));
215
136
}
@@ -513,8 +434,12 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
513
434
indices.assign (expandedIndices.begin (), expandedIndices.end ());
514
435
}
515
436
SmallVector<Value> sourceIndices;
437
+ // memref.load and affine.load guarantee that indexes start inbounds
438
+ // while the vector operations don't. This impacts if our linearization
439
+ // is `disjoint`
516
440
if (failed (resolveSourceIndicesExpandShape (
517
- loadOp.getLoc (), rewriter, expandShapeOp, indices, sourceIndices)))
441
+ loadOp.getLoc (), rewriter, expandShapeOp, indices, sourceIndices,
442
+ isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation ()))))
518
443
return failure ();
519
444
llvm::TypeSwitch<Operation *, void >(loadOp)
520
445
.Case ([&](affine::AffineLoadOp op) {
@@ -676,8 +601,12 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
676
601
indices.assign (expandedIndices.begin (), expandedIndices.end ());
677
602
}
678
603
SmallVector<Value> sourceIndices;
604
+ // memref.store and affine.store guarantee that indexes start inbounds
605
+ // while the vector operations don't. This impacts if our linearization
606
+ // is `disjoint`
679
607
if (failed (resolveSourceIndicesExpandShape (
680
- storeOp.getLoc (), rewriter, expandShapeOp, indices, sourceIndices)))
608
+ storeOp.getLoc (), rewriter, expandShapeOp, indices, sourceIndices,
609
+ isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation ()))))
681
610
return failure ();
682
611
llvm::TypeSwitch<Operation *, void >(storeOp)
683
612
.Case ([&](affine::AffineStoreOp op) {
0 commit comments