@@ -70,71 +70,6 @@ def store(
70
70
return get_op_result_or_op_results (StoreOp (value , mem , indices , loc = loc , ip = ip ))
71
71
72
72
73
- def subview (
74
- source : "MemRef" ,
75
- offsets : Optional [Sequence [Value ]] = None ,
76
- strides : Optional [Sequence [Value ]] = None ,
77
- static_offsets : Optional [Sequence [int ]] = None ,
78
- static_sizes : Optional [Sequence [int ]] = None ,
79
- static_strides : Optional [Sequence [int ]] = None ,
80
- * ,
81
- loc = None ,
82
- ip = None ,
83
- ):
84
- if loc is None :
85
- loc = get_user_code_loc ()
86
- if offsets is None :
87
- offsets = []
88
- if static_offsets is None :
89
- static_offsets = []
90
- if strides is None :
91
- strides = []
92
- if static_strides is None :
93
- static_strides = []
94
- assert static_sizes , f"this convenience method only handles static sizes"
95
- sizes = []
96
- wrong_type = T .memref (* static_sizes , source .dtype )
97
- if offsets and static_offsets :
98
- assert all (s == S for s in static_offsets )
99
- if strides and static_strides :
100
- assert all (s == S for s in static_strides )
101
- val = memref .subview (
102
- wrong_type ,
103
- source ,
104
- offsets ,
105
- sizes ,
106
- strides ,
107
- static_offsets ,
108
- static_sizes ,
109
- static_strides ,
110
- loc = loc ,
111
- ip = ip ,
112
- )
113
- # dumbest hack ever - the default builder doesn't connect to inferReturnTypes
114
- # but the diag message does
115
- try :
116
- val .owner .verify ()
117
- return val
118
- except MLIRError as e :
119
- diag = str (e .error_diagnostics [0 ])
120
- correct_type = re .findall (r"'memref<(.*)>'" , diag )
121
- assert len (correct_type ) == 1
122
- correct_type = Type .parse (f"memref<{ correct_type [0 ]} >" )
123
- val .owner .erase ()
124
- return memref .subview (
125
- correct_type ,
126
- source ,
127
- offsets ,
128
- sizes ,
129
- strides ,
130
- static_offsets ,
131
- static_sizes ,
132
- static_strides ,
133
- loc = loc ,
134
- ip = ip ,
135
- )
136
-
137
-
138
73
@register_value_caster (MemRefType .static_typeid )
139
74
class MemRef (Value ):
140
75
def __str__ (self ):
@@ -266,16 +201,15 @@ def _subview(
266
201
if indexer .is_constant ():
267
202
out = subview (
268
203
out ,
269
- static_offsets = indexer .static_offsets (),
270
- static_sizes = indexer .static_sizes (),
271
- static_strides = indexer .static_strides (),
204
+ offsets = indexer .static_offsets (),
205
+ sizes = indexer .static_sizes (),
206
+ strides = indexer .static_strides (),
272
207
loc = loc ,
273
208
ip = ip ,
274
209
)
275
210
else :
276
211
# special tile case
277
212
offsets = [None ] * len (indexer .in_shape )
278
- static_offsets = [None ] * len (indexer .in_shape )
279
213
static_sizes = [None ] * len (indexer .in_shape )
280
214
static_strides = [None ] * len (indexer .in_shape )
281
215
for i , ind in enumerate (indexer .indices ):
@@ -292,15 +226,13 @@ def _subview(
292
226
and ind .step .is_constant ()
293
227
):
294
228
offsets [i ] = ind .start
295
- static_offsets [i ] = S
296
229
static_sizes [i ] = maybe_size .literal_value
297
230
static_strides [i ] = (
298
231
ind .step .literal_value if isinstance (ind .step , Scalar ) else ind .step
299
232
)
300
233
else :
301
234
raise RuntimeError (f"indexing not supported { indexer .indices } " )
302
235
offsets = list (filter (None , offsets ))
303
- static_offsets = list (filter (None , static_offsets ))
304
236
static_sizes = list (filter (None , static_sizes ))
305
237
static_strides = list (filter (None , static_strides ))
306
238
assert (
@@ -312,9 +244,8 @@ def _subview(
312
244
out = subview (
313
245
out ,
314
246
offsets = offsets ,
315
- static_offsets = static_offsets ,
316
- static_sizes = static_sizes ,
317
- static_strides = static_strides ,
247
+ sizes = static_sizes ,
248
+ strides = static_strides ,
318
249
loc = loc ,
319
250
ip = ip ,
320
251
)
0 commit comments