File tree Expand file tree Collapse file tree 2 files changed +29
-1
lines changed Expand file tree Collapse file tree 2 files changed +29
-1
lines changed Original file line number Diff line number Diff line change @@ -2743,6 +2743,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
2743
2743
// __init__.py will subclass it with site-specific functionality and set a
2744
2744
// "Context" attribute on this module.
2745
2745
// ----------------------------------------------------------------------------
2746
+
2747
+ // Expose DefaultThreadPool to python
2748
+ nb::class_<PyThreadPool>(m, " ThreadPool" )
2749
+ .def (" __init__" , [](PyThreadPool &self) { new (&self) PyThreadPool (); })
2750
+ .def (" get_max_concurrency" , &PyThreadPool::getMaxConcurrency);
2751
+
2746
2752
nb::class_<PyMlirContext>(m, " _BaseContext" )
2747
2753
.def (" __init__" ,
2748
2754
[](PyMlirContext &self) {
@@ -2814,6 +2820,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
2814
2820
mlirContextEnableMultithreading (self.get (), enable);
2815
2821
},
2816
2822
nb::arg (" enable" ))
2823
+ .def (" set_thread_pool" ,
2824
+ [](PyMlirContext &self, PyThreadPool &pool) {
2825
+ mlirContextSetThreadPool (self.get (), pool.get ());
2826
+ })
2817
2827
.def (
2818
2828
" is_registered_operation" ,
2819
2829
[](PyMlirContext &self, std::string &name) {
Original file line number Diff line number Diff line change 22
22
#include " mlir-c/IR.h"
23
23
#include " mlir-c/IntegerSet.h"
24
24
#include " mlir-c/Transforms.h"
25
- #include " mlir/Bindings/Python/NanobindAdaptors.h"
26
25
#include " mlir/Bindings/Python/Nanobind.h"
26
+ #include " mlir/Bindings/Python/NanobindAdaptors.h"
27
27
#include " llvm/ADT/DenseMap.h"
28
+ #include " llvm/Support/ThreadPool.h"
28
29
29
30
namespace mlir {
30
31
namespace python {
@@ -158,6 +159,23 @@ class PyThreadContextEntry {
158
159
FrameKind frameKind;
159
160
};
160
161
162
+ // / Wrapper around MlirLlvmThreadPool
163
+ // / Python object owns the C++ thread pool
164
+ class PyThreadPool {
165
+ public:
166
+ PyThreadPool () {
167
+ ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
168
+ }
169
+ PyThreadPool (const PyThreadPool &) = delete ;
170
+ PyThreadPool (PyThreadPool &&) = delete ;
171
+
172
+ int getMaxConcurrency () const { return ownedThreadPool->getMaxConcurrency (); }
173
+ MlirLlvmThreadPool get () { return wrap (ownedThreadPool.get ()); }
174
+
175
+ private:
176
+ std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
177
+ };
178
+
161
179
// / Wrapper around MlirContext.
162
180
using PyMlirContextRef = PyObjectRef<PyMlirContext>;
163
181
class PyMlirContext {
You can’t perform that action at this time.
0 commit comments