Skip to content

Commit 30d6189

Browse files
committed
[mlir] provide C API and Python bindings for symbol tables
Symbol tables are a largely useful top-level IR construct, for example, they make it easy to access functions in a module by name instead of traversing the list of module's operations to find the corresponding function. Depends On D112886 Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D112821
1 parent feec2d9 commit 30d6189

File tree

9 files changed

+404
-54
lines changed

9 files changed

+404
-54
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ DEFINE_C_API_STRUCT(MlirOperation, void);
5454
DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
5555
DEFINE_C_API_STRUCT(MlirBlock, void);
5656
DEFINE_C_API_STRUCT(MlirRegion, void);
57+
DEFINE_C_API_STRUCT(MlirSymbolTable, void);
5758

5859
DEFINE_C_API_STRUCT(MlirAttribute, const void);
5960
DEFINE_C_API_STRUCT(MlirIdentifier, const void);
@@ -738,6 +739,47 @@ MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2);
738739
/// Returns the hash value of the type id.
739740
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID);
740741

742+
//===----------------------------------------------------------------------===//
743+
// Symbol and SymbolTable API.
744+
//===----------------------------------------------------------------------===//
745+
746+
/// Returns the name of the attribute used to store symbol names compatible with
747+
/// symbol tables.
748+
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName();
749+
750+
/// Creates a symbol table for the given operation. If the operation does not
751+
/// have the SymbolTable trait, returns a null symbol table.
752+
MLIR_CAPI_EXPORTED MlirSymbolTable
753+
mlirSymbolTableCreate(MlirOperation operation);
754+
755+
/// Returns true if the symbol table is null.
756+
static inline bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable) {
757+
return !symbolTable.ptr;
758+
}
759+
760+
/// Destroys the symbol table created with mlirSymbolTableCreate. This does not
761+
/// affect the operations in the table.
762+
MLIR_CAPI_EXPORTED void mlirSymbolTableDestroy(MlirSymbolTable symbolTable);
763+
764+
/// Looks up a symbol with the given name in the given symbol table and returns
765+
/// the operation that corresponds to the symbol. If the symbol cannot be found,
766+
/// returns a null operation.
767+
MLIR_CAPI_EXPORTED MlirOperation
768+
mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name);
769+
770+
/// Inserts the given operation into the given symbol table. The operation must
771+
/// have the symbol trait. If the symbol table already has a symbol with the
772+
/// same name, renames the symbol being inserted to ensure name uniqueness. Note
773+
/// that this does not move the operation itself into the block of the symbol
774+
/// table operation, this should be done separately. Returns the name of the
775+
/// symbol after insertion.
776+
MLIR_CAPI_EXPORTED MlirAttribute
777+
mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation);
778+
779+
/// Removes the given operation from the symbol table and erases it.
780+
MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable,
781+
MlirOperation operation);
782+
741783
#ifdef __cplusplus
742784
}
743785
#endif

mlir/include/mlir-c/Support.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ inline static MlirStringRef mlirStringRefCreate(const char *str,
7979
MLIR_CAPI_EXPORTED MlirStringRef
8080
mlirStringRefCreateFromCString(const char *str);
8181

82+
/// Returns true if two string references are equal, false otherwise.
83+
MLIR_CAPI_EXPORTED bool mlirStringRefEqual(MlirStringRef string,
84+
MlirStringRef other);
85+
8286
/// A callback for returning string references.
8387
///
8488
/// This function is called back by the functions that need to return a

mlir/include/mlir/CAPI/IR.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
2727
DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
2828
DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags)
2929
DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
30+
DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable);
3031

3132
DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute)
3233
DEFINE_C_API_METHODS(MlirIdentifier, mlir::Identifier)

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,6 +1530,57 @@ PyValue PyValue::createFromCapsule(pybind11::object capsule) {
15301530
return PyValue(ownerRef, value);
15311531
}
15321532

1533+
//------------------------------------------------------------------------------
1534+
// PySymbolTable.
1535+
//------------------------------------------------------------------------------
1536+
1537+
PySymbolTable::PySymbolTable(PyOperationBase &operation)
1538+
: operation(operation.getOperation().getRef()) {
1539+
symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1540+
if (mlirSymbolTableIsNull(symbolTable)) {
1541+
throw py::cast_error("Operation is not a Symbol Table.");
1542+
}
1543+
}
1544+
1545+
py::object PySymbolTable::dunderGetItem(const std::string &name) {
1546+
operation->checkValid();
1547+
MlirOperation symbol = mlirSymbolTableLookup(
1548+
symbolTable, mlirStringRefCreate(name.data(), name.length()));
1549+
if (mlirOperationIsNull(symbol))
1550+
throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1551+
1552+
return PyOperation::forOperation(operation->getContext(), symbol,
1553+
operation.getObject())
1554+
->createOpView();
1555+
}
1556+
1557+
void PySymbolTable::erase(PyOperationBase &symbol) {
1558+
operation->checkValid();
1559+
symbol.getOperation().checkValid();
1560+
mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1561+
// The operation is also erased, so we must invalidate it. There may be Python
1562+
// references to this operation so we don't want to delete it from the list of
1563+
// live operations here.
1564+
symbol.getOperation().valid = false;
1565+
}
1566+
1567+
void PySymbolTable::dunderDel(const std::string &name) {
1568+
py::object operation = dunderGetItem(name);
1569+
erase(py::cast<PyOperationBase &>(operation));
1570+
}
1571+
1572+
PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1573+
operation->checkValid();
1574+
symbol.getOperation().checkValid();
1575+
MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1576+
symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1577+
if (mlirAttributeIsNull(symbolAttr))
1578+
throw py::value_error("Expected operation to have a symbol name.");
1579+
return PyAttribute(
1580+
symbol.getOperation().getContext(),
1581+
mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1582+
}
1583+
15331584
namespace {
15341585
/// CRTP base class for Python MLIR values that subclass Value and should be
15351586
/// castable from it. The value hierarchy is one level deep and is not supposed
@@ -2670,6 +2721,20 @@ void mlir::python::populateIRCore(py::module &m) {
26702721
PyBlockArgument::bind(m);
26712722
PyOpResult::bind(m);
26722723

2724+
//----------------------------------------------------------------------------
2725+
// Mapping of SymbolTable.
2726+
//----------------------------------------------------------------------------
2727+
py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
2728+
.def(py::init<PyOperationBase &>())
2729+
.def("__getitem__", &PySymbolTable::dunderGetItem)
2730+
.def("insert", &PySymbolTable::insert)
2731+
.def("erase", &PySymbolTable::erase)
2732+
.def("__delitem__", &PySymbolTable::dunderDel)
2733+
.def("__contains__", [](PySymbolTable &table, const std::string &name) {
2734+
return !mlirOperationIsNull(mlirSymbolTableLookup(
2735+
table, mlirStringRefCreate(name.data(), name.length())));
2736+
});
2737+
26732738
// Container bindings.
26742739
PyBlockArgumentList::bind(m);
26752740
PyBlockIterator::bind(m);

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class DefaultingPyMlirContext;
3232
class PyModule;
3333
class PyOperation;
3434
class PyType;
35+
class PySymbolTable;
3536
class PyValue;
3637

3738
/// Template for a reference to a concrete type which captures a python
@@ -513,6 +514,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
513514
bool valid = true;
514515

515516
friend class PyOperationBase;
517+
friend class PySymbolTable;
516518
};
517519

518520
/// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
@@ -876,6 +878,38 @@ class PyIntegerSet : public BaseContextObject {
876878
MlirIntegerSet integerSet;
877879
};
878880

881+
/// Bindings for MLIR symbol tables.
882+
class PySymbolTable {
883+
public:
884+
/// Constructs a symbol table for the given operation.
885+
explicit PySymbolTable(PyOperationBase &operation);
886+
887+
/// Destroys the symbol table.
888+
~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); }
889+
890+
/// Returns the symbol (opview) with the given name, throws if there is no
891+
/// such symbol in the table.
892+
pybind11::object dunderGetItem(const std::string &name);
893+
894+
/// Removes the given operation from the symbol table and erases it.
895+
void erase(PyOperationBase &symbol);
896+
897+
/// Removes the operation with the given name from the symbol table and erases
898+
/// it, throws if there is no such symbol in the table.
899+
void dunderDel(const std::string &name);
900+
901+
/// Inserts the given operation into the symbol table. The operation must have
902+
/// the symbol trait.
903+
PyAttribute insert(PyOperationBase &symbol);
904+
905+
/// Casts the bindings class into the C API structure.
906+
operator MlirSymbolTable() { return symbolTable; }
907+
908+
private:
909+
PyOperationRef operation;
910+
MlirSymbolTable symbolTable;
911+
};
912+
879913
void populateIRAffine(pybind11::module &m);
880914
void populateIRAttributes(pybind11::module &m);
881915
void populateIRCore(pybind11::module &m);

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,3 +763,36 @@ bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) {
763763
size_t mlirTypeIDHashValue(MlirTypeID typeID) {
764764
return hash_value(unwrap(typeID));
765765
}
766+
767+
//===----------------------------------------------------------------------===//
768+
// Symbol and SymbolTable API.
769+
//===----------------------------------------------------------------------===//
770+
771+
MlirStringRef mlirSymbolTableGetSymbolAttributeName() {
772+
return wrap(SymbolTable::getSymbolAttrName());
773+
}
774+
775+
MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) {
776+
if (!unwrap(operation)->hasTrait<OpTrait::SymbolTable>())
777+
return wrap(static_cast<SymbolTable *>(nullptr));
778+
return wrap(new SymbolTable(unwrap(operation)));
779+
}
780+
781+
void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) {
782+
delete unwrap(symbolTable);
783+
}
784+
785+
MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable,
786+
MlirStringRef name) {
787+
return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length)));
788+
}
789+
790+
MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable,
791+
MlirOperation operation) {
792+
return wrap(unwrap(symbolTable)->insert(unwrap(operation)));
793+
}
794+
795+
void mlirSymbolTableErase(MlirSymbolTable symbolTable,
796+
MlirOperation operation) {
797+
unwrap(symbolTable)->erase(unwrap(operation));
798+
}

mlir/lib/CAPI/IR/Support.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,15 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir-c/Support.h"
10+
#include "llvm/ADT/StringRef.h"
1011

1112
#include <cstring>
1213

1314
MlirStringRef mlirStringRefCreateFromCString(const char *str) {
1415
return mlirStringRefCreate(str, strlen(str));
1516
}
17+
18+
bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) {
19+
return llvm::StringRef(string.data, string.length) ==
20+
llvm::StringRef(other.data, other.length);
21+
}

0 commit comments

Comments
 (0)