Skip to content

Added PyThreadPool as wrapper around MlirLlvmThreadPool in MLIR python bindings #130109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,15 @@ MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context,
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context,
MlirLlvmThreadPool threadPool);

/// Gets the number of threads of the thread pool of the context when
/// multithreading is enabled. Returns 1 if no multithreading.
MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context);

/// Gets the thread pool of the context when enabled multithreading, otherwise
/// an assertion is raised.
MLIR_CAPI_EXPORTED MlirLlvmThreadPool
mlirContextGetThreadPool(MlirContext context);

//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 26 additions & 0 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2743,6 +2743,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
// __init__.py will subclass it with site-specific functionality and set a
// "Context" attribute on this module.
//----------------------------------------------------------------------------

// Expose DefaultThreadPool to python
nb::class_<PyThreadPool>(m, "ThreadPool")
.def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
.def("get_max_concurrency", &PyThreadPool::getMaxConcurrency)
.def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr);

nb::class_<PyMlirContext>(m, "_BaseContext")
.def("__init__",
[](PyMlirContext &self) {
Expand Down Expand Up @@ -2814,6 +2821,25 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirContextEnableMultithreading(self.get(), enable);
},
nb::arg("enable"))
.def("set_thread_pool",
[](PyMlirContext &self, PyThreadPool &pool) {
// we should disable multi-threading first before setting
// new thread pool otherwise the assert in
// MLIRContext::setThreadPool will be raised.
mlirContextEnableMultithreading(self.get(), false);
mlirContextSetThreadPool(self.get(), pool.get());
})
.def("get_num_threads",
[](PyMlirContext &self) {
return mlirContextGetNumThreads(self.get());
})
.def("_mlir_thread_pool_ptr",
[](PyMlirContext &self) {
MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
std::stringstream ss;
ss << pool.ptr;
return ss.str();
})
.def(
"is_registered_operation",
[](PyMlirContext &self, std::string &name) {
Expand Down
27 changes: 26 additions & 1 deletion mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#define MLIR_BINDINGS_PYTHON_IRMODULES_H

#include <optional>
#include <sstream>
#include <utility>
#include <vector>

Expand All @@ -22,9 +23,10 @@
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/ThreadPool.h"

namespace mlir {
namespace python {
Expand Down Expand Up @@ -158,6 +160,29 @@ class PyThreadContextEntry {
FrameKind frameKind;
};

/// Wrapper around MlirLlvmThreadPool
/// Python object owns the C++ thread pool
class PyThreadPool {
public:
PyThreadPool() {
ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
}
PyThreadPool(const PyThreadPool &) = delete;
PyThreadPool(PyThreadPool &&) = delete;

int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); }
MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); }

std::string _mlir_thread_pool_ptr() const {
std::stringstream ss;
ss << ownedThreadPool.get();
return ss.str();
}

private:
std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
};

/// Wrapper around MlirContext.
using PyMlirContextRef = PyObjectRef<PyMlirContext>;
class PyMlirContext {
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ void mlirContextSetThreadPool(MlirContext context,
unwrap(context)->setThreadPool(*unwrap(threadPool));
}

unsigned mlirContextGetNumThreads(MlirContext context) {
return unwrap(context)->getNumThreads();
}

MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context) {
return wrap(&unwrap(context)->getThreadPool());
}

//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
Expand Down
16 changes: 14 additions & 2 deletions mlir/python/mlir/_mlir_libs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,25 @@ def process_initializer_module(module_name):
break

class Context(ir._BaseContext):
def __init__(self, load_on_create_dialects=None, *args, **kwargs):
def __init__(
self, load_on_create_dialects=None, thread_pool=None, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.append_dialect_registry(get_dialect_registry())
for hook in post_init_hooks:
hook(self)
if disable_multithreading and thread_pool is not None:
raise ValueError(
"Context constructor has given thread_pool argument, "
"but disable_multithreading flag is True. "
"Please, set thread_pool argument to None or "
"set disable_multithreading flag to False."
)
if not disable_multithreading:
self.enable_multithreading(True)
if thread_pool is None:
self.enable_multithreading(True)
else:
self.set_thread_pool(thread_pool)
if load_on_create_dialects is not None:
logger.debug(
"Loading all dialects from load_on_create_dialects arg %r",
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/python/ir/context_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,26 @@
assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule)
c5 = mlir.ir.Context._CAPICreate(c4_capsule)
assert c4 is c5
c4 = None
c5 = None
gc.collect()

# Create a global threadpool and use it in two contexts
tp = mlir.ir.ThreadPool()
assert tp.get_max_concurrency() > 0
c5 = mlir.ir.Context()
c5.set_thread_pool(tp)
assert c5.get_num_threads() == tp.get_max_concurrency()
assert c5._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
c6 = mlir.ir.Context()
c6.set_thread_pool(tp)
assert c6.get_num_threads() == tp.get_max_concurrency()
assert c6._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
c7 = mlir.ir.Context(thread_pool=tp)
assert c7.get_num_threads() == tp.get_max_concurrency()
assert c7._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
assert mlir.ir.Context._get_live_count() == 3
c5 = None
c6 = None
c7 = None
gc.collect()
Loading