@@ -82,71 +82,69 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)
82
82
83
83
// -----
84
84
85
- // CHECK: #[[$MAP1:.* ]] = affine_map<(d0, d1) -> (d1)>
86
- // CHECK: #[[$MAP2:.* ]] = affine_map<(d0, d1) -> (d0, d1)>
85
+ // CHECK: #[[$MAP0:.+ ]] = affine_map<(d0, d1) -> (d1)>
86
+ // CHECK: #[[$MAP1:.+ ]] = affine_map<(d0, d1) -> (d0, d1)>
87
87
88
88
// CHECK-LABEL: @fully_connected
89
89
func.func @fully_connected (%arg0: tensor <5 x3 xf32 >, %arg1: tensor <6 x3 xf32 >, %arg2: tensor <6 xf32 >) -> (tensor <5 x6 xf32 >) {
90
- // CHECK: [[INITT:%.+]] = tensor.empty()
91
- // CHECK: [[ZERO:%.+]] = arith.constant 0
92
- // CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]]
93
- // CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]>
94
- // CHECK: [[TRANSPOSE:%.+]] = tosa.transpose %arg1, [[PERM]]
95
- // CHECK: [[INITB:%.+]] = tensor.empty()
96
- // CHECK: [[MATMUL:%.+]] = linalg.matmul ins(%arg0, [[TRANSPOSE]] : tensor<5x3xf32>, tensor<3x6xf32>) outs([[FILL]] : tensor<5x6xf32>) -> tensor<5x6xf32>
97
- // CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xf32>, tensor<5x6xf32>) outs([[INITB]] : tensor<5x6xf32>) {
98
- // CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: f32, %[[ARG4:[0-9a-zA-Z_]+]]: f32, %[[ARG5:[0-9a-zA-Z_]+]]: f32):
99
- // CHECK: [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
100
- // CHECK: linalg.yield [[ADD]] : f32
90
+ // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
91
+ // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
92
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32>
93
+
94
+ // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<5x6xf32>) {
95
+ // CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
96
+ // CHECK: linalg.yield %[[IN]] : f32
97
+ // CHECK: } -> tensor<5x6xf32>
98
+
99
+ // CHECK: linalg.matmul ins(%arg0, %[[TRANSPOSED]] : tensor<5x3xf32>, tensor<3x6xf32>) outs(%[[BROADCAST]] : tensor<5x6xf32>) -> tensor<5x6xf32>
101
100
102
101
%0 = tosa.fully_connected %arg0 , %arg1 , %arg2 : (tensor <5 x3 xf32 >, tensor <6 x3 xf32 >, tensor <6 xf32 >) -> tensor <5 x6 xf32 >
103
102
return %0 : tensor <5 x6 xf32 >
104
103
}
105
104
106
105
// -----
107
106
108
- // CHECK: #[[$MAP1:.* ]] = affine_map<(d0, d1) -> (d1)>
109
- // CHECK: #[[$MAP2:.* ]] = affine_map<(d0, d1) -> (d0, d1)>
107
+ // CHECK: #[[$MAP0:.+ ]] = affine_map<(d0, d1) -> (d1)>
108
+ // CHECK: #[[$MAP1:.+ ]] = affine_map<(d0, d1) -> (d0, d1)>
110
109
111
110
// CHECK-LABEL: @quantized_fully_connected
112
111
func.func @quantized_fully_connected (%arg0: tensor <5 x3 xi8 >, %arg1: tensor <6 x3 xi8 >, %arg2: tensor <6 xi32 >) -> (tensor <5 x6 xi32 >) {
113
- // CHECK: [[INITT:% .+]] = tensor.empty()
114
- // CHECK: [[ZERO:% .+]] = arith.constant 0
115
- // CHECK: [[FILL:% .+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]]
116
- // CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]>
117
- // CHECK: [[TRANSPOSE:% .+]] = tosa.transpose %arg1, [[PERM]]
118
- // CHECK: [[INITB:% .+]] = tensor.empty()
119
- // CHECK: [[ONE:%.+ ]] = arith.constant 1
120
- // CHECK: [[TWO:%.+]] = arith.constant 2
121
- // CHECK: [[MATMUL:%.+]] = linalg.quantized_matmul ins(%arg0, [[TRANSPOSE]], [[ONE]], [[TWO]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs([[FILL]] : tensor<5x6xi32>) -> tensor<5x6xi32>
122
- // CHECK: [[ADDED:% .+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xi32>, tensor<5x6xi32>) outs([[INITB]]
123
- // CHECK: ^bb0([[IN1:% .+]]: i32, [[IN2:%.+]]: i32, [[UNUSED:%.+]]: i32):
124
- // CHECK: [[ADD:%.+ ]] = arith.addi
125
- // CHECK: linalg.yield [[ADD]] : i32
112
+ // CHECK: %[[PERM: .+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
113
+ // CHECK: %[[TRANSPOSE: .+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xi8>, tensor<2xi64>) -> tensor<3x6xi8>
114
+ // CHECK: %[[INIT: .+]] = tensor.empty() : tensor<5x6xi32>
115
+
116
+ // CHECK: %[[BROADCAST: .+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xi32>) outs(%[[INIT]] : tensor<5x6xi32>) {
117
+ // CHECK: ^bb0(%[[IN: .+]]: i32, %[[OUT:.+]]: i32):
118
+ // CHECK: linalg.yield %[[IN ]] : i32
119
+ // CHECK: } -> tensor<5x6xi32>
120
+
121
+ // CHECK: %[[C1: .+]] = arith.constant 1 : i32
122
+ // CHECK: %[[C2: .+]] = arith.constant 2 : i32
123
+ // CHECK: linalg.quantized_matmul ins(%arg0, %[[TRANSPOSE]], %[[C1]], %[[C2]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs(%[[BROADCAST ]] : tensor<5x6xi32>) -> tensor<5x6xi32>
124
+
126
125
%0 = tosa.fully_connected %arg0 , %arg1 , %arg2 {quantization_info = #tosa.conv_quant <input_zp = 1 , weight_zp = 2 >} : (tensor <5 x3 xi8 >, tensor <6 x3 xi8 >, tensor <6 xi32 >) -> tensor <5 x6 xi32 >
127
126
return %0 : tensor <5 x6 xi32 >
128
127
}
129
128
130
129
// -----
131
130
132
- // CHECK: #[[$MAP1:.* ]] = affine_map<(d0, d1) -> (d1)>
133
- // CHECK: #[[$MAP2:.* ]] = affine_map<(d0, d1) -> (d0, d1)>
131
+ // CHECK: #[[$MAP0:.+ ]] = affine_map<(d0, d1) -> (d1)>
132
+ // CHECK: #[[$MAP1:.+ ]] = affine_map<(d0, d1) -> (d0, d1)>
134
133
135
134
// CHECK-LABEL: @fully_connected_dyn
136
135
func.func @fully_connected_dyn (%arg0: tensor <?x3 xf32 >, %arg1: tensor <6 x3 xf32 >, %arg2: tensor <6 xf32 >) -> (tensor <?x6 xf32 >) {
137
- // CHECK: %[[C0:.+]] = arith.constant 0
138
- // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]]
139
- // CHECK: %[[INITT:.+]] = tensor.empty(%[[DIM]])
140
- // CHECK: %[[ZERO:.+]] = arith.constant 0
141
- // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INITT]]
142
- // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]>
143
- // CHECK: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]]
144
- // CHECK: %[[INITB:.+]] = tensor.empty(%[[DIM]])
145
- // CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%arg0, %[[TRANSPOSE]] : tensor<?x3xf32>, tensor<3x6xf32>) outs(%[[FILL]] : tensor<?x6xf32>) -> tensor<?x6xf32>
146
- // CHECK: %[[ADDED:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, %[[MATMUL]] : tensor<6xf32>, tensor<?x6xf32>) outs(%[[INITB]] : tensor<?x6xf32>) {
147
- // CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: f32, %[[ARG4:[0-9a-zA-Z_]+]]: f32, %[[ARG5:[0-9a-zA-Z_]+]]: f32):
148
- // CHECK: %[[ADD:.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
149
- // CHECK: linalg.yield %[[ADD]] : f32
136
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
137
+ // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %c0 : tensor<?x3xf32>
138
+ // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
139
+ // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
140
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x6xf32>
141
+
142
+ // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<?x6xf32>) {
143
+ // CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
144
+ // CHECK: linalg.yield %[[IN]] : f32
145
+ // CHECK: } -> tensor<?x6xf32>
146
+
147
+ // CHECK: linalg.matmul ins(%arg0, %[[TRANSPOSED]] : tensor<?x3xf32>, tensor<3x6xf32>) outs(%[[BROADCAST]] : tensor<?x6xf32>) -> tensor<?x6xf32>
150
148
151
149
%0 = tosa.fully_connected %arg0 , %arg1 , %arg2 : (tensor <?x3 xf32 >, tensor <6 x3 xf32 >, tensor <6 xf32 >) -> tensor <?x6 xf32 >
152
150
return %0 : tensor <?x6 xf32 >
0 commit comments