@@ -848,6 +848,17 @@ func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?
848
848
849
849
// -----
850
850
851
+ func.func @no_infer_pack_shape (%arg0: tensor <?x32 x100 xf32 >, %arg1: index ) -> tensor <32 x7 x?x16 x1 xf32 > {
852
+ %cst = arith.constant 0.000000e+00 : f32
853
+ %0 = tensor.empty (%arg1 ) : tensor <32 x7 x?x16 x1 xf32 >
854
+ %pack = tensor.pack %arg0 padding_value (%cst : f32 ) outer_dims_perm = [1 , 2 , 0 ] inner_dims_pos = [2 , 0 ] inner_tiles = [16 , 1 ] into %0 : tensor <?x32 x100 xf32 > -> tensor <32 x7 x?x16 x1 xf32 >
855
+ return %pack : tensor <32 x7 x?x16 x1 xf32 >
856
+ }
857
+ // CHECK-LABEL: func.func @no_infer_pack_shape
858
+ // CHECK-NOT: tensor.cast
859
+
860
+ // -----
861
+
851
862
func.func @fold_padding_value_pack_negative1 (%arg0: tensor <1200 x499999 xf32 >) -> tensor <31250 x1200 x16 x1 xf32 > {
852
863
%cst = arith.constant 0.000000e+00 : f32
853
864
%0 = tensor.empty () : tensor <31250 x1200 x16 x1 xf32 >
@@ -944,6 +955,18 @@ func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30
944
955
945
956
// -----
946
957
958
+ func.func @no_infer_unpack_shape (%arg1: tensor <32 x7 x?x16 x1 xf32 >, %arg2: index ) -> tensor <?x32 x100 xf32 > {
959
+ %cst = arith.constant 0.000000e+00 : f32
960
+ %0 = tensor.empty (%arg2 ) : tensor <?x32 x100 xf32 >
961
+ %unpack = tensor.unpack %arg1 outer_dims_perm = [1 , 2 , 0 ] inner_dims_pos = [2 , 0 ] inner_tiles = [16 , 1 ] into %0 : tensor <32 x7 x?x16 x1 xf32 > -> tensor <?x32 x100 xf32 >
962
+ return %unpack : tensor <?x32 x100 xf32 >
963
+ }
964
+ // CHECK-LABEL: func.func @no_infer_unpack_shape
965
+ // CHECK-NOT: tensor.cast
966
+
967
+ // -----
968
+
969
+
947
970
// CHECK-LABEL: func @fold_overlapping_insert
948
971
// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
949
972
func.func @fold_overlapping_insert (%input : tensor <?x?x?xf32 >, %slice1: tensor <4 x?x8 xf32 >, %slice2: tensor <4 x?x8 xf32 >, %i: index , %size: index ) -> (tensor <?x?x?xf32 >) {
0 commit comments