Skip to content

Commit 58259bc

Browse files
committed
[mlir][vector] Add 1D vector.deinterleave lowering
This patchs implements the lowering of vector.deinterleave for 1D vectors. For fixed vector types, the operation is lowered to two llvm shufflevector operations. One for even indexed elements and the other for odd indexed elements. A poison operation is used to satisfy the parameters of the shufflevector parameters. For scalable vectors, the llvm vector.deinterleave2 intrinsic is used for lowering. As such the results found by extraction and used to form the result struct for the intrinsic.
1 parent 7dd81f4 commit 58259bc

File tree

4 files changed

+136
-2
lines changed

4 files changed

+136
-2
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,6 +1761,70 @@ struct VectorInterleaveOpLowering
17611761
}
17621762
};
17631763

1764+
/// Conversion pattern for a `vector.deinterleave`.
1765+
/// This supports for fixed-sized vectors and scalable vectors.
1766+
struct VectorDeinterleaveOpLowering
1767+
: public ConvertOpToLLVMPattern<vector::DeinterleaveOp> {
1768+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1769+
1770+
LogicalResult
1771+
matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1772+
ConversionPatternRewriter &rewriter) const override {
1773+
VectorType resultType = deinterleaveOp.getResultVectorType();
1774+
VectorType sourceType = deinterleaveOp.getSourceVectorType();
1775+
auto loc = deinterleaveOp.getLoc();
1776+
1777+
// Note: n-D deinterleave operations should be lowered to the 1-D before
1778+
// converting to LLVM.
1779+
if (resultType.getRank() != 1)
1780+
return rewriter.notifyMatchFailure(deinterleaveOp,
1781+
"DeinterleaveOp not rank 1");
1782+
1783+
if (resultType.isScalable()) {
1784+
auto llvmTypeConverter = this->getTypeConverter();
1785+
auto deinterleaveResults = deinterleaveOp.getResultTypes();
1786+
auto packedOpResults =
1787+
llvmTypeConverter->packOperationResults(deinterleaveResults);
1788+
auto intrinsic = rewriter.create<LLVM::vector_deinterleave2>(
1789+
loc, packedOpResults, adaptor.getSource());
1790+
1791+
auto evenResult = rewriter.create<LLVM::ExtractValueOp>(
1792+
loc, intrinsic->getResult(0), 0);
1793+
auto oddResult = rewriter.create<LLVM::ExtractValueOp>(
1794+
loc, intrinsic->getResult(0), 1);
1795+
1796+
rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult});
1797+
return success();
1798+
}
1799+
// Lower fixed-size deinterleave to two shufflevectors. While the
1800+
// vector.deinterleave2 intrinsic supports fixed and scalable vectors, the
1801+
// langref still recommends fixed-vectors use shufflevector, see:
1802+
// https://llvm.org/docs/LangRef.html#id889.
1803+
int64_t resultVectorSize = resultType.getNumElements();
1804+
SmallVector<int32_t> evenShuffleMask;
1805+
SmallVector<int32_t> oddShuffleMask;
1806+
1807+
evenShuffleMask.reserve(resultVectorSize);
1808+
oddShuffleMask.reserve(resultVectorSize);
1809+
1810+
for (int i = 0; i < sourceType.getNumElements(); ++i) {
1811+
if (i % 2 == 0)
1812+
evenShuffleMask.push_back(i);
1813+
else
1814+
oddShuffleMask.push_back(i);
1815+
}
1816+
1817+
auto poison = rewriter.create<LLVM::PoisonOp>(loc, sourceType);
1818+
auto evenShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1819+
loc, adaptor.getSource(), poison, evenShuffleMask);
1820+
auto oddShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1821+
loc, adaptor.getSource(), poison, oddShuffleMask);
1822+
1823+
rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle});
1824+
return success();
1825+
}
1826+
};
1827+
17641828
} // namespace
17651829

17661830
/// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1785,8 +1849,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
17851849
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
17861850
VectorSplatOpLowering, VectorSplatNdOpLowering,
17871851
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1788-
MaskedReductionOpConversion, VectorInterleaveOpLowering>(
1789-
converter);
1852+
MaskedReductionOpConversion, VectorInterleaveOpLowering,
1853+
VectorDeinterleaveOpLowering>(converter);
17901854
// Transfer ops with rank > 1 are handled by VectorToSCF.
17911855
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
17921856
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2546,3 +2546,25 @@ func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]
25462546
%0 = vector.interleave %a, %b : vector<2x[8]xi16>
25472547
return %0 : vector<2x[16]xi16>
25482548
}
2549+
2550+
// -----
2551+
2552+
// CHECK-LABEL: @vector_deinterleave_1d
2553+
// CHECK-SAME: (%[[SRC:.*]]: vector<4xi32>) -> (vector<2xi32>, vector<2xi32>)
2554+
func.func @vector_deinterleave_1d(%a: vector<4xi32>) -> (vector<2xi32>, vector<2xi32>) {
2555+
// CHECK: %[[POISON:.*]] = llvm.mlir.poison : vector<4xi32>
2556+
// CHECK: llvm.shufflevector %[[SRC]], %[[POISON]] [0, 2] : vector<4xi32>
2557+
// CHECK: llvm.shufflevector %[[SRC]], %[[POISON]] [1, 3] : vector<4xi32>
2558+
%0, %1 = vector.deinterleave %a : vector<4xi32> -> vector<2xi32>
2559+
return %0, %1 : vector<2xi32>, vector<2xi32>
2560+
}
2561+
2562+
// CHECK-LABEL: @vector_deinterleave_1d_scalable
2563+
// CHECK-SAME: %[[SRC:.*]]: vector<[4]xi32>) -> (vector<[2]xi32>, vector<[2]xi32>)
2564+
func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi32>, vector<[2]xi32>) {
2565+
// CHECK: %[[RES:.*]] = "llvm.intr.vector.deinterleave2"
2566+
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(vector<[2]xi32>, vector<[2]xi32>)>
2567+
// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(vector<[2]xi32>, vector<[2]xi32>)>
2568+
%0, %1 = vector.deinterleave %a : vector<[4]xi32> -> vector<[2]xi32>
2569+
return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
2570+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// DEFINE: %{entry_point} = entry
2+
// DEFINE: %{compile} = mlir-opt %s -test-lower-to-llvm
3+
// DEFINE: %{run} = %mcr_aarch64_cmd -march=aarch64 -mattr=+sve \
4+
// DEFINE: -e %{entry_point} -entry-point-result=void \
5+
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_runner_utils
6+
7+
// RUN: %{compile} | %{run} | FileCheck %s
8+
9+
func.func @entry() {
10+
// Scalable vector length set to 256 which is equivalent to vscale 2.
11+
// The test assumes that the vscale is 2. As such the SVL is set accordingly.
12+
%c256 = arith.constant 256 : i32
13+
func.call @setArmVLBits(%c256) : (i32) -> ()
14+
func.call @test_deinterleave() : () -> ()
15+
return
16+
}
17+
18+
func.func @test_deinterleave() {
19+
%step_vector = llvm.intr.experimental.stepvector : vector<[4]xi8>
20+
vector.print %step_vector : vector<[4]xi8>
21+
// CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
22+
%v1, %v2 = vector.deinterleave %step_vector : vector<[4]xi8> -> vector<[2]xi8>
23+
vector.print %v1 : vector<[2]xi8>
24+
vector.print %v2 : vector<[2]xi8>
25+
// CHECK: ( 0, 2, 4, 6 )
26+
// CHECK: ( 1, 3, 5, 7 )
27+
return
28+
}
29+
30+
func.func private @setArmVLBits(%bits : i32)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: mlir-opt %s -test-lower-to-llvm | \
2+
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
3+
// RUN: -shared-libs=%mlir_c_runner_utils | \
4+
// RUN: FileCheck %s
5+
6+
func.func @entry() {
7+
%v0 = arith.constant dense<[1, 2, 3, 4]> : vector<4xi8>
8+
vector.print %v0 : vector<4xi8>
9+
// CHECK: ( 1, 2, 3, 4 )
10+
11+
%v1, %v2 = vector.deinterleave %v0 : vector<4xi8> -> vector<2xi8>
12+
vector.print %v1 : vector<2xi8>
13+
vector.print %v2 : vector<2xi8>
14+
// CHECK: ( 1, 3 )
15+
// CHECK: ( 2, 4 )
16+
17+
return
18+
}

0 commit comments

Comments
 (0)