@@ -566,6 +566,46 @@ module attributes {transform.with_named_sequence} {
566
566
567
567
// -----
568
568
569
+ func.func @test_vectorize_dynamic_result_pack (%arg0: tensor <?x?xf32 >, %arg1: tensor <?x?x16 x2 xf32 >) -> tensor <?x?x16 x2 xf32 > {
570
+ %pack = tensor.pack %arg0 inner_dims_pos = [1 , 0 ] inner_tiles = [16 , 2 ] into %arg1 : tensor <?x?xf32 > -> tensor <?x?x16 x2 xf32 >
571
+ return %pack : tensor <?x?x16 x2 xf32 >
572
+ }
573
+ module attributes {transform.with_named_sequence } {
574
+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
575
+ %0 = transform.structured.match ops {[" tensor.pack" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
576
+ transform.structured.vectorize %0 vector_sizes [4 , 1 ] : !transform.any_op
577
+ transform.yield
578
+ }
579
+ }
580
+ // CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
581
+ // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
582
+ // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
583
+ // CHECK-DAG: %[[d0:.*]] = tensor.dim {{.*}} %[[c0]] : tensor<?x?x16x2xf32>
584
+ // CHECK-DAG: %[[d1:.*]] = tensor.dim {{.*}} %[[c1]] : tensor<?x?x16x2xf32>
585
+ // CHECK-DAG: %[[c0_0:.*]] = arith.constant 0 : index
586
+ // CHECK-DAG: %[[c1_0:.*]] = arith.constant 1 : index
587
+ // CHECK-DAG: %[[d0_0:.*]] = tensor.dim {{.*}} %[[c0_0]] : tensor<?x?xf32>
588
+ // CHECK-DAG: %[[d1_0:.*]] = tensor.dim {{.*}} %[[c1_0]] : tensor<?x?xf32>
589
+ // CHECK: %[[mask:.*]] = vector.create_mask %[[d0_0]], %[[d1_0]] : vector<8x16xi1>
590
+ // CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
591
+ // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
592
+ // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0_1]], %[[c0_1]]], %[[cst]]
593
+ // CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<8x16xf32>
594
+ // CHECK-SAME: } : vector<8x16xi1> -> vector<8x16xf32>
595
+ // CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<8x16xf32> to vector<4x2x1x16xf32>
596
+ // CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 2, 3, 1] : vector<4x2x1x16xf32> to vector<4x1x16x2xf32>
597
+ // CHECK-DAG: %[[c0_2:.*]] = arith.constant 0 : index
598
+ // CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
599
+ // CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
600
+ // CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[d0]], %[[d1]]) : tensor<?x?x16x2xf32>
601
+ // CHECK: %[[mask_0:.*]] = vector.create_mask %[[d0]], %[[d1]], %[[c16]], %[[c2]] : vector<4x1x16x2xi1>
602
+ // CHECK: %[[masked_write:.*]] = vector.mask %[[mask_0]] {
603
+ // CHECK-SAME: vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]]
604
+ // CHECK-SAME: {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<?x?x16x2xf32>
605
+ // CHECK: return %[[masked_write]] : tensor<?x?x16x2xf32>
606
+
607
+ // -----
608
+
569
609
func.func @matmul (%A: memref <?x?xf32 >, %B: memref <?x?xf32 >, %C: memref <?x?xf32 >) {
570
610
linalg.matmul ins (%A , %B: memref <?x?xf32 >, memref <?x?xf32 >)
571
611
outs (%C: memref <?x?xf32 >)
0 commit comments