Skip to content

Commit 671829e

Browse files
committed
Adding test cases to allowInsertSliceLowering and allowExtractSliceLowering
1 parent 0fa5401 commit 671829e

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,34 @@ module attributes {transform.with_named_sequence} {
9696

9797
// -----
9898

99+
// This is same as pack_as_pad but since we explicitly added {lowerPadLikeWithInsertSlice = false}, it should not
100+
// be lowered to insert_slice.
101+
// CHECK-LABEL: func.func @pack_disallowed_as_pad(
102+
func.func @pack_disallowed_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
103+
%cst_0 = arith.constant 0.0 : f32
104+
// tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
105+
// CHECK-SAME: %[[ARG0:[^:]*]]: tensor<129x47x16x16xf32>
106+
// CHECK: %[[PAD:.*]] = tensor.pad %[[ARG0]]
107+
// CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]]
108+
// CHECK: %[[PAD_EXPANDED:.*]] = tensor.expand_shape %[[PAD]]
109+
// CHECK: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]]
110+
%pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
111+
: tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
112+
return %pack : tensor<1x1x1x1x136x64x16x16xf32>
113+
}
114+
115+
module attributes {transform.with_named_sequence} {
116+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
117+
%pack = transform.structured.match ops{["tensor.pack"]} in %module_op
118+
: (!transform.any_op) -> !transform.op<"tensor.pack">
119+
transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false}: (!transform.op<"tensor.pack">)
120+
-> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
121+
transform.yield
122+
}
123+
}
124+
125+
// -----
126+
99127
// Check that we don't lower the following pack as a pad.
100128
// Although all the outer most dimensions in the resulting shape are 1s,
101129
// some of the original dimensions are not part of the inner_dims_pos, hence
@@ -233,6 +261,38 @@ module attributes {transform.with_named_sequence} {
233261

234262
// -----
235263

264+
// This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not
265+
// be lowered to extract_slice.
266+
// CHECK-LABEL: func.func @unpack_disallowed_as_pad(
267+
func.func @unpack_disallowed_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
268+
%cst_0 = arith.constant 0.0 : f32
269+
270+
// tensor.unpack is lowered to tensor.extract_slice + linalg.transpose + tensor.collapse_shape
271+
// CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
272+
// CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
273+
// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]]
274+
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]]
275+
// CHECK: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]]
276+
%pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
277+
: tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
278+
return %pack : tensor<129x47x16x16xf32>
279+
}
280+
281+
module attributes {transform.with_named_sequence} {
282+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
283+
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
284+
: (!transform.any_op) -> !transform.op<"tensor.unpack">
285+
transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false}: (!transform.op<"tensor.unpack">)
286+
-> (!transform.op<"tensor.empty">,
287+
!transform.op<"linalg.transpose">,
288+
!transform.op<"tensor.collapse_shape">,
289+
!transform.op<"tensor.extract_slice">)
290+
transform.yield
291+
}
292+
}
293+
294+
// -----
295+
236296
// CHECK-LABEL: func.func @pack_with_outer_dims_perm(
237297
func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
238298
%dest: tensor<200x4x16x100x16x32xi32>)

0 commit comments

Comments
 (0)