@@ -54,6 +54,10 @@ class ShapeCastOp2DDownCastRewritePattern
54
54
PatternRewriter &rewriter) const override {
55
55
auto sourceVectorType = op.getSourceVectorType ();
56
56
auto resultVectorType = op.getResultVectorType ();
57
+
58
+ if (sourceVectorType.isScalable () || resultVectorType.isScalable ())
59
+ return failure ();
60
+
57
61
if (sourceVectorType.getRank () != 2 || resultVectorType.getRank () != 1 )
58
62
return failure ();
59
63
@@ -87,6 +91,10 @@ class ShapeCastOp2DUpCastRewritePattern
87
91
PatternRewriter &rewriter) const override {
88
92
auto sourceVectorType = op.getSourceVectorType ();
89
93
auto resultVectorType = op.getResultVectorType ();
94
+
95
+ if (sourceVectorType.isScalable () || resultVectorType.isScalable ())
96
+ return failure ();
97
+
90
98
if (sourceVectorType.getRank () != 1 || resultVectorType.getRank () != 2 )
91
99
return failure ();
92
100
@@ -106,6 +114,20 @@ class ShapeCastOp2DUpCastRewritePattern
106
114
}
107
115
};
108
116
117
+ static void incIdx (llvm::MutableArrayRef<int64_t > idx, VectorType tp,
118
+ int dimIdx, int initialStep = 1 ) {
119
+ int step = initialStep;
120
+ for (int d = dimIdx; d >= 0 ; d--) {
121
+ idx[d] += step;
122
+ if (idx[d] >= tp.getDimSize (d)) {
123
+ idx[d] = 0 ;
124
+ step = 1 ;
125
+ } else {
126
+ break ;
127
+ }
128
+ }
129
+ }
130
+
109
131
// We typically should not lower general shape cast operations into data
110
132
// movement instructions, since the assumption is that these casts are
111
133
// optimized away during progressive lowering. For completeness, however,
@@ -121,6 +143,9 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
121
143
auto sourceVectorType = op.getSourceVectorType ();
122
144
auto resultVectorType = op.getResultVectorType ();
123
145
146
+ if (sourceVectorType.isScalable () || resultVectorType.isScalable ())
147
+ return failure ();
148
+
124
149
// Special case 2D / 1D lowerings with better implementations.
125
150
// TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
126
151
int64_t srcRank = sourceVectorType.getRank ();
@@ -175,21 +200,161 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
175
200
rewriter.replaceOp (op, result);
176
201
return success ();
177
202
}
203
+ };
204
+
205
+ // / A shape_cast lowering for scalable vectors with a single trailing scalable
206
+ // / dimension. This is similar to the general shape_cast lowering but makes use
207
+ // / of vector.scalable.insert and vector.scalable.extract to move elements a
208
+ // / subvector at a time.
209
+ // /
210
+ // / E.g.:
211
+ // / ```
212
+ // / // Flatten scalable vector
213
+ // / %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
214
+ // / ```
215
+ // / is rewritten to:
216
+ // / ```
217
+ // / // Flatten scalable vector
218
+ // / %c = arith.constant dense<0> : vector<[8]xi32>
219
+ // / %0 = vector.extract %arg0[0, 0] : vector<2x1x[4]xi32>
220
+ // / %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
221
+ // / %2 = vector.extract %arg0[1, 0] : vector<2x1x[4]xi32>
222
+ // / %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
223
+ // / ```
224
+ // / or:
225
+ // / ```
226
+ // / // Un-flatten scalable vector
227
+ // / %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
228
+ // / ```
229
+ // / is rewritten to:
230
+ // / ```
231
+ // / // Un-flatten scalable vector
232
+ // / %c = arith.constant dense<0> : vector<2x1x[4]xi32>
233
+ // / %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
234
+ // / %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
235
+ // / %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
236
+ // / %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
237
+ // / ```
238
+ class ScalableShapeCastOpRewritePattern
239
+ : public OpRewritePattern<vector::ShapeCastOp> {
240
+ public:
241
+ using OpRewritePattern::OpRewritePattern;
242
+
243
+ LogicalResult matchAndRewrite (vector::ShapeCastOp op,
244
+ PatternRewriter &rewriter) const override {
245
+
246
+ Location loc = op.getLoc ();
247
+ auto sourceVectorType = op.getSourceVectorType ();
248
+ auto resultVectorType = op.getResultVectorType ();
249
+ auto srcRank = sourceVectorType.getRank ();
250
+ auto resRank = resultVectorType.getRank ();
251
+
252
+ // This can only lower shape_casts where both the source and result types
253
+ // have a single trailing scalable dimension. This is because there are no
254
+ // legal representation of other scalable types in LLVM (and likely won't be
255
+ // soon). There are also (currently) no operations that can index or extract
256
+ // from >= 2D scalable vectors or scalable vectors of fixed vectors.
257
+ if (!isTrailingDimScalable (sourceVectorType) ||
258
+ !isTrailingDimScalable (resultVectorType)) {
259
+ return failure ();
260
+ }
261
+
262
+ // The sizes of the trailing dimension of the source and result vectors, the
263
+ // size of subvector to move, and the number of elements in the vectors.
264
+ // These are "min" sizes as they are the size when vscale == 1.
265
+ auto minSourceTrailingSize = sourceVectorType.getShape ().back ();
266
+ auto minResultTrailingSize = resultVectorType.getShape ().back ();
267
+ auto minExtractionSize =
268
+ std::min (minSourceTrailingSize, minResultTrailingSize);
269
+ int64_t minNumElts = 1 ;
270
+ for (auto size : sourceVectorType.getShape ())
271
+ minNumElts *= size;
272
+
273
+ // The subvector type to move from the source to the result. Note that this
274
+ // is a scalable vector. This rewrite will generate code in terms of the
275
+ // "min" size (vscale == 1 case), that scales to any vscale.
276
+ auto extractionVectorType = VectorType::get (
277
+ {minExtractionSize}, sourceVectorType.getElementType (), {true });
278
+
279
+ Value result = rewriter.create <arith::ConstantOp>(
280
+ loc, resultVectorType, rewriter.getZeroAttr (resultVectorType));
281
+
282
+ SmallVector<int64_t > srcIdx (srcRank);
283
+ SmallVector<int64_t > resIdx (resRank);
284
+
285
+ // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
286
+ // once D150000 lands.
287
+ Value currentResultScalableVector;
288
+ Value currentSourceScalableVector;
289
+ for (int64_t i = 0 ; i < minNumElts; i += minExtractionSize) {
290
+ // 1. Extract a scalable subvector from the source vector.
291
+ if (!currentSourceScalableVector) {
292
+ if (srcRank != 1 ) {
293
+ currentSourceScalableVector = rewriter.create <vector::ExtractOp>(
294
+ loc, op.getSource (), llvm::ArrayRef (srcIdx).drop_back ());
295
+ } else {
296
+ currentSourceScalableVector = op.getSource ();
297
+ }
298
+ }
299
+ Value sourceSubVector = currentSourceScalableVector;
300
+ if (minExtractionSize < minSourceTrailingSize) {
301
+ sourceSubVector = rewriter.create <vector::ScalableExtractOp>(
302
+ loc, extractionVectorType, sourceSubVector, srcIdx.back ());
303
+ }
178
304
179
- private:
180
- static void incIdx (SmallVector<int64_t > &idx, VectorType tp, int64_t r) {
181
- assert (0 <= r && r < tp.getRank ());
182
- if (++idx[r] == tp.getDimSize (r)) {
183
- idx[r] = 0 ;
184
- incIdx (idx, tp, r - 1 );
305
+ // 2. Insert the scalable subvector into the result vector.
306
+ if (!currentResultScalableVector) {
307
+ if (minExtractionSize == minResultTrailingSize) {
308
+ currentResultScalableVector = sourceSubVector;
309
+ } else if (resRank != 1 ) {
310
+ currentResultScalableVector = rewriter.create <vector::ExtractOp>(
311
+ loc, result, llvm::ArrayRef (resIdx).drop_back ());
312
+ } else {
313
+ currentResultScalableVector = result;
314
+ }
315
+ }
316
+ if (minExtractionSize < minResultTrailingSize) {
317
+ currentResultScalableVector = rewriter.create <vector::ScalableInsertOp>(
318
+ loc, sourceSubVector, currentResultScalableVector, resIdx.back ());
319
+ }
320
+
321
+ // 3. Update the source and result scalable vectors if needed.
322
+ if (resIdx.back () + minExtractionSize >= minResultTrailingSize &&
323
+ currentResultScalableVector != result) {
324
+ // Finished row of result. Insert complete scalable vector into result
325
+ // (n-D) vector.
326
+ result = rewriter.create <vector::InsertOp>(
327
+ loc, currentResultScalableVector, result,
328
+ llvm::ArrayRef (resIdx).drop_back ());
329
+ currentResultScalableVector = {};
330
+ }
331
+ if (srcIdx.back () + minExtractionSize >= minSourceTrailingSize) {
332
+ // Finished row of source.
333
+ currentSourceScalableVector = {};
334
+ }
335
+
336
+ // 4. Increment the insert/extract indices, stepping by minExtractionSize
337
+ // for the trailing dimensions.
338
+ incIdx (srcIdx, sourceVectorType, srcRank - 1 , minExtractionSize);
339
+ incIdx (resIdx, resultVectorType, resRank - 1 , minExtractionSize);
185
340
}
341
+
342
+ rewriter.replaceOp (op, result);
343
+ return success ();
344
+ }
345
+
346
+ static bool isTrailingDimScalable (VectorType type) {
347
+ return type.getRank () >= 1 && type.getScalableDims ().back () &&
348
+ !llvm::is_contained (type.getScalableDims ().drop_back (), true );
186
349
}
187
350
};
351
+
188
352
} // namespace
189
353
190
354
void mlir::vector::populateVectorShapeCastLoweringPatterns (
191
355
RewritePatternSet &patterns, PatternBenefit benefit) {
192
356
patterns.add <ShapeCastOp2DDownCastRewritePattern,
193
- ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
194
- patterns.getContext (), benefit);
357
+ ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
358
+ ScalableShapeCastOpRewritePattern>(patterns.getContext (),
359
+ benefit);
195
360
}
0 commit comments