Skip to content

Commit 6ae6cad

Browse files
committed
fix striding calcs for memref (depends on llvm/llvm-project#79393)
1 parent 8cb90ec commit 6ae6cad

File tree

3 files changed

+217
-149
lines changed

3 files changed

+217
-149
lines changed

mlir/extras/dialects/ext/memref.py

Lines changed: 5 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -70,71 +70,6 @@ def store(
7070
return get_op_result_or_op_results(StoreOp(value, mem, indices, loc=loc, ip=ip))
7171

7272

73-
def subview(
74-
source: "MemRef",
75-
offsets: Optional[Sequence[Value]] = None,
76-
strides: Optional[Sequence[Value]] = None,
77-
static_offsets: Optional[Sequence[int]] = None,
78-
static_sizes: Optional[Sequence[int]] = None,
79-
static_strides: Optional[Sequence[int]] = None,
80-
*,
81-
loc=None,
82-
ip=None,
83-
):
84-
if loc is None:
85-
loc = get_user_code_loc()
86-
if offsets is None:
87-
offsets = []
88-
if static_offsets is None:
89-
static_offsets = []
90-
if strides is None:
91-
strides = []
92-
if static_strides is None:
93-
static_strides = []
94-
assert static_sizes, f"this convenience method only handles static sizes"
95-
sizes = []
96-
wrong_type = T.memref(*static_sizes, source.dtype)
97-
if offsets and static_offsets:
98-
assert all(s == S for s in static_offsets)
99-
if strides and static_strides:
100-
assert all(s == S for s in static_strides)
101-
val = memref.subview(
102-
wrong_type,
103-
source,
104-
offsets,
105-
sizes,
106-
strides,
107-
static_offsets,
108-
static_sizes,
109-
static_strides,
110-
loc=loc,
111-
ip=ip,
112-
)
113-
# dumbest hack ever - the default builder doesn't connect to inferReturnTypes
114-
# but the diag message does
115-
try:
116-
val.owner.verify()
117-
return val
118-
except MLIRError as e:
119-
diag = str(e.error_diagnostics[0])
120-
correct_type = re.findall(r"'memref<(.*)>'", diag)
121-
assert len(correct_type) == 1
122-
correct_type = Type.parse(f"memref<{correct_type[0]}>")
123-
val.owner.erase()
124-
return memref.subview(
125-
correct_type,
126-
source,
127-
offsets,
128-
sizes,
129-
strides,
130-
static_offsets,
131-
static_sizes,
132-
static_strides,
133-
loc=loc,
134-
ip=ip,
135-
)
136-
137-
13873
@register_value_caster(MemRefType.static_typeid)
13974
class MemRef(Value):
14075
def __str__(self):
@@ -266,16 +201,15 @@ def _subview(
266201
if indexer.is_constant():
267202
out = subview(
268203
out,
269-
static_offsets=indexer.static_offsets(),
270-
static_sizes=indexer.static_sizes(),
271-
static_strides=indexer.static_strides(),
204+
offsets=indexer.static_offsets(),
205+
sizes=indexer.static_sizes(),
206+
strides=indexer.static_strides(),
272207
loc=loc,
273208
ip=ip,
274209
)
275210
else:
276211
# special tile case
277212
offsets = [None] * len(indexer.in_shape)
278-
static_offsets = [None] * len(indexer.in_shape)
279213
static_sizes = [None] * len(indexer.in_shape)
280214
static_strides = [None] * len(indexer.in_shape)
281215
for i, ind in enumerate(indexer.indices):
@@ -292,15 +226,13 @@ def _subview(
292226
and ind.step.is_constant()
293227
):
294228
offsets[i] = ind.start
295-
static_offsets[i] = S
296229
static_sizes[i] = maybe_size.literal_value
297230
static_strides[i] = (
298231
ind.step.literal_value if isinstance(ind.step, Scalar) else ind.step
299232
)
300233
else:
301234
raise RuntimeError(f"indexing not supported {indexer.indices}")
302235
offsets = list(filter(None, offsets))
303-
static_offsets = list(filter(None, static_offsets))
304236
static_sizes = list(filter(None, static_sizes))
305237
static_strides = list(filter(None, static_strides))
306238
assert (
@@ -312,9 +244,8 @@ def _subview(
312244
out = subview(
313245
out,
314246
offsets=offsets,
315-
static_offsets=static_offsets,
316-
static_sizes=static_sizes,
317-
static_strides=static_strides,
247+
sizes=static_sizes,
248+
strides=static_strides,
318249
loc=loc,
319250
ip=ip,
320251
)

mlir/extras/testing/testing.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
from ..context import MLIRContext, mlir_mod_ctx
1313
from .generate_test_checks import main
1414
from ..runtime.refbackend import LLVMJITBackend
15+
from ...ir import Module
1516

1617

1718
def filecheck(correct: str, module):
19+
if isinstance(module, Module):
20+
assert module.operation.verify()
1821
filecheck_name = "FileCheck"
1922
if platform.system() == "Windows":
2023
filecheck_name += ".exe"

0 commit comments

Comments
 (0)