Skip to content

Commit 861d563

Browse files
committed
[MLIR][Transform] Consolidate result of structured.split into one list
E.g.: ``` %0:2 = transform.structured.split ``` is changed to ``` %t = transform.structured.split %0:2 = transform.split_handle %t ```
1 parent 4da4fac commit 861d563

File tree

9 files changed

+48
-37
lines changed

9 files changed

+48
-37
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -676,8 +676,10 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
676676
{ target_size = 10, dimension = 1 }
677677
: !transform.any_op, !transform.param<i64>,
678678
!transform.param<i64>, !transform.param<i64>
679-
%low, %high = structured.split %target after %split { dimension = 1 }
679+
%handles = structured.split %target after %split { dimension = 1 }
680680
: !transform.any_op, !transform.param<i64>
681+
%low, %high = transform.split_handle %handles : (!transform.any_op)
682+
-> (!transform.any_op, !transform.any_op)
681683
%tiled_low, %loop1 = structured.tile_using_for %low [0, %sz1]
682684
: (!transform.any_op, !transform.param<i64>)
683685
-> (!transform.any_op, !transform.any_op)
@@ -1422,30 +1424,32 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
14221424
operations pointed to by the target handle.
14231425

14241426
The operation consumes the target handle, but preserves the chunk size
1425-
handle if provided. Without the `multiway` attribute, it produces two
1426-
new handles pointing to the two parts of the structured op after splitting,
1427-
in the same order as the target operand, with the first handle
1428-
corresponding to the part with lower iteration space indices.
1427+
handle if provided. Without the `multiway` attribute, it produces a
1428+
new handle that is a list of the two parts of the structured op after
1429+
splitting, whose lower index part corresponding to the part with lower
1430+
iteration space indices.
14291431

14301432
Multiway split mode is enabled by specifying the `multiway` attribute.
14311433
In this mode a single `target` op is split into multiple parts covering
14321434
the iteration space of the specified dimension. `static_chunk_sizes` and
14331435
`dynamic_chunk_sizes` in this case is a list of chunk sizes that the given
1434-
dimension should be split into. With `multiway` it produces two handles;
1435-
the first handle is a list of the multiple parts of the structured op
1436+
dimension should be split into. With `multiway` it also produces a handle;
1437+
The result handle is a list of the multiple parts of the structured op
14361438
after splitting, where the target dimensions for each linalg op in the
14371439
list corresponds to the chunk sizes specfied in the input split list.
14381440
If the chunk sizes do not cover the entire iteration space, the leftover
1439-
chunk is the last payload in the first handle. The second handle is empty.
1441+
chunk is the last payload in the result handle.
1442+
1443+
As the result handle is most of time a list, an `transform.split_handle`
1444+
is needed to access individual handle.
14401445
}];
14411446

14421447
let arguments = (ins TransformHandleTypeInterface:$target,
14431448
I64Attr:$dimension,
14441449
Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_chunk_sizes,
14451450
I64Attr:$static_chunk_sizes,
14461451
UnitAttr:$multiway);
1447-
let results = (outs TransformHandleTypeInterface:$first,
1448-
TransformHandleTypeInterface:$second);
1452+
let results = (outs TransformHandleTypeInterface:$split_list);
14491453
let hasCustomAssemblyFormat = 1;
14501454
let hasVerifier = 1;
14511455
}

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2348,10 +2348,10 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
23482348
return DiagnosedSilenceableFailure::success();
23492349
};
23502350

2351+
SmallVector<Operation *> opList;
23512352
if (isMultiwaySplit) {
23522353

23532354
// Split a single target operation at multiple points.
2354-
SmallVector<Operation *> opList;
23552355
TilingInterface head, tail;
23562356
Operation *target = payload.front();
23572357

@@ -2391,8 +2391,6 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
23912391
// Append any leftover parts to the end of the result list.
23922392
if (tail)
23932393
opList.push_back(tail.getOperation());
2394-
results.set(cast<OpResult>(getFirst()), opList);
2395-
results.set(cast<OpResult>(getSecond()), {});
23962394

23972395
} else {
23982396
// Split each target operation.
@@ -2438,9 +2436,11 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
24382436
return diag;
24392437
}
24402438

2441-
results.set(cast<OpResult>(getFirst()), first);
2442-
results.set(cast<OpResult>(getSecond()), second);
2439+
opList.append(first);
2440+
if (second.size())
2441+
opList.append(second);
24432442
}
2443+
results.set(cast<OpResult>(getSplitList()), opList);
24442444
return DiagnosedSilenceableFailure::success();
24452445
}
24462446

@@ -2492,7 +2492,7 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
24922492
result.addAttribute(
24932493
SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
24942494
staticChunkSizes);
2495-
result.addTypes({targetType, targetType});
2495+
result.addTypes(targetType);
24962496
return success();
24972497
}
24982498

mlir/python/mlir/dialects/transform/structured.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,6 @@ def __init__(
445445
dynamic_chunk_sizes = chunk_sizes
446446

447447
super().__init__(
448-
target.type,
449448
target.type,
450449
target,
451450
dimension=dimension,

mlir/test/Dialect/Linalg/continuous-tiling-full.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ module attributes {transform.with_named_sequence} {
44
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
55
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
66
%tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op
7-
%linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
7+
%linalg_splits = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
88
transform.foreach %linalg_splits, %tile_sizes : !transform.any_op, !transform.any_op {
99
^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op):
1010
%tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
@@ -65,7 +65,7 @@ module attributes {transform.with_named_sequence} {
6565
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
6666
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
6767
%tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.param<i64>
68-
%linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.param<i64>
68+
%linalg_splits = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.param<i64>
6969
transform.foreach %linalg_splits, %tile_sizes : !transform.any_op, !transform.param<i64> {
7070
^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.param<i64>):
7171
%tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
@@ -126,7 +126,7 @@ module attributes {transform.with_named_sequence} {
126126
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
127127
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
128128
%tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op
129-
%linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
129+
%linalg_splits = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
130130
transform.foreach %linalg_splits, %tile_sizes with_zip_shortest : !transform.any_op, !transform.any_op {
131131
^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op):
132132
%tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
@@ -177,4 +177,4 @@ func.func @continuous_tile_dynamic_linalg_matmul(
177177
// CHECK: %[[AM16:.*]] = affine.min #[[$MAP12]]()[%{{.*}}, %{{.*}}, %[[AM0]], %[[AM4]], %[[AM8]], %[[AM12]]]
178178
// CHECK: %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM16]] step %[[C1]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
179179
// CHECK: %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<1x?xf32>) -> tensor<1x?xf32>
180-
// CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [1, %{{.*}}] [1, 1] : tensor<1x?xf32> into tensor<?x?xf32>
180+
// CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [1, %{{.*}}] [1, 1] : tensor<1x?xf32> into tensor<?x?xf32>

mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ module attributes {transform.with_named_sequence} {
88
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
99
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1010
%tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.any_op
11-
%low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.any_op
11+
%splits2 = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.any_op
1212
transform.yield
1313
}
1414
}
@@ -58,7 +58,7 @@ module attributes {transform.with_named_sequence} {
5858
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
5959
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
6060
%tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.param<i64>
61-
%low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.param<i64>
61+
%splits2 = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.param<i64>
6262
transform.yield
6363
}
6464
}

mlir/test/Dialect/Linalg/multisize-tiling-full.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@ module attributes {transform.with_named_sequence} {
66
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
77
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
88
%1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.any_op
9-
%2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
9+
%split = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.any_op
10+
%2:2 = transform.split_handle %split : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
1011
%3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
1112
%4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
1213
%5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
1314
transform.foreach %5 : !transform.any_op {
1415
^bb0(%inner_linalg: !transform.any_op):
1516
%low, %high, %split_point = transform.structured.multitile_sizes %inner_linalg { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.any_op
16-
%inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op
17+
%split2 = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.any_op
18+
%inner_linalg_low, %inner_linalg_high = transform.split_handle %split2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
1719
transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
1820
transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
1921
}
@@ -111,14 +113,16 @@ module attributes {transform.with_named_sequence} {
111113
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
112114
%1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!transform.any_op) -> !transform.param<i64>
113115
%t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!transform.any_op) -> !transform.param<i64>
114-
%2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.param<i64>
116+
%split = transform.structured.split %0 after %1#2 { dimension = 0 } : !transform.any_op, !transform.param<i64>
117+
%2:2 = transform.split_handle %split : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
115118
%3:2 = transform.structured.tile_using_for %2#0 tile_sizes [%1#0] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
116119
%4:2 = transform.structured.tile_using_for %2#1 tile_sizes [%1#1] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
117120
%5 = transform.merge_handles %3#0, %4#0 : !transform.any_op
118121
%tt:3 = transform.replicate num(%5) %t#0, %t#1, %t#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
119122
transform.foreach %5, %tt#0, %tt#1, %tt#2 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64> {
120123
^bb0(%inner_linalg: !transform.any_op, %low: !transform.param<i64>, %high: !transform.param<i64>, %split_point: !transform.param<i64>):
121-
%inner_linalg_low, %inner_linalg_high = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64>
124+
%split2 = transform.structured.split %inner_linalg after %split_point { dimension = 1 } : !transform.any_op, !transform.param<i64>
125+
%inner_linalg_low, %inner_linalg_high = transform.split_handle %split2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
122126
transform.structured.tile_using_for %inner_linalg_low tile_sizes [0, %low] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
123127
transform.structured.tile_using_for %inner_linalg_high tile_sizes [0, %high] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
124128
}

mlir/test/Dialect/Linalg/transform-op-split.mlir

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
module attributes {transform.with_named_sequence} {
44
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
55
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
6-
%1:2 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
6+
%1 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
77
transform.yield
88
}
99
}
@@ -53,7 +53,7 @@ func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tenso
5353
module attributes {transform.with_named_sequence} {
5454
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
5555
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
56-
%1:2 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
56+
%1 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
5757
transform.yield
5858
}
5959
}
@@ -138,8 +138,9 @@ func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100
138138
module attributes {transform.with_named_sequence} {
139139
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
140140
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
141-
%1:2 = transform.structured.split %0 after 4 { dimension = 0 } : !transform.any_op
142-
%2:2 = transform.structured.split %1#1 after 16 { dimension = 1 } : !transform.any_op
141+
%t = transform.structured.split %0 after 4 { dimension = 0 } : !transform.any_op
142+
%1:2 = transform.split_handle %t : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
143+
%2 = transform.structured.split %1#1 after 16 { dimension = 1 } : !transform.any_op
143144
transform.yield
144145
}
145146
}
@@ -197,7 +198,7 @@ func.func @two_d(%arg0: tensor<10x34xf32>,
197198
module attributes {transform.with_named_sequence} {
198199
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) {
199200
// expected-error @below {{expects either a dynamic or a static split point to be provided}}
200-
%0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
201+
%0 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op)
201202
transform.yield
202203
}
203204
}
@@ -303,7 +304,7 @@ module attributes {transform.with_named_sequence} {
303304
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
304305
// expected-error @below {{splitting does not produce the second part for a subset of targets}}
305306
// expected-note @below {{expected splitting to produce the second part of all or none of the targets}}
306-
%1:2 = transform.structured.split %0 after 142 { dimension = 0 } : !transform.any_op
307+
%1 = transform.structured.split %0 after 142 { dimension = 0 } : !transform.any_op
307308
transform.yield
308309
}
309310
}

mlir/test/Dialect/Linalg/transform-ops.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ transform.sequence failures(propagate) {
1818

1919
transform.sequence failures(propagate) {
2020
^bb1(%arg0: !transform.any_op):
21-
%0:2 = transform.structured.split %arg0 after 42 { dimension = 0 } : !transform.any_op
21+
%t = transform.structured.split %arg0 after 42 { dimension = 0 } : !transform.any_op
22+
%0:2 = transform.split_handle %t : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
2223
transform.structured.split %0#0 after %0#1 { dimension = 1 } : !transform.any_op, !transform.any_op
2324
}
2425

mlir/test/python/dialects/transform_structured_ext.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,13 @@ def testScalarize(target):
361361
@run
362362
@create_sequence
363363
def testSplit(target):
364-
split = structured.SplitOp(target, dimension=1, chunk_sizes=42)
364+
handle = structured.SplitOp(target, dimension=1, chunk_sizes=42)
365+
split = transform.SplitHandleOp([transform.AnyOpType.get(), transform.AnyOpType.get()], handle)
365366
structured.SplitOp(split.results[0], dimension=3, chunk_sizes=split.results[1])
366367
# CHECK-LABEL: TEST: testSplit
367-
# CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
368-
# CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
368+
# CHECK: %[[G:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
369+
# CHECK: %[[F:.+]]:2 = split_handle %[[G]]
370+
# CHECK: transform.structured.split %[[F]]#0 after %[[F]]#1 {dimension = 3
369371

370372

371373
@run

0 commit comments

Comments
 (0)