Skip to content

Commit e1adf1e

Browse files
committed
[MLIR] Add support for multiway split in SplitOp
Add functionality that enables SplitOp to do a multiway split of a traget linalg along a given dimension. When multiway attribute is `true`, the SplitOp takes a list of split points and applies it to a single linalg along the given dimension to generate multiple linalgs extracted from the target.
1 parent 37f4d82 commit e1adf1e

File tree

2 files changed

+123
-50
lines changed

2 files changed

+123
-50
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,7 +1396,7 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
13961396
DeclareOpInterfaceMethods<TransformOpInterface>,
13971397
ReportTrackingListenerFailuresOpTrait]> {
13981398
let description = [{
1399-
Indicates that the given `target` op should be split into two complementary
1399+
Splits the given `target` op into two or more complementary
14001400
parts, which combined cover the entire iteration domain of the original op.
14011401
The split is performed along the iteration space dimension provided as
14021402
attribute. In case of dimension overflow, the transformation fails. The
@@ -1409,16 +1409,27 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
14091409
operations pointed to by the target handle.
14101410

14111411
The operation consumes the target handle, but preserves the split point
1412-
handle if provided. It produces two new handles pointing to the two parts
1413-
of the structured op after splitting, in the same order as the target
1414-
operand, with the first handle corresponding to the part with lower
1415-
iteration space indices.
1412+
handle if provided. Without the `multiway` attribute, it produces two
1413+
new handles pointing to the two parts of the structured op after splitting,
1414+
in the same order as the target operand, with the first handle
1415+
corresponding to the part with lower iteration space indices.
1416+
1417+
Multiway split mode is enabled by specifying the `multiway` attribute.
1418+
In this mode a single `target` op is split into multiple parts covering
1419+
the iteration space of the specified dimension. `static_split_point` and
1420+
`dynamic_split_point` in this case is a list of chunk sizes that the given
1421+
dimension should be split into. With `multiway` it produces two handles;
1422+
the first handle is a list of the multiple parts of the structured op
1423+
after splitting, where the target dimensions for each linalg op in the
1424+
list corresponds to the chunk sizes specfied in the input split list.
1425+
The second handle is empty.
14161426
}];
14171427

14181428
let arguments = (ins TransformHandleTypeInterface:$target,
14191429
I64Attr:$dimension,
14201430
Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_split_point,
1421-
I64Attr:$static_split_point);
1431+
I64Attr:$static_split_point,
1432+
UnitAttr:$multiway);
14221433
let results = (outs TransformHandleTypeInterface:$first,
14231434
TransformHandleTypeInterface:$second);
14241435
let hasCustomAssemblyFormat = 1;

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 106 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2269,8 +2269,20 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
22692269
// Collect the dynamic split points if provided.
22702270
SmallVector<Operation *> payload =
22712271
llvm::to_vector(state.getPayloadOps(getTarget()));
2272+
2273+
bool isMultiwaySplit = getMultiway() ? true : false;
2274+
2275+
if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2276+
return emitDefiniteFailure() << "requires exactly one target when "
2277+
"multiway split is enabled (got "
2278+
<< llvm::range_size(payload) << ")";
2279+
}
2280+
22722281
SmallVector<OpFoldResult> splitPoints;
2273-
splitPoints.reserve(payload.size());
2282+
2283+
if (!isMultiwaySplit)
2284+
splitPoints.reserve(payload.size());
2285+
22742286
if (getDynamicSplitPoint()) {
22752287
auto diag = DiagnosedSilenceableFailure::success();
22762288
if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
@@ -2293,7 +2305,9 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
22932305
if (diag.isSilenceableFailure())
22942306
return diag;
22952307

2296-
if (splitPoints.size() != payload.size()) {
2308+
// For multiway split, a single payload is expected to have multiple
2309+
// split points.
2310+
if (!isMultiwaySplit && splitPoints.size() != payload.size()) {
22972311
return emitDefiniteFailure()
22982312
<< "expected the dynamic split point handle to point to as "
22992313
"many operations ("
@@ -2305,57 +2319,105 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
23052319
rewriter.getIndexAttr(getStaticSplitPoint()));
23062320
}
23072321

2308-
// Split each target operation.
2309-
SmallVector<Operation *> first, second;
2310-
Operation *noSecondPart = nullptr;
2311-
for (const auto &pair : llvm::zip(payload, splitPoints)) {
2312-
Operation *target = std::get<0>(pair);
2313-
auto linalgOp = dyn_cast<LinalgOp>(target);
2314-
if (!linalgOp) {
2315-
auto diag = emitSilenceableError() << "only applies to structured ops";
2316-
diag.attachNote(target->getLoc()) << "target op";
2317-
return diag;
2318-
}
2322+
if (isMultiwaySplit) {
23192323

2320-
if (getDimension() >= linalgOp.getNumLoops()) {
2321-
auto diag = emitSilenceableError() << "dimension " << getDimension()
2322-
<< " does not exist in target op";
2323-
diag.attachNote(target->getLoc()) << "target op";
2324-
return diag;
2324+
// Split a single target operation at multiple points.
2325+
SmallVector<Operation *> opList;
2326+
Operation *head, *tail;
2327+
for (const auto [idx, splitPoint] : llvm::enumerate(splitPoints)) {
2328+
2329+
Operation *target;
2330+
if (idx == 0)
2331+
target = payload.front();
2332+
else
2333+
target = tail;
2334+
2335+
if (!target)
2336+
break;
2337+
2338+
auto linalgOp = dyn_cast<LinalgOp>(target);
2339+
2340+
if (!linalgOp) {
2341+
auto diag = emitSilenceableError() << "only applies to structured ops";
2342+
diag.attachNote(target->getLoc()) << "target op";
2343+
return diag;
2344+
}
2345+
2346+
if (getDimension() >= linalgOp.getNumLoops()) {
2347+
auto diag = emitSilenceableError() << "dimension " << getDimension()
2348+
<< " does not exist in target op";
2349+
diag.attachNote(target->getLoc()) << "target op";
2350+
return diag;
2351+
}
2352+
2353+
rewriter.setInsertionPoint(linalgOp);
2354+
std::tie(head, tail) = linalg::splitOp(
2355+
rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2356+
getDimension(), splitPoint);
2357+
2358+
opList.push_back(head);
23252359
}
23262360

2327-
rewriter.setInsertionPoint(linalgOp);
2328-
std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2329-
rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2330-
getDimension(), std::get<1>(pair));
2361+
// Append any leftover parts to the end of the result list.
2362+
if (tail)
2363+
opList.push_back(tail);
2364+
results.set(cast<OpResult>(getFirst()), opList);
2365+
results.set(cast<OpResult>(getSecond()), {});
23312366

2332-
// Propagate errors.
2333-
if (!first.back() && !second.back()) {
2334-
auto diag = emitDefiniteFailure() << "internal failure in splitting";
2335-
diag.attachNote(target->getLoc()) << "target op";
2336-
return diag;
2367+
} else {
2368+
// Split each target operation.
2369+
SmallVector<Operation *> first, second;
2370+
Operation *noSecondPart = nullptr;
2371+
for (const auto &pair : llvm::zip(payload, splitPoints)) {
2372+
Operation *target = std::get<0>(pair);
2373+
auto linalgOp = dyn_cast<LinalgOp>(target);
2374+
if (!linalgOp) {
2375+
auto diag = emitSilenceableError() << "only applies to structured ops";
2376+
diag.attachNote(target->getLoc()) << "target op";
2377+
return diag;
2378+
}
2379+
2380+
if (getDimension() >= linalgOp.getNumLoops()) {
2381+
auto diag = emitSilenceableError() << "dimension " << getDimension()
2382+
<< " does not exist in target op";
2383+
diag.attachNote(target->getLoc()) << "target op";
2384+
return diag;
2385+
}
2386+
2387+
rewriter.setInsertionPoint(linalgOp);
2388+
std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2389+
rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2390+
getDimension(), std::get<1>(pair));
2391+
2392+
// Propagate errors.
2393+
if (!first.back() && !second.back()) {
2394+
auto diag = emitDefiniteFailure() << "internal failure in splitting";
2395+
diag.attachNote(target->getLoc()) << "target op";
2396+
return diag;
2397+
}
2398+
2399+
// Do not add null second parts.
2400+
if (!second.back()) {
2401+
noSecondPart = target;
2402+
second.pop_back();
2403+
}
23372404
}
23382405

2339-
// Do not add null second parts.
2340-
if (!second.back()) {
2341-
noSecondPart = target;
2342-
second.pop_back();
2406+
if (second.size() != first.size() && !second.empty()) {
2407+
auto diag = emitSilenceableError()
2408+
<< "splitting does not produce the second part for a subset "
2409+
"of targets";
2410+
diag.attachNote()
2411+
<< "expected splitting to produce the second part of all "
2412+
"or none of the targets";
2413+
diag.attachNote(noSecondPart->getLoc())
2414+
<< "first target with no second part";
2415+
return diag;
23432416
}
2344-
}
23452417

2346-
if (second.size() != first.size() && !second.empty()) {
2347-
auto diag = emitSilenceableError()
2348-
<< "splitting does not produce the second part for a subset "
2349-
"of targets";
2350-
diag.attachNote() << "expected splitting to produce the second part of all "
2351-
"or none of the targets";
2352-
diag.attachNote(noSecondPart->getLoc())
2353-
<< "first target with no second part";
2354-
return diag;
2418+
results.set(cast<OpResult>(getFirst()), first);
2419+
results.set(cast<OpResult>(getSecond()), second);
23552420
}
2356-
2357-
results.set(cast<OpResult>(getFirst()), first);
2358-
results.set(cast<OpResult>(getSecond()), second);
23592421
return DiagnosedSilenceableFailure::success();
23602422
}
23612423

0 commit comments

Comments
 (0)