Skip to content

Added PyThreadPool as wrapper around MlirLlvmThreadPool in MLIR python bindings #130109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 10, 2025

Conversation

vfdev-5
Copy link
Contributor

@vfdev-5 vfdev-5 commented Mar 6, 2025

Description:

  • Exposed MlirLlvmThreadPool as PyThreadPool in MLIR Python bindings
  • Add tests

Context:

In JAX ir.Context are used with disabled multi-threading to avoid caching multiple threading pools:
https://github.com/jax-ml/jax/blob/623865fe9538100d877ba9d36f788d0f95a11ed2/jax/_src/interpreters/mlir.py#L606-L611
However, when context has enabled multithreading it also uses locks on the StorageUniquers and this can be helpful to avoid data races in the multi-threaded execution (for example with free-threaded cpython, jax-ml/jax#26272).
With this PR user can enable the multi-threading: 1) enables additional locking and 2) set a shared threading pool such that cached contexts can have one global pool.

cc @hawkinsp

@llvmbot llvmbot added the mlir label Mar 6, 2025
@llvmbot
Copy link
Member

llvmbot commented Mar 6, 2025

@llvm/pr-subscribers-llvm-support

@llvm/pr-subscribers-mlir

Author: vfdev (vfdev-5)

Changes

cc @hawkinsp


Full diff: https://github.com/llvm/llvm-project/pull/130109.diff

2 Files Affected:

  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+10)
  • (modified) mlir/lib/Bindings/Python/IRModule.h (+19-1)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 12793f7dd15be..1ec52a1a9bcd4 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2743,6 +2743,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   // __init__.py will subclass it with site-specific functionality and set a
   // "Context" attribute on this module.
   //----------------------------------------------------------------------------
+
+  // Expose DefaultThreadPool to python
+  nb::class_<PyThreadPool>(m, "ThreadPool")
+      .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
+      .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency);
+
   nb::class_<PyMlirContext>(m, "_BaseContext")
       .def("__init__",
            [](PyMlirContext &self) {
@@ -2814,6 +2820,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
             mlirContextEnableMultithreading(self.get(), enable);
           },
           nb::arg("enable"))
+      .def("set_thread_pool",
+           [](PyMlirContext &self, PyThreadPool &pool) {
+             mlirContextSetThreadPool(self.get(), pool.get());
+           })
       .def(
           "is_registered_operation",
           [](PyMlirContext &self, std::string &name) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 1ed6240a6ca69..b7bbd646d982e 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -22,9 +22,10 @@
 #include "mlir-c/IR.h"
 #include "mlir-c/IntegerSet.h"
 #include "mlir-c/Transforms.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/ThreadPool.h"
 
 namespace mlir {
 namespace python {
@@ -158,6 +159,23 @@ class PyThreadContextEntry {
   FrameKind frameKind;
 };
 
+/// Wrapper around MlirLlvmThreadPool
+/// Python object owns the C++ thread pool
+class PyThreadPool {
+public:
+  PyThreadPool() {
+    ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
+  }
+  PyThreadPool(const PyThreadPool &) = delete;
+  PyThreadPool(PyThreadPool &&) = delete;
+
+  int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); }
+  MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); }
+
+private:
+  std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
+};
+
 /// Wrapper around MlirContext.
 using PyMlirContextRef = PyObjectRef<PyMlirContext>;
 class PyMlirContext {

@vfdev-5 vfdev-5 marked this pull request as draft March 6, 2025 13:52
@vfdev-5 vfdev-5 force-pushed the mlir-python-expose-threadpool branch from cb796b9 to feedd57 Compare March 6, 2025 14:50
@vfdev-5 vfdev-5 marked this pull request as ready for review March 6, 2025 14:50
@llvmbot llvmbot added llvm:support bazel "Peripheral" support tier build system: utils/bazel labels Mar 6, 2025
@vfdev-5 vfdev-5 marked this pull request as draft March 6, 2025 14:56
@vfdev-5 vfdev-5 force-pushed the mlir-python-expose-threadpool branch from feedd57 to 6e0960a Compare March 6, 2025 14:58
@vfdev-5 vfdev-5 marked this pull request as ready for review March 6, 2025 15:03
@vfdev-5 vfdev-5 force-pushed the mlir-python-expose-threadpool branch from 6e0960a to 998821b Compare March 7, 2025 10:47
@llvmbot llvmbot added the mlir:python MLIR Python bindings label Mar 7, 2025
@vfdev-5 vfdev-5 force-pushed the mlir-python-expose-threadpool branch from 998821b to cb63466 Compare March 7, 2025 10:49
@vfdev-5
Copy link
Contributor Author

vfdev-5 commented Mar 7, 2025

@joker-eph in the last commit (7db1c50) I added thread_pool arg to mlir.ir.Context constructor. Also exposed get_num_threads and _mlir_thread_pool_ptr methods:

  • updated MLIR C API: added mlirContextGetNumThreads, mlirContextGetThreadPool methods

_mlir_thread_pool_ptr is an internal method to ensure that C++ thread pool is correctly set up.

Let me know if this works. Thanks!

Copy link

github-actions bot commented Mar 7, 2025

✅ With the latest revision this PR passed the Python code formatter.

@vfdev-5 vfdev-5 force-pushed the mlir-python-expose-threadpool branch from cb63466 to 7db1c50 Compare March 7, 2025 10:56
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG with one minor thing to check.

@vfdev-5 vfdev-5 force-pushed the mlir-python-expose-threadpool branch from 7db1c50 to a557554 Compare March 8, 2025 13:37
@joker-eph joker-eph merged commit ab18cc2 into llvm:main Mar 10, 2025
11 checks passed
@vfdev-5 vfdev-5 deleted the mlir-python-expose-threadpool branch March 10, 2025 10:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bazel "Peripheral" support tier build system: utils/bazel llvm:support mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants