@@ -2478,11 +2478,52 @@ class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2478
2478
}
2479
2479
};
2480
2480
2481
+ // / Pattern to rewrite a fixed-size interleave via vector.shuffle to
2482
+ // / vector.interleave.
2483
+ class ShuffleInterleave : public OpRewritePattern <ShuffleOp> {
2484
+ public:
2485
+ using OpRewritePattern::OpRewritePattern;
2486
+
2487
+ LogicalResult matchAndRewrite (ShuffleOp op,
2488
+ PatternRewriter &rewriter) const override {
2489
+ VectorType resultType = op.getResultVectorType ();
2490
+ if (resultType.isScalable ())
2491
+ return rewriter.notifyMatchFailure (
2492
+ op, " ShuffleOp can't represent a scalable interleave" );
2493
+
2494
+ if (resultType.getRank () != 1 )
2495
+ return rewriter.notifyMatchFailure (
2496
+ op, " ShuffleOp can't represent an n-D interleave" );
2497
+
2498
+ VectorType sourceType = op.getV1VectorType ();
2499
+ if (sourceType != op.getV2VectorType () ||
2500
+ ArrayRef<int64_t >{sourceType.getNumElements () * 2 } !=
2501
+ resultType.getShape ()) {
2502
+ return rewriter.notifyMatchFailure (
2503
+ op, " ShuffleOp types don't match an interleave" );
2504
+ }
2505
+
2506
+ ArrayAttr shuffleMask = op.getMask ();
2507
+ int64_t resultVectorSize = resultType.getNumElements ();
2508
+ for (int i = 0 , e = resultVectorSize / 2 ; i < e; ++i) {
2509
+ int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2 ]).getInt ();
2510
+ int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2 ) + 1 ]).getInt ();
2511
+ if (maskValueA != i || maskValueB != (resultVectorSize / 2 ) + i)
2512
+ return rewriter.notifyMatchFailure (op,
2513
+ " ShuffleOp mask not interleaving" );
2514
+ }
2515
+
2516
+ rewriter.replaceOpWithNewOp <InterleaveOp>(op, op.getV1 (), op.getV2 ());
2517
+ return success ();
2518
+ }
2519
+ };
2520
+
2481
2521
} // namespace
2482
2522
2483
2523
void ShuffleOp::getCanonicalizationPatterns (RewritePatternSet &results,
2484
2524
MLIRContext *context) {
2485
- results.add <ShuffleSplat, Canonicalize0DShuffleOp>(context);
2525
+ results.add <ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2526
+ context);
2486
2527
}
2487
2528
2488
2529
// ===----------------------------------------------------------------------===//
0 commit comments