Skip to content

Commit a2288a8

Browse files
authored
[mlir][python] remove mixins (llvm#68853)
This PR replaces the mixin `OpView` extension mechanism with the standard inheritance mechanism. Why? Firstly, mixins are not very pythonic (inheritance is usually used for this), a little convoluted, and too "tight" (can only be used in the immediately adjacent `_ext.py`). Secondly, it (mixins) are now blocking are correct implementation of "value builders" (see [here](llvm#68764)) where the problem becomes how to choose the correct base class that the value builder should call. This PR looks big/complicated but appearances are deceiving; 4 things were needed to make this work: 1. Drop `skipDefaultBuilders` in `OpPythonBindingGen::emitDefaultOpBuilders` 2. Former mixin extension classes are converted to inherit from the generated `OpView` instead of being "mixins" a. extension classes that simply were calling into an already generated `super().__init__` continue to do so b. (almost all) extension classes that were calling `self.build_generic` because of a lack of default builder being generated can now also just call `super().__init__` 3. To handle the [lone single use-case](https://sourcegraph.com/search?q=context%3Aglobal+select_opview_mixin&patternType=standard&sm=1&groupBy=repo) of `select_opview_mixin`, namely [linalg](https://github.com/llvm/llvm-project/blob/main/mlir/python/mlir/dialects/_linalg_ops_ext.py#L38), only a small change was necessary in `opdsl/lang/emitter.py` (thanks to the emission/generation of default builders/`__init__`s) 4. since the `extend_opview_class` decorator is removed, we need a way to register extension classes as the desired `OpView` that `op.opview` conjures into existence; so we do the standard thing and just enable replacing the existing registered `OpView` i.e., `register_operation(_Dialect, replace=True)`. Note, the upgrade path for the common case is to change an extension to inherit from the generated builder and decorate it with `register_operation(_Dialect, replace=True)`. In the slightly more complicated case where `super().__init(self.build_generic(...))` is called in the extension's `__init__`, this needs to be updated to call `__init__` in `OpView`, i.e., the grandparent (see updated docs). Note, also `<DIALECT>_ext.py` files/modules will no longer be automatically loaded. Note, the PR has 3 base commits that look funny but this was done for the purpose of tracking the line history of moving the `<DIALECT>_ops_ext.py` class into `<DIALECT>.py` and updating (commit labeled "fix").
1 parent a30095a commit a2288a8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2814
-2920
lines changed

mlir/docs/Bindings/Python.md

Lines changed: 58 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,90 +1017,79 @@ very generic signature.
10171017

10181018
#### Extending Generated Op Classes
10191019

1020-
Note that this is a rather complex mechanism and this section errs on the side
1021-
of explicitness. Users are encouraged to find an example and duplicate it if
1022-
they don't feel the need to understand the subtlety. The `builtin` dialect
1023-
provides some relatively simple examples.
1024-
10251020
As mentioned above, the build system generates Python sources like
10261021
`_{DIALECT_NAMESPACE}_ops_gen.py` for each dialect with Python bindings. It is
1027-
often desirable to to use these generated classes as a starting point for
1028-
further customization, so an extension mechanism is provided to make this easy
1029-
(you are always free to do ad-hoc patching in your `{DIALECT_NAMESPACE}.py` file
1030-
but we prefer a more standard mechanism that is applied uniformly).
1022+
often desirable to use these generated classes as a starting point for
1023+
further customization, so an extension mechanism is provided to make this easy.
1024+
This mechanism uses conventional inheritance combined with `OpView` registration.
1025+
For example, the default builder for `arith.constant`
1026+
1027+
```python
1028+
class ConstantOp(_ods_ir.OpView):
1029+
OPERATION_NAME = "arith.constant"
1030+
1031+
_ODS_REGIONS = (0, True)
1032+
1033+
def __init__(self, value, *, loc=None, ip=None):
1034+
...
1035+
```
10311036

1032-
To provide extensions, add a `_{DIALECT_NAMESPACE}_ops_ext.py` file to the
1033-
`dialects` module (i.e. adjacent to your `{DIALECT_NAMESPACE}.py` top-level and
1034-
the `*_ops_gen.py` file). Using the `builtin` dialect and `FuncOp` as an
1035-
example, the generated code will include an import like this:
1037+
expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`).
1038+
Thus, a natural extension is a builder that accepts a MLIR type and a Python value and instantiates the appropriate `TypedAttr`:
10361039

10371040
```python
1038-
try:
1039-
from . import _builtin_ops_ext as _ods_ext_module
1040-
except ImportError:
1041-
_ods_ext_module = None
1041+
from typing import Union
1042+
1043+
from mlir.ir import Type, IntegerAttr, FloatAttr
1044+
from mlir.dialects._arith_ops_gen import _Dialect, ConstantOp
1045+
from mlir.dialects._ods_common import _cext
1046+
1047+
@_cext.register_operation(_Dialect, replace=True)
1048+
class ConstantOpExt(ConstantOp):
1049+
def __init__(
1050+
self, result: Type, value: Union[int, float], *, loc=None, ip=None
1051+
):
1052+
if isinstance(value, int):
1053+
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
1054+
elif isinstance(value, float):
1055+
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
1056+
else:
1057+
raise NotImplementedError(f"Building `arith.constant` not supported for {result=} {value=}")
10421058
```
10431059

1044-
Then for each generated concrete `OpView` subclass, it will apply a decorator
1045-
like:
1060+
which enables building an instance of `arith.constant` like so:
10461061

10471062
```python
1048-
@_ods_cext.register_operation(_Dialect)
1049-
@_ods_extend_opview_class(_ods_ext_module)
1050-
class FuncOp(_ods_ir.OpView):
1063+
from mlir.ir import F32Type
1064+
1065+
a = ConstantOpExt(F32Type.get(), 42.42)
1066+
b = ConstantOpExt(IntegerType.get_signless(32), 42)
10511067
```
10521068

1053-
See the `_ods_common.py` `extend_opview_class` function for details of the
1054-
mechanism. At a high level:
1055-
1056-
* If the extension module exists, locate an extension class for the op (in
1057-
this example, `FuncOp`):
1058-
* First by looking for an attribute with the exact name in the extension
1059-
module.
1060-
* Falling back to calling a `select_opview_mixin(parent_opview_cls)`
1061-
function defined in the extension module.
1062-
* If a mixin class is found, a new subclass is dynamically created that
1063-
multiply inherits from `({_builtin_ops_ext.FuncOp},
1064-
_builtin_ops_gen.FuncOp)`.
1065-
1066-
The mixin class should not inherit from anything (i.e. directly extends `object`
1067-
only). The facility is typically used to define custom `__init__` methods,
1068-
properties, instance methods and static methods. Due to the inheritance
1069-
ordering, the mixin class can act as though it extends the generated `OpView`
1070-
subclass in most contexts (i.e. `issubclass(_builtin_ops_ext.FuncOp, OpView)`
1071-
will return `False` but usage generally allows you treat it as duck typed as an
1072-
`OpView`).
1073-
1074-
There are a couple of recommendations, given how the class hierarchy is defined:
1075-
1076-
* For static methods that need to instantiate the actual "leaf" op (which is
1077-
dynamically generated and would result in circular dependencies to try to
1078-
reference by name), prefer to use `@classmethod` and the concrete subclass
1079-
will be provided as your first `cls` argument. See
1080-
`_builtin_ops_ext.FuncOp.from_py_func` as an example.
1081-
* If seeking to replace the generated `__init__` method entirely, you may
1082-
actually want to invoke the super-super-class `mlir.ir.OpView` constructor
1083-
directly, as it takes an `mlir.ir.Operation`, which is likely what you are
1084-
constructing (i.e. the generated `__init__` method likely adds more API
1085-
constraints than you want to expose in a custom builder).
1086-
1087-
A pattern that comes up frequently is wanting to provide a sugared `__init__`
1088-
method which has optional or type-polymorphism/implicit conversions but to
1089-
otherwise want to invoke the default op building logic. For such cases, it is
1090-
recommended to use an idiom such as:
1069+
Note, three key aspects of the extension mechanism in this example:
1070+
1071+
1. `ConstantOpExt` directly inherits from the generated `ConstantOp`;
1072+
2. in this, simplest, case all that's required is a call to the super class' initializer, i.e., `super().__init__(...)`;
1073+
3. in order to register `ConstantOpExt` as the preferred `OpView` that is returned by `mlir.ir.Operation.opview` (see [Operations, Regions and Blocks](#operations-regions-and-blocks))
1074+
we decorate the class with `@_cext.register_operation(_Dialect, replace=True)`, **where the `replace=True` must be used**.
1075+
1076+
In some more complex cases it might be necessary to explicitly build the `OpView` through `OpView.build_generic` (see [Default Builder](#default-builder)), just as is performed by the generated builders.
1077+
I.e., we must call `OpView.build_generic` **and pass the result to `OpView.__init__`**, where the small issue becomes that the latter is already overridden by the generated builder.
1078+
Thus, we must call a method of a super class' super class (the "grandparent"); for example:
10911079

10921080
```python
1093-
def __init__(self, sugar, spice, *, loc=None, ip=None):
1094-
... massage into result_type, operands, attributes ...
1095-
OpView.__init__(self, self.build_generic(
1096-
results=[result_type],
1097-
operands=operands,
1098-
attributes=attributes,
1099-
loc=loc,
1100-
ip=ip))
1081+
from mlir.dialects._scf_ops_gen import _Dialect, ForOp
1082+
from mlir.dialects._ods_common import _cext
1083+
1084+
@_cext.register_operation(_Dialect, replace=True)
1085+
class ForOpExt(ForOp):
1086+
def __init__(self, lower_bound, upper_bound, step, iter_args, *, loc=None, ip=None):
1087+
...
1088+
super(ForOp, self).__init__(self.build_generic(...))
11011089
```
11021090

1103-
Refer to the documentation for `build_generic` for more information.
1091+
where `OpView.__init__` is called via `super(ForOp, self).__init__`.
1092+
Note, there are alternatives ways to implement this (e.g., explicitly writing `OpView.__init__`); see any discussion on Python inheritance.
11041093

11051094
## Providing Python bindings for a dialect
11061095

mlir/lib/Bindings/Python/Globals.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ class PyGlobals {
7777
pybind11::object pyClass);
7878

7979
/// Adds a concrete implementation operation class.
80-
/// Raises an exception if the mapping already exists.
80+
/// Raises an exception if the mapping already exists and replace == false.
8181
/// This is intended to be called by implementation code.
8282
void registerOperationImpl(const std::string &operationName,
83-
pybind11::object pyClass);
83+
pybind11::object pyClass, bool replace = false);
8484

8585
/// Returns the custom Attribute builder for Attribute kind.
8686
std::optional<pybind11::function>

mlir/lib/Bindings/Python/IRModule.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
9696
}
9797

9898
void PyGlobals::registerOperationImpl(const std::string &operationName,
99-
py::object pyClass) {
99+
py::object pyClass, bool replace) {
100100
py::object &found = operationClassMap[operationName];
101-
if (found) {
101+
if (found && !replace) {
102102
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
103103
"' is already registered.")
104104
.str());

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ PYBIND11_MODULE(_mlir, m) {
4141
"dialect_namespace"_a, "dialect_class"_a,
4242
"Testing hook for directly registering a dialect")
4343
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
44-
"operation_name"_a, "operation_class"_a,
44+
"operation_name"_a, "operation_class"_a, "replace"_a = false,
4545
"Testing hook for directly registering an operation");
4646

4747
// Aside from making the globals accessible to python, having python manage
@@ -63,20 +63,21 @@ PYBIND11_MODULE(_mlir, m) {
6363
"Class decorator for registering a custom Dialect wrapper");
6464
m.def(
6565
"register_operation",
66-
[](const py::object &dialectClass) -> py::cpp_function {
66+
[](const py::object &dialectClass, bool replace) -> py::cpp_function {
6767
return py::cpp_function(
68-
[dialectClass](py::object opClass) -> py::object {
68+
[dialectClass, replace](py::object opClass) -> py::object {
6969
std::string operationName =
7070
opClass.attr("OPERATION_NAME").cast<std::string>();
71-
PyGlobals::get().registerOperationImpl(operationName, opClass);
71+
PyGlobals::get().registerOperationImpl(operationName, opClass,
72+
replace);
7273

7374
// Dict-stuff the new opClass by name onto the dialect class.
7475
py::object opClassName = opClass.attr("__name__");
7576
dialectClass.attr(opClassName) = opClass;
7677
return opClass;
7778
});
7879
},
79-
"dialect_class"_a,
80+
"dialect_class"_a, "replace"_a = false,
8081
"Produce a class decorator for registering an Operation class as part of "
8182
"a dialect");
8283
m.def(

mlir/python/CMakeLists.txt

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ declare_mlir_dialect_python_bindings(
5252
TD_FILE dialects/AffineOps.td
5353
SOURCES
5454
dialects/affine.py
55-
dialects/_affine_ops_ext.py
5655
DIALECT_NAME affine
5756
GEN_ENUM_BINDINGS)
5857

@@ -78,7 +77,6 @@ declare_mlir_dialect_python_bindings(
7877
TD_FILE dialects/BufferizationOps.td
7978
SOURCES
8079
dialects/bufferization.py
81-
dialects/_bufferization_ops_ext.py
8280
DIALECT_NAME bufferization
8381
GEN_ENUM_BINDINGS_TD_FILE
8482
"../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
@@ -90,7 +88,6 @@ declare_mlir_dialect_python_bindings(
9088
TD_FILE dialects/BuiltinOps.td
9189
SOURCES
9290
dialects/builtin.py
93-
dialects/_builtin_ops_ext.py
9491
DIALECT_NAME builtin)
9592

9693
declare_mlir_dialect_python_bindings(
@@ -115,7 +112,6 @@ declare_mlir_dialect_python_bindings(
115112
TD_FILE dialects/FuncOps.td
116113
SOURCES
117114
dialects/func.py
118-
dialects/_func_ops_ext.py
119115
DIALECT_NAME func)
120116

121117
declare_mlir_dialect_python_bindings(
@@ -131,7 +127,6 @@ declare_mlir_dialect_python_bindings(
131127
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
132128
TD_FILE dialects/LinalgOps.td
133129
SOURCES
134-
dialects/_linalg_ops_ext.py
135130
SOURCES_GLOB
136131
dialects/linalg/*.py
137132
DIALECT_NAME linalg
@@ -152,7 +147,6 @@ ADD_TO_PARENT MLIRPythonSources.Dialects
152147
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
153148
TD_FILE dialects/TransformPDLExtensionOps.td
154149
SOURCES
155-
dialects/_transform_pdl_extension_ops_ext.py
156150
dialects/transform/pdl.py
157151
DIALECT_NAME transform
158152
EXTENSION_NAME transform_pdl_extension)
@@ -162,7 +156,6 @@ declare_mlir_dialect_python_bindings(
162156
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
163157
TD_FILE dialects/TransformOps.td
164158
SOURCES
165-
dialects/_transform_ops_ext.py
166159
dialects/transform/__init__.py
167160
_mlir_libs/_mlir/dialects/transform/__init__.pyi
168161
DIALECT_NAME transform
@@ -175,7 +168,6 @@ declare_mlir_dialect_extension_python_bindings(
175168
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
176169
TD_FILE dialects/BufferizationTransformOps.td
177170
SOURCES
178-
dialects/_bufferization_transform_ops_ext.py
179171
dialects/transform/bufferization.py
180172
DIALECT_NAME transform
181173
EXTENSION_NAME bufferization_transform)
@@ -185,7 +177,6 @@ declare_mlir_dialect_extension_python_bindings(
185177
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
186178
TD_FILE dialects/GPUTransformOps.td
187179
SOURCES
188-
dialects/_gpu_transform_ops_ext.py
189180
dialects/transform/gpu.py
190181
DIALECT_NAME transform
191182
EXTENSION_NAME gpu_transform)
@@ -195,7 +186,6 @@ declare_mlir_dialect_extension_python_bindings(
195186
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
196187
TD_FILE dialects/SCFLoopTransformOps.td
197188
SOURCES
198-
dialects/_loop_transform_ops_ext.py
199189
dialects/transform/loop.py
200190
DIALECT_NAME transform
201191
EXTENSION_NAME loop_transform)
@@ -205,7 +195,6 @@ declare_mlir_dialect_extension_python_bindings(
205195
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
206196
TD_FILE dialects/MemRefTransformOps.td
207197
SOURCES
208-
dialects/_memref_transform_ops_ext.py
209198
dialects/transform/memref.py
210199
DIALECT_NAME transform
211200
EXTENSION_NAME memref_transform)
@@ -224,7 +213,6 @@ declare_mlir_dialect_extension_python_bindings(
224213
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
225214
TD_FILE dialects/LinalgStructuredTransformOps.td
226215
SOURCES
227-
dialects/_structured_transform_ops_ext.py
228216
dialects/transform/structured.py
229217
DIALECT_NAME transform
230218
EXTENSION_NAME structured_transform
@@ -246,7 +234,6 @@ declare_mlir_dialect_extension_python_bindings(
246234
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
247235
TD_FILE dialects/TensorTransformOps.td
248236
SOURCES
249-
dialects/_tensor_transform_ops_ext.py
250237
dialects/transform/tensor.py
251238
DIALECT_NAME transform
252239
EXTENSION_NAME tensor_transform)
@@ -276,7 +263,6 @@ declare_mlir_dialect_python_bindings(
276263
TD_FILE dialects/ArithOps.td
277264
SOURCES
278265
dialects/arith.py
279-
dialects/_arith_ops_ext.py
280266
DIALECT_NAME arith
281267
GEN_ENUM_BINDINGS)
282268

@@ -286,7 +272,6 @@ declare_mlir_dialect_python_bindings(
286272
TD_FILE dialects/MemRefOps.td
287273
SOURCES
288274
dialects/memref.py
289-
dialects/_memref_ops_ext.py
290275
DIALECT_NAME memref)
291276

292277
declare_mlir_dialect_python_bindings(
@@ -295,7 +280,6 @@ declare_mlir_dialect_python_bindings(
295280
TD_FILE dialects/MLProgramOps.td
296281
SOURCES
297282
dialects/ml_program.py
298-
dialects/_ml_program_ops_ext.py
299283
DIALECT_NAME ml_program)
300284

301285
declare_mlir_dialect_python_bindings(
@@ -339,7 +323,6 @@ declare_mlir_dialect_python_bindings(
339323
TD_FILE dialects/PDLOps.td
340324
SOURCES
341325
dialects/pdl.py
342-
dialects/_pdl_ops_ext.py
343326
_mlir_libs/_mlir/dialects/pdl.pyi
344327
DIALECT_NAME pdl)
345328

@@ -357,7 +340,6 @@ declare_mlir_dialect_python_bindings(
357340
TD_FILE dialects/SCFOps.td
358341
SOURCES
359342
dialects/scf.py
360-
dialects/_scf_ops_ext.py
361343
DIALECT_NAME scf)
362344

363345
declare_mlir_dialect_python_bindings(
@@ -383,7 +365,6 @@ declare_mlir_dialect_python_bindings(
383365
TD_FILE dialects/TensorOps.td
384366
SOURCES
385367
dialects/tensor.py
386-
dialects/_tensor_ops_ext.py
387368
DIALECT_NAME tensor)
388369

389370
declare_mlir_dialect_python_bindings(

0 commit comments

Comments
 (0)