@@ -2266,13 +2266,26 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
2266
2266
// Collect the dynamic split points if provided.
2267
2267
SmallVector<Operation *> payload =
2268
2268
llvm::to_vector (state.getPayloadOps (getTarget ()));
2269
- SmallVector<OpFoldResult> splitPoints;
2270
- splitPoints.reserve (payload.size ());
2271
- if (getDynamicSplitPoint ()) {
2269
+
2270
+ bool isMultiwaySplit = getMultiway ();
2271
+
2272
+ if (isMultiwaySplit && !llvm::hasSingleElement (payload)) {
2273
+ return mlir::emitSilenceableFailure (getLoc ())
2274
+ << " requires exactly one target when "
2275
+ " multiway split is enabled (got "
2276
+ << llvm::range_size (payload) << " )" ;
2277
+ }
2278
+
2279
+ SmallVector<OpFoldResult> chunkSizes;
2280
+
2281
+ if (!isMultiwaySplit)
2282
+ chunkSizes.reserve (payload.size ());
2283
+
2284
+ if (getDynamicChunkSizes ()) {
2272
2285
auto diag = DiagnosedSilenceableFailure::success ();
2273
- if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint ().getType ())) {
2274
- splitPoints = llvm::to_vector (llvm::map_range (
2275
- state.getPayloadOps (getDynamicSplitPoint ()), [&](Operation *op) {
2286
+ if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes ().getType ())) {
2287
+ chunkSizes = llvm::to_vector (llvm::map_range (
2288
+ state.getPayloadOps (getDynamicChunkSizes ()), [&](Operation *op) {
2276
2289
if (op->getNumResults () != 1 ||
2277
2290
!op->getResult (0 ).getType ().isIndex ()) {
2278
2291
diag = emitSilenceableError ()
@@ -2283,103 +2296,172 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
2283
2296
return OpFoldResult (op->getResult (0 ));
2284
2297
}));
2285
2298
} else {
2286
- splitPoints = llvm::to_vector (
2287
- llvm::map_range (state.getParams (getDynamicSplitPoint ()),
2299
+ chunkSizes = llvm::to_vector (
2300
+ llvm::map_range (state.getParams (getDynamicChunkSizes ()),
2288
2301
[](Attribute attr) { return OpFoldResult (attr); }));
2289
2302
}
2290
2303
if (diag.isSilenceableFailure ())
2291
2304
return diag;
2292
2305
2293
- if (splitPoints.size () != payload.size ()) {
2306
+ // For multiway split, a single payload is expected to have multiple
2307
+ // split points.
2308
+ if (!isMultiwaySplit && chunkSizes.size () != payload.size ()) {
2294
2309
return emitDefiniteFailure ()
2295
2310
<< " expected the dynamic split point handle to point to as "
2296
2311
" many operations ("
2297
- << splitPoints .size () << " ) as the target handle ("
2312
+ << chunkSizes .size () << " ) as the target handle ("
2298
2313
<< payload.size () << " )" ;
2299
2314
}
2300
2315
} else {
2301
- splitPoints .resize (payload.size (),
2302
- rewriter.getIndexAttr (getStaticSplitPoint ()));
2316
+ chunkSizes .resize (payload.size (),
2317
+ rewriter.getIndexAttr (getStaticChunkSizes ()));
2303
2318
}
2304
2319
2305
- // Split each target operation.
2306
- SmallVector<Operation *> first, second;
2307
- Operation *noSecondPart = nullptr ;
2308
- for (const auto &pair : llvm::zip (payload, splitPoints)) {
2309
- Operation *target = std::get<0 >(pair);
2310
- auto linalgOp = dyn_cast<LinalgOp>(target);
2320
+ auto checkStructuredOpAndDimensions =
2321
+ [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2311
2322
if (!linalgOp) {
2312
2323
auto diag = emitSilenceableError () << " only applies to structured ops" ;
2313
- diag.attachNote (target-> getLoc () ) << " target op" ;
2324
+ diag.attachNote (loc ) << " target op" ;
2314
2325
return diag;
2315
2326
}
2316
2327
2317
2328
if (getDimension () >= linalgOp.getNumLoops ()) {
2318
2329
auto diag = emitSilenceableError () << " dimension " << getDimension ()
2319
- << " does not exist in target op" ;
2320
- diag.attachNote (target-> getLoc () ) << " target op" ;
2330
+ << " does not exist in target op" ;
2331
+ diag.attachNote (loc ) << " target op" ;
2321
2332
return diag;
2322
2333
}
2334
+ return DiagnosedSilenceableFailure::success ();
2335
+ };
2323
2336
2324
- rewriter.setInsertionPoint (linalgOp);
2325
- std::tie (first.emplace_back (), second.emplace_back ()) = linalg::splitOp (
2326
- rewriter, cast<TilingInterface>(linalgOp.getOperation ()),
2327
- getDimension (), std::get<1 >(pair));
2328
-
2329
- // Propagate errors.
2330
- if (!first.back () && !second.back ()) {
2337
+ auto checkFailureInSplitting =
2338
+ [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2339
+ if (hasFailed) {
2331
2340
auto diag = emitDefiniteFailure () << " internal failure in splitting" ;
2332
- diag.attachNote (target-> getLoc () ) << " target op" ;
2341
+ diag.attachNote (loc ) << " target op" ;
2333
2342
return diag;
2334
2343
}
2344
+ return DiagnosedSilenceableFailure::success ();
2345
+ };
2346
+
2347
+ if (isMultiwaySplit) {
2348
+
2349
+ // Split a single target operation at multiple points.
2350
+ SmallVector<Operation *> opList;
2351
+ TilingInterface head, tail;
2352
+ Operation *target = payload.front ();
2353
+
2354
+ LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2355
+
2356
+ // Check that the target is a valid LinalgOp with correct dimensions.
2357
+ DiagnosedSilenceableFailure diag =
2358
+ checkStructuredOpAndDimensions (linalgOp, target->getLoc ());
2359
+ if (diag.isSilenceableFailure ())
2360
+ return diag;
2361
+
2362
+ for (auto &&[idx, chunkSize] : llvm::enumerate (chunkSizes)) {
2363
+
2364
+ if (idx > 0 )
2365
+ target = tail.getOperation ();
2366
+
2367
+ if (!target)
2368
+ break ;
2335
2369
2336
- // Do not add null second parts.
2337
- if (!second.back ()) {
2338
- noSecondPart = target;
2339
- second.pop_back ();
2370
+ linalgOp = cast<LinalgOp>(target);
2371
+
2372
+ rewriter.setInsertionPoint (linalgOp);
2373
+ std::tie (head, tail) = linalg::splitOp (
2374
+ rewriter, cast<TilingInterface>(linalgOp.getOperation ()),
2375
+ getDimension (), chunkSize);
2376
+
2377
+ // Propagate errors.
2378
+ DiagnosedSilenceableFailure diag =
2379
+ checkFailureInSplitting (!head && !tail, target->getLoc ());
2380
+ if (diag.isDefiniteFailure ())
2381
+ return diag;
2382
+
2383
+ opList.push_back (head.getOperation ());
2340
2384
}
2341
- }
2342
2385
2343
- if (second.size () != first.size () && !second.empty ()) {
2344
- auto diag = emitSilenceableError ()
2345
- << " splitting does not produce the second part for a subset "
2346
- " of targets" ;
2347
- diag.attachNote () << " expected splitting to produce the second part of all "
2348
- " or none of the targets" ;
2349
- diag.attachNote (noSecondPart->getLoc ())
2350
- << " first target with no second part" ;
2351
- return diag;
2352
- }
2386
+ // Append any leftover parts to the end of the result list.
2387
+ if (tail)
2388
+ opList.push_back (tail.getOperation ());
2389
+ results.set (cast<OpResult>(getFirst ()), opList);
2390
+ results.set (cast<OpResult>(getSecond ()), {});
2391
+
2392
+ } else {
2393
+ // Split each target operation.
2394
+ SmallVector<Operation *> first, second;
2395
+ Operation *noSecondPart = nullptr ;
2396
+ for (const auto &pair : llvm::zip (payload, chunkSizes)) {
2397
+ Operation *target = std::get<0 >(pair);
2398
+ LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2399
+ DiagnosedSilenceableFailure diag =
2400
+ checkStructuredOpAndDimensions (linalgOp, target->getLoc ());
2401
+
2402
+ if (diag.isSilenceableFailure ())
2403
+ return diag;
2404
+
2405
+ rewriter.setInsertionPoint (linalgOp);
2406
+ std::tie (first.emplace_back (), second.emplace_back ()) = linalg::splitOp (
2407
+ rewriter, cast<TilingInterface>(linalgOp.getOperation ()),
2408
+ getDimension (), std::get<1 >(pair));
2409
+
2410
+ // Propagate errors.
2411
+ DiagnosedSilenceableFailure diagSplit = checkFailureInSplitting (
2412
+ !first.back () && !second.back (), target->getLoc ());
2413
+ if (diagSplit.isDefiniteFailure ())
2414
+ return diag;
2415
+
2416
+ // Do not add null second parts.
2417
+ if (!second.back ()) {
2418
+ noSecondPart = target;
2419
+ second.pop_back ();
2420
+ }
2421
+ }
2422
+
2423
+ if (second.size () != first.size () && !second.empty ()) {
2424
+ auto diag = emitSilenceableError ()
2425
+ << " splitting does not produce the second part for a subset "
2426
+ " of targets" ;
2427
+ diag.attachNote ()
2428
+ << " expected splitting to produce the second part of all "
2429
+ " or none of the targets" ;
2430
+ diag.attachNote (noSecondPart->getLoc ())
2431
+ << " first target with no second part" ;
2432
+ return diag;
2433
+ }
2353
2434
2354
- results.set (cast<OpResult>(getFirst ()), first);
2355
- results.set (cast<OpResult>(getSecond ()), second);
2435
+ results.set (cast<OpResult>(getFirst ()), first);
2436
+ results.set (cast<OpResult>(getSecond ()), second);
2437
+ }
2356
2438
return DiagnosedSilenceableFailure::success ();
2357
2439
}
2358
2440
2359
2441
void SplitOp::getEffects (
2360
2442
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2361
2443
consumesHandle (getTargetMutable (), effects);
2362
- if (getDynamicSplitPoint ())
2363
- onlyReadsHandle (getDynamicSplitPointMutable (), effects);
2444
+ if (getDynamicChunkSizes ())
2445
+ onlyReadsHandle (getDynamicChunkSizesMutable (), effects);
2364
2446
producesHandle (getOperation ()->getOpResults (), effects);
2365
2447
modifiesPayload (effects);
2366
2448
}
2367
2449
2368
2450
ParseResult SplitOp::parse (OpAsmParser &parser, OperationState &result) {
2369
- OpAsmParser::UnresolvedOperand target, dynamicSplitPoint ;
2370
- IntegerAttr staticSplitPoint ;
2451
+ OpAsmParser::UnresolvedOperand target, dynamicChunkSizes ;
2452
+ IntegerAttr staticChunkSizes ;
2371
2453
if (parser.parseOperand (target) || parser.parseKeyword (" after" ))
2372
2454
return failure ();
2373
2455
2374
2456
OptionalParseResult dynamicPointParseResult =
2375
- parser.parseOptionalOperand (dynamicSplitPoint );
2457
+ parser.parseOptionalOperand (dynamicChunkSizes );
2376
2458
if (!dynamicPointParseResult.has_value ()) {
2377
- int64_t staticSplitPointValue ;
2378
- if (failed (parser.parseInteger (staticSplitPointValue )))
2459
+ int64_t staticChunkSizesValue ;
2460
+ if (failed (parser.parseInteger (staticChunkSizesValue )))
2379
2461
return failure ();
2380
2462
2381
- staticSplitPoint =
2382
- parser.getBuilder ().getI64IntegerAttr (staticSplitPointValue );
2463
+ staticChunkSizes =
2464
+ parser.getBuilder ().getI64IntegerAttr (staticChunkSizesValue );
2383
2465
}
2384
2466
2385
2467
Type targetType;
@@ -2389,43 +2471,43 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
2389
2471
return failure ();
2390
2472
}
2391
2473
if (dynamicPointParseResult.has_value ()) {
2392
- Type splitPointType ;
2474
+ Type ChunkSizesType ;
2393
2475
if (failed (*dynamicPointParseResult) || parser.parseComma () ||
2394
- parser.parseType (splitPointType ) ||
2395
- parser.resolveOperand (dynamicSplitPoint, splitPointType ,
2476
+ parser.parseType (ChunkSizesType ) ||
2477
+ parser.resolveOperand (dynamicChunkSizes, ChunkSizesType ,
2396
2478
result.operands )) {
2397
2479
return failure ();
2398
2480
}
2399
2481
2400
- staticSplitPoint =
2482
+ staticChunkSizes =
2401
2483
parser.getBuilder ().getI64IntegerAttr (ShapedType::kDynamic );
2402
2484
}
2403
2485
2404
2486
result.addAttribute (
2405
- SplitOp::getStaticSplitPointAttrName (result.name ).getValue (),
2406
- staticSplitPoint );
2487
+ SplitOp::getStaticChunkSizesAttrName (result.name ).getValue (),
2488
+ staticChunkSizes );
2407
2489
result.addTypes ({targetType, targetType});
2408
2490
return success ();
2409
2491
}
2410
2492
2411
2493
void SplitOp::print (OpAsmPrinter &printer) {
2412
2494
printer << " " << getTarget () << " after " ;
2413
- int64_t staticSplitSize = static_cast <int64_t >(getStaticSplitPoint ());
2414
- if (staticSplitSize != ShapedType::kDynamic )
2415
- printer << staticSplitSize ;
2495
+ int64_t staticChunkSize = static_cast <int64_t >(getStaticChunkSizes ());
2496
+ if (staticChunkSize != ShapedType::kDynamic )
2497
+ printer << staticChunkSize ;
2416
2498
else
2417
- printer << getDynamicSplitPoint ();
2499
+ printer << getDynamicChunkSizes ();
2418
2500
printer << " " ;
2419
2501
printer.printOptionalAttrDict (getOperation ()->getAttrs (),
2420
- {getStaticSplitPointAttrName ()});
2502
+ {getStaticChunkSizesAttrName ()});
2421
2503
printer << " : " << getTarget ().getType ();
2422
- if (staticSplitSize == ShapedType::kDynamic )
2423
- printer << " , " << getDynamicSplitPoint ().getType ();
2504
+ if (staticChunkSize == ShapedType::kDynamic )
2505
+ printer << " , " << getDynamicChunkSizes ().getType ();
2424
2506
}
2425
2507
2426
2508
LogicalResult SplitOp::verify () {
2427
- if ((static_cast <int64_t >(getStaticSplitPoint ()) != ShapedType::kDynamic ) ^
2428
- (getDynamicSplitPoint () == nullptr )) {
2509
+ if ((static_cast <int64_t >(getStaticChunkSizes ()) != ShapedType::kDynamic ) ^
2510
+ (getDynamicChunkSizes () == nullptr )) {
2429
2511
return emitOpError () << " expects either a dynamic or a static split "
2430
2512
" point to be provided" ;
2431
2513
}
0 commit comments