Skip to content

Commit ae40d62

Browse files
committed
[mlir] Refactor ElementsAttr's value access API
There are several aspects of the API that either aren't easy to use, or are deceptively easy to do the wrong thing. The main change of this commit is to remove all of the `getValue<T>`/`getFlatValue<T>` from ElementsAttr and instead provide operator[] methods on the ranges returned by `getValues<T>`. This provides a much more convenient API for the value ranges. It also removes the easy-to-be-inefficient nature of getValue/getFlatValue, which under the hood would construct a new range for the type `T`. Constructing a range is not necessarily cheap in all cases, and could lead to very poor performance if used within a loop; i.e. if you were to naively write something like: ``` DenseElementsAttr attr = ...; for (int i = 0; i < size; ++i) { // We are internally rebuilding the APFloat value range on each iteration!! APFloat it = attr.getFlatValue<APFloat>(i); } ``` Differential Revision: https://reviews.llvm.org/D113229
1 parent 4a0c89a commit ae40d62

File tree

25 files changed

+241
-315
lines changed

25 files changed

+241
-315
lines changed

mlir/include/mlir/IR/BuiltinAttributeInterfaces.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,33 @@ class ElementsAttrIterator
227227
ElementsAttrIndexer indexer;
228228
ptrdiff_t index;
229229
};
230+
231+
/// This class provides iterator utilities for an ElementsAttr range.
232+
template <typename IteratorT>
233+
class ElementsAttrRange : public llvm::iterator_range<IteratorT> {
234+
public:
235+
using reference = typename IteratorT::reference;
236+
237+
ElementsAttrRange(Type shapeType,
238+
const llvm::iterator_range<IteratorT> &range)
239+
: llvm::iterator_range<IteratorT>(range), shapeType(shapeType) {}
240+
ElementsAttrRange(Type shapeType, IteratorT beginIt, IteratorT endIt)
241+
: ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {}
242+
243+
/// Return the value at the given index.
244+
reference operator[](ArrayRef<uint64_t> index) const;
245+
reference operator[](uint64_t index) const {
246+
return *std::next(this->begin(), index);
247+
}
248+
249+
/// Return the size of this range.
250+
size_t size() const { return llvm::size(*this); }
251+
252+
private:
253+
/// The shaped type of the parent ElementsAttr.
254+
Type shapeType;
255+
};
256+
230257
} // namespace detail
231258

232259
//===----------------------------------------------------------------------===//
@@ -256,6 +283,16 @@ verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape,
256283
//===----------------------------------------------------------------------===//
257284

258285
namespace mlir {
286+
namespace detail {
287+
/// Return the value at the given index.
288+
template <typename IteratorT>
289+
auto ElementsAttrRange<IteratorT>::operator[](ArrayRef<uint64_t> index) const
290+
-> reference {
291+
// Skip to the element corresponding to the flattened index.
292+
return (*this)[ElementsAttr::getFlattenedIndex(shapeType, index)];
293+
}
294+
} // namespace detail
295+
259296
/// Return the elements of this attribute as a value of type 'T'.
260297
template <typename T>
261298
auto ElementsAttr::value_begin() const -> DefaultValueCheckT<T, iterator<T>> {

mlir/include/mlir/IR/BuiltinAttributeInterfaces.td

Lines changed: 32 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -158,27 +158,6 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
158158
];
159159

160160
string ElementsAttrInterfaceAccessors = [{
161-
/// Return the attribute value at the given index. The index is expected to
162-
/// refer to a valid element.
163-
Attribute getValue(ArrayRef<uint64_t> index) const {
164-
return getValue<Attribute>(index);
165-
}
166-
167-
/// Return the value of type 'T' at the given index, where 'T' corresponds
168-
/// to an Attribute type.
169-
template <typename T>
170-
std::enable_if_t<!std::is_same<T, ::mlir::Attribute>::value &&
171-
std::is_base_of<T, ::mlir::Attribute>::value>
172-
getValue(ArrayRef<uint64_t> index) const {
173-
return getValue(index).template dyn_cast_or_null<T>();
174-
}
175-
176-
/// Return the value of type 'T' at the given index.
177-
template <typename T>
178-
T getValue(ArrayRef<uint64_t> index) const {
179-
return getFlatValue<T>(getFlattenedIndex(index));
180-
}
181-
182161
/// Return the number of elements held by this attribute.
183162
int64_t size() const { return getNumElements(); }
184163

@@ -281,6 +260,14 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
281260
// Value Iteration
282261
//===------------------------------------------------------------------===//
283262

263+
/// The iterator for the given element type T.
264+
template <typename T, typename AttrT = ConcreteAttr>
265+
using iterator = decltype(std::declval<AttrT>().template value_begin<T>());
266+
/// The iterator range over the given element T.
267+
template <typename T, typename AttrT = ConcreteAttr>
268+
using iterator_range =
269+
decltype(std::declval<AttrT>().template getValues<T>());
270+
284271
/// Return an iterator to the first element of this attribute as a value of
285272
/// type `T`.
286273
template <typename T>
@@ -292,19 +279,16 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
292279
template <typename T>
293280
auto getValues() const {
294281
auto beginIt = $_attr.template value_begin<T>();
295-
return llvm::make_range(beginIt, std::next(beginIt, size()));
296-
}
297-
/// Return the value at the given flattened index.
298-
template <typename T> T getFlatValue(uint64_t index) const {
299-
return *std::next($_attr.template value_begin<T>(), index);
282+
return detail::ElementsAttrRange<decltype(beginIt)>(
283+
Attribute($_attr).getType(), beginIt, std::next(beginIt, size()));
300284
}
301285
}] # ElementsAttrInterfaceAccessors;
302286

303287
let extraClassDeclaration = [{
304288
template <typename T>
305289
using iterator = detail::ElementsAttrIterator<T>;
306290
template <typename T>
307-
using iterator_range = llvm::iterator_range<iterator<T>>;
291+
using iterator_range = detail::ElementsAttrRange<iterator<T>>;
308292

309293
//===------------------------------------------------------------------===//
310294
// Accessors
@@ -329,8 +313,12 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
329313
uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const {
330314
return getFlattenedIndex(*this, index);
331315
}
332-
static uint64_t getFlattenedIndex(Attribute elementsAttr,
316+
static uint64_t getFlattenedIndex(Type type,
333317
ArrayRef<uint64_t> index);
318+
static uint64_t getFlattenedIndex(Attribute elementsAttr,
319+
ArrayRef<uint64_t> index) {
320+
return getFlattenedIndex(elementsAttr.getType(), index);
321+
}
334322

335323
/// Returns the number of elements held by this attribute.
336324
int64_t getNumElements() const { return getNumElements(*this); }
@@ -350,13 +338,6 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
350338
!std::is_base_of<Attribute, T>::value,
351339
ResultT>;
352340

353-
/// Return the element of this attribute at the given index as a value of
354-
/// type 'T'.
355-
template <typename T>
356-
T getFlatValue(uint64_t index) const {
357-
return *std::next(value_begin<T>(), index);
358-
}
359-
360341
/// Return the splat value for this attribute. This asserts that the
361342
/// attribute corresponds to a splat.
362343
template <typename T>
@@ -368,7 +349,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
368349
/// Return the elements of this attribute as a value of type 'T'.
369350
template <typename T>
370351
DefaultValueCheckT<T, iterator_range<T>> getValues() const {
371-
return iterator_range<T>(value_begin<T>(), value_end<T>());
352+
return {Attribute::getType(), value_begin<T>(), value_end<T>()};
372353
}
373354
template <typename T>
374355
DefaultValueCheckT<T, iterator<T>> value_begin() const;
@@ -384,12 +365,12 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
384365
llvm::mapped_iterator<iterator<Attribute>, T (*)(Attribute)>;
385366
template <typename T>
386367
using DerivedAttrValueIteratorRange =
387-
llvm::iterator_range<DerivedAttrValueIterator<T>>;
368+
detail::ElementsAttrRange<DerivedAttrValueIterator<T>>;
388369
template <typename T, typename = DerivedAttrValueCheckT<T>>
389370
DerivedAttrValueIteratorRange<T> getValues() const {
390371
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
391-
return llvm::map_range(getValues<Attribute>(),
392-
static_cast<T (*)(Attribute)>(castFn));
372+
return {Attribute::getType(), llvm::map_range(getValues<Attribute>(),
373+
static_cast<T (*)(Attribute)>(castFn))};
393374
}
394375
template <typename T, typename = DerivedAttrValueCheckT<T>>
395376
DerivedAttrValueIterator<T> value_begin() const {
@@ -407,8 +388,10 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
407388
/// return the iterable range. Otherwise, return llvm::None.
408389
template <typename T>
409390
DefaultValueCheckT<T, Optional<iterator_range<T>>> tryGetValues() const {
410-
if (Optional<iterator<T>> beginIt = try_value_begin<T>())
411-
return iterator_range<T>(*beginIt, value_end<T>());
391+
if (Optional<iterator<T>> beginIt = try_value_begin<T>()) {
392+
return iterator_range<T>(Attribute::getType(), *beginIt,
393+
value_end<T>());
394+
}
412395
return llvm::None;
413396
}
414397
template <typename T>
@@ -418,10 +401,15 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
418401
/// return the iterable range. Otherwise, return llvm::None.
419402
template <typename T, typename = DerivedAttrValueCheckT<T>>
420403
Optional<DerivedAttrValueIteratorRange<T>> tryGetValues() const {
404+
auto values = tryGetValues<Attribute>();
405+
if (!values)
406+
return llvm::None;
407+
421408
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
422-
if (auto values = tryGetValues<Attribute>())
423-
return llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn));
424-
return llvm::None;
409+
return DerivedAttrValueIteratorRange<T>(
410+
Attribute::getType(),
411+
llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn))
412+
);
425413
}
426414
template <typename T, typename = DerivedAttrValueCheckT<T>>
427415
Optional<DerivedAttrValueIterator<T>> try_value_begin() const {

0 commit comments

Comments
 (0)