@@ -36,9 +36,10 @@ module attributes {transform.with_named_sequence} {
36
36
/// w == 0, kw = 0
37
37
// CHECK: %[[SC_INPUT:.*]] = vector.shape_cast %[[V_INPUT_R]] : vector<1x8x3xi8> to vector<1x24xi8>
38
38
// CHECK: %[[SC_OUTPUT:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<1x8x3xi8> to vector<1x24xi8>
39
- // CHECK: %[[B_FILTER:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<3xi8> to vector<1x8x3xi8>
40
- // CHECK: %[[SC_FILTER:.*]] = vector.shape_cast %[[B_FILTER]] : vector<1x8x3xi8> to vector<1x24xi8>
41
- // CHECK: %[[MULI:.*]] = arith.muli %[[SC_INPUT]], %[[SC_FILTER]] : vector<1x24xi8>
39
+ // CHECK: %[[SH_FILTER_0:.*]] = vector.shuffle %[[V_FILTER_0]], %[[V_FILTER_0]]
40
+ // CHECK-SAME: [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
41
+ // CHECK: %[[B_FILTER:.*]] = vector.broadcast %[[SH_FILTER_0]] : vector<24xi8> to vector<1x24xi8>
42
+ // CHECK: %[[MULI:.*]] = arith.muli %[[SC_INPUT]], %[[B_FILTER]] : vector<1x24xi8>
42
43
// CHECK: %[[ADDI:.*]] = arith.addi %[[MULI]], %[[SC_OUTPUT]] : vector<1x24xi8>
43
44
44
45
// Write the result back in one shot.
@@ -80,15 +81,17 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3
80
81
/// w == 0, kw = 0
81
82
// CHECK: %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xf32> to vector<3x8xf32>
82
83
// CHECK: %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xf32> to vector<3x8xf32>
83
- // CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xf32> to vector<3x2x4xf32>
84
- // CHECK: %[[SC_B_FILTER_0:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x2x4xf32> to vector<3x8xf32>
85
- // CHECK: %[[FMA_0:.*]] = vector.fma %[[SC_V_INPUT_0]], %[[SC_B_FILTER_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xf32>
84
+ // CHECK: %[[SH_FILTER_0:.*]] = vector.shuffle %[[V_FILTER_0]], %[[V_FILTER_0]]
85
+ // CHECK-SAME: [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
86
+ // CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[SH_FILTER_0]] : vector<8xf32> to vector<3x8xf32>
87
+ // CHECK: %[[FMA_0:.*]] = vector.fma %[[SC_V_INPUT_0]], %[[B_FILTER_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xf32>
86
88
87
89
/// w == 0, kw = 1
88
90
// CHECK: %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xf32> to vector<3x8xf32>
89
- // CHECK: %[[B_V_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xf32> to vector<3x2x4xf32>
90
- // CHECK: %[[SC_B_FILTER_1:.*]] = vector.shape_cast %[[B_V_FILTER_1]] : vector<3x2x4xf32> to vector<3x8xf32>
91
- // CHECK: %[[FMA_1:.*]] = vector.fma %[[SC_V_INPUT_1]], %[[SC_B_FILTER_1]], %[[FMA_0]] : vector<3x8xf32>
91
+ // CHECK: %[[SH_FILTER_1:.*]] = vector.shuffle %[[V_FILTER_1]], %[[V_FILTER_1]]
92
+ // CHECK-SAME: [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
93
+ // CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[SH_FILTER_1]] : vector<8xf32> to vector<3x8xf32>
94
+ // CHECK: %[[FMA_1:.*]] = vector.fma %[[SC_V_INPUT_1]], %[[B_FILTER_1]], %[[FMA_0]] : vector<3x8xf32>
92
95
93
96
// Write the result back in one shot.
94
97
// CHECK: %[[SC_FMA_1:.*]] = vector.shape_cast %[[FMA_1]] : vector<3x8xf32> to vector<3x2x4xf32>
@@ -138,19 +141,21 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2(%input: memref<3x5
138
141
// CHECK: %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xi8> to vector<3x8xi8>
139
142
// CHECK: %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xi32> to vector<3x8xi32>
140
143
// CHECK: %[[EXT_INPUT_0:.*]] = arith.extsi %[[SC_V_INPUT_0]] : vector<3x8xi8> to vector<3x8xi32>
141
- // CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x2x4xi8>
142
- // CHECK: %[[SC_B_FILTER_0:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x2x4xi8> to vector<3x8xi8>
143
- // CHECK: %[[EXT_FILTER_0:.*]] = arith.extsi %[[SC_B_FILTER_0]] : vector<3x8xi8> to vector<3x8xi32>
144
- // CHECK: %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[EXT_FILTER_0]] : vector<3x8xi32>
144
+ // CHECK: %[[SH_FILTER_0:.*]] = vector.shuffle %[[V_FILTER_0]], %[[V_FILTER_0]]
145
+ // CHECK-SAME: [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
146
+ // CHECK: %[[EXT_FILTER_0:.*]] = arith.extsi %[[SH_FILTER_0]] : vector<8xi8> to vector<8xi32>
147
+ // CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[EXT_FILTER_0]] : vector<8xi32> to vector<3x8xi32>
148
+ // CHECK: %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[B_FILTER_0]] : vector<3x8xi32>
145
149
// CHECK: %[[ADD_0:.*]] = arith.addi %[[MUL_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xi32>
146
150
147
151
/// w == 0, kw = 1
148
152
// CHECK: %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xi8> to vector<3x8xi8>
149
153
// CHECK: %[[EXT_INPUT_1:.*]] = arith.extsi %[[SC_V_INPUT_1]] : vector<3x8xi8> to vector<3x8xi32>
150
- // CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x2x4xi8>
151
- // CHECK: %[[SC_B_FILTER_1:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x2x4xi8> to vector<3x8xi8>
152
- // CHECK: %[[EXT_FILTER_1:.*]] = arith.extsi %[[SC_B_FILTER_1]] : vector<3x8xi8> to vector<3x8xi32>
153
- // CHECK: %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[EXT_FILTER_1]] : vector<3x8xi32>
154
+ // CHECK: %[[SH_FILTER_1:.*]] = vector.shuffle %[[V_FILTER_1]], %[[V_FILTER_1]]
155
+ // CHECK-SAME: [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
156
+ // CHECK: %[[EXT_FILTER_1:.*]] = arith.extsi %[[SH_FILTER_1]] : vector<8xi8> to vector<8xi32>
157
+ // CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[EXT_FILTER_1]] : vector<8xi32> to vector<3x8xi32>
158
+ // CHECK: %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[B_FILTER_1]] : vector<3x8xi32>
154
159
// CHECK: %[[ADD_1:.*]] = arith.addi %[[MUL_1]], %[[ADD_0]] : vector<3x8xi32>
155
160
156
161
// Write the result back in one shot.
@@ -223,69 +228,60 @@ func.func @depthwise_conv1d_nwc_wc_3x9x4xi8_tensor_stride_2(%input: tensor<3x9x4
223
228
/// w == 0, kw == 0
224
229
// CHECK: %[[VAL_23:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x1x4xi8> to vector<3x4xi8>
225
230
// CHECK: %[[VAL_24:.*]] = vector.shape_cast %[[V_OUTPUT_0]] : vector<3x1x4xi8> to vector<3x4xi8>
226
- // CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
227
- // CHECK: %[[VAL_26:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
228
- // CHECK: %[[VAL_27:.*]] = arith.muli %[[VAL_23]], %[[VAL_26]] : vector<3x4xi8>
231
+ // CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x4xi8>
232
+ // CHECK: %[[VAL_27:.*]] = arith.muli %[[VAL_23]], %[[B_FILTER_0]] : vector<3x4xi8>
229
233
// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_27]], %[[VAL_24]] : vector<3x4xi8>
230
234
231
235
/// w == 1, kw == 0
232
236
// CHECK: %[[VAL_29:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x1x4xi8> to vector<3x4xi8>
233
237
// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[V_OUTPUT_1]] : vector<3x1x4xi8> to vector<3x4xi8>
234
- // CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
235
- // CHECK: %[[VAL_32:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
236
- // CHECK: %[[VAL_33:.*]] = arith.muli %[[VAL_29]], %[[VAL_32]] : vector<3x4xi8>
238
+ // CHECK: %[[B_FILTER_0_1:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x4xi8>
239
+ // CHECK: %[[VAL_33:.*]] = arith.muli %[[VAL_29]], %[[B_FILTER_0_1]] : vector<3x4xi8>
237
240
// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_33]], %[[VAL_30]] : vector<3x4xi8>
238
241
239
242
/// w == 2, kw == 0
240
243
// CHECK: %[[VAL_35:.*]] = vector.shape_cast %[[V_INPUT_2]] : vector<3x1x4xi8> to vector<3x4xi8>
241
244
// CHECK: %[[VAL_36:.*]] = vector.shape_cast %[[V_OUTPUT_2]] : vector<3x1x4xi8> to vector<3x4xi8>
242
- // CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
243
- // CHECK: %[[VAL_38:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
244
- // CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_35]], %[[VAL_38]] : vector<3x4xi8>
245
+ // CHECK: %[[B_FILTER_0_2:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x4xi8>
246
+ // CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_35]], %[[B_FILTER_0_2]] : vector<3x4xi8>
245
247
// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_39]], %[[VAL_36]] : vector<3x4xi8>
246
248
247
249
/// w == 3, kw == 1
248
250
// CHECK: %[[VAL_41:.*]] = vector.shape_cast %[[V_INPUT_3]] : vector<3x1x4xi8> to vector<3x4xi8>
249
- // CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
250
- // CHECK: %[[VAL_43:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
251
- // CHECK: %[[VAL_44:.*]] = arith.muli %[[VAL_41]], %[[VAL_43]] : vector<3x4xi8>
251
+ // CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x4xi8>
252
+ // CHECK: %[[VAL_44:.*]] = arith.muli %[[VAL_41]], %[[B_FILTER_1]] : vector<3x4xi8>
252
253
// CHECK: %[[VAL_45:.*]] = arith.addi %[[VAL_44]], %[[VAL_28]] : vector<3x4xi8>
253
254
254
255
/// w == 4, kw == 1
255
256
// CHECK: %[[VAL_46:.*]] = vector.shape_cast %[[V_INPUT_4]] : vector<3x1x4xi8> to vector<3x4xi8>
256
- // CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
257
- // CHECK: %[[VAL_48:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
258
- // CHECK: %[[VAL_49:.*]] = arith.muli %[[VAL_46]], %[[VAL_48]] : vector<3x4xi8>
257
+ // CHECK: %[[B_FILTER_1_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x4xi8>
258
+ // CHECK: %[[VAL_49:.*]] = arith.muli %[[VAL_46]], %[[B_FILTER_1_1]] : vector<3x4xi8>
259
259
// CHECK: %[[VAL_50:.*]] = arith.addi %[[VAL_49]], %[[VAL_34]] : vector<3x4xi8>
260
260
261
261
/// w == 5, kw == 1
262
262
// CHECK: %[[VAL_51:.*]] = vector.shape_cast %[[V_INPUT_5]] : vector<3x1x4xi8> to vector<3x4xi8>
263
- // CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
264
- // CHECK: %[[VAL_53:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
265
- // CHECK: %[[VAL_54:.*]] = arith.muli %[[VAL_51]], %[[VAL_53]] : vector<3x4xi8>
263
+ // CHECK: %[[B_FILTER_1_2:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x4xi8>
264
+ // CHECK: %[[VAL_54:.*]] = arith.muli %[[VAL_51]], %[[B_FILTER_1_2]] : vector<3x4xi8>
266
265
// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_54]], %[[VAL_40]] : vector<3x4xi8>
267
266
268
267
/// w == 6, kw == 2
269
268
// CHECK: %[[VAL_56:.*]] = vector.shape_cast %[[V_INPUT_6]] : vector<3x1x4xi8> to vector<3x4xi8>
270
- // CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
271
- // CHECK: %[[VAL_58:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
272
- // CHECK: %[[VAL_59:.*]] = arith.muli %[[VAL_56]], %[[VAL_58]] : vector<3x4xi8>
269
+ // CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x4xi8>
270
+ // CHECK: %[[VAL_59:.*]] = arith.muli %[[VAL_56]], %[[B_FILTER_2]] : vector<3x4xi8>
273
271
// CHECK: %[[VAL_60:.*]] = arith.addi %[[VAL_59]], %[[VAL_45]] : vector<3x4xi8>
274
272
275
273
/// w == 7, kw == 2
276
274
// CHECK: %[[VAL_61:.*]] = vector.shape_cast %[[VAL_60]] : vector<3x4xi8> to vector<3x1x4xi8>
277
275
// CHECK: %[[VAL_62:.*]] = vector.shape_cast %[[V_INPUT_7]] : vector<3x1x4xi8> to vector<3x4xi8>
278
- // CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
279
- // CHECK: %[[VAL_64:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
280
- // CHECK: %[[VAL_65:.*]] = arith.muli %[[VAL_62]], %[[VAL_64]] : vector<3x4xi8>
276
+ // CHECK: %[[B_FILTER_2_1:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x4xi8>
277
+ // CHECK: %[[VAL_65:.*]] = arith.muli %[[VAL_62]], %[[B_FILTER_2_1]] : vector<3x4xi8>
281
278
// CHECK: %[[VAL_66:.*]] = arith.addi %[[VAL_65]], %[[VAL_50]] : vector<3x4xi8>
282
279
283
280
/// w == 8, kw == 2
284
281
// CHECK: %[[VAL_67:.*]] = vector.shape_cast %[[VAL_66]] : vector<3x4xi8> to vector<3x1x4xi8>
285
282
// CHECK: %[[VAL_68:.*]] = vector.shape_cast %[[V_INPUT_8]] : vector<3x1x4xi8> to vector<3x4xi8>
286
- // CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
287
- // CHECK: %[[VAL_70:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
288
- // CHECK: %[[VAL_71:.*]] = arith.muli %[[VAL_68]], %[[VAL_70]] : vector<3x4xi8>
283
+ // CHECK: %[[B_FILTER_2_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x4xi8>
284
+ // CHECK: %[[VAL_71:.*]] = arith.muli %[[VAL_68]], %[[B_FILTER_2_2]] : vector<3x4xi8>
289
285
// CHECK: %[[VAL_72:.*]] = arith.addi %[[VAL_71]], %[[VAL_55]] : vector<3x4xi8>
290
286
291
287
// Write the result back.
0 commit comments