Skip to content

Commit 96ff025

Browse files
authored
[mlir] cleanup of structured.tile* transform ops (#67320)
Rename and restructure tiling-related transform ops from the structured extension to be more homogeneous. In particular, all ops now follow a consistent naming scheme: - `transform.structured.tile_using_for`; - `transform.structured.tile_using_forall`; - `transform.structured.tile_reduction_using_for`; - `transform.structured.tile_reduction_using_forall`. This drops the "_op" naming artifact from `tile_to_forall_op` that shouldn't have been included in the first place, consistently specifies the name of the control flow op to be produced for loops (instead of `tile_reduction_using_scf` since `scf.forall` also belongs to `scf`), and opts for the `using` connector to avoid ambiguity. The loops produced by tiling are now systematically placed as *trailing* results of the transform op. While this required changing 3 out of 4 ops (except for `tile_using_for`), this is the only choice that makes sense when producing multiple `scf.for` ops that can be associated with a variadic number of handles. This choice is also most consistent with *other* transform ops from the structured extension, in particular with fusion ops, that produce the structured op as the leading result and the loop as the trailing result.
1 parent 9f276d4 commit 96ff025

File tree

52 files changed

+309
-306
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+309
-306
lines changed

mlir/docs/Tutorials/transform/Ch1.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,13 @@ transform.sequence failures(propagate) {
119119
%arg1: !transform.op<"linalg.matmul">,
120120
%arg2: !transform.op<"linalg.elemwise_binary">):
121121
// The actual tiling transformation takes tile sizes as attributes.
122-
%loop, %tiled = transform.structured.tile_to_forall_op %arg1 tile_sizes [4, 32]
122+
%loop, %tiled = transform.structured.tile_using_forall %arg1 tile_sizes [4, 32]
123123
: (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op)
124124
transform.yield
125125
}
126126
```
127127
128-
The transformation returns two handles, as indicated in its [documentation](https://mlir.llvm.org/docs/Dialects/Transform/#transformstructuredtile_to_forall_op-transformtiletoforallop):
128+
The transformation returns two handles, as indicated in its [documentation](https://mlir.llvm.org/docs/Dialects/Transform/#transformstructuredtile_using_forall-transformtiletoforallop):
129129
130130
* A handle to the `scf.forall` “multi-for” loop around tensors.
131131
* A handle to `linalg.generic` operating on the subset of the original data.
@@ -176,7 +176,7 @@ transform.sequence failures(propagate) {
176176
%arg1: !transform.op<"linalg.matmul">,
177177
%arg2: !transform.op<"linalg.elemwise_binary">):
178178
// The actual tiling transformation takes tile sizes as attributes.
179-
%loop, %tiled = transform.structured.tile_to_forall_op %arg1 tile_sizes [4, 32]
179+
%loop, %tiled = transform.structured.tile_using_forall %arg1 tile_sizes [4, 32]
180180
: (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op)
181181
182182
// This is trying to use an invalidated handle leading to undefined behavior.
@@ -203,7 +203,7 @@ matmul.mlir:26:9: note: handle to invalidated ops
203203
%mm = transform.cast %matmul : !transform.op<"linalg.matmul"> to !transform.any_op
204204
^
205205
matmul.mlir:27:19: note: invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them
206-
%loop, %tiled = transform.structured.tile_to_forall_op %mm tile_sizes [4, 32]
206+
%loop, %tiled = transform.structured.tile_using_forall %mm tile_sizes [4, 32]
207207
```
208208
209209
One may observe that some operations such as `transform.cast` do not consume the operand (because they don’t erase the corresponding operation). So what would happen if we tried to use that operand instead?
@@ -219,7 +219,7 @@ transform.sequence failures(propagate) {
219219
to !transform.any_op
220220
221221
// The actual tiling transformation takes tile sizes as attributes.
222-
%loop, %tiled = transform.structured.tile_to_forall_op %arg1 tile_sizes [4, 32]
222+
%loop, %tiled = transform.structured.tile_using_forall %arg1 tile_sizes [4, 32]
223223
: (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op)
224224
225225
// Consuming an operand invalidates the consumed handle and any other handle that is
@@ -240,7 +240,7 @@ matmul.mlir:21:29: note: handle to invalidated ops
240240
^bb0(%root: !transform.any_op, %matmul: !transform.op<"linalg.matmul">, %elemwise: !transform.op<"linalg.elemwise_binary">):
241241
^
242242
matmul.mlir:27:19: note: invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them
243-
%loop, %tiled = transform.structured.tile_to_forall_op %mm tile_sizes [4, 32]
243+
%loop, %tiled = transform.structured.tile_using_forall %mm tile_sizes [4, 32]
244244
```
245245
246246
## Chaining Transformations with Handles
@@ -262,7 +262,7 @@ transform.sequence failures(propagate) {
262262
// The actual tiling transformation takes tile sizes as attributes. It
263263
// produces a handle to the loop generated during tiling.
264264
%loop, %tiled_max =
265-
transform.structured.tile_to_forall_op %max tile_sizes [8, 32]
265+
transform.structured.tile_using_forall %max tile_sizes [8, 32]
266266
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
267267
268268
// We can now fuse the other operations into the loop. Here, we fuse
@@ -304,7 +304,7 @@ transform.sequence failures(propagate) {
304304
305305
// The actual tiling transformation takes tile sizes as attributes. It
306306
// produces a handle to the loop generated during tiling.
307-
%loop, %tiled = transform.structured.tile_to_forall_op %max tile_sizes [8, 32]
307+
%loop, %tiled = transform.structured.tile_using_forall %max tile_sizes [8, 32]
308308
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
309309
310310
// We can now fuse the other operations into the loop. Here, we fuse
@@ -328,7 +328,7 @@ transform.sequence failures(propagate) {
328328
// dialect. Otherwise, it is difficult to differentiate "add" and "max", both
329329
// of which having the same kind.
330330
%loop_2, %tiled_2 =
331-
transform.structured.tile_to_forall_op %add_fused tile_sizes [4, 4]
331+
transform.structured.tile_using_forall %add_fused tile_sizes [4, 4]
332332
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
333333
%matmul_fused_2, %loop_3 =
334334
transform.structured.fuse_into_containing_op %matmul_fused into %loop_2
@@ -339,7 +339,7 @@ transform.sequence failures(propagate) {
339339
// such as loops, use tiling to size 1 to materialize the outer loop that is
340340
// going to be outlined.
341341
%outline_target, %_ =
342-
transform.structured.tile_to_forall_op %tiled_2 tile_sizes [1]
342+
transform.structured.tile_using_forall %tiled_2 tile_sizes [1]
343343
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
344344
transform.structured.fuse_into_containing_op %matmul_fused_2
345345
into %outline_target
@@ -361,7 +361,7 @@ test/Examples/transform/Ch1/invalidation-2.mlir:109:3: error: op uses a handle i
361361
transform.test_print_remark_at_operand %outline_target, "outlined loop" : !transform.any_op
362362
^
363363
test/Examples/transform/Ch1/invalidation-2.mlir:102:25: note: handle to invalidated ops
364-
%outline_target, %_ = transform.structured.tile_to_forall_op %tiled_2 tile_sizes [1]
364+
%outline_target, %_ = transform.structured.tile_using_forall %tiled_2 tile_sizes [1]
365365
^
366366
test/Examples/transform/Ch1/invalidation-2.mlir:106:18: note: invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them
367367
%func, %call = transform.loop.outline %outline_target {func_name = "outlined"}

mlir/docs/Tutorials/transform/Ch2.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ transform.sequence failures(propagate) {
292292
293293
// The actual tiling transformation takes tile sizes as attributes. It produces a
294294
// handle to the loop generated during tiling.
295-
%loop, %tiled = transform.structured.tile_to_forall_op %max tile_sizes [8, 32]
295+
%loop, %tiled = transform.structured.tile_using_forall %max tile_sizes [8, 32]
296296
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
297297
298298
// We can now fuse the other operations into the loop. Here, we fuse
@@ -311,15 +311,15 @@ transform.sequence failures(propagate) {
311311
// "max" operation. This illustrates the precise targeting with the transform
312312
// dialect. Otherwise, it is difficult to differentiate "add" and "max", both
313313
// of which having the same kind.
314-
%loop_2, %tiled_2 = transform.structured.tile_to_forall_op %add_fused tile_sizes [4, 4]
314+
%loop_2, %tiled_2 = transform.structured.tile_using_forall %add_fused tile_sizes [4, 4]
315315
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
316316
%matmul_fused_2 = transform.structured.fuse_into_containing_op %matmul_fused into %loop_2
317317
: (!transform.any_op, !transform.any_op) -> !transform.any_op
318318
319319
// Since outlining is currently only implemented for region-holding operations
320320
// such as loops, use tiling to size 1 to materialize the outer loop that is
321321
// going to be outlined.
322-
%outline_target, %_ = transform.structured.tile_to_forall_op %tiled_2 tile_sizes [1]
322+
%outline_target, %_ = transform.structured.tile_using_forall %tiled_2 tile_sizes [1]
323323
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
324324
transform.structured.fuse_into_containing_op %matmul_fused_2 into %outline_target
325325
: (!transform.any_op, !transform.any_op) -> !transform.any_op

mlir/docs/Tutorials/transform/ChH.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ Linalg as described below.
218218
the inner loop having at most the given number of iterations. This can be
219219
understood as loop _strip-mining_ or a degenerate case of tiling a single
220220
dimension using any of `linalg.tile_` transform ops. We will be using
221-
`transform.structured.tile_to_forall_op` as this kind of loop is best
221+
`transform.structured.tile_using_forall` as this kind of loop is best
222222
supported by bufferization and can also be turned into a parallel loop later
223223
on. Unlike Halide, this doesn’t add new dimensions to the original
224224
operation, but rather creates a loop around it and rewrites the operation
@@ -275,9 +275,9 @@ The remaining dimensions can be materialized as loops in one transformation.
275275

276276
```mlir
277277
// [n y x c]
278-
%co, %relu2 = transform.structured.tile_to_forall_op %relu
278+
%co, %relu2 = transform.structured.tile_using_forall %relu
279279
tile_sizes [0, 0, 0, 64]
280-
%n_y_xo, %relu3 = transform.structured.tile_to_forall_op %relu2
280+
%n_y_xo, %relu3 = transform.structured.tile_using_forall %relu2
281281
tile_sizes [1, 1, 5, 0]
282282
```
283283

@@ -355,7 +355,7 @@ more than one dimension at the moment of writing.)
355355

356356
```mlir
357357
%rz_ry_rx, %red_fill, %conv4, %comb
358-
= transform.structured.tile_reduction_using_scf %conv3
358+
= transform.structured.tile_reduction_using_for %conv3
359359
// n y x c rz ry rx
360360
by tile_sizes=[0, 0, 0, 0, 1, 1, 1]
361361
```
@@ -386,10 +386,10 @@ dimension:
386386

387387
```mlir
388388
// n y xi ci
389-
%1, %c5 = transform.structured.tile_to_forall_op %conv4 tile_sizes [0, 0, 1, 16]
390-
%2, %b4 = transform.structured.tile_to_forall_op %bias3 tile_sizes [0, 0, 1, 16]
391-
%3, %r4 = transform.structured.tile_to_forall_op %relu3 tile_sizes [0, 0, 1, 16]
392-
%4, %c2 = transform.structured.tile_to_forall_op %comb tile_sizes [0, 0, 1, 16]
389+
%1, %c5 = transform.structured.tile_using_forall %conv4 tile_sizes [0, 0, 1, 16]
390+
%2, %b4 = transform.structured.tile_using_forall %bias3 tile_sizes [0, 0, 1, 16]
391+
%3, %r4 = transform.structured.tile_using_forall %relu3 tile_sizes [0, 0, 1, 16]
392+
%4, %c2 = transform.structured.tile_using_forall %comb tile_sizes [0, 0, 1, 16]
393393
```
394394

395395
Note that the combiner operation produced by reduction tiling is also tiled here.
@@ -638,7 +638,7 @@ bufferization invalidates all loop handles including to loops that we are
638638
willing to unroll. This hurdle can be overcome by matching the payload IR
639639
operations after bufferization to produce new handles. We will first change the
640640
kind of loops produced in the schedule from `scf.for` to `scf.forall` to have
641-
less operations to match by using `transform.structured.tile_to_forall_op`
641+
less operations to match by using `transform.structured.tile_using_forall`
642642
instead of `transform.structured.tile` when tiling with sizes `[0, 0, 1, 16]`.
643643
Then we can match all `scf.forall` operations in the payload IR and transform
644644
them into single-iterator `scf.for` loops _after bufferization_.

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -618,10 +618,10 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
618618
!transform.param<i64>, !transform.param<i64>
619619
%low, %high = structured.split %target after %split { dimension = 1 }
620620
: !transform.any_op, !transform.param<i64>
621-
%tiled_low, %loop1 = structured.tile %low [0, %sz1]
621+
%tiled_low, %loop1 = structured.tile_using_for %low [0, %sz1]
622622
: (!transform.any_op, !transform.param<i64>)
623623
-> (!transform.any_op, !transform.any_op)
624-
%tiled_high, %loop2 = structured.tile %high [0, %sz2]
624+
%tiled_high, %loop2 = structured.tile_using_for %high [0, %sz2]
625625
: (!transform.any_op, !transform.param<i64>)
626626
-> (!transform.any_op, !transform.any_op)
627627
%common = merge_handles %tiled_low, %tiled_high : !transform.any_op
@@ -1514,10 +1514,10 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
15141514
}
15151515

15161516
//===----------------------------------------------------------------------===//
1517-
// TileReductionUsingScfOp
1517+
// TileReductionUsingForOp
15181518
//===----------------------------------------------------------------------===//
15191519

1520-
def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_using_scf",
1520+
def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_using_for",
15211521
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
15221522
TransformEachOpTrait, TransformOpInterface,
15231523
ReportTrackingListenerFailuresOpTrait]> {
@@ -1536,11 +1536,11 @@ def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_u
15361536

15371537
#### Return modes
15381538

1539-
This 4 returned handles point to:
1540-
- the parent for op,
1539+
Returns 4 handles associated with (in order):
15411540
- the fill op used to initialize the neutral element,
15421541
- the parallel tiled op and
1543-
- the result-combining op.
1542+
- the result-combining op,
1543+
- the parent `for` op.
15441544

15451545
#### Example:
15461546

@@ -1590,13 +1590,13 @@ def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_u
15901590
```
15911591
}];
15921592

1593-
// TODO: support mixed static-dynamic (see TileToForallOp).
1593+
// TODO: support mixed static-dynamic (see TileUsingForallOp).
15941594
let arguments = (ins TransformHandleTypeInterface:$target,
15951595
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
1596-
let results = (outs TransformHandleTypeInterface:$for_op,
1597-
TransformHandleTypeInterface:$fill_op,
1596+
let results = (outs TransformHandleTypeInterface:$fill_op,
15981597
TransformHandleTypeInterface:$split_linalg_op,
1599-
TransformHandleTypeInterface:$combining_linalg_op);
1598+
TransformHandleTypeInterface:$combining_linalg_op,
1599+
TransformHandleTypeInterface:$for_op);
16001600

16011601
let builders = [
16021602
OpBuilder<(ins "Value":$target,
@@ -1644,11 +1644,11 @@ def TileReductionUsingForallOp :
16441644

16451645
#### Return modes
16461646

1647-
This 4 returned handles point to:
1648-
- the parent forall op,
1647+
Returns 4 handles associated with (in order):
16491648
- the fill op used to initialize the neutral element,
16501649
- the parallel tiled op and
1651-
- the result-combining op.
1650+
- the result-combining op,
1651+
- the parent `forall` op.
16521652

16531653
#### Example:
16541654

@@ -1694,15 +1694,15 @@ def TileReductionUsingForallOp :
16941694
```
16951695
}];
16961696

1697-
// TODO: support mixed static-dynamic (see TileToForallOp).
1697+
// TODO: support mixed static-dynamic (see TileUsingForallOp).
16981698
let arguments = (ins TransformHandleTypeInterface:$target,
16991699
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads,
17001700
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
17011701
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
1702-
let results = (outs TransformHandleTypeInterface:$forall_op,
1703-
TransformHandleTypeInterface:$fill_op,
1702+
let results = (outs TransformHandleTypeInterface:$fill_op,
17041703
TransformHandleTypeInterface:$split_linalg_op,
1705-
TransformHandleTypeInterface:$combining_linalg_op);
1704+
TransformHandleTypeInterface:$combining_linalg_op,
1705+
TransformHandleTypeInterface:$forall_op);
17061706

17071707
let builders = [
17081708
OpBuilder<(ins "Value":$target,
@@ -1732,10 +1732,10 @@ def TileReductionUsingForallOp :
17321732
}
17331733

17341734
//===----------------------------------------------------------------------===//
1735-
// TileOp
1735+
// TileUsingForOp
17361736
//===----------------------------------------------------------------------===//
17371737

1738-
def TileOp : Op<Transform_Dialect, "structured.tile",
1738+
def TileUsingForOp : Op<Transform_Dialect, "structured.tile_using_for",
17391739
[DeclareOpInterfaceMethods<TransformOpInterface>,
17401740
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
17411741
ReportTrackingListenerFailuresOpTrait]> {
@@ -1820,11 +1820,11 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
18201820
}
18211821

18221822
//===----------------------------------------------------------------------===//
1823-
// TileToForallOp
1823+
// TileUsingForallOp
18241824
//===----------------------------------------------------------------------===//
18251825

1826-
def TileToForallOp :
1827-
Op<Transform_Dialect, "structured.tile_to_forall_op",
1826+
def TileUsingForallOp :
1827+
Op<Transform_Dialect, "structured.tile_using_forall",
18281828
[AttrSizedOperandSegments,
18291829
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
18301830
TransformOpInterface, ReportTrackingListenerFailuresOpTrait]> {
@@ -1834,9 +1834,9 @@ def TileToForallOp :
18341834
Tiling is applied by either specifying `num_threads` or `tile_size`. If
18351835
`num_threads` is specified, then the tile size for each dimension `i` is
18361836
calculated dynamically via `ceilDiv(dimSize[i], num_threads[i])`.
1837-
`num_threads` and `tile_size` can be either static index attributes or SSA
1838-
values of PDL operation handle type (or a mix thereof). Operation handles
1839-
must be mapped to exactly one op that has exactly one result of index type.
1837+
`num_threads` and `tile_size` can be either static index attributes or
1838+
operation handles (or a mix thereof). Operation handles must be mapped to
1839+
exactly one op that has exactly one result of index type.
18401840

18411841
Static zero tile sizes indicate that the dimension is not tiled and can be
18421842
thought of as tiling by the full size of data.
@@ -1872,15 +1872,15 @@ def TileToForallOp :
18721872

18731873
```
18741874
%0 = pdl_match @match_matmul in %arg1
1875-
%3:2 = transform.structured.tile_to_forall_op %0 num_threads [10, 20]
1875+
%3:2 = transform.structured.tile_using_forall %0 num_threads [10, 20]
18761876
```
18771877

18781878
#### Example using `tile_sizes`
18791879

18801880
```
18811881
%0 = pdl_match @match_matmul in %arg1
18821882
%sz = pdl_match @match_size_op in %arg1
1883-
%3:2 = transform.structured.tile_to_forall_op %0 tile_sizes [0, %sz, 20]
1883+
%3:2 = transform.structured.tile_using_forall %0 tile_sizes [0, %sz, 20]
18841884
```
18851885
}];
18861886

@@ -1892,8 +1892,8 @@ def TileToForallOp :
18921892
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_num_threads,
18931893
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_tile_sizes,
18941894
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
1895-
let results = (outs TransformHandleTypeInterface:$forall_op,
1896-
TransformHandleTypeInterface:$tiled_op);
1895+
let results = (outs TransformHandleTypeInterface:$tiled_op,
1896+
TransformHandleTypeInterface:$forall_op);
18971897

18981898
let builders = [
18991899
OpBuilder<(ins "Value":$target,

0 commit comments

Comments
 (0)