Skip to content

Commit b046dde

Browse files
committed
Add extension mechanism to BufferizationState
1 parent b4a7fa3 commit b046dde

File tree

6 files changed

+123
-16
lines changed

6 files changed

+123
-16
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -582,12 +582,77 @@ class AnalysisState {
582582
/// bufferization process.
583583
class BufferizationState {
584584
public:
585-
/// Get the cached symbol tables.
586-
/// The user is expected to update / invalidate the cached symbol tables if
587-
/// the bufferized operation have the Symbol or SymbolTable traits.
588-
SymbolTableCollection &getSymbolTables();
585+
/// Base class for BufferizationState extensions that allow BufferizationState
586+
/// to contain user-specified information in the state object. The extension
587+
/// mechanism of BufferizationState mirrors the one of OneShotAnalysisState.
588+
class Extension {
589+
public:
590+
/// Base virtual destructor.
591+
// Out-of-line definition ensures symbols are emitted in a single object
592+
// file.
593+
virtual ~Extension();
594+
595+
protected:
596+
/// Constructs an extension of the given state object.
597+
Extension(BufferizationState &state) : state(state) {}
598+
599+
/// Provides read-only access to the parent OneShotAnalysisState object.
600+
const BufferizationState &getBufferizationState() const { return state; }
601+
602+
private:
603+
/// Back-reference to the state that is being extended.
604+
BufferizationState &state;
605+
};
589606

590-
private:
607+
/// Adds a new Extension of the type specified as template parameter,
608+
/// constructing it with the arguments provided. The extension is owned by the
609+
/// BufferizationState. It is expected that the state does not already have an
610+
/// extension of the same type. Extension constructors are expected to take a
611+
/// reference to BufferizationState as first argument, automatically supplied
612+
/// by this call.
613+
template <typename Ty, typename... Args>
614+
Ty &addExtension(Args &&...args) {
615+
static_assert(std::is_base_of<Extension, Ty>::value,
616+
"only a class derived from "
617+
"BufferizationState::Extension is allowed");
618+
auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
619+
auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
620+
assert(result.second && "extension already added");
621+
return *static_cast<Ty *>(result.first->second.get());
622+
}
623+
624+
/// Returns the extension of the specified type.
625+
template <typename Ty>
626+
Ty *getExtension() {
627+
static_assert(std::is_base_of<Extension, Ty>::value,
628+
"only a class derived from "
629+
"BufferizationState::Extension is allowed");
630+
auto iter = extensions.find(TypeID::get<Ty>());
631+
if (iter == extensions.end())
632+
return nullptr;
633+
return static_cast<Ty *>(iter->second.get());
634+
}
635+
636+
/// Returns the extension of the specified type.
637+
template <typename Ty>
638+
const Ty *getExtension() const {
639+
return const_cast<BufferizationState *>(this)->getExtension<Ty>();
640+
}
641+
642+
/// Extensions attached to the state, identified by the TypeID of their type.
643+
/// Only one extension of any given type is allowed.
644+
DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
645+
};
646+
647+
/// Extra bufferization state that is required for bufferization of operations
648+
/// declaring a symbol or a symbol table.
649+
struct SymbolBufferizationState : public BufferizationState::Extension {
650+
SymbolBufferizationState(BufferizationState &state)
651+
: BufferizationState::Extension(state) {}
652+
653+
/// The cached symbol tables.
654+
/// The user is expected to update / invalidate the cached symbol tables if
655+
/// the bufferized operation has the Symbol or SymbolTable traits.
591656
SymbolTableCollection symbolTables;
592657
};
593658

mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class GlobalOp;
2929
} // namespace memref
3030

3131
namespace bufferization {
32+
class BufferizationState;
3233

3334
/// A simple analysis that detects allocation operations.
3435
class BufferPlacementAllocs {
@@ -126,6 +127,15 @@ FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
126127
uint64_t alignment,
127128
Attribute memorySpace = {});
128129

130+
FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp op,
131+
BufferizationState &state,
132+
uint64_t alignment,
133+
Attribute memorySpace);
134+
135+
void removeSymbol(Operation *op, BufferizationState &state);
136+
137+
void insertSymbol(Operation *op, BufferizationState &state);
138+
129139
} // namespace bufferization
130140
} // namespace mlir
131141

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ struct ConstantOpInterface
4747
// Create global memory segment and replace tensor with memref pointing to
4848
// that memory segment.
4949
FailureOr<memref::GlobalOp> globalOp =
50-
getGlobalFor(constantOp, state.getSymbolTables(),
51-
options.bufferAlignment, memorySpace);
50+
getGlobalFor(constantOp, state, options.bufferAlignment, memorySpace);
5251
if (failed(globalOp))
5352
return failure();
5453
memref::GlobalOp globalMemref = *globalOp;

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,6 @@ void AnalysisState::resetCache() {
125125
insideMutuallyExclusiveRegionsCache.clear();
126126
}
127127

128-
SymbolTableCollection &BufferizationState::getSymbolTables() {
129-
return symbolTables;
130-
}
131-
132128
Region *bufferization::getNextEnclosingRepetitiveRegion(
133129
Region *region, const BufferizationOptions &options) {
134130
assert(isRepetitiveRegion(region, options) && "expected repetitive region");

mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,42 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
159159
global->moveBefore(&moduleOp.front());
160160
return global;
161161
}
162+
163+
namespace mlir::bufferization {
164+
FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp op,
165+
BufferizationState &state,
166+
uint64_t alignment,
167+
Attribute memorySpace) {
168+
if (auto *symbolBufferizationState =
169+
state.getExtension<SymbolBufferizationState>()) {
170+
// Use the cached symbol tables.
171+
return getGlobalFor(op, symbolBufferizationState->symbolTables, alignment,
172+
memorySpace);
173+
}
174+
175+
SymbolTableCollection symbolTables;
176+
return getGlobalFor(op, symbolTables, alignment, memorySpace);
177+
}
178+
179+
void removeSymbol(Operation *op, BufferizationState &state) {
180+
if (auto *symbolBufferizationState =
181+
state.getExtension<SymbolBufferizationState>()) {
182+
SymbolTable &symbolTable =
183+
symbolBufferizationState->symbolTables.getSymbolTable(
184+
op->getParentWithTrait<OpTrait::SymbolTable>());
185+
186+
symbolTable.remove(op);
187+
}
188+
}
189+
190+
void insertSymbol(Operation *op, BufferizationState &state) {
191+
if (auto *symbolBufferizationState =
192+
state.getExtension<SymbolBufferizationState>()) {
193+
SymbolTable &symbolTable =
194+
symbolBufferizationState->symbolTables.getSymbolTable(
195+
op->getParentWithTrait<OpTrait::SymbolTable>());
196+
197+
symbolTable.insert(op);
198+
}
199+
}
200+
} // namespace mlir::bufferization

mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
1010

1111
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12+
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
1213
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
1314
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1415

@@ -58,10 +59,7 @@ struct GlobalOpInterface
5859
if (!globalOp.getValue().has_value())
5960
return globalOp.emitError("global op must have a value");
6061

61-
SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
62-
globalOp->getParentWithTrait<OpTrait::SymbolTable>());
63-
64-
symbolTable.remove(globalOp);
62+
bufferization::removeSymbol(globalOp, state);
6563

6664
auto tensorType = cast<TensorType>(globalOp.getType());
6765
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
@@ -74,7 +72,7 @@ struct GlobalOpInterface
7472
/*constant=*/!globalOp.getIsMutable(),
7573
/*alignment=*/nullptr);
7674

77-
symbolTable.insert(replacement);
75+
bufferization::insertSymbol(replacement, state);
7876
return success();
7977
}
8078
};

0 commit comments

Comments
 (0)