@@ -326,6 +326,10 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
326
326
VectorType inputType = op.getSourceVectorType ();
327
327
VectorType resType = op.getResultVectorType ();
328
328
329
+ if (inputType.isScalable ())
330
+ return rewriter.notifyMatchFailure (
331
+ op, " This lowering does not support scalable vectors" );
332
+
329
333
// Set up convenience transposition table.
330
334
ArrayRef<int64_t > transp = op.getPermutation ();
331
335
@@ -334,28 +338,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
334
338
return rewriter.notifyMatchFailure (
335
339
op, " Options specifies lowering to shuffle" );
336
340
337
- // Replace:
338
- // vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
339
- // vector<1xnxelty>
340
- // with:
341
- // vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
342
- //
343
- // Source with leading unit dim (inverse) is also replaced. Unit dim must
344
- // be fixed. Non-unit can be scalable.
345
- if (resType.getRank () == 2 &&
346
- ((resType.getShape ().front () == 1 &&
347
- !resType.getScalableDims ().front ()) ||
348
- (resType.getShape ().back () == 1 &&
349
- !resType.getScalableDims ().back ())) &&
350
- transp == ArrayRef<int64_t >({1 , 0 })) {
351
- rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(op, resType, input);
352
- return success ();
353
- }
354
-
355
- // TODO: Add support for scalable vectors
356
- if (inputType.isScalable ())
357
- return failure ();
358
-
359
341
// Handle a true 2-D matrix transpose differently when requested.
360
342
if (vectorTransformOptions.vectorTransposeLowering ==
361
343
vector::VectorTransposeLowering::Flat &&
@@ -411,6 +393,64 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
411
393
vector::VectorTransformsOptions vectorTransformOptions;
412
394
};
413
395
396
+ // / Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
397
+ // / to 2D vectors with at least one unit dim. For example:
398
+ // /
399
+ // / Replace:
400
+ // / vector.transpose %0, [1, 0] : vector<4x1xi32>> to
401
+ // / vector<1x4xi32>
402
+ // / with:
403
+ // / vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
404
+ // /
405
+ // / Source with leading unit dim (inverse) is also replaced. Unit dim must
406
+ // / be fixed. Non-unit dim can be scalable.
407
+ // /
408
+ // / TODO: This pattern was introduced specifically to help lower scalable
409
+ // / vectors. In hindsight, a more specialised canonicalization (for shape_cast's
410
+ // / to cancel out) would be preferable:
411
+ // /
412
+ // / BEFORE:
413
+ // / %0 = some_op
414
+ // / %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
415
+ // / %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
416
+ // / AFTER:
417
+ // / %0 = some_op
418
+ // / %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
419
+ // /
420
+ // / Given the context above, we may want to consider (re-)moving this pattern
421
+ // / at some later time. I am leaving it for now in case there are other users
422
+ // / that I am not aware of.
423
+ class Transpose2DWithUnitDimToShapeCast
424
+ : public OpRewritePattern<vector::TransposeOp> {
425
+ public:
426
+ using OpRewritePattern::OpRewritePattern;
427
+
428
+ Transpose2DWithUnitDimToShapeCast (MLIRContext *context,
429
+ PatternBenefit benefit = 1 )
430
+ : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
431
+
432
+ LogicalResult matchAndRewrite (vector::TransposeOp op,
433
+ PatternRewriter &rewriter) const override {
434
+ Value input = op.getVector ();
435
+ VectorType resType = op.getResultVectorType ();
436
+
437
+ // Set up convenience transposition table.
438
+ ArrayRef<int64_t > transp = op.getPermutation ();
439
+
440
+ if (resType.getRank () == 2 &&
441
+ ((resType.getShape ().front () == 1 &&
442
+ !resType.getScalableDims ().front ()) ||
443
+ (resType.getShape ().back () == 1 &&
444
+ !resType.getScalableDims ().back ())) &&
445
+ transp == ArrayRef<int64_t >({1 , 0 })) {
446
+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(op, resType, input);
447
+ return success ();
448
+ }
449
+
450
+ return failure ();
451
+ }
452
+ };
453
+
414
454
// / Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
415
455
// / If the strategy is Shuffle1D, it will be lowered to:
416
456
// / vector.shape_cast 2D -> 1D
@@ -483,6 +523,8 @@ class TransposeOp2DToShuffleLowering
483
523
void mlir::vector::populateVectorTransposeLoweringPatterns (
484
524
RewritePatternSet &patterns, VectorTransformsOptions options,
485
525
PatternBenefit benefit) {
526
+ patterns.add <Transpose2DWithUnitDimToShapeCast>(patterns.getContext (),
527
+ benefit);
486
528
patterns.add <TransposeOpLowering, TransposeOp2DToShuffleLowering>(
487
529
options, patterns.getContext (), benefit);
488
530
}
0 commit comments