Skip to content

Commit 6ecebb4

Browse files
[mlir][bufferization] Support unstructured control flow
This revision adds support for unstructured control flow to the bufferization infrastructure. In particular: regions with multiple blocks, `cf.br`, `cf.cond_br`. Two helper templates are added to `BufferizableOpInterface.h`, which can be implemented by ops that supported unstructured control flow in their regions (e.g., `func.func`) and ops that branch to another block (e.g., `cf.br`). A block signature is always bufferized together with the op that owns the block. Differential Revision: https://reviews.llvm.org/D158094
1 parent e2cb07c commit 6ecebb4

23 files changed

+956
-29
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,10 @@ struct TraversalConfig {
411411
/// Specifies whether OpOperands with a different type that are not the result
412412
/// of a CastOpInterface op should be followed.
413413
bool followSameTypeOrCastsOnly = false;
414+
415+
/// Specifies whether already visited values should be visited again.
416+
/// (Note: This can result in infinite looping.)
417+
bool revisitAlreadyVisitedValues = false;
414418
};
415419

416420
/// AnalysisState provides a variety of helper functions for dealing with

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,13 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
415415
the input IR and returns `failure` in that case. If this op is
416416
expected to survive bufferization, `success` should be returned
417417
(together with `allow-unknown-ops` enabled).
418+
419+
Note: If this op supports unstructured control flow in its regions,
420+
then this function should also bufferize all block signatures that
421+
belong to this op. Branch ops (that branch to a block) are typically
422+
bufferized together with the block signature (this is just a
423+
suggestion to make sure IR is valid at every point in time and could
424+
be done differently).
418425
}],
419426
/*retType=*/"::mlir::LogicalResult",
420427
/*methodName=*/"bufferize",
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
//===- UnstructuredControlFlow.h - Op Interface Helpers ---------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_
10+
#define MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_
11+
12+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
13+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
14+
15+
//===----------------------------------------------------------------------===//
16+
// Helpers for Unstructured Control Flow
17+
//===----------------------------------------------------------------------===//
18+
19+
namespace mlir {
20+
namespace bufferization {
21+
22+
namespace detail {
23+
/// Return a list of operands that are forwarded to the given block argument.
24+
/// I.e., find all predecessors of the block argument's owner and gather the
25+
/// operands that are equivalent to the block argument.
26+
SmallVector<OpOperand *> getCallerOpOperands(BlockArgument bbArg);
27+
} // namespace detail
28+
29+
/// A template that provides a default implementation of `getAliasingOpOperands`
30+
/// for ops that support unstructured control flow within their regions.
31+
template <typename ConcreteModel, typename ConcreteOp>
32+
struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
33+
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
34+
35+
FailureOr<BaseMemRefType>
36+
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
37+
SmallVector<Value> &invocationStack) const {
38+
// Note: The user may want to override this function for OpResults in
39+
// case the bufferized result type is different from the bufferized type of
40+
// the aliasing OpOperand (if any).
41+
if (isa<OpResult>(value))
42+
return bufferization::detail::defaultGetBufferType(value, options,
43+
invocationStack);
44+
45+
// Compute the buffer type of the block argument by computing the bufferized
46+
// operand types of all forwarded values. If these are all the same type,
47+
// take that type. Otherwise, take only the memory space and fall back to a
48+
// buffer type with a fully dynamic layout map.
49+
BaseMemRefType bufferType;
50+
auto tensorType = cast<TensorType>(value.getType());
51+
for (OpOperand *opOperand :
52+
detail::getCallerOpOperands(cast<BlockArgument>(value))) {
53+
54+
// If the forwarded operand is already on the invocation stack, we ran
55+
// into a loop and this operand cannot be used to compute the bufferized
56+
// type.
57+
if (llvm::find(invocationStack, opOperand->get()) !=
58+
invocationStack.end())
59+
continue;
60+
61+
// Compute the bufferized type of the forwarded operand.
62+
BaseMemRefType callerType;
63+
if (auto memrefType =
64+
dyn_cast<BaseMemRefType>(opOperand->get().getType())) {
65+
// The operand was already bufferized. Take its type directly.
66+
callerType = memrefType;
67+
} else {
68+
FailureOr<BaseMemRefType> maybeCallerType =
69+
bufferization::getBufferType(opOperand->get(), options,
70+
invocationStack);
71+
if (failed(maybeCallerType))
72+
return failure();
73+
callerType = *maybeCallerType;
74+
}
75+
76+
if (!bufferType) {
77+
// This is the first buffer type that we computed.
78+
bufferType = callerType;
79+
continue;
80+
}
81+
82+
if (bufferType == callerType)
83+
continue;
84+
85+
// If the computed buffer type does not match the computed buffer type
86+
// of the earlier forwarded operands, fall back to a buffer type with a
87+
// fully dynamic layout map.
88+
#ifndef NDEBUG
89+
if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
90+
assert(bufferType.hasRank() && callerType.hasRank() &&
91+
"expected ranked memrefs");
92+
assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(),
93+
rankedTensorType.getShape()}) &&
94+
"expected same shape");
95+
} else {
96+
assert(!bufferType.hasRank() && !callerType.hasRank() &&
97+
"expected unranked memrefs");
98+
}
99+
#endif // NDEBUG
100+
101+
if (bufferType.getMemorySpace() != callerType.getMemorySpace())
102+
return op->emitOpError("incoming operands of block argument have "
103+
"inconsistent memory spaces");
104+
105+
bufferType = getMemRefTypeWithFullyDynamicLayout(
106+
tensorType, bufferType.getMemorySpace());
107+
}
108+
109+
if (!bufferType)
110+
return op->emitOpError("could not infer buffer type of block argument");
111+
112+
return bufferType;
113+
}
114+
115+
protected:
116+
/// Assuming that `bbArg` is a block argument of a block that belongs to the
117+
/// given `op`, return all OpOperands of users of this block that are
118+
/// aliasing with the given block argument.
119+
AliasingOpOperandList
120+
getAliasingBranchOpOperands(Operation *op, BlockArgument bbArg,
121+
const AnalysisState &state) const {
122+
assert(bbArg.getOwner()->getParentOp() == op && "invalid bbArg");
123+
124+
// Gather aliasing OpOperands of all operations (callers) that link to
125+
// this block.
126+
AliasingOpOperandList result;
127+
for (OpOperand *opOperand : detail::getCallerOpOperands(bbArg))
128+
result.addAlias(
129+
{opOperand, BufferRelation::Equivalent, /*isDefinite=*/false});
130+
131+
return result;
132+
}
133+
};
134+
135+
/// A template that provides a default implementation of `getAliasingValues`
136+
/// for ops that implement the `BranchOpInterface`.
137+
template <typename ConcreteModel, typename ConcreteOp>
138+
struct BranchOpBufferizableOpInterfaceExternalModel
139+
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
140+
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
141+
const AnalysisState &state) const {
142+
AliasingValueList result;
143+
auto branchOp = cast<BranchOpInterface>(op);
144+
auto operandNumber = opOperand.getOperandNumber();
145+
146+
// Gather aliasing block arguments of blocks to which this op may branch to.
147+
for (const auto &it : llvm::enumerate(op->getSuccessors())) {
148+
Block *block = it.value();
149+
SuccessorOperands operands = branchOp.getSuccessorOperands(it.index());
150+
assert(operands.getProducedOperandCount() == 0 &&
151+
"produced operands not supported");
152+
if (operands.getForwardedOperands().empty())
153+
continue;
154+
// The first and last operands that are forwarded to this successor.
155+
int64_t firstOperandIndex =
156+
operands.getForwardedOperands().getBeginOperandIndex();
157+
int64_t lastOperandIndex =
158+
firstOperandIndex + operands.getForwardedOperands().size();
159+
bool matchingDestination = operandNumber >= firstOperandIndex &&
160+
operandNumber < lastOperandIndex;
161+
// A branch op may have multiple successors. Find the ones that correspond
162+
// to this OpOperand. (There is usually only one.)
163+
if (!matchingDestination)
164+
continue;
165+
// Compute the matching block argument of the destination block.
166+
BlockArgument bbArg =
167+
block->getArgument(operandNumber - firstOperandIndex);
168+
result.addAlias(
169+
{bbArg, BufferRelation::Equivalent, /*isDefinite=*/false});
170+
}
171+
172+
return result;
173+
}
174+
};
175+
176+
} // namespace bufferization
177+
} // namespace mlir
178+
179+
#endif // MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_

mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,15 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
7878
const OpFilter *opFilter = nullptr,
7979
BufferizationStatistics *statistics = nullptr);
8080

81-
/// Bufferize the signature of `block`. All block argument types are changed to
82-
/// memref types.
81+
/// Bufferize the signature of `block` and its callers (i.e., ops that have the
82+
/// given block as a successor). All block argument types are changed to memref
83+
/// types. All corresponding operands of all callers are wrapped in
84+
/// bufferization.to_memref ops. All uses of bufferized tensor block arguments
85+
/// are wrapped in bufferization.to_tensor ops.
86+
///
87+
/// It is expected that all callers implement the `BranchOpInterface`.
88+
/// Otherwise, this function will fail. The `BranchOpInterface` is used to query
89+
/// the range of operands that are forwarded to this block.
8390
///
8491
/// It is expected that the parent op of this block implements the
8592
/// `BufferizableOpInterface`. The buffer types of tensor block arguments are
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_CONTROLFLOW_BUFFERIZABLEOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_CONTROLFLOW_BUFFERIZABLEOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace cf {
16+
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
17+
} // namespace cf
18+
} // namespace mlir
19+
20+
#endif // MLIR_DIALECT_CONTROLFLOW_BUFFERIZABLEOPINTERFACEIMPL_H

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
3030
#include "mlir/Dialect/Complex/IR/Complex.h"
3131
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
32+
#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h"
3233
#include "mlir/Dialect/DLTI/DLTI.h"
3334
#include "mlir/Dialect/EmitC/IR/EmitC.h"
3435
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -135,6 +136,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
135136
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
136137
registry);
137138
builtin::registerCastOpInterfaceExternalModels(registry);
139+
cf::registerBufferizableOpInterfaceExternalModels(registry);
138140
linalg::registerBufferizableOpInterfaceExternalModels(registry);
139141
linalg::registerTilingInterfaceExternalModels(registry);
140142
linalg::registerValueBoundsOpInterfaceExternalModels(registry);

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
234234
});
235235

236236
if (aliasingValues.getNumAliases() == 1 &&
237+
isa<OpResult>(aliasingValues.getAliases()[0].value) &&
237238
!state.bufferizesToMemoryWrite(opOperand) &&
238239
state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
239240
.getNumAliases() == 1 &&
@@ -498,11 +499,16 @@ bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
498499
bool AnalysisState::isValueRead(Value value) const {
499500
assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
500501
SmallVector<OpOperand *> workingSet;
502+
DenseSet<OpOperand *> visited;
501503
for (OpOperand &use : value.getUses())
502504
workingSet.push_back(&use);
503505

504506
while (!workingSet.empty()) {
505507
OpOperand *uMaybeReading = workingSet.pop_back_val();
508+
if (visited.contains(uMaybeReading))
509+
continue;
510+
visited.insert(uMaybeReading);
511+
506512
// Skip over all ops that neither read nor write (but create an alias).
507513
if (bufferizesToAliasOnly(*uMaybeReading))
508514
for (AliasingValue alias : getAliasingValues(*uMaybeReading))
@@ -522,11 +528,21 @@ bool AnalysisState::isValueRead(Value value) const {
522528
llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
523529
Value value, llvm::function_ref<bool(Value)> condition,
524530
TraversalConfig config) const {
531+
llvm::DenseSet<Value> visited;
525532
llvm::SetVector<Value> result, workingSet;
526533
workingSet.insert(value);
527534

528535
while (!workingSet.empty()) {
529536
Value value = workingSet.pop_back_val();
537+
538+
if (!config.revisitAlreadyVisitedValues && visited.contains(value)) {
539+
// Stop traversal if value was already visited.
540+
if (config.alwaysIncludeLeaves)
541+
result.insert(value);
542+
continue;
543+
}
544+
visited.insert(value);
545+
530546
if (condition(value)) {
531547
result.insert(value);
532548
continue;
@@ -659,11 +675,15 @@ bool AnalysisState::isTensorYielded(Value tensor) const {
659675
// preceding value, so we can follow SSA use-def chains and do a simple
660676
// analysis.
661677
SmallVector<OpOperand *> worklist;
678+
DenseSet<OpOperand *> visited;
662679
for (OpOperand &use : tensor.getUses())
663680
worklist.push_back(&use);
664681

665682
while (!worklist.empty()) {
666683
OpOperand *operand = worklist.pop_back_val();
684+
if (visited.contains(operand))
685+
continue;
686+
visited.insert(operand);
667687
Operation *op = operand->getOwner();
668688

669689
// If the op is not bufferizable, we can safely assume that the value is not

mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
33
BufferizableOpInterface.cpp
44
BufferizationOps.cpp
55
BufferizationDialect.cpp
6+
UnstructuredControlFlow.cpp
67

78
ADDITIONAL_HEADER_DIRS
89
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===- UnstructuredControlFlow.cpp - Op Interface Helpers ----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
10+
11+
using namespace mlir;
12+
13+
SmallVector<OpOperand *>
14+
mlir::bufferization::detail::getCallerOpOperands(BlockArgument bbArg) {
15+
SmallVector<OpOperand *> result;
16+
Block *block = bbArg.getOwner();
17+
for (Operation *caller : block->getUsers()) {
18+
auto branchOp = dyn_cast<BranchOpInterface>(caller);
19+
assert(branchOp && "expected that all callers implement BranchOpInterface");
20+
auto it = llvm::find(caller->getSuccessors(), block);
21+
assert(it != caller->getSuccessors().end() && "could not find successor");
22+
int64_t successorIdx = std::distance(caller->getSuccessors().begin(), it);
23+
SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx);
24+
assert(operands.getProducedOperandCount() == 0 &&
25+
"produced operands not supported");
26+
int64_t operandIndex =
27+
operands.getForwardedOperands().getBeginOperandIndex() +
28+
bbArg.getArgNumber();
29+
result.push_back(&caller->getOpOperand(operandIndex));
30+
}
31+
return result;
32+
}

0 commit comments

Comments
 (0)