Skip to content

Commit 795f4b8

Browse files
committed
Fixups
- Remove vector.interleave -> vector.shuffle canonicalization - Add vector.shuffle -> vector.interleave canonicalization - Split vector.interleave unrolling and LLVM lowering - Unrolling now done in LowerVectorInterleave.cpp - Add missing tests to vector ops.mlir - Fixed a few nits
1 parent 171007a commit 795f4b8

File tree

11 files changed

+192
-112
lines changed

11 files changed

+192
-112
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,8 +539,6 @@ def Vector_InterleaveOp :
539539
return ::llvm::cast<VectorType>(getResult().getType());
540540
}
541541
}];
542-
543-
let hasCanonicalizer = 1;
544542
}
545543

546544
def Vector_ExtractElementOp :

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,14 @@ void populateVectorMaskLoweringPatternsForSideEffectingOps(
264264
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns,
265265
PatternBenefit benefit = 1);
266266

267+
/// Populate the pattern set with the following patterns:
268+
///
269+
/// [InterleaveOpLowering]
270+
/// Progressive lowering of InterleaveOp to ExtractOp + InsertOp + lower-D
271+
/// InterleaveOp until dim 1.
272+
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
273+
PatternBenefit benefit = 1);
274+
267275
} // namespace vector
268276
} // namespace mlir
269277
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 22 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,66 +1734,40 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
17341734
}
17351735
};
17361736

1737+
/// Conversion pattern for a `vector.interleave`.
1738+
/// This supports fixed-sized vectors and scalable vectors.
17371739
struct VectorInterleaveOpLowering
17381740
: public ConvertOpToLLVMPattern<vector::InterleaveOp> {
17391741
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
17401742

1741-
void initialize() {
1742-
// This pattern recursively unpacks one dimension at a time. The recursion
1743-
// bounded as the rank is strictly decreasing.
1744-
setHasBoundedRewriteRecursion();
1745-
}
1746-
17471743
LogicalResult
17481744
matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
17491745
ConversionPatternRewriter &rewriter) const override {
17501746
VectorType resultType = interleaveOp.getResultVectorType();
1751-
1747+
// n-D interleaves should have been lowered already.
1748+
if (resultType.getRank() != 1)
1749+
return failure();
17521750
// If the result is rank 1, then this directly maps to LLVM.
1753-
if (resultType.getRank() == 1) {
1754-
if (resultType.isScalable()) {
1755-
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_interleave2>(
1756-
interleaveOp, typeConverter->convertType(resultType),
1757-
adaptor.getLhs(), adaptor.getRhs());
1758-
return success();
1759-
}
1760-
// Lower fixed-size interleaves to a shufflevector. While the
1761-
// vector.interleave2 intrinsic supports fixed and scalable vectors, the
1762-
// langref still recommends fixed-vectors use shufflevector, see:
1763-
// https://llvm.org/docs/LangRef.html#id876.
1764-
int64_t resultVectorSize = resultType.getNumElements();
1765-
SmallVector<int32_t> interleaveShuffleMask;
1766-
interleaveShuffleMask.reserve(resultVectorSize);
1767-
for (int i = 0; i < resultVectorSize / 2; i++) {
1768-
interleaveShuffleMask.push_back(i);
1769-
interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1770-
}
1771-
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1772-
interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1773-
interleaveShuffleMask);
1751+
if (resultType.isScalable()) {
1752+
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_interleave2>(
1753+
interleaveOp, typeConverter->convertType(resultType),
1754+
adaptor.getLhs(), adaptor.getRhs());
17741755
return success();
17751756
}
1776-
1777-
// It's not possible to unroll a scalable dimension.
1778-
if (resultType.getScalableDims().front())
1779-
return failure();
1780-
1781-
// n-D case: Unroll the leading dimension.
1782-
// This eventually converges to an LLVM lowering.
1783-
auto loc = interleaveOp.getLoc();
1784-
Value result = rewriter.create<arith::ConstantOp>(
1785-
loc, resultType, rewriter.getZeroAttr(resultType));
1786-
for (int d = 0; d < resultType.getDimSize(0); d++) {
1787-
Value extractLhs =
1788-
rewriter.create<ExtractOp>(loc, interleaveOp.getLhs(), d);
1789-
Value extractRhs =
1790-
rewriter.create<ExtractOp>(loc, interleaveOp.getRhs(), d);
1791-
Value dimInterleave =
1792-
rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
1793-
result = rewriter.create<InsertOp>(loc, dimInterleave, result, d);
1757+
// Lower fixed-size interleaves to a shufflevector. While the
1758+
// vector.interleave2 intrinsic supports fixed and scalable vectors, the
1759+
// langref still recommends fixed-vectors use shufflevector, see:
1760+
// https://llvm.org/docs/LangRef.html#id876.
1761+
int64_t resultVectorSize = resultType.getNumElements();
1762+
SmallVector<int32_t> interleaveShuffleMask;
1763+
interleaveShuffleMask.reserve(resultVectorSize);
1764+
for (int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1765+
interleaveShuffleMask.push_back(i);
1766+
interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
17941767
}
1795-
1796-
rewriter.replaceOp(interleaveOp, result);
1768+
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1769+
interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1770+
interleaveShuffleMask);
17971771
return success();
17981772
}
17991773
};

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
6868
populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
6969
populateVectorMaskOpLoweringPatterns(patterns);
7070
populateVectorShapeCastLoweringPatterns(patterns);
71+
populateVectorInterleaveLoweringPatterns(patterns);
7172
populateVectorTransposeLoweringPatterns(patterns,
7273
VectorTransformsOptions());
7374
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 42 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2478,11 +2478,52 @@ class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
24782478
}
24792479
};
24802480

2481+
/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
2482+
/// vector.interleave.
2483+
class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
2484+
public:
2485+
using OpRewritePattern::OpRewritePattern;
2486+
2487+
LogicalResult matchAndRewrite(ShuffleOp op,
2488+
PatternRewriter &rewriter) const override {
2489+
VectorType resultType = op.getResultVectorType();
2490+
if (resultType.isScalable())
2491+
return rewriter.notifyMatchFailure(
2492+
op, "ShuffleOp can't represent a scalable interleave");
2493+
2494+
if (resultType.getRank() != 1)
2495+
return rewriter.notifyMatchFailure(
2496+
op, "ShuffleOp can't represent an n-D interleave");
2497+
2498+
VectorType sourceType = op.getV1VectorType();
2499+
if (sourceType != op.getV2VectorType() ||
2500+
ArrayRef<int64_t>{sourceType.getNumElements() * 2} !=
2501+
resultType.getShape()) {
2502+
return rewriter.notifyMatchFailure(
2503+
op, "ShuffleOp types don't match an interleave");
2504+
}
2505+
2506+
ArrayAttr shuffleMask = op.getMask();
2507+
int64_t resultVectorSize = resultType.getNumElements();
2508+
for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2509+
int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
2510+
int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
2511+
if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
2512+
return rewriter.notifyMatchFailure(op,
2513+
"ShuffleOp mask not interleaving");
2514+
}
2515+
2516+
rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
2517+
return success();
2518+
}
2519+
};
2520+
24812521
} // namespace
24822522

24832523
void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
24842524
MLIRContext *context) {
2485-
results.add<ShuffleSplat, Canonicalize0DShuffleOp>(context);
2525+
results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
2526+
context);
24862527
}
24872528

24882529
//===----------------------------------------------------------------------===//
@@ -6308,48 +6349,6 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
63086349
verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
63096350
}
63106351

6311-
//===----------------------------------------------------------------------===//
6312-
// InterleaveOp
6313-
//===----------------------------------------------------------------------===//
6314-
6315-
// The rank 1 case of vector.interleave on fixed-size vectors is equivalent to a
6316-
// vector.shuffle, which (as an older op) is more likely to be matched by
6317-
// existing pipelines.
6318-
struct FoldRank1FixedSizeInterleaveOp : public OpRewritePattern<InterleaveOp> {
6319-
using OpRewritePattern::OpRewritePattern;
6320-
6321-
LogicalResult matchAndRewrite(InterleaveOp interleaveOp,
6322-
PatternRewriter &rewriter) const override {
6323-
auto resultType = interleaveOp.getResultVectorType();
6324-
if (resultType.getRank() != 1)
6325-
return rewriter.notifyMatchFailure(
6326-
interleaveOp, "cannot fold interleave with result rank > 1");
6327-
6328-
if (resultType.isScalable())
6329-
return rewriter.notifyMatchFailure(
6330-
interleaveOp, "cannot fold interleave of scalable vectors");
6331-
6332-
int64_t resultVectorSize = resultType.getNumElements();
6333-
SmallVector<int64_t> interleaveShuffleMask;
6334-
interleaveShuffleMask.reserve(resultVectorSize);
6335-
for (int i = 0; i < resultVectorSize / 2; i++) {
6336-
interleaveShuffleMask.push_back(i);
6337-
interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
6338-
}
6339-
6340-
rewriter.replaceOpWithNewOp<ShuffleOp>(interleaveOp, interleaveOp.getLhs(),
6341-
interleaveOp.getRhs(),
6342-
interleaveShuffleMask);
6343-
6344-
return success();
6345-
}
6346-
};
6347-
6348-
void InterleaveOp::getCanonicalizationPatterns(RewritePatternSet &results,
6349-
MLIRContext *context) {
6350-
results.add<FoldRank1FixedSizeInterleaveOp>(context);
6351-
}
6352-
63536352
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
63546353
CombiningKind kind, Value v1, Value acc,
63556354
arith::FastMathFlagsAttr fastmath,

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
44
LowerVectorBroadcast.cpp
55
LowerVectorContract.cpp
66
LowerVectorGather.cpp
7+
LowerVectorInterleave.cpp
78
LowerVectorMask.cpp
89
LowerVectorMultiReduction.cpp
910
LowerVectorScan.cpp
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements target-independent rewrites and utilities to lower the
10+
// 'vector.interleave' operation.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
15+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
16+
#include "mlir/IR/BuiltinTypes.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
19+
#define DEBUG_TYPE "vector-interleave-lowering"
20+
21+
using namespace mlir;
22+
using namespace mlir::vector;
23+
24+
namespace {
25+
/// Progressive lowering of InterleaveOp.
26+
class InterleaveOpLowering : public OpRewritePattern<vector::InterleaveOp> {
27+
public:
28+
using OpRewritePattern::OpRewritePattern;
29+
30+
LogicalResult matchAndRewrite(vector::InterleaveOp op,
31+
PatternRewriter &rewriter) const override {
32+
VectorType resultType = op.getResultVectorType();
33+
// 1-D vector.interleave ops can be directly lowered to LLVM (later).
34+
if (resultType.getRank() == 1)
35+
return failure();
36+
37+
// Below we unroll the leading (or front) dimension. If that dimension is
38+
// scalable we can't unroll it.
39+
if (resultType.getScalableDims().front())
40+
return failure();
41+
42+
// n-D case: Unroll the leading dimension.
43+
auto loc = op.getLoc();
44+
Value result = rewriter.create<arith::ConstantOp>(
45+
loc, resultType, rewriter.getZeroAttr(resultType));
46+
for (int idx = 0, end = resultType.getDimSize(0); idx < end; ++idx) {
47+
Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), idx);
48+
Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), idx);
49+
Value interleave =
50+
rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
51+
result = rewriter.create<InsertOp>(loc, interleave, result, idx);
52+
}
53+
54+
rewriter.replaceOp(op, result);
55+
return success();
56+
}
57+
};
58+
59+
} // namespace
60+
61+
void mlir::vector::populateVectorInterleaveLoweringPatterns(
62+
RewritePatternSet &patterns, PatternBenefit benefit) {
63+
patterns.add<InterleaveOpLowering>(patterns.getContext(), benefit);
64+
}

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2570,23 +2570,23 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
25702570

25712571
// -----
25722572

2573-
// CHECK-LABEL: func.func @fold_rank_1_vector_interleave(
2574-
// CHECK-SAME: %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>)
2575-
func.func @fold_rank_1_vector_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> {
2576-
// CHECK: %[[ZIP:.*]] = vector.shuffle %[[LHS]], %[[RHS]] [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32>
2577-
// CHECK: return %[[ZIP]] : vector<12xi32>
2578-
%0 = vector.interleave %arg0, %arg1 : vector<6xi32>
2579-
return %0 : vector<12xi32>
2573+
// CHECK-LABEL: func.func @rank_0_shuffle_to_interleave(
2574+
// CHECK-SAME: %[[LHS:.*]]: vector<f64>, %[[RHS:.*]]: vector<f64>)
2575+
func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64>
2576+
{
2577+
// CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<f64>
2578+
// CHECK: return %[[ZIP]]
2579+
%0 = vector.shuffle %arg0, %arg1 [0, 1] : vector<f64>, vector<f64>
2580+
return %0 : vector<2xf64>
25802581
}
25812582

25822583
// -----
25832584

2584-
// CHECK-LABEL: func.func @fold_rank_0_vector_interleave(
2585-
// CHECK-SAME: %[[LHS:.*]]: vector<f64>, %[[RHS:.*]]: vector<f64>)
2586-
func.func @fold_rank_0_vector_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64>
2587-
{
2588-
// CHECK: %[[ZIP:.*]] = vector.shuffle %[[LHS]], %[[RHS]] [0, 1] : vector<f64>, vector<f64>
2589-
// CHECK: return %[[ZIP]] : vector<2xf64>
2590-
%0 = vector.interleave %arg0, %arg1 : vector<f64>
2591-
return %0 : vector<2xf64>
2585+
// CHECK-LABEL: func.func @rank_1_shuffle_to_interleave(
2586+
// CHECK-SAME: %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>)
2587+
func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> {
2588+
// CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32>
2589+
// CHECK: return %[[ZIP]]
2590+
%0 = vector.shuffle %arg0, %arg1 [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32>
2591+
return %0 : vector<12xi32>
25922592
}

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,3 +1081,38 @@ func.func @fastmath(%x: vector<42xf32>) -> f32 {
10811081
%min = vector.reduction <minnumf>, %x fastmath<reassoc,nnan,ninf> : vector<42xf32> into f32
10821082
return %min: f32
10831083
}
1084+
1085+
// CHECK-LABEL: @interleave_0d
1086+
func.func @interleave_0d(%a: vector<f32>, %b: vector<f32>) -> vector<2xf32> {
1087+
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<f32>
1088+
%0 = vector.interleave %a, %b : vector<f32>
1089+
return %0 : vector<2xf32>
1090+
}
1091+
1092+
// CHECK-LABEL: @interleave_1d
1093+
func.func @interleave_1d(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<8xf32> {
1094+
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<4xf32>
1095+
%0 = vector.interleave %a, %b : vector<4xf32>
1096+
return %0 : vector<8xf32>
1097+
}
1098+
1099+
// CHECK-LABEL: @interleave_1d_scalable
1100+
func.func @interleave_1d_scalable(%a: vector<[8]xi16>, %b: vector<[8]xi16>) -> vector<[16]xi16> {
1101+
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<[8]xi16>
1102+
%0 = vector.interleave %a, %b : vector<[8]xi16>
1103+
return %0 : vector<[16]xi16>
1104+
}
1105+
1106+
// CHECK-LABEL: @interleave_2d
1107+
func.func @interleave_2d(%a: vector<2x8xf32>, %b: vector<2x8xf32>) -> vector<2x16xf32> {
1108+
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x8xf32>
1109+
%0 = vector.interleave %a, %b : vector<2x8xf32>
1110+
return %0 : vector<2x16xf32>
1111+
}
1112+
1113+
// CHECK-LABEL: @interleave_2d_scalable
1114+
func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>) -> vector<2x[4]xf64> {
1115+
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x[2]xf64>
1116+
%0 = vector.interleave %a, %b : vector<2x[2]xf64>
1117+
return %0 : vector<2x[4]xf64>
1118+
}

mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
// RUN: FileCheck %s
55

66
func.func @entry() {
7-
%f1 = arith.constant 1.0: f32
8-
%f2 = arith.constant 2.0: f32
7+
%f1 = arith.constant 1.0 : f32
8+
%f2 = arith.constant 2.0 : f32
99
%v1 = vector.splat %f1 : vector<[4]xf32>
1010
%v2 = vector.splat %f2 : vector<[4]xf32>
1111
vector.print %v1 : vector<[4]xf32>

mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
// RUN: FileCheck %s
55

66
func.func @entry() {
7-
%f1 = arith.constant 1.0: f32
8-
%f2 = arith.constant 2.0: f32
7+
%f1 = arith.constant 1.0 : f32
8+
%f2 = arith.constant 2.0 : f32
99
%v1 = vector.splat %f1 : vector<2x4xf32>
1010
%v2 = vector.splat %f2 : vector<2x4xf32>
1111
vector.print %v1 : vector<2x4xf32>

0 commit comments

Comments
 (0)