@@ -2269,8 +2269,20 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
2269
2269
// Collect the dynamic split points if provided.
2270
2270
SmallVector<Operation *> payload =
2271
2271
llvm::to_vector (state.getPayloadOps (getTarget ()));
2272
+
2273
+ bool isMultiwaySplit = getMultiway () ? true : false ;
2274
+
2275
+ if (isMultiwaySplit && !llvm::hasSingleElement (payload)) {
2276
+ return emitDefiniteFailure () << " requires exactly one target when "
2277
+ " multiway split is enabled (got "
2278
+ << llvm::range_size (payload) << " )" ;
2279
+ }
2280
+
2272
2281
SmallVector<OpFoldResult> splitPoints;
2273
- splitPoints.reserve (payload.size ());
2282
+
2283
+ if (!isMultiwaySplit)
2284
+ splitPoints.reserve (payload.size ());
2285
+
2274
2286
if (getDynamicSplitPoint ()) {
2275
2287
auto diag = DiagnosedSilenceableFailure::success ();
2276
2288
if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint ().getType ())) {
@@ -2293,7 +2305,9 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
2293
2305
if (diag.isSilenceableFailure ())
2294
2306
return diag;
2295
2307
2296
- if (splitPoints.size () != payload.size ()) {
2308
+ // For multiway split, a single payload is expected to have multiple
2309
+ // split points.
2310
+ if (!isMultiwaySplit && splitPoints.size () != payload.size ()) {
2297
2311
return emitDefiniteFailure ()
2298
2312
<< " expected the dynamic split point handle to point to as "
2299
2313
" many operations ("
@@ -2305,57 +2319,105 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
2305
2319
rewriter.getIndexAttr (getStaticSplitPoint ()));
2306
2320
}
2307
2321
2308
- // Split each target operation.
2309
- SmallVector<Operation *> first, second;
2310
- Operation *noSecondPart = nullptr ;
2311
- for (const auto &pair : llvm::zip (payload, splitPoints)) {
2312
- Operation *target = std::get<0 >(pair);
2313
- auto linalgOp = dyn_cast<LinalgOp>(target);
2314
- if (!linalgOp) {
2315
- auto diag = emitSilenceableError () << " only applies to structured ops" ;
2316
- diag.attachNote (target->getLoc ()) << " target op" ;
2317
- return diag;
2318
- }
2322
+ if (isMultiwaySplit) {
2319
2323
2320
- if (getDimension () >= linalgOp.getNumLoops ()) {
2321
- auto diag = emitSilenceableError () << " dimension " << getDimension ()
2322
- << " does not exist in target op" ;
2323
- diag.attachNote (target->getLoc ()) << " target op" ;
2324
- return diag;
2324
+ // Split a single target operation at multiple points.
2325
+ SmallVector<Operation *> opList;
2326
+ Operation *head, *tail;
2327
+ for (const auto [idx, splitPoint] : llvm::enumerate (splitPoints)) {
2328
+
2329
+ Operation *target;
2330
+ if (idx == 0 )
2331
+ target = payload.front ();
2332
+ else
2333
+ target = tail;
2334
+
2335
+ if (!target)
2336
+ break ;
2337
+
2338
+ auto linalgOp = dyn_cast<LinalgOp>(target);
2339
+
2340
+ if (!linalgOp) {
2341
+ auto diag = emitSilenceableError () << " only applies to structured ops" ;
2342
+ diag.attachNote (target->getLoc ()) << " target op" ;
2343
+ return diag;
2344
+ }
2345
+
2346
+ if (getDimension () >= linalgOp.getNumLoops ()) {
2347
+ auto diag = emitSilenceableError () << " dimension " << getDimension ()
2348
+ << " does not exist in target op" ;
2349
+ diag.attachNote (target->getLoc ()) << " target op" ;
2350
+ return diag;
2351
+ }
2352
+
2353
+ rewriter.setInsertionPoint (linalgOp);
2354
+ std::tie (head, tail) = linalg::splitOp (
2355
+ rewriter, cast<TilingInterface>(linalgOp.getOperation ()),
2356
+ getDimension (), splitPoint);
2357
+
2358
+ opList.push_back (head);
2325
2359
}
2326
2360
2327
- rewriter.setInsertionPoint (linalgOp);
2328
- std::tie (first.emplace_back (), second.emplace_back ()) = linalg::splitOp (
2329
- rewriter, cast<TilingInterface>(linalgOp.getOperation ()),
2330
- getDimension (), std::get<1 >(pair));
2361
+ // Append any leftover parts to the end of the result list.
2362
+ if (tail)
2363
+ opList.push_back (tail);
2364
+ results.set (cast<OpResult>(getFirst ()), opList);
2365
+ results.set (cast<OpResult>(getSecond ()), {});
2331
2366
2332
- // Propagate errors.
2333
- if (!first.back () && !second.back ()) {
2334
- auto diag = emitDefiniteFailure () << " internal failure in splitting" ;
2335
- diag.attachNote (target->getLoc ()) << " target op" ;
2336
- return diag;
2367
+ } else {
2368
+ // Split each target operation.
2369
+ SmallVector<Operation *> first, second;
2370
+ Operation *noSecondPart = nullptr ;
2371
+ for (const auto &pair : llvm::zip (payload, splitPoints)) {
2372
+ Operation *target = std::get<0 >(pair);
2373
+ auto linalgOp = dyn_cast<LinalgOp>(target);
2374
+ if (!linalgOp) {
2375
+ auto diag = emitSilenceableError () << " only applies to structured ops" ;
2376
+ diag.attachNote (target->getLoc ()) << " target op" ;
2377
+ return diag;
2378
+ }
2379
+
2380
+ if (getDimension () >= linalgOp.getNumLoops ()) {
2381
+ auto diag = emitSilenceableError () << " dimension " << getDimension ()
2382
+ << " does not exist in target op" ;
2383
+ diag.attachNote (target->getLoc ()) << " target op" ;
2384
+ return diag;
2385
+ }
2386
+
2387
+ rewriter.setInsertionPoint (linalgOp);
2388
+ std::tie (first.emplace_back (), second.emplace_back ()) = linalg::splitOp (
2389
+ rewriter, cast<TilingInterface>(linalgOp.getOperation ()),
2390
+ getDimension (), std::get<1 >(pair));
2391
+
2392
+ // Propagate errors.
2393
+ if (!first.back () && !second.back ()) {
2394
+ auto diag = emitDefiniteFailure () << " internal failure in splitting" ;
2395
+ diag.attachNote (target->getLoc ()) << " target op" ;
2396
+ return diag;
2397
+ }
2398
+
2399
+ // Do not add null second parts.
2400
+ if (!second.back ()) {
2401
+ noSecondPart = target;
2402
+ second.pop_back ();
2403
+ }
2337
2404
}
2338
2405
2339
- // Do not add null second parts.
2340
- if (!second.back ()) {
2341
- noSecondPart = target;
2342
- second.pop_back ();
2406
+ if (second.size () != first.size () && !second.empty ()) {
2407
+ auto diag = emitSilenceableError ()
2408
+ << " splitting does not produce the second part for a subset "
2409
+ " of targets" ;
2410
+ diag.attachNote ()
2411
+ << " expected splitting to produce the second part of all "
2412
+ " or none of the targets" ;
2413
+ diag.attachNote (noSecondPart->getLoc ())
2414
+ << " first target with no second part" ;
2415
+ return diag;
2343
2416
}
2344
- }
2345
2417
2346
- if (second.size () != first.size () && !second.empty ()) {
2347
- auto diag = emitSilenceableError ()
2348
- << " splitting does not produce the second part for a subset "
2349
- " of targets" ;
2350
- diag.attachNote () << " expected splitting to produce the second part of all "
2351
- " or none of the targets" ;
2352
- diag.attachNote (noSecondPart->getLoc ())
2353
- << " first target with no second part" ;
2354
- return diag;
2418
+ results.set (cast<OpResult>(getFirst ()), first);
2419
+ results.set (cast<OpResult>(getSecond ()), second);
2355
2420
}
2356
-
2357
- results.set (cast<OpResult>(getFirst ()), first);
2358
- results.set (cast<OpResult>(getSecond ()), second);
2359
2421
return DiagnosedSilenceableFailure::success ();
2360
2422
}
2361
2423
0 commit comments