Skip to content

[mlir][python] enable memref.subview #79393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir-c/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,12 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type);
/// Returns the memory space of the given MemRef type.
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);

/// Returns the strides of the MemRef if the layout map is in strided form.
/// Both strides and offset are out params. strides must point to pre-allocated
/// memory of length equal to the rank of the memref.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(
MlirType type, int64_t *strides, int64_t *offset);

/// Returns the memory spcae of the given Unranked MemRef type.
MLIR_CAPI_EXPORTED MlirAttribute
mlirUnrankedMemrefGetMemorySpace(MlirType type);
Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Bindings/Python/IRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Support.h"

#include <optional>

namespace py = pybind11;
Expand Down Expand Up @@ -618,6 +620,18 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
return mlirMemRefTypeGetLayout(self);
},
"The layout of the MemRef type.")
.def(
"get_strides_and_offset",
[](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
int64_t offset;
if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
self, strides.data(), &offset)))
throw std::runtime_error(
"Failed to extract strides and offset from memref.");
return {strides, offset};
},
"The strides and offset of the MemRef type.")
.def_property_readonly(
"affine_map",
[](PyMemRefType &self) -> PyAffineMap {
Expand Down
16 changes: 16 additions & 0 deletions mlir/lib/CAPI/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/CAPI/AffineMap.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LogicalResult.h"

#include <algorithm>

using namespace mlir;

Expand Down Expand Up @@ -426,6 +430,18 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
}

MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type,
int64_t *strides,
int64_t *offset) {
MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
SmallVector<int64_t> strides_;
if (failed(getStridesAndOffset(memrefType, strides_, *offset)))
return mlirLogicalResultFailure();

(void)std::copy(strides_.begin(), strides_.end(), strides);
return mlirLogicalResultSuccess();
}

MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
return wrap(UnrankedMemRefType::getTypeID());
}
Expand Down
174 changes: 171 additions & 3 deletions mlir/python/mlir/dialects/_ods_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,30 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# Provide a convenient name for sub-packages to resolve the main C-extension
# with a relative import.
from .._mlir_libs import _mlir as _cext
from typing import (
List as _List,
Optional as _Optional,
Sequence as _Sequence,
Tuple as _Tuple,
Type as _Type,
TypeVar as _TypeVar,
Union as _Union,
)

from .._mlir_libs import _mlir as _cext
from ..ir import (
ArrayAttr,
Attribute,
BoolAttr,
DenseI64ArrayAttr,
IntegerAttr,
IntegerType,
OpView,
Operation,
ShapedType,
Value,
)

__all__ = [
"equally_sized_accessor",
"get_default_loc_context",
Expand Down Expand Up @@ -138,3 +152,157 @@ def get_op_result_or_op_results(
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
ResultValueT = _Union[ResultValueTypeTuple]
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]

StaticIntLike = _Union[int, IntegerAttr]
ValueLike = _Union[Operation, OpView, Value]
MixedInt = _Union[StaticIntLike, ValueLike]

IntOrAttrList = _Sequence[_Union[IntegerAttr, int]]
OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]]

BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]]
OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]]

MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]

DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]]


def _dispatch_dynamic_index_list(
indices: _Union[DynamicIndexList, ArrayAttr],
) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]:
"""Dispatches a list of indices to the appropriate form.
This is similar to the custom `DynamicIndexList` directive upstream:
provided indices may be in the form of dynamic SSA values or static values,
and they may be scalable (i.e., as a singleton list) or not. This function
dispatches each index into its respective form. It also extracts the SSA
values and static indices from various similar structures, respectively.
"""
dynamic_indices = []
static_indices = [ShapedType.get_dynamic_size()] * len(indices)
scalable_indices = [False] * len(indices)

# ArrayAttr: Extract index values.
if isinstance(indices, ArrayAttr):
indices = [idx for idx in indices]

def process_nonscalable_index(i, index):
"""Processes any form of non-scalable index.
Returns False if the given index was scalable and thus remains
unprocessed; True otherwise.
"""
if isinstance(index, int):
static_indices[i] = index
elif isinstance(index, IntegerAttr):
static_indices[i] = index.value # pytype: disable=attribute-error
elif isinstance(index, (Operation, Value, OpView)):
dynamic_indices.append(index)
else:
return False
return True

# Process each index at a time.
for i, index in enumerate(indices):
if not process_nonscalable_index(i, index):
# If it wasn't processed, it must be a scalable index, which is
# provided as a _Sequence of one value, so extract and process that.
scalable_indices[i] = True
assert len(index) == 1
ret = process_nonscalable_index(i, index[0])
assert ret

return dynamic_indices, static_indices, scalable_indices


# Dispatches `MixedValues` that all represents integers in various forms into
# the following three categories:
# - `dynamic_values`: a list of `Value`s, potentially from op results;
# - `packed_values`: a value handle, potentially from an op result, associated
# to one or more payload operations of integer type;
# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
# The input is in the form for `packed_values`, only that result is set and the
# other two are empty. Otherwise, the input can be a mix of the other two forms,
# and for each dynamic value, a special value is added to the `static_values`.
def _dispatch_mixed_values(
values: MixedValues,
) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]:
dynamic_values = []
packed_values = None
static_values = None
if isinstance(values, ArrayAttr):
static_values = values
elif isinstance(values, (Operation, Value, OpView)):
packed_values = values
else:
static_values = []
for size in values or []:
if isinstance(size, int):
static_values.append(size)
else:
static_values.append(ShapedType.get_dynamic_size())
dynamic_values.append(size)
static_values = DenseI64ArrayAttr.get(static_values)

return (dynamic_values, packed_values, static_values)


def _get_value_or_attribute_value(
value_or_attr: _Union[any, Attribute, ArrayAttr]
) -> any:
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
return value_or_attr.value
if isinstance(value_or_attr, ArrayAttr):
return _get_value_list(value_or_attr)
return value_or_attr


def _get_value_list(
sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr]
) -> _Sequence[any]:
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]


def _get_int_array_attr(
values: _Optional[_Union[ArrayAttr, IntOrAttrList]]
) -> ArrayAttr:
if values is None:
return None

# Turn into a Python list of Python ints.
values = _get_value_list(values)

# Make an ArrayAttr of IntegerAttrs out of it.
return ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
)


def _get_int_array_array_attr(
values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]]
) -> ArrayAttr:
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
The input has to be a collection of a collection of integers, where any
Python _Sequence and ArrayAttr are admissible collections and Python ints and
any IntegerAttr are admissible integers. Both levels of collections are
turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
If the input is None, an empty ArrayAttr is returned.
"""
if values is None:
return None

# Make sure the outer level is a list.
values = _get_value_list(values)

# The inner level is now either invalid or a mixed sequence of ArrayAttrs and
# Sequences. Make sure the nested values are all lists.
values = [_get_value_list(nested) for nested in values]

# Turn each nested list into an ArrayAttr.
values = [_get_int_array_attr(nested) for nested in values]

# Turn the outer list into an ArrayAttr.
return ArrayAttr.get(values)
Loading