Skip to content

Commit 19989bb

Browse files
committed
add new test cases
1 parent b3a13e0 commit 19989bb

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,17 @@ func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?
848848

849849
// -----
850850

851+
func.func @no_infer_pack_shape(%arg0: tensor<?x32x100xf32>, %arg1: index) -> tensor<32x7x?x16x1xf32> {
852+
%cst = arith.constant 0.000000e+00 : f32
853+
%0 = tensor.empty(%arg1) : tensor<32x7x?x16x1xf32>
854+
%pack = tensor.pack %arg0 padding_value(%cst : f32) outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<?x32x100xf32> -> tensor<32x7x?x16x1xf32>
855+
return %pack : tensor<32x7x?x16x1xf32>
856+
}
857+
// CHECK-LABEL: func.func @no_infer_pack_shape
858+
// CHECK-NOT: tensor.cast
859+
860+
// -----
861+
851862
func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
852863
%cst = arith.constant 0.000000e+00 : f32
853864
%0 = tensor.empty() : tensor<31250x1200x16x1xf32>
@@ -944,6 +955,18 @@ func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30
944955

945956
// -----
946957

958+
func.func @no_infer_unpack_shape(%arg1: tensor<32x7x?x16x1xf32>, %arg2: index) -> tensor<?x32x100xf32> {
959+
%cst = arith.constant 0.000000e+00 : f32
960+
%0 = tensor.empty(%arg2) : tensor<?x32x100xf32>
961+
%unpack = tensor.unpack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<32x7x?x16x1xf32> -> tensor<?x32x100xf32>
962+
return %unpack : tensor<?x32x100xf32>
963+
}
964+
// CHECK-LABEL: func.func @no_infer_unpack_shape
965+
// CHECK-NOT: tensor.cast
966+
967+
// -----
968+
969+
947970
// CHECK-LABEL: func @fold_overlapping_insert
948971
// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
949972
func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {

0 commit comments

Comments
 (0)