Skip to content

Commit 8433b18

Browse files
committed
turn into a general slice walk
1 parent c458ec8 commit 8433b18

File tree

3 files changed

+105
-97
lines changed

3 files changed

+105
-97
lines changed

mlir/include/mlir/IR/SliceSupport.h

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,13 @@ namespace mlir {
1515

1616
/// A class to signal how to proceed with the walk of the backward slice:
1717
/// - Interrupt: Stops the walk.
18-
/// - Advance: Continues the walk to control flow predecessors values.
1918
/// - AdvanceTo: Continues the walk to user-specified values.
2019
/// - Skip: Continues the walk, but skips the predecessors of the current value.
2120
class WalkContinuation {
2221
public:
2322
enum class WalkAction {
2423
/// Stops the walk.
2524
Interrupt,
26-
/// Continues the walk to control flow predecessors values.
27-
Advance,
2825
/// Continues the walk to user-specified values.
2926
AdvanceTo,
3027
/// Continues the walk, but skips the predecessors of the current value.
@@ -34,10 +31,6 @@ class WalkContinuation {
3431
WalkContinuation(WalkAction action, mlir::ValueRange nextValues)
3532
: action(action), nextValues(nextValues) {}
3633

37-
/// Allows LogicalResult to interrupt the walk on failure.
38-
explicit WalkContinuation(llvm::LogicalResult action)
39-
: action(failed(action) ? WalkAction::Interrupt : WalkAction::Advance) {}
40-
4134
/// Allows diagnostics to interrupt the walk.
4235
explicit WalkContinuation(mlir::Diagnostic &&)
4336
: action(WalkAction::Interrupt) {}
@@ -58,12 +51,6 @@ class WalkContinuation {
5851
return WalkContinuation(WalkAction::AdvanceTo, nextValues);
5952
}
6053

61-
/// Creates a continuation that adds the control flow predecessor values to
62-
/// the work list and advances the walk.
63-
static WalkContinuation advance() {
64-
return WalkContinuation(WalkAction::Advance, {});
65-
}
66-
6754
/// Creates a continuation that advances the walk without adding any
6855
/// predecessor values to the work list.
6956
static WalkContinuation skip() {
@@ -89,34 +76,23 @@ class WalkContinuation {
8976
};
9077

9178
/// A callback that is invoked for each value encountered during the walk of the
92-
/// backward slice. The callback takes the current value, and returns the walk
93-
/// continuation, which determines if the walk should proceed and if yes, with
94-
/// which values.
95-
using WalkCallback = mlir::function_ref<WalkContinuation(mlir::Value)>;
96-
97-
/// Walks the backward slice starting from the `rootValues` using a depth-first
98-
/// traversal following the use-def chains. The walk calls the provided
99-
/// `walkCallback` for each value encountered in the backward slice and uses the
100-
/// returned walk continuation to determine how to proceed. Additionally, the
101-
/// walk also transparently traverses through select operations and control flow
102-
/// operations that implement RegionBranchOpInterface or BranchOpInterface.
103-
WalkContinuation walkBackwardSlice(mlir::ValueRange rootValues,
104-
WalkCallback walkCallback);
105-
106-
/// A callback that is invoked for each value encountered during the walk of the
107-
/// backward slice. The callback takes the current value, and returns the walk
79+
/// slice. The callback takes the current value, and returns the walk
10880
/// continuation, which determines if the walk should proceed and if yes, with
10981
/// which values.
11082
using WalkCallback = mlir::function_ref<WalkContinuation(mlir::Value)>;
11183

112-
/// Walks the backward slice starting from the `rootValues` using a depth-first
113-
/// traversal following the use-def chains. The walk calls the provided
114-
/// `walkCallback` for each value encountered in the backward slice and uses the
115-
/// returned walk continuation to determine how to proceed. Additionally, the
116-
/// walk also transparently traverses through select operations and control flow
117-
/// operations that implement RegionBranchOpInterface or BranchOpInterface.
118-
WalkContinuation walkBackwardSlice(mlir::ValueRange rootValues,
119-
WalkCallback walkCallback);
84+
/// Walks the slice starting from the `rootValues` using a depth-first
85+
/// traversal. The walk calls the provided `walkCallback` for each value
86+
/// encountered in the slice and uses the returned walk continuation to
87+
/// determine how to proceed.
88+
WalkContinuation walkSlice(mlir::ValueRange rootValues,
89+
WalkCallback walkCallback);
90+
91+
/// Computes a vector of all control predecessors of `value`. Relies on
92+
/// RegionBranchOpInterface and BranchOpInterface to determine predecessors.
93+
/// Returns nullopt if value has no predecessors or when the relevant operations
94+
/// are missing the interface implementations.
95+
std::optional<SmallVector<Value>> getControlFlowPredecessors(Value value);
12096

12197
} // namespace mlir
12298

mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,11 @@ static ArrayAttr concatArrayAttr(ArrayAttr lhs, ArrayAttr rhs) {
224224

225225
/// Attempts to return the set of all underlying pointer values that
226226
/// `pointerValue` is based on. This function traverses through select
227-
/// operations and block arguments unlike getUnderlyingObject.
228-
static SmallVector<Value> getUnderlyingObjectSet(Value pointerValue) {
227+
/// operations and block arguments.
228+
static FailureOr<SmallVector<Value>>
229+
getUnderlyingObjectSet(Value pointerValue) {
229230
SmallVector<Value> result;
230-
walkBackwardSlice(pointerValue, [&](Value val) {
231+
WalkContinuation walkResult = walkSlice(pointerValue, [&](Value val) {
231232
if (auto gepOp = val.getDefiningOp<LLVM::GEPOp>())
232233
return WalkContinuation::advanceTo(gepOp.getBase());
233234

@@ -239,12 +240,28 @@ static SmallVector<Value> getUnderlyingObjectSet(Value pointerValue) {
239240
return WalkContinuation::advanceTo(
240241
{selectOp.getTrueValue(), selectOp.getFalseValue()});
241242

242-
if (isa<OpResult>(val))
243+
// Attempt to advance to control flow predecessors.
244+
std::optional<SmallVector<Value>> controlFlowPredecessors =
245+
getControlFlowPredecessors(val);
246+
if (controlFlowPredecessors)
247+
return WalkContinuation::advanceTo(*controlFlowPredecessors);
248+
249+
// For all non-control flow results, consider `val` an underlying object.
250+
if (isa<OpResult>(val)) {
243251
result.push_back(val);
252+
return WalkContinuation::skip();
253+
}
244254

245-
return WalkContinuation::advance();
255+
// If this place is reached, `val` is a block argument that is not
256+
// understood. Therfore, we conservatively interrupt.
257+
// Note: Dealing with the function arguments is not necessary, as the slice
258+
// would have to go through an SSACopyOp first.
259+
return WalkContinuation::interrupt();
246260
});
247261

262+
if (walkResult.wasInterrupted())
263+
return failure();
264+
248265
return result;
249266
}
250267

@@ -306,9 +323,14 @@ static void createNewAliasScopesFromNoAliasParameter(
306323

307324
// Find the set of underlying pointers that this pointer is based on.
308325
SmallPtrSet<Value, 4> basedOnPointers;
309-
for (Value pointer : pointerArgs)
310-
llvm::copy(getUnderlyingObjectSet(pointer),
326+
for (Value pointer : pointerArgs) {
327+
FailureOr<SmallVector<Value>> underlyingObjectSet =
328+
getUnderlyingObjectSet(pointer);
329+
if (failed(underlyingObjectSet))
330+
return;
331+
llvm::copy(*underlyingObjectSet,
311332
std::inserter(basedOnPointers, basedOnPointers.begin()));
333+
}
312334

313335
bool aliasesOtherKnownObject = false;
314336
// Go through the based on pointers and check that they are either:

mlir/lib/IR/SliceSupport.cpp

Lines changed: 63 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,35 @@
33

44
using namespace mlir;
55

6+
WalkContinuation mlir::walkSlice(ValueRange rootValues,
7+
WalkCallback walkCallback) {
8+
// Search the backward slice starting from the root values.
9+
SmallVector<Value> workList = rootValues;
10+
llvm::SmallDenseSet<Value, 16> seenValues;
11+
while (!workList.empty()) {
12+
// Search the backward slice of the current value.
13+
Value current = workList.pop_back_val();
14+
15+
// Skip the current value if it has already been seen.
16+
if (!seenValues.insert(current).second)
17+
continue;
18+
19+
// Call the walk callback with the current value.
20+
WalkContinuation continuation = walkCallback(current);
21+
if (continuation.wasInterrupted())
22+
return continuation;
23+
if (continuation.wasSkipped())
24+
continue;
25+
26+
assert(continuation.wasAdvancedTo());
27+
// Add the next values to the work list if the walk should continue.
28+
workList.append(continuation.getNextValues().begin(),
29+
continuation.getNextValues().end());
30+
}
31+
32+
return WalkContinuation::skip();
33+
}
34+
635
/// Returns the operands from all predecessor regions that match `operandNumber`
736
/// for the `successor` region within `regionOp`.
837
static SmallVector<Value>
@@ -49,15 +78,20 @@ getRegionPredecessorOperands(RegionBranchOpInterface regionOp,
4978
return predecessorOperands;
5079
}
5180

52-
/// Returns the predecessor branch operands that match `blockArg`.
53-
static SmallVector<Value> getBlockPredecessorOperands(BlockArgument blockArg) {
81+
/// Returns the predecessor branch operands that match `blockArg`. Returns a
82+
/// nullopt when some of the predecessor terminators do not implement the
83+
/// BranchOpInterface.
84+
static std::optional<SmallVector<Value>>
85+
getBlockPredecessorOperands(BlockArgument blockArg) {
5486
Block *block = blockArg.getOwner();
5587

5688
// Search the predecessor operands for all predecessor terminators.
5789
SmallVector<Value> predecessorOperands;
5890
for (auto it = block->pred_begin(); it != block->pred_end(); ++it) {
5991
Block *predecessor = *it;
60-
auto branchOp = cast<BranchOpInterface>(predecessor->getTerminator());
92+
auto branchOp = dyn_cast<BranchOpInterface>(predecessor->getTerminator());
93+
if (!branchOp)
94+
return std::nullopt;
6195
SuccessorOperands successorOperands =
6296
branchOp.getSuccessorOperands(it.getSuccessorIndex());
6397
// Store the predecessor operand if the block argument matches an operand
@@ -69,62 +103,38 @@ static SmallVector<Value> getBlockPredecessorOperands(BlockArgument blockArg) {
69103
return predecessorOperands;
70104
}
71105

72-
mlir::WalkContinuation mlir::walkBackwardSlice(ValueRange rootValues,
73-
WalkCallback walkCallback) {
74-
// Search the backward slice starting from the root values.
75-
SmallVector<Value> workList = rootValues;
76-
llvm::SmallDenseSet<Value, 16> seenValues;
77-
while (!workList.empty()) {
78-
// Search the backward slice of the current value.
79-
Value current = workList.pop_back_val();
80-
81-
// Skip the current value if it has already been seen.
82-
if (!seenValues.insert(current).second)
83-
continue;
84-
85-
// Call the walk callback with the current value.
86-
WalkContinuation continuation = walkCallback(current);
87-
if (continuation.wasInterrupted())
88-
return continuation;
89-
if (continuation.wasSkipped())
90-
continue;
91-
92-
// Add the next values to the work list if the walk should continue.
93-
if (continuation.wasAdvancedTo()) {
94-
workList.append(continuation.getNextValues().begin(),
95-
continuation.getNextValues().end());
96-
continue;
97-
}
98-
106+
std::optional<SmallVector<Value>>
107+
mlir::getControlFlowPredecessors(Value value) {
108+
SmallVector<Value> result;
109+
if (OpResult opResult = dyn_cast<OpResult>(value)) {
110+
auto regionOp = dyn_cast<RegionBranchOpInterface>(opResult.getOwner());
111+
// If the interface is not implemented, there are no control flow
112+
// predecessors to work with.
113+
if (!regionOp)
114+
return std::nullopt;
99115
// Add the control flow predecessor operands to the work list.
100-
if (OpResult opResult = dyn_cast<OpResult>(current)) {
101-
auto regionOp = dyn_cast<RegionBranchOpInterface>(opResult.getOwner());
102-
if (!regionOp)
103-
continue;
104-
RegionSuccessor region(regionOp->getResults());
105-
SmallVector<Value> predecessorOperands = getRegionPredecessorOperands(
106-
regionOp, region, opResult.getResultNumber());
107-
workList.append(predecessorOperands.begin(), predecessorOperands.end());
108-
continue;
109-
}
116+
RegionSuccessor region(regionOp->getResults());
117+
SmallVector<Value> predecessorOperands = getRegionPredecessorOperands(
118+
regionOp, region, opResult.getResultNumber());
119+
return predecessorOperands;
120+
}
110121

111-
auto blockArg = cast<BlockArgument>(current);
112-
Block *block = blockArg.getOwner();
113-
// Search the region predecessor operands for structured control flow.
114-
auto regionBranchOp =
115-
dyn_cast<RegionBranchOpInterface>(block->getParentOp());
116-
if (block->isEntryBlock() && regionBranchOp) {
122+
auto blockArg = cast<BlockArgument>(value);
123+
Block *block = blockArg.getOwner();
124+
// Search the region predecessor operands for structured control flow.
125+
if (block->isEntryBlock()) {
126+
if (auto regionBranchOp =
127+
dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
117128
RegionSuccessor region(blockArg.getParentRegion());
118129
SmallVector<Value> predecessorOperands = getRegionPredecessorOperands(
119130
regionBranchOp, region, blockArg.getArgNumber());
120-
workList.append(predecessorOperands.begin(), predecessorOperands.end());
121-
continue;
131+
return predecessorOperands;
122132
}
123-
// Search the block predecessor operands for unstructured control flow.
124-
SmallVector<Value> predecessorOperands =
125-
getBlockPredecessorOperands(blockArg);
126-
workList.append(predecessorOperands.begin(), predecessorOperands.end());
133+
// Unclear how to deal with this operation, conservatively return a
134+
// failure.
135+
return std::nullopt;
127136
}
128137

129-
return WalkContinuation::advance();
138+
// Search the block predecessor operands for unstructured control flow.
139+
return getBlockPredecessorOperands(blockArg);
130140
}

0 commit comments

Comments
 (0)