Skip to content

Commit 6e81eae

Browse files
[mlir][Vector] Support 0-D vectors in TransposeOp
Co-authored-by: Michal Terepeta <[email protected]> Reviewed-by: ftynse Differential Revision: https://reviews.llvm.org/D115743
1 parent 1736f76 commit 6e81eae

File tree

5 files changed

+56
-3
lines changed

5 files changed

+56
-3
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2229,12 +2229,13 @@ def Vector_TransposeOp :
22292229
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
22302230
PredOpTrait<"operand and result have same element type",
22312231
TCresVTEtIsSameAsOpBase<0, 0>>]>,
2232-
Arguments<(ins AnyVector:$vector, I64ArrayAttr:$transp)>,
2233-
Results<(outs AnyVector:$result)> {
2232+
Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$transp)>,
2233+
Results<(outs AnyVectorOfAnyRank:$result)> {
22342234
let summary = "vector transpose operation";
22352235
let description = [{
22362236
Takes a n-D vector and returns the transposed n-D vector defined by
2237-
the permutation of ranks in the n-sized integer array attribute.
2237+
the permutation of ranks in the n-sized integer array attribute (in case
2238+
of 0-D vectors the array attribute must be empty).
22382239
In the operation
22392240

22402241
```mlir

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,6 +1760,8 @@ func.func @create_mask_1d(%a : index) -> vector<4xi1> {
17601760
// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<4xi32>
17611761
// CHECK: return %[[result]] : vector<4xi1>
17621762

1763+
// -----
1764+
17631765
func.func @create_mask_1d_scalable(%a : index) -> vector<[4]xi1> {
17641766
%v = vector.create_mask %a : vector<[4]xi1>
17651767
return %v: vector<[4]xi1>
@@ -1776,6 +1778,17 @@ func.func @create_mask_1d_scalable(%a : index) -> vector<[4]xi1> {
17761778

17771779
// -----
17781780

1781+
func.func @transpose_0d(%arg0: vector<f32>) -> vector<f32> {
1782+
%0 = vector.transpose %arg0, [] : vector<f32> to vector<f32>
1783+
return %0 : vector<f32>
1784+
}
1785+
1786+
// CHECK-LABEL: func @transpose_0d
1787+
// CHECK-SAME: %[[A:.*]]: vector<f32>
1788+
// CHECK: return %[[A]] : vector<f32>
1789+
1790+
// -----
1791+
17791792
func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
17801793
%0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
17811794
: vector<16xf32> -> vector<16xf32>

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,11 +1145,25 @@ func.func @multi_reduce_invalid_type(%arg0: vector<4x16xf32>, %acc: vector<16xf3
11451145

11461146
// -----
11471147

1148+
func.func @transpose_rank_mismatch_0d(%arg0: vector<f32>) {
1149+
// expected-error@+1 {{'vector.transpose' op vector result rank mismatch: 1}}
1150+
%0 = vector.transpose %arg0, [] : vector<f32> to vector<100xf32>
1151+
}
1152+
1153+
// -----
1154+
11481155
func.func @transpose_rank_mismatch(%arg0: vector<4x16x11xf32>) {
11491156
// expected-error@+1 {{'vector.transpose' op vector result rank mismatch: 1}}
11501157
%0 = vector.transpose %arg0, [2, 1, 0] : vector<4x16x11xf32> to vector<100xf32>
11511158
}
11521159

1160+
// -----
1161+
1162+
func.func @transpose_length_mismatch_0d(%arg0: vector<f32>) {
1163+
// expected-error@+1 {{'vector.transpose' op transposition length mismatch: 1}}
1164+
%0 = vector.transpose %arg0, [1] : vector<f32> to vector<f32>
1165+
}
1166+
11531167
// -----
11541168

11551169
func.func @transpose_length_mismatch(%arg0: vector<4x4xf32>) {

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,22 @@ func.func @transpose_int(%arg0: vector<11x7x3x2xi32>) -> vector<2x11x7x3xi32> {
570570
return %0 : vector<2x11x7x3xi32>
571571
}
572572

573+
// CHECK-LABEL: @transpose_fp_0d
574+
func.func @transpose_fp_0d(%arg0: vector<f32>) -> vector<f32> {
575+
// CHECK: %[[X:.*]] = vector.transpose %{{.*}}, [] : vector<f32> to vector<f32>
576+
%0 = vector.transpose %arg0, [] : vector<f32> to vector<f32>
577+
// CHECK: return %[[X]] : vector<f32>
578+
return %0 : vector<f32>
579+
}
580+
581+
// CHECK-LABEL: @transpose_int_0d
582+
func.func @transpose_int_0d(%arg0: vector<i32>) -> vector<i32> {
583+
// CHECK: %[[X:.*]] = vector.transpose %{{.*}}, [] : vector<i32> to vector<i32>
584+
%0 = vector.transpose %arg0, [] : vector<i32> to vector<i32>
585+
// CHECK: return %[[X]] : vector<i32>
586+
return %0 : vector<i32>
587+
}
588+
573589
// CHECK-LABEL: @flat_transpose_fp
574590
func.func @flat_transpose_fp(%arg0: vector<16xf32>) -> vector<16xf32> {
575591
// CHECK: %[[X:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>

mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ func.func @fma_0d(%four: vector<f32>) {
120120
return
121121
}
122122

123+
func.func @transpose_0d(%arg: vector<i32>) {
124+
%1 = vector.transpose %arg, [] : vector<i32> to vector<i32>
125+
// CHECK: ( 42 )
126+
vector.print %1: vector<i32>
127+
return
128+
}
129+
123130
func.func @entry() {
124131
%0 = arith.constant 42.0 : f32
125132
%1 = arith.constant dense<0.0> : vector<f32>
@@ -151,6 +158,8 @@ func.func @entry() {
151158

152159
%5 = arith.constant dense<4.0> : vector<f32>
153160
call @fma_0d(%5) : (vector<f32>) -> ()
161+
%6 = arith.constant dense<42> : vector<i32>
162+
call @transpose_0d(%6) : (vector<i32>) -> ()
154163

155164
return
156165
}

0 commit comments

Comments
 (0)