Skip to content

Commit 95ddbed

Browse files
committed
[mlir] Split out Python bindings for dialects into separate libs
Historically, the bindings for the Linalg dialect were included into the "core" bindings library because they depended on the C++ implementation of the "core" bindings. The other dialects followed the pattern. Now that this dependency is gone, split out each dialect into a separate Python extension library. Depends On D116649, D116605 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D116662
1 parent 6e4bbbf commit 95ddbed

File tree

12 files changed

+154
-118
lines changed

12 files changed

+154
-118
lines changed

mlir/lib/Bindings/Python/DialectLinalg.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,23 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "Dialects.h"
109
#include "mlir-c/Dialect/Linalg.h"
1110
#include "mlir-c/IR.h"
1211
#include "mlir/Bindings/Python/PybindAdaptors.h"
1312

1413
namespace py = pybind11;
1514

16-
void mlir::python::populateDialectLinalgSubmodule(py::module m) {
15+
static void populateDialectLinalgSubmodule(py::module m) {
1716
m.def(
1817
"fill_builtin_region",
1918
[](MlirOperation op) { mlirLinalgFillBuiltinNamedOpRegion(op); },
2019
py::arg("op"),
2120
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
2221
"op.");
2322
}
23+
24+
PYBIND11_MODULE(_mlirDialectsLinalg, m) {
25+
m.doc() = "MLIR Linalg dialect.";
26+
27+
populateDialectLinalgSubmodule(m);
28+
}

mlir/lib/Bindings/Python/DialectQuant.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "Dialects.h"
109
#include "mlir-c/Dialect/Quant.h"
1110
#include "mlir-c/IR.h"
1211
#include "mlir/Bindings/Python/PybindAdaptors.h"
@@ -16,16 +15,13 @@ using namespace llvm;
1615
using namespace mlir;
1716
using namespace mlir::python::adaptors;
1817

19-
void mlir::python::populateDialectQuantSubmodule(const py::module &m,
20-
const py::module &irModule) {
21-
auto typeClass = irModule.attr("Type");
22-
18+
static void populateDialectQuantSubmodule(const py::module &m) {
2319
//===-------------------------------------------------------------------===//
2420
// QuantizedType
2521
//===-------------------------------------------------------------------===//
2622

27-
auto quantizedType = mlir_type_subclass(m, "QuantizedType",
28-
mlirTypeIsAQuantizedType, typeClass);
23+
auto quantizedType =
24+
mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType);
2925
quantizedType.def_staticmethod(
3026
"default_minimum_for_integer",
3127
[](bool isSigned, unsigned integralWidth) {
@@ -305,3 +301,9 @@ void mlir::python::populateDialectQuantSubmodule(const py::module &m,
305301
return mlirCalibratedQuantizedTypeGetMax(type);
306302
});
307303
}
304+
305+
PYBIND11_MODULE(_mlirDialectsQuant, m) {
306+
m.doc() = "MLIR Quantization dialect";
307+
308+
populateDialectQuantSubmodule(m);
309+
}

mlir/lib/Bindings/Python/DialectSparseTensor.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "Dialects.h"
109
#include "mlir-c/Dialect/SparseTensor.h"
1110
#include "mlir-c/IR.h"
1211
#include "mlir/Bindings/Python/PybindAdaptors.h"
@@ -16,18 +15,14 @@ using namespace llvm;
1615
using namespace mlir;
1716
using namespace mlir::python::adaptors;
1817

19-
void mlir::python::populateDialectSparseTensorSubmodule(
20-
const py::module &m, const py::module &irModule) {
21-
auto attributeClass = irModule.attr("Attribute");
22-
18+
static void populateDialectSparseTensorSubmodule(const py::module &m) {
2319
py::enum_<MlirSparseTensorDimLevelType>(m, "DimLevelType", py::module_local())
2420
.value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE)
2521
.value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED)
2622
.value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON);
2723

2824
mlir_attribute_subclass(m, "EncodingAttr",
29-
mlirAttributeIsASparseTensorEncodingAttr,
30-
attributeClass)
25+
mlirAttributeIsASparseTensorEncodingAttr)
3126
.def_classmethod(
3227
"get",
3328
[](py::object cls,
@@ -72,3 +67,8 @@ void mlir::python::populateDialectSparseTensorSubmodule(
7267
return mlirSparseTensorEncodingAttrGetIndexBitWidth(self);
7368
});
7469
}
70+
71+
PYBIND11_MODULE(_mlirDialectsSparseTensor, m) {
72+
m.doc() = "MLIR SparseTensor dialect.";
73+
populateDialectSparseTensorSubmodule(m);
74+
}

mlir/lib/Bindings/Python/Dialects.h

Lines changed: 0 additions & 26 deletions
This file was deleted.

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
#include "PybindUtils.h"
1212

13-
#include "Dialects.h"
1413
#include "Globals.h"
1514
#include "IRModule.h"
1615
#include "Pass.h"
@@ -100,13 +99,4 @@ PYBIND11_MODULE(_mlir, m) {
10099
auto passModule =
101100
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
102101
populatePassManagerSubmodule(passModule);
103-
104-
// Define and populate dialect submodules.
105-
auto dialectsModule = m.def_submodule("dialects");
106-
auto linalgModule = dialectsModule.def_submodule("linalg");
107-
populateDialectLinalgSubmodule(linalgModule);
108-
populateDialectSparseTensorSubmodule(
109-
dialectsModule.def_submodule("sparse_tensor"), irModule);
110-
populateDialectQuantSubmodule(dialectsModule.def_submodule("quant"),
111-
irModule);
112102
}

mlir/python/CMakeLists.txt

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core
2525
_mlir_libs/_mlir/__init__.pyi
2626
_mlir_libs/_mlir/ir.pyi
2727
_mlir_libs/_mlir/passmanager.pyi
28-
# TODO: this should be split out into a separate library.
29-
_mlir_libs/_mlir/dialects/quant.pyi
3028
)
3129

3230
declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine
@@ -122,7 +120,8 @@ declare_mlir_python_sources(
122120
ADD_TO_PARENT MLIRPythonSources.Dialects
123121
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
124122
SOURCES
125-
dialects/quant.py)
123+
dialects/quant.py
124+
_mlir_libs/_mlir/dialects/quant.pyi)
126125

127126
declare_mlir_dialect_python_bindings(
128127
ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -191,9 +190,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
191190
ADD_TO_PARENT MLIRPythonSources.Core
192191
ROOT_DIR "${PYTHON_SOURCE_DIR}"
193192
SOURCES
194-
DialectLinalg.cpp # TODO: Break this out.
195-
DialectSparseTensor.cpp # TODO: Break this out.
196-
DialectQuant.cpp # TODO: Break this out.
197193
MainModule.cpp
198194
IRAffine.cpp
199195
IRAttributes.cpp
@@ -205,7 +201,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
205201
Pass.cpp
206202

207203
# Headers must be included explicitly so they are installed.
208-
Dialects.h
209204
Globals.h
210205
IRModule.h
211206
Pass.h
@@ -219,10 +214,46 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
219214
MLIRCAPIRegistration # TODO: See about dis-aggregating
220215

221216
# Dialects
222-
MLIRCAPILinalg # TODO: Remove when above is removed.
223-
MLIRCAPISparseTensor # TODO: Remove when above is removed.
224217
MLIRCAPIStandard
225-
MLIRCAPIQuant # TODO: Remove when above is removed.
218+
)
219+
220+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
221+
MODULE_NAME _mlirDialectsLinalg
222+
ADD_TO_PARENT MLIRPythonSources.Dialects.linalg
223+
ROOT_DIR "${PYTHON_SOURCE_DIR}"
224+
SOURCES
225+
DialectLinalg.cpp
226+
PRIVATE_LINK_LIBS
227+
LLVMSupport
228+
EMBED_CAPI_LINK_LIBS
229+
MLIRCAPIIR
230+
MLIRCAPILinalg
231+
)
232+
233+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
234+
MODULE_NAME _mlirDialectsQuant
235+
ADD_TO_PARENT MLIRPythonSources.Dialects.quant
236+
ROOT_DIR "${PYTHON_SOURCE_DIR}"
237+
SOURCES
238+
DialectQuant.cpp
239+
PRIVATE_LINK_LIBS
240+
LLVMSupport
241+
EMBED_CAPI_LINK_LIBS
242+
MLIRCAPIIR
243+
MLIRCAPIQuant
244+
)
245+
246+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
247+
MODULE_NAME _mlirDialectsSparseTensor
248+
ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor
249+
ROOT_DIR "${PYTHON_SOURCE_DIR}"
250+
SOURCES
251+
DialectSparseTensor.cpp
252+
PRIVATE_LINK_LIBS
253+
LLVMSupport
254+
EMBED_CAPI_LINK_LIBS
255+
MLIRCAPIIR
256+
MLIRCAPISparseTensor
226257
)
227258

228259
declare_mlir_python_extension(MLIRPythonExtension.AllPassesRegistration

mlir/python/mlir/dialects/_linalg_ops_ext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Optional, Sequence, Union
77
from ..ir import *
88
from ._ods_common import get_default_loc_context
9-
from .._mlir_libs._mlir.dialects.linalg import fill_builtin_region
9+
from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region
1010
except ImportError as e:
1111
raise RuntimeError("Error loading imports from extension module") from e
1212

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
# Re-export the objects provided by pybind.
6+
from ..._mlir_libs._mlirDialectsLinalg import *
7+
58
# These are the backing OpView classes generated from the linalg tablegen
69
# definitions following these steps:
710
# DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py.
@@ -15,39 +18,39 @@
1518
# C=TensorDef(U, S.M, S.N, output=True)):
1619
# ```
1720
# using the linalg-py eDSL.
18-
# The linalg-py eDSL builds a python representation (PyRepr) that is
21+
# The linalg-py eDSL builds a python representation (PyRepr) that is
1922
# used in following ways:
2023
# 1. PyRepr -> YAML to generate the C++ and Python .td files. These
2124
# then turn into the core C++ Op classes and Python OpView classes
22-
# respectively (made available in _linalg_ops_gen). The generic OpView class
25+
# respectively (made available in _linalg_ops_gen). The generic OpView class
2326
# mechanism makes the C++ classes available to python through the CAPI.
2427
# PyRepr -> YAML currently occurs before compiler compile time.
2528
# The other steps in this category occur at compiler compile time.
26-
# 2. PyRepr -> linalg.core_named_ops calls: piggybacks on the
29+
# 2. PyRepr -> linalg.core_named_ops calls: piggybacks on the
2730
# _linalg_ops_gen classes and the OpView mechanism to build IR at
2831
# runtime in python:
2932
# a. by default, the Named Op Form is emitted, e.g.:
3033
# `linalg.matmul(lhs, rhs, outs=[out])` creates the following IR:
3134
# ```
32-
# %1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>)
35+
# %1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>)
3336
# outs(%0 : tensor<4x8xf32>)
34-
# -> tensor<4x8xf32>
37+
# -> tensor<4x8xf32>
3538
# ```
3639
# b. by setting emit_generic=True, the Generic Op Form is emitted, e.g.:
3740
# `linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)` creates the following IR:
3841
# ```
39-
# %1 = linalg.generic {indexing_maps = [...], iterator_types = [...]}
40-
# ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>)
42+
# %1 = linalg.generic {indexing_maps = [...], iterator_types = [...]}
43+
# ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>)
4144
# outs(%0 : tensor<4x8xf32>) {
42-
# ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
45+
# ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
4346
# ...
4447
# linalg.yield %3 : f32
45-
# } -> tensor<4x8xf32>
48+
# } -> tensor<4x8xf32>
4649
# ```
4750
# 3. PyRepr -> Runtime Custom Op definitions: directly generates a
4851
# linalg.generic form like in 2.b.
49-
# !!!WARNING!!!: if one creates a runtime custom op with the same name
52+
# !!!WARNING!!!: if one creates a runtime custom op with the same name
5053
# as an existing core named op, step 2. will likely take precedence.
51-
# TODO: guard against surprises and fail create Runtime Custom Ops with
54+
# TODO: guard against surprises and fail create Runtime Custom Ops with
5255
# the same name as existing Core Named Ops.
5356
from .opdsl.ops.core_named_ops import *

mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Dict, List, Sequence, Tuple, Union
66

77
from .....ir import *
8-
from ....._mlir_libs._mlir.dialects.linalg import fill_builtin_region
98

109
from .... import linalg
1110
from .... import std
@@ -173,7 +172,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
173172
f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}")
174173

175174
named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
176-
fill_builtin_region(named_op.operation)
175+
linalg.fill_builtin_region(named_op.operation)
177176
# Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps
178177
# attribute that the non-yaml path does not. The non-yaml path hardcodes the
179178
# indexing_maps in C++ directly.

mlir/python/mlir/dialects/quant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5-
from .._mlir_libs._mlir.dialects.quant import *
5+
from .._mlir_libs._mlirDialectsQuant import *

mlir/python/mlir/dialects/sparse_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55
from ._sparse_tensor_ops_gen import *
6-
from .._mlir_libs._mlir.dialects.sparse_tensor import *
6+
from .._mlir_libs._mlirDialectsSparseTensor import *
77
from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses

0 commit comments

Comments
 (0)