Skip to content

Commit 4eefc8d

Browse files
authored
[MLIR][Python] enhance python api for tensor.empty (#103087)
Since we have extended `EmptyOp`, maybe we should also provide a corresponding `tensor.empty` method. In the downstream usage, I tend to use APIs with all lowercase letters to create ops, so having a `tensor.empty` to replace the extended `tensor.EmptyOp` would keep my code style consistent.
1 parent 7d5281a commit 4eefc8d

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

mlir/python/mlir/dialects/tensor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from typing import Sequence, Union
1515
from ._ods_common import _cext as _ods_cext
16+
from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results
1617

1718

1819
@_ods_cext.register_operation(_Dialect, replace=True)
@@ -43,6 +44,18 @@ def __init__(
4344
super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
4445

4546

47+
def empty(
48+
sizes: Sequence[Union[int, Value]],
49+
element_type: Type,
50+
*,
51+
loc=None,
52+
ip=None,
53+
) -> _ods_cext.ir.Value:
54+
return _get_op_result_or_op_results(
55+
EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip)
56+
)
57+
58+
4659
generate = region_op(
4760
lambda result, dynamic_extents: GenerateOp(result, dynamic_extents),
4861
terminator=lambda args: YieldOp(args[0]),

mlir/test/python/dialects/linalg/opdsl/emit_matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def matmul_poly(
6363
RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
6464
)
6565
def test_matmul_mono(lhs, rhs):
66-
init_result = tensor.EmptyOp([4, 8], f32)
67-
return matmul_mono(lhs, rhs, outs=[init_result.result])
66+
init_result = tensor.empty([4, 8], f32)
67+
return matmul_mono(lhs, rhs, outs=[init_result])
6868

6969
# CHECK-LABEL: @test_i8i8i32_matmul
7070
# CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32)

mlir/test/python/dialects/linalg/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def testNamedStructuredOpGenericForm():
9797
RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), f32)
9898
)
9999
def named_form(lhs, rhs):
100-
init_result = tensor.EmptyOp([4, 8], f32)
100+
init_result = tensor.empty([4, 8], f32)
101101
# CHECK: "linalg.matmul"(%{{.*}})
102102
# CHECK-SAME: cast = #linalg.type_fn<cast_signed>
103103
# CHECK-SAME: operandSegmentSizes = array<i32: 2, 1>
@@ -106,7 +106,7 @@ def named_form(lhs, rhs):
106106
# CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32
107107
# CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
108108
# CHECK-NEXT: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
109-
return linalg.matmul(lhs, rhs, outs=[init_result.result])
109+
return linalg.matmul(lhs, rhs, outs=[init_result])
110110

111111
module.operation.print(print_generic_op_form=True)
112112

0 commit comments

Comments
 (0)