@@ -691,12 +691,12 @@ func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor<?x?xf32>, %arg1: t
691
691
// CHECK: %[[CNST2:.*]] = arith.constant 2 : index
692
692
// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
693
693
// CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
694
- // CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2 ] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32 >
695
- // CHEdCK : %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<4x2x1x16xf32 > to vector<4x16xf32 >
696
- // CHEdCK : %[[empt0:.*]] = tensor.empty
697
- // CHEdCK : %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<4x16xi1 >
698
- // CHEdCK : %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
699
- // CHEdCK : return %[[write0]]
694
+ // CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 2, 3, 1 ] : vector<2x1x16x2xf32> to vector<2x16x2x1xf32 >
695
+ // CHECK : %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x16x2x1xf32 > to vector<32x2xf32 >
696
+ // CHECK : %[[empt0:.*]] = tensor.empty
697
+ // CHECK : %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<32x2xi1 >
698
+ // CHECK : %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[empt0]]
699
+ // CHECK : return %[[write0]]
700
700
%ret = tensor.unpack %arg1 inner_dims_pos = [1 , 0 ] inner_tiles = [16 , 2 ] into %arg0 : tensor <?x?x16 x2 xf32 > -> tensor <?x?xf32 >
701
701
return %ret : tensor <?x?xf32 >
702
702
}
@@ -707,3 +707,58 @@ module attributes {transform.with_named_sequence} {
707
707
transform.yield
708
708
}
709
709
}
710
+
711
+ // -----
712
+
713
+ // CHECK-LABEL: func @test_vectorize_unpack
714
+ func.func @test_vectorize_unpack (%source: tensor <8 x8 x32 x16 xf32 >, %dest: tensor <256 x128 xf32 >) -> tensor <256 x128 xf32 > {
715
+ // CHECK: %[[C0:.*]]= arith.constant 0 : index
716
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
717
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
718
+ // CHECK: %[[C80:.*]] = arith.constant 8 : index
719
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
720
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
721
+ // CHECK: %[[MSK0:.*]] = vector.create_mask %c8, %c8_0, %c32, %c16 : vector<16x8x32x16xi1>
722
+ // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<16x8x32x16xi1> -> vector<16x8x32x16xf32>
723
+ // CHECK: %[[TRANSP0:.*]] = vector.transpose %[[READ0]], [0, 2, 1, 3] : vector<16x8x32x16xf32> to vector<16x32x8x16xf32>
724
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP0]] : vector<16x32x8x16xf32> to vector<512x128xf32>
725
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
726
+ // CHECK: %[[C01:.*]] = arith.constant 0 : index
727
+ // CHECK: %[[C256:.*]] = arith.constant 256 : index
728
+ // CHECK: %[[C128:.*]] = arith.constant 128 : index
729
+ // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C256]], %[[C128]] : vector<512x128xi1>
730
+ // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<512x128xi1> -> tensor<256x128xf32>
731
+ // CHECK: return %[[WRIT]] : tensor<256x128xf32>
732
+ %0 = tensor.unpack %source inner_dims_pos = [0 , 1 ] inner_tiles = [32 , 16 ] into %dest : tensor <8 x8 x32 x16 xf32 > -> tensor <256 x128 xf32 >
733
+ return %0 : tensor <256 x128 xf32 >
734
+ }
735
+ module attributes {transform.with_named_sequence } {
736
+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
737
+ %0 = transform.structured.match ops {[" tensor.unpack" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
738
+ transform.structured.vectorize %0 vector_sizes [512 , 128 ] : !transform.any_op
739
+ transform.yield
740
+ }
741
+ }
742
+
743
+ // -----
744
+
745
+ // CHECK-LABEL: func @test_vectorize_unpack_no_masks
746
+ func.func @test_vectorize_unpack_no_masks (%source: tensor <8 x8 x32 x16 xf32 >, %dest: tensor <256 x128 xf32 >) -> tensor <256 x128 xf32 > {
747
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
748
+ // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
749
+ // CHECK: %[[READ:.*]] = vector.transfer_read {{.*}} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32>
750
+ // CHECK: %[[TRANSP:.*]] = vector.transpose %[[READ]], [0, 2, 1, 3] : vector<8x8x32x16xf32> to vector<8x32x8x16xf32>
751
+ // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[TRANSP]] : vector<8x32x8x16xf32> to vector<256x128xf32>
752
+ // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<256x128xf32>
753
+ // CHECK: %[[C00:.*]] = arith.constant 0 : index
754
+ // CHECK: %[[WRIT:.*]] = vector.transfer_write %[[SHAPC]], {{.*}} : vector<256x128xf32>, tensor<256x128xf32>
755
+ // CHECK: return %[[WRIT]] : tensor<256x128xf32>
756
+ %0 = tensor.unpack %source inner_dims_pos = [0 , 1 ] inner_tiles = [32 , 16 ] into %dest : tensor <8 x8 x32 x16 xf32 > -> tensor <256 x128 xf32 >
757
+ return %0 : tensor <256 x128 xf32 >
758
+ }
759
+ module attributes {transform.with_named_sequence } {
760
+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
761
+ %0 = transform.structured.match ops {[" tensor.unpack" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
762
+ transform.structured.vectorize %0 vector_sizes [256 , 128 ] : !transform.any_op
763
+ transform.yield
764
+ } }
0 commit comments