Skip to content

Commit 68b5328

Browse files
committed
Adding additional negative test cases
- Added additional test cases to demonstrate insert/extract slice will block producer/consumer fusion - Readability enahncements
1 parent 3ad8cd6 commit 68b5328

File tree

2 files changed

+134
-17
lines changed

2 files changed

+134
-17
lines changed

mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,15 @@ module attributes {transform.with_named_sequence} {
9898

9999
// This is same as pack_as_pad but since we explicitly added {lowerPadLikeWithInsertSlice = false}, it should not
100100
// be lowered to insert_slice.
101-
// CHECK-LABEL: func.func @pack_disallowed_as_pad(
102-
func.func @pack_disallowed_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
101+
// CHECK-LABEL: func.func @pack_as_pad_disabled_insert_slice(
102+
func.func @pack_as_pad_disabled_insert_slice(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
103103
%cst_0 = arith.constant 0.0 : f32
104104
// tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
105105
// CHECK-SAME: %[[ARG0:[^:]*]]: tensor<129x47x16x16xf32>
106-
// CHECK: %[[PAD:.*]] = tensor.pad %[[ARG0]]
106+
// CHECK-DAG: %[[PAD:.*]] = tensor.pad %[[ARG0]]
107107
// CHECK-NOT: %[[RES:.*]] = tensor.insert_slice %[[PAD]]
108108
// CHECK: %[[PAD_EXPANDED:.*]] = tensor.expand_shape %[[PAD]]
109-
// CHECK: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]]
109+
// CHECK-DAG: %[[RES:.*]] = linalg.transpose ins(%[[PAD_EXPANDED]]
110110
%pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
111111
: tensor<129x47x16x16xf32> -> tensor<1x1x1x1x136x64x16x16xf32>
112112
return %pack : tensor<1x1x1x1x136x64x16x16xf32>
@@ -261,18 +261,18 @@ module attributes {transform.with_named_sequence} {
261261

262262
// -----
263263

264-
// This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not
264+
// This is same as upack_as_pad but since we explicitly added {lowerUnpadLikeWithExtractSlice = false}, it should not
265265
// be lowered to extract_slice.
266-
// CHECK-LABEL: func.func @unpack_disallowed_as_pad(
267-
func.func @unpack_disallowed_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
266+
// CHECK-LABEL: func.func @unpack_as_pad_disabled_extract_slice(
267+
func.func @unpack_as_pad_disabled_extract_slice(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
268268
%cst_0 = arith.constant 0.0 : f32
269269

270270
// tensor.unpack is lowered to tensor.extract_slice + linalg.transpose + tensor.collapse_shape
271-
// CHECK-SAME: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
272-
// CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
273-
// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]]
274-
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]]
275-
// CHECK: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]]
271+
// CHECK-DAG: %[[ARG0:[^:]*]]: tensor<1x1x1x1x136x64x16x16xf32>
272+
// CHECK-NOT: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
273+
// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[ARG0]]
274+
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[TRANSPOSED]]
275+
// CHECK-DAG: %[[RES:.*]] = tensor.extract_slice %[[COLLAPSED]]
276276
%pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
277277
: tensor<1x1x1x1x136x64x16x16xf32> -> tensor<129x47x16x16xf32>
278278
return %pack : tensor<129x47x16x16xf32>
@@ -632,7 +632,7 @@ func.func @unpack_fully_dynamic(%source: tensor<?x?x?x?xf32>, %dest: tensor<?x?x
632632
module attributes {transform.with_named_sequence} {
633633
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
634634
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
635-
: (!transform.any_op) -> !transform.op<"tensor.unpack">
635+
: (!transform.any_op) -> !transform.op<"tensor.unpack">
636636
transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
637637
-> (!transform.op<"tensor.empty">,
638638
!transform.op<"linalg.transpose">,
@@ -687,9 +687,9 @@ module attributes {transform.with_named_sequence} {
687687
// CHECK-LABEL: @unpack_with_outer_dims_perm
688688
// CHECK-SAME: %[[ARG0:.*]]: tensor<32x64xf32>, %[[ARG1:.*]]: tensor<2x4x32x8xf32>
689689
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<4x8x2x32xf32>
690-
// CHECK: %[[TRAN:.*]] = linalg.transpose
691-
// CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>)
692-
// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>)
690+
// CHECK: %[[TRAN:.*]] = linalg.transpose
691+
// CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>)
692+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>)
693693
// CHECK-SAME: permutation = [1, 3, 0, 2]
694694
// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
695695
// CHECK-SAME: : tensor<4x8x2x32xf32> into tensor<32x64xf32>
@@ -698,7 +698,7 @@ module attributes {transform.with_named_sequence} {
698698
// CHECK: linalg.copy ins(%[[SLICE]]
699699
// CHECK-SAME: : tensor<32x64xf32>) outs(%[[ARG0]] : tensor<32x64xf32>) -> tensor<32x64xf32>
700700
func.func @unpack_with_outer_dims_perm(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
701-
%unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0]
701+
%unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0]
702702
inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg0 : tensor<2x4x32x8xf32> -> tensor<32x64xf32>
703703
return %unpack : tensor<32x64xf32>
704704
}

mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,62 @@ module {
5858
}
5959
}
6060

61+
// -----
62+
// For pack op, by default lowerPadLikeWithInsertSlice = true, which generates insert_slice and blocks fusion.
63+
64+
module {
65+
// CHECK-label: func @fuse_pack_as_producer_blocked_by_insert_slice
66+
// CHECK: tensor.insert_slice
67+
// CHECK: scf.forall {{.*}} {
68+
// CHECK: scf.forall.in_parallel
69+
// CHECK: }
70+
func.func @fuse_pack_as_producer_blocked_by_insert_slice(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>)
71+
-> tensor<4x4x128x256xf32> {
72+
%dest = tensor.empty() : tensor<1x1x128x256xf32>
73+
%pack = tensor.pack %src inner_dims_pos = [0, 1] inner_tiles = [128, 256]
74+
into %dest : tensor<128x256xf32> -> tensor<1x1x128x256xf32>
75+
76+
%out = tensor.empty() : tensor<4x4x128x256xf32>
77+
%res = linalg.generic
78+
{indexing_maps = [affine_map<(i, j, k, l) -> (0, 0, k, l)>,
79+
affine_map<(i, j, k, l) -> (i, j, k, l)>,
80+
affine_map<(i, j, k, l) -> (i, j, k, l)>],
81+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
82+
ins(%pack, %other: tensor<1x1x128x256xf32>, tensor<4x4x128x256xf32>)
83+
outs(%out: tensor<4x4x128x256xf32>) {
84+
^bb0(%pack_elem: f32, %other_elem: f32, %out_elem: f32):
85+
%r = arith.addf %pack_elem, %other_elem : f32
86+
linalg.yield %r : f32
87+
} -> tensor<4x4x128x256xf32>
88+
89+
return %res : tensor<4x4x128x256xf32>
90+
}
91+
92+
module attributes {transform.with_named_sequence} {
93+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
94+
// Find and lower pack operation.
95+
%pack = transform.structured.match ops{["tensor.pack"]} in %arg1
96+
: (!transform.any_op) -> !transform.op<"tensor.pack">
97+
%paded, %expanded, %transpose = transform.structured.lower_pack %pack
98+
: (!transform.op<"tensor.pack">)
99+
-> (!transform.op<"tensor.pad">,
100+
!transform.op<"tensor.expand_shape">,
101+
!transform.op<"linalg.transpose">)
102+
103+
%root = transform.structured.match ops{["linalg.generic"]} in %arg1
104+
: (!transform.any_op) -> !transform.any_op
105+
// Tile the lialg operation with parallel forall loop tiling [4, 4].
106+
%tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
107+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
108+
109+
// Fuse the transpose operation into the tiled loop.
110+
transform.structured.fuse_into_containing_op %transpose into %forall_op
111+
: (!transform.op<"linalg.transpose">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
112+
transform.yield
113+
}
114+
}
115+
}
116+
61117
// -----
62118
// For unpack op, we use lowerUnpadLikeWithExtractSlice = false to ensure no extract_slice is generated.
63119
// This allows linalg.transpose to be fused as a consumer operation. Alternatively, without this attribute
@@ -119,3 +175,64 @@ module {
119175
}
120176
}
121177
}
178+
179+
// -----
180+
// For unpack op, by default lowerUnpadLikeWithExtractSlice = true, which generates extract_slice and blocks fusion.
181+
182+
module {
183+
// CHECK-label: func @fuse_unpack_as_consumer_blocked_by_extract_slice
184+
// CHECK: scf.forall {{.*}} {
185+
// CHECK: linalg.generic
186+
// CHECK: scf.forall.in_parallel
187+
// CHECK: }
188+
// CHECK: tensor.extract_slice
189+
func.func @fuse_unpack_as_consumer_blocked_by_extract_slice(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>)
190+
-> tensor<128x256xf32> {
191+
%out = tensor.empty() : tensor<1x1x128x256xf32>
192+
%res = linalg.generic
193+
{indexing_maps = [affine_map<(i, j, k, l) -> (i, j, k, l)>,
194+
affine_map<(i, j, k, l) -> (i, j, k, l)>,
195+
affine_map<(i, j, k, l) -> (0, 0, k, l)>],
196+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
197+
ins(%src, %other: tensor<4x4x128x256xf32>, tensor<4x4x128x256xf32>)
198+
outs(%out: tensor<1x1x128x256xf32>) {
199+
^bb0(%unpack_elem: f32, %other_elem: f32, %out_elem: f32):
200+
%r = arith.addf %unpack_elem, %other_elem : f32
201+
linalg.yield %r : f32
202+
} -> tensor<1x1x128x256xf32>
203+
204+
%dest = tensor.empty() : tensor<128x256xf32>
205+
%unpack = tensor.unpack %res inner_dims_pos = [0, 1] inner_tiles = [128, 256]
206+
into %dest : tensor<1x1x128x256xf32> -> tensor<128x256xf32>
207+
208+
return %unpack : tensor<128x256xf32>
209+
}
210+
211+
module attributes {transform.with_named_sequence} {
212+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
213+
// Find and lower unpack operation.
214+
%unpack = transform.structured.match ops{["tensor.unpack"]} in %arg1
215+
: (!transform.any_op) -> !transform.op<"tensor.unpack">
216+
transform.structured.lower_unpack %unpack
217+
: (!transform.op<"tensor.unpack">)
218+
-> (!transform.op<"tensor.empty">,
219+
!transform.op<"linalg.transpose">,
220+
!transform.op<"tensor.collapse_shape">,
221+
!transform.op<"tensor.extract_slice">)
222+
223+
%root = transform.structured.match ops{["linalg.generic"]} in %arg1
224+
: (!transform.any_op) -> !transform.any_op
225+
// Tile the lialg operation with parallel forall loop tiling [4, 4].
226+
%tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
227+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
228+
229+
// Fuse the consumer operation into the tiled loop.
230+
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
231+
: (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
232+
// Note that we cannot apply transform.test.fuse_consumer here because the extract_slice
233+
// is not qualified consumer operation. Forcing this will yeild "could not fetch consumer
234+
// to fuse" error.
235+
transform.yield
236+
}
237+
}
238+
}

0 commit comments

Comments
 (0)