Skip to content

Commit 46b7202

Browse files
committed
This commit add test cases to allowInsertSliceLowering and
allowExtractSliceLowering
1 parent 012f6d4 commit 46b7202

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,34 @@ module attributes {transform.with_named_sequence} {
9696

9797
// -----
9898

99+
// This is same as pack_as_pad but since we explicitly added {allowInsertSliceLowering = false}, it should not
100+
// be lowered to insert_slice.
101+
// CHECK-LABEL: func.func @pack_disallowed_as_pad(
102+
// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
103+
// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
104+
func.func @pack_disallowed_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
105+
%cst_0 = arith.constant 0.0 : f32
106+
// tensor.pack is lowered to tensor.pad + tensor.expand_shape + tensor.insert_slice
107+
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
108+
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
109+
// CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
110+
%pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
111+
: tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
112+
return %pack : tensor<1x1x1x1x136x64x16x16xf32>
113+
}
114+
115+
module attributes {transform.with_named_sequence} {
116+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
117+
%pack = transform.structured.match ops{["tensor.pack"]} in %module_op
118+
: (!transform.any_op) -> !transform.op<"tensor.pack">
119+
transform.structured.lower_pack %pack {allowInsertSliceLowering = false}: (!transform.op<"tensor.pack">)
120+
-> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
121+
transform.yield
122+
}
123+
}
124+
125+
// -----
126+
99127
// Check that we don't lower the following pack as a pad.
100128
// Although all the outer most dimensions in the resulting shape are 1s,
101129
// some of the original dimensions are not part of the inner_dims_pos, hence
@@ -233,6 +261,34 @@ module attributes {transform.with_named_sequence} {
233261

234262
// -----
235263

264+
// This is same as upack_as_pad but since we explicitly added {allowExtractSlicelowering = false}, it should not
265+
// be lowered to extract_slice.
266+
// CHECK-LABEL: func.func @unpack_disallowed_as_pad(
267+
func.func @unpack_disallowed_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
268+
%cst_0 = arith.constant 0.0 : f32
269+
270+
// CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
271+
// CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
272+
%pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
273+
: tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
274+
return %pack : tensor<129x47x16x16xf32>
275+
}
276+
277+
module attributes {transform.with_named_sequence} {
278+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
279+
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
280+
: (!transform.any_op) -> !transform.op<"tensor.unpack">
281+
transform.structured.lower_unpack %unpack {allowExtractSliceLowering = false}: (!transform.op<"tensor.unpack">)
282+
-> (!transform.op<"tensor.empty">,
283+
!transform.op<"linalg.transpose">,
284+
!transform.op<"tensor.collapse_shape">,
285+
!transform.op<"tensor.extract_slice">)
286+
transform.yield
287+
}
288+
}
289+
290+
// -----
291+
236292
// CHECK-LABEL: func.func @pack_with_outer_dims_perm(
237293
func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
238294
%dest: tensor<200x4x16x100x16x32xi32>)

0 commit comments

Comments
 (0)