Skip to content

Commit c4979c9

Browse files
authored
[mlir][VectorOps] Add fold vector.shuffle -> vector.interleave (#80968)
This folds fixed-size vector.shuffle ops that perform a 1-D interleave to a vector.interleave operation. For example: ```mlir %0 = vector.shuffle %a, %b [0, 2, 1, 4] : vector<2xi32>, vector<2xi32> ``` folds to: ```mlir %0 = vector.interleave %a, %b : vector<2xi32> ``` Depends on: #80967
1 parent 38763be commit c4979c9

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2479,11 +2479,51 @@ class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
24792479
}
24802480
};
24812481

2482+
/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
2483+
/// vector.interleave.
2484+
class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
2485+
public:
2486+
using OpRewritePattern::OpRewritePattern;
2487+
2488+
LogicalResult matchAndRewrite(ShuffleOp op,
2489+
PatternRewriter &rewriter) const override {
2490+
VectorType resultType = op.getResultVectorType();
2491+
if (resultType.isScalable())
2492+
return rewriter.notifyMatchFailure(
2493+
op, "ShuffleOp can't represent a scalable interleave");
2494+
2495+
if (resultType.getRank() != 1)
2496+
return rewriter.notifyMatchFailure(
2497+
op, "ShuffleOp can't represent an n-D interleave");
2498+
2499+
VectorType sourceType = op.getV1VectorType();
2500+
if (sourceType != op.getV2VectorType() ||
2501+
sourceType.getNumElements() * 2 != resultType.getNumElements()) {
2502+
return rewriter.notifyMatchFailure(
2503+
op, "ShuffleOp types don't match an interleave");
2504+
}
2505+
2506+
ArrayAttr shuffleMask = op.getMask();
2507+
int64_t resultVectorSize = resultType.getNumElements();
2508+
for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2509+
int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
2510+
int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
2511+
if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2512+
return rewriter.notifyMatchFailure(op,
2513+
"ShuffleOp mask not interleaving");
2514+
}
2515+
2516+
rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
2517+
return success();
2518+
}
2519+
};
2520+
24822521
} // namespace
24832522

24842523
void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
24852524
MLIRContext *context) {
2486-
results.add<ShuffleSplat, Canonicalize0DShuffleOp>(context);
2525+
results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2526+
context);
24872527
}
24882528

24892529
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2567,3 +2567,26 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
25672567
tensor<4x4x4xf32>, vector<1x100x4x5xf32>
25682568
return %r : vector<1x100x4x5xf32>
25692569
}
2570+
2571+
// -----
2572+
2573+
// CHECK-LABEL: func.func @rank_0_shuffle_to_interleave(
2574+
// CHECK-SAME: %[[LHS:.*]]: vector<f64>, %[[RHS:.*]]: vector<f64>)
2575+
func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64>
2576+
{
2577+
// CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<f64>
2578+
// CHECK: return %[[ZIP]]
2579+
%0 = vector.shuffle %arg0, %arg1 [0, 1] : vector<f64>, vector<f64>
2580+
return %0 : vector<2xf64>
2581+
}
2582+
2583+
// -----
2584+
2585+
// CHECK-LABEL: func.func @rank_1_shuffle_to_interleave(
2586+
// CHECK-SAME: %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>)
2587+
func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> {
2588+
// CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32>
2589+
// CHECK: return %[[ZIP]]
2590+
%0 = vector.shuffle %arg0, %arg1 [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32>
2591+
return %0 : vector<12xi32>
2592+
}

0 commit comments

Comments
 (0)