Skip to content

Commit 1c66bac

Browse files
author
Jeff Niu
committed
[mlir] Fix try_value_begin_impl for DenseElementsAttr
The previous implementation would still crash if the element type was not iterable. This patch changes SparseElementsAttr to properly implement `try_value_begin_impl` according to ElementsAttr and changes DenseElementsAttr to implement `tryGetValues` as the basis for querying element values. Depends on D132904 Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D132958
1 parent 5cf5708 commit 1c66bac

File tree

3 files changed

+154
-228
lines changed

3 files changed

+154
-228
lines changed

mlir/include/mlir/IR/BuiltinAttributes.h

Lines changed: 114 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -392,107 +392,110 @@ class DenseElementsAttr : public Attribute {
392392
return getSplatValue<Attribute>().template cast<T>();
393393
}
394394

395-
/// Return the held element values as a range of integer or floating-point
395+
/// Try to get an iterator of the given type to the start of the held element
396+
/// values. Return failure if the type cannot be iterated.
397+
template <typename T>
398+
auto try_value_begin() const {
399+
auto range = tryGetValues<T>();
400+
using iterator = decltype(range->begin());
401+
return failed(range) ? FailureOr<iterator>(failure()) : range->begin();
402+
}
403+
404+
/// Try to get an iterator of the given type to the end of the held element
405+
/// values. Return failure if the type cannot be iterated.
406+
template <typename T>
407+
auto try_value_end() const {
408+
auto range = tryGetValues<T>();
409+
using iterator = decltype(range->begin());
410+
return failed(range) ? FailureOr<iterator>(failure()) : range->end();
411+
}
412+
413+
/// Return the held element values as a range of the given type.
414+
template <typename T>
415+
auto getValues() const {
416+
auto range = tryGetValues<T>();
417+
assert(succeeded(range) && "element type cannot be iterated");
418+
return std::move(*range);
419+
}
420+
421+
/// Get an iterator of the given type to the start of the held element values.
422+
template <typename T>
423+
auto value_begin() const {
424+
return getValues<T>().begin();
425+
}
426+
427+
/// Get an iterator of the given type to the end of the held element values.
428+
template <typename T>
429+
auto value_end() const {
430+
return getValues<T>().end();
431+
}
432+
433+
/// Try to get the held element values as a range of integer or floating-point
396434
/// values.
397435
template <typename T>
398436
using IntFloatValueTemplateCheckT =
399437
typename std::enable_if<(!std::is_same<T, bool>::value &&
400438
std::numeric_limits<T>::is_integer) ||
401439
is_valid_cpp_fp_type<T>::value>::type;
402440
template <typename T, typename = IntFloatValueTemplateCheckT<T>>
403-
iterator_range_impl<ElementIterator<T>> getValues() const {
404-
assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
405-
std::numeric_limits<T>::is_signed));
441+
FailureOr<iterator_range_impl<ElementIterator<T>>> tryGetValues() const {
442+
if (!isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
443+
std::numeric_limits<T>::is_signed))
444+
return failure();
406445
const char *rawData = getRawData().data();
407446
bool splat = isSplat();
408-
return {getType(), ElementIterator<T>(rawData, splat, 0),
409-
ElementIterator<T>(rawData, splat, getNumElements())};
410-
}
411-
template <typename T, typename = IntFloatValueTemplateCheckT<T>>
412-
ElementIterator<T> value_begin() const {
413-
assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
414-
std::numeric_limits<T>::is_signed));
415-
return ElementIterator<T>(getRawData().data(), isSplat(), 0);
416-
}
417-
template <typename T, typename = IntFloatValueTemplateCheckT<T>>
418-
ElementIterator<T> value_end() const {
419-
assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
420-
std::numeric_limits<T>::is_signed));
421-
return ElementIterator<T>(getRawData().data(), isSplat(), getNumElements());
447+
return iterator_range_impl<ElementIterator<T>>(
448+
getType(), ElementIterator<T>(rawData, splat, 0),
449+
ElementIterator<T>(rawData, splat, getNumElements()));
422450
}
423451

424-
/// Return the held element values as a range of std::complex.
452+
/// Try to get the held element values as a range of std::complex.
425453
template <typename T, typename ElementT>
426454
using ComplexValueTemplateCheckT =
427455
typename std::enable_if<detail::is_complex_t<T>::value &&
428456
(std::numeric_limits<ElementT>::is_integer ||
429457
is_valid_cpp_fp_type<ElementT>::value)>::type;
430458
template <typename T, typename ElementT = typename T::value_type,
431459
typename = ComplexValueTemplateCheckT<T, ElementT>>
432-
iterator_range_impl<ElementIterator<T>> getValues() const {
433-
assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
434-
std::numeric_limits<ElementT>::is_signed));
460+
FailureOr<iterator_range_impl<ElementIterator<T>>> tryGetValues() const {
461+
if (!isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
462+
std::numeric_limits<ElementT>::is_signed))
463+
return failure();
435464
const char *rawData = getRawData().data();
436465
bool splat = isSplat();
437-
return {getType(), ElementIterator<T>(rawData, splat, 0),
438-
ElementIterator<T>(rawData, splat, getNumElements())};
439-
}
440-
template <typename T, typename ElementT = typename T::value_type,
441-
typename = ComplexValueTemplateCheckT<T, ElementT>>
442-
ElementIterator<T> value_begin() const {
443-
assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
444-
std::numeric_limits<ElementT>::is_signed));
445-
return ElementIterator<T>(getRawData().data(), isSplat(), 0);
446-
}
447-
template <typename T, typename ElementT = typename T::value_type,
448-
typename = ComplexValueTemplateCheckT<T, ElementT>>
449-
ElementIterator<T> value_end() const {
450-
assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
451-
std::numeric_limits<ElementT>::is_signed));
452-
return ElementIterator<T>(getRawData().data(), isSplat(), getNumElements());
466+
return iterator_range_impl<ElementIterator<T>>(
467+
getType(), ElementIterator<T>(rawData, splat, 0),
468+
ElementIterator<T>(rawData, splat, getNumElements()));
453469
}
454470

455-
/// Return the held element values as a range of StringRef.
471+
/// Try to get the held element values as a range of StringRef.
456472
template <typename T>
457473
using StringRefValueTemplateCheckT =
458474
typename std::enable_if<std::is_same<T, StringRef>::value>::type;
459475
template <typename T, typename = StringRefValueTemplateCheckT<T>>
460-
iterator_range_impl<ElementIterator<StringRef>> getValues() const {
476+
FailureOr<iterator_range_impl<ElementIterator<StringRef>>>
477+
tryGetValues() const {
461478
auto stringRefs = getRawStringData();
462479
const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
463480
bool splat = isSplat();
464-
return {getType(), ElementIterator<StringRef>(ptr, splat, 0),
465-
ElementIterator<StringRef>(ptr, splat, getNumElements())};
466-
}
467-
template <typename T, typename = StringRefValueTemplateCheckT<T>>
468-
ElementIterator<StringRef> value_begin() const {
469-
const char *ptr = reinterpret_cast<const char *>(getRawStringData().data());
470-
return ElementIterator<StringRef>(ptr, isSplat(), 0);
471-
}
472-
template <typename T, typename = StringRefValueTemplateCheckT<T>>
473-
ElementIterator<StringRef> value_end() const {
474-
const char *ptr = reinterpret_cast<const char *>(getRawStringData().data());
475-
return ElementIterator<StringRef>(ptr, isSplat(), getNumElements());
481+
return iterator_range_impl<ElementIterator<StringRef>>(
482+
getType(), ElementIterator<StringRef>(ptr, splat, 0),
483+
ElementIterator<StringRef>(ptr, splat, getNumElements()));
476484
}
477485

478-
/// Return the held element values as a range of Attributes.
486+
/// Try to get the held element values as a range of Attributes.
479487
template <typename T>
480488
using AttributeValueTemplateCheckT =
481489
typename std::enable_if<std::is_same<T, Attribute>::value>::type;
482490
template <typename T, typename = AttributeValueTemplateCheckT<T>>
483-
iterator_range_impl<AttributeElementIterator> getValues() const {
484-
return {getType(), value_begin<Attribute>(), value_end<Attribute>()};
485-
}
486-
template <typename T, typename = AttributeValueTemplateCheckT<T>>
487-
AttributeElementIterator value_begin() const {
488-
return AttributeElementIterator(*this, 0);
489-
}
490-
template <typename T, typename = AttributeValueTemplateCheckT<T>>
491-
AttributeElementIterator value_end() const {
492-
return AttributeElementIterator(*this, getNumElements());
491+
FailureOr<iterator_range_impl<AttributeElementIterator>>
492+
tryGetValues() const {
493+
return iterator_range_impl<AttributeElementIterator>(
494+
getType(), AttributeElementIterator(*this, 0),
495+
AttributeElementIterator(*this, getNumElements()));
493496
}
494497

495-
/// Return the held element values a range of T, where T is a derived
498+
/// Try to get the held element values a range of T, where T is a derived
496499
/// attribute type.
497500
template <typename T>
498501
using DerivedAttrValueTemplateCheckT =
@@ -510,115 +513,71 @@ class DenseElementsAttr : public Attribute {
510513
T mapElement(Attribute attr) const { return attr.cast<T>(); }
511514
};
512515
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
513-
iterator_range_impl<DerivedAttributeElementIterator<T>> getValues() const {
516+
FailureOr<iterator_range_impl<DerivedAttributeElementIterator<T>>>
517+
tryGetValues() const {
514518
using DerivedIterT = DerivedAttributeElementIterator<T>;
515-
return {getType(), DerivedIterT(value_begin<Attribute>()),
516-
DerivedIterT(value_end<Attribute>())};
517-
}
518-
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
519-
DerivedAttributeElementIterator<T> value_begin() const {
520-
return {value_begin<Attribute>()};
521-
}
522-
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
523-
DerivedAttributeElementIterator<T> value_end() const {
524-
return {value_end<Attribute>()};
519+
return iterator_range_impl<DerivedIterT>(
520+
getType(), DerivedIterT(value_begin<Attribute>()),
521+
DerivedIterT(value_end<Attribute>()));
525522
}
526523

527-
/// Return the held element values as a range of bool. The element type of
524+
/// Try to get the held element values as a range of bool. The element type of
528525
/// this attribute must be of integer type of bitwidth 1.
529526
template <typename T>
530527
using BoolValueTemplateCheckT =
531528
typename std::enable_if<std::is_same<T, bool>::value>::type;
532529
template <typename T, typename = BoolValueTemplateCheckT<T>>
533-
iterator_range_impl<BoolElementIterator> getValues() const {
534-
assert(isValidBool() && "bool is not the value of this elements attribute");
535-
return {getType(), BoolElementIterator(*this, 0),
536-
BoolElementIterator(*this, getNumElements())};
537-
}
538-
template <typename T, typename = BoolValueTemplateCheckT<T>>
539-
BoolElementIterator value_begin() const {
540-
assert(isValidBool() && "bool is not the value of this elements attribute");
541-
return BoolElementIterator(*this, 0);
542-
}
543-
template <typename T, typename = BoolValueTemplateCheckT<T>>
544-
BoolElementIterator value_end() const {
545-
assert(isValidBool() && "bool is not the value of this elements attribute");
546-
return BoolElementIterator(*this, getNumElements());
530+
FailureOr<iterator_range_impl<BoolElementIterator>> tryGetValues() const {
531+
if (!isValidBool())
532+
return failure();
533+
return iterator_range_impl<BoolElementIterator>(
534+
getType(), BoolElementIterator(*this, 0),
535+
BoolElementIterator(*this, getNumElements()));
547536
}
548537

549-
/// Return the held element values as a range of APInts. The element type of
550-
/// this attribute must be of integer type.
538+
/// Try to get the held element values as a range of APInts. The element type
539+
/// of this attribute must be of integer type.
551540
template <typename T>
552541
using APIntValueTemplateCheckT =
553542
typename std::enable_if<std::is_same<T, APInt>::value>::type;
554543
template <typename T, typename = APIntValueTemplateCheckT<T>>
555-
iterator_range_impl<IntElementIterator> getValues() const {
556-
assert(getElementType().isIntOrIndex() && "expected integral type");
557-
return {getType(), raw_int_begin(), raw_int_end()};
558-
}
559-
template <typename T, typename = APIntValueTemplateCheckT<T>>
560-
IntElementIterator value_begin() const {
561-
assert(getElementType().isIntOrIndex() && "expected integral type");
562-
return raw_int_begin();
563-
}
564-
template <typename T, typename = APIntValueTemplateCheckT<T>>
565-
IntElementIterator value_end() const {
566-
assert(getElementType().isIntOrIndex() && "expected integral type");
567-
return raw_int_end();
544+
FailureOr<iterator_range_impl<IntElementIterator>> tryGetValues() const {
545+
if (!getElementType().isIntOrIndex())
546+
return failure();
547+
return iterator_range_impl<IntElementIterator>(getType(), raw_int_begin(),
548+
raw_int_end());
568549
}
569550

570-
/// Return the held element values as a range of complex APInts. The element
571-
/// type of this attribute must be a complex of integer type.
551+
/// Try to get the held element values as a range of complex APInts. The
552+
/// element type of this attribute must be a complex of integer type.
572553
template <typename T>
573554
using ComplexAPIntValueTemplateCheckT = typename std::enable_if<
574555
std::is_same<T, std::complex<APInt>>::value>::type;
575556
template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
576-
iterator_range_impl<ComplexIntElementIterator> getValues() const {
577-
return getComplexIntValues();
578-
}
579-
template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
580-
ComplexIntElementIterator value_begin() const {
581-
return complex_value_begin();
582-
}
583-
template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
584-
ComplexIntElementIterator value_end() const {
585-
return complex_value_end();
557+
FailureOr<iterator_range_impl<ComplexIntElementIterator>>
558+
tryGetValues() const {
559+
return tryGetComplexIntValues();
586560
}
587561

588-
/// Return the held element values as a range of APFloat. The element type of
589-
/// this attribute must be of float type.
562+
/// Try to get the held element values as a range of APFloat. The element type
563+
/// of this attribute must be of float type.
590564
template <typename T>
591565
using APFloatValueTemplateCheckT =
592566
typename std::enable_if<std::is_same<T, APFloat>::value>::type;
593567
template <typename T, typename = APFloatValueTemplateCheckT<T>>
594-
iterator_range_impl<FloatElementIterator> getValues() const {
595-
return getFloatValues();
596-
}
597-
template <typename T, typename = APFloatValueTemplateCheckT<T>>
598-
FloatElementIterator value_begin() const {
599-
return float_value_begin();
600-
}
601-
template <typename T, typename = APFloatValueTemplateCheckT<T>>
602-
FloatElementIterator value_end() const {
603-
return float_value_end();
568+
FailureOr<iterator_range_impl<FloatElementIterator>> tryGetValues() const {
569+
return tryGetFloatValues();
604570
}
605571

606-
/// Return the held element values as a range of complex APFloat. The element
607-
/// type of this attribute must be a complex of float type.
572+
/// Try to get the held element values as a range of complex APFloat. The
573+
/// element type of this attribute must be a complex of float type.
608574
template <typename T>
609575
using ComplexAPFloatValueTemplateCheckT = typename std::enable_if<
610576
std::is_same<T, std::complex<APFloat>>::value>::type;
611577
template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
612-
iterator_range_impl<ComplexFloatElementIterator> getValues() const {
613-
return getComplexFloatValues();
614-
}
615-
template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
616-
ComplexFloatElementIterator value_begin() const {
617-
return complex_float_value_begin();
618-
}
619-
template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
620-
ComplexFloatElementIterator value_end() const {
621-
return complex_float_value_end();
578+
FailureOr<iterator_range_impl<ComplexFloatElementIterator>>
579+
tryGetValues() const {
580+
return tryGetComplexFloatValues();
622581
}
623582

624583
/// Return the raw storage data held by this attribute. Users should generally
@@ -687,16 +646,12 @@ class DenseElementsAttr : public Attribute {
687646
IntElementIterator raw_int_end() const {
688647
return IntElementIterator(*this, getNumElements());
689648
}
690-
iterator_range_impl<ComplexIntElementIterator> getComplexIntValues() const;
691-
ComplexIntElementIterator complex_value_begin() const;
692-
ComplexIntElementIterator complex_value_end() const;
693-
iterator_range_impl<FloatElementIterator> getFloatValues() const;
694-
FloatElementIterator float_value_begin() const;
695-
FloatElementIterator float_value_end() const;
696-
iterator_range_impl<ComplexFloatElementIterator>
697-
getComplexFloatValues() const;
698-
ComplexFloatElementIterator complex_float_value_begin() const;
699-
ComplexFloatElementIterator complex_float_value_end() const;
649+
FailureOr<iterator_range_impl<ComplexIntElementIterator>>
650+
tryGetComplexIntValues() const;
651+
FailureOr<iterator_range_impl<FloatElementIterator>>
652+
tryGetFloatValues() const;
653+
FailureOr<iterator_range_impl<ComplexFloatElementIterator>>
654+
tryGetComplexFloatValues() const;
700655

701656
/// Overload of the raw 'get' method that asserts that the given type is of
702657
/// complex type. This method is used to verify type invariants that the
@@ -973,8 +928,8 @@ class DenseFPElementsAttr : public DenseIntOrFPElementsAttr {
973928
function_ref<APInt(const APFloat &)> mapping) const;
974929

975930
/// Iterator access to the float element values.
976-
iterator begin() const { return float_value_begin(); }
977-
iterator end() const { return float_value_end(); }
931+
iterator begin() const { return tryGetFloatValues()->begin(); }
932+
iterator end() const { return tryGetFloatValues()->end(); }
978933

979934
/// Method for supporting type inquiry through isa, cast and dyn_cast.
980935
static bool classof(Attribute attr);
@@ -1026,12 +981,15 @@ class DenseIntElementsAttr : public DenseIntOrFPElementsAttr {
1026981
//===----------------------------------------------------------------------===//
1027982

1028983
template <typename T>
1029-
auto SparseElementsAttr::value_begin() const -> iterator<T> {
984+
auto SparseElementsAttr::try_value_begin_impl(OverloadToken<T>) const
985+
-> FailureOr<iterator<T>> {
1030986
auto zeroValue = getZeroValue<T>();
1031-
auto valueIt = getValues().value_begin<T>();
987+
auto valueIt = getValues().try_value_begin<T>();
988+
if (failed(valueIt))
989+
return failure();
1032990
const std::vector<ptrdiff_t> flatSparseIndices(getFlattenedSparseIndices());
1033991
std::function<T(ptrdiff_t)> mapFn =
1034-
[flatSparseIndices{flatSparseIndices}, valueIt{std::move(valueIt)},
992+
[flatSparseIndices{flatSparseIndices}, valueIt{std::move(*valueIt)},
1035993
zeroValue{std::move(zeroValue)}](ptrdiff_t index) {
1036994
// Try to map the current index to one of the sparse indices.
1037995
for (unsigned i = 0, e = flatSparseIndices.size(); i != e; ++i)

0 commit comments

Comments
 (0)