@@ -822,7 +822,7 @@ func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x3
822
822
// CHECK-LABEL: func.func @infer_src_shape_pack
823
823
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
824
824
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
825
- // CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32 >
825
+ // CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32 >
826
826
// CHECK: %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
827
827
// CHECK: return %[[PACK]]
828
828
@@ -841,13 +841,24 @@ func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?
841
841
// CHECK-LABEL: func.func @infer_dest_shape_pack
842
842
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
843
843
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
844
- // CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32 >
844
+ // CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32 >
845
845
// CHECK: %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
846
- // CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<10x20x30x?x16xf32 > to tensor<?x?x?x?x16xf32>
846
+ // CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<?x20x10x30x16xf32 > to tensor<?x?x?x?x16xf32>
847
847
// CHECK: return %[[CAST_PACK]]
848
848
849
849
// -----
850
850
851
+ func.func @no_infer_pack_shape (%arg0: tensor <?x32 x100 xf32 >, %arg1: index ) -> tensor <32 x7 x?x16 x1 xf32 > {
852
+ %cst = arith.constant 0.000000e+00 : f32
853
+ %0 = tensor.empty (%arg1 ) : tensor <32 x7 x?x16 x1 xf32 >
854
+ %pack = tensor.pack %arg0 padding_value (%cst : f32 ) outer_dims_perm = [1 , 2 , 0 ] inner_dims_pos = [2 , 0 ] inner_tiles = [16 , 1 ] into %0 : tensor <?x32 x100 xf32 > -> tensor <32 x7 x?x16 x1 xf32 >
855
+ return %pack : tensor <32 x7 x?x16 x1 xf32 >
856
+ }
857
+ // CHECK-LABEL: func.func @no_infer_pack_shape
858
+ // CHECK-NOT: tensor.cast
859
+
860
+ // -----
861
+
851
862
func.func @fold_padding_value_pack_negative1 (%arg0: tensor <1200 x499999 xf32 >) -> tensor <31250 x1200 x16 x1 xf32 > {
852
863
%cst = arith.constant 0.000000e+00 : f32
853
864
%0 = tensor.empty () : tensor <31250 x1200 x16 x1 xf32 >
@@ -920,9 +931,9 @@ func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tens
920
931
// CHECK-LABEL: func.func @infer_dest_shape_unpack
921
932
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
922
933
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
923
- // CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32 >
934
+ // CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32 >
924
935
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
925
- // CHECK: %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<30x20x?x10xf32 > to tensor<?x?x?x?xf32>
936
+ // CHECK: %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<40x20x?x30xf32 > to tensor<?x?x?x?xf32>
926
937
// CHECK: return %[[CAST_UNPACK]]
927
938
928
939
// -----
@@ -938,12 +949,24 @@ func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30
938
949
// CHECK-LABEL: func.func @infer_src_shape_unpack
939
950
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
940
951
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
941
- // CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32 >
952
+ // CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32 >
942
953
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
943
954
// CHECK: return %[[UNPACK]]
944
955
945
956
// -----
946
957
958
+ func.func @no_infer_unpack_shape (%arg1: tensor <32 x7 x?x16 x1 xf32 >, %arg2: index ) -> tensor <?x32 x100 xf32 > {
959
+ %cst = arith.constant 0.000000e+00 : f32
960
+ %0 = tensor.empty (%arg2 ) : tensor <?x32 x100 xf32 >
961
+ %unpack = tensor.unpack %arg1 outer_dims_perm = [1 , 2 , 0 ] inner_dims_pos = [2 , 0 ] inner_tiles = [16 , 1 ] into %0 : tensor <32 x7 x?x16 x1 xf32 > -> tensor <?x32 x100 xf32 >
962
+ return %unpack : tensor <?x32 x100 xf32 >
963
+ }
964
+ // CHECK-LABEL: func.func @no_infer_unpack_shape
965
+ // CHECK-NOT: tensor.cast
966
+
967
+ // -----
968
+
969
+
947
970
// CHECK-LABEL: func @fold_overlapping_insert
948
971
// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
949
972
func.func @fold_overlapping_insert (%input : tensor <?x?x?xf32 >, %slice1: tensor <4 x?x8 xf32 >, %slice2: tensor <4 x?x8 xf32 >, %i: index , %size: index ) -> (tensor <?x?x?xf32 >) {
0 commit comments