Skip to content

Commit f5bbd13

Browse files
committed
replace isLoopWithIdenticalConfiguration with checkFusionStructuralLegality
1 parent b73238a commit f5bbd13

File tree

3 files changed

+29
-57
lines changed

3 files changed

+29
-57
lines changed

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
160160
// Fusion related helpers
161161
//===----------------------------------------------------------------------===//
162162

163-
template <typename LoopTy>
164-
bool checkFusionStructuralLegality(Operation *target, Operation *source);
163+
bool checkFusionStructuralLegality(LoopLikeOpInterface &target,
164+
LoopLikeOpInterface &source);
165165

166166
/// Prepends operations of firstPloop's body into secondPloop's body.
167167
/// Updates secondPloop with new loop.

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -442,34 +442,6 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
442442
return DiagnosedSilenceableFailure::success();
443443
}
444444

445-
/// Check if `target` scf loop can be fused into `source` scf loop.
446-
/// Applies for scf.for, scf.forall, and scf.parallel.
447-
///
448-
/// This simply checks if both loops have the same bounds, steps and mapping.
449-
/// No attempt is made at checking that the side effects of `target` and
450-
/// `source` are independent of each other.
451-
template <typename LoopTy>
452-
static bool isLoopWithIdenticalConfiguration(Operation *target,
453-
Operation *source) {
454-
static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
455-
scf::ParallelOp>::value,
456-
"applies to only `forall`, `for` and `parallel`");
457-
auto targetOp = dyn_cast<LoopTy>(target);
458-
auto sourceOp = dyn_cast<LoopTy>(source);
459-
if (!targetOp || !sourceOp)
460-
return false;
461-
462-
if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
463-
return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
464-
targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
465-
targetOp.getMixedStep() == sourceOp.getMixedStep() &&
466-
targetOp.getMapping() == sourceOp.getMapping();
467-
else
468-
return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
469-
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
470-
targetOp.getStep() == sourceOp.getStep();
471-
}
472-
473445
DiagnosedSilenceableFailure
474446
transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
475447
transform::TransformResults &results,
@@ -485,29 +457,37 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
485457
<< "source handle (got " << llvm::range_size(sourceOps) << ")";
486458
}
487459

488-
Operation *target = *targetOps.begin();
489-
Operation *source = *sourceOps.begin();
460+
LoopLikeOpInterface target =
461+
dyn_cast<LoopLikeOpInterface>(*targetOps.begin());
462+
LoopLikeOpInterface source =
463+
dyn_cast<LoopLikeOpInterface>(*sourceOps.begin());
464+
if (!target || !source)
465+
return emitSilenceableFailure(target->getLoc())
466+
<< "target or source is not a loop op";
490467

491468
// Check if the target and source are siblings.
492469
DiagnosedSilenceableFailure diag = isOpSibling(target, source);
493470
if (!diag.succeeded())
494471
return diag;
495472

473+
if (!mlir::checkFusionStructuralLegality(target, source))
474+
return emitSilenceableFailure(target->getLoc())
475+
<< "operations cannot be fused";
476+
496477
Operation *fusedLoop;
497478
/// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
498-
if (isLoopWithIdenticalConfiguration<scf::ForOp>(target, source)) {
479+
if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
499480
fusedLoop = fuseIndependentSiblingForLoops(
500481
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
501-
} else if (isLoopWithIdenticalConfiguration<scf::ForallOp>(target, source)) {
482+
} else if (isa<scf::ForallOp>(target) && isa<scf::ForallOp>(source)) {
502483
fusedLoop = fuseIndependentSiblingForallLoops(
503484
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
504-
} else if (isLoopWithIdenticalConfiguration<scf::ParallelOp>(target,
505-
source)) {
485+
} else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
506486
fusedLoop = fuseIndependentSiblingParallelLoops(
507487
cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
508488
} else
509489
return emitSilenceableFailure(target->getLoc())
510-
<< "operations cannot be fused";
490+
<< "unsupported loop type for fusion";
511491

512492
assert(fusedLoop && "failed to fuse operations");
513493

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,26 +1195,18 @@ static bool hasNestedParallelOp(scf::ParallelOp ploop) {
11951195
return walkResult.wasInterrupted();
11961196
}
11971197

1198-
template <typename LoopTy>
1199-
static bool checkFusionStructuralLegality(Operation *target,
1200-
Operation *source) {
1201-
static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
1202-
scf::ParallelOp>::value,
1203-
"applies to only `forall`, `for` and `parallel`");
1204-
auto targetOp = dyn_cast<LoopTy>(target);
1205-
auto sourceOp = dyn_cast<LoopTy>(source);
1206-
if (!targetOp || !sourceOp)
1207-
return false;
1208-
1209-
if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
1210-
return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
1211-
targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
1212-
targetOp.getMixedStep() == sourceOp.getMixedStep() &&
1213-
targetOp.getMapping() == sourceOp.getMapping();
1214-
else
1215-
return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
1216-
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
1217-
targetOp.getStep() == sourceOp.getStep();
1198+
bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
1199+
LoopLikeOpInterface &source) {
1200+
auto iterSpaceEq =
1201+
target.getMixedLowerBound() == source.getMixedLowerBound() &&
1202+
target.getMixedUpperBound() == source.getMixedUpperBound() &&
1203+
target.getMixedStep() == source.getMixedStep();
1204+
auto forAllTarget = dyn_cast<scf::ForallOp>(*target);
1205+
auto forAllSource = dyn_cast<scf::ForallOp>(*source);
1206+
if (forAllTarget && forAllSource)
1207+
return iterSpaceEq &&
1208+
forAllTarget.getMapping() == forAllSource.getMapping();
1209+
return iterSpaceEq;
12181210
}
12191211

12201212
static bool isFusionLegal(scf::ParallelOp firstPloop,

0 commit comments

Comments
 (0)