@@ -88,7 +88,8 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)
88
88
// CHECK-LABEL: @fully_connected
89
89
func.func @fully_connected (%arg0: tensor <5 x3 xf32 >, %arg1: tensor <6 x3 xf32 >, %arg2: tensor <6 xf32 >) -> (tensor <5 x6 xf32 >) {
90
90
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
91
- // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
91
+ // CHECK: %[[TRANSPOSEDINIT:.+]] = tensor.empty() : tensor<3x6xf32>
92
+ // CHECK: %[[TRANSPOSED:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xf32>) outs(%[[TRANSPOSEDINIT]] : tensor<3x6xf32>) permutation = [1, 0]
92
93
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32>
93
94
94
95
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<5x6xf32>) {
@@ -110,7 +111,7 @@ func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2
110
111
// CHECK-LABEL: @quantized_fully_connected
111
112
func.func @quantized_fully_connected (%arg0: tensor <5 x3 xi8 >, %arg1: tensor <6 x3 xi8 >, %arg2: tensor <6 xi32 >) -> (tensor <5 x6 xi32 >) {
112
113
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
113
- // CHECK: %[[TRANSPOSE:.+]] = tosa .transpose %arg1, %[[PERM ]] : ( tensor<6x3xi8>, tensor<2xi64>) -> tensor<3x6xi8>
114
+ // CHECK: %[[TRANSPOSE:.+]] = linalg .transpose ins( %arg1 : tensor<6x3xi8>) outs( %[[TRANSPOSEDINIT:.+ ]] : tensor<3x6xi8>) permutation = [1, 0]
114
115
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xi32>
115
116
116
117
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xi32>) outs(%[[INIT]] : tensor<5x6xi32>) {
@@ -136,7 +137,7 @@ func.func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %
136
137
// CHECK: %[[C0:.+]] = arith.constant 0 : index
137
138
// CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %c0 : tensor<?x3xf32>
138
139
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
139
- // CHECK: %[[TRANSPOSED:.+]] = tosa .transpose %arg1, %[[PERM ]] : ( tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
140
+ // CHECK: %[[TRANSPOSED:.+]] = linalg .transpose ins( %arg1 : tensor<6x3xf32>) outs( %[[TRANSPOSEDINIT:.+ ]] : tensor<3x6xf32>) permutation = [1, 0]
140
141
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x6xf32>
141
142
142
143
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<?x6xf32>) {
@@ -377,7 +378,7 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
377
378
// CHECK-LABEL: @conv2d_i8
378
379
func.func @conv2d_i8 (%input: tensor <1 x49 x42 x27 xi8 >, %weights: tensor <28 x1 x1 x27 xi8 >, %bias: tensor <28 xi8 >) -> () {
379
380
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
380
- // HWCF: %[[TRANSPOSE:.+]] = tosa .transpose %arg1, %[[TRANSPOSE_DIMS ]] : ( tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
381
+ // HWCF: %[[TRANSPOSE:.+]] = linalg .transpose ins( %arg1 : tensor<28x1x1x27xi8>) outs( %[[TRANSPOSEDINIT:.+ ]] : tensor<1x1x27x28xi8>) permutation = [1, 2, 3, 0]
381
382
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
382
383
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
383
384
// CHECK: arith.extsi
@@ -398,7 +399,7 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
398
399
// CHECK-LABEL: @conv2d_f32
399
400
func.func @conv2d_f32 (%input: tensor <1 x49 x42 x27 xf32 >, %weights: tensor <28 x3 x3 x27 xf32 >, %bias: tensor <28 xf32 >) -> () {
400
401
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
401
- // HWCF: %[[TRANSPOSE:.+]] = tosa .transpose %arg1, %[[TRANSPOSE_DIMS ]] : ( tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
402
+ // HWCF: %[[TRANSPOSE:.+]] = linalg .transpose ins( %arg1 : tensor<28x3x3x27xf32>) outs( %[[TRANSPOSEDINIT:.+ ]] : tensor<3x3x27x28xf32>) permutation = [1, 2, 3, 0]
402
403
403
404
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
404
405
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) {
@@ -677,7 +678,7 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
677
678
// CHECK-LABEL: @conv3d_f32
678
679
func.func @conv3d_f32 (%input: tensor <1 x49 x48 x47 x27 xf32 >, %weights: tensor <28 x3 x4 x5 x27 xf32 >, %bias: tensor <28 xf32 >) -> () {
679
680
// CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
680
- // CHECK-DAG: %[[TRANSPOSE:.+]] = tosa .transpose %arg1, %[[PERMS] ]
681
+ // CHECK-DAG: %[[TRANSPOSE:.+]] = linalg .transpose ins( %arg1 : tensor<28x3x4x5x27xf32>) outs( %[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xf32>) permutation = [1, 2, 3, 4, 0 ]
681
682
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
682
683
// CHECK: %[[BROADCAST:.+]] = linalg.generic
683
684
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
@@ -701,7 +702,7 @@ func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4
701
702
// CHECK-LABEL: @conv3d_i8
702
703
func.func @conv3d_i8 (%input: tensor <1 x49 x48 x47 x27 xi8 >, %weights: tensor <28 x3 x4 x5 x27 xi8 >, %bias: tensor <28 xi32 >) -> () {
703
704
// CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
704
- // CHECK-DAG: %[[TRANSPOSE:.+]] = tosa .transpose %arg1, %[[PERMS] ]
705
+ // CHECK-DAG: %[[TRANSPOSE:.+]] = linalg .transpose ins( %arg1 : tensor<28x3x4x5x27xi8>) outs( %[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xi8>) permutation = [1, 2, 3, 4, 0 ]
705
706
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xi32>
706
707
// CHECK: %[[BROADCAST:.+]] = linalg.generic
707
708
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
@@ -720,3 +721,63 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5
720
721
%0 = tosa.conv3d %input , %weights , %bias {pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, quantization_info = #tosa.conv_quant <input_zp = -128 , weight_zp = 42 >, stride = array<i64 : 1 , 1 , 1 >, dilation = array<i64 : 1 , 1 , 1 >} : (tensor <1 x49 x48 x47 x27 xi8 >, tensor <28 x3 x4 x5 x27 xi8 >, tensor <28 xi32 >) -> tensor <1 x47 x45 x43 x28 xi32 >
721
722
return
722
723
}
724
+
725
+ // -----
726
+
727
+ // CHECK-LABEL: @test_transpose
728
+ // CHECK-SAME: (%[[ARG0:.+]]: tensor<1x2x3xi32>)
729
+ func.func @test_transpose (%arg0: tensor <1 x2 x3 xi32 >) -> () {
730
+ %0 = arith.constant dense <[1 , 2 , 0 ]> : tensor <3 xi32 >
731
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<2x3x1xi32>
732
+ // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<1x2x3xi32>) outs(%[[INIT]] : tensor<2x3x1xi32>) permutation = [1, 2, 0]
733
+ %1 = tosa.transpose %arg0 , %0 : (tensor <1 x2 x3 xi32 >, tensor <3 xi32 >) -> tensor <2 x3 x1 xi32 >
734
+ return
735
+ }
736
+
737
+ // -----
738
+
739
+ // CHECK-LABEL: @test_transpose_dyn
740
+ // CHECK-SAME: (%[[ARG0:.+]]: tensor<1x?x3x4xi32>)
741
+ func.func @test_transpose_dyn (%arg0: tensor <1 x?x3 x4 xi32 >) -> () {
742
+ %0 = arith.constant dense <[1 , 3 , 0 , 2 ]> : tensor <4 xi32 >
743
+ // CHECK: %[[C1:.+]] = arith.constant 1
744
+ // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]]
745
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]]) : tensor<?x4x1x3xi32>
746
+ // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<1x?x3x4xi32>) outs(%[[INIT]] : tensor<?x4x1x3xi32>) permutation = [1, 3, 0, 2]
747
+ %1 = tosa.transpose %arg0 , %0 : (tensor <1 x?x3 x4 xi32 >, tensor <4 xi32 >) -> tensor <?x4 x1 x3 xi32 >
748
+ return
749
+ }
750
+
751
+ // -----
752
+
753
+ // CHECK-LABEL: @test_transpose_dyn_multiple_2d
754
+ // CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>)
755
+ func.func @test_transpose_dyn_multiple_2d (%arg0: tensor <?x?xf32 >) -> () {
756
+ %0 = arith.constant dense <[1 , 0 ]> : tensor <2 xi32 >
757
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0
758
+ // CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
759
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1
760
+ // CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
761
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]])
762
+ // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<?x?xf32>) outs(%[[INIT]] : tensor<?x?xf32>) permutation = [1, 0]
763
+ %1 = tosa.transpose %arg0 , %0 : (tensor <?x?xf32 >, tensor <2 xi32 >) -> tensor <?x?xf32 >
764
+ return
765
+ }
766
+
767
+ // -----
768
+
769
+ // CHECK-LABEL: @test_transpose_dyn_multiple_3d
770
+ // CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?x?xf32>)
771
+ func.func @test_transpose_dyn_multiple_3d (%arg0: tensor <?x?x?xf32 >) {
772
+ %0 = arith.constant dense <[2 , 0 , 1 ]> : tensor <3 xi32 >
773
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
774
+ // CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
775
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
776
+ // CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
777
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
778
+ // CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
779
+ // CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM2]], %[[DIM0]], %[[DIM1]]) : tensor<?x?x?xf32>
780
+ // CHECK: %[[TRANSPOSE:.*]] = linalg.transpose ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?xf32>) permutation = [2, 0, 1]
781
+ %1 = " tosa.transpose" (%arg0 , %0 ) : (tensor <?x?x?xf32 >, tensor <3 xi32 >) -> tensor <?x?x?xf32 >
782
+ return
783
+ }
0 commit comments