Skip to content

Commit 5cf5708

Browse files
author
Jeff Niu
committed
[mlir][ElementsAttr] Change value_begin_impl to try_value_begin_impl
This patch changes `value_begin_impl` to a faillable `try_value_begin_impl` so that specific cases can fail iteration if the type doesn't match the internal storage. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D132904
1 parent bac3aed commit 5cf5708

File tree

7 files changed

+111
-57
lines changed

7 files changed

+111
-57
lines changed

mlir/include/mlir/IR/BuiltinAttributeInterfaces.td

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,33 +54,36 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
5454
using NonContiguousIterableTypesT = std::tuple<APInt, Attribute>;
5555
```
5656

57-
* Provide a `iterator value_begin_impl(OverloadToken<T>) const` overload for
58-
each iterable type
57+
* Provide a `FailureOr<iterator> try_value_begin_impl(OverloadToken<T>) const`
58+
overload for each iterable type
5959

6060
These overloads should return an iterator to the start of the range for the
61-
respective iterable type. Consider the example i64 elements attribute
62-
described in the previous section. This attribute may define the
63-
value_begin_impl overloads like so:
61+
respective iterable type or fail if the type cannot be iterated. Consider
62+
the example i64 elements attribute described in the previous section. This
63+
attribute may define the value_begin_impl overloads like so:
6464

6565
```c++
6666
/// Provide begin iterators for the various iterable types.
6767
/// * uint64_t
68-
auto value_begin_impl(OverloadToken<uint64_t>) const {
68+
FailureOr<const uint64_t *>
69+
value_begin_impl(OverloadToken<uint64_t>) const {
6970
return getElements().begin();
7071
}
7172
/// * APInt
7273
auto value_begin_impl(OverloadToken<llvm::APInt>) const {
73-
return llvm::map_range(getElements(), [=](uint64_t value) {
74+
auto it = llvm::map_range(getElements(), [=](uint64_t value) {
7475
return llvm::APInt(/*numBits=*/64, value);
7576
}).begin();
77+
return FailureOr<decltype(it)>(std::move(it));
7678
}
7779
/// * Attribute
7880
auto value_begin_impl(OverloadToken<mlir::Attribute>) const {
7981
mlir::Type elementType = getType().getElementType();
80-
return llvm::map_range(getElements(), [=](uint64_t value) {
82+
auto it = llvm::map_range(getElements(), [=](uint64_t value) {
8183
return mlir::IntegerAttr::get(elementType,
8284
llvm::APInt(/*numBits=*/64, value));
8385
}).begin();
86+
return FailureOr<decltype(it)>(std::move(it));
8487
}
8588
```
8689

@@ -244,18 +247,22 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
244247
/*isSplat=*/false, nullptr);
245248
}
246249

247-
auto valueIt = $_attr.value_begin_impl(OverloadToken<T>());
250+
auto valueIt = $_attr.try_value_begin_impl(OverloadToken<T>());
251+
if (::mlir::failed(valueIt))
252+
return ::mlir::failure();
248253
return ::mlir::detail::ElementsAttrIndexer::contiguous(
249-
$_attr.isSplat(), &*valueIt);
254+
$_attr.isSplat(), &**valueIt);
250255
}
251256
/// Build an indexer for the given type `T`, which is represented via a
252257
/// non-contiguous range.
253258
template <typename T>
254259
::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> buildValueResult(
255260
/*isContiguous*/std::false_type) const {
256-
auto valueIt = $_attr.value_begin_impl(OverloadToken<T>());
261+
auto valueIt = $_attr.try_value_begin_impl(OverloadToken<T>());
262+
if (::mlir::failed(valueIt))
263+
return ::mlir::failure();
257264
return ::mlir::detail::ElementsAttrIndexer::nonContiguous(
258-
$_attr.isSplat(), valueIt);
265+
$_attr.isSplat(), *valueIt);
259266
}
260267

261268
public:
@@ -275,7 +282,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
275282
/// type `T`.
276283
template <typename T>
277284
auto value_begin() const {
278-
return $_attr.value_begin_impl(OverloadToken<T>());
285+
return *$_attr.try_value_begin_impl(OverloadToken<T>());
279286
}
280287

281288
/// Return the elements of this attribute as a value of type 'T'.

mlir/include/mlir/IR/BuiltinAttributes.td

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,20 @@ def Builtin_DenseArray : Builtin_Attr<
215215
/// ElementsAttr implementation.
216216
using ContiguousIterableTypesT =
217217
std::tuple<bool, int8_t, int16_t, int32_t, int64_t, float, double>;
218-
const bool *value_begin_impl(OverloadToken<bool>) const;
219-
const int8_t *value_begin_impl(OverloadToken<int8_t>) const;
220-
const int16_t *value_begin_impl(OverloadToken<int16_t>) const;
221-
const int32_t *value_begin_impl(OverloadToken<int32_t>) const;
222-
const int64_t *value_begin_impl(OverloadToken<int64_t>) const;
223-
const float *value_begin_impl(OverloadToken<float>) const;
224-
const double *value_begin_impl(OverloadToken<double>) const;
218+
FailureOr<const bool *>
219+
try_value_begin_impl(OverloadToken<bool>) const;
220+
FailureOr<const int8_t *>
221+
try_value_begin_impl(OverloadToken<int8_t>) const;
222+
FailureOr<const int16_t *>
223+
try_value_begin_impl(OverloadToken<int16_t>) const;
224+
FailureOr<const int32_t *>
225+
try_value_begin_impl(OverloadToken<int32_t>) const;
226+
FailureOr<const int64_t *>
227+
try_value_begin_impl(OverloadToken<int64_t>) const;
228+
FailureOr<const float *>
229+
try_value_begin_impl(OverloadToken<float>) const;
230+
FailureOr<const double *>
231+
try_value_begin_impl(OverloadToken<double>) const;
225232
}];
226233

227234
let genVerifyDecl = 1;
@@ -292,10 +299,11 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
292299
APFloat, std::complex<APFloat>
293300
>;
294301

295-
/// Provide a `value_begin_impl` to enable iteration within ElementsAttr.
302+
/// Provide a `try_value_begin_impl` to enable iteration within
303+
/// ElementsAttr.
296304
template <typename T>
297-
auto value_begin_impl(OverloadToken<T>) const {
298-
return value_begin<T>();
305+
auto try_value_begin_impl(OverloadToken<T>) const {
306+
return ::mlir::success(value_begin<T>());
299307
}
300308

301309
/// Convert endianess of input ArrayRef for big-endian(BE) machines. All of
@@ -421,10 +429,11 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
421429
using ContiguousIterableTypesT = std::tuple<StringRef>;
422430
using NonContiguousIterableTypesT = std::tuple<Attribute>;
423431

424-
/// Provide a `value_begin_impl` to enable iteration within ElementsAttr.
432+
/// Provide a `try_value_begin_impl` to enable iteration within
433+
/// ElementsAttr.
425434
template <typename T>
426-
auto value_begin_impl(OverloadToken<T>) const {
427-
return value_begin<T>();
435+
auto try_value_begin_impl(OverloadToken<T>) const {
436+
return ::mlir::success(value_begin<T>());
428437
}
429438

430439
protected:
@@ -892,10 +901,11 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
892901
>;
893902
using ElementsAttr::Trait<SparseElementsAttr>::getValues;
894903

895-
/// Provide a `value_begin_impl` to enable iteration within ElementsAttr.
904+
/// Provide a `try_value_begin_impl` to enable iteration within
905+
/// ElementsAttr.
896906
template <typename T>
897-
auto value_begin_impl(OverloadToken<T>) const {
898-
return value_begin<T>();
907+
auto try_value_begin_impl(OverloadToken<T>) const {
908+
return ::mlir::success(value_begin<T>());
899909
}
900910

901911
template <typename T>

mlir/include/mlir/Support/LogicalResult.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,13 @@ class [[nodiscard]] FailureOr : public Optional<T> {
9999
using Optional<T>::has_value;
100100
};
101101

102+
/// Wrap a value on the success path in a FailureOr of the same value type.
103+
template <typename T,
104+
typename = std::enable_if_t<!std::is_convertible_v<T, bool>>>
105+
inline auto success(T &&t) {
106+
return FailureOr<std::decay_t<T>>(std::forward<T>(t));
107+
}
108+
102109
/// This class represents success/failure for parsing-like operations that find
103110
/// it important to chain together failable operations with `||`. This is an
104111
/// extended version of `LogicalResult` that allows for explicit conversion to

mlir/lib/IR/BuiltinAttributes.cpp

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -763,26 +763,47 @@ DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError,
763763
return success();
764764
}
765765

766-
const bool *DenseArrayAttr::value_begin_impl(OverloadToken<bool>) const {
767-
return cast<DenseBoolArrayAttr>().asArrayRef().begin();
768-
}
769-
const int8_t *DenseArrayAttr::value_begin_impl(OverloadToken<int8_t>) const {
770-
return cast<DenseI8ArrayAttr>().asArrayRef().begin();
771-
}
772-
const int16_t *DenseArrayAttr::value_begin_impl(OverloadToken<int16_t>) const {
773-
return cast<DenseI16ArrayAttr>().asArrayRef().begin();
774-
}
775-
const int32_t *DenseArrayAttr::value_begin_impl(OverloadToken<int32_t>) const {
776-
return cast<DenseI32ArrayAttr>().asArrayRef().begin();
777-
}
778-
const int64_t *DenseArrayAttr::value_begin_impl(OverloadToken<int64_t>) const {
779-
return cast<DenseI64ArrayAttr>().asArrayRef().begin();
780-
}
781-
const float *DenseArrayAttr::value_begin_impl(OverloadToken<float>) const {
782-
return cast<DenseF32ArrayAttr>().asArrayRef().begin();
783-
}
784-
const double *DenseArrayAttr::value_begin_impl(OverloadToken<double>) const {
785-
return cast<DenseF64ArrayAttr>().asArrayRef().begin();
766+
FailureOr<const bool *>
767+
DenseArrayAttr::try_value_begin_impl(OverloadToken<bool>) const {
768+
if (auto attr = dyn_cast<DenseBoolArrayAttr>())
769+
return attr.asArrayRef().begin();
770+
return failure();
771+
}
772+
FailureOr<const int8_t *>
773+
DenseArrayAttr::try_value_begin_impl(OverloadToken<int8_t>) const {
774+
if (auto attr = dyn_cast<DenseI8ArrayAttr>())
775+
return attr.asArrayRef().begin();
776+
return failure();
777+
}
778+
FailureOr<const int16_t *>
779+
DenseArrayAttr::try_value_begin_impl(OverloadToken<int16_t>) const {
780+
if (auto attr = dyn_cast<DenseI16ArrayAttr>())
781+
return attr.asArrayRef().begin();
782+
return failure();
783+
}
784+
FailureOr<const int32_t *>
785+
DenseArrayAttr::try_value_begin_impl(OverloadToken<int32_t>) const {
786+
if (auto attr = dyn_cast<DenseI32ArrayAttr>())
787+
return attr.asArrayRef().begin();
788+
return failure();
789+
}
790+
FailureOr<const int64_t *>
791+
DenseArrayAttr::try_value_begin_impl(OverloadToken<int64_t>) const {
792+
if (auto attr = dyn_cast<DenseI64ArrayAttr>())
793+
return attr.asArrayRef().begin();
794+
return failure();
795+
}
796+
FailureOr<const float *>
797+
DenseArrayAttr::try_value_begin_impl(OverloadToken<float>) const {
798+
if (auto attr = dyn_cast<DenseF32ArrayAttr>())
799+
return attr.asArrayRef().begin();
800+
return failure();
801+
}
802+
FailureOr<const double *>
803+
DenseArrayAttr::try_value_begin_impl(OverloadToken<double>) const {
804+
if (auto attr = dyn_cast<DenseF64ArrayAttr>())
805+
return attr.asArrayRef().begin();
806+
return failure();
786807
}
787808

788809
namespace {

mlir/test/IR/elements-attr-interface.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,24 @@ arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64>
2828
arith.constant dense<> : tensor<0xi64>
2929

3030
// expected-error@below {{Test iterating `bool`: true, false, true, false, true, false}}
31+
// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
3132
arith.constant array<i1: true, false, true, false, true, false>
3233
// expected-error@below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}}
34+
// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
3335
arith.constant array<i8: 10, 11, -12, 13, 14>
3436
// expected-error@below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}}
37+
// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
3538
arith.constant array<i16: 10, 11, -12, 13, 14>
3639
// expected-error@below {{Test iterating `int32_t`: 10, 11, -12, 13, 14}}
40+
// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
3741
arith.constant array<i32: 10, 11, -12, 13, 14>
3842
// expected-error@below {{Test iterating `int64_t`: 10, 11, -12, 13, 14}}
3943
arith.constant array<i64: 10, 11, -12, 13, 14>
4044
// expected-error@below {{Test iterating `float`: 10.00, 11.00, -12.00, 13.00, 14.00}}
45+
// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
4146
arith.constant array<f32: 10., 11., -12., 13., 14.>
4247
// expected-error@below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}}
48+
// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
4349
arith.constant array<f64: 10., 11., -12., 13., 14.>
4450

4551
// Check that we handle an external constant parsed from the config.

mlir/test/lib/Dialect/Test/TestAttrDefs.td

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,22 +94,23 @@ def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
9494

9595
/// Provide begin iterators for the various iterable types.
9696
// * uint64_t
97-
auto value_begin_impl(OverloadToken<uint64_t>) const {
97+
mlir::FailureOr<const uint64_t *>
98+
try_value_begin_impl(OverloadToken<uint64_t>) const {
9899
return getElements().begin();
99100
}
100101
// * Attribute
101-
auto value_begin_impl(OverloadToken<mlir::Attribute>) const {
102+
auto try_value_begin_impl(OverloadToken<mlir::Attribute>) const {
102103
mlir::Type elementType = getType().getElementType();
103-
return llvm::map_range(getElements(), [=](uint64_t value) {
104+
return mlir::success(llvm::map_range(getElements(), [=](uint64_t value) {
104105
return mlir::IntegerAttr::get(elementType,
105106
llvm::APInt(/*numBits=*/64, value));
106-
}).begin();
107+
}).begin());
107108
}
108109
// * APInt
109-
auto value_begin_impl(OverloadToken<llvm::APInt>) const {
110-
return llvm::map_range(getElements(), [=](uint64_t value) {
110+
auto try_value_begin_impl(OverloadToken<llvm::APInt>) const {
111+
return mlir::success(llvm::map_range(getElements(), [=](uint64_t value) {
111112
return llvm::APInt(/*numBits=*/64, value);
112-
}).begin();
113+
}).begin());
113114
}
114115
}];
115116
let genVerifyDecl = 1;
@@ -257,7 +258,8 @@ def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [
257258

258259
/// Provide begin iterators for the various iterable types.
259260
// * uint64_t
260-
auto value_begin_impl(OverloadToken<uint64_t>) const {
261+
mlir::FailureOr<const uint64_t *>
262+
try_value_begin_impl(OverloadToken<uint64_t>) const {
261263
return getElements().begin();
262264
}
263265
}];

mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ struct TestElementsAttrInterface
6464
.Case([&](DenseF64ArrayAttr attr) {
6565
testElementsAttrIteration<double>(op, attr, "double");
6666
});
67+
testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
6768
continue;
6869
}
6970
testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");

0 commit comments

Comments
 (0)