Skip to content

Commit f7082ae

Browse files
[mlir][bufferization] Add BufferViewFlowOpInterface
This commit adds the `BufferViewFlowOpInterface` to the bufferization dialect. This interface can be implemented by ops that operate on buffers to indicate that a buffer op result and/or region entry block argument may be the same buffer as a buffer operand (or a view thereof). This interface is queried by the `BufferViewFlowAnalysis`. There are currently no ops that implement this interface. The first op implementations will be added in a consecutive commit. BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC
1 parent 0d51c87 commit f7082ae

File tree

16 files changed

+356
-14
lines changed

16 files changed

+356
-14
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_ARITH_BUFFERVIEWFLOWOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_ARITH_BUFFERVIEWFLOWOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace arith {
16+
void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry &registry);
17+
} // namespace arith
18+
} // namespace mlir
19+
20+
#endif // MLIR_DIALECT_ARITH_BUFFERVIEWFLOWOPINTERFACEIMPL_H
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- BufferViewFlowOpInterface.h - Buffer View Flow Analysis --*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
10+
#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
11+
12+
#include "mlir/IR/OpDefinition.h"
13+
#include "mlir/Support/LLVM.h"
14+
15+
namespace mlir {
16+
class ValueRange;
17+
18+
namespace bufferization {
19+
20+
using RegisterDependenciesFn = std::function<void(ValueRange, ValueRange)>;
21+
22+
} // namespace bufferization
23+
} // namespace mlir
24+
25+
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h.inc"
26+
27+
#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
//===-- BufferViewFlowOpInterface.td - Buffer View Flow ----*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef BUFFER_VIEW_FLOW_OP_INTERFACE
10+
#define BUFFER_VIEW_FLOW_OP_INTERFACE
11+
12+
include "mlir/IR/OpBase.td"
13+
14+
def BufferViewFlowOpInterface :
15+
OpInterface<"BufferViewFlowOpInterface"> {
16+
let description = [{
17+
An op interface for the buffer view flow analysis. This interface describes
18+
buffer dependencies between operands and op results/region entry block
19+
arguments.
20+
}];
21+
let cppNamespace = "::mlir::bufferization";
22+
let methods = [
23+
InterfaceMethod<
24+
/*desc=*/[{
25+
Populate buffer dependencies between operands and op results/region
26+
entry block arguments.
27+
28+
Implementations should register dependencies between an operand ("X")
29+
and an op result/region entry block argument ("Y") if Y may depend
30+
on X. Y depends on X if Y and X are the same buffer or if Y is a
31+
subview of X.
32+
33+
Example:
34+
```
35+
%r = arith.select %c, %m1, %m2 : memref<5xf32>
36+
```
37+
In the above example, %0 may depend on %m1 or %m2 and a correct
38+
interface implementation should call:
39+
- "registerDependenciesFn(%m1, %r)".
40+
- "registerDependenciesFn(%m2, %r)"
41+
}],
42+
/*retType=*/"void",
43+
/*methodName=*/"populateDependencies",
44+
/*args=*/(ins
45+
"::mlir::bufferization::RegisterDependenciesFn"
46+
:$registerDependenciesFn)
47+
>,
48+
InterfaceMethod<
49+
/*desc=*/[{
50+
Return "true" if the given value is a terminal buffer. A buffer value
51+
is "terminal" if it cannot be traced back any further in the buffer
52+
view flow analysis. E.g., because the value is a newly allocated
53+
buffer or because there is not enough information available.
54+
55+
The given SSA value is guaranteed to be an OpResult of this operation
56+
or a region entry block argument of this operation.
57+
}],
58+
/*retType=*/"bool",
59+
/*methodName=*/"isTerminalBuffer",
60+
/*args=*/(ins "Value":$value),
61+
/*methodBody=*/"",
62+
/*defaultImplementation=*/"return false;"
63+
>,
64+
];
65+
}
66+
67+
#endif // BUFFER_VIEW_FLOW_OP_INTERFACE

mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
33
add_mlir_interface(AllocationOpInterface)
44
add_mlir_interface(BufferDeallocationOpInterface)
55
add_mlir_interface(BufferizableOpInterface)
6+
add_mlir_interface(BufferViewFlowOpInterface)
67

78
set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
89
mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)

mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,19 @@ class BufferViewFlowAnalysis {
6363
/// results have to be changed.
6464
void rename(Value from, Value to);
6565

66+
/// Return "true" if the given value is a terminal.
67+
bool isTerminalBuffer(Value value) const;
68+
6669
private:
6770
/// This function constructs a mapping from values to its immediate
6871
/// dependencies.
6972
void build(Operation *op);
7073

7174
/// Maps values to all immediate dependencies this value can have.
7275
ValueMapT dependencies;
76+
77+
/// A set of all terminal values. I.e., values at which the analysis stopped.
78+
DenseSet<Value> terminals;
7379
};
7480

7581
} // namespace mlir
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MEMREF_BUFFERVIEWFLOWOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_MEMREF_BUFFERVIEWFLOWOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace memref {
16+
void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry &registry);
17+
} // namespace memref
18+
} // namespace mlir
19+
20+
#endif // MLIR_DIALECT_MEMREF_BUFFERVIEWFLOWOPINTERFACEIMPL_H

mlir/include/mlir/InitAllDialects.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Dialect/Arith/IR/Arith.h"
2222
#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
2323
#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
24+
#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
2425
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
2526
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
2627
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
@@ -53,6 +54,7 @@
5354
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
5455
#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
5556
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
57+
#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
5658
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
5759
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
5860
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
@@ -145,6 +147,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
145147
affine::registerValueBoundsOpInterfaceExternalModels(registry);
146148
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
147149
arith::registerBufferizableOpInterfaceExternalModels(registry);
150+
arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
148151
arith::registerValueBoundsOpInterfaceExternalModels(registry);
149152
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
150153
registry);
@@ -157,6 +160,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
157160
linalg::registerTilingInterfaceExternalModels(registry);
158161
linalg::registerValueBoundsOpInterfaceExternalModels(registry);
159162
memref::registerAllocationOpInterfaceExternalModels(registry);
163+
memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
160164
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
161165
memref::registerValueBoundsOpInterfaceExternalModels(registry);
162166
memref::registerMemorySlotExternalModels(registry);
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
10+
11+
#include "mlir/Dialect/Arith/IR/Arith.h"
12+
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
13+
14+
using namespace mlir;
15+
using namespace mlir::bufferization;
16+
using namespace mlir::arith;
17+
18+
namespace mlir {
19+
namespace arith {
20+
namespace {
21+
22+
struct SelectOpInterface
23+
: public BufferViewFlowOpInterface::ExternalModel<SelectOpInterface,
24+
SelectOp> {
25+
void
26+
populateDependencies(Operation *op,
27+
RegisterDependenciesFn registerDependenciesFn) const {
28+
auto selectOp = cast<SelectOp>(op);
29+
30+
// Either one of the true/false value may be selected at runtime.
31+
registerDependenciesFn(selectOp.getTrueValue(), selectOp.getResult());
32+
registerDependenciesFn(selectOp.getFalseValue(), selectOp.getResult());
33+
}
34+
};
35+
36+
} // namespace
37+
} // namespace arith
38+
} // namespace mlir
39+
40+
void arith::registerBufferViewFlowOpInterfaceExternalModels(
41+
DialectRegistry &registry) {
42+
registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
43+
SelectOp::attachInterface<SelectOpInterface>(*ctx);
44+
});
45+
}

mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithTransforms
22
BufferDeallocationOpInterfaceImpl.cpp
33
BufferizableOpInterfaceImpl.cpp
44
Bufferize.cpp
5+
BufferViewFlowOpInterfaceImpl.cpp
56
EmulateUnsupportedFloats.cpp
67
EmulateWideInt.cpp
78
EmulateNarrowType.cpp
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//===- BufferViewFlowOpInterface.cpp - Buffer View Flow Analysis ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
10+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11+
12+
namespace mlir {
13+
namespace bufferization {
14+
15+
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp.inc"
16+
17+
} // namespace bufferization
18+
} // namespace mlir

mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
44
BufferDeallocationOpInterface.cpp
55
BufferizationOps.cpp
66
BufferizationDialect.cpp
7+
BufferViewFlowOpInterface.cpp
78
UnstructuredControlFlow.cpp
89

910
ADDITIONAL_HEADER_DIRS

mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@
88

99
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
1010

11+
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
12+
#include "mlir/Interfaces/CallInterfaces.h"
1113
#include "mlir/Interfaces/ControlFlowInterfaces.h"
14+
#include "mlir/Interfaces/FunctionInterfaces.h"
1215
#include "mlir/Interfaces/ViewLikeInterface.h"
1316
#include "llvm/ADT/SetOperations.h"
1417
#include "llvm/ADT/SetVector.h"
1518

1619
using namespace mlir;
20+
using namespace mlir::bufferization;
1721

1822
/// Constructs a new alias analysis using the op provided.
1923
BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
@@ -65,18 +69,44 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) {
6569
void BufferViewFlowAnalysis::build(Operation *op) {
6670
// Registers all dependencies of the given values.
6771
auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
68-
for (auto [value, dep] : llvm::zip(values, dependencies))
72+
for (auto [value, dep] : llvm::zip_equal(values, dependencies))
6973
this->dependencies[value].insert(dep);
7074
};
7175

76+
// Mark all buffer results and buffer region entry block arguments of the
77+
// given op as terminals.
78+
auto populateTerminalValues = [&](Operation *op) {
79+
for (Value v : op->getResults())
80+
if (isa<BaseMemRefType>(v.getType()))
81+
this->terminals.insert(v);
82+
for (Region &r : op->getRegions())
83+
for (BlockArgument v : r.getArguments())
84+
if (isa<BaseMemRefType>(v.getType()))
85+
this->terminals.insert(v);
86+
};
87+
7288
op->walk([&](Operation *op) {
73-
// TODO: We should have an op interface instead of a hard-coded list of
74-
// interfaces/ops.
89+
// Query BufferViewFlowOpInterface. If the op does not implement that
90+
// interface, try to infer the dependencies from other interfaces that the
91+
// op may implement.
92+
if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
93+
bufferViewFlowOp.populateDependencies(registerDependencies);
94+
for (Value v : op->getResults())
95+
if (isa<BaseMemRefType>(v.getType()) &&
96+
bufferViewFlowOp.isTerminalBuffer(v))
97+
this->terminals.insert(v);
98+
for (Region &r : op->getRegions())
99+
for (BlockArgument v : r.getArguments())
100+
if (isa<BaseMemRefType>(v.getType()) &&
101+
bufferViewFlowOp.isTerminalBuffer(v))
102+
this->terminals.insert(v);
103+
return WalkResult::advance();
104+
}
75105

76106
// Add additional dependencies created by view changes to the alias list.
77107
if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
78-
dependencies[viewInterface.getViewSource()].insert(
79-
viewInterface->getResult(0));
108+
registerDependencies(viewInterface.getViewSource(),
109+
viewInterface->getResult(0));
80110
return WalkResult::advance();
81111
}
82112

@@ -131,16 +161,30 @@ void BufferViewFlowAnalysis::build(Operation *op) {
131161
return WalkResult::advance();
132162
}
133163

134-
// Unknown op: Assume that all operands alias with all results.
135-
for (Value operand : op->getOperands()) {
136-
if (!isa<BaseMemRefType>(operand.getType()))
137-
continue;
138-
for (Value result : op->getResults()) {
139-
if (!isa<BaseMemRefType>(result.getType()))
140-
continue;
141-
registerDependencies({operand}, {result});
142-
}
164+
// Region terminators are handled together with RegionBranchOpInterface.
165+
if (isa<RegionBranchTerminatorOpInterface>(op))
166+
return WalkResult::advance();
167+
168+
if (isa<CallOpInterface>(op)) {
169+
// This is an intra-function analysis. We have no information about other
170+
// functions. Conservatively assume that each operand may alias with each
171+
// result. Also mark the results are terminals because the function could
172+
// return newly allocated buffers.
173+
populateTerminalValues(op);
174+
for (Value operand : op->getOperands())
175+
for (Value result : op->getResults())
176+
registerDependencies({operand}, {result});
177+
return WalkResult::advance();
143178
}
179+
180+
// We have no information about unknown ops.
181+
populateTerminalValues(op);
182+
144183
return WalkResult::advance();
145184
});
146185
}
186+
187+
bool BufferViewFlowAnalysis::isTerminalBuffer(Value value) const {
188+
assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
189+
return terminals.contains(value);
190+
}

mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
2626
LINK_LIBS PUBLIC
2727
MLIRArithDialect
2828
MLIRBufferizationDialect
29+
MLIRBufferizationTransforms
2930
MLIRControlFlowInterfaces
3031
MLIRFuncDialect
3132
MLIRFunctionInterfaces

0 commit comments

Comments
 (0)