Skip to content

Commit 0987e00

Browse files
committed
[mlir][VectorOps] Add deinterleave operation to vector dialect
The deinterleave operation constructs two vectors from a single input vector. Each new vector is the collection of even and odd elements from the input, respectively. This is essentially the inverse of an interleave operation. Each output's size is half of the input vector's trailing dimension for the n-D case and only dimension for 1-D cases. It is not possible to conduct the operation on 0-D inputs or vectors where the size of the (trailing) dimension is 1. The operation supports scalable vectors. Example: ```mlir %0 = vector.deinterleave %a : vector<[4]xi32> ; yields vector<[2]xi32>, vector<[2]xi32> %1 = vector.deinterleave %b : vector<8xi8> ; yields vector<4xi8>, vector<4xi8> %2 = vector.deinterleave %c : vector<2x8xf32> ; yields vector<2x4xf32>, vector<2x4xf32> %3 = vector.deinterleave %d : vector<2x4x[6]xf64> ; yields vector<2x4x[3]xf64>, vector<2x4x[3]xf64> ```
1 parent 11b97da commit 0987e00

File tree

3 files changed

+174
-0
lines changed

3 files changed

+174
-0
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,82 @@ def Vector_InterleaveOp :
543543
}];
544544
}
545545

546+
class ResultIsHalfSourceVectorType<string result> : TypesMatchWith<
547+
"type of 'input' is double the width of results",
548+
"input", result,
549+
[{
550+
[&]() -> ::mlir::VectorType {
551+
auto vectorType = ::llvm::cast<mlir::VectorType>($_self);
552+
::mlir::VectorType::Builder builder(vectorType);
553+
auto lastDim = vectorType.getRank() - 1;
554+
auto newDimSize = vectorType.getDimSize(lastDim) / 2;;
555+
if (newDimSize <= 0)
556+
return vectorType; // (invalid input type)
557+
return builder.setDim(lastDim, newDimSize);
558+
}()
559+
}]
560+
>;
561+
562+
def Vector_DeinterleaveOp :
563+
Vector_Op<"deinterleave", [Pure,
564+
PredOpTrait<"trailing dimension of input vector must be an even number",
565+
CPred<[{
566+
[&](){
567+
auto srcVec = getSourceVectorType();
568+
return srcVec.getDimSize(srcVec.getRank() - 1) % 2 == 0;
569+
}()
570+
}]>>,
571+
ResultIsHalfSourceVectorType<"res1">,
572+
ResultIsHalfSourceVectorType<"res2">,
573+
AllTypesMatch<["res1", "res2"]>
574+
]> {
575+
let summary = "constructs two vectors by deinterleaving an input vector";
576+
let description = [{
577+
The deinterleave operation constructs two vectors from a single input
578+
vector. Each new vector is the collection of even and odd elements
579+
from the input, respectively. This is essentially the inverse of an
580+
interleave operation.
581+
582+
Each output's size is half of the input vector's trailing dimension
583+
for the n-D case and only dimension for 1-D cases. It is not possible
584+
to conduct the operation on 0-D inputs or vectors where the size of
585+
the (trailing) dimension is 1.
586+
587+
The operation supports scalable vectors.
588+
589+
Example:
590+
```mlir
591+
%0 = vector.deinterleave %a
592+
: vector<[4]xi32> ; yields vector<[2]xi32>, vector<[2]xi32>
593+
%1 = vector.deinterleave %b
594+
: vector<8xi8> ; yields vector<4xi8>, vector<4xi8>
595+
%2 = vector.deinterleave %c
596+
: vector<2x8xf32> ; yields vector<2x4xf32>, vector<2x4xf32>
597+
%3 = vector.deinterleave %d
598+
: vector<2x4x[6]xf64> ; yields vector<2x4x[3]xf64>, vector<2x4x[3]xf64>
599+
```
600+
}];
601+
602+
let arguments = (ins AnyVector:$input);
603+
let results = (outs AnyVector:$res1, AnyVector:$res2);
604+
605+
let assemblyFormat = [{
606+
$input attr-dict `:` type($input)
607+
}];
608+
609+
let extraClassDeclaration = [{
610+
VectorType getSourceVectorType() {
611+
return ::llvm::cast<VectorType>(getInput().getType());
612+
}
613+
VectorType getResultOneVectorType() {
614+
return ::llvm::cast<VectorType>(getRes1().getType());
615+
}
616+
VectorType getResultTwoVectorType() {
617+
return ::llvm::cast<VectorType>(getRes2().getType());
618+
}
619+
}];
620+
}
621+
546622
def Vector_ExtractElementOp :
547623
Vector_Op<"extractelement", [Pure,
548624
TypesMatchWith<"result type matches element type of vector operand",

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1798,3 +1798,59 @@ func.func @invalid_outerproduct1(%src : memref<?xf32>) {
17981798
// expected-error @+1 {{'vector.outerproduct' op expected 1-d vector for operand #1}}
17991799
%op = vector.outerproduct %0, %1 : vector<[4]x[4]xf32>, vector<[4]xf32>
18001800
}
1801+
1802+
// -----
1803+
1804+
func.func @deinterleave_zero_dim_fail(%vec : vector<f32>) {
1805+
// expected-error @+1 {{'vector.deinterleave' 'input' must be vector of any type values, but got 'vector<f32>'}}
1806+
%0, %1 = vector.deinterleave %vec : vector<f32>
1807+
return
1808+
}
1809+
1810+
// -----
1811+
1812+
func.func @deinterleave_one_dim_fail(%vec : vector<1xf32>) {
1813+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that trailing dimension of input vector must be an even number}}
1814+
%0, %1 = vector.deinterleave %vec : vector<1xf32>
1815+
return
1816+
}
1817+
1818+
// -----
1819+
1820+
func.func @deinterleave_oversized_output_fail(%vec : vector<4xf32>) {
1821+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
1822+
%0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<8xf32>, vector<8xf32>)
1823+
return
1824+
}
1825+
1826+
// -----
1827+
1828+
func.func @deinterleave_output_dim_size_mismatch(%vec : vector<4xf32>) {
1829+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
1830+
%0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<4xf32>, vector<2xf32>)
1831+
return
1832+
}
1833+
1834+
// -----
1835+
1836+
func.func @deinterleave_n_dim_rank_fail(%vec : vector<2x3x4xf32>) {
1837+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
1838+
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x3x4xf32>) -> (vector<2x3x4xf32>, vector<2x3x2xf32>)
1839+
return
1840+
}
1841+
1842+
// -----
1843+
1844+
func.func @deinterleave_scalable_dim_size_fail(%vec : vector<2x[4]xf32>) {
1845+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
1846+
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<2x[1]xf32>)
1847+
return
1848+
}
1849+
1850+
// -----
1851+
1852+
func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
1853+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
1854+
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<[2]xf32>)
1855+
return
1856+
}

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,3 +1116,45 @@ func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>)
11161116
%0 = vector.interleave %a, %b : vector<2x[2]xf64>
11171117
return %0 : vector<2x[4]xf64>
11181118
}
1119+
1120+
// CHECK-LABEL: @deinterleave_1d
1121+
func.func @deinterleave_1d(%arg: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
1122+
// CHECK: vector.deinterleave %{{.*}} : vector<4xf32>
1123+
%0, %1 = vector.deinterleave %arg : vector<4xf32>
1124+
return %0, %1 : vector<2xf32>, vector<2xf32>
1125+
}
1126+
1127+
// CHECK-LABEL: @deinterleave_1d_scalable
1128+
func.func @deinterleave_1d_scalable(%arg: vector<[4]xf32>) -> (vector<[2]xf32>, vector<[2]xf32>) {
1129+
// CHECK: vector.deinterleave %{{.*}} : vector<[4]xf32>
1130+
%0, %1 = vector.deinterleave %arg : vector<[4]xf32>
1131+
return %0, %1 : vector<[2]xf32>, vector<[2]xf32>
1132+
}
1133+
1134+
// CHECK-LABEL: @deinterleave_2d
1135+
func.func @deinterleave_2d(%arg: vector<3x4xf32>) -> (vector<3x2xf32>, vector<3x2xf32>) {
1136+
// CHECK: vector.deinterleave %{{.*}} : vector<3x4xf32>
1137+
%0, %1 = vector.deinterleave %arg : vector<3x4xf32>
1138+
return %0, %1 : vector<3x2xf32>, vector<3x2xf32>
1139+
}
1140+
1141+
// CHECK-LABEL: @deinterleave_2d_scalable
1142+
func.func @deinterleave_2d_scalable(%arg: vector<3x[4]xf32>) -> (vector<3x[2]xf32>, vector<3x[2]xf32>) {
1143+
// CHECK: vector.deinterleave %{{.*}} : vector<3x[4]xf32>
1144+
%0, %1 = vector.deinterleave %arg : vector<3x[4]xf32>
1145+
return %0, %1 : vector<3x[2]xf32>, vector<3x[2]xf32>
1146+
}
1147+
1148+
// CHECK-LABEL: @deinterleave_nd
1149+
func.func @deinterleave_nd(%arg: vector<2x3x4x6xf32>) -> (vector<2x3x4x3xf32>, vector<2x3x4x3xf32>) {
1150+
// CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x6xf32>
1151+
%0, %1 = vector.deinterleave %arg : vector<2x3x4x6xf32>
1152+
return %0, %1 : vector<2x3x4x3xf32>, vector<2x3x4x3xf32>
1153+
}
1154+
1155+
// CHECK-LABEL: @deinterleave_nd_scalable
1156+
func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>) {
1157+
// CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x[6]xf32>
1158+
%0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32>
1159+
return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
1160+
}

0 commit comments

Comments
 (0)