Skip to content

Commit 10ae8ae

Browse files
committed
[mlir][NFC] Make ReturnLike trait imply RegionBranchTerminatorOpInterface
This implication was already done de-facto and there were plenty of users and wrapper functions specifically used to handle the "return-like or RegionBranchTerminatorOpInterface" case. These simply existed due to up until recently missing features in ODS. With the new capabilities of traits, we can make `ReturnLike` imply `RegionBranchTerminatorOpInterface` and auto generate proper definitions for its methods. Various occurrences and wrapper methods used for `isa<RegionBranchTerminatorOpInterface>() || hasTrait<ReturnLike>()` have all been removed. Differential Revision: https://reviews.llvm.org/D157402
1 parent df5137e commit 10ae8ae

File tree

13 files changed

+89
-156
lines changed

13 files changed

+89
-156
lines changed

mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -397,13 +397,13 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
397397
void visitRegionSuccessors(RegionBranchOpInterface branch,
398398
ArrayRef<AbstractSparseLattice *> operands);
399399

400-
/// Visit a terminator (an op implementing `RegionBranchTerminatorOpInterface`
401-
/// or a return-like op) to compute the lattice values of its operands, given
402-
/// its parent op `branch`. The lattice value of an operand is determined
403-
/// based on the corresponding arguments in `terminator`'s region
404-
/// successor(s).
405-
void visitRegionSuccessorsFromTerminator(Operation *terminator,
406-
RegionBranchOpInterface branch);
400+
/// Visit a `RegionBranchTerminatorOpInterface` to compute the lattice values
401+
/// of its operands, given its parent op `branch`. The lattice value of an
402+
/// operand is determined based on the corresponding arguments in
403+
/// `terminator`'s region successor(s).
404+
void visitRegionSuccessorsFromTerminator(
405+
RegionBranchTerminatorOpInterface terminator,
406+
RegionBranchOpInterface branch);
407407

408408
/// Get the lattice element for a value, and also set up
409409
/// dependencies so that the analysis on the given ProgramPoint is re-invoked

mlir/include/mlir/Interfaces/ControlFlowInterfaces.h

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -229,32 +229,6 @@ Region *getEnclosingRepetitiveRegion(Operation *op);
229229
/// exists.
230230
Region *getEnclosingRepetitiveRegion(Value value);
231231

232-
//===----------------------------------------------------------------------===//
233-
// RegionBranchTerminatorOpInterface
234-
//===----------------------------------------------------------------------===//
235-
236-
/// Returns true if the given operation is either annotated with the
237-
/// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`.
238-
bool isRegionReturnLike(Operation *operation);
239-
240-
/// Returns the mutable operands that are passed to the region with the given
241-
/// `regionIndex`. If the operation does not implement the
242-
/// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the
243-
/// result will be `std::nullopt`. In all other cases, the resulting
244-
/// `OperandRange` represents all operands that are passed to the specified
245-
/// successor region. If `regionIndex` is `std::nullopt`, all operands that are
246-
/// passed to the parent operation will be returned.
247-
std::optional<MutableOperandRange>
248-
getMutableRegionBranchSuccessorOperands(Operation *operation,
249-
std::optional<unsigned> regionIndex);
250-
251-
/// Returns the read only operands that are passed to the region with the given
252-
/// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more
253-
/// information.
254-
std::optional<OperandRange>
255-
getRegionBranchSuccessorOperands(Operation *operation,
256-
std::optional<unsigned> regionIndex);
257-
258232
//===----------------------------------------------------------------------===//
259233
// ControlFlow Traits
260234
//===----------------------------------------------------------------------===//

mlir/include/mlir/Interfaces/ControlFlowInterfaces.td

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,19 @@ def RegionBranchTerminatorOpInterface :
272272
//===----------------------------------------------------------------------===//
273273

274274
// Op is "return-like".
275-
def ReturnLike : NativeOpTrait<"ReturnLike">;
275+
def ReturnLike : TraitList<[
276+
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>,
277+
NativeOpTrait<
278+
/*name=*/"ReturnLike",
279+
/*traits=*/[],
280+
/*extraOpDeclaration=*/"",
281+
/*extraOpDefinition=*/[{
282+
::mlir::MutableOperandRange $cppClass::getMutableSuccessorOperands(
283+
::std::optional<unsigned> index) {
284+
return ::mlir::MutableOperandRange(*this);
285+
}
286+
}]
287+
>
288+
]>;
276289

277290
#endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES

mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,14 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
9191
for (int i = 0, e = op->getNumRegions(); i != e; ++i) {
9292
if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(i)) {
9393
for (Block &block : op->getRegion(i)) {
94-
Operation *term = block.getTerminator();
9594
// Try to determine possible region-branch successor operands for the
9695
// current region.
97-
auto successorOperands =
98-
getRegionBranchSuccessorOperands(term, regionIndex);
99-
if (successorOperands) {
100-
collectUnderlyingAddressValues((*successorOperands)[*operandIndex],
101-
maxDepth, visited, output);
102-
} else if (term->getNumSuccessors()) {
96+
if (auto term = dyn_cast<RegionBranchTerminatorOpInterface>(
97+
block.getTerminator())) {
98+
collectUnderlyingAddressValues(
99+
term.getSuccessorOperands(regionIndex)[*operandIndex], maxDepth,
100+
visited, output);
101+
} else if (block.getNumSuccessors()) {
103102
// Otherwise, if this terminator may exit the region we can't make
104103
// any assumptions about which values get passed.
105104
output.push_back(inputValue);

mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
337337
// There may be a weird case where a terminator may be transferring control
338338
// either to the parent or to another block, so exit blocks and successors
339339
// are not mutually exclusive.
340-
Operation *terminator = b->getTerminator();
341-
return terminator && (terminator->hasTrait<OpTrait::ReturnLike>() ||
342-
isa<RegionBranchTerminatorOpInterface>(terminator));
340+
return isa_and_nonnull<RegionBranchTerminatorOpInterface>(
341+
b->getTerminator());
343342
};
344343
if (isExitBlock(block)) {
345344
// If this block is exiting from a callable, the successors of exiting from

mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,9 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
9393
// `BranchOpInterface`, `RegionBranchTerminatorOpInterface` or return-like op.
9494
Operation *op = operand.getOwner();
9595
assert((isa<RegionBranchOpInterface>(op) || isa<BranchOpInterface>(op) ||
96-
isa<RegionBranchTerminatorOpInterface>(op) ||
97-
op->hasTrait<OpTrait::ReturnLike>()) &&
96+
isa<RegionBranchTerminatorOpInterface>(op)) &&
9897
"expected the op to be `RegionBranchOpInterface`, "
99-
"`BranchOpInterface`, `RegionBranchTerminatorOpInterface`, or "
100-
"return-like");
98+
"`BranchOpInterface` or `RegionBranchTerminatorOpInterface`");
10199

102100
// The lattices of the non-forwarded branch operands don't get updated like
103101
// the forwarded branch operands or the non-branch operands. Thus they need
@@ -161,11 +159,10 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
161159
visitOperation(op, operandLiveness, resultsLiveness);
162160

163161
// We also visit the parent op with the parent's results and this operand if
164-
// `op` is a `RegionBranchTerminatorOpInterface` or return-like because its
165-
// non-forwarded operand depends on not only its memory effects/results but
166-
// also on those of its parent's.
167-
if (!isa<RegionBranchTerminatorOpInterface>(op) &&
168-
!op->hasTrait<OpTrait::ReturnLike>())
162+
// `op` is a `RegionBranchTerminatorOpInterface` because its non-forwarded
163+
// operand depends on not only its memory effects/results but also on those of
164+
// its parent's.
165+
if (!isa<RegionBranchTerminatorOpInterface>(op))
169166
return;
170167
Operation *parentOp = op->getParentOp();
171168
SmallVector<const Liveness *, 4> parentResultsLiveness;

mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,9 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
226226
if (op == branch) {
227227
operands = branch.getSuccessorEntryOperands(successorIndex);
228228
// Otherwise, try to deduce the operands from a region return-like op.
229-
} else {
230-
if (isRegionReturnLike(op))
231-
operands = getRegionBranchSuccessorOperands(op, successorIndex);
229+
} else if (auto regionTerminator =
230+
dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
231+
operands = regionTerminator.getSuccessorOperands(successorIndex);
232232
}
233233

234234
if (!operands) {
@@ -439,10 +439,9 @@ void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
439439
// successor's input. There are two types of successor operands: the operands
440440
// of this op itself and the operands of the terminators of the regions of
441441
// this op.
442-
if (isa<RegionBranchTerminatorOpInterface>(op) ||
443-
op->hasTrait<OpTrait::ReturnLike>()) {
442+
if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
444443
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
445-
visitRegionSuccessorsFromTerminator(op, branch);
444+
visitRegionSuccessorsFromTerminator(terminator, branch);
446445
return;
447446
}
448447
}
@@ -506,12 +505,11 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
506505
}
507506

508507
void AbstractSparseBackwardDataFlowAnalysis::
509-
visitRegionSuccessorsFromTerminator(Operation *terminator,
510-
RegionBranchOpInterface branch) {
511-
assert(isa<RegionBranchTerminatorOpInterface>(terminator) ||
512-
terminator->hasTrait<OpTrait::ReturnLike>() &&
513-
"expected a `RegionBranchTerminatorOpInterface` op or a "
514-
"return-like op");
508+
visitRegionSuccessorsFromTerminator(
509+
RegionBranchTerminatorOpInterface terminator,
510+
RegionBranchOpInterface branch) {
511+
assert(isa<RegionBranchTerminatorOpInterface>(terminator) &&
512+
"expected a `RegionBranchTerminatorOpInterface` op");
515513
assert(terminator->getParentOp() == branch.getOperation() &&
516514
"expected `branch` to be the parent op of `terminator`");
517515

@@ -527,10 +525,8 @@ void AbstractSparseBackwardDataFlowAnalysis::
527525
for (const RegionSuccessor &successor : successors) {
528526
ValueRange inputs = successor.getSuccessorInputs();
529527
Region *region = successor.getSuccessor();
530-
OperandRange operands =
531-
region ? *getRegionBranchSuccessorOperands(terminator,
532-
region->getRegionNumber())
533-
: *getRegionBranchSuccessorOperands(terminator, {});
528+
OperandRange operands = terminator.getSuccessorOperands(
529+
region ? region->getRegionNumber() : std::optional<unsigned>{});
534530
MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
535531
for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
536532
meet(getLatticeElement(opOperand.get()),

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ bool AnalysisState::isTensorYielded(Value tensor) const {
690690
return true;
691691

692692
// Check if the op is returning/yielding.
693-
if (isRegionReturnLike(op))
693+
if (isa<RegionBranchTerminatorOpInterface>(op))
694694
return true;
695695

696696
// Add all aliasing OpResults to the worklist.

mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,15 @@ using namespace mlir;
7070
using namespace mlir::bufferization;
7171

7272
/// Walks over all immediate return-like terminators in the given region.
73-
static LogicalResult
74-
walkReturnOperations(Region *region,
75-
llvm::function_ref<LogicalResult(Operation *)> func) {
73+
static LogicalResult walkReturnOperations(
74+
Region *region,
75+
llvm::function_ref<LogicalResult(RegionBranchTerminatorOpInterface)> func) {
7676
for (Block &block : *region) {
7777
Operation *terminator = block.getTerminator();
7878
// Skip non region-return-like terminators.
79-
if (isRegionReturnLike(terminator)) {
80-
if (failed(func(terminator)))
79+
if (auto regionTerminator =
80+
dyn_cast<RegionBranchTerminatorOpInterface>(terminator)) {
81+
if (failed(func(regionTerminator)))
8182
return failure();
8283
}
8384
}
@@ -447,23 +448,25 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
447448
// Iterate over all immediate terminator operations to introduce
448449
// new buffer allocations. Thereby, the appropriate terminator operand
449450
// will be adjusted to point to the newly allocated buffer instead.
450-
if (failed(walkReturnOperations(&region, [&](Operation *terminator) {
451-
// Get the actual mutable operands for this terminator op.
452-
auto terminatorOperands = *getMutableRegionBranchSuccessorOperands(
453-
terminator, region.getRegionNumber());
454-
// Extract the source value from the current terminator.
455-
// This conversion needs to exist on a separate line due to a bug in
456-
// GCC conversion analysis.
457-
OperandRange immutableTerminatorOperands = terminatorOperands;
458-
Value sourceValue = immutableTerminatorOperands[operandIndex];
459-
// Create a new clone at the current location of the terminator.
460-
auto clone = introduceCloneBuffers(sourceValue, terminator);
461-
if (failed(clone))
462-
return failure();
463-
// Wire clone and terminator operand.
464-
terminatorOperands.slice(operandIndex, 1).assign(*clone);
465-
return success();
466-
})))
451+
if (failed(walkReturnOperations(
452+
&region, [&](RegionBranchTerminatorOpInterface terminator) {
453+
// Get the actual mutable operands for this terminator op.
454+
auto terminatorOperands =
455+
terminator.getMutableSuccessorOperands(
456+
region.getRegionNumber());
457+
// Extract the source value from the current terminator.
458+
// This conversion needs to exist on a separate line due to a
459+
// bug in GCC conversion analysis.
460+
OperandRange immutableTerminatorOperands = terminatorOperands;
461+
Value sourceValue = immutableTerminatorOperands[operandIndex];
462+
// Create a new clone at the current location of the terminator.
463+
auto clone = introduceCloneBuffers(sourceValue, terminator);
464+
if (failed(clone))
465+
return failure();
466+
// Wire clone and terminator operand.
467+
terminatorOperands.slice(operandIndex, 1).assign(*clone);
468+
return success();
469+
})))
467470
return failure();
468471
}
469472
return success();

mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ leavesAllocationScope(Region *parentRegion,
7575
// If there is at least one alias that leaves the parent region, we know
7676
// that this alias escapes the whole region and hence the associated
7777
// allocation leaves allocation scope.
78-
if (isRegionReturnLike(use) && use->getParentRegion() == parentRegion)
78+
if (isa<RegionBranchTerminatorOpInterface>(use) &&
79+
use->getParentRegion() == parentRegion)
7980
return true;
8081
}
8182
}

mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,11 @@ void BufferViewFlowAnalysis::build(Operation *op) {
128128
regionIndex = regionSuccessor->getRegionNumber();
129129
// Iterate over all immediate terminator operations and wire the
130130
// successor inputs with the successor operands of each terminator.
131-
for (Block &block : region) {
132-
auto successorOperands = getRegionBranchSuccessorOperands(
133-
block.getTerminator(), regionIndex);
134-
if (successorOperands) {
135-
registerDependencies(*successorOperands,
131+
for (Block &block : region)
132+
if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
133+
block.getTerminator()))
134+
registerDependencies(terminator.getSuccessorOperands(regionIndex),
136135
successorRegion.getSuccessorInputs());
137-
}
138-
}
139136
}
140137
}
141138

mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ void OneShotAnalysisState::createAliasInfoEntry(Value v) {
183183
// the IR.
184184
void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
185185
op->walk([&](Operation *returnOp) {
186-
if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp))
186+
if (!isa<RegionBranchTerminatorOpInterface>(returnOp) ||
187+
!getOptions().isOpAllowed(returnOp))
187188
return WalkResult::advance();
188189

189190
for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
@@ -1059,7 +1060,7 @@ static LogicalResult assertNoAllocsReturned(Operation *op,
10591060
LogicalResult status = success();
10601061
DominanceInfo domInfo(op);
10611062
op->walk([&](Operation *returnOp) {
1062-
if (!isRegionReturnLike(returnOp) ||
1063+
if (!isa<RegionBranchTerminatorOpInterface>(returnOp) ||
10631064
!state.getOptions().isOpAllowed(returnOp))
10641065
return WalkResult::advance();
10651066

0 commit comments

Comments
 (0)