@@ -58,6 +58,62 @@ module {
58
58
}
59
59
}
60
60
61
+ // -----
62
+ // For pack op, by default lowerPadLikeWithInsertSlice = true, which generates insert_slice and blocks fusion.
63
+
64
+ module {
65
+ // CHECK-label: func @fuse_pack_as_producer_blocked_by_insert_slice
66
+ // CHECK: tensor.insert_slice
67
+ // CHECK: scf.forall {{.*}} {
68
+ // CHECK: scf.forall.in_parallel
69
+ // CHECK: }
70
+ func.func @fuse_pack_as_producer_blocked_by_insert_slice (%src: tensor <128 x256 xf32 >, %other: tensor <4 x4 x128 x256 xf32 >)
71
+ -> tensor <4 x4 x128 x256 xf32 > {
72
+ %dest = tensor.empty () : tensor <1 x1 x128 x256 xf32 >
73
+ %pack = tensor.pack %src inner_dims_pos = [0 , 1 ] inner_tiles = [128 , 256 ]
74
+ into %dest : tensor <128 x256 xf32 > -> tensor <1 x1 x128 x256 xf32 >
75
+
76
+ %out = tensor.empty () : tensor <4 x4 x128 x256 xf32 >
77
+ %res = linalg.generic
78
+ {index ing_maps = [affine_map <(i , j , k , l ) -> (0 , 0 , k , l )>,
79
+ affine_map <(i , j , k , l ) -> (i , j , k , l )>,
80
+ affine_map <(i , j , k , l ) -> (i , j , k , l )>],
81
+ iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]}
82
+ ins (%pack , %other: tensor <1 x1 x128 x256 xf32 >, tensor <4 x4 x128 x256 xf32 >)
83
+ outs (%out: tensor <4 x4 x128 x256 xf32 >) {
84
+ ^bb0 (%pack_elem: f32 , %other_elem: f32 , %out_elem: f32 ):
85
+ %r = arith.addf %pack_elem , %other_elem : f32
86
+ linalg.yield %r : f32
87
+ } -> tensor <4 x4 x128 x256 xf32 >
88
+
89
+ return %res : tensor <4 x4 x128 x256 xf32 >
90
+ }
91
+
92
+ module attributes {transform.with_named_sequence } {
93
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
94
+ // Find and lower pack operation.
95
+ %pack = transform.structured.match ops {[" tensor.pack" ]} in %arg1
96
+ : (!transform.any_op ) -> !transform.op <" tensor.pack" >
97
+ %paded , %expanded , %transpose = transform.structured.lower_pack %pack
98
+ : (!transform.op <" tensor.pack" >)
99
+ -> (!transform.op <" tensor.pad" >,
100
+ !transform.op <" tensor.expand_shape" >,
101
+ !transform.op <" linalg.transpose" >)
102
+
103
+ %root = transform.structured.match ops {[" linalg.generic" ]} in %arg1
104
+ : (!transform.any_op ) -> !transform.any_op
105
+ // Tile the lialg operation with parallel forall loop tiling [4, 4].
106
+ %tiled_op , %forall_op = transform.structured.tile_using_forall %root num_threads [4 , 4 ]
107
+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
108
+
109
+ // Fuse the transpose operation into the tiled loop.
110
+ transform.structured.fuse_into_containing_op %transpose into %forall_op
111
+ : (!transform.op <" linalg.transpose" >, !transform.any_op ) -> (!transform.any_op , !transform.any_op )
112
+ transform.yield
113
+ }
114
+ }
115
+ }
116
+
61
117
// -----
62
118
// For unpack op, we use lowerUnpadLikeWithExtractSlice = false to ensure no extract_slice is generated.
63
119
// This allows linalg.transpose to be fused as a consumer operation. Alternatively, without this attribute
@@ -119,3 +175,64 @@ module {
119
175
}
120
176
}
121
177
}
178
+
179
+ // -----
180
+ // For unpack op, by default lowerUnpadLikeWithExtractSlice = true, which generates extract_slice and blocks fusion.
181
+
182
+ module {
183
+ // CHECK-label: func @fuse_unpack_as_consumer_blocked_by_extract_slice
184
+ // CHECK: scf.forall {{.*}} {
185
+ // CHECK: linalg.generic
186
+ // CHECK: scf.forall.in_parallel
187
+ // CHECK: }
188
+ // CHECK: tensor.extract_slice
189
+ func.func @fuse_unpack_as_consumer_blocked_by_extract_slice (%src: tensor <4 x4 x128 x256 xf32 >, %other: tensor <4 x4 x128 x256 xf32 >)
190
+ -> tensor <128 x256 xf32 > {
191
+ %out = tensor.empty () : tensor <1 x1 x128 x256 xf32 >
192
+ %res = linalg.generic
193
+ {index ing_maps = [affine_map <(i , j , k , l ) -> (i , j , k , l )>,
194
+ affine_map <(i , j , k , l ) -> (i , j , k , l )>,
195
+ affine_map <(i , j , k , l ) -> (0 , 0 , k , l )>],
196
+ iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]}
197
+ ins (%src , %other: tensor <4 x4 x128 x256 xf32 >, tensor <4 x4 x128 x256 xf32 >)
198
+ outs (%out: tensor <1 x1 x128 x256 xf32 >) {
199
+ ^bb0 (%unpack_elem: f32 , %other_elem: f32 , %out_elem: f32 ):
200
+ %r = arith.addf %unpack_elem , %other_elem : f32
201
+ linalg.yield %r : f32
202
+ } -> tensor <1 x1 x128 x256 xf32 >
203
+
204
+ %dest = tensor.empty () : tensor <128 x256 xf32 >
205
+ %unpack = tensor.unpack %res inner_dims_pos = [0 , 1 ] inner_tiles = [128 , 256 ]
206
+ into %dest : tensor <1 x1 x128 x256 xf32 > -> tensor <128 x256 xf32 >
207
+
208
+ return %unpack : tensor <128 x256 xf32 >
209
+ }
210
+
211
+ module attributes {transform.with_named_sequence } {
212
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
213
+ // Find and lower unpack operation.
214
+ %unpack = transform.structured.match ops {[" tensor.unpack" ]} in %arg1
215
+ : (!transform.any_op ) -> !transform.op <" tensor.unpack" >
216
+ transform.structured.lower_unpack %unpack
217
+ : (!transform.op <" tensor.unpack" >)
218
+ -> (!transform.op <" tensor.empty" >,
219
+ !transform.op <" linalg.transpose" >,
220
+ !transform.op <" tensor.collapse_shape" >,
221
+ !transform.op <" tensor.extract_slice" >)
222
+
223
+ %root = transform.structured.match ops {[" linalg.generic" ]} in %arg1
224
+ : (!transform.any_op ) -> !transform.any_op
225
+ // Tile the lialg operation with parallel forall loop tiling [4, 4].
226
+ %tiled_op , %forall_op = transform.structured.tile_using_forall %root num_threads [4 , 4 ]
227
+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
228
+
229
+ // Fuse the consumer operation into the tiled loop.
230
+ %slice_op = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %forall_op
231
+ : (!transform.any_op ) -> !transform.op <" tensor.parallel_insert_slice" >
232
+ // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice
233
+ // is not qualified consumer operation. Forcing this will yeild "could not fetch consumer
234
+ // to fuse" error.
235
+ transform.yield
236
+ }
237
+ }
238
+ }
0 commit comments