Skip to content

Commit b7c1976

Browse files
committed
[mlir][VectorOps] Add fold vector.shuffle -> vector.interleave
This folds fixed-size vector.shuffle ops that perform a 1-D interleave to a vector.interleave operation. i.e.: ```mlir %0 = vector.shuffle %a, %b [0, 2, 1, 4] : vector<2xi32>, vector<2xi32> ``` to: ```mlir %0 = vector.interleave %a, %b : vector<2xi32> ```
1 parent d31406b commit b7c1976

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2479,11 +2479,52 @@ class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
24792479
}
24802480
};
24812481

2482+
/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
2483+
/// vector.interleave.
2484+
class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
2485+
public:
2486+
using OpRewritePattern::OpRewritePattern;
2487+
2488+
LogicalResult matchAndRewrite(ShuffleOp op,
2489+
PatternRewriter &rewriter) const override {
2490+
VectorType resultType = op.getResultVectorType();
2491+
if (resultType.isScalable())
2492+
return rewriter.notifyMatchFailure(
2493+
op, "ShuffleOp can't represent a scalable interleave");
2494+
2495+
if (resultType.getRank() != 1)
2496+
return rewriter.notifyMatchFailure(
2497+
op, "ShuffleOp can't represent an n-D interleave");
2498+
2499+
VectorType sourceType = op.getV1VectorType();
2500+
if (sourceType != op.getV2VectorType() ||
2501+
ArrayRef<int64_t>{sourceType.getNumElements() * 2} !=
2502+
resultType.getShape()) {
2503+
return rewriter.notifyMatchFailure(
2504+
op, "ShuffleOp types don't match an interleave");
2505+
}
2506+
2507+
ArrayAttr shuffleMask = op.getMask();
2508+
int64_t resultVectorSize = resultType.getNumElements();
2509+
for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2510+
int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
2511+
int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
2512+
if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2513+
return rewriter.notifyMatchFailure(op,
2514+
"ShuffleOp mask not interleaving");
2515+
}
2516+
2517+
rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
2518+
return success();
2519+
}
2520+
};
2521+
24822522
} // namespace
24832523

24842524
void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
24852525
MLIRContext *context) {
2486-
results.add<ShuffleSplat, Canonicalize0DShuffleOp>(context);
2526+
results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2527+
context);
24872528
}
24882529

24892530
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2567,3 +2567,26 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
25672567
tensor<4x4x4xf32>, vector<1x100x4x5xf32>
25682568
return %r : vector<1x100x4x5xf32>
25692569
}
2570+
2571+
// -----
2572+
2573+
// CHECK-LABEL: func.func @rank_0_shuffle_to_interleave(
2574+
// CHECK-SAME: %[[LHS:.*]]: vector<f64>, %[[RHS:.*]]: vector<f64>)
2575+
func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64>
2576+
{
2577+
// CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<f64>
2578+
// CHECK: return %[[ZIP]]
2579+
%0 = vector.shuffle %arg0, %arg1 [0, 1] : vector<f64>, vector<f64>
2580+
return %0 : vector<2xf64>
2581+
}
2582+
2583+
// -----
2584+
2585+
// CHECK-LABEL: func.func @rank_1_shuffle_to_interleave(
2586+
// CHECK-SAME: %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>)
2587+
func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> {
2588+
// CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32>
2589+
// CHECK: return %[[ZIP]]
2590+
%0 = vector.shuffle %arg0, %arg1 [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32>
2591+
return %0 : vector<12xi32>
2592+
}

0 commit comments

Comments
 (0)