Skip to content

Commit b73238a

Browse files
committed
add checkFusionStructuralLegality
1 parent 50852d5 commit b73238a

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,13 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
156156
void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
157157
scf::ForOp root);
158158

159+
//===----------------------------------------------------------------------===//
160+
// Fusion related helpers
161+
//===----------------------------------------------------------------------===//
162+
163+
template <typename LoopTy>
164+
bool checkFusionStructuralLegality(Operation *target, Operation *source);
165+
159166
/// Prepends operations of firstPloop's body into secondPloop's body.
160167
/// Updates secondPloop with new loop.
161168
void fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,13 +1184,39 @@ static bool equalIterationSpaces(scf::ParallelOp firstPloop,
11841184
matchOperands(firstPloop.getStep(), secondPloop.getStep());
11851185
}
11861186

1187+
//===----------------------------------------------------------------------===//
1188+
// Fusion related helpers
1189+
//===----------------------------------------------------------------------===//
1190+
11871191
/// Verify there are no nested ParallelOps.
11881192
static bool hasNestedParallelOp(scf::ParallelOp ploop) {
11891193
auto walkResult = ploop.getBody()->walk(
11901194
[](scf::ParallelOp) { return WalkResult::interrupt(); });
11911195
return walkResult.wasInterrupted();
11921196
}
11931197

1198+
template <typename LoopTy>
1199+
static bool checkFusionStructuralLegality(Operation *target,
1200+
Operation *source) {
1201+
static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
1202+
scf::ParallelOp>::value,
1203+
"applies to only `forall`, `for` and `parallel`");
1204+
auto targetOp = dyn_cast<LoopTy>(target);
1205+
auto sourceOp = dyn_cast<LoopTy>(source);
1206+
if (!targetOp || !sourceOp)
1207+
return false;
1208+
1209+
if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
1210+
return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
1211+
targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
1212+
targetOp.getMixedStep() == sourceOp.getMixedStep() &&
1213+
targetOp.getMapping() == sourceOp.getMapping();
1214+
else
1215+
return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
1216+
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
1217+
targetOp.getStep() == sourceOp.getStep();
1218+
}
1219+
11941220
static bool isFusionLegal(scf::ParallelOp firstPloop,
11951221
scf::ParallelOp secondPloop,
11961222
const IRMapping &firstToSecondPloopIndices,

0 commit comments

Comments
 (0)