@@ -2479,11 +2479,51 @@ class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2479
2479
}
2480
2480
};
2481
2481
2482
+ // / Pattern to rewrite a fixed-size interleave via vector.shuffle to
2483
+ // / vector.interleave.
2484
+ class ShuffleInterleave : public OpRewritePattern <ShuffleOp> {
2485
+ public:
2486
+ using OpRewritePattern::OpRewritePattern;
2487
+
2488
+ LogicalResult matchAndRewrite (ShuffleOp op,
2489
+ PatternRewriter &rewriter) const override {
2490
+ VectorType resultType = op.getResultVectorType ();
2491
+ if (resultType.isScalable ())
2492
+ return rewriter.notifyMatchFailure (
2493
+ op, " ShuffleOp can't represent a scalable interleave" );
2494
+
2495
+ if (resultType.getRank () != 1 )
2496
+ return rewriter.notifyMatchFailure (
2497
+ op, " ShuffleOp can't represent an n-D interleave" );
2498
+
2499
+ VectorType sourceType = op.getV1VectorType ();
2500
+ if (sourceType != op.getV2VectorType () ||
2501
+ sourceType.getNumElements () * 2 != resultType.getNumElements ()) {
2502
+ return rewriter.notifyMatchFailure (
2503
+ op, " ShuffleOp types don't match an interleave" );
2504
+ }
2505
+
2506
+ ArrayAttr shuffleMask = op.getMask ();
2507
+ int64_t resultVectorSize = resultType.getNumElements ();
2508
+ for (int i = 0 , e = resultVectorSize / 2 ; i < e; ++i) {
2509
+ int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2 ]).getInt ();
2510
+ int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2 ) + 1 ]).getInt ();
2511
+ if (maskValueA != i || maskValueB != (resultVectorSize / 2 ) + i)
2512
+ return rewriter.notifyMatchFailure (op,
2513
+ " ShuffleOp mask not interleaving" );
2514
+ }
2515
+
2516
+ rewriter.replaceOpWithNewOp <InterleaveOp>(op, op.getV1 (), op.getV2 ());
2517
+ return success ();
2518
+ }
2519
+ };
2520
+
2482
2521
} // namespace
2483
2522
2484
2523
void ShuffleOp::getCanonicalizationPatterns (RewritePatternSet &results,
2485
2524
MLIRContext *context) {
2486
- results.add <ShuffleSplat, Canonicalize0DShuffleOp>(context);
2525
+ results.add <ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2526
+ context);
2487
2527
}
2488
2528
2489
2529
// ===----------------------------------------------------------------------===//
0 commit comments