Skip to content

Commit d213245

Browse files
committed
incorporate comments
1 parent b6cb174 commit d213245

File tree

7 files changed

+237
-193
lines changed

7 files changed

+237
-193
lines changed

mlir/include/mlir-c/BuiltinTypes.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -411,9 +411,8 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);
411411
/// Returns the strides of the MemRef if the layout map is in strided form.
412412
/// Both strides and offset are out params. strides must point to pre-allocated
413413
/// memory of length equal to the rank of the memref.
414-
MLIR_CAPI_EXPORTED void mlirMemRefTypeGetStridesAndOffset(MlirType type,
415-
int64_t *strides,
416-
int64_t *offset);
414+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(
415+
MlirType type, int64_t *strides, int64_t *offset);
417416

418417
/// Returns the memory spcae of the given Unranked MemRef type.
419418
MLIR_CAPI_EXPORTED MlirAttribute

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
#include "mlir-c/BuiltinAttributes.h"
1414
#include "mlir-c/BuiltinTypes.h"
15+
#include "mlir-c/Support.h"
16+
1517
#include <optional>
1618

1719
namespace py = pybind11;
@@ -618,12 +620,15 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
618620
return mlirMemRefTypeGetLayout(self);
619621
},
620622
"The layout of the MemRef type.")
621-
.def_property_readonly(
622-
"strides_and_offset",
623+
.def(
624+
"get_strides_and_offset",
623625
[](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
624626
std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
625627
int64_t offset;
626-
mlirMemRefTypeGetStridesAndOffset(self, strides.data(), &offset);
628+
if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
629+
self, strides.data(), &offset)))
630+
throw std::runtime_error(
631+
"failed to extract strides and offset from memref");
627632
return {strides, offset};
628633
},
629634
"The strides and offset of the MemRef type.")

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
#include "mlir-c/BuiltinTypes.h"
1010
#include "mlir-c/AffineMap.h"
1111
#include "mlir-c/IR.h"
12+
#include "mlir-c/Support.h"
1213
#include "mlir/CAPI/AffineMap.h"
1314
#include "mlir/CAPI/IR.h"
1415
#include "mlir/CAPI/Support.h"
1516
#include "mlir/IR/AffineMap.h"
1617
#include "mlir/IR/BuiltinTypes.h"
1718
#include "mlir/IR/Types.h"
19+
#include "mlir/Support/LogicalResult.h"
1820

1921
#include <algorithm>
2022

@@ -428,16 +430,16 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
428430
return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
429431
}
430432

431-
void mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *strides,
432-
int64_t *offset) {
433+
MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type,
434+
int64_t *strides,
435+
int64_t *offset) {
433436
MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
434-
std::pair<SmallVector<int64_t>, int64_t> stridesOffsets =
435-
getStridesAndOffset(memrefType);
436-
assert(stridesOffsets.first.size() == memrefType.getRank() &&
437-
"Strides and rank don't match for memref");
438-
(void)std::copy(stridesOffsets.first.begin(), stridesOffsets.first.end(),
439-
strides);
440-
*offset = stridesOffsets.second;
437+
SmallVector<int64_t> strides_;
438+
if (failed(getStridesAndOffset(memrefType, strides_, *offset)))
439+
return mlirLogicalResultFailure();
440+
441+
(void)std::copy(strides_.begin(), strides_.end(), strides);
442+
return mlirLogicalResultSuccess();
441443
}
442444

443445
MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {

mlir/python/mlir/dialects/_ods_common.py

Lines changed: 171 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,30 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5-
# Provide a convenient name for sub-packages to resolve the main C-extension
6-
# with a relative import.
7-
from .._mlir_libs import _mlir as _cext
85
from typing import (
6+
List as _List,
7+
Optional as _Optional,
98
Sequence as _Sequence,
9+
Tuple as _Tuple,
1010
Type as _Type,
1111
TypeVar as _TypeVar,
1212
Union as _Union,
1313
)
1414

15+
from .._mlir_libs import _mlir as _cext
16+
from ..ir import (
17+
ArrayAttr,
18+
Attribute,
19+
BoolAttr,
20+
DenseI64ArrayAttr,
21+
IntegerAttr,
22+
IntegerType,
23+
OpView,
24+
Operation,
25+
ShapedType,
26+
Value,
27+
)
28+
1529
__all__ = [
1630
"equally_sized_accessor",
1731
"get_default_loc_context",
@@ -138,3 +152,157 @@ def get_op_result_or_op_results(
138152
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
139153
ResultValueT = _Union[ResultValueTypeTuple]
140154
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
155+
156+
StaticIntLike = _Union[int, IntegerAttr]
157+
ValueLike = _Union[Operation, OpView, Value]
158+
MixedInt = _Union[StaticIntLike, ValueLike]
159+
160+
IntOrAttrList = _Sequence[_Union[IntegerAttr, int]]
161+
OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]]
162+
163+
BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]]
164+
OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]]
165+
166+
MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
167+
168+
DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]]
169+
170+
171+
def _dispatch_dynamic_index_list(
172+
indices: _Union[DynamicIndexList, ArrayAttr],
173+
) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]:
174+
"""Dispatches a list of indices to the appropriate form.
175+
176+
This is similar to the custom `DynamicIndexList` directive upstream:
177+
provided indices may be in the form of dynamic SSA values or static values,
178+
and they may be scalable (i.e., as a singleton list) or not. This function
179+
dispatches each index into its respective form. It also extracts the SSA
180+
values and static indices from various similar structures, respectively.
181+
"""
182+
dynamic_indices = []
183+
static_indices = [ShapedType.get_dynamic_size()] * len(indices)
184+
scalable_indices = [False] * len(indices)
185+
186+
# ArrayAttr: Extract index values.
187+
if isinstance(indices, ArrayAttr):
188+
indices = [idx for idx in indices]
189+
190+
def process_nonscalable_index(i, index):
191+
"""Processes any form of non-scalable index.
192+
193+
Returns False if the given index was scalable and thus remains
194+
unprocessed; True otherwise.
195+
"""
196+
if isinstance(index, int):
197+
static_indices[i] = index
198+
elif isinstance(index, IntegerAttr):
199+
static_indices[i] = index.value # pytype: disable=attribute-error
200+
elif isinstance(index, (Operation, Value, OpView)):
201+
dynamic_indices.append(index)
202+
else:
203+
return False
204+
return True
205+
206+
# Process each index at a time.
207+
for i, index in enumerate(indices):
208+
if not process_nonscalable_index(i, index):
209+
# If it wasn't processed, it must be a scalable index, which is
210+
# provided as a _Sequence of one value, so extract and process that.
211+
scalable_indices[i] = True
212+
assert len(index) == 1
213+
ret = process_nonscalable_index(i, index[0])
214+
assert ret
215+
216+
return dynamic_indices, static_indices, scalable_indices
217+
218+
219+
# Dispatches `MixedValues` that all represents integers in various forms into
220+
# the following three categories:
221+
# - `dynamic_values`: a list of `Value`s, potentially from op results;
222+
# - `packed_values`: a value handle, potentially from an op result, associated
223+
# to one or more payload operations of integer type;
224+
# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
225+
# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
226+
# The input is in the form for `packed_values`, only that result is set and the
227+
# other two are empty. Otherwise, the input can be a mix of the other two forms,
228+
# and for each dynamic value, a special value is added to the `static_values`.
229+
def _dispatch_mixed_values(
230+
values: MixedValues,
231+
) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]:
232+
dynamic_values = []
233+
packed_values = None
234+
static_values = None
235+
if isinstance(values, ArrayAttr):
236+
static_values = values
237+
elif isinstance(values, (Operation, Value, OpView)):
238+
packed_values = values
239+
else:
240+
static_values = []
241+
for size in values or []:
242+
if isinstance(size, int):
243+
static_values.append(size)
244+
else:
245+
static_values.append(ShapedType.get_dynamic_size())
246+
dynamic_values.append(size)
247+
static_values = DenseI64ArrayAttr.get(static_values)
248+
249+
return (dynamic_values, packed_values, static_values)
250+
251+
252+
def _get_value_or_attribute_value(
253+
value_or_attr: _Union[any, Attribute, ArrayAttr]
254+
) -> any:
255+
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
256+
return value_or_attr.value
257+
if isinstance(value_or_attr, ArrayAttr):
258+
return _get_value_list(value_or_attr)
259+
return value_or_attr
260+
261+
262+
def _get_value_list(
263+
sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr]
264+
) -> _Sequence[any]:
265+
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
266+
267+
268+
def _get_int_array_attr(
269+
values: _Optional[_Union[ArrayAttr, IntOrAttrList]]
270+
) -> ArrayAttr:
271+
if values is None:
272+
return None
273+
274+
# Turn into a Python list of Python ints.
275+
values = _get_value_list(values)
276+
277+
# Make an ArrayAttr of IntegerAttrs out of it.
278+
return ArrayAttr.get(
279+
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
280+
)
281+
282+
283+
def _get_int_array_array_attr(
284+
values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]]
285+
) -> ArrayAttr:
286+
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
287+
288+
The input has to be a collection of a collection of integers, where any
289+
Python _Sequence and ArrayAttr are admissible collections and Python ints and
290+
any IntegerAttr are admissible integers. Both levels of collections are
291+
turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
292+
If the input is None, an empty ArrayAttr is returned.
293+
"""
294+
if values is None:
295+
return None
296+
297+
# Make sure the outer level is a list.
298+
values = _get_value_list(values)
299+
300+
# The inner level is now either invalid or a mixed sequence of ArrayAttrs and
301+
# Sequences. Make sure the nested values are all lists.
302+
values = [_get_value_list(nested) for nested in values]
303+
304+
# Turn each nested list into an ArrayAttr.
305+
values = [_get_int_array_attr(nested) for nested in values]
306+
307+
# Turn the outer list into an ArrayAttr.
308+
return ArrayAttr.get(values)

mlir/python/mlir/dialects/memref.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,52 @@
66
from typing import Optional
77

88
from ._memref_ops_gen import *
9-
from .arith import ConstantOp
10-
from .transform.structured import _dispatch_mixed_values, MixedValues
9+
from ._ods_common import _dispatch_mixed_values, MixedValues
10+
from .arith import ConstantOp, _is_integer_like_type
1111
from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType
1212

1313

14-
def _is_constant(i):
15-
return isinstance(i, Value) and isinstance(i.owner.opview, ConstantOp)
14+
def _is_constant_int_like(i):
15+
return (
16+
isinstance(i, Value)
17+
and isinstance(i.owner.opview, ConstantOp)
18+
and _is_integer_like_type(i.type)
19+
)
1620

1721

18-
def _is_static(i):
19-
return (isinstance(i, int) and not ShapedType.is_dynamic_size(i)) or _is_constant(i)
22+
def _is_static_int_like(i):
23+
return (
24+
isinstance(i, int) and not ShapedType.is_dynamic_size(i)
25+
) or _is_constant_int_like(i)
2026

2127

2228
def _infer_memref_subview_result_type(
2329
source_memref_type, offsets, static_sizes, static_strides
2430
):
25-
source_strides, source_offset = source_memref_type.strides_and_offset
31+
source_strides, source_offset = source_memref_type.get_strides_and_offset()
2632
# "canonicalize" from tuple|list -> list
2733
offsets, static_sizes, static_strides, source_strides = map(
2834
list, (offsets, static_sizes, static_strides, source_strides)
2935
)
3036

31-
assert all(
32-
all(_is_static(i) for i in s)
37+
if not all(
38+
all(_is_static_int_like(i) for i in s)
3339
for s in [
3440
static_sizes,
3541
static_strides,
3642
source_strides,
3743
]
38-
), f"Only inferring from python or mlir integer constant is supported"
44+
):
45+
raise ValueError(
46+
"Only inferring from python or mlir integer constant is supported."
47+
)
3948

4049
for s in [offsets, static_sizes, static_strides]:
4150
for idx, i in enumerate(s):
42-
if _is_constant(i):
51+
if _is_constant_int_like(i):
4352
s[idx] = i.owner.opview.literal_value
4453

45-
if any(not _is_static(i) for i in offsets + [source_offset]):
54+
if any(not _is_static_int_like(i) for i in offsets + [source_offset]):
4655
target_offset = ShapedType.get_dynamic_size()
4756
else:
4857
target_offset = source_offset
@@ -91,22 +100,22 @@ def subview(
91100
sizes = []
92101
if strides is None:
93102
strides = []
94-
source_strides, source_offset = source.type.strides_and_offset
103+
source_strides, source_offset = source.type.get_strides_and_offset()
95104
if result_type is None and all(
96-
all(_is_static(i) for i in s) for s in [sizes, strides, source_strides]
105+
all(_is_static_int_like(i) for i in s) for s in [sizes, strides, source_strides]
97106
):
98107
# If any are arith.constant results then this will canonicalize to python int
99-
# (which can then be used to fully specific the subview).
108+
# (which can then be used to fully specify the subview).
100109
(
101110
offsets,
102111
sizes,
103112
strides,
104113
result_type,
105114
) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
106-
else:
107-
assert (
108-
result_type is not None
109-
), "mixed static/dynamic offset/sizes/strides requires explicit result type"
115+
elif result_type is None:
116+
raise ValueError(
117+
"mixed static/dynamic offset/sizes/strides requires explicit result type."
118+
)
110119

111120
offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)
112121
sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)

0 commit comments

Comments
 (0)