Skip to content

[mlir][bufferization] Add BufferViewFlowOpInterface #78718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
#define MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H

namespace mlir {
class DialectRegistry;

namespace arith {
void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace arith
} // namespace mlir

#endif // MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===- BufferViewFlowOpInterface.h - Buffer View Flow Analysis --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_

#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"

namespace mlir {
class ValueRange;

namespace bufferization {

using RegisterDependenciesFn = std::function<void(ValueRange, ValueRange)>;

} // namespace bufferization
} // namespace mlir

#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h.inc"

#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//===-- BufferViewFlowOpInterface.td - Buffer View Flow ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef BUFFER_VIEW_FLOW_OP_INTERFACE
#define BUFFER_VIEW_FLOW_OP_INTERFACE

include "mlir/IR/OpBase.td"

def BufferViewFlowOpInterface :
OpInterface<"BufferViewFlowOpInterface"> {
let description = [{
An op interface for the buffer view flow analysis. This interface describes
buffer dependencies between operands and op results/region entry block
arguments.
}];
let cppNamespace = "::mlir::bufferization";
let methods = [
InterfaceMethod<
/*desc=*/[{
Populate buffer dependencies between operands and op results/region
entry block arguments.

Implementations should register dependencies between an operand ("X")
and an op result/region entry block argument ("Y") if Y may depend
on X. Y depends on X if Y and X are the same buffer or if Y is a
subview of X.

Example:
```
%r = arith.select %c, %m1, %m2 : memref<5xf32>
```
In the above example, %0 may depend on %m1 or %m2 and a correct
interface implementation should call:
- "registerDependenciesFn(%m1, %r)".
- "registerDependenciesFn(%m2, %r)"
}],
/*retType=*/"void",
/*methodName=*/"populateDependencies",
/*args=*/(ins
"::mlir::bufferization::RegisterDependenciesFn"
:$registerDependenciesFn)
>,
InterfaceMethod<
/*desc=*/[{
Return "true" if the given value may be a terminal buffer. A buffer
value is "terminal" if it cannot be traced back any further in the
buffer view flow analysis.

Examples: A buffer could be terminal because:
- it is a newly allocated buffer (e.g., "memref.alloc"),
- or: because there is not enough compile-time information available
to make a definite decision (e.g., "memref.realloc" may reallocate
but we do not know for sure; another example are call ops where we
would have to analyze the body of the callee).

Implementations can assume that the given SSA value is an OpResult of
this operation or a region entry block argument of this operation.
}],
/*retType=*/"bool",
/*methodName=*/"mayBeTerminalBuffer",
/*args=*/(ins "Value":$value),
/*methodBody=*/"",
/*defaultImplementation=*/"return false;"
>,
];
}

#endif // BUFFER_VIEW_FLOW_OP_INTERFACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
add_mlir_interface(AllocationOpInterface)
add_mlir_interface(BufferDeallocationOpInterface)
add_mlir_interface(BufferizableOpInterface)
add_mlir_interface(BufferViewFlowOpInterface)

set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,19 @@ class BufferViewFlowAnalysis {
/// results have to be changed.
void rename(Value from, Value to);

/// Returns "true" if the given value may be a terminal.
bool mayBeTerminalBuffer(Value value) const;

private:
/// This function constructs a mapping from values to its immediate
/// dependencies.
void build(Operation *op);

/// Maps values to all immediate dependencies this value can have.
ValueMapT dependencies;

/// A set of all SSA values that may be terminal buffers.
DenseSet<Value> terminals;
};

} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
#define MLIR_DIALECT_MEMREF_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H

namespace mlir {
class DialectRegistry;

namespace memref {
void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace memref
} // namespace mlir

#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
4 changes: 4 additions & 0 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
Expand Down Expand Up @@ -52,6 +53,7 @@
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
Expand Down Expand Up @@ -148,6 +150,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
affine::registerValueBoundsOpInterfaceExternalModels(registry);
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
arith::registerValueBoundsOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
Expand All @@ -157,6 +160,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
linalg::registerAllDialectInterfaceImplementations(registry);
memref::registerAllocationOpInterfaceExternalModels(registry);
memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
memref::registerValueBoundsOpInterfaceExternalModels(registry);
memref::registerMemorySlotExternalModels(registry);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"

using namespace mlir;
using namespace mlir::bufferization;

namespace mlir {
namespace arith {
namespace {

struct SelectOpInterface
: public BufferViewFlowOpInterface::ExternalModel<SelectOpInterface,
SelectOp> {
void
populateDependencies(Operation *op,
RegisterDependenciesFn registerDependenciesFn) const {
auto selectOp = cast<SelectOp>(op);

// Either one of the true/false value may be selected at runtime.
registerDependenciesFn(selectOp.getTrueValue(), selectOp.getResult());
registerDependenciesFn(selectOp.getFalseValue(), selectOp.getResult());
}
};

} // namespace
} // namespace arith
} // namespace mlir

void arith::registerBufferViewFlowOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
SelectOp::attachInterface<SelectOpInterface>(*ctx);
});
}
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithTransforms
BufferDeallocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
BufferViewFlowOpInterfaceImpl.cpp
EmulateUnsupportedFloats.cpp
EmulateWideInt.cpp
EmulateNarrowType.cpp
Expand Down
18 changes: 18 additions & 0 deletions mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//===- BufferViewFlowOpInterface.cpp - Buffer View Flow Analysis ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"

namespace mlir {
namespace bufferization {

#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp.inc"

} // namespace bufferization
} // namespace mlir
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
BufferDeallocationOpInterface.cpp
BufferizationOps.cpp
BufferizationDialect.cpp
BufferViewFlowOpInterface.cpp
UnstructuredControlFlow.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@

#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"

#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SetVector.h"

using namespace mlir;
using namespace mlir::bufferization;

/// Constructs a new alias analysis using the op provided.
BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
Expand Down Expand Up @@ -65,18 +69,44 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) {
void BufferViewFlowAnalysis::build(Operation *op) {
// Registers all dependencies of the given values.
auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
for (auto [value, dep] : llvm::zip(values, dependencies))
for (auto [value, dep] : llvm::zip_equal(values, dependencies))
this->dependencies[value].insert(dep);
};

// Mark all buffer results and buffer region entry block arguments of the
// given op as terminals.
auto populateTerminalValues = [&](Operation *op) {
for (Value v : op->getResults())
if (isa<BaseMemRefType>(v.getType()))
this->terminals.insert(v);
for (Region &r : op->getRegions())
for (BlockArgument v : r.getArguments())
if (isa<BaseMemRefType>(v.getType()))
this->terminals.insert(v);
};

op->walk([&](Operation *op) {
// TODO: We should have an op interface instead of a hard-coded list of
// interfaces/ops.
// Query BufferViewFlowOpInterface. If the op does not implement that
// interface, try to infer the dependencies from other interfaces that the
// op may implement.
if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
bufferViewFlowOp.populateDependencies(registerDependencies);
for (Value v : op->getResults())
if (isa<BaseMemRefType>(v.getType()) &&
bufferViewFlowOp.mayBeTerminalBuffer(v))
this->terminals.insert(v);
for (Region &r : op->getRegions())
for (BlockArgument v : r.getArguments())
if (isa<BaseMemRefType>(v.getType()) &&
bufferViewFlowOp.mayBeTerminalBuffer(v))
this->terminals.insert(v);
return WalkResult::advance();
}

// Add additional dependencies created by view changes to the alias list.
if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
dependencies[viewInterface.getViewSource()].insert(
viewInterface->getResult(0));
registerDependencies(viewInterface.getViewSource(),
viewInterface->getResult(0));
return WalkResult::advance();
}

Expand Down Expand Up @@ -131,16 +161,30 @@ void BufferViewFlowAnalysis::build(Operation *op) {
return WalkResult::advance();
}

// Unknown op: Assume that all operands alias with all results.
for (Value operand : op->getOperands()) {
if (!isa<BaseMemRefType>(operand.getType()))
continue;
for (Value result : op->getResults()) {
if (!isa<BaseMemRefType>(result.getType()))
continue;
registerDependencies({operand}, {result});
}
// Region terminators are handled together with RegionBranchOpInterface.
if (isa<RegionBranchTerminatorOpInterface>(op))
return WalkResult::advance();

if (isa<CallOpInterface>(op)) {
// This is an intra-function analysis. We have no information about other
// functions. Conservatively assume that each operand may alias with each
// result. Also mark the results are terminals because the function could
// return newly allocated buffers.
populateTerminalValues(op);
for (Value operand : op->getOperands())
for (Value result : op->getResults())
registerDependencies({operand}, {result});
return WalkResult::advance();
}

// We have no information about unknown ops.
populateTerminalValues(op);

return WalkResult::advance();
});
}

bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const {
assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
return terminals.contains(value);
}
Loading