@@ -2479,11 +2479,52 @@ class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2479
2479
}
2480
2480
};
2481
2481
2482
+ // / Pattern to rewrite a fixed-size interleave via vector.shuffle to
2483
+ // / vector.interleave.
2484
+ class ShuffleInterleave : public OpRewritePattern <ShuffleOp> {
2485
+ public:
2486
+ using OpRewritePattern::OpRewritePattern;
2487
+
2488
+ LogicalResult matchAndRewrite (ShuffleOp op,
2489
+ PatternRewriter &rewriter) const override {
2490
+ VectorType resultType = op.getResultVectorType ();
2491
+ if (resultType.isScalable ())
2492
+ return rewriter.notifyMatchFailure (
2493
+ op, " ShuffleOp can't represent a scalable interleave" );
2494
+
2495
+ if (resultType.getRank () != 1 )
2496
+ return rewriter.notifyMatchFailure (
2497
+ op, " ShuffleOp can't represent an n-D interleave" );
2498
+
2499
+ VectorType sourceType = op.getV1VectorType ();
2500
+ if (sourceType != op.getV2VectorType () ||
2501
+ ArrayRef<int64_t >{sourceType.getNumElements () * 2 } !=
2502
+ resultType.getShape ()) {
2503
+ return rewriter.notifyMatchFailure (
2504
+ op, " ShuffleOp types don't match an interleave" );
2505
+ }
2506
+
2507
+ ArrayAttr shuffleMask = op.getMask ();
2508
+ int64_t resultVectorSize = resultType.getNumElements ();
2509
+ for (int i = 0 , e = resultVectorSize / 2 ; i < e; ++i) {
2510
+ int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2 ]).getInt ();
2511
+ int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2 ) + 1 ]).getInt ();
2512
+ if (maskValueA != i || maskValueB != (resultVectorSize / 2 ) + i)
2513
+ return rewriter.notifyMatchFailure (op,
2514
+ " ShuffleOp mask not interleaving" );
2515
+ }
2516
+
2517
+ rewriter.replaceOpWithNewOp <InterleaveOp>(op, op.getV1 (), op.getV2 ());
2518
+ return success ();
2519
+ }
2520
+ };
2521
+
2482
2522
} // namespace
2483
2523
2484
2524
void ShuffleOp::getCanonicalizationPatterns (RewritePatternSet &results,
2485
2525
MLIRContext *context) {
2486
- results.add <ShuffleSplat, Canonicalize0DShuffleOp>(context);
2526
+ results.add <ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2527
+ context);
2487
2528
}
2488
2529
2489
2530
// ===----------------------------------------------------------------------===//
0 commit comments