Skip to content

Commit 6a8dde0

Browse files
wsmosesftynse
andauthored
[MLIR] Change getBackwardSlice to return a logicalresult rather than crash (#140961)
The current implementation of getBackwardSlice will crash if an operation in the dependency chain is defined by an operation with multiple regions or blocks. Crashing is bad (and forbids many analyses from using getBackwardSlice, as well as causing existing users of getBackwardSlice to fail for IR with this property). This PR instead causes the analysis to return a failure, rather than crash in the cases it cannot compute the full slice --------- Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
1 parent 4a6b1fb commit 6a8dde0

File tree

10 files changed

+67
-39
lines changed

10 files changed

+67
-39
lines changed

mlir/include/mlir/Analysis/SliceAnalysis.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,17 @@ void getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
138138
/// Assuming all local orders match the numbering order:
139139
/// {1, 2, 5, 3, 4, 6}
140140
///
141-
void getBackwardSlice(Operation *op, SetVector<Operation *> *backwardSlice,
142-
const BackwardSliceOptions &options = {});
141+
/// This function returns whether the backwards slice was able to be
142+
/// successfully computed, and failure if it was unable to determine the slice.
143+
LogicalResult getBackwardSlice(Operation *op,
144+
SetVector<Operation *> *backwardSlice,
145+
const BackwardSliceOptions &options = {});
143146

144147
/// Value-rooted version of `getBackwardSlice`. Return the union of all backward
145148
/// slices for the op defining or owning the value `root`.
146-
void getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
147-
const BackwardSliceOptions &options = {});
149+
LogicalResult getBackwardSlice(Value root,
150+
SetVector<Operation *> *backwardSlice,
151+
const BackwardSliceOptions &options = {});
148152

149153
/// Iteratively computes backward slices and forward slices until
150154
/// a fixed point is reached. Returns an `SetVector<Operation *>` which

mlir/include/mlir/Query/Matcher/SliceMatchers.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ bool BackwardSliceMatcher<Matcher>::matches(
112112
}
113113
return true;
114114
};
115-
getBackwardSlice(rootOp, &backwardSlice, options);
115+
LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
116+
assert(result.succeeded() && "expected backward slice to succeed");
116117
return options.inclusive ? backwardSlice.size() > 1
117118
: backwardSlice.size() >= 1;
118119
}

mlir/lib/Analysis/SliceAnalysis.cpp

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -80,41 +80,43 @@ void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
8080
forwardSlice->insert(v.rbegin(), v.rend());
8181
}
8282

83-
static void getBackwardSliceImpl(Operation *op,
84-
SetVector<Operation *> *backwardSlice,
85-
const BackwardSliceOptions &options) {
83+
static LogicalResult getBackwardSliceImpl(Operation *op,
84+
SetVector<Operation *> *backwardSlice,
85+
const BackwardSliceOptions &options) {
8686
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
87-
return;
87+
return success();
8888

8989
// Evaluate whether we should keep this def.
9090
// This is useful in particular to implement scoping; i.e. return the
9191
// transitive backwardSlice in the current scope.
9292
if (options.filter && !options.filter(op))
93-
return;
93+
return success();
9494

9595
auto processValue = [&](Value value) {
9696
if (auto *definingOp = value.getDefiningOp()) {
9797
if (backwardSlice->count(definingOp) == 0)
98-
getBackwardSliceImpl(definingOp, backwardSlice, options);
98+
return getBackwardSliceImpl(definingOp, backwardSlice, options);
9999
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
100100
if (options.omitBlockArguments)
101-
return;
101+
return success();
102102

103103
Block *block = blockArg.getOwner();
104104
Operation *parentOp = block->getParentOp();
105105
// TODO: determine whether we want to recurse backward into the other
106106
// blocks of parentOp, which are not technically backward unless they flow
107107
// into us. For now, just bail.
108108
if (parentOp && backwardSlice->count(parentOp) == 0) {
109-
assert(parentOp->getNumRegions() == 1 &&
110-
llvm::hasSingleElement(parentOp->getRegion(0).getBlocks()));
111-
getBackwardSliceImpl(parentOp, backwardSlice, options);
109+
if (parentOp->getNumRegions() == 1 &&
110+
llvm::hasSingleElement(parentOp->getRegion(0).getBlocks())) {
111+
return getBackwardSliceImpl(parentOp, backwardSlice, options);
112+
}
112113
}
113-
} else {
114-
llvm_unreachable("No definingOp and not a block argument.");
115114
}
115+
return failure();
116116
};
117117

118+
bool succeeded = true;
119+
118120
if (!options.omitUsesFromAbove) {
119121
llvm::for_each(op->getRegions(), [&](Region &region) {
120122
// Walk this region recursively to collect the regions that descend from
@@ -125,36 +127,41 @@ static void getBackwardSliceImpl(Operation *op,
125127
region.walk([&](Operation *op) {
126128
for (OpOperand &operand : op->getOpOperands()) {
127129
if (!descendents.contains(operand.get().getParentRegion()))
128-
processValue(operand.get());
130+
if (!processValue(operand.get()).succeeded()) {
131+
return WalkResult::interrupt();
132+
}
129133
}
134+
return WalkResult::advance();
130135
});
131136
});
132137
}
133138
llvm::for_each(op->getOperands(), processValue);
134139

135140
backwardSlice->insert(op);
141+
return success(succeeded);
136142
}
137143

138-
void mlir::getBackwardSlice(Operation *op,
139-
SetVector<Operation *> *backwardSlice,
140-
const BackwardSliceOptions &options) {
141-
getBackwardSliceImpl(op, backwardSlice, options);
144+
LogicalResult mlir::getBackwardSlice(Operation *op,
145+
SetVector<Operation *> *backwardSlice,
146+
const BackwardSliceOptions &options) {
147+
LogicalResult result = getBackwardSliceImpl(op, backwardSlice, options);
142148

143149
if (!options.inclusive) {
144150
// Don't insert the top level operation, we just queried on it and don't
145151
// want it in the results.
146152
backwardSlice->remove(op);
147153
}
154+
return result;
148155
}
149156

150-
void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
151-
const BackwardSliceOptions &options) {
157+
LogicalResult mlir::getBackwardSlice(Value root,
158+
SetVector<Operation *> *backwardSlice,
159+
const BackwardSliceOptions &options) {
152160
if (Operation *definingOp = root.getDefiningOp()) {
153-
getBackwardSlice(definingOp, backwardSlice, options);
154-
return;
161+
return getBackwardSlice(definingOp, backwardSlice, options);
155162
}
156163
Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
157-
getBackwardSlice(bbAargOwner, backwardSlice, options);
164+
return getBackwardSlice(bbAargOwner, backwardSlice, options);
158165
}
159166

160167
SetVector<Operation *>
@@ -170,7 +177,9 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
170177
auto *currentOp = (slice)[currentIndex];
171178
// Compute and insert the backwardSlice starting from currentOp.
172179
backwardSlice.clear();
173-
getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
180+
LogicalResult result =
181+
getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
182+
assert(result.succeeded());
174183
slice.insert_range(backwardSlice);
175184

176185
// Compute and insert the forwardSlice starting from currentOp.
@@ -193,7 +202,8 @@ static bool dependsOnCarriedVals(Value value,
193202
sliceOptions.filter = [&](Operation *op) {
194203
return !ancestorOp->isAncestor(op);
195204
};
196-
getBackwardSlice(value, &slice, sliceOptions);
205+
LogicalResult result = getBackwardSlice(value, &slice, sliceOptions);
206+
assert(result.succeeded());
197207

198208
// Check that none of the operands of the operations in the backward slice are
199209
// loop iteration arguments, and neither is the value itself.

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,9 @@ getSliceContract(Operation *op,
317317
auto *currentOp = (slice)[currentIndex];
318318
// Compute and insert the backwardSlice starting from currentOp.
319319
backwardSlice.clear();
320-
getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
320+
LogicalResult result =
321+
getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
322+
assert(result.succeeded() && "expected a backward slice");
321323
slice.insert_range(backwardSlice);
322324

323325
// Compute and insert the forwardSlice starting from currentOp.

mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,13 @@ static void computeBackwardSlice(tensor::PadOp padOp,
124124
getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(),
125125
valuesDefinedAbove);
126126
for (Value v : valuesDefinedAbove) {
127-
getBackwardSlice(v, &backwardSlice, sliceOptions);
127+
LogicalResult result = getBackwardSlice(v, &backwardSlice, sliceOptions);
128+
assert(result.succeeded() && "expected a backward slice");
128129
}
129130
// Then, add the backward slice from padOp itself.
130-
getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions);
131+
LogicalResult result =
132+
getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions);
133+
assert(result.succeeded() && "expected a backward slice");
131134
}
132135

133136
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,10 @@ static void getPipelineStages(
290290
});
291291
options.inclusive = true;
292292
for (Operation &op : forOp.getBody()->getOperations()) {
293-
if (stage0Ops.contains(&op))
294-
getBackwardSlice(&op, &dependencies, options);
293+
if (stage0Ops.contains(&op)) {
294+
LogicalResult result = getBackwardSlice(&op, &dependencies, options);
295+
assert(result.succeeded() && "expected a backward slice");
296+
}
295297
}
296298

297299
for (Operation &op : forOp.getBody()->getOperations()) {

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1772,7 +1772,8 @@ checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
17721772
};
17731773
llvm::SetVector<Operation *> slice;
17741774
for (auto operand : consumerOp->getOperands()) {
1775-
getBackwardSlice(operand, &slice, options);
1775+
LogicalResult result = getBackwardSlice(operand, &slice, options);
1776+
assert(result.succeeded() && "expected a backward slice");
17761777
}
17771778

17781779
if (!slice.empty()) {

mlir/lib/Transforms/Utils/RegionUtils.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,8 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
10941094
return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
10951095
};
10961096
llvm::SetVector<Operation *> slice;
1097-
getBackwardSlice(op, &slice, options);
1097+
LogicalResult result = getBackwardSlice(op, &slice, options);
1098+
assert(result.succeeded() && "expected a backward slice");
10981099

10991100
// If the slice contains `insertionPoint` cannot move the dependencies.
11001101
if (slice.contains(insertionPoint)) {
@@ -1159,7 +1160,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
11591160
};
11601161
llvm::SetVector<Operation *> slice;
11611162
for (auto value : prunedValues) {
1162-
getBackwardSlice(value, &slice, options);
1163+
LogicalResult result = getBackwardSlice(value, &slice, options);
1164+
assert(result.succeeded() && "expected a backward slice");
11631165
}
11641166

11651167
// If the slice contains `insertionPoint` cannot move the dependencies.

mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,9 @@ void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) {
154154
patternTestSlicingOps().match(f, &matches);
155155
for (auto m : matches) {
156156
SetVector<Operation *> backwardSlice;
157-
getBackwardSlice(m.getMatchedOperation(), &backwardSlice);
157+
LogicalResult result =
158+
getBackwardSlice(m.getMatchedOperation(), &backwardSlice);
159+
assert(result.succeeded() && "expected a backward slice");
158160
outs << "\nmatched: " << *m.getMatchedOperation()
159161
<< " backward static slice: ";
160162
for (auto *op : backwardSlice)

mlir/test/lib/IR/TestSlicing.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ static LogicalResult createBackwardSliceFunction(Operation *op,
4141
options.omitBlockArguments = omitBlockArguments;
4242
// TODO: Make this default.
4343
options.omitUsesFromAbove = false;
44-
getBackwardSlice(op, &slice, options);
44+
LogicalResult result = getBackwardSlice(op, &slice, options);
45+
assert(result.succeeded() && "expected a backward slice");
4546
for (Operation *slicedOp : slice)
4647
builder.clone(*slicedOp, mapper);
4748
builder.create<func::ReturnOp>(loc);

0 commit comments

Comments
 (0)