@@ -158,27 +158,6 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
158
158
];
159
159
160
160
string ElementsAttrInterfaceAccessors = [{
161
- /// Return the attribute value at the given index. The index is expected to
162
- /// refer to a valid element.
163
- Attribute getValue(ArrayRef<uint64_t> index) const {
164
- return getValue<Attribute>(index);
165
- }
166
-
167
- /// Return the value of type 'T' at the given index, where 'T' corresponds
168
- /// to an Attribute type.
169
- template <typename T>
170
- std::enable_if_t<!std::is_same<T, ::mlir::Attribute>::value &&
171
- std::is_base_of<T, ::mlir::Attribute>::value>
172
- getValue(ArrayRef<uint64_t> index) const {
173
- return getValue(index).template dyn_cast_or_null<T>();
174
- }
175
-
176
- /// Return the value of type 'T' at the given index.
177
- template <typename T>
178
- T getValue(ArrayRef<uint64_t> index) const {
179
- return getFlatValue<T>(getFlattenedIndex(index));
180
- }
181
-
182
161
/// Return the number of elements held by this attribute.
183
162
int64_t size() const { return getNumElements(); }
184
163
@@ -281,6 +260,14 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
281
260
// Value Iteration
282
261
//===------------------------------------------------------------------===//
283
262
263
+ /// The iterator for the given element type T.
264
+ template <typename T, typename AttrT = ConcreteAttr>
265
+ using iterator = decltype(std::declval<AttrT>().template value_begin<T>());
266
+ /// The iterator range over the given element T.
267
+ template <typename T, typename AttrT = ConcreteAttr>
268
+ using iterator_range =
269
+ decltype(std::declval<AttrT>().template getValues<T>());
270
+
284
271
/// Return an iterator to the first element of this attribute as a value of
285
272
/// type `T`.
286
273
template <typename T>
@@ -292,19 +279,16 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
292
279
template <typename T>
293
280
auto getValues() const {
294
281
auto beginIt = $_attr.template value_begin<T>();
295
- return llvm::make_range(beginIt, std::next(beginIt, size()));
296
- }
297
- /// Return the value at the given flattened index.
298
- template <typename T> T getFlatValue(uint64_t index) const {
299
- return *std::next($_attr.template value_begin<T>(), index);
282
+ return detail::ElementsAttrRange<decltype(beginIt)>(
283
+ Attribute($_attr).getType(), beginIt, std::next(beginIt, size()));
300
284
}
301
285
}] # ElementsAttrInterfaceAccessors;
302
286
303
287
let extraClassDeclaration = [{
304
288
template <typename T>
305
289
using iterator = detail::ElementsAttrIterator<T>;
306
290
template <typename T>
307
- using iterator_range = llvm::iterator_range <iterator<T>>;
291
+ using iterator_range = detail::ElementsAttrRange <iterator<T>>;
308
292
309
293
//===------------------------------------------------------------------===//
310
294
// Accessors
@@ -329,8 +313,12 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
329
313
uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const {
330
314
return getFlattenedIndex(*this, index);
331
315
}
332
- static uint64_t getFlattenedIndex(Attribute elementsAttr ,
316
+ static uint64_t getFlattenedIndex(Type type ,
333
317
ArrayRef<uint64_t> index);
318
+ static uint64_t getFlattenedIndex(Attribute elementsAttr,
319
+ ArrayRef<uint64_t> index) {
320
+ return getFlattenedIndex(elementsAttr.getType(), index);
321
+ }
334
322
335
323
/// Returns the number of elements held by this attribute.
336
324
int64_t getNumElements() const { return getNumElements(*this); }
@@ -350,13 +338,6 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
350
338
!std::is_base_of<Attribute, T>::value,
351
339
ResultT>;
352
340
353
- /// Return the element of this attribute at the given index as a value of
354
- /// type 'T'.
355
- template <typename T>
356
- T getFlatValue(uint64_t index) const {
357
- return *std::next(value_begin<T>(), index);
358
- }
359
-
360
341
/// Return the splat value for this attribute. This asserts that the
361
342
/// attribute corresponds to a splat.
362
343
template <typename T>
@@ -368,7 +349,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
368
349
/// Return the elements of this attribute as a value of type 'T'.
369
350
template <typename T>
370
351
DefaultValueCheckT<T, iterator_range<T>> getValues() const {
371
- return iterator_range<T>( value_begin<T>(), value_end<T>()) ;
352
+ return {Attribute::getType(), value_begin<T>(), value_end<T>()} ;
372
353
}
373
354
template <typename T>
374
355
DefaultValueCheckT<T, iterator<T>> value_begin() const;
@@ -384,12 +365,12 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
384
365
llvm::mapped_iterator<iterator<Attribute>, T (*)(Attribute)>;
385
366
template <typename T>
386
367
using DerivedAttrValueIteratorRange =
387
- llvm::iterator_range <DerivedAttrValueIterator<T>>;
368
+ detail::ElementsAttrRange <DerivedAttrValueIterator<T>>;
388
369
template <typename T, typename = DerivedAttrValueCheckT<T>>
389
370
DerivedAttrValueIteratorRange<T> getValues() const {
390
371
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
391
- return llvm::map_range(getValues<Attribute>(),
392
- static_cast<T (*)(Attribute)>(castFn));
372
+ return {Attribute::getType(), llvm::map_range(getValues<Attribute>(),
373
+ static_cast<T (*)(Attribute)>(castFn))} ;
393
374
}
394
375
template <typename T, typename = DerivedAttrValueCheckT<T>>
395
376
DerivedAttrValueIterator<T> value_begin() const {
@@ -407,8 +388,10 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
407
388
/// return the iterable range. Otherwise, return llvm::None.
408
389
template <typename T>
409
390
DefaultValueCheckT<T, Optional<iterator_range<T>>> tryGetValues() const {
410
- if (Optional<iterator<T>> beginIt = try_value_begin<T>())
411
- return iterator_range<T>(*beginIt, value_end<T>());
391
+ if (Optional<iterator<T>> beginIt = try_value_begin<T>()) {
392
+ return iterator_range<T>(Attribute::getType(), *beginIt,
393
+ value_end<T>());
394
+ }
412
395
return llvm::None;
413
396
}
414
397
template <typename T>
@@ -418,10 +401,15 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
418
401
/// return the iterable range. Otherwise, return llvm::None.
419
402
template <typename T, typename = DerivedAttrValueCheckT<T>>
420
403
Optional<DerivedAttrValueIteratorRange<T>> tryGetValues() const {
404
+ auto values = tryGetValues<Attribute>();
405
+ if (!values)
406
+ return llvm::None;
407
+
421
408
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
422
- if (auto values = tryGetValues<Attribute>())
423
- return llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn));
424
- return llvm::None;
409
+ return DerivedAttrValueIteratorRange<T>(
410
+ Attribute::getType(),
411
+ llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn))
412
+ );
425
413
}
426
414
template <typename T, typename = DerivedAttrValueCheckT<T>>
427
415
Optional<DerivedAttrValueIterator<T>> try_value_begin() const {
0 commit comments