@@ -442,34 +442,6 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
442
442
return DiagnosedSilenceableFailure::success ();
443
443
}
444
444
445
- // / Check if `target` scf loop can be fused into `source` scf loop.
446
- // / Applies for scf.for, scf.forall, and scf.parallel.
447
- // /
448
- // / This simply checks if both loops have the same bounds, steps and mapping.
449
- // / No attempt is made at checking that the side effects of `target` and
450
- // / `source` are independent of each other.
451
- template <typename LoopTy>
452
- static bool isLoopWithIdenticalConfiguration (Operation *target,
453
- Operation *source) {
454
- static_assert (llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
455
- scf::ParallelOp>::value,
456
- " applies to only `forall`, `for` and `parallel`" );
457
- auto targetOp = dyn_cast<LoopTy>(target);
458
- auto sourceOp = dyn_cast<LoopTy>(source);
459
- if (!targetOp || !sourceOp)
460
- return false ;
461
-
462
- if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
463
- return targetOp.getMixedLowerBound () == sourceOp.getMixedLowerBound () &&
464
- targetOp.getMixedUpperBound () == sourceOp.getMixedUpperBound () &&
465
- targetOp.getMixedStep () == sourceOp.getMixedStep () &&
466
- targetOp.getMapping () == sourceOp.getMapping ();
467
- else
468
- return targetOp.getLowerBound () == sourceOp.getLowerBound () &&
469
- targetOp.getUpperBound () == sourceOp.getUpperBound () &&
470
- targetOp.getStep () == sourceOp.getStep ();
471
- }
472
-
473
445
DiagnosedSilenceableFailure
474
446
transform::LoopFuseSiblingOp::apply (transform::TransformRewriter &rewriter,
475
447
transform::TransformResults &results,
@@ -485,29 +457,37 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
485
457
<< " source handle (got " << llvm::range_size (sourceOps) << " )" ;
486
458
}
487
459
488
- Operation *target = *targetOps.begin ();
489
- Operation *source = *sourceOps.begin ();
460
+ LoopLikeOpInterface target =
461
+ dyn_cast<LoopLikeOpInterface>(*targetOps.begin ());
462
+ LoopLikeOpInterface source =
463
+ dyn_cast<LoopLikeOpInterface>(*sourceOps.begin ());
464
+ if (!target || !source)
465
+ return emitSilenceableFailure (target->getLoc ())
466
+ << " target or source is not a loop op" ;
490
467
491
468
// Check if the target and source are siblings.
492
469
DiagnosedSilenceableFailure diag = isOpSibling (target, source);
493
470
if (!diag.succeeded ())
494
471
return diag;
495
472
473
+ if (!mlir::checkFusionStructuralLegality (target, source))
474
+ return emitSilenceableFailure (target->getLoc ())
475
+ << " operations cannot be fused" ;
476
+
496
477
Operation *fusedLoop;
497
478
// / TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
498
- if (isLoopWithIdenticalConfiguration <scf::ForOp>(target, source)) {
479
+ if (isa <scf::ForOp>(target) && isa<scf::ForOp>( source)) {
499
480
fusedLoop = fuseIndependentSiblingForLoops (
500
481
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
501
- } else if (isLoopWithIdenticalConfiguration <scf::ForallOp>(target, source)) {
482
+ } else if (isa <scf::ForallOp>(target) && isa<scf::ForallOp>( source)) {
502
483
fusedLoop = fuseIndependentSiblingForallLoops (
503
484
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
504
- } else if (isLoopWithIdenticalConfiguration<scf::ParallelOp>(target,
505
- source)) {
485
+ } else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
506
486
fusedLoop = fuseIndependentSiblingParallelLoops (
507
487
cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
508
488
} else
509
489
return emitSilenceableFailure (target->getLoc ())
510
- << " operations cannot be fused " ;
490
+ << " unsupported loop type for fusion " ;
511
491
512
492
assert (fusedLoop && " failed to fuse operations" );
513
493
0 commit comments