Skip to content

Commit 427045b

Browse files
committed
[mlir][python] enable memref.subview
1 parent 66ef690 commit 427045b

File tree

5 files changed

+315
-2
lines changed

5 files changed

+315
-2
lines changed

mlir/include/mlir-c/BuiltinTypes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,13 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type);
408408
/// Returns the memory space of the given MemRef type.
409409
MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type);
410410

411+
/// Returns the strides of the MemRef if the layout map is in strided form.
412+
/// Both strides and offset are out params. strides must point to pre-allocated
413+
/// memory of length equal to the rank of the memref.
414+
MLIR_CAPI_EXPORTED void mlirMemRefTypeGetStridesAndOffset(MlirType type,
415+
int64_t *strides,
416+
int64_t *offset);
417+
411418
/// Returns the memory spcae of the given Unranked MemRef type.
412419
MLIR_CAPI_EXPORTED MlirAttribute
413420
mlirUnrankedMemrefGetMemorySpace(MlirType type);

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,15 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
618618
return mlirMemRefTypeGetLayout(self);
619619
},
620620
"The layout of the MemRef type.")
621+
.def_property_readonly(
622+
"strides_and_offset",
623+
[](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
624+
std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
625+
int64_t offset;
626+
mlirMemRefTypeGetStridesAndOffset(self, strides.data(), &offset);
627+
return {strides, offset};
628+
},
629+
"The strides and offset of the MemRef type.")
621630
.def_property_readonly(
622631
"affine_map",
623632
[](PyMemRefType &self) -> PyAffineMap {

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include "mlir/IR/BuiltinTypes.h"
1717
#include "mlir/IR/Types.h"
1818

19+
#include <algorithm>
20+
1921
using namespace mlir;
2022

2123
//===----------------------------------------------------------------------===//
@@ -426,6 +428,18 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
426428
return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
427429
}
428430

431+
void mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *strides,
432+
int64_t *offset) {
433+
MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
434+
std::pair<SmallVector<int64_t>, int64_t> stridesOffsets =
435+
getStridesAndOffset(memrefType);
436+
assert(stridesOffsets.first.size() == memrefType.getRank() &&
437+
"Strides and rank don't match for memref");
438+
(void)std::copy(stridesOffsets.first.begin(), stridesOffsets.first.end(),
439+
strides);
440+
*offset = stridesOffsets.second;
441+
}
442+
429443
MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
430444
return wrap(UnrankedMemRefType::getTypeID());
431445
}

mlir/python/mlir/dialects/memref.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,126 @@
11
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
import operator
5+
from itertools import accumulate
6+
from typing import Optional
47

58
from ._memref_ops_gen import *
9+
from .arith import ConstantOp
10+
from .transform.structured import _dispatch_mixed_values, MixedValues
11+
from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType
12+
13+
14+
def _is_constant(i):
15+
return isinstance(i, Value) and isinstance(i.owner.opview, ConstantOp)
16+
17+
18+
def _is_static(i):
19+
return (isinstance(i, int) and not ShapedType.is_dynamic_size(i)) or _is_constant(i)
20+
21+
22+
def _infer_memref_subview_result_type(
23+
source_memref_type, offsets, static_sizes, static_strides
24+
):
25+
source_strides, source_offset = source_memref_type.strides_and_offset
26+
# "canonicalize" from tuple|list -> list
27+
offsets, static_sizes, static_strides, source_strides = map(
28+
list, (offsets, static_sizes, static_strides, source_strides)
29+
)
30+
31+
assert all(
32+
all(_is_static(i) for i in s)
33+
for s in [
34+
static_sizes,
35+
static_strides,
36+
source_strides,
37+
]
38+
), f"Only inferring from python or mlir integer constant is supported"
39+
40+
for s in [offsets, static_sizes, static_strides]:
41+
for idx, i in enumerate(s):
42+
if _is_constant(i):
43+
s[idx] = i.owner.opview.literal_value
44+
45+
if any(not _is_static(i) for i in offsets + [source_offset]):
46+
target_offset = ShapedType.get_dynamic_size()
47+
else:
48+
target_offset = source_offset
49+
for offset, target_stride in zip(offsets, source_strides):
50+
target_offset += offset * target_stride
51+
52+
target_strides = []
53+
for source_stride, static_stride in zip(source_strides, static_strides):
54+
target_strides.append(source_stride * static_stride)
55+
56+
# If default striding then no need to complicate things for downstream ops (e.g., expand_shape).
57+
default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1]
58+
if target_strides == default_strides and target_offset == 0:
59+
layout = None
60+
else:
61+
layout = StridedLayoutAttr.get(target_offset, target_strides)
62+
return (
63+
offsets,
64+
static_sizes,
65+
static_strides,
66+
MemRefType.get(
67+
static_sizes,
68+
source_memref_type.element_type,
69+
layout,
70+
source_memref_type.memory_space,
71+
),
72+
)
73+
74+
75+
_generated_subview = subview
76+
77+
78+
def subview(
79+
source: Value,
80+
offsets: MixedValues,
81+
sizes: MixedValues,
82+
strides: MixedValues,
83+
*,
84+
result_type: Optional[MemRefType] = None,
85+
loc=None,
86+
ip=None,
87+
):
88+
if offsets is None:
89+
offsets = []
90+
if sizes is None:
91+
sizes = []
92+
if strides is None:
93+
strides = []
94+
source_strides, source_offset = source.type.strides_and_offset
95+
if result_type is None and all(
96+
all(_is_static(i) for i in s) for s in [sizes, strides, source_strides]
97+
):
98+
# If any are arith.constant results then this will canonicalize to python int
99+
# (which can then be used to fully specific the subview).
100+
(
101+
offsets,
102+
sizes,
103+
strides,
104+
result_type,
105+
) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
106+
else:
107+
assert (
108+
result_type is not None
109+
), "mixed static/dynamic offset/sizes/strides requires explicit result type"
110+
111+
offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)
112+
sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)
113+
strides, _packed_strides, static_strides = _dispatch_mixed_values(strides)
114+
115+
return _generated_subview(
116+
result_type,
117+
source,
118+
offsets,
119+
sizes,
120+
strides,
121+
static_offsets,
122+
static_sizes,
123+
static_strides,
124+
loc=loc,
125+
ip=ip,
126+
)

mlir/test/python/dialects/memref.py

Lines changed: 164 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

3-
from mlir.ir import *
4-
import mlir.dialects.func as func
3+
import mlir.dialects.arith as arith
54
import mlir.dialects.memref as memref
65
import mlir.extras.types as T
6+
from mlir.dialects.memref import _infer_memref_subview_result_type
7+
from mlir.ir import *
78

89

910
def run(f):
@@ -88,3 +89,164 @@ def testMemRefAttr():
8889
memref.global_("objFifo_in0", T.memref(16, T.i32()))
8990
# CHECK: memref.global @objFifo_in0 : memref<16xi32>
9091
print(module)
92+
93+
94+
# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeSemantics
95+
@run
96+
def testSubViewOpInferReturnTypeSemantics():
97+
with Context() as ctx, Location.unknown(ctx):
98+
module = Module.create()
99+
with InsertionPoint(module.body):
100+
x = memref.alloc(T.memref(10, 10, T.i32()), [], [])
101+
# CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32>
102+
print(x.owner)
103+
104+
y = memref.subview(x, [1, 1], [3, 3], [1, 1])
105+
assert y.owner.verify()
106+
# CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
107+
print(y.owner)
108+
109+
z = memref.subview(
110+
x,
111+
[arith.constant(T.index(), 1), 1],
112+
[3, 3],
113+
[1, 1],
114+
)
115+
# CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>>
116+
print(z.owner)
117+
118+
z = memref.subview(
119+
x,
120+
[arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
121+
[3, 3],
122+
[1, 1],
123+
)
124+
# CHECK: %{{.*}} = memref.subview %[[ALLOC]][3, 4] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 34>>
125+
print(z.owner)
126+
127+
s = arith.addi(arith.constant(T.index(), 3), arith.constant(T.index(), 4))
128+
z = memref.subview(
129+
x,
130+
[s, 0],
131+
[3, 3],
132+
[1, 1],
133+
)
134+
# CHECK: {{.*}} = memref.subview %[[ALLOC]][%0, 0] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: ?>>
135+
print(z)
136+
137+
try:
138+
_infer_memref_subview_result_type(
139+
x.type,
140+
[arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
141+
[ShapedType.get_dynamic_size(), 3],
142+
[1, 1],
143+
)
144+
except AssertionError as e:
145+
# CHECK: Only inferring from python or mlir integer constant is supported
146+
print(e)
147+
148+
try:
149+
memref.subview(
150+
x,
151+
[arith.constant(T.index(), 3), arith.constant(T.index(), 4)],
152+
[ShapedType.get_dynamic_size(), 3],
153+
[1, 1],
154+
)
155+
except AssertionError as e:
156+
# CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type
157+
print(e)
158+
159+
layout = StridedLayoutAttr.get(ShapedType.get_dynamic_size(), [10, 1])
160+
x = memref.alloc(
161+
T.memref(
162+
10,
163+
10,
164+
T.i32(),
165+
layout=layout,
166+
),
167+
[],
168+
[arith.constant(T.index(), 42)],
169+
)
170+
# CHECK: %[[DYNAMICALLOC:.*]] = memref.alloc()[%c42] : memref<10x10xi32, strided<[10, 1], offset: ?>>
171+
print(x.owner)
172+
y = memref.subview(
173+
x,
174+
[1, 1],
175+
[3, 3],
176+
[1, 1],
177+
result_type=T.memref(3, 3, T.i32(), layout=layout),
178+
)
179+
# CHECK: %{{.*}} = memref.subview %[[DYNAMICALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32, strided<[10, 1], offset: ?>> to memref<3x3xi32, strided<[10, 1], offset: ?>>
180+
print(y.owner)
181+
182+
183+
# CHECK-LABEL: TEST: testSubViewOpInferReturnTypeExtensiveSlicing
184+
@run
185+
def testSubViewOpInferReturnTypeExtensiveSlicing():
186+
def check_strides_offset(memref, np_view):
187+
layout = memref.type.layout
188+
dtype_size_in_bytes = np_view.dtype.itemsize
189+
golden_strides = (np.array(np_view.strides) // dtype_size_in_bytes).tolist()
190+
golden_offset = (
191+
np_view.ctypes.data - np_view.base.ctypes.data
192+
) // dtype_size_in_bytes
193+
194+
assert (layout.strides, layout.offset) == (golden_strides, golden_offset)
195+
196+
with Context() as ctx, Location.unknown(ctx):
197+
module = Module.create()
198+
with InsertionPoint(module.body):
199+
shape = (10, 22, 333, 4444)
200+
golden_mem = np.zeros(shape, dtype=np.int32)
201+
mem1 = memref.alloc(T.memref(*shape, T.i32()), [], [])
202+
203+
# fmt: off
204+
check_strides_offset(memref.subview(mem1, (1, 0, 0, 0), (1, 22, 333, 4444), (1, 1, 1, 1)), golden_mem[1:2, ...])
205+
check_strides_offset(memref.subview(mem1, (0, 1, 0, 0), (10, 1, 333, 4444), (1, 1, 1, 1)), golden_mem[:, 1:2])
206+
check_strides_offset(memref.subview(mem1, (0, 0, 1, 0), (10, 22, 1, 4444), (1, 1, 1, 1)), golden_mem[:, :, 1:2])
207+
check_strides_offset(memref.subview(mem1, (0, 0, 0, 1), (10, 22, 333, 1), (1, 1, 1, 1)), golden_mem[:, :, :, 1:2])
208+
check_strides_offset(memref.subview(mem1, (0, 1, 0, 1), (10, 1, 333, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, :, 1:2])
209+
check_strides_offset(memref.subview(mem1, (1, 0, 0, 1), (1, 22, 333, 1), (1, 1, 1, 1)), golden_mem[1:2, :, :, 1:2])
210+
check_strides_offset(memref.subview(mem1, (1, 1, 0, 0), (1, 1, 333, 4444), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, :])
211+
check_strides_offset(memref.subview(mem1, (0, 0, 1, 1), (10, 22, 1, 1), (1, 1, 1, 1)), golden_mem[:, :, 1:2, 1:2])
212+
check_strides_offset(memref.subview(mem1, (0, 1, 1, 0), (10, 1, 1, 4444), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, :])
213+
check_strides_offset(memref.subview(mem1, (1, 0, 1, 0), (1, 22, 1, 4444), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, :])
214+
check_strides_offset(memref.subview(mem1, (1, 1, 0, 1), (1, 1, 333, 1), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, 1:2])
215+
check_strides_offset(memref.subview(mem1, (1, 0, 1, 1), (1, 22, 1, 1), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, 1:2])
216+
check_strides_offset(memref.subview(mem1, (0, 1, 1, 1), (10, 1, 1, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, 1:2])
217+
check_strides_offset(memref.subview(mem1, (1, 1, 1, 0), (1, 1, 1, 4444), (1, 1, 1, 1)), golden_mem[1:2, 1:2, 1:2, :])
218+
# fmt: on
219+
220+
# default strides and offset means no stridedlayout attribute means affinemap layout
221+
assert memref.subview(
222+
mem1, (0, 0, 0, 0), (10, 22, 333, 4444), (1, 1, 1, 1)
223+
).type.layout == AffineMapAttr.get(
224+
AffineMap.get(
225+
4,
226+
0,
227+
[
228+
AffineDimExpr.get(0),
229+
AffineDimExpr.get(1),
230+
AffineDimExpr.get(2),
231+
AffineDimExpr.get(3),
232+
],
233+
)
234+
)
235+
236+
shape = (7, 22, 333, 4444)
237+
golden_mem = np.zeros(shape, dtype=np.int32)
238+
mem2 = memref.alloc(T.memref(*shape, T.i32()), [], [])
239+
# fmt: off
240+
check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 333, 4444), (1, 2, 1, 1)), golden_mem[:, 0:22:2])
241+
check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 4444), (1, 2, 30, 1)), golden_mem[:, 0:22:2, 0:330:30])
242+
check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 11), (1, 2, 30, 400)), golden_mem[:, 0:22:2, 0:330:30, 0:4400:400])
243+
check_strides_offset(memref.subview(mem2, (0, 0, 100, 1000), (7, 22, 20, 20), (1, 1, 5, 50)), golden_mem[:, :, 100:200:5, 1000:2000:50])
244+
# fmt: on
245+
246+
shape = (8, 8)
247+
golden_mem = np.zeros(shape, dtype=np.int32)
248+
# fmt: off
249+
mem3 = memref.alloc(T.memref(*shape, T.i32()), [], [])
250+
check_strides_offset(memref.subview(mem3, (0, 0), (4, 4), (1, 1)), golden_mem[0:4, 0:4])
251+
check_strides_offset(memref.subview(mem3, (4, 4), (4, 4), (1, 1)), golden_mem[4:8, 4:8])
252+
# fmt: on

0 commit comments

Comments
 (0)