@@ -1184,13 +1184,39 @@ static bool equalIterationSpaces(scf::ParallelOp firstPloop,
1184
1184
matchOperands (firstPloop.getStep (), secondPloop.getStep ());
1185
1185
}
1186
1186
1187
+ // ===----------------------------------------------------------------------===//
1188
+ // Fusion related helpers
1189
+ // ===----------------------------------------------------------------------===//
1190
+
1187
1191
// / Verify there are no nested ParallelOps.
1188
1192
static bool hasNestedParallelOp (scf::ParallelOp ploop) {
1189
1193
auto walkResult = ploop.getBody ()->walk (
1190
1194
[](scf::ParallelOp) { return WalkResult::interrupt (); });
1191
1195
return walkResult.wasInterrupted ();
1192
1196
}
1193
1197
1198
+ template <typename LoopTy>
1199
+ static bool checkFusionStructuralLegality (Operation *target,
1200
+ Operation *source) {
1201
+ static_assert (llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
1202
+ scf::ParallelOp>::value,
1203
+ " applies to only `forall`, `for` and `parallel`" );
1204
+ auto targetOp = dyn_cast<LoopTy>(target);
1205
+ auto sourceOp = dyn_cast<LoopTy>(source);
1206
+ if (!targetOp || !sourceOp)
1207
+ return false ;
1208
+
1209
+ if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
1210
+ return targetOp.getMixedLowerBound () == sourceOp.getMixedLowerBound () &&
1211
+ targetOp.getMixedUpperBound () == sourceOp.getMixedUpperBound () &&
1212
+ targetOp.getMixedStep () == sourceOp.getMixedStep () &&
1213
+ targetOp.getMapping () == sourceOp.getMapping ();
1214
+ else
1215
+ return targetOp.getLowerBound () == sourceOp.getLowerBound () &&
1216
+ targetOp.getUpperBound () == sourceOp.getUpperBound () &&
1217
+ targetOp.getStep () == sourceOp.getStep ();
1218
+ }
1219
+
1194
1220
static bool isFusionLegal (scf::ParallelOp firstPloop,
1195
1221
scf::ParallelOp secondPloop,
1196
1222
const IRMapping &firstToSecondPloopIndices,
0 commit comments