Skip to content

Commit 7db1c50

Browse files
committed
Added get_num_threads and _mlir_thread_pool_ptr methods to _BaseContext
Added thread_pool arg to the constructor: `mlir.ir.Context(thread_pool=tp)`
1 parent e0c17c5 commit 7db1c50

File tree

6 files changed

+60
-6
lines changed

6 files changed

+60
-6
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,15 @@ MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context,
162162
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context,
163163
MlirLlvmThreadPool threadPool);
164164

165+
/// Gets the number of threads of the thread pool of the context when
166+
/// multithreading is enabled. Returns 1 if no multithreading.
167+
MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context);
168+
169+
/// Gets the thread pool of the context when enabled multithreading, otherwise
170+
/// an assertion is raised.
171+
MLIR_CAPI_EXPORTED MlirLlvmThreadPool
172+
mlirContextGetThreadPool(MlirContext context);
173+
165174
//===----------------------------------------------------------------------===//
166175
// Dialect API.
167176
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2747,7 +2747,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
27472747
// Expose DefaultThreadPool to python
27482748
nb::class_<PyThreadPool>(m, "ThreadPool")
27492749
.def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
2750-
.def("get_max_concurrency", &PyThreadPool::getMaxConcurrency);
2750+
.def("get_max_concurrency", &PyThreadPool::getMaxConcurrency)
2751+
.def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr);
27512752

27522753
nb::class_<PyMlirContext>(m, "_BaseContext")
27532754
.def("__init__",
@@ -2822,8 +2823,23 @@ void mlir::python::populateIRCore(nb::module_ &m) {
28222823
nb::arg("enable"))
28232824
.def("set_thread_pool",
28242825
[](PyMlirContext &self, PyThreadPool &pool) {
2826+
// we should disable multi-threading first before setting
2827+
// new thread pool otherwise the assert in
2828+
// MLIRContext::setThreadPool will be raised.
2829+
mlirContextEnableMultithreading(self.get(), false);
28252830
mlirContextSetThreadPool(self.get(), pool.get());
28262831
})
2832+
.def("get_num_threads",
2833+
[](PyMlirContext &self) {
2834+
return mlirContextGetNumThreads(self.get());
2835+
})
2836+
.def("_mlir_thread_pool_ptr",
2837+
[](PyMlirContext &self) {
2838+
MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
2839+
std::stringstream ss;
2840+
ss << pool.ptr;
2841+
return ss.str();
2842+
})
28272843
.def(
28282844
"is_registered_operation",
28292845
[](PyMlirContext &self, std::string &name) {

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#define MLIR_BINDINGS_PYTHON_IRMODULES_H
1212

1313
#include <optional>
14+
#include <sstream>
1415
#include <utility>
1516
#include <vector>
1617

@@ -172,6 +173,12 @@ class PyThreadPool {
172173
int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); }
173174
MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); }
174175

176+
std::string _mlir_thread_pool_ptr() const {
177+
std::stringstream ss;
178+
ss << ownedThreadPool.get();
179+
return ss.str();
180+
}
181+
175182
private:
176183
std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
177184
};

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,14 @@ void mlirContextSetThreadPool(MlirContext context,
114114
unwrap(context)->setThreadPool(*unwrap(threadPool));
115115
}
116116

117+
unsigned mlirContextGetNumThreads(MlirContext context) {
118+
return unwrap(context)->getNumThreads();
119+
}
120+
121+
MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context) {
122+
return wrap(&unwrap(context)->getThreadPool());
123+
}
124+
117125
//===----------------------------------------------------------------------===//
118126
// Dialect API.
119127
//===----------------------------------------------------------------------===//

mlir/python/mlir/_mlir_libs/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,18 @@ def process_initializer_module(module_name):
148148
break
149149

150150
class Context(ir._BaseContext):
151-
def __init__(self, load_on_create_dialects=None, *args, **kwargs):
151+
def __init__(
152+
self, load_on_create_dialects=None, thread_pool=None, *args, **kwargs
153+
):
152154
super().__init__(*args, **kwargs)
153155
self.append_dialect_registry(get_dialect_registry())
154156
for hook in post_init_hooks:
155157
hook(self)
156158
if not disable_multithreading:
157-
self.enable_multithreading(True)
159+
if thread_pool is None:
160+
self.enable_multithreading(True)
161+
else:
162+
self.set_thread_pool(thread_pool)
158163
if load_on_create_dialects is not None:
159164
logger.debug(
160165
"Loading all dialects from load_on_create_dialects arg %r",

mlir/test/python/ir/context_lifecycle.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,18 @@
5555
tp = mlir.ir.ThreadPool()
5656
assert tp.get_max_concurrency() > 0
5757
c5 = mlir.ir.Context()
58-
c5.enable_multithreading(False)
5958
c5.set_thread_pool(tp)
59+
assert c5.get_num_threads() == tp.get_max_concurrency()
60+
assert c5._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
6061
c6 = mlir.ir.Context()
61-
c6.enable_multithreading(False)
6262
c6.set_thread_pool(tp)
63-
assert mlir.ir.Context._get_live_count() == 2
63+
assert c6.get_num_threads() == tp.get_max_concurrency()
64+
assert c6._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
65+
c7 = mlir.ir.Context(thread_pool=tp)
66+
assert c7.get_num_threads() == tp.get_max_concurrency()
67+
assert c7._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
68+
assert mlir.ir.Context._get_live_count() == 3
69+
c5 = None
70+
c6 = None
71+
c7 = None
72+
gc.collect()

0 commit comments

Comments
 (0)