Skip to content

Commit 4fc8bfb

Browse files
committed
[MLIR] Change getBackwardSlice to return a logicalresult rather than crash
1 parent fd8bc37 commit 4fc8bfb

File tree

10 files changed

+62
-35
lines changed

10 files changed

+62
-35
lines changed

mlir/include/mlir/Analysis/SliceAnalysis.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,15 @@ void getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
138138
/// Assuming all local orders match the numbering order:
139139
/// {1, 2, 5, 3, 4, 6}
140140
///
141-
void getBackwardSlice(Operation *op, SetVector<Operation *> *backwardSlice,
142-
const BackwardSliceOptions &options = {});
141+
LogicalResult getBackwardSlice(Operation *op,
142+
SetVector<Operation *> *backwardSlice,
143+
const BackwardSliceOptions &options = {});
143144

144145
/// Value-rooted version of `getBackwardSlice`. Return the union of all backward
145146
/// slices for the op defining or owning the value `root`.
146-
void getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
147-
const BackwardSliceOptions &options = {});
147+
LogicalResult getBackwardSlice(Value root,
148+
SetVector<Operation *> *backwardSlice,
149+
const BackwardSliceOptions &options = {});
148150

149151
/// Iteratively computes backward slices and forward slices until
150152
/// a fixed point is reached. Returns an `SetVector<Operation *>` which

mlir/include/mlir/Query/Matcher/SliceMatchers.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ bool BackwardSliceMatcher<Matcher>::matches(
112112
}
113113
return true;
114114
};
115-
getBackwardSlice(rootOp, &backwardSlice, options);
115+
auto result = getBackwardSlice(rootOp, &backwardSlice, options);
116+
assert(result.succeeded());
116117
return options.inclusive ? backwardSlice.size() > 1
117118
: backwardSlice.size() >= 1;
118119
}

mlir/lib/Analysis/SliceAnalysis.cpp

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -80,22 +80,25 @@ void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
8080
forwardSlice->insert(v.rbegin(), v.rend());
8181
}
8282

83-
static void getBackwardSliceImpl(Operation *op,
84-
SetVector<Operation *> *backwardSlice,
85-
const BackwardSliceOptions &options) {
83+
static LogicalResult getBackwardSliceImpl(Operation *op,
84+
SetVector<Operation *> *backwardSlice,
85+
const BackwardSliceOptions &options) {
8686
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
87-
return;
87+
return success();
8888

8989
// Evaluate whether we should keep this def.
9090
// This is useful in particular to implement scoping; i.e. return the
9191
// transitive backwardSlice in the current scope.
9292
if (options.filter && !options.filter(op))
93-
return;
93+
return success();
94+
95+
bool succeeded = true;
9496

9597
auto processValue = [&](Value value) {
9698
if (auto *definingOp = value.getDefiningOp()) {
9799
if (backwardSlice->count(definingOp) == 0)
98-
getBackwardSliceImpl(definingOp, backwardSlice, options);
100+
succeeded &= getBackwardSliceImpl(definingOp, backwardSlice, options)
101+
.succeeded();
99102
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
100103
if (options.omitBlockArguments)
101104
return;
@@ -106,9 +109,13 @@ static void getBackwardSliceImpl(Operation *op,
106109
// blocks of parentOp, which are not technically backward unless they flow
107110
// into us. For now, just bail.
108111
if (parentOp && backwardSlice->count(parentOp) == 0) {
109-
assert(parentOp->getNumRegions() == 1 &&
110-
llvm::hasSingleElement(parentOp->getRegion(0).getBlocks()));
111-
getBackwardSliceImpl(parentOp, backwardSlice, options);
112+
if (parentOp->getNumRegions() == 1 &&
113+
llvm::hasSingleElement(parentOp->getRegion(0).getBlocks())) {
114+
succeeded &= getBackwardSliceImpl(parentOp, backwardSlice, options)
115+
.succeeded();
116+
} else {
117+
succeeded = false;
118+
}
112119
}
113120
} else {
114121
llvm_unreachable("No definingOp and not a block argument.");
@@ -133,28 +140,30 @@ static void getBackwardSliceImpl(Operation *op,
133140
llvm::for_each(op->getOperands(), processValue);
134141

135142
backwardSlice->insert(op);
143+
return success(succeeded);
136144
}
137145

138-
void mlir::getBackwardSlice(Operation *op,
139-
SetVector<Operation *> *backwardSlice,
140-
const BackwardSliceOptions &options) {
141-
getBackwardSliceImpl(op, backwardSlice, options);
146+
LogicalResult
147+
mlir::getBackwardSlice(Operation *op, SetVector<Operation *> *backwardSlice,
148+
const BackwardSliceOptions &options) {
149+
LogicalResult result = getBackwardSliceImpl(op, backwardSlice, options);
142150

143151
if (!options.inclusive) {
144152
// Don't insert the top level operation, we just queried on it and don't
145153
// want it in the results.
146154
backwardSlice->remove(op);
147155
}
156+
return result;
148157
}
149158

150-
void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
151-
const BackwardSliceOptions &options) {
159+
LogicalResult mlir::getBackwardSlice(Value root,
160+
SetVector<Operation *> *backwardSlice,
161+
const BackwardSliceOptions &options) {
152162
if (Operation *definingOp = root.getDefiningOp()) {
153-
getBackwardSlice(definingOp, backwardSlice, options);
154-
return;
163+
return getBackwardSlice(definingOp, backwardSlice, options);
155164
}
156165
Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
157-
getBackwardSlice(bbAargOwner, backwardSlice, options);
166+
return getBackwardSlice(bbAargOwner, backwardSlice, options);
158167
}
159168

160169
SetVector<Operation *>
@@ -170,7 +179,9 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
170179
auto *currentOp = (slice)[currentIndex];
171180
// Compute and insert the backwardSlice starting from currentOp.
172181
backwardSlice.clear();
173-
getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
182+
auto result =
183+
getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
184+
assert(result.succeeded());
174185
slice.insert_range(backwardSlice);
175186

176187
// Compute and insert the forwardSlice starting from currentOp.
@@ -193,7 +204,8 @@ static bool dependsOnCarriedVals(Value value,
193204
sliceOptions.filter = [&](Operation *op) {
194205
return !ancestorOp->isAncestor(op);
195206
};
196-
getBackwardSlice(value, &slice, sliceOptions);
207+
auto result = getBackwardSlice(value, &slice, sliceOptions);
208+
assert(result.succeeded());
197209

198210
// Check that none of the operands of the operations in the backward slice are
199211
// loop iteration arguments, and neither is the value itself.

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,9 @@ getSliceContract(Operation *op,
317317
auto *currentOp = (slice)[currentIndex];
318318
// Compute and insert the backwardSlice starting from currentOp.
319319
backwardSlice.clear();
320-
getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
320+
auto result =
321+
getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
322+
assert(result.succeeded());
321323
slice.insert_range(backwardSlice);
322324

323325
// Compute and insert the forwardSlice starting from currentOp.

mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,13 @@ static void computeBackwardSlice(tensor::PadOp padOp,
124124
getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(),
125125
valuesDefinedAbove);
126126
for (Value v : valuesDefinedAbove) {
127-
getBackwardSlice(v, &backwardSlice, sliceOptions);
127+
auto result = getBackwardSlice(v, &backwardSlice, sliceOptions);
128+
assert(result.succeeded());
128129
}
129130
// Then, add the backward slice from padOp itself.
130-
getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions);
131+
auto result =
132+
getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions);
133+
assert(result.succeeded());
131134
}
132135

133136
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,10 @@ static void getPipelineStages(
290290
});
291291
options.inclusive = true;
292292
for (Operation &op : forOp.getBody()->getOperations()) {
293-
if (stage0Ops.contains(&op))
294-
getBackwardSlice(&op, &dependencies, options);
293+
if (stage0Ops.contains(&op)) {
294+
auto result = getBackwardSlice(&op, &dependencies, options);
295+
assert(result.succeeded());
296+
}
295297
}
296298

297299
for (Operation &op : forOp.getBody()->getOperations()) {

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1772,7 +1772,8 @@ checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
17721772
};
17731773
llvm::SetVector<Operation *> slice;
17741774
for (auto operand : consumerOp->getOperands()) {
1775-
getBackwardSlice(operand, &slice, options);
1775+
auto result = getBackwardSlice(operand, &slice, options);
1776+
assert(result.succeeded());
17761777
}
17771778

17781779
if (!slice.empty()) {

mlir/lib/Transforms/Utils/RegionUtils.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,8 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
10941094
return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
10951095
};
10961096
llvm::SetVector<Operation *> slice;
1097-
getBackwardSlice(op, &slice, options);
1097+
auto result = getBackwardSlice(op, &slice, options);
1098+
assert(result.succeeded());
10981099

10991100
// If the slice contains `insertionPoint` cannot move the dependencies.
11001101
if (slice.contains(insertionPoint)) {
@@ -1159,7 +1160,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
11591160
};
11601161
llvm::SetVector<Operation *> slice;
11611162
for (auto value : prunedValues) {
1162-
getBackwardSlice(value, &slice, options);
1163+
auto result = getBackwardSlice(value, &slice, options);
1164+
assert(result.succeeded());
11631165
}
11641166

11651167
// If the slice contains `insertionPoint` cannot move the dependencies.

mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) {
154154
patternTestSlicingOps().match(f, &matches);
155155
for (auto m : matches) {
156156
SetVector<Operation *> backwardSlice;
157-
getBackwardSlice(m.getMatchedOperation(), &backwardSlice);
157+
auto result = getBackwardSlice(m.getMatchedOperation(), &backwardSlice);
158+
assert(result.succeeded());
158159
outs << "\nmatched: " << *m.getMatchedOperation()
159160
<< " backward static slice: ";
160161
for (auto *op : backwardSlice)

mlir/test/lib/IR/TestSlicing.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ static LogicalResult createBackwardSliceFunction(Operation *op,
4141
options.omitBlockArguments = omitBlockArguments;
4242
// TODO: Make this default.
4343
options.omitUsesFromAbove = false;
44-
getBackwardSlice(op, &slice, options);
44+
auto result = getBackwardSlice(op, &slice, options);
45+
assert(result.succeeded());
4546
for (Operation *slicedOp : slice)
4647
builder.clone(*slicedOp, mapper);
4748
builder.create<func::ReturnOp>(loc);

0 commit comments

Comments
 (0)