Skip to content

Commit 2b57fa2

Browse files
committed
Enable LICM for ops with read side effects in scf.for wrapped by a guard
1 parent bf700c3 commit 2b57fa2

File tree

9 files changed

+355
-18
lines changed

9 files changed

+355
-18
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def ForOp : SCF_Op<"for",
139139
"getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
140140
"getLoopUpperBounds", "getYieldedValuesMutable",
141141
"promoteIfSingleIteration", "replaceWithAdditionalYields",
142+
"wrapInTripCountCheck", "unwrapTripCountCheck",
142143
"yieldTiledValuesAndReplace"]>,
143144
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
144145
ConditionallySpeculatable,
@@ -302,7 +303,7 @@ def ForallOp : SCF_Op<"forall", [
302303
AttrSizedOperandSegments,
303304
AutomaticAllocationScope,
304305
DeclareOpInterfaceMethods<LoopLikeOpInterface,
305-
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
306+
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
306307
"getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
307308
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
308309
RecursiveMemoryEffects,

mlir/include/mlir/Interfaces/LoopLikeInterface.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
7979
/*methodBody=*/"",
8080
/*defaultImplementation=*/"op->moveBefore($_op);"
8181
>,
82+
InterfaceMethod<[{
83+
Wraps the loop into a trip-count check.
84+
}],
85+
/*retTy=*/"FailureOr<std::pair<::mlir::Operation *, ::mlir::Region *>>",
86+
/*methodName=*/"wrapInTripCountCheck",
87+
/*args=*/(ins),
88+
/*methodBody=*/"",
89+
/*defaultImplementation=*/"return ::mlir::failure();"
90+
>,
91+
InterfaceMethod<[{
92+
Unwraps the trip-count check.
93+
}],
94+
/*retTy=*/"::llvm::LogicalResult",
95+
/*methodName=*/"unwrapTripCountCheck",
96+
/*args=*/(ins),
97+
/*methodBody=*/"",
98+
/*defaultImplementation=*/[{
99+
return ::mlir::failure();
100+
}]
101+
>,
82102
InterfaceMethod<[{
83103
Promotes the loop body to its containing block if the loop is known to
84104
have a single iteration. Returns "success" if the promotion was

mlir/include/mlir/Interfaces/SideEffectInterfaces.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,10 @@ bool wouldOpBeTriviallyDead(Operation *op);
433433
/// conditions are satisfied.
434434
bool isMemoryEffectFree(Operation *op);
435435

436+
/// Returns true if the given operation implements `MemoryEffectOpInterface` and
437+
/// has only read effects.
438+
bool hasOnlyReadEffect(Operation *op);
439+
436440
/// Returns the side effects of an operation. If the operation has
437441
/// RecursiveMemoryEffects, include all side effects of child operations.
438442
///

mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,19 @@ class Value;
4848
/// }
4949
/// ```
5050
///
51-
/// Users must supply three callbacks.
51+
/// Users must supply five callbacks.
5252
///
5353
/// - `isDefinedOutsideRegion` returns true if the given value is invariant with
5454
/// respect to the given region. A common implementation might be:
5555
/// `value.getParentRegion()->isProperAncestor(region)`.
5656
/// - `shouldMoveOutOfRegion` returns true if the provided operation can be
57-
/// moved of the given region, e.g. if it is side-effect free.
57+
/// moved of the given region, e.g. if it is side-effect free or has only read
58+
/// side effects.
59+
/// - `wrapInGuard` wraps the given operation in a trip-count check guard.
5860
/// - `moveOutOfRegion` moves the operation out of the given region. A common
5961
/// implementation might be: `op->moveBefore(region->getParentOp())`.
62+
/// - `unwrapGuard` unwraps the trip-count check if there is no op guarded by
63+
/// this check.
6064
///
6165
/// An operation is moved if all of its operands satisfy
6266
/// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
@@ -66,7 +70,9 @@ size_t moveLoopInvariantCode(
6670
ArrayRef<Region *> regions,
6771
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
6872
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
69-
function_ref<void(Operation *, Region *)> moveOutOfRegion);
73+
function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
74+
function_ref<void(Operation *, Region *)> moveOutOfRegion,
75+
function_ref<LogicalResult()> unwrapGuard);
7076

7177
/// Move side-effect free loop invariant code out of a loop-like op using
7278
/// methods provided by the interface.

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,83 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
395395

396396
std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
397397

398+
FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() {
399+
auto lowerBound = this->getLowerBound();
400+
auto upperBound = this->getUpperBound();
401+
auto step = this->getStep();
402+
auto initArgs = this->getInitArgs();
403+
auto results = this->getResults();
404+
auto loc = this->getLoc();
405+
406+
IRRewriter rewriter(this->getContext());
407+
OpBuilder::InsertionGuard insertGuard(rewriter);
408+
rewriter.setInsertionPointAfter(this->getOperation());
409+
410+
// Form the trip count calculation
411+
auto subOp = rewriter.create<arith::SubIOp>(loc, upperBound, lowerBound);
412+
auto ceilDivSIOp = rewriter.create<arith::CeilDivSIOp>(loc, subOp, step);
413+
Value zero;
414+
if (upperBound.getType().isIndex()) {
415+
zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
416+
} else {
417+
zero = rewriter.create<arith::ConstantIntOp>(
418+
loc, 0,
419+
/*width=*/
420+
upperBound.getType().getIntOrFloatBitWidth());
421+
}
422+
auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
423+
ceilDivSIOp, zero);
424+
scf::YieldOp yieldInThen;
425+
// Create the trip-count check
426+
auto ifOp = rewriter.create<scf::IfOp>(
427+
loc, cmpIOp,
428+
[&](OpBuilder &builder, Location loc) {
429+
yieldInThen = builder.create<scf::YieldOp>(loc, results);
430+
},
431+
[&](OpBuilder &builder, Location loc) {
432+
builder.create<scf::YieldOp>(loc, initArgs);
433+
});
434+
435+
for (auto [forOpResult, ifOpResult] : llvm::zip(results, ifOp.getResults()))
436+
rewriter.replaceAllUsesExcept(forOpResult, ifOpResult, yieldInThen);
437+
// Move the scf.for into the then block
438+
rewriter.moveOpBefore(this->getOperation(), yieldInThen);
439+
return std::make_pair(ifOp.getOperation(), &this->getRegion());
440+
}
441+
442+
LogicalResult ForOp::unwrapTripCountCheck() {
443+
auto ifOp = (*this)->getParentRegion()->getParentOp();
444+
if (!isa<scf::IfOp>(ifOp))
445+
return failure();
446+
447+
auto wrappedForOp = this->getOperation();
448+
449+
IRRewriter rewriter(ifOp->getContext());
450+
OpBuilder::InsertionGuard insertGuard(rewriter);
451+
rewriter.setInsertionPoint(ifOp);
452+
453+
auto cmpOp = ifOp->getOperand(0).getDefiningOp();
454+
auto ceilDivSIOp = cmpOp->getOperand(0).getDefiningOp();
455+
auto zero = cmpOp->getOperand(1).getDefiningOp();
456+
auto subOp = ceilDivSIOp->getOperand(0).getDefiningOp();
457+
if (!isa<arith::CmpIOp>(cmpOp) || !isa<arith::CeilDivSIOp>(ceilDivSIOp) ||
458+
!isa<arith::SubIOp>(subOp))
459+
return failure();
460+
461+
rewriter.moveOpBefore(wrappedForOp, ifOp);
462+
463+
for (auto [forOpResult, ifOpResult] :
464+
llvm::zip(wrappedForOp->getResults(), ifOp->getResults()))
465+
rewriter.replaceAllUsesWith(ifOpResult, forOpResult);
466+
467+
rewriter.eraseOp(ifOp);
468+
rewriter.eraseOp(cmpOp);
469+
rewriter.eraseOp(zero);
470+
rewriter.eraseOp(ceilDivSIOp);
471+
rewriter.eraseOp(subOp);
472+
return success();
473+
}
474+
398475
/// Promotes the loop body of a forOp to its containing block if the forOp
399476
/// it can be determined that the loop has a single iteration.
400477
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
@@ -3397,9 +3474,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
33973474

33983475
if (functionType.getNumInputs() != operands.size()) {
33993476
return parser.emitError(typeLoc)
3400-
<< "expected as many input types as operands "
3401-
<< "(expected " << operands.size() << " got "
3402-
<< functionType.getNumInputs() << ")";
3477+
<< "expected as many input types as operands " << "(expected "
3478+
<< operands.size() << " got " << functionType.getNumInputs() << ")";
34033479
}
34043480

34053481
// Resolve input operands.

mlir/lib/Interfaces/SideEffectInterfaces.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,13 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) {
306306
return wouldOpBeTriviallyDeadImpl(op);
307307
}
308308

309+
bool mlir::hasOnlyReadEffect(Operation *op) {
310+
if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op)) {
311+
return memEffects.onlyHasEffect<MemoryEffects::Read>();
312+
}
313+
return false;
314+
}
315+
309316
bool mlir::isMemoryEffectFree(Operation *op) {
310317
if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
311318
if (!memInterface.hasNoEffect())

mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,48 +56,117 @@ static bool canBeHoisted(Operation *op,
5656
op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
5757
}
5858

59+
static bool dependsOnGuarded(Operation *op,
60+
function_ref<bool(OpOperand &)> condition) {
61+
auto walkFn = [&](Operation *child) {
62+
for (OpOperand &operand : child->getOpOperands()) {
63+
if (!condition(operand))
64+
return WalkResult::interrupt();
65+
}
66+
return WalkResult::advance();
67+
};
68+
return op->walk(walkFn).wasInterrupted();
69+
}
70+
71+
static bool dependsOnGuarded(Operation *op,
72+
function_ref<bool(Value)> definedOutsideGuard) {
73+
return dependsOnGuarded(op, [&](OpOperand &operand) {
74+
return definedOutsideGuard(operand.get());
75+
});
76+
}
77+
78+
static bool loopSideEffectFreeOrHasOnlyReadEffect(Operation *loop) {
79+
for (auto &region : loop->getRegions()) {
80+
for (auto &block : region.getBlocks()) {
81+
for (Operation &op : block.getOperations()) {
82+
if (!isMemoryEffectFree(&op) && !hasOnlyReadEffect(&op))
83+
return false;
84+
}
85+
}
86+
}
87+
return true;
88+
}
89+
5990
size_t mlir::moveLoopInvariantCode(
6091
ArrayRef<Region *> regions,
6192
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
6293
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
63-
function_ref<void(Operation *, Region *)> moveOutOfRegion) {
94+
function_ref<FailureOr<std::pair<Operation *, Region *>>()> wrapInGuard,
95+
function_ref<void(Operation *, Region *)> moveOutOfRegion,
96+
function_ref<LogicalResult()> unwrapGuard) {
6497
size_t numMoved = 0;
6598

6699
for (Region *region : regions) {
67100
LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
68101
<< *region->getParentOp() << "\n");
69102

103+
auto loopSideEffectFreeOrHasOnlyReadSideEffect =
104+
loopSideEffectFreeOrHasOnlyReadEffect(region->getParentOp());
105+
106+
size_t numMovedWithoutGuard = 0;
107+
108+
FailureOr<std::pair<Operation *, Region *>> ifOpAndRegion = wrapInGuard();
109+
Region *loopRegion = region;
110+
auto isLoopWrapped = false;
111+
if (succeeded(ifOpAndRegion)) {
112+
loopRegion = ifOpAndRegion->second;
113+
isLoopWrapped = true;
114+
}
115+
70116
std::queue<Operation *> worklist;
71117
// Add top-level operations in the loop body to the worklist.
72-
for (Operation &op : region->getOps())
118+
for (Operation &op : loopRegion->getOps())
73119
worklist.push(&op);
74120

75121
auto definedOutside = [&](Value value) {
76-
return isDefinedOutsideRegion(value, region);
122+
return isDefinedOutsideRegion(value, loopRegion);
123+
};
124+
125+
auto definedOutsideGuard = [&](Value value) {
126+
return isDefinedOutsideRegion(value, loopRegion->getParentRegion());
77127
};
78128

79129
while (!worklist.empty()) {
80130
Operation *op = worklist.front();
81131
worklist.pop();
82132
// Skip ops that have already been moved. Check if the op can be hoisted.
83-
if (op->getParentRegion() != region)
133+
if (op->getParentRegion() != loopRegion)
84134
continue;
85135

86136
LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
87-
if (!shouldMoveOutOfRegion(op, region) ||
137+
138+
if (!shouldMoveOutOfRegion(op, loopRegion) ||
88139
!canBeHoisted(op, definedOutside))
89140
continue;
141+
// Can only hoist pure ops (side-effect free) when there is an op with
142+
// write side effects in the loop
143+
if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
144+
continue;
90145

91146
LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
92-
moveOutOfRegion(op, region);
147+
148+
auto moveWithoutGuard = isMemoryEffectFree(op) &&
149+
!dependsOnGuarded(op, definedOutsideGuard) &&
150+
isLoopWrapped;
151+
numMovedWithoutGuard += moveWithoutGuard;
152+
153+
moveOutOfRegion(op, moveWithoutGuard ? loopRegion->getParentRegion()
154+
: loopRegion);
93155
++numMoved;
94156

95157
// Since the op has been moved, we need to check its users within the
96158
// top-level of the loop body.
97159
for (Operation *user : op->getUsers())
98-
if (user->getParentRegion() == region)
160+
if (user->getParentRegion() == loopRegion)
99161
worklist.push(user);
100162
}
163+
164+
// Unwrap the loop if it was wrapped but no ops were moved in the guard.
165+
if (isLoopWrapped && numMovedWithoutGuard == numMoved) {
166+
auto tripCountCheckUnwrapped = unwrapGuard();
167+
if (failed(tripCountCheckUnwrapped))
168+
llvm_unreachable("Should not fail unwrapping trip-count check");
169+
}
101170
}
102171

103172
return numMoved;
@@ -106,13 +175,18 @@ size_t mlir::moveLoopInvariantCode(
106175
size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
107176
return moveLoopInvariantCode(
108177
loopLike.getLoopRegions(),
109-
[&](Value value, Region *) {
110-
return loopLike.isDefinedOutsideOfLoop(value);
178+
[&](Value value, Region *region) {
179+
return !region->isAncestor(value.getParentRegion());
111180
},
112181
[&](Operation *op, Region *) {
113-
return isMemoryEffectFree(op) && isSpeculatable(op);
182+
return isSpeculatable(op) &&
183+
(isMemoryEffectFree(op) || hasOnlyReadEffect(op));
184+
},
185+
[&]() { return loopLike.wrapInTripCountCheck(); },
186+
[&](Operation *op, Region *region) {
187+
op->moveBefore(region->getParentOp());
114188
},
115-
[&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
189+
[&]() { return loopLike.unwrapTripCountCheck(); });
116190
}
117191

118192
namespace {

0 commit comments

Comments
 (0)