Skip to content

Commit a1f1c72

Browse files
committed
[mlir][Transform] Add support for multiway split in SplitOp
Add functionality that enables SplitOp to do a multiway split of a traget op along a given dimension. With multiway attribute, SplitOp takes a list of chunk sizes and applies it to a single target along the given dimension to generate multiple structured ops extracted from the target.
1 parent 77a93d6 commit a1f1c72

File tree

5 files changed

+190
-94
lines changed

5 files changed

+190
-94
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,29 +1396,43 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
13961396
DeclareOpInterfaceMethods<TransformOpInterface>,
13971397
ReportTrackingListenerFailuresOpTrait]> {
13981398
let description = [{
1399-
Indicates that the given `target` op should be split into two complementary
1399+
Splits the given `target` op into two or more complementary
14001400
parts, which combined cover the entire iteration domain of the original op.
14011401
The split is performed along the iteration space dimension provided as
1402-
attribute. In case of dimension overflow, the transformation fails. The
1403-
split is performed at the dimension iterator value specified as either the
1404-
static split point attribute when it is known at transform IR construction
1405-
time or as the handle to an operation producing a single index-typed value
1406-
when it is computed by payload IR. In the latter case, the static split
1402+
chunk size attribute specifying the size of the lower part; the remaining
1403+
range in the iteration space is assigned as the upper part. In case of
1404+
dimension overflow, the transformation fails. The split is performed at the
1405+
dimension iterator value specified as either the static chunk size
1406+
attribute when it is known at transform IR construction time or
1407+
as the handle to an operation producing a single index-typed value
1408+
when it is computed by payload IR. In the latter case, the chunk size
14071409
point must be set to `ShapedType::kDynamic` and the dynamic size handle
14081410
must point to as many value-producing operations as there are structured
14091411
operations pointed to by the target handle.
14101412

1411-
The operation consumes the target handle, but preserves the split point
1412-
handle if provided. It produces two new handles pointing to the two parts
1413-
of the structured op after splitting, in the same order as the target
1414-
operand, with the first handle corresponding to the part with lower
1415-
iteration space indices.
1413+
The operation consumes the target handle, but preserves the chunk size
1414+
handle if provided. Without the `multiway` attribute, it produces two
1415+
new handles pointing to the two parts of the structured op after splitting,
1416+
in the same order as the target operand, with the first handle
1417+
corresponding to the part with lower iteration space indices.
1418+
1419+
Multiway split mode is enabled by specifying the `multiway` attribute.
1420+
In this mode a single `target` op is split into multiple parts covering
1421+
the iteration space of the specified dimension. `static_chunk_sizes` and
1422+
`dynamic_chunk_sizes` in this case is a list of chunk sizes that the given
1423+
dimension should be split into. With `multiway` it produces two handles;
1424+
the first handle is a list of the multiple parts of the structured op
1425+
after splitting, where the target dimensions for each linalg op in the
1426+
list corresponds to the chunk sizes specfied in the input split list.
1427+
If the chunk sizes do not cover the entire iteration space, the leftover
1428+
chunk is the last payload in the first handle. The second handle is empty.
14161429
}];
14171430

14181431
let arguments = (ins TransformHandleTypeInterface:$target,
14191432
I64Attr:$dimension,
1420-
Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_split_point,
1421-
I64Attr:$static_split_point);
1433+
Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_chunk_sizes,
1434+
I64Attr:$static_chunk_sizes,
1435+
UnitAttr:$multiway);
14221436
let results = (outs TransformHandleTypeInterface:$first,
14231437
TransformHandleTypeInterface:$second);
14241438
let hasCustomAssemblyFormat = 1;

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 152 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2266,13 +2266,26 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
22662266
// Collect the dynamic split points if provided.
22672267
SmallVector<Operation *> payload =
22682268
llvm::to_vector(state.getPayloadOps(getTarget()));
2269-
SmallVector<OpFoldResult> splitPoints;
2270-
splitPoints.reserve(payload.size());
2271-
if (getDynamicSplitPoint()) {
2269+
2270+
bool isMultiwaySplit = getMultiway();
2271+
2272+
if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2273+
return mlir::emitSilenceableFailure(getLoc())
2274+
<< "requires exactly one target when "
2275+
"multiway split is enabled (got "
2276+
<< llvm::range_size(payload) << ")";
2277+
}
2278+
2279+
SmallVector<OpFoldResult> chunkSizes;
2280+
2281+
if (!isMultiwaySplit)
2282+
chunkSizes.reserve(payload.size());
2283+
2284+
if (getDynamicChunkSizes()) {
22722285
auto diag = DiagnosedSilenceableFailure::success();
2273-
if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
2274-
splitPoints = llvm::to_vector(llvm::map_range(
2275-
state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
2286+
if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
2287+
chunkSizes = llvm::to_vector(llvm::map_range(
2288+
state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
22762289
if (op->getNumResults() != 1 ||
22772290
!op->getResult(0).getType().isIndex()) {
22782291
diag = emitSilenceableError()
@@ -2283,103 +2296,172 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
22832296
return OpFoldResult(op->getResult(0));
22842297
}));
22852298
} else {
2286-
splitPoints = llvm::to_vector(
2287-
llvm::map_range(state.getParams(getDynamicSplitPoint()),
2299+
chunkSizes = llvm::to_vector(
2300+
llvm::map_range(state.getParams(getDynamicChunkSizes()),
22882301
[](Attribute attr) { return OpFoldResult(attr); }));
22892302
}
22902303
if (diag.isSilenceableFailure())
22912304
return diag;
22922305

2293-
if (splitPoints.size() != payload.size()) {
2306+
// For multiway split, a single payload is expected to have multiple
2307+
// split points.
2308+
if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
22942309
return emitDefiniteFailure()
22952310
<< "expected the dynamic split point handle to point to as "
22962311
"many operations ("
2297-
<< splitPoints.size() << ") as the target handle ("
2312+
<< chunkSizes.size() << ") as the target handle ("
22982313
<< payload.size() << ")";
22992314
}
23002315
} else {
2301-
splitPoints.resize(payload.size(),
2302-
rewriter.getIndexAttr(getStaticSplitPoint()));
2316+
chunkSizes.resize(payload.size(),
2317+
rewriter.getIndexAttr(getStaticChunkSizes()));
23032318
}
23042319

2305-
// Split each target operation.
2306-
SmallVector<Operation *> first, second;
2307-
Operation *noSecondPart = nullptr;
2308-
for (const auto &pair : llvm::zip(payload, splitPoints)) {
2309-
Operation *target = std::get<0>(pair);
2310-
auto linalgOp = dyn_cast<LinalgOp>(target);
2320+
auto checkStructuredOpAndDimensions =
2321+
[&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
23112322
if (!linalgOp) {
23122323
auto diag = emitSilenceableError() << "only applies to structured ops";
2313-
diag.attachNote(target->getLoc()) << "target op";
2324+
diag.attachNote(loc) << "target op";
23142325
return diag;
23152326
}
23162327

23172328
if (getDimension() >= linalgOp.getNumLoops()) {
23182329
auto diag = emitSilenceableError() << "dimension " << getDimension()
2319-
<< " does not exist in target op";
2320-
diag.attachNote(target->getLoc()) << "target op";
2330+
<< " does not exist in target op";
2331+
diag.attachNote(loc) << "target op";
23212332
return diag;
23222333
}
2334+
return DiagnosedSilenceableFailure::success();
2335+
};
23232336

2324-
rewriter.setInsertionPoint(linalgOp);
2325-
std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2326-
rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2327-
getDimension(), std::get<1>(pair));
2328-
2329-
// Propagate errors.
2330-
if (!first.back() && !second.back()) {
2337+
auto checkFailureInSplitting =
2338+
[&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2339+
if (hasFailed) {
23312340
auto diag = emitDefiniteFailure() << "internal failure in splitting";
2332-
diag.attachNote(target->getLoc()) << "target op";
2341+
diag.attachNote(loc) << "target op";
23332342
return diag;
23342343
}
2344+
return DiagnosedSilenceableFailure::success();
2345+
};
2346+
2347+
if (isMultiwaySplit) {
2348+
2349+
// Split a single target operation at multiple points.
2350+
SmallVector<Operation *> opList;
2351+
TilingInterface head, tail;
2352+
Operation *target = payload.front();
2353+
2354+
LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2355+
2356+
// Check that the target is a valid LinalgOp with correct dimensions.
2357+
DiagnosedSilenceableFailure diag =
2358+
checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2359+
if (diag.isSilenceableFailure())
2360+
return diag;
2361+
2362+
for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2363+
2364+
if (idx > 0)
2365+
target = tail.getOperation();
2366+
2367+
if (!target)
2368+
break;
23352369

2336-
// Do not add null second parts.
2337-
if (!second.back()) {
2338-
noSecondPart = target;
2339-
second.pop_back();
2370+
linalgOp = cast<LinalgOp>(target);
2371+
2372+
rewriter.setInsertionPoint(linalgOp);
2373+
std::tie(head, tail) = linalg::splitOp(
2374+
rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2375+
getDimension(), chunkSize);
2376+
2377+
// Propagate errors.
2378+
DiagnosedSilenceableFailure diag =
2379+
checkFailureInSplitting(!head && !tail, target->getLoc());
2380+
if (diag.isDefiniteFailure())
2381+
return diag;
2382+
2383+
opList.push_back(head.getOperation());
23402384
}
2341-
}
23422385

2343-
if (second.size() != first.size() && !second.empty()) {
2344-
auto diag = emitSilenceableError()
2345-
<< "splitting does not produce the second part for a subset "
2346-
"of targets";
2347-
diag.attachNote() << "expected splitting to produce the second part of all "
2348-
"or none of the targets";
2349-
diag.attachNote(noSecondPart->getLoc())
2350-
<< "first target with no second part";
2351-
return diag;
2352-
}
2386+
// Append any leftover parts to the end of the result list.
2387+
if (tail)
2388+
opList.push_back(tail.getOperation());
2389+
results.set(cast<OpResult>(getFirst()), opList);
2390+
results.set(cast<OpResult>(getSecond()), {});
2391+
2392+
} else {
2393+
// Split each target operation.
2394+
SmallVector<Operation *> first, second;
2395+
Operation *noSecondPart = nullptr;
2396+
for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2397+
Operation *target = std::get<0>(pair);
2398+
LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2399+
DiagnosedSilenceableFailure diag =
2400+
checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2401+
2402+
if (diag.isSilenceableFailure())
2403+
return diag;
2404+
2405+
rewriter.setInsertionPoint(linalgOp);
2406+
std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2407+
rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2408+
getDimension(), std::get<1>(pair));
2409+
2410+
// Propagate errors.
2411+
DiagnosedSilenceableFailure diagSplit = checkFailureInSplitting(
2412+
!first.back() && !second.back(), target->getLoc());
2413+
if (diagSplit.isDefiniteFailure())
2414+
return diag;
2415+
2416+
// Do not add null second parts.
2417+
if (!second.back()) {
2418+
noSecondPart = target;
2419+
second.pop_back();
2420+
}
2421+
}
2422+
2423+
if (second.size() != first.size() && !second.empty()) {
2424+
auto diag = emitSilenceableError()
2425+
<< "splitting does not produce the second part for a subset "
2426+
"of targets";
2427+
diag.attachNote()
2428+
<< "expected splitting to produce the second part of all "
2429+
"or none of the targets";
2430+
diag.attachNote(noSecondPart->getLoc())
2431+
<< "first target with no second part";
2432+
return diag;
2433+
}
23532434

2354-
results.set(cast<OpResult>(getFirst()), first);
2355-
results.set(cast<OpResult>(getSecond()), second);
2435+
results.set(cast<OpResult>(getFirst()), first);
2436+
results.set(cast<OpResult>(getSecond()), second);
2437+
}
23562438
return DiagnosedSilenceableFailure::success();
23572439
}
23582440

23592441
void SplitOp::getEffects(
23602442
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
23612443
consumesHandle(getTargetMutable(), effects);
2362-
if (getDynamicSplitPoint())
2363-
onlyReadsHandle(getDynamicSplitPointMutable(), effects);
2444+
if (getDynamicChunkSizes())
2445+
onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
23642446
producesHandle(getOperation()->getOpResults(), effects);
23652447
modifiesPayload(effects);
23662448
}
23672449

23682450
ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
2369-
OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
2370-
IntegerAttr staticSplitPoint;
2451+
OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
2452+
IntegerAttr staticChunkSizes;
23712453
if (parser.parseOperand(target) || parser.parseKeyword("after"))
23722454
return failure();
23732455

23742456
OptionalParseResult dynamicPointParseResult =
2375-
parser.parseOptionalOperand(dynamicSplitPoint);
2457+
parser.parseOptionalOperand(dynamicChunkSizes);
23762458
if (!dynamicPointParseResult.has_value()) {
2377-
int64_t staticSplitPointValue;
2378-
if (failed(parser.parseInteger(staticSplitPointValue)))
2459+
int64_t staticChunkSizesValue;
2460+
if (failed(parser.parseInteger(staticChunkSizesValue)))
23792461
return failure();
23802462

2381-
staticSplitPoint =
2382-
parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
2463+
staticChunkSizes =
2464+
parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
23832465
}
23842466

23852467
Type targetType;
@@ -2389,43 +2471,43 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
23892471
return failure();
23902472
}
23912473
if (dynamicPointParseResult.has_value()) {
2392-
Type splitPointType;
2474+
Type ChunkSizesType;
23932475
if (failed(*dynamicPointParseResult) || parser.parseComma() ||
2394-
parser.parseType(splitPointType) ||
2395-
parser.resolveOperand(dynamicSplitPoint, splitPointType,
2476+
parser.parseType(ChunkSizesType) ||
2477+
parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
23962478
result.operands)) {
23972479
return failure();
23982480
}
23992481

2400-
staticSplitPoint =
2482+
staticChunkSizes =
24012483
parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
24022484
}
24032485

24042486
result.addAttribute(
2405-
SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
2406-
staticSplitPoint);
2487+
SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
2488+
staticChunkSizes);
24072489
result.addTypes({targetType, targetType});
24082490
return success();
24092491
}
24102492

24112493
void SplitOp::print(OpAsmPrinter &printer) {
24122494
printer << " " << getTarget() << " after ";
2413-
int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
2414-
if (staticSplitSize != ShapedType::kDynamic)
2415-
printer << staticSplitSize;
2495+
int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
2496+
if (staticChunkSize != ShapedType::kDynamic)
2497+
printer << staticChunkSize;
24162498
else
2417-
printer << getDynamicSplitPoint();
2499+
printer << getDynamicChunkSizes();
24182500
printer << " ";
24192501
printer.printOptionalAttrDict(getOperation()->getAttrs(),
2420-
{getStaticSplitPointAttrName()});
2502+
{getStaticChunkSizesAttrName()});
24212503
printer << " : " << getTarget().getType();
2422-
if (staticSplitSize == ShapedType::kDynamic)
2423-
printer << ", " << getDynamicSplitPoint().getType();
2504+
if (staticChunkSize == ShapedType::kDynamic)
2505+
printer << ", " << getDynamicChunkSizes().getType();
24242506
}
24252507

24262508
LogicalResult SplitOp::verify() {
2427-
if ((static_cast<int64_t>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
2428-
(getDynamicSplitPoint() == nullptr)) {
2509+
if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2510+
(getDynamicChunkSizes() == nullptr)) {
24292511
return emitOpError() << "expects either a dynamic or a static split "
24302512
"point to be provided";
24312513
}

0 commit comments

Comments
 (0)