@@ -2269,13 +2269,26 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
2269
2269
// Collect the dynamic split points if provided.
2270
2270
SmallVector<Operation *> payload =
2271
2271
llvm::to_vector (state.getPayloadOps (getTarget ()));
2272
- SmallVector<OpFoldResult> splitPoints;
2273
- splitPoints.reserve (payload.size ());
2274
- if (getDynamicSplitPoint ()) {
2272
+
2273
+ bool isMultiwaySplit = getMultiway ();
2274
+
2275
+ if (isMultiwaySplit && !llvm::hasSingleElement (payload)) {
2276
+ return mlir::emitSilenceableFailure (getLoc ())
2277
+ << " requires exactly one target when "
2278
+ " multiway split is enabled (got "
2279
+ << llvm::range_size (payload) << " )" ;
2280
+ }
2281
+
2282
+ SmallVector<OpFoldResult> chunkSizes;
2283
+
2284
+ if (!isMultiwaySplit)
2285
+ chunkSizes.reserve (payload.size ());
2286
+
2287
+ if (getDynamicChunkSizes ()) {
2275
2288
auto diag = DiagnosedSilenceableFailure::success ();
2276
- if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint ().getType ())) {
2277
- splitPoints = llvm::to_vector (llvm::map_range (
2278
- state.getPayloadOps (getDynamicSplitPoint ()), [&](Operation *op) {
2289
+ if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes ().getType ())) {
2290
+ chunkSizes = llvm::to_vector (llvm::map_range (
2291
+ state.getPayloadOps (getDynamicChunkSizes ()), [&](Operation *op) {
2279
2292
if (op->getNumResults () != 1 ||
2280
2293
!op->getResult (0 ).getType ().isIndex ()) {
2281
2294
diag = emitSilenceableError ()
@@ -2286,103 +2299,171 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
2286
2299
return OpFoldResult (op->getResult (0 ));
2287
2300
}));
2288
2301
} else {
2289
- splitPoints = llvm::to_vector (
2290
- llvm::map_range (state.getParams (getDynamicSplitPoint ()),
2302
+ chunkSizes = llvm::to_vector (
2303
+ llvm::map_range (state.getParams (getDynamicChunkSizes ()),
2291
2304
[](Attribute attr) { return OpFoldResult (attr); }));
2292
2305
}
2293
2306
if (diag.isSilenceableFailure ())
2294
2307
return diag;
2295
2308
2296
- if (splitPoints.size () != payload.size ()) {
2309
+ // For multiway split, a single payload is expected to have multiple
2310
+ // split points.
2311
+ if (!isMultiwaySplit && chunkSizes.size () != payload.size ()) {
2297
2312
return emitDefiniteFailure ()
2298
2313
<< " expected the dynamic split point handle to point to as "
2299
2314
" many operations ("
2300
- << splitPoints .size () << " ) as the target handle ("
2315
+ << chunkSizes .size () << " ) as the target handle ("
2301
2316
<< payload.size () << " )" ;
2302
2317
}
2303
2318
} else {
2304
- splitPoints .resize (payload.size (),
2305
- rewriter.getIndexAttr (getStaticSplitPoint ()));
2319
+ chunkSizes .resize (payload.size (),
2320
+ rewriter.getIndexAttr (getStaticChunkSizes ()));
2306
2321
}
2307
2322
2308
- // Split each target operation.
2309
- SmallVector<Operation *> first, second;
2310
- Operation *noSecondPart = nullptr ;
2311
- for (const auto &pair : llvm::zip (payload, splitPoints)) {
2312
- Operation *target = std::get<0 >(pair);
2313
- auto linalgOp = dyn_cast<LinalgOp>(target);
2323
+ auto checkStructuredOpAndDimensions =
2324
+ [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2314
2325
if (!linalgOp) {
2315
2326
auto diag = emitSilenceableError () << " only applies to structured ops" ;
2316
- diag.attachNote (target-> getLoc () ) << " target op" ;
2327
+ diag.attachNote (loc ) << " target op" ;
2317
2328
return diag;
2318
2329
}
2319
2330
2320
2331
if (getDimension () >= linalgOp.getNumLoops ()) {
2321
2332
auto diag = emitSilenceableError () << " dimension " << getDimension ()
2322
- << " does not exist in target op" ;
2323
- diag.attachNote (target-> getLoc () ) << " target op" ;
2333
+ << " does not exist in target op" ;
2334
+ diag.attachNote (loc ) << " target op" ;
2324
2335
return diag;
2325
2336
}
2337
+ return DiagnosedSilenceableFailure::success ();
2338
+ };
2326
2339
2327
- rewriter.setInsertionPoint (linalgOp);
2328
- std::tie (first.emplace_back (), second.emplace_back ()) = linalg::splitOp (
2329
- rewriter, cast<TilingInterface>(linalgOp.getOperation ()),
2330
- getDimension (), std::get<1 >(pair));
2331
-
2332
- // Propagate errors.
2333
- if (!first.back () && !second.back ()) {
2340
+ auto checkFailureInSplitting =
2341
+ [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2342
+ if (hasFailed) {
2334
2343
auto diag = emitDefiniteFailure () << " internal failure in splitting" ;
2335
- diag.attachNote (target-> getLoc () ) << " target op" ;
2344
+ diag.attachNote (loc ) << " target op" ;
2336
2345
return diag;
2337
2346
}
2347
+ return DiagnosedSilenceableFailure::success ();
2348
+ };
2349
+
2350
+ if (isMultiwaySplit) {
2351
+
2352
+ // Split a single target operation at multiple points.
2353
+ SmallVector<Operation *> opList;
2354
+ TilingInterface head, tail;
2355
+ Operation *target = payload.front ();
2356
+
2357
+ LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2358
+ DiagnosedSilenceableFailure diag =
2359
+ checkStructuredOpAndDimensions (linalgOp, target->getLoc ());
2360
+
2361
+ if (diag.isSilenceableFailure ())
2362
+ return diag;
2363
+
2364
+ for (auto &&[idx, chunkSize] : llvm::enumerate (chunkSizes)) {
2365
+
2366
+ if (idx > 0 )
2367
+ target = tail.getOperation ();
2368
+
2369
+ if (!target)
2370
+ break ;
2338
2371
2339
- // Do not add null second parts.
2340
- if (!second.back ()) {
2341
- noSecondPart = target;
2342
- second.pop_back ();
2372
+ linalgOp = cast<LinalgOp>(target);
2373
+
2374
+ rewriter.setInsertionPoint (linalgOp);
2375
+ std::tie (head, tail) = linalg::splitOp (
2376
+ rewriter, cast<TilingInterface>(linalgOp.getOperation ()),
2377
+ getDimension (), chunkSize);
2378
+
2379
+ // Propagate errors.
2380
+ DiagnosedSilenceableFailure diag =
2381
+ checkFailureInSplitting (!head && !tail, target->getLoc ());
2382
+ if (diag.isDefiniteFailure ())
2383
+ return diag;
2384
+
2385
+ opList.push_back (head.getOperation ());
2343
2386
}
2344
- }
2345
2387
2346
- if (second.size () != first.size () && !second.empty ()) {
2347
- auto diag = emitSilenceableError ()
2348
- << " splitting does not produce the second part for a subset "
2349
- " of targets" ;
2350
- diag.attachNote () << " expected splitting to produce the second part of all "
2351
- " or none of the targets" ;
2352
- diag.attachNote (noSecondPart->getLoc ())
2353
- << " first target with no second part" ;
2354
- return diag;
2355
- }
2388
+ // Append any leftover parts to the end of the result list.
2389
+ if (tail)
2390
+ opList.push_back (tail);
2391
+ results.set (cast<OpResult>(getFirst ()), opList);
2392
+ results.set (cast<OpResult>(getSecond ()), {});
2393
+
2394
+ } else {
2395
+ // Split each target operation.
2396
+ SmallVector<Operation *> first, second;
2397
+ Operation *noSecondPart = nullptr ;
2398
+ for (const auto &pair : llvm::zip (payload, chunkSizes)) {
2399
+ Operation *target = std::get<0 >(pair);
2400
+ LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2401
+ DiagnosedSilenceableFailure diag =
2402
+ checkStructuredOpAndDimensions (linalgOp, target->getLoc ());
2403
+
2404
+ if (diag.isSilenceableFailure ())
2405
+ return diag;
2406
+
2407
+ rewriter.setInsertionPoint (linalgOp);
2408
+ std::tie (first.emplace_back (), second.emplace_back ()) = linalg::splitOp (
2409
+ rewriter, cast<TilingInterface>(linalgOp.getOperation ()),
2410
+ getDimension (), std::get<1 >(pair));
2411
+
2412
+ // Propagate errors.
2413
+ DiagnosedSilenceableFailure diagSplit = checkFailureInSplitting (
2414
+ !first.back () && !second.back (), target->getLoc ());
2415
+ if (diagSplit.isDefiniteFailure ())
2416
+ return diag;
2417
+
2418
+ // Do not add null second parts.
2419
+ if (!second.back ()) {
2420
+ noSecondPart = target;
2421
+ second.pop_back ();
2422
+ }
2423
+ }
2424
+
2425
+ if (second.size () != first.size () && !second.empty ()) {
2426
+ auto diag = emitSilenceableError ()
2427
+ << " splitting does not produce the second part for a subset "
2428
+ " of targets" ;
2429
+ diag.attachNote ()
2430
+ << " expected splitting to produce the second part of all "
2431
+ " or none of the targets" ;
2432
+ diag.attachNote (noSecondPart->getLoc ())
2433
+ << " first target with no second part" ;
2434
+ return diag;
2435
+ }
2356
2436
2357
- results.set (cast<OpResult>(getFirst ()), first);
2358
- results.set (cast<OpResult>(getSecond ()), second);
2437
+ results.set (cast<OpResult>(getFirst ()), first);
2438
+ results.set (cast<OpResult>(getSecond ()), second);
2439
+ }
2359
2440
return DiagnosedSilenceableFailure::success ();
2360
2441
}
2361
2442
2362
2443
void SplitOp::getEffects (
2363
2444
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2364
2445
consumesHandle (getTarget (), effects);
2365
- if (getDynamicSplitPoint ())
2366
- onlyReadsHandle (getDynamicSplitPoint (), effects);
2446
+ if (getDynamicChunkSizes ())
2447
+ onlyReadsHandle (getDynamicChunkSizes (), effects);
2367
2448
producesHandle (getResults (), effects);
2368
2449
modifiesPayload (effects);
2369
2450
}
2370
2451
2371
2452
ParseResult SplitOp::parse (OpAsmParser &parser, OperationState &result) {
2372
- OpAsmParser::UnresolvedOperand target, dynamicSplitPoint ;
2373
- IntegerAttr staticSplitPoint ;
2453
+ OpAsmParser::UnresolvedOperand target, dynamicChunkSizes ;
2454
+ IntegerAttr staticChunkSizes ;
2374
2455
if (parser.parseOperand (target) || parser.parseKeyword (" after" ))
2375
2456
return failure ();
2376
2457
2377
2458
OptionalParseResult dynamicPointParseResult =
2378
- parser.parseOptionalOperand (dynamicSplitPoint );
2459
+ parser.parseOptionalOperand (dynamicChunkSizes );
2379
2460
if (!dynamicPointParseResult.has_value ()) {
2380
- int64_t staticSplitPointValue ;
2381
- if (failed (parser.parseInteger (staticSplitPointValue )))
2461
+ int64_t staticChunkSizesValue ;
2462
+ if (failed (parser.parseInteger (staticChunkSizesValue )))
2382
2463
return failure ();
2383
2464
2384
- staticSplitPoint =
2385
- parser.getBuilder ().getI64IntegerAttr (staticSplitPointValue );
2465
+ staticChunkSizes =
2466
+ parser.getBuilder ().getI64IntegerAttr (staticChunkSizesValue );
2386
2467
}
2387
2468
2388
2469
Type targetType;
@@ -2392,43 +2473,43 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
2392
2473
return failure ();
2393
2474
}
2394
2475
if (dynamicPointParseResult.has_value ()) {
2395
- Type splitPointType ;
2476
+ Type ChunkSizesType ;
2396
2477
if (failed (*dynamicPointParseResult) || parser.parseComma () ||
2397
- parser.parseType (splitPointType ) ||
2398
- parser.resolveOperand (dynamicSplitPoint, splitPointType ,
2478
+ parser.parseType (ChunkSizesType ) ||
2479
+ parser.resolveOperand (dynamicChunkSizes, ChunkSizesType ,
2399
2480
result.operands )) {
2400
2481
return failure ();
2401
2482
}
2402
2483
2403
- staticSplitPoint =
2484
+ staticChunkSizes =
2404
2485
parser.getBuilder ().getI64IntegerAttr (ShapedType::kDynamic );
2405
2486
}
2406
2487
2407
2488
result.addAttribute (
2408
- SplitOp::getStaticSplitPointAttrName (result.name ).getValue (),
2409
- staticSplitPoint );
2489
+ SplitOp::getStaticChunkSizesAttrName (result.name ).getValue (),
2490
+ staticChunkSizes );
2410
2491
result.addTypes ({targetType, targetType});
2411
2492
return success ();
2412
2493
}
2413
2494
2414
2495
void SplitOp::print (OpAsmPrinter &printer) {
2415
2496
printer << " " << getTarget () << " after " ;
2416
- int64_t staticSplitSize = static_cast <int64_t >(getStaticSplitPoint ());
2417
- if (staticSplitSize != ShapedType::kDynamic )
2418
- printer << staticSplitSize ;
2497
+ int64_t staticChunkSize = static_cast <int64_t >(getStaticChunkSizes ());
2498
+ if (staticChunkSize != ShapedType::kDynamic )
2499
+ printer << staticChunkSize ;
2419
2500
else
2420
- printer << getDynamicSplitPoint ();
2501
+ printer << getDynamicChunkSizes ();
2421
2502
printer << " " ;
2422
2503
printer.printOptionalAttrDict (getOperation ()->getAttrs (),
2423
- {getStaticSplitPointAttrName ()});
2504
+ {getStaticChunkSizesAttrName ()});
2424
2505
printer << " : " << getTarget ().getType ();
2425
- if (staticSplitSize == ShapedType::kDynamic )
2426
- printer << " , " << getDynamicSplitPoint ().getType ();
2506
+ if (staticChunkSize == ShapedType::kDynamic )
2507
+ printer << " , " << getDynamicChunkSizes ().getType ();
2427
2508
}
2428
2509
2429
2510
LogicalResult SplitOp::verify () {
2430
- if ((static_cast <int64_t >(getStaticSplitPoint ()) != ShapedType::kDynamic ) ^
2431
- (getDynamicSplitPoint () == nullptr )) {
2511
+ if ((static_cast <int64_t >(getStaticChunkSizes ()) != ShapedType::kDynamic ) ^
2512
+ (getDynamicChunkSizes () == nullptr )) {
2432
2513
return emitOpError () << " expects either a dynamic or a static split "
2433
2514
" point to be provided" ;
2434
2515
}
0 commit comments