Skip to content

Commit 922415d

Browse files
committed
Added PyThreadPool as wrapper around MlirLlvmThreadPool in MLIR python bindings
1 parent 81168e2 commit 922415d

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2743,6 +2743,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
27432743
// __init__.py will subclass it with site-specific functionality and set a
27442744
// "Context" attribute on this module.
27452745
//----------------------------------------------------------------------------
2746+
2747+
// Expose DefaultThreadPool to python
2748+
nb::class_<PyThreadPool>(m, "ThreadPool")
2749+
.def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
2750+
.def("get_max_concurrency", &PyThreadPool::getMaxConcurrency);
2751+
27462752
nb::class_<PyMlirContext>(m, "_BaseContext")
27472753
.def("__init__",
27482754
[](PyMlirContext &self) {
@@ -2814,6 +2820,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
28142820
mlirContextEnableMultithreading(self.get(), enable);
28152821
},
28162822
nb::arg("enable"))
2823+
.def("set_thread_pool",
2824+
[](PyMlirContext &self, PyThreadPool &pool) {
2825+
mlirContextSetThreadPool(self.get(), pool.get());
2826+
})
28172827
.def(
28182828
"is_registered_operation",
28192829
[](PyMlirContext &self, std::string &name) {

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
#include "mlir-c/IR.h"
2323
#include "mlir-c/IntegerSet.h"
2424
#include "mlir-c/Transforms.h"
25-
#include "mlir/Bindings/Python/NanobindAdaptors.h"
2625
#include "mlir/Bindings/Python/Nanobind.h"
26+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
2727
#include "llvm/ADT/DenseMap.h"
28+
#include "llvm/Support/ThreadPool.h"
2829

2930
namespace mlir {
3031
namespace python {
@@ -158,6 +159,23 @@ class PyThreadContextEntry {
158159
FrameKind frameKind;
159160
};
160161

162+
/// Wrapper around MlirLlvmThreadPool
163+
/// Python object owns the C++ thread pool
164+
class PyThreadPool {
165+
public:
166+
PyThreadPool() {
167+
ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
168+
}
169+
PyThreadPool(const PyThreadPool &) = delete;
170+
PyThreadPool(PyThreadPool &&) = delete;
171+
172+
int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); }
173+
MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); }
174+
175+
private:
176+
std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
177+
};
178+
161179
/// Wrapper around MlirContext.
162180
using PyMlirContextRef = PyObjectRef<PyMlirContext>;
163181
class PyMlirContext {

0 commit comments

Comments
 (0)