Skip to content

Commit a6d0588

Browse files
committed
check if isOpSibling in checkFusionStructuralLegality
1 parent ffb73a7 commit a6d0588

File tree

5 files changed

+99
-88
lines changed

5 files changed

+99
-88
lines changed

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,11 @@ void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
160160
// Fusion related helpers
161161
//===----------------------------------------------------------------------===//
162162

163-
/// Check structural compatibility between two loops such as iteration space.
163+
/// Check structural compatibility between two loops such as iteration space
164+
/// and dominance.
164165
bool checkFusionStructuralLegality(LoopLikeOpInterface target,
165-
LoopLikeOpInterface source);
166+
LoopLikeOpInterface source,
167+
Diagnostic &diag);
166168

167169
/// Given two scf.forall loops, `target` and `source`, fuses `target` into
168170
/// `source`. Assumes that the given loops are siblings and are independent of

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 4 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -425,78 +425,6 @@ void transform::TakeAssumedBranchOp::getEffects(
425425
// LoopFuseSiblingOp
426426
//===----------------------------------------------------------------------===//
427427

428-
/// Check if `target` and `source` are siblings, in the context that `target`
429-
/// is being fused into `source`.
430-
///
431-
/// This is a simple check that just checks if both operations are in the same
432-
/// block and some checks to ensure that the fused IR does not violate
433-
/// dominance.
434-
static DiagnosedSilenceableFailure isOpSibling(Operation *target,
435-
Operation *source) {
436-
// Check if both operations are same.
437-
if (target == source)
438-
return emitSilenceableFailure(source)
439-
<< "target and source need to be different loops";
440-
441-
// Check if both operations are in the same block.
442-
if (target->getBlock() != source->getBlock())
443-
return emitSilenceableFailure(source)
444-
<< "target and source are not in the same block";
445-
446-
// Check if fusion will violate dominance.
447-
DominanceInfo domInfo(source);
448-
if (target->isBeforeInBlock(source)) {
449-
// Since `target` is before `source`, all users of results of `target`
450-
// need to be dominated by `source`.
451-
for (Operation *user : target->getUsers()) {
452-
if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
453-
return emitSilenceableFailure(target)
454-
<< "user of results of target should be properly dominated by "
455-
"source";
456-
}
457-
}
458-
} else {
459-
// Since `target` is after `source`, all values used by `target` need
460-
// to dominate `source`.
461-
462-
// Check if operands of `target` are dominated by `source`.
463-
for (Value operand : target->getOperands()) {
464-
Operation *operandOp = operand.getDefiningOp();
465-
// Operands without defining operations are block arguments. When `target`
466-
// and `source` occur in the same block, these operands dominate `source`.
467-
if (!operandOp)
468-
continue;
469-
470-
// Operand's defining operation should properly dominate `source`.
471-
if (!domInfo.properlyDominates(operandOp, source,
472-
/*enclosingOpOk=*/false))
473-
return emitSilenceableFailure(target)
474-
<< "operands of target should be properly dominated by source";
475-
}
476-
477-
// Check if values used by `target` are dominated by `source`.
478-
bool failed = false;
479-
OpOperand *failedValue = nullptr;
480-
visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
481-
Operation *operandOp = operand->get().getDefiningOp();
482-
if (operandOp && !domInfo.properlyDominates(operandOp, source,
483-
/*enclosingOpOk=*/false)) {
484-
// `operand` is not an argument of an enclosing block and the defining
485-
// op of `operand` is outside `target` but does not dominate `source`.
486-
failed = true;
487-
failedValue = operand;
488-
}
489-
});
490-
491-
if (failed)
492-
return emitSilenceableFailure(failedValue->getOwner())
493-
<< "values used inside regions of target should be properly "
494-
"dominated by source";
495-
}
496-
497-
return DiagnosedSilenceableFailure::success();
498-
}
499-
500428
DiagnosedSilenceableFailure
501429
transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
502430
transform::TransformResults &results,
@@ -520,14 +448,10 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
520448
return emitSilenceableFailure(target->getLoc())
521449
<< "target or source is not a loop op";
522450

523-
// Check if the target and source are siblings.
524-
DiagnosedSilenceableFailure diag = isOpSibling(target, source);
525-
if (!diag.succeeded())
526-
return diag;
527-
528-
if (!mlir::checkFusionStructuralLegality(target, source))
529-
return emitSilenceableFailure(target->getLoc())
530-
<< "operations cannot be fused";
451+
// Check if loops can be fused
452+
Diagnostic diag(target.getLoc(), DiagnosticSeverity::Error);
453+
if (!mlir::checkFusionStructuralLegality(target, source, diag))
454+
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
531455

532456
Operation *fusedLoop;
533457
// TODO: Support fusion for loop-like ops besides scf.for, scf.forall

mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,10 @@ verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
136136
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
137137
const IRMapping &firstToSecondPloopIndices,
138138
llvm::function_ref<bool(Value, Value)> mayAlias) {
139+
Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark);
139140
return !hasNestedParallelOp(firstPloop) &&
140141
!hasNestedParallelOp(secondPloop) &&
141-
checkFusionStructuralLegality(firstPloop, secondPloop) &&
142+
checkFusionStructuralLegality(firstPloop, secondPloop, diag) &&
142143
succeeded(verifyDependencies(firstPloop, secondPloop,
143144
firstToSecondPloopIndices, mayAlias));
144145
}

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Func/IR/FuncOps.h"
1818
#include "mlir/Dialect/SCF/IR/SCF.h"
1919
#include "mlir/IR/BuiltinOps.h"
20+
#include "mlir/IR/Dominance.h"
2021
#include "mlir/IR/IRMapping.h"
2122
#include "mlir/IR/PatternMatch.h"
2223
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -1074,8 +1075,86 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
10741075
// Fusion related helpers
10751076
//===----------------------------------------------------------------------===//
10761077

1078+
/// Check if `target` and `source` are siblings, in the context that `target`
1079+
/// is being fused into `source`.
1080+
///
1081+
/// This is a simple check that just checks if both operations are in the same
1082+
/// block and some checks to ensure that the fused IR does not violate
1083+
/// dominance.
1084+
static bool isOpSibling(Operation *target, Operation *source,
1085+
Diagnostic &diag) {
1086+
// Check if both operations are same.
1087+
if (target == source) {
1088+
diag << "target and source need to be different loops";
1089+
return false;
1090+
}
1091+
1092+
// Check if both operations are in the same block.
1093+
if (target->getBlock() != source->getBlock()) {
1094+
diag << "target and source are not in the same block";
1095+
return false;
1096+
}
1097+
1098+
// Check if fusion will violate dominance.
1099+
DominanceInfo domInfo(source);
1100+
if (target->isBeforeInBlock(source)) {
1101+
// Since `target` is before `source`, all users of results of `target`
1102+
// need to be dominated by `source`.
1103+
for (Operation *user : target->getUsers()) {
1104+
if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
1105+
diag << "user of results of target should "
1106+
"be properly dominated by source";
1107+
return false;
1108+
}
1109+
}
1110+
} else {
1111+
// Since `target` is after `source`, all values used by `target` need
1112+
// to dominate `source`.
1113+
1114+
// Check if operands of `target` are dominated by `source`.
1115+
for (Value operand : target->getOperands()) {
1116+
Operation *operandOp = operand.getDefiningOp();
1117+
// Operands without defining operations are block arguments. When `target`
1118+
// and `source` occur in the same block, these operands dominate `source`.
1119+
if (!operandOp)
1120+
continue;
1121+
1122+
// Operand's defining operation should properly dominate `source`.
1123+
if (!domInfo.properlyDominates(operandOp, source,
1124+
/*enclosingOpOk=*/false)) {
1125+
diag << "operands of target should be properly dominated by source";
1126+
return false;
1127+
}
1128+
}
1129+
1130+
// Check if values used by `target` are dominated by `source`.
1131+
bool failed = false;
1132+
OpOperand *failedValue = nullptr;
1133+
visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
1134+
Operation *operandOp = operand->get().getDefiningOp();
1135+
if (operandOp && !domInfo.properlyDominates(operandOp, source,
1136+
/*enclosingOpOk=*/false)) {
1137+
// `operand` is not an argument of an enclosing block and the defining
1138+
// op of `operand` is outside `target` but does not dominate `source`.
1139+
failed = true;
1140+
failedValue = operand;
1141+
}
1142+
});
1143+
1144+
if (failed) {
1145+
diag << "values used inside regions of target should be properly "
1146+
"dominated by source";
1147+
diag.attachNote(failedValue->getOwner()->getLoc()) << "see operation";
1148+
return false;
1149+
}
1150+
}
1151+
1152+
return true;
1153+
}
1154+
10771155
bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target,
1078-
LoopLikeOpInterface source) {
1156+
LoopLikeOpInterface source,
1157+
Diagnostic &diag) {
10791158
bool iterSpaceEq =
10801159
target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
10811160
target.getLoopUpperBounds() == source.getLoopUpperBounds() &&
@@ -1085,9 +1164,13 @@ bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target,
10851164
// TODO: Decouple checks on concrete loop types and move this function
10861165
// somewhere for general utility for `LoopLikeOpInterface`
10871166
if (forAllTarget && forAllSource)
1088-
return iterSpaceEq &&
1089-
forAllTarget.getMapping() == forAllSource.getMapping();
1090-
return iterSpaceEq;
1167+
iterSpaceEq =
1168+
iterSpaceEq && forAllTarget.getMapping() == forAllSource.getMapping();
1169+
if (!iterSpaceEq) {
1170+
diag << "target and source iteration spaces must be equal";
1171+
return false;
1172+
}
1173+
return isOpSibling(target, source, diag);
10911174
}
10921175

10931176
scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,

mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,9 @@ func.func @target_for_region_uses_result_of_source_for_err(%A: tensor<128xf32>,
335335
%6 = vector.transfer_write %5, %arg4[%arg3] {in_bounds = [true]} : vector<16xf32>, tensor<128xf32>
336336
scf.yield %6 : tensor<128xf32>
337337
}
338-
%dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
339338
// expected-error @below {{values used inside regions of target should be properly dominated by source}}
339+
%dup1 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %B) -> (tensor<128xf32>) {
340+
// expected-note @below {{see operation}}
340341
%dup2 = vector.transfer_read %1[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
341342
%dup3 = vector.transfer_read %arg4[%arg3], %cst {in_bounds = [true]} : tensor<128xf32>, vector<16xf32>
342343
%dup5 = arith.addf %dup3, %dup2 : vector<16xf32>

0 commit comments

Comments
 (0)