Skip to content

Commit ec294eb

Browse files
[mlir][linalg] Add an InitTensorOp python builder.
* This has the API I want but I am not thrilled with the implementation. There are various things that could be improved both about the way that Python builders are mapped and the way the Linalg ops are factored to increase code sharing between C++/Python. * Landing this as-is since it at least makes the InitTensorOp usable with the right API. Will refactor underneath in follow-ons. Differential Revision: https://reviews.llvm.org/D99000
1 parent 3240910 commit ec294eb

File tree

4 files changed

+85
-23
lines changed

4 files changed

+85
-23
lines changed

mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,48 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5+
from typing import Optional, Sequence, Union
6+
from ..ir import *
7+
from ._ods_common import get_default_loc_context
8+
9+
10+
class InitTensorOp:
11+
"""Extends the linalg.init_tensor op."""
12+
13+
def __init__(self,
14+
sizes: Union[Sequence[int], Sequence[Value]],
15+
element_type: Type,
16+
*,
17+
loc=None,
18+
ip=None):
19+
"""Constructs an `init_tensor` with either static or dynamic sizes."""
20+
context = get_default_loc_context(loc)
21+
operands = []
22+
attributes = {}
23+
# TODO: Refactor the InitTensorOp to take an element type attribute and
24+
# then use normal result type inference, unifying the Python and C++ side
25+
# with a standard mechanism (versus stashing that in builders).
26+
if sizes and isinstance(sizes[0], Value):
27+
# Dynamic sizes.
28+
operands.extend(sizes)
29+
static_size_ints = [-1] * len(sizes)
30+
result_type = RankedTensorType.get(static_size_ints, element_type)
31+
else:
32+
# Static sizes.
33+
result_type = RankedTensorType.get(sizes, element_type)
34+
static_size_ints = sizes
35+
36+
index_type = IndexType.get(context)
37+
attributes["static_sizes"] = ArrayAttr.get(
38+
[IntegerAttr.get(index_type, s) for s in static_size_ints],
39+
context=context)
40+
op = self.build_generic(results=[result_type],
41+
operands=operands,
42+
attributes=attributes,
43+
loc=loc,
44+
ip=ip)
45+
OpView.__init__(self, op)
46+
547

648
class StructuredOpMixin:
749
"""All structured ops use the same mixin class."""

mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@ def extend_opview_class(ext_module):
1717
"""Decorator to extend an OpView class from an extension module.
1818
1919
Extension modules can expose various entry-points:
20+
Stand-alone class with the same name as a parent OpView class (i.e.
21+
"ReturnOp"). A name-based match is attempted first before falling back
22+
to a below mechanism.
23+
2024
def select_opview_mixin(parent_opview_cls):
2125
If defined, allows an appropriate mixin class to be selected dynamically
2226
based on the parent OpView class. Should return NotImplemented if a
2327
decision is not made.
2428
25-
Stand-alone class with the same name as a parent OpView class (i.e.
26-
"ReturnOp").
27-
2829
Args:
2930
ext_module: A module from which to locate extensions. Can be None if not
3031
available.
@@ -38,16 +39,18 @@ def class_decorator(parent_opview_cls: type):
3839
if ext_module is None:
3940
return parent_opview_cls
4041
mixin_cls = NotImplemented
42+
# First try to resolve by name.
4143
try:
42-
select_mixin = getattr(ext_module, "select_opview_mixin")
44+
mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
4345
except AttributeError:
44-
# Try to default resolve it.
46+
# Fall back to a select_opview_mixin hook.
4547
try:
46-
mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
48+
select_mixin = getattr(ext_module, "select_opview_mixin")
4749
except AttributeError:
4850
pass
49-
else:
50-
mixin_cls = select_mixin(parent_opview_cls)
51+
else:
52+
mixin_cls = select_mixin(parent_opview_cls)
53+
5154
if mixin_cls is NotImplemented or mixin_cls is None:
5255
return parent_opview_cls
5356

mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,7 @@ def matmul_poly(A=TensorDef(TV.T1, S.M, S.K),
5555
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
5656
RankedTensorType.get((16, 8), f32))
5757
def test_matmul_mono(lhs, rhs):
58-
# TODO: Enable outs inference and add sugar for InitTensorOp
59-
# construction.
60-
init_result = linalg.InitTensorOp(result=RankedTensorType.get((4, 8),
61-
f32),
62-
static_sizes=ArrayAttr.get([
63-
IntegerAttr.get(IndexType.get(), 4),
64-
IntegerAttr.get(IndexType.get(), 8)
65-
]),
66-
sizes=[])
58+
init_result = linalg.InitTensorOp([4, 8], f32)
6759
return matmul_mono(lhs, rhs, outs=[init_result.result])
6860

6961
# CHECK-LABEL: @test_i8i8i32_matmul

mlir/test/Bindings/Python/dialects/linalg/ops.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,39 @@
99
def run(f):
1010
print("\nTEST:", f.__name__)
1111
f()
12+
return f
13+
14+
15+
# CHECK-LABEL: TEST: testInitTensor
16+
@run
17+
def testInitTensor():
18+
with Context() as ctx, Location.unknown():
19+
module = Module.create()
20+
f32 = F32Type.get()
21+
with InsertionPoint(module.body):
22+
# CHECK-LABEL: func @static_sizes
23+
# CHECK: %0 = linalg.init_tensor [3, 4] : tensor<3x4xf32>
24+
@builtin.FuncOp.from_py_func()
25+
def static_sizes():
26+
return linalg.InitTensorOp([3, 4], f32)
27+
28+
# CHECK-LABEL: func @dynamic_sizes
29+
# CHECK: %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
30+
@builtin.FuncOp.from_py_func(IndexType.get(), IndexType.get())
31+
def dynamic_sizes(d0, d1):
32+
return linalg.InitTensorOp([d0, d1], f32)
33+
34+
# CHECK-LABEL: func @zero_d
35+
# CHECK: %0 = linalg.init_tensor [] : tensor<f32>
36+
@builtin.FuncOp.from_py_func()
37+
def zero_d():
38+
return linalg.InitTensorOp([], f32)
39+
40+
print(module)
1241

1342

1443
# CHECK-LABEL: TEST: testStructuredOpOnTensors
44+
@run
1545
def testStructuredOpOnTensors():
1646
with Context() as ctx, Location.unknown():
1747
module = Module.create()
@@ -31,10 +61,8 @@ def testStructuredOpOnTensors():
3161
print(module)
3262

3363

34-
run(testStructuredOpOnTensors)
35-
36-
3764
# CHECK-LABEL: TEST: testStructuredOpOnBuffers
65+
@run
3866
def testStructuredOpOnBuffers():
3967
with Context() as ctx, Location.unknown():
4068
module = Module.create()
@@ -52,6 +80,3 @@ def testStructuredOpOnBuffers():
5280

5381
# CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>)
5482
print(module)
55-
56-
57-
run(testStructuredOpOnBuffers)

0 commit comments

Comments
 (0)