Skip to content

Commit 270b468

Browse files
committed
Add encoding argument to tensor.empty Python function
1 parent b8b036a commit 270b468

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

mlir/python/mlir/dialects/tensor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
from typing import Optional
45

56
from ._tensor_ops_gen import *
67
from ._tensor_ops_gen import _Dialect
@@ -25,6 +26,7 @@ def __init__(
2526
sizes: Sequence[Union[int, Value]],
2627
element_type: Type,
2728
*,
29+
encoding: Optional[Attribute] = None,
2830
loc=None,
2931
ip=None,
3032
):
@@ -40,19 +42,20 @@ def __init__(
4042
else:
4143
static_sizes.append(ShapedType.get_dynamic_size())
4244
dynamic_sizes.append(s)
43-
result_type = RankedTensorType.get(static_sizes, element_type)
45+
result_type = RankedTensorType.get(static_sizes, element_type, encoding)
4446
super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
4547

4648

4749
def empty(
4850
sizes: Sequence[Union[int, Value]],
4951
element_type: Type,
5052
*,
53+
encoding: Optional[Attribute] = None,
5154
loc=None,
5255
ip=None,
5356
) -> _ods_cext.ir.Value:
5457
return _get_op_result_or_op_results(
55-
EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip)
58+
EmptyOp(sizes=sizes, element_type=element_type, encoding=encoding, loc=loc, ip=ip)
5659
)
5760

5861

mlir/test/python/dialects/sparse_tensor/dialect.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

33
from mlir.ir import *
4-
from mlir.dialects import sparse_tensor as st
4+
from mlir.dialects import sparse_tensor as st, tensor
55
import textwrap
66

77

@@ -219,9 +219,12 @@ def testEncodingAttrOnTensorType():
219219
)
220220
)
221221
)
222-
tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
223-
# CHECK: tensor<1024xf32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>>
224-
print(tt)
225-
# CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>
226-
print(tt.encoding)
227-
assert tt.encoding == encoding
222+
for tt in (
223+
RankedTensorType.get((1024,), F32Type.get(), encoding=encoding),
224+
tensor.empty((1024,), F32Type.get(), encoding=encoding),
225+
):
226+
# CHECK: tensor<1024xf32, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>>
227+
print(tt)
228+
# CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>
229+
print(tt.encoding)
230+
assert tt.encoding == encoding

0 commit comments

Comments
 (0)