@@ -812,3 +812,37 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
812
812
transform.yield
813
813
}
814
814
}
815
+
816
+ // -----
817
+
818
+ // CHECK-LABEL: test_vectorize_padded_pack_no_vector_sizes
819
+ func.func @test_vectorize_padded_pack_no_vector_sizes (%arg0: tensor <32 x7 x15 xf32 >, %arg1: tensor <32 x4 x1 x16 x2 xf32 >) -> tensor <32 x4 x1 x16 x2 xf32 > {
820
+ %pad = arith.constant 0.000000e+00 : f32
821
+ %pack = tensor.pack %arg0 padding_value (%pad : f32 ) inner_dims_pos = [2 , 1 ] inner_tiles = [16 , 2 ] into %arg1 : tensor <32 x7 x15 xf32 > -> tensor <32 x4 x1 x16 x2 xf32 >
822
+ return %pack : tensor <32 x4 x1 x16 x2 xf32 >
823
+ }
824
+ // CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
825
+ // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
826
+ // CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index
827
+ // CHECK-DAG: %[[c7:.*]] = arith.constant 7 : index
828
+ // CHECK-DAG: %[[c15:.*]] = arith.constant 15 : index
829
+ // CHECK: %[[mask:.*]] = vector.create_mask %[[c32]], %[[c7]], %[[c15]] : vector<32x8x16xi1>
830
+ // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
831
+ // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
832
+ // CHECK-SAME: {in_bounds = [true, true, true]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
833
+ // CHECK-SAME: } : vector<32x8x16xi1> -> vector<32x8x16xf32>
834
+ // CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
835
+ // CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
836
+ // CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
837
+ // CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
838
+ // CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
839
+ // CHECK-SAME: {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
840
+ // CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
841
+
842
+ module attributes {transform.with_named_sequence } {
843
+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
844
+ %0 = transform.structured.match ops {[" tensor.pack" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
845
+ transform.structured.vectorize %0 : !transform.any_op
846
+ transform.yield
847
+ }
848
+ }
0 commit comments