Skip to content

Commit 11dc393

Browse files
committed
[mlir][VectorOps] Add deinterleave operation to vector dialect
The deinterleave operation constructs two vectors from a single input vector. Each new vector is the collection of even and odd elements from the input, respectively. This is essentially the inverse of an interleave operation. Each output's size is half of the input vector's trailing dimension for the n-D case and only dimension for 1-D cases. It is not possible to conduct the operation on 0-D inputs or vectors where the size of the (trailing) dimension is 1. The operation supports scalable vectors. Example: ```mlir %0 = vector.deinterleave %a : vector<[4]xi32> -> vector<[2]xi32> %1 = vector.deinterleave %b : vector<8xi8> -> vector<4xi8> %2 = vector.deinterleave %c : vector<2x8xf32> -> vector<2x4xf32> %3 = vector.deinterleave %d : vector<2x4x[6]xf64> -> vector<2x4x[3]xf64> ```
1 parent 0987e00 commit 11dc393

File tree

3 files changed

+56
-52
lines changed

3 files changed

+56
-52
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -544,8 +544,8 @@ def Vector_InterleaveOp :
544544
}
545545

546546
class ResultIsHalfSourceVectorType<string result> : TypesMatchWith<
547-
"type of 'input' is double the width of results",
548-
"input", result,
547+
"the trailing dimension of the results is half the width of source trailing dimension",
548+
"source", result,
549549
[{
550550
[&]() -> ::mlir::VectorType {
551551
auto vectorType = ::llvm::cast<mlir::VectorType>($_self);
@@ -559,63 +559,67 @@ class ResultIsHalfSourceVectorType<string result> : TypesMatchWith<
559559
}]
560560
>;
561561

562-
def Vector_DeinterleaveOp :
563-
Vector_Op<"deinterleave", [Pure,
564-
PredOpTrait<"trailing dimension of input vector must be an even number",
562+
def SourceVectorEvenElementCount : PredOpTrait<
563+
"the trailing dimension of the source vector has an even number of elements",
565564
CPred<[{
566565
[&](){
567566
auto srcVec = getSourceVectorType();
568567
return srcVec.getDimSize(srcVec.getRank() - 1) % 2 == 0;
569568
}()
570-
}]>>,
569+
}]>
570+
>;
571+
572+
def Vector_DeinterleaveOp :
573+
Vector_Op<"deinterleave", [Pure,
574+
SourceVectorEvenElementCount,
571575
ResultIsHalfSourceVectorType<"res1">,
572-
ResultIsHalfSourceVectorType<"res2">,
573576
AllTypesMatch<["res1", "res2"]>
574577
]> {
575578
let summary = "constructs two vectors by deinterleaving an input vector";
576579
let description = [{
577580
The deinterleave operation constructs two vectors from a single input
578-
vector. Each new vector is the collection of even and odd elements
579-
from the input, respectively. This is essentially the inverse of an
580-
interleave operation.
581+
vector. The first result vector contains the elements from even indexes
582+
of the input, and the second contains elements from odd indexes. This is
583+
the inverse of a `vector.interleave` operation.
581584

582-
Each output's size is half of the input vector's trailing dimension
583-
for the n-D case and only dimension for 1-D cases. It is not possible
584-
to conduct the operation on 0-D inputs or vectors where the size of
585-
the (trailing) dimension is 1.
585+
Each output's trailing dimension is half of the size of the input
586+
vector's trailing dimension. This operation requires the input vector
587+
to have a rank > 0 and an even number of elements in its trailing
588+
dimension.
586589

587590
The operation supports scalable vectors.
588591

589592
Example:
590593
```mlir
591-
%0 = vector.deinterleave %a
592-
: vector<[4]xi32> ; yields vector<[2]xi32>, vector<[2]xi32>
593-
%1 = vector.deinterleave %b
594-
: vector<8xi8> ; yields vector<4xi8>, vector<4xi8>
595-
%2 = vector.deinterleave %c
596-
: vector<2x8xf32> ; yields vector<2x4xf32>, vector<2x4xf32>
597-
%3 = vector.deinterleave %d
598-
: vector<2x4x[6]xf64> ; yields vector<2x4x[3]xf64>, vector<2x4x[3]xf64>
594+
%0, %1 = vector.deinterleave %a
595+
:vector<8xi8> -> vector<4xi8>
596+
%2, %3 = vector.deinterleave %b
597+
: vector<2x8xi8> -> vector<2x4xi8>
598+
%4, %5 = vector.deinterleave %b
599+
: vector<2x8x4xi8> -> vector<2x8x2xi8>
600+
%6, %7 = vector.deinterleave %c
601+
: vector<[8]xf32> -> vector<[4]xf32>
602+
%8, %9 = vector.deinterleave %d
603+
: vector<2x[6]xf64> -> vector<2x[3]xf64>
604+
%10, %11 = vector.deinterleave %d
605+
: vector<2x4x[6]xf64> -> vector<2x4x[3]xf64>
599606
```
600607
}];
601608

602-
let arguments = (ins AnyVector:$input);
609+
let arguments = (ins AnyVector:$source);
603610
let results = (outs AnyVector:$res1, AnyVector:$res2);
604611

605612
let assemblyFormat = [{
606-
$input attr-dict `:` type($input)
613+
$source attr-dict `:` type($source) `->` type($res1)
607614
}];
608615

609616
let extraClassDeclaration = [{
610617
VectorType getSourceVectorType() {
611-
return ::llvm::cast<VectorType>(getInput().getType());
618+
return ::llvm::cast<VectorType>(getSource().getType());
612619
}
613-
VectorType getResultOneVectorType() {
620+
VectorType getResultVectorType() {
614621
return ::llvm::cast<VectorType>(getRes1().getType());
615622
}
616-
VectorType getResultTwoVectorType() {
617-
return ::llvm::cast<VectorType>(getRes2().getType());
618-
}
619623
}];
620624
}
621625

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,55 +1802,55 @@ func.func @invalid_outerproduct1(%src : memref<?xf32>) {
18021802
// -----
18031803

18041804
func.func @deinterleave_zero_dim_fail(%vec : vector<f32>) {
1805-
// expected-error @+1 {{'vector.deinterleave' 'input' must be vector of any type values, but got 'vector<f32>'}}
1806-
%0, %1 = vector.deinterleave %vec : vector<f32>
1805+
// expected-error @+1 {{'vector.deinterleave' op operand #0 must be vector of any type values, but got 'vector<f32>}}
1806+
%0, %1 = vector.deinterleave %vec : vector<f32> -> vector<f32>
18071807
return
18081808
}
18091809

18101810
// -----
18111811

18121812
func.func @deinterleave_one_dim_fail(%vec : vector<1xf32>) {
1813-
// expected-error @+1 {{'vector.deinterleave' op failed to verify that trailing dimension of input vector must be an even number}}
1814-
%0, %1 = vector.deinterleave %vec : vector<1xf32>
1813+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the source vector has an even number of elements}}
1814+
%0, %1 = vector.deinterleave %vec : vector<1xf32> -> vector<1xf32>
18151815
return
18161816
}
18171817

18181818
// -----
18191819

18201820
func.func @deinterleave_oversized_output_fail(%vec : vector<4xf32>) {
1821-
// expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
1821+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
18221822
%0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<8xf32>, vector<8xf32>)
18231823
return
18241824
}
18251825

18261826
// -----
18271827

18281828
func.func @deinterleave_output_dim_size_mismatch(%vec : vector<4xf32>) {
1829-
// expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
1829+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
18301830
%0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<4xf32>, vector<2xf32>)
18311831
return
18321832
}
18331833

18341834
// -----
18351835

18361836
func.func @deinterleave_n_dim_rank_fail(%vec : vector<2x3x4xf32>) {
1837-
// expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
1837+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
18381838
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x3x4xf32>) -> (vector<2x3x4xf32>, vector<2x3x2xf32>)
18391839
return
18401840
}
18411841

18421842
// -----
18431843

18441844
func.func @deinterleave_scalable_dim_size_fail(%vec : vector<2x[4]xf32>) {
1845-
// expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
1845+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that all of {res1, res2} have same type}}
18461846
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<2x[1]xf32>)
18471847
return
18481848
}
18491849

18501850
// -----
18511851

18521852
func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
1853-
// expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
1853+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that all of {res1, res2} have same type}}
18541854
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<[2]xf32>)
18551855
return
1856-
}
1856+
}

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,42 +1119,42 @@ func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>)
11191119

11201120
// CHECK-LABEL: @deinterleave_1d
11211121
func.func @deinterleave_1d(%arg: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
1122-
// CHECK: vector.deinterleave %{{.*}} : vector<4xf32>
1123-
%0, %1 = vector.deinterleave %arg : vector<4xf32>
1122+
// CHECK: vector.deinterleave %{{.*}} : vector<4xf32> -> vector<2xf32>
1123+
%0, %1 = vector.deinterleave %arg : vector<4xf32> -> vector<2xf32>
11241124
return %0, %1 : vector<2xf32>, vector<2xf32>
11251125
}
11261126

11271127
// CHECK-LABEL: @deinterleave_1d_scalable
11281128
func.func @deinterleave_1d_scalable(%arg: vector<[4]xf32>) -> (vector<[2]xf32>, vector<[2]xf32>) {
1129-
// CHECK: vector.deinterleave %{{.*}} : vector<[4]xf32>
1130-
%0, %1 = vector.deinterleave %arg : vector<[4]xf32>
1129+
// CHECK: vector.deinterleave %{{.*}} : vector<[4]xf32> -> vector<[2]xf32>
1130+
%0, %1 = vector.deinterleave %arg : vector<[4]xf32> -> vector<[2]xf32>
11311131
return %0, %1 : vector<[2]xf32>, vector<[2]xf32>
11321132
}
11331133

11341134
// CHECK-LABEL: @deinterleave_2d
11351135
func.func @deinterleave_2d(%arg: vector<3x4xf32>) -> (vector<3x2xf32>, vector<3x2xf32>) {
1136-
// CHECK: vector.deinterleave %{{.*}} : vector<3x4xf32>
1137-
%0, %1 = vector.deinterleave %arg : vector<3x4xf32>
1136+
// CHECK: vector.deinterleave %{{.*}} : vector<3x4xf32> -> vector<3x2xf32>
1137+
%0, %1 = vector.deinterleave %arg : vector<3x4xf32> -> vector<3x2xf32>
11381138
return %0, %1 : vector<3x2xf32>, vector<3x2xf32>
11391139
}
11401140

11411141
// CHECK-LABEL: @deinterleave_2d_scalable
11421142
func.func @deinterleave_2d_scalable(%arg: vector<3x[4]xf32>) -> (vector<3x[2]xf32>, vector<3x[2]xf32>) {
1143-
// CHECK: vector.deinterleave %{{.*}} : vector<3x[4]xf32>
1144-
%0, %1 = vector.deinterleave %arg : vector<3x[4]xf32>
1143+
// CHECK: vector.deinterleave %{{.*}} : vector<3x[4]xf32> -> vector<3x[2]xf32>
1144+
%0, %1 = vector.deinterleave %arg : vector<3x[4]xf32> -> vector<3x[2]xf32>
11451145
return %0, %1 : vector<3x[2]xf32>, vector<3x[2]xf32>
11461146
}
11471147

11481148
// CHECK-LABEL: @deinterleave_nd
11491149
func.func @deinterleave_nd(%arg: vector<2x3x4x6xf32>) -> (vector<2x3x4x3xf32>, vector<2x3x4x3xf32>) {
1150-
// CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x6xf32>
1151-
%0, %1 = vector.deinterleave %arg : vector<2x3x4x6xf32>
1150+
// CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x6xf32> -> vector<2x3x4x3xf32>
1151+
%0, %1 = vector.deinterleave %arg : vector<2x3x4x6xf32> -> vector<2x3x4x3xf32>
11521152
return %0, %1 : vector<2x3x4x3xf32>, vector<2x3x4x3xf32>
11531153
}
11541154

11551155
// CHECK-LABEL: @deinterleave_nd_scalable
11561156
func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>) {
1157-
// CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x[6]xf32>
1158-
%0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32>
1157+
// CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>
1158+
%0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>
11591159
return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
1160-
}
1160+
}

0 commit comments

Comments
 (0)