Skip to content

Commit 894d88a

Browse files
[mlir][python] Add facility for extending generated python ODS.
* This isn't exclusive with other mechanisms for more ODS centric op definitions, but based on discussions, we feel that we will always benefit from a python escape hatch, and that is the most natural way to write things that don't fit the mold. * I suspect this facility needs further tweaking, and once it settles, I'll document it and add more tests. * Added extensions for linalg, since it is unusable without them and continued to evolve my e2e example. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D94752
1 parent b99147b commit 894d88a

File tree

7 files changed

+161
-44
lines changed

7 files changed

+161
-44
lines changed

mlir/examples/python/.style.yapf

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[style]
2+
based_on_style = google
3+
column_limit = 80
4+
indent_width = 2

mlir/examples/python/linalg_matmul.py

Lines changed: 48 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,59 +15,69 @@
1515

1616
# TODO: This should be in the core API.
1717
def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]:
18-
"""Creates a |func| op.
18+
"""Creates a |func| op.
1919
TODO: This should really be in the MLIR API.
2020
Returns:
2121
(operation, entry_block)
2222
"""
23-
attrs = {
24-
"type": TypeAttr.get(func_type),
25-
"sym_name": StringAttr.get(name),
26-
}
27-
op = Operation.create("func", regions=1, attributes=attrs)
28-
body_region = op.regions[0]
29-
entry_block = body_region.blocks.append(*func_type.inputs)
30-
return op, entry_block
23+
attrs = {
24+
"type": TypeAttr.get(func_type),
25+
"sym_name": StringAttr.get(name),
26+
}
27+
op = Operation.create("func", regions=1, attributes=attrs)
28+
body_region = op.regions[0]
29+
entry_block = body_region.blocks.append(*func_type.inputs)
30+
return op, entry_block
3131

3232

33-
# TODO: Generate customs builder vs patching one in.
34-
def PatchMatmulOpInit(self, lhs, rhs, result, loc=None, ip=None):
35-
super(linalg.MatmulOp, self).__init__(
36-
self._ods_build_default(operands=[[lhs, rhs], [result]],
37-
results=[],
38-
loc=loc,
39-
ip=ip))
33+
def build_matmul_buffers_func(func_name, m, k, n, dtype):
34+
lhs_type = MemRefType.get(dtype, [m, k])
35+
rhs_type = MemRefType.get(dtype, [k, n])
36+
result_type = MemRefType.get(dtype, [m, n])
37+
# TODO: There should be a one-liner for this.
38+
func_type = FunctionType.get([lhs_type, rhs_type, result_type], [])
39+
_, entry = FuncOp(func_name, func_type)
40+
lhs, rhs, result = entry.arguments
41+
with InsertionPoint(entry):
42+
op = linalg.MatmulOp([lhs, rhs], [result])
4043
# TODO: Implement support for SingleBlockImplicitTerminator
41-
block = self.regions[0].blocks.append()
44+
block = op.regions[0].blocks.append()
4245
with InsertionPoint(block):
4346
linalg.YieldOp(values=[])
4447

45-
linalg.MatmulOp.__init__ = PatchMatmulOpInit
48+
std.ReturnOp([])
4649

4750

48-
def build_matmul_func(func_name, m, k, n, dtype):
49-
lhs_type = MemRefType.get(dtype, [m, k])
50-
rhs_type = MemRefType.get(dtype, [k, n])
51-
result_type = MemRefType.get(dtype, [m, n])
52-
# TODO: There should be a one-liner for this.
53-
func_type = FunctionType.get([lhs_type, rhs_type, result_type], [])
54-
_, entry = FuncOp(func_name, func_type)
55-
lhs, rhs, result = entry.arguments
56-
with InsertionPoint(entry):
57-
linalg.MatmulOp(lhs, rhs, result)
58-
std.ReturnOp([])
51+
def build_matmul_tensors_func(func_name, m, k, n, dtype):
52+
# TODO: MemRefType and TensorTypes should not have inverted dtype/shapes
53+
# from each other.
54+
lhs_type = RankedTensorType.get([m, k], dtype)
55+
rhs_type = RankedTensorType.get([k, n], dtype)
56+
result_type = RankedTensorType.get([m, n], dtype)
57+
# TODO: There should be a one-liner for this.
58+
func_type = FunctionType.get([lhs_type, rhs_type], [result_type])
59+
_, entry = FuncOp(func_name, func_type)
60+
lhs, rhs = entry.arguments
61+
with InsertionPoint(entry):
62+
op = linalg.MatmulOp([lhs, rhs], results=[result_type])
63+
# TODO: Implement support for SingleBlockImplicitTerminator
64+
block = op.regions[0].blocks.append()
65+
with InsertionPoint(block):
66+
linalg.YieldOp(values=[])
67+
std.ReturnOp([op.result])
5968

6069

6170
def run():
62-
with Context() as c, Location.unknown():
63-
module = Module.create()
64-
# TODO: This at_block_terminator vs default construct distinction feels
65-
# wrong and is error-prone.
66-
with InsertionPoint.at_block_terminator(module.body):
67-
build_matmul_func('main', 18, 32, 96, F32Type.get())
71+
with Context() as c, Location.unknown():
72+
module = Module.create()
73+
# TODO: This at_block_terminator vs default construct distinction feels
74+
# wrong and is error-prone.
75+
with InsertionPoint.at_block_terminator(module.body):
76+
build_matmul_buffers_func('main_buffers', 18, 32, 96, F32Type.get())
77+
build_matmul_tensors_func('main_tensors', 18, 32, 96, F32Type.get())
6878

69-
print(module)
70-
print(module.operation.get_asm(print_generic_op_form=True))
79+
print(module)
7180

7281

73-
if __name__ == '__main__': run()
82+
if __name__ == '__main__':
83+
run()

mlir/lib/Bindings/Python/.style.yapf

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[style]
2+
based_on_style = google
3+
column_limit = 80
4+
indent_width = 2

mlir/lib/Bindings/Python/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ set(PY_SRC_FILES
1010
mlir/_dlloader.py
1111
mlir/ir.py
1212
mlir/dialects/__init__.py
13+
mlir/dialects/_linalg.py
1314
mlir/ir.py
1415
mlir/passmanager.py
1516
mlir/transforms/__init__.py

mlir/lib/Bindings/Python/mlir/dialects/__init__.py

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,68 @@
55
# Re-export the parent _cext so that every level of the API can get it locally.
66
from .. import _cext
77

8-
def _segmented_accessor(elements, raw_segments, idx):
8+
__all__ = [
9+
"equally_sized_accessor",
10+
"extend_opview_class",
11+
"get_default_loc_context",
12+
"segmented_accessor",
13+
]
14+
15+
16+
def extend_opview_class(ext_module):
17+
"""Decorator to extend an OpView class from an extension module.
18+
19+
Extension modules can expose various entry-points:
20+
def select_opview_mixin(parent_opview_cls):
21+
If defined, allows an appropriate mixin class to be selected dynamically
22+
based on the parent OpView class. Should return NotImplemented if a
23+
decision is not made.
24+
25+
Stand-alone class with the same name as a parent OpView class (i.e.
26+
"ReturnOp").
27+
28+
Args:
29+
ext_module: A module from which to locate extensions. Can be None if not
30+
available.
31+
32+
Returns:
33+
A decorator that takes an OpView subclass and further extends it as
34+
needed.
35+
"""
36+
37+
def class_decorator(parent_opview_cls: type):
38+
if ext_module is None:
39+
return parent_opview_cls
40+
mixin_cls = NotImplemented
41+
try:
42+
select_mixin = getattr(ext_module, "select_opview_mixin")
43+
except AttributeError:
44+
# Try to default resolve it.
45+
try:
46+
select_mixin = getattr(ext_module, parent_opview_cls.__name__)
47+
except AttributeError:
48+
pass
49+
else:
50+
mixin_cls = select_mixin(parent_opview_cls)
51+
if mixin_cls is NotImplemented or mixin_cls is None:
52+
return parent_opview_cls
53+
54+
# Have a mixin_cls. Create an appropriate subclass.
55+
try:
56+
57+
class LocalOpView(mixin_cls, parent_opview_cls):
58+
pass
59+
except TypeError as e:
60+
raise TypeError(
61+
f"Could not mixin {mixin_cls} into {parent_opview_cls}") from e
62+
LocalOpView.__name__ = parent_opview_cls.__name__
63+
LocalOpView.__qualname__ = parent_opview_cls.__qualname__
64+
return LocalOpView
65+
66+
return class_decorator
67+
68+
69+
def segmented_accessor(elements, raw_segments, idx):
970
"""
1071
Returns a slice of elements corresponding to the idx-th segment.
1172
@@ -20,8 +81,8 @@ def _segmented_accessor(elements, raw_segments, idx):
2081
return elements[start:end]
2182

2283

23-
def _equally_sized_accessor(elements, n_variadic, n_preceding_simple,
24-
n_preceding_variadic):
84+
def equally_sized_accessor(elements, n_variadic, n_preceding_simple,
85+
n_preceding_variadic):
2586
"""
2687
Returns a starting position and a number of elements per variadic group
2788
assuming equally-sized groups and the given numbers of preceding groups.
@@ -42,7 +103,8 @@ def _equally_sized_accessor(elements, n_variadic, n_preceding_simple,
42103
start = n_preceding_simple + n_preceding_variadic * elements_per_group
43104
return start, elements_per_group
44105

45-
def _get_default_loc_context(location = None):
106+
107+
def get_default_loc_context(location=None):
46108
"""
47109
Returns a context in which the defaulted location is created. If the location
48110
is None, takes the current location from the stack, raises ValueError if there
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
6+
class StructuredOpMixin:
7+
"""All structured ops use the same mixin class."""
8+
9+
def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
10+
if outputs and results:
11+
raise ValueError(
12+
"Structured ops must have outputs or results, but not both.")
13+
super().__init__(
14+
self._ods_build_default(operands=[list(inputs),
15+
list(outputs)],
16+
results=list(results),
17+
loc=loc,
18+
ip=ip))
19+
20+
21+
def select_opview_mixin(parent_opview_cls):
22+
# TODO: This shouldn't be a heuristic: we should have a way to annotate
23+
# the OpView to note that it is a structured op.
24+
if ("__init__" not in parent_opview_cls.__dict__ and
25+
hasattr(parent_opview_cls, "inputs") and
26+
hasattr(parent_opview_cls, "outputs") and
27+
hasattr(parent_opview_cls, "result_tensors")):
28+
return StructuredOpMixin

mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,19 @@ using namespace mlir;
2323
using namespace mlir::tblgen;
2424

2525
/// File header and includes.
26+
/// {0} is the dialect namespace.
2627
constexpr const char *fileHeader = R"Py(
2728
# Autogenerated by mlir-tblgen; don't manually edit.
2829
2930
from . import _cext as _ods_cext
30-
from . import _segmented_accessor as _ods_segmented_accessor, _equally_sized_accessor as _ods_equally_sized_accessor, _get_default_loc_context as _ods_get_default_loc_context
31+
from . import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context
3132
_ods_ir = _ods_cext.ir
33+
34+
try:
35+
from . import _{0} as _ods_ext_module
36+
except ImportError:
37+
_ods_ext_module = None
38+
3239
)Py";
3340

3441
/// Template for dialect class:
@@ -46,6 +53,7 @@ class _Dialect(_ods_ir.Dialect):
4653
/// {1} is the operation name.
4754
constexpr const char *opClassTemplate = R"Py(
4855
@_ods_cext.register_operation(_Dialect)
56+
@_ods_extend_opview_class(_ods_ext_module)
4957
class {0}(_ods_ir.OpView):
5058
OPERATION_NAME = "{1}"
5159
)Py";
@@ -706,7 +714,7 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
706714
AttributeClasses attributeClasses;
707715
constructAttributeMapping(records, attributeClasses);
708716

709-
os << fileHeader;
717+
os << llvm::formatv(fileHeader, clDialectName.getValue());
710718
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
711719
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
712720
Operator op(rec);

0 commit comments

Comments
 (0)