Skip to content

Commit 4449999

Browse files
committed
[mlir][vector] Implement lowering for 1D vector.deinterleave operations
This patchs implements the lowering of vector.deinterleave for 1D vectors. For fixed vector types, the operation is lowered to two llvm shufflevector operations. One for even indexed elements and the other for odd indexed elements. A poison operation is used to satisfy the parameters of the shufflevector parameters. For scalable vectors, the llvm vector.deinterleave2 intrinsic is used for lowering. As such the results found by extraction and used to form the result struct for the intrinsic.
1 parent 3d79b27 commit 4449999

File tree

4 files changed

+120
-1
lines changed

4 files changed

+120
-1
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1761,6 +1761,66 @@ struct VectorInterleaveOpLowering
17611761
}
17621762
};
17631763

1764+
/// Conversion pattern for a `vector.deinterleave`.
1765+
/// Support available for fixed-sized vectors and scalable vectors.
1766+
1767+
struct VectorDeinterleaveOpLowering
1768+
: public ConvertOpToLLVMPattern<vector::DeinterleaveOp> {
1769+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1770+
1771+
LogicalResult
1772+
matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1773+
ConversionPatternRewriter &rewriter) const override {
1774+
VectorType resultType = deinterleaveOp.getResultVectorType();
1775+
VectorType sourceType = deinterleaveOp.getSourceVectorType();
1776+
auto loc = deinterleaveOp.getLoc();
1777+
1778+
if (resultType.getRank() != 1)
1779+
return rewriter.notifyMatchFailure(deinterleaveOp,
1780+
"deinterleaveOp not rank 1");
1781+
1782+
if (resultType.isScalable()) {
1783+
auto llvmTypeConverter = this->getTypeConverter();
1784+
auto deinterleaveResults = deinterleaveOp.getResultTypes();
1785+
auto packedOpResults = llvmTypeConverter->packOperationResults(deinterleaveResults);
1786+
auto intrinsic = rewriter.create<LLVM::vector_deinterleave2>(loc, packedOpResults, adaptor.getSource());
1787+
1788+
auto resultOne = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsic->getResult(0), 0);
1789+
auto resultTwo = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsic->getResult(0), 1);
1790+
1791+
rewriter.replaceOp(
1792+
deinterleaveOp, ValueRange{resultOne, resultTwo}
1793+
);
1794+
return success();
1795+
}
1796+
1797+
int64_t resultVectorSize = resultType.getNumElements();
1798+
auto poison = rewriter.create<LLVM::PoisonOp>(loc, sourceType);
1799+
SmallVector<int32_t> shuffleMaskOne;
1800+
SmallVector<int32_t> shuffleMaskTwo;
1801+
1802+
shuffleMaskOne.reserve(resultVectorSize);
1803+
shuffleMaskTwo.reserve(resultVectorSize);
1804+
1805+
for (int i = 0; i < sourceType.getNumElements(); ++i) {
1806+
if (i % 2 == 0)
1807+
shuffleMaskOne.push_back(i);
1808+
else
1809+
shuffleMaskTwo.push_back(i);
1810+
}
1811+
1812+
auto evenShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1813+
loc, adaptor.getSource(), poison, shuffleMaskOne);
1814+
auto oddShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1815+
loc, adaptor.getSource(), poison, shuffleMaskTwo);
1816+
1817+
rewriter.replaceOp(
1818+
deinterleaveOp, ValueRange{evenShuffle, oddShuffle}
1819+
);
1820+
return::success();
1821+
}
1822+
};
1823+
17641824
} // namespace
17651825

17661826
/// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1785,7 +1845,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
17851845
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
17861846
VectorSplatOpLowering, VectorSplatNdOpLowering,
17871847
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1788-
MaskedReductionOpConversion, VectorInterleaveOpLowering>(
1848+
MaskedReductionOpConversion, VectorInterleaveOpLowering,
1849+
VectorDeinterleaveOpLowering>(
17891850
converter);
17901851
// Transfer ops with rank > 1 are handled by VectorToSCF.
17911852
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2546,3 +2546,25 @@ func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]
25462546
%0 = vector.interleave %a, %b : vector<2x[8]xi16>
25472547
return %0 : vector<2x[16]xi16>
25482548
}
2549+
2550+
// -----
2551+
2552+
// CHECK-LABEL: @vector_deinterleave_1d
2553+
// CHECK-SAME: (%{{.*}}: vector<4xi32>) -> (vector<2xi32>, vector<2xi32>)
2554+
func.func @vector_deinterleave_1d(%a: vector<4xi32>) -> (vector<2xi32>, vector<2xi32>) {
2555+
// CHECK: llvm.mlir.poison : vector<4xi32>
2556+
// CHECK: llvm.shufflevector %{{.*}}, %{{.*}} [0, 2] : vector<4xi32>
2557+
// CHECK: llvm.shufflevector %{{.*}}, %{{.*}} [1, 3] : vector<4xi32>
2558+
%0, %1 = vector.deinterleave %a : vector<4xi32> -> vector<2xi32>
2559+
return %0, %1 : vector<2xi32>, vector<2xi32>
2560+
}
2561+
2562+
// CHECK-LABEL: @vector_deinterleave_1d_scalable
2563+
// CHECK-SAME: %{{.*}}: vector<[4]xi32>) -> (vector<[2]xi32>, vector<[2]xi32>)
2564+
func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi32>, vector<[2]xi32>) {
2565+
// CHECK: llvm.intr.vector.deinterleave2
2566+
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(vector<[2]xi32>, vector<[2]xi32>)>
2567+
// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(vector<[2]xi32>, vector<[2]xi32>)>
2568+
%0, %1 = vector.deinterleave %a : vector<[4]xi32> -> vector<[2]xi32>
2569+
return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
2570+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: mlir-opt %s -test-lower-to-llvm | \
2+
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
3+
// RUN: -shared-libs=%mlir_c_runner_utils | \
4+
// RUN: FileCheck %s
5+
6+
func.func @entry() {
7+
%step_vector = llvm.intr.experimental.stepvector : vector<[4]xi8>
8+
vector.print %step_vector : vector<[4]xi8>
9+
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
10+
11+
%v1, %v2 = vector.deinterleave %step_vector : vector<[4]xi8> -> vector<[2]xi8>
12+
vector.print %v1 : vector<[2]xi8>
13+
vector.print %v2 : vector<[2]xi8>
14+
// CHECK: ( 0, 2, 4, 6 )
15+
// CHECK: ( 1, 3, 5, 7 )
16+
17+
return
18+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: mlir-opt %s -test-lower-to-llvm | \
2+
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
3+
// RUN: -shared-libs=%mlir_c_runner_utils | \
4+
// RUN: FileCheck %s
5+
6+
func.func @entry() {
7+
%v0 = arith.constant dense<[1, 2, 3, 4]> : vector<4xi8>
8+
vector.print %v0 : vector<4xi8>
9+
// CHECK: ( 1, 2, 3, 4 )
10+
11+
%v1, %v2 = vector.deinterleave %v0 : vector<4xi8> -> vector<2xi8>
12+
vector.print %v1 : vector<2xi8>
13+
vector.print %v2 : vector<2xi8>
14+
// CHECK: ( 1, 3 )
15+
// CHECK: ( 2, 4 )
16+
17+
return
18+
}

0 commit comments

Comments
 (0)