@@ -2269,13 +2269,26 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
2269
2269
// Collect the dynamic split points if provided.
2270
2270
SmallVector<Operation *> payload =
2271
2271
llvm::to_vector (state.getPayloadOps (getTarget ()));
2272
- SmallVector<OpFoldResult> splitPoints;
2273
- splitPoints.reserve (payload.size ());
2274
- if (getDynamicSplitPoint ()) {
2272
+
2273
+ bool isMultiwaySplit = getMultiway ();
2274
+
2275
+ if (isMultiwaySplit && !llvm::hasSingleElement (payload)) {
2276
+ return mlir::emitSilenceableFailure (getLoc ())
2277
+ << " requires exactly one target when "
2278
+ " multiway split is enabled (got "
2279
+ << llvm::range_size (payload) << " )" ;
2280
+ }
2281
+
2282
+ SmallVector<OpFoldResult> chunkSizes;
2283
+
2284
+ if (!isMultiwaySplit)
2285
+ chunkSizes.reserve (payload.size ());
2286
+
2287
+ if (getDynamicChunkSizes ()) {
2275
2288
auto diag = DiagnosedSilenceableFailure::success ();
2276
- if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint ().getType ())) {
2277
- splitPoints = llvm::to_vector (llvm::map_range (
2278
- state.getPayloadOps (getDynamicSplitPoint ()), [&](Operation *op) {
2289
+ if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes ().getType ())) {
2290
+ chunkSizes = llvm::to_vector (llvm::map_range (
2291
+ state.getPayloadOps (getDynamicChunkSizes ()), [&](Operation *op) {
2279
2292
if (op->getNumResults () != 1 ||
2280
2293
!op->getResult (0 ).getType ().isIndex ()) {
2281
2294
diag = emitSilenceableError ()
@@ -2286,103 +2299,172 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
2286
2299
return OpFoldResult (op->getResult (0 ));
2287
2300
}));
2288
2301
} else {
2289
- splitPoints = llvm::to_vector (
2290
- llvm::map_range (state.getParams (getDynamicSplitPoint ()),
2302
+ chunkSizes = llvm::to_vector (
2303
+ llvm::map_range (state.getParams (getDynamicChunkSizes ()),
2291
2304
[](Attribute attr) { return OpFoldResult (attr); }));
2292
2305
}
2293
2306
if (diag.isSilenceableFailure ())
2294
2307
return diag;
2295
2308
2296
- if (splitPoints.size () != payload.size ()) {
2309
+ // For multiway split, a single payload is expected to have multiple
2310
+ // split points.
2311
+ if (!isMultiwaySplit && chunkSizes.size () != payload.size ()) {
2297
2312
return emitDefiniteFailure ()
2298
2313
<< " expected the dynamic split point handle to point to as "
2299
2314
" many operations ("
2300
- << splitPoints .size () << " ) as the target handle ("
2315
+ << chunkSizes .size () << " ) as the target handle ("
2301
2316
<< payload.size () << " )" ;
2302
2317
}
2303
2318
} else {
2304
- splitPoints .resize (payload.size (),
2305
- rewriter.getIndexAttr (getStaticSplitPoint ()));
2319
+ chunkSizes .resize (payload.size (),
2320
+ rewriter.getIndexAttr (getStaticChunkSizes ()));
2306
2321
}
2307
2322
2308
- // Split each target operation.
2309
- SmallVector<Operation *> first, second;
2310
- Operation *noSecondPart = nullptr ;
2311
- for (const auto &pair : llvm::zip (payload, splitPoints)) {
2312
- Operation *target = std::get<0 >(pair);
2313
- auto linalgOp = dyn_cast<LinalgOp>(target);
2323
+ auto checkStructuredOpAndDimensions =
2324
+ [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2314
2325
if (!linalgOp) {
2315
2326
auto diag = emitSilenceableError () << " only applies to structured ops" ;
2316
- diag.attachNote (target-> getLoc () ) << " target op" ;
2327
+ diag.attachNote (loc ) << " target op" ;
2317
2328
return diag;
2318
2329
}
2319
2330
2320
2331
if (getDimension () >= linalgOp.getNumLoops ()) {
2321
2332
auto diag = emitSilenceableError () << " dimension " << getDimension ()
2322
- << " does not exist in target op" ;
2323
- diag.attachNote (target-> getLoc () ) << " target op" ;
2333
+ << " does not exist in target op" ;
2334
+ diag.attachNote (loc ) << " target op" ;
2324
2335
return diag;
2325
2336
}
2337
+ return DiagnosedSilenceableFailure::success ();
2338
+ };
2326
2339
2327
- rewriter.setInsertionPoint (linalgOp);
2328
- std::tie (first.emplace_back (), second.emplace_back ()) = linalg::splitOp (
2329
- rewriter, cast<TilingInterface>(linalgOp.getOperation ()),
2330
- getDimension (), std::get<1 >(pair));
2331
-
2332
- // Propagate errors.
2333
- if (!first.back () && !second.back ()) {
2340
+ auto checkFailureInSplitting =
2341
+ [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2342
+ if (hasFailed) {
2334
2343
auto diag = emitDefiniteFailure () << " internal failure in splitting" ;
2335
- diag.attachNote (target-> getLoc () ) << " target op" ;
2344
+ diag.attachNote (loc ) << " target op" ;
2336
2345
return diag;
2337
2346
}
2347
+ return DiagnosedSilenceableFailure::success ();
2348
+ };
2349
+
2350
+ if (isMultiwaySplit) {
2351
+
2352
+ // Split a single target operation at multiple points.
2353
+ SmallVector<Operation *> opList;
2354
+ TilingInterface head, tail;
2355
+ Operation *target = payload.front ();
2356
+
2357
+ LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2358
+
2359
+ // Check that the target is a valid LinalgOp with correct dimensions.
2360
+ DiagnosedSilenceableFailure diag =
2361
+ checkStructuredOpAndDimensions (linalgOp, target->getLoc ());
2362
+ if (diag.isSilenceableFailure ())
2363
+ return diag;
2364
+
2365
+ for (auto &&[idx, chunkSize] : llvm::enumerate (chunkSizes)) {
2366
+
2367
+ if (idx > 0 )
2368
+ target = tail.getOperation ();
2369
+
2370
+ if (!target)
2371
+ break ;
2338
2372
2339
- // Do not add null second parts.
2340
- if (!second.back ()) {
2341
- noSecondPart = target;
2342
- second.pop_back ();
2373
+ linalgOp = cast<LinalgOp>(target);
2374
+
2375
+ rewriter.setInsertionPoint (linalgOp);
2376
+ std::tie (head, tail) = linalg::splitOp (
2377
+ rewriter, cast<TilingInterface>(linalgOp.getOperation ()),
2378
+ getDimension (), chunkSize);
2379
+
2380
+ // Propagate errors.
2381
+ DiagnosedSilenceableFailure diag =
2382
+ checkFailureInSplitting (!head && !tail, target->getLoc ());
2383
+ if (diag.isDefiniteFailure ())
2384
+ return diag;
2385
+
2386
+ opList.push_back (head.getOperation ());
2343
2387
}
2344
- }
2345
2388
2346
- if (second.size () != first.size () && !second.empty ()) {
2347
- auto diag = emitSilenceableError ()
2348
- << " splitting does not produce the second part for a subset "
2349
- " of targets" ;
2350
- diag.attachNote () << " expected splitting to produce the second part of all "
2351
- " or none of the targets" ;
2352
- diag.attachNote (noSecondPart->getLoc ())
2353
- << " first target with no second part" ;
2354
- return diag;
2355
- }
2389
+ // Append any leftover parts to the end of the result list.
2390
+ if (tail)
2391
+ opList.push_back (tail.getOperation ());
2392
+ results.set (cast<OpResult>(getFirst ()), opList);
2393
+ results.set (cast<OpResult>(getSecond ()), {});
2394
+
2395
+ } else {
2396
+ // Split each target operation.
2397
+ SmallVector<Operation *> first, second;
2398
+ Operation *noSecondPart = nullptr ;
2399
+ for (const auto &pair : llvm::zip (payload, chunkSizes)) {
2400
+ Operation *target = std::get<0 >(pair);
2401
+ LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2402
+ DiagnosedSilenceableFailure diag =
2403
+ checkStructuredOpAndDimensions (linalgOp, target->getLoc ());
2404
+
2405
+ if (diag.isSilenceableFailure ())
2406
+ return diag;
2407
+
2408
+ rewriter.setInsertionPoint (linalgOp);
2409
+ std::tie (first.emplace_back (), second.emplace_back ()) = linalg::splitOp (
2410
+ rewriter, cast<TilingInterface>(linalgOp.getOperation ()),
2411
+ getDimension (), std::get<1 >(pair));
2412
+
2413
+ // Propagate errors.
2414
+ DiagnosedSilenceableFailure diagSplit = checkFailureInSplitting (
2415
+ !first.back () && !second.back (), target->getLoc ());
2416
+ if (diagSplit.isDefiniteFailure ())
2417
+ return diag;
2418
+
2419
+ // Do not add null second parts.
2420
+ if (!second.back ()) {
2421
+ noSecondPart = target;
2422
+ second.pop_back ();
2423
+ }
2424
+ }
2425
+
2426
+ if (second.size () != first.size () && !second.empty ()) {
2427
+ auto diag = emitSilenceableError ()
2428
+ << " splitting does not produce the second part for a subset "
2429
+ " of targets" ;
2430
+ diag.attachNote ()
2431
+ << " expected splitting to produce the second part of all "
2432
+ " or none of the targets" ;
2433
+ diag.attachNote (noSecondPart->getLoc ())
2434
+ << " first target with no second part" ;
2435
+ return diag;
2436
+ }
2356
2437
2357
- results.set (cast<OpResult>(getFirst ()), first);
2358
- results.set (cast<OpResult>(getSecond ()), second);
2438
+ results.set (cast<OpResult>(getFirst ()), first);
2439
+ results.set (cast<OpResult>(getSecond ()), second);
2440
+ }
2359
2441
return DiagnosedSilenceableFailure::success ();
2360
2442
}
2361
2443
2362
2444
void SplitOp::getEffects (
2363
2445
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2364
2446
consumesHandle (getTarget (), effects);
2365
- if (getDynamicSplitPoint ())
2366
- onlyReadsHandle (getDynamicSplitPoint (), effects);
2447
+ if (getDynamicChunkSizes ())
2448
+ onlyReadsHandle (getDynamicChunkSizes (), effects);
2367
2449
producesHandle (getResults (), effects);
2368
2450
modifiesPayload (effects);
2369
2451
}
2370
2452
2371
2453
ParseResult SplitOp::parse (OpAsmParser &parser, OperationState &result) {
2372
- OpAsmParser::UnresolvedOperand target, dynamicSplitPoint ;
2373
- IntegerAttr staticSplitPoint ;
2454
+ OpAsmParser::UnresolvedOperand target, dynamicChunkSizes ;
2455
+ IntegerAttr staticChunkSizes ;
2374
2456
if (parser.parseOperand (target) || parser.parseKeyword (" after" ))
2375
2457
return failure ();
2376
2458
2377
2459
OptionalParseResult dynamicPointParseResult =
2378
- parser.parseOptionalOperand (dynamicSplitPoint );
2460
+ parser.parseOptionalOperand (dynamicChunkSizes );
2379
2461
if (!dynamicPointParseResult.has_value ()) {
2380
- int64_t staticSplitPointValue ;
2381
- if (failed (parser.parseInteger (staticSplitPointValue )))
2462
+ int64_t staticChunkSizesValue ;
2463
+ if (failed (parser.parseInteger (staticChunkSizesValue )))
2382
2464
return failure ();
2383
2465
2384
- staticSplitPoint =
2385
- parser.getBuilder ().getI64IntegerAttr (staticSplitPointValue );
2466
+ staticChunkSizes =
2467
+ parser.getBuilder ().getI64IntegerAttr (staticChunkSizesValue );
2386
2468
}
2387
2469
2388
2470
Type targetType;
@@ -2392,43 +2474,43 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
2392
2474
return failure ();
2393
2475
}
2394
2476
if (dynamicPointParseResult.has_value ()) {
2395
- Type splitPointType ;
2477
+ Type ChunkSizesType ;
2396
2478
if (failed (*dynamicPointParseResult) || parser.parseComma () ||
2397
- parser.parseType (splitPointType ) ||
2398
- parser.resolveOperand (dynamicSplitPoint, splitPointType ,
2479
+ parser.parseType (ChunkSizesType ) ||
2480
+ parser.resolveOperand (dynamicChunkSizes, ChunkSizesType ,
2399
2481
result.operands )) {
2400
2482
return failure ();
2401
2483
}
2402
2484
2403
- staticSplitPoint =
2485
+ staticChunkSizes =
2404
2486
parser.getBuilder ().getI64IntegerAttr (ShapedType::kDynamic );
2405
2487
}
2406
2488
2407
2489
result.addAttribute (
2408
- SplitOp::getStaticSplitPointAttrName (result.name ).getValue (),
2409
- staticSplitPoint );
2490
+ SplitOp::getStaticChunkSizesAttrName (result.name ).getValue (),
2491
+ staticChunkSizes );
2410
2492
result.addTypes ({targetType, targetType});
2411
2493
return success ();
2412
2494
}
2413
2495
2414
2496
void SplitOp::print (OpAsmPrinter &printer) {
2415
2497
printer << " " << getTarget () << " after " ;
2416
- int64_t staticSplitSize = static_cast <int64_t >(getStaticSplitPoint ());
2417
- if (staticSplitSize != ShapedType::kDynamic )
2418
- printer << staticSplitSize ;
2498
+ int64_t staticChunkSize = static_cast <int64_t >(getStaticChunkSizes ());
2499
+ if (staticChunkSize != ShapedType::kDynamic )
2500
+ printer << staticChunkSize ;
2419
2501
else
2420
- printer << getDynamicSplitPoint ();
2502
+ printer << getDynamicChunkSizes ();
2421
2503
printer << " " ;
2422
2504
printer.printOptionalAttrDict (getOperation ()->getAttrs (),
2423
- {getStaticSplitPointAttrName ()});
2505
+ {getStaticChunkSizesAttrName ()});
2424
2506
printer << " : " << getTarget ().getType ();
2425
- if (staticSplitSize == ShapedType::kDynamic )
2426
- printer << " , " << getDynamicSplitPoint ().getType ();
2507
+ if (staticChunkSize == ShapedType::kDynamic )
2508
+ printer << " , " << getDynamicChunkSizes ().getType ();
2427
2509
}
2428
2510
2429
2511
LogicalResult SplitOp::verify () {
2430
- if ((static_cast <int64_t >(getStaticSplitPoint ()) != ShapedType::kDynamic ) ^
2431
- (getDynamicSplitPoint () == nullptr )) {
2512
+ if ((static_cast <int64_t >(getStaticChunkSizes ()) != ShapedType::kDynamic ) ^
2513
+ (getDynamicChunkSizes () == nullptr )) {
2432
2514
return emitOpError () << " expects either a dynamic or a static split "
2433
2515
" point to be provided" ;
2434
2516
}
0 commit comments