Skip to content

Commit 618f231

Browse files
authored
[MLIR][Transform] Consolidate result of structured.split into one list (#111171)
Follow-up a review comment from #82792 (comment) as a separate PR: E.g.: ``` %0:2 = transform.structured.split ``` is changed to ``` %t = transform.structured.split %0:2 = transform.split_handle %t ```
1 parent 98daf22 commit 618f231

File tree

9 files changed

+50
-37
lines changed

9 files changed

+50
-37
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -703,8 +703,10 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
703703
{ target_size = 10, dimension = 1 }
704704
: !transform.any_op, !transform.param<i64>,
705705
!transform.param<i64>, !transform.param<i64>
706-
%low, %high = structured.split %target after %split { dimension = 1 }
706+
%handles = structured.split %target after %split { dimension = 1 }
707707
: !transform.any_op, !transform.param<i64>
708+
%low, %high = transform.split_handle %handles : (!transform.any_op)
709+
-> (!transform.any_op, !transform.any_op)
708710
%tiled_low, %loop1 = structured.tile_using_for %low [0, %sz1]
709711
: (!transform.any_op, !transform.param<i64>)
710712
-> (!transform.any_op, !transform.any_op)
@@ -1452,30 +1454,32 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
14521454
operations pointed to by the target handle.
14531455

14541456
The operation consumes the target handle, but preserves the chunk size
1455-
handle if provided. Without the `multiway` attribute, it produces two
1456-
new handles pointing to the two parts of the structured op after splitting,
1457-
in the same order as the target operand, with the first handle
1458-
corresponding to the part with lower iteration space indices.
1457+
handle if provided. Without the `multiway` attribute, it produces a
1458+
new handle that is a list of the two parts of the structured op after
1459+
splitting, whose lower index part corresponding to the part with lower
1460+
iteration space indices.
14591461

14601462
Multiway split mode is enabled by specifying the `multiway` attribute.
14611463
In this mode a single `target` op is split into multiple parts covering
14621464
the iteration space of the specified dimension. `static_chunk_sizes` and
14631465
`dynamic_chunk_sizes` in this case is a list of chunk sizes that the given
1464-
dimension should be split into. With `multiway` it produces two handles;
1465-
the first handle is a list of the multiple parts of the structured op
1466+
dimension should be split into. With `multiway` it also produces a handle;
1467+
The result handle is a list of the multiple parts of the structured op
14661468
after splitting, where the target dimensions for each linalg op in the
14671469
list corresponds to the chunk sizes specfied in the input split list.
14681470
If the chunk sizes do not cover the entire iteration space, the leftover
1469-
chunk is the last payload in the first handle. The second handle is empty.
1471+
chunk is the last payload in the result handle.
1472+
1473+
As the result handle is most of time a list, an `transform.split_handle`
1474+
is needed to access individual handle.
14701475
}];
14711476

14721477
let arguments = (ins TransformHandleTypeInterface:$target,
14731478
I64Attr:$dimension,
14741479
Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_chunk_sizes,
14751480
I64Attr:$static_chunk_sizes,
14761481
UnitAttr:$multiway);
1477-
let results = (outs TransformHandleTypeInterface:$first,
1478-
TransformHandleTypeInterface:$second);
1482+
let results = (outs TransformHandleTypeInterface:$split_list);
14791483
let hasCustomAssemblyFormat = 1;
14801484
let hasVerifier = 1;
14811485
}

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2363,10 +2363,10 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
23632363
return DiagnosedSilenceableFailure::success();
23642364
};
23652365

2366+
SmallVector<Operation *> opList;
23662367
if (isMultiwaySplit) {
23672368

23682369
// Split a single target operation at multiple points.
2369-
SmallVector<Operation *> opList;
23702370
TilingInterface head, tail;
23712371
Operation *target = payload.front();
23722372

@@ -2406,8 +2406,6 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
24062406
// Append any leftover parts to the end of the result list.
24072407
if (tail)
24082408
opList.push_back(tail.getOperation());
2409-
results.set(cast<OpResult>(getFirst()), opList);
2410-
results.set(cast<OpResult>(getSecond()), {});
24112409

24122410
} else {
24132411
// Split each target operation.
@@ -2453,9 +2451,11 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
24532451
return diag;
24542452
}
24552453

2456-
results.set(cast<OpResult>(getFirst()), first);
2457-
results.set(cast<OpResult>(getSecond()), second);
2454+
opList.append(first);
2455+
if (second.size())
2456+
opList.append(second);
24582457
}
2458+
results.set(cast<OpResult>(getSplitList()), opList);
24592459
return DiagnosedSilenceableFailure::success();
24602460
}
24612461

@@ -2507,7 +2507,7 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
25072507
result.addAttribute(
25082508
SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
25092509
staticChunkSizes);
2510-
result.addTypes({targetType, targetType});
2510+
result.addTypes(targetType);
25112511
return success();
25122512
}
25132513

mlir/python/mlir/dialects/transform/structured.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,6 @@ def __init__(
445445
dynamic_chunk_sizes = chunk_sizes
446446

447447
super().__init__(
448-
target.type,
449448
target.type,
450449
target,
451450
dimension=dimension,

mlir/test/Dialect/Linalg/continuous-tiling-full.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ module attributes {transform.with_named_sequence} {
44
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
55
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
66
%tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op
7-
%linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
7+
%linalg_splits = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
88
transform.foreach %linalg_splits, %tile_sizes : !transform.any_op, !transform.any_op {
99
^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op):
1010
%tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
@@ -65,7 +65,7 @@ module attributes {transform.with_named_sequence} {
6565
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
6666
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
6767
%tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.param<i64>
68-
%linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.param<i64>
68+
%linalg_splits = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.param<i64>
6969
transform.foreach %linalg_splits, %tile_sizes : !transform.any_op, !transform.param<i64> {
7070
^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.param<i64>):
7171
%tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
@@ -126,7 +126,7 @@ module attributes {transform.with_named_sequence} {
126126
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
127127
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
128128
%tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op
129-
%linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
129+
%linalg_splits = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
130130
transform.foreach %linalg_splits, %tile_sizes with_zip_shortest : !transform.any_op, !transform.any_op {
131131
^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op):
132132
%tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
@@ -177,4 +177,4 @@ func.func @continuous_tile_dynamic_linalg_matmul(
177177
// CHECK: %[[AM16:.*]] = affine.min #[[$MAP12]]()[%{{.*}}, %{{.*}}, %[[AM0]], %[[AM4]], %[[AM8]], %[[AM12]]]
178178
// CHECK: %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM16]] step %[[C1]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
179179
// CHECK: %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<1x?xf32>) -> tensor<1x?xf32>
180-
// CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [1, %{{.*}}] [1, 1] : tensor<1x?xf32> into tensor<?x?xf32>
180+
// CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [1, %{{.*}}] [1, 1] : tensor<1x?xf32> into tensor<?x?xf32>

mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ module attributes {transform.with_named_sequence} {
88
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
99
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1010
%tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.any_op
11-
%low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.any_op
11+
%splits2 = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.any_op
1212
transform.yield
1313
}
1414
}
@@ -58,7 +58,7 @@ module attributes {transform.with_named_sequence} {
5858
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
5959
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
6060
%tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.param<i64>
61-
%low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.param<i64>
61+
%splits2 = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.param<i64>
6262
transform.yield
6363
}
6464
}

mlir/test/Dialect/Linalg/multisize-tiling-full.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@ module attributes {transform.with_named_sequence} {
66
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
77
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
88
%1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op
9-
%2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
9+
%split = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
10+
%2:2 = transform.split_handle %split : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
1011
%3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
1112
%4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
1213
%5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
1314
transform.foreach %5 : !transform.any_op {
1415
^bb0(%inner_linalg: !transform.any_op):
1516
%low, %high, %split_point = transform.structured.multitile_sizes %inner_linalg { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
16-
%inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op
17+
%split2 = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op
18+
%inner_linalg_low, %inner_linalg_high = transform.split_handle %split2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
1719
transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
1820
transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
1921
}
@@ -111,14 +113,16 @@ module attributes {transform.with_named_sequence} {
111113
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
112114
%1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.param<i64>
113115
%t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.param<i64>
114-
%2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.param<i64>
116+
%split = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.param<i64>
117+
%2:2 = transform.split_handle %split : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
115118
%3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
116119
%4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
117120
%5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
118121
%tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
119122
transform.foreach %5, %tt#0, %tt#1, %tt#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> {
120123
^bb0(%inner_linalg: !transform.any_op, %low: !transform.param<i64>, %high: !transform.param<i64>, %split_point: !transform.param<i64>):
121-
%inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64>
124+
%split2 = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64>
125+
%inner_linalg_low, %inner_linalg_high = transform.split_handle %split2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
122126
transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
123127
transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
124128
}

mlir/test/Dialect/Linalg/transform-op-split.mlir

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
module attributes {transform.with_named_sequence} {
44
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
55
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
6-
%1:2 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
6+
%1 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
77
transform.yield
88
}
99
}
@@ -53,7 +53,7 @@ func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tenso
5353
module attributes {transform.with_named_sequence} {
5454
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
5555
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
56-
%1:2 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
56+
%1 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
5757
transform.yield
5858
}
5959
}
@@ -138,8 +138,9 @@ func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100
138138
module attributes {transform.with_named_sequence} {
139139
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
140140
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
141-
%1:2 = transform.structured.split %0 after 4 { dimension = 0 } : !transform.any_op
142-
%2:2 = transform.structured.split %1#1 after 16 { dimension = 1 } : !transform.any_op
141+
%t = transform.structured.split %0 after 4 { dimension = 0 } : !transform.any_op
142+
%1:2 = transform.split_handle %t : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
143+
%2 = transform.structured.split %1#1 after 16 { dimension = 1 } : !transform.any_op
143144
transform.yield
144145
}
145146
}
@@ -197,7 +198,7 @@ func.func @two_d(%arg0: tensor<10x34xf32>,
197198
module attributes {transform.with_named_sequence} {
198199
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) {
199200
// expected-error @below {{expects either a dynamic or a static split point to be provided}}
200-
%0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
201+
%0 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op)
201202
transform.yield
202203
}
203204
}
@@ -303,7 +304,7 @@ module attributes {transform.with_named_sequence} {
303304
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
304305
// expected-error @below {{splitting does not produce the second part for a subset of targets}}
305306
// expected-note @below {{expected splitting to produce the second part of all or none of the targets}}
306-
%1:2 = transform.structured.split %0 after 142 { dimension = 0 } : !transform.any_op
307+
%1 = transform.structured.split %0 after 142 { dimension = 0 } : !transform.any_op
307308
transform.yield
308309
}
309310
}

mlir/test/Dialect/Linalg/transform-ops.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ transform.sequence failures(propagate) {
1818

1919
transform.sequence failures(propagate) {
2020
^bb1(%arg0: !transform.any_op):
21-
%0:2 = transform.structured.split %arg0 after 42 { dimension = 0 } : !transform.any_op
21+
%t = transform.structured.split %arg0 after 42 { dimension = 0 } : !transform.any_op
22+
%0:2 = transform.split_handle %t : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
2223
transform.structured.split %0#0 after %0#1 { dimension = 1 } : !transform.any_op, !transform.any_op
2324
}
2425

mlir/test/python/dialects/transform_structured_ext.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,15 @@ def testScalarize(target):
361361
@run
362362
@create_sequence
363363
def testSplit(target):
364-
split = structured.SplitOp(target, dimension=1, chunk_sizes=42)
364+
handle = structured.SplitOp(target, dimension=1, chunk_sizes=42)
365+
split = transform.SplitHandleOp(
366+
[transform.AnyOpType.get(), transform.AnyOpType.get()], handle
367+
)
365368
structured.SplitOp(split.results[0], dimension=3, chunk_sizes=split.results[1])
366369
# CHECK-LABEL: TEST: testSplit
367-
# CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
368-
# CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
370+
# CHECK: %[[G:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
371+
# CHECK: %[[F:.+]]:2 = split_handle %[[G]]
372+
# CHECK: transform.structured.split %[[F]]#0 after %[[F]]#1 {dimension = 3
369373

370374

371375
@run

0 commit comments

Comments
 (0)