Skip to content

Commit 64b06ba

Browse files
committed
[mlir][Vector] Add vector.to_elements op
This PR introduces the `vector.to_elements` op, which decomposes a vector into its scalar elements. This operation is symmetrical to the existing `vector.from_elements`. Examples: ``` // Decompose a 0-D vector. %0 = vector.to_elements %v0 : vector<f32> // %0 = %v0[0] // Decompose a 1-D vector. %0:2 = vector.to_elements %v1 : vector<2xf32> // %0#0 = %v1[0] // %0#1 = %v1[1] // Decompose a 2-D. %0:6 = vector.to_elements %v2 : vector<2x3xf32> // %0#0 = %v2[0, 0] // %0#1 = %v2[0, 1] // %0#2 = %v2[0, 2] // %0#3 = %v2[1, 0] // %0#4 = %v2[1, 1] // %0#5 = %v2[1, 2] ``` This op is aimed at reducing code size when modeling "structured" vector extractions and simplifying canonicalizations of large sequences of `vector.extract` and `vector.insert` ops into `vector.shuffle` and other sophisticated ops that can re-arrange vector elements. More related PRs to come!
1 parent 773d357 commit 64b06ba

File tree

3 files changed

+105
-15
lines changed

3 files changed

+105
-15
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,57 @@ def Vector_FMAOp :
790790
}];
791791
}
792792

793+
def Vector_ToElementsOp : Vector_Op<"to_elements", [
794+
Pure,
795+
TypesMatchWith<"operand element type matches result types",
796+
"input", "elements", "SmallVector<Type>("
797+
"::llvm::cast<VectorType>($_self).getNumElements(), "
798+
"::llvm::cast<VectorType>($_self).getElementType())">]> {
799+
let summary = "operation that decomposes a vector into all its scalar elements";
800+
let description = [{
801+
This operation decomposes all the scalar elements from a vector. The
802+
decomposed scalar elements are returned in row-major order. The number of
803+
scalar results must match the number of elements in the input vector type.
804+
All the result elements have the same result type, which must match the
805+
element type of the input vector. Scalable vectors are not supported.
806+
807+
Examples:
808+
809+
```mlir
810+
// Decompose a 0-D vector.
811+
%0 = vector.to_elements %v0 : vector<f32>
812+
// %0 = %v0[0]
813+
814+
// Decompose a 1-D vector.
815+
%0:2 = vector.to_elements %v1 : vector<2xf32>
816+
// %0#0 = %v1[0]
817+
// %0#1 = %v1[1]
818+
819+
// Decompose a 2-D.
820+
%0:6 = vector.to_elements %v2 : vector<2x3xf32>
821+
// %0#0 = %v2[0, 0]
822+
// %0#1 = %v2[0, 1]
823+
// %0#2 = %v2[0, 2]
824+
// %0#3 = %v2[1, 0]
825+
// %0#4 = %v2[1, 1]
826+
// %0#5 = %v2[1, 2]
827+
828+
// Decompose a 3-D vector.
829+
%0:6 = vector.to_elements %v3 : vector<3x1x2xf32>
830+
// %0#0 = %v3[0, 0, 0]
831+
// %0#1 = %v3[0, 0, 1]
832+
// %0#2 = %v3[1, 0, 0]
833+
// %0#3 = %v3[1, 0, 1]
834+
// %0#4 = %v3[2, 0, 0]
835+
// %0#5 = %v3[2, 0, 1]
836+
```
837+
}];
838+
839+
let arguments = (ins AnyVectorOfAnyRank:$input);
840+
let results = (outs Variadic<AnyType>:$elements);
841+
let assemblyFormat = "$input attr-dict `:` type($input)";
842+
}
843+
793844
def Vector_FromElementsOp : Vector_Op<"from_elements", [
794845
Pure,
795846
TypesMatchWith<"operand types match result element type",
@@ -799,26 +850,30 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
799850
let summary = "operation that defines a vector from scalar elements";
800851
let description = [{
801852
This operation defines a vector from one or multiple scalar elements. The
802-
number of elements must match the number of elements in the result type.
803-
All elements must have the same type, which must match the element type of
804-
the result vector type.
805-
806-
`elements` are a flattened version of the result vector in row-major order.
853+
scalar elements are arranged in row-major within the vector. The number of
854+
elements must match the number of elements in the result type. All elements
855+
must have the same type, which must match the element type of the result
856+
vector type. Scalable vectors are not supported.
807857

808-
Example:
858+
Examples:
809859

810860
```mlir
811-
// %f1
861+
// Define a 0-D vector.
812862
%0 = vector.from_elements %f1 : vector<f32>
813-
// [%f1, %f2]
863+
// [%f1]
864+
865+
// Define a 1-D vector.
814866
%1 = vector.from_elements %f1, %f2 : vector<2xf32>
815-
// [[%f1, %f2, %f3], [%f4, %f5, %f6]]
867+
// [%f1, %f2]
868+
869+
// Define a 2-D vector.
816870
%2 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<2x3xf32>
817-
// [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
871+
// [[%f1, %f2, %f3], [%f4, %f5, %f6]]
872+
873+
// Define a 3-D vector.
818874
%3 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<3x1x2xf32>
875+
// [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
819876
```
820-
821-
Note, scalable vectors are not supported.
822877
}];
823878

824879
let arguments = (ins Variadic<AnyType>:$elements);

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1896,7 +1896,24 @@ func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
18961896

18971897
// -----
18981898

1899-
func.func @invalid_from_elements(%a: f32) {
1899+
func.func @to_elements_wrong_num_results(%a: vector<1x1x2xf32>) {
1900+
// expected-error @+1 {{operation defines 2 results but was provided 4 to bind}}
1901+
%0:4 = vector.to_elements %a : vector<1x1x2xf32>
1902+
return
1903+
}
1904+
1905+
// -----
1906+
1907+
func.func @to_elements_wrong_result_type(%a: vector<2xf32>) -> i32 {
1908+
// expected-error @+3 {{use of value '%0' expects different type than prior uses: 'i32'}}
1909+
// expected-note @+1 {{prior use here}}
1910+
%0:2 = vector.to_elements %a : vector<2xf32>
1911+
return %0#0 : i32
1912+
}
1913+
1914+
// -----
1915+
1916+
func.func @from_elements_wrong_num_operands(%a: f32) {
19001917
// expected-error @+1 {{'vector.from_elements' number of operands and types do not match: got 1 operands and 2 types}}
19011918
vector.from_elements %a : vector<2xf32>
19021919
return
@@ -1905,12 +1922,11 @@ func.func @invalid_from_elements(%a: f32) {
19051922
// -----
19061923

19071924
// expected-note @+1 {{prior use here}}
1908-
func.func @invalid_from_elements(%a: f32, %b: i32) {
1925+
func.func @from_elements_wrong_operand_type(%a: f32, %b: i32) {
19091926
// expected-error @+1 {{use of value '%b' expects different type than prior uses: 'f32' vs 'i32'}}
19101927
vector.from_elements %a, %b : vector<2xf32>
19111928
return
19121929
}
1913-
19141930
// -----
19151931

19161932
func.func @invalid_from_elements_scalable(%a: f32, %b: i32) {

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,6 +1175,25 @@ func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4
11751175
return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
11761176
}
11771177

1178+
// CHECK-LABEL: func @to_elements(
1179+
// CHECK-SAME: %[[A_VEC:.*]]: vector<f32>, %[[B_VEC:.*]]: vector<4xf32>,
1180+
// CHECK-SAME: %[[C_VEC:.*]]: vector<1xf32>, %[[D_VEC:.*]]: vector<2x2xf32>)
1181+
func.func @to_elements(%a_vec : vector<f32>, %b_vec : vector<4xf32>, %c_vec : vector<1xf32>, %d_vec : vector<2x2xf32>)
1182+
-> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
1183+
// CHECK: %[[A_ELEMS:.*]] = vector.to_elements %[[A_VEC]] : vector<f32>
1184+
%0 = vector.to_elements %a_vec : vector<f32>
1185+
// CHECK: %[[B_ELEMS:.*]]:4 = vector.to_elements %[[B_VEC]] : vector<4xf32>
1186+
%1:4 = vector.to_elements %b_vec : vector<4xf32>
1187+
// CHECK: %[[C_ELEMS:.*]] = vector.to_elements %[[C_VEC]] : vector<1xf32>
1188+
%2 = vector.to_elements %c_vec : vector<1xf32>
1189+
// CHECK: %[[D_ELEMS:.*]]:4 = vector.to_elements %[[D_VEC]] : vector<2x2xf32>
1190+
%3:4 = vector.to_elements %d_vec : vector<2x2xf32>
1191+
// CHECK: return %[[A_ELEMS]], %[[B_ELEMS]]#0, %[[B_ELEMS]]#1, %[[B_ELEMS]]#2,
1192+
// CHECK-SAME: %[[B_ELEMS]]#3, %[[C_ELEMS]], %[[D_ELEMS]]#0, %[[D_ELEMS]]#1,
1193+
// CHECK-SAME: %[[D_ELEMS]]#2, %[[D_ELEMS]]#3
1194+
return %0, %1#0, %1#1, %1#2, %1#3, %2, %3#0, %3#1, %3#2, %3#3 : f32, f32, f32, f32, f32, f32, f32, f32, f32, f32
1195+
}
1196+
11781197
// CHECK-LABEL: func @from_elements(
11791198
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
11801199
func.func @from_elements(%a: f32, %b: f32) -> (vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>) {

0 commit comments

Comments
 (0)