@@ -253,6 +253,54 @@ module attributes {transform.with_named_sequence} {
253
253
transform.yield
254
254
}
255
255
}
256
+
257
+ // -----
258
+
259
+ #map = affine_map <(d0 , d1 ) -> (d0 , d1 )>
260
+ #map1 = affine_map <(d0 , d1 , d2 ) -> (d0 + d1 + d2 )>
261
+ func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim (%arg0: tensor <8 x128 x768 xf32 >, %arg1 : index ) -> tensor <8 x1 xf32 > {
262
+ %c0 = arith.constant 0 : index
263
+ %0 = tensor.empty () : tensor <8 x1 xf32 >
264
+ %1 = linalg.generic {
265
+ indexing_maps = [#map ],
266
+ iterator_types = [" parallel" , " parallel" ]
267
+ } outs (%0 : tensor <8 x1 xf32 >) {
268
+ ^bb0 (%arg5: f32 ):
269
+ %2 = linalg.index 0 : index
270
+ %3 = linalg.index 1 : index
271
+ %4 = affine.apply #map1 (%arg1 , %3 , %arg1 )
272
+ %extracted = tensor.extract %arg0 [%2 , %c0 , %4 ] : tensor <8 x128 x768 xf32 >
273
+ linalg.yield %extracted : f32
274
+ } -> tensor <8 x1 xf32 >
275
+ return %1 : tensor <8 x1 xf32 >
276
+ }
277
+
278
+ module attributes {transform.with_named_sequence } {
279
+ transform.named_sequence @__transform_main (%arg2: !transform.any_op {transform.readonly }) {
280
+ %0 = transform.structured.match ops {[" linalg.generic" ]} in %arg2 : (!transform.any_op ) -> !transform.any_op
281
+ %1 = transform.get_parent_op %0 {isolated_from_above } : (!transform.any_op ) -> !transform.any_op
282
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract } : (!transform.any_op ) -> !transform.any_op
283
+ transform.yield
284
+ }
285
+ }
286
+
287
+ // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim
288
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
289
+ // CHECK-SAME: %[[ARG1:.*]]: index
290
+ // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
291
+ // CHECK-DAG: %[[C0_i32:.*]] = arith.constant 0 : i32
292
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
293
+ // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
294
+ // CHECK: %[[IDX0:.*]] = tensor.empty() : tensor<8x1xf32>
295
+ // CHECK: %[[IDX1:.*]] = vector.broadcast %[[CST_0]] : vector<8xindex> to vector<1x8xindex
296
+ // CHECK: %[[IDX2:.*]] = vector.transpose %[[IDX1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
297
+ // CHECK: %[[IDX3:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
298
+ // CHECK: %[[IDX4:.*]] = vector.transpose %[[IDX2]], [1, 0] : vector<8x1xindex> to vector<1x8xindex>
299
+ // CHECK: %[[IDX5:.*]] = vector.shape_cast %[[IDX4]] : vector<1x8xindex> to vector<8xindex>
300
+ // CHECK: %[[IDX6:.*]] = vector.extractelement %[[IDX5]][%[[C0_i32]] : i32] : vector<8xindex>
301
+ // CHECK: %[[IDX7:.*]] = vector.transfer_read %[[ARG0]][%[[IDX6]], %[[C0]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true]} : tensor<8x128x768xf32>, vector<8x1xf32>
302
+ // CHECK: vector.transfer_write %[[IDX7]], %[[IDX0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
303
+
256
304
// -----
257
305
258
306
#map = affine_map <(d0 ) -> (d0 )>
0 commit comments