Skip to content

Commit b164f23

Browse files
committed
[mlir][python] support taking ops instead of values in op constructors
Introduce support for accepting ops instead of values when constructing ops. A single-result op can be used instead of a value, including in lists of values, and any op can be used instead of a list of values. This is similar to, but more powerful, than the C++ API that allows for implicitly casting an OpType to Value if it is statically known to have a single result - the cast in Python is based on the op dynamically having a single result, and also handles the multi-result case. This allows to build IR in a more concise way: op = dialect.produce_multiple_results() other = dialect.produce_single_result() dialect.consume_multiple_results(other, op) instead of having to access the results manually op = dialect.produce.multiple_results() other = dialect.produce_single_result() dialect.consume_multiple_results(other.result, op.operation.results) The dispatch is implemented directly in Python and is triggered automatically for autogenerated OpView subclasses. Extension OpView classes should use the functions provided in ods_common.py if they want to implement this behavior. An alternative could be to implement the dispatch in the C++ bindings code, but it would require to forward opaque types through all Python functions down to a binding call, which makes it hard to inspect them in Python, e.g., to obtain the types of values. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D111306
1 parent cb879d0 commit b164f23

File tree

9 files changed

+271
-117
lines changed

9 files changed

+271
-117
lines changed

mlir/python/mlir/dialects/_linalg_ops_ext.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
except ImportError as e:
1111
raise RuntimeError("Error loading imports from extension module") from e
1212

13+
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
1314

1415
def isa(cls: Type, ty: Type):
1516
try:
@@ -26,11 +27,12 @@ def __init__(self, output: Value, value: Value, *, loc=None, ip=None):
2627
results = []
2728
if isa(RankedTensorType, output.type):
2829
results = [output.type]
29-
op = self.build_generic(results=results,
30-
operands=[value, output],
31-
attributes=None,
32-
loc=loc,
33-
ip=ip)
30+
op = self.build_generic(
31+
results=results,
32+
operands=[_get_op_result_or_value(o) for o in [value, output]],
33+
attributes=None,
34+
loc=loc,
35+
ip=ip)
3436
OpView.__init__(self, op)
3537
linalgDialect = Context.current.get_dialect_descriptor("linalg")
3638
fill_builtin_region(linalgDialect, self.operation)

mlir/python/mlir/dialects/_ods_common.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
# Provide a convenient name for sub-packages to resolve the main C-extension
66
# with a relative import.
77
from .._mlir_libs import _mlir as _cext
8+
from typing import Sequence as _Sequence, Union as _Union
89

910
__all__ = [
1011
"equally_sized_accessor",
1112
"extend_opview_class",
1213
"get_default_loc_context",
14+
"get_op_result_or_value",
15+
"get_op_results_or_values",
1316
"segmented_accessor",
1417
]
1518

@@ -118,3 +121,38 @@ def get_default_loc_context(location=None):
118121
# Location.current raises ValueError if there is no current location.
119122
return _cext.ir.Location.current.context
120123
return location.context
124+
125+
126+
def get_op_result_or_value(
127+
arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]
128+
) -> _cext.ir.Value:
129+
"""Returns the given value or the single result of the given op.
130+
131+
This is useful to implement op constructors so that they can take other ops as
132+
arguments instead of requiring the caller to extract results for every op.
133+
Raises ValueError if provided with an op that doesn't have a single result.
134+
"""
135+
if isinstance(arg, _cext.ir.OpView):
136+
return arg.operation.result
137+
elif isinstance(arg, _cext.ir.Operation):
138+
return arg.result
139+
else:
140+
assert isinstance(arg, _cext.ir.Value)
141+
return arg
142+
143+
144+
def get_op_results_or_values(
145+
arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _Sequence[_cext.ir.Value]]
146+
) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]:
147+
"""Returns the given sequence of values or the results of the given op.
148+
149+
This is useful to implement op constructors so that they can take other ops as
150+
lists of arguments instead of requiring the caller to extract results for
151+
every op.
152+
"""
153+
if isinstance(arg, _cext.ir.OpView):
154+
return arg.operation.results
155+
elif isinstance(arg, _cext.ir.Operation):
156+
return arg.results
157+
else:
158+
return arg

mlir/python/mlir/dialects/_scf_ops_ext.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
except ImportError as e:
88
raise RuntimeError("Error loading imports from extension module") from e
99

10-
from typing import Any, Sequence
11-
10+
from typing import Any, Optional, Sequence, Union
11+
from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
1212

1313
class ForOp:
1414
"""Specialization for the SCF for op class."""
@@ -17,7 +17,8 @@ def __init__(self,
1717
lower_bound,
1818
upper_bound,
1919
step,
20-
iter_args: Sequence[Any] = [],
20+
iter_args: Optional[Union[Operation, OpView,
21+
Sequence[Value]]] = None,
2122
*,
2223
loc=None,
2324
ip=None):
@@ -26,14 +27,22 @@ def __init__(self,
2627
- `lower_bound` is the value to use as lower bound of the loop.
2728
- `upper_bound` is the value to use as upper bound of the loop.
2829
- `step` is the value to use as loop step.
29-
- `iter_args` is a list of additional loop-carried arguments.
30+
- `iter_args` is a list of additional loop-carried arguments or an operation
31+
producing them as results.
3032
"""
33+
if iter_args is None:
34+
iter_args = []
35+
iter_args = _get_op_results_or_values(iter_args)
36+
3137
results = [arg.type for arg in iter_args]
3238
super().__init__(
3339
self.build_generic(
3440
regions=1,
3541
results=results,
36-
operands=[lower_bound, upper_bound, step] + list(iter_args),
42+
operands=[
43+
_get_op_result_or_value(o)
44+
for o in [lower_bound, upper_bound, step]
45+
] + list(iter_args),
3746
loc=loc,
3847
ip=ip))
3948
self.regions[0].blocks.append(IndexType.get(), *results)

mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,23 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5-
from typing import Dict, List
5+
from typing import Dict, List, Sequence, Union
66

77
from contextlib import contextmanager
88
import functools
99
import inspect
1010
import threading
1111

1212
from ..... import ir
13+
from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
1314
from .comprehension import *
1415
from .config import *
1516
from .emitter import *
1617

1718
_CONTEXT = threading.local()
1819

20+
StructuredOpOuts = Union[ir.Operation, ir.OpView, ir.OpResultList,
21+
Sequence[Union[ir.Value, ir.Operation, ir.OpView]]]
1922

2023
@contextmanager
2124
def bind_op_def(model: LinalgOpDef):
@@ -37,14 +40,24 @@ def current_op_def() -> LinalgOpDef:
3740
"but none is set. Did you mean to call this in an op definition?")
3841

3942

43+
def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList:
44+
if isinstance(outs, (ir.Operation, ir.OpView)):
45+
return _get_op_results_or_values(outs)
46+
elif isinstance(outs, ir.OpResultList):
47+
return outs
48+
49+
return [_get_op_result_or_value(o) for o in outs]
50+
51+
4052
class DefinedOpCallable:
4153
"""Callable that wraps any defined op function."""
4254

4355
def __init__(self, op_name: str, model: LinalgOpDef):
4456
self.op_name = op_name
4557
self.model = model
4658

47-
def __call__(self, *ins: ir.Value, outs: Sequence[ir.Value], **kwargs):
59+
def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value],
60+
outs: StructuredOpOuts, **kwargs):
4861
"""Emits the corresponding op definition as IR.
4962
5063
Most arguments are passed through to the underlying emitter. The following
@@ -73,17 +86,19 @@ def __call__(self, *ins: ir.Value, outs: Sequence[ir.Value], **kwargs):
7386
emit_generic or not ctx.is_registered_operation(fully_qualified_name))
7487

7588
op_config = op_configs[0]
89+
out_values = _prepare_structured_op_outs(outs)
90+
in_values = [_get_op_result_or_value(i) for i in ins]
7691
if op_config.structured_op:
7792
if emit_generic:
7893
return emit_generic_structured_op(
79-
op_config.structured_op, *ins, outs=outs, **kwargs)
94+
op_config.structured_op, *in_values, outs=out_values, **kwargs)
8095
else:
8196
return emit_named_structured_op(
8297
op_config.structured_op,
8398
self.op_name,
8499
self.model.metadata.cpp_class_name,
85-
*ins,
86-
outs=outs,
100+
*in_values,
101+
outs=out_values,
87102
**kwargs)
88103

89104
raise NotImplementedError(

mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5-
from typing import Dict, Sequence
5+
from typing import Dict, List, Sequence, Tuple, Union
66

77
from .....ir import *
88
from ....._mlir_libs._mlir.dialects.linalg import fill_builtin_region
99

1010
from .... import linalg
1111
from .... import std
1212
from .... import math
13+
from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
1314

1415
from .scalar_expr import *
1516
from .config import *
@@ -18,8 +19,10 @@
1819
__all__ = [
1920
"emit_generic_structured_op",
2021
"emit_named_structured_op",
22+
"ValueList",
2123
]
2224

25+
ValueList = Union[Sequence[Value], OpResultList]
2326

2427
def isa(cls: Type, ty: Type):
2528
try:
@@ -30,17 +33,18 @@ def isa(cls: Type, ty: Type):
3033

3134

3235
def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
33-
*ins: Value, outs: Sequence[Value],
36+
*ins: Value, outs: ValueList,
3437
**attrs: Sequence[int]):
3538
all_arg_defs = op_config.ordered_operands
3639
in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "InputOperand"]
3740
out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "OutputOperand"]
3841
attr_arg_defs = [arg for arg in all_arg_defs if arg.usage == "IndexAttribute"]
3942

40-
# Verify outs is a sequence.
41-
if not isinstance(outs, Sequence):
42-
raise ValueError(f"Expected named argument outs to have type Sequence "
43-
f"but got {type(outs)}")
43+
# Verify outs is a sequence or a list of results.
44+
if not isinstance(outs, (Sequence, OpResultList)):
45+
raise ValueError(
46+
f"Expected named argument outs to have type Sequence or OpResultLis but got {type(outs)}"
47+
)
4448

4549
# Arity validation.
4650
if len(ins) != len(in_arg_defs):
@@ -122,7 +126,7 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
122126

123127

124128
def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
125-
outs: Sequence[Value], **attrs: Sequence[int]):
129+
outs: ValueList, **attrs: Sequence[int]):
126130
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
127131
indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
128132
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
@@ -153,8 +157,8 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
153157

154158

155159
def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
156-
op_class_name: str, *ins: Value,
157-
outs: Sequence[Value], **attrs: Sequence[int]):
160+
op_class_name: str, *ins: Value, outs: ValueList,
161+
**attrs: Sequence[int]):
158162
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
159163
indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
160164
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
@@ -355,11 +359,11 @@ def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
355359
return std.MinUIOp(lhs.type, lhs, rhs).result
356360
raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}")
357361

358-
def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
359-
in_arg_defs: Sequence[OperandDefConfig],
360-
ins: Sequence[Value],
361-
out_arg_defs: Sequence[OperandDefConfig],
362-
outs: Sequence[Value]):
362+
def _infer_structured_outs(
363+
op_config: LinalgStructuredOpConfig,
364+
in_arg_defs: Sequence[OperandDefConfig], ins: Sequence[Value],
365+
out_arg_defs: Sequence[OperandDefConfig],
366+
outs: Union[Sequence[Value], OpResultList]) -> Tuple[ValueList, List[Type]]:
363367
"""Infers implicit outs and output types.
364368
365369
Respects existing contents of outs if not empty.

mlir/test/mlir-tblgen/op-python-bindings.td

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
2424
// CHECK: operands = []
2525
// CHECK: results = []
2626
// CHECK: attributes = {}
27-
// CHECK: operands.append(variadic1)
28-
// CHECK: operands.append(non_variadic)
29-
// CHECK: if variadic2 is not None: operands.append(variadic2)
27+
// CHECK: operands.append(_get_op_results_or_values(variadic1))
28+
// CHECK: operands.append(_get_op_result_or_value(non_variadic))
29+
// CHECK: if variadic2 is not None: operands.append(_get_op_result_or_value(variadic2))
3030
// CHECK: _ods_successors = None
3131
// CHECK: super().__init__(self.build_generic(
3232
// CHECK: attributes=attributes, results=results, operands=operands,
@@ -150,8 +150,8 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
150150
// CHECK: operands = []
151151
// CHECK: results = []
152152
// CHECK: attributes = {}
153-
// CHECK: operands.append(_gen_arg_0)
154-
// CHECK: operands.append(_gen_arg_2)
153+
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_0))
154+
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_2))
155155
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
156156
// CHECK: _ods_get_default_loc_context(loc))
157157
// CHECK: if is_ is not None: attributes["is"] = is_
@@ -197,9 +197,9 @@ def MissingNamesOp : TestOp<"missing_names"> {
197197
// CHECK: results.append(i32)
198198
// CHECK: results.append(_gen_res_1)
199199
// CHECK: results.append(i64)
200-
// CHECK: operands.append(_gen_arg_0)
201-
// CHECK: operands.append(f32)
202-
// CHECK: operands.append(_gen_arg_2)
200+
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_0))
201+
// CHECK: operands.append(_get_op_result_or_value(f32))
202+
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_2))
203203
// CHECK: _ods_successors = None
204204
// CHECK: super().__init__(self.build_generic(
205205
// CHECK: attributes=attributes, results=results, operands=operands,
@@ -230,8 +230,8 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
230230
// CHECK: operands = []
231231
// CHECK: results = []
232232
// CHECK: attributes = {}
233-
// CHECK: operands.append(non_variadic)
234-
// CHECK: operands.extend(variadic)
233+
// CHECK: operands.append(_get_op_result_or_value(non_variadic))
234+
// CHECK: operands.extend(_get_op_results_or_values(variadic))
235235
// CHECK: _ods_successors = None
236236
// CHECK: super().__init__(self.build_generic(
237237
// CHECK: attributes=attributes, results=results, operands=operands,
@@ -285,7 +285,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
285285
// CHECK: operands = []
286286
// CHECK: results = []
287287
// CHECK: attributes = {}
288-
// CHECK: operands.append(in_)
288+
// CHECK: operands.append(_get_op_result_or_value(in_))
289289
// CHECK: _ods_successors = None
290290
// CHECK: super().__init__(self.build_generic(
291291
// CHECK: attributes=attributes, results=results, operands=operands,
@@ -353,8 +353,8 @@ def SimpleOp : TestOp<"simple"> {
353353
// CHECK: attributes = {}
354354
// CHECK: results.append(i64)
355355
// CHECK: results.append(f64)
356-
// CHECK: operands.append(i32)
357-
// CHECK: operands.append(f32)
356+
// CHECK: operands.append(_get_op_result_or_value(i32))
357+
// CHECK: operands.append(_get_op_result_or_value(f32))
358358
// CHECK: _ods_successors = None
359359
// CHECK: super().__init__(self.build_generic(
360360
// CHECK: attributes=attributes, results=results, operands=operands,

mlir/test/python/dialects/linalg/ops.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,30 @@ def generic_form(lhs, rhs):
185185
return linalg.matmul(lhs, rhs, outs=[init_result.result], emit_generic=True)
186186

187187
print(module)
188+
189+
190+
# CHECK-LABEL: TEST: testOpResultFromOtherOp
191+
@run
192+
def testOpResultFromOtherOp():
193+
with Context(), Location.unknown():
194+
module = Module.create()
195+
f32 = F32Type.get()
196+
with InsertionPoint(module.body):
197+
198+
@builtin.FuncOp.from_py_func(
199+
RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8),
200+
f32))
201+
def pass_an_op_directly(arg0, arg1):
202+
one = std.ConstantOp(F32Type.get(), 1.0)
203+
# CHECK: %[[LHS:.*]] = linalg.fill
204+
lhs = linalg.FillOp(arg0, one)
205+
# CHECK: %[[RHS:.*]] = linalg.fill
206+
rhs = linalg.FillOp(arg1, one)
207+
# CHECK: %[[INIT:.*]] = linalg.init_tensor
208+
init = linalg.InitTensorOp([4, 8], f32)
209+
# CHECK: linalg.matmul
210+
# CHECK: ins(%[[LHS]], %[[RHS]]
211+
# CHECK: outs(%[[INIT]]
212+
return linalg.matmul(lhs, rhs, outs=init)
213+
214+
print(module)

0 commit comments

Comments
 (0)