Skip to content

Commit 7358c26

Browse files
authored
[flang] Check for overflows in RESHAPE folding (#68342)
TotalElementCount() was modified to return std::optional<uint64_t>, where std::nullopt means overflow occurred. Besides the additional check in RESHAPE folding, all callers of TotalElementCount() were changed, to also check for overflows.
1 parent 77e88db commit 7358c26

File tree

8 files changed

+107
-46
lines changed

8 files changed

+107
-46
lines changed

flang/include/flang/Evaluate/constant.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ inline int GetRank(const ConstantSubscripts &s) {
4646
return static_cast<int>(s.size());
4747
}
4848

49-
std::size_t TotalElementCount(const ConstantSubscripts &);
49+
// Returns the number of elements of shape, if no overflow occurs.
50+
std::optional<uint64_t> TotalElementCount(const ConstantSubscripts &shape);
5051

5152
// Validate dimension re-ordering like ORDER in RESHAPE.
5253
// On success, return a vector that can be used as dimOrder in

flang/include/flang/Evaluate/initial-image.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,14 @@ namespace Fortran::evaluate {
2222

2323
class InitialImage {
2424
public:
25-
enum Result { Ok, NotAConstant, OutOfRange, SizeMismatch, LengthMismatch };
25+
enum Result {
26+
Ok,
27+
NotAConstant,
28+
OutOfRange,
29+
SizeMismatch,
30+
LengthMismatch,
31+
TooManyElems
32+
};
2633

2734
explicit InitialImage(std::size_t bytes) : data_(bytes) {}
2835
InitialImage(InitialImage &&that) = default;
@@ -60,7 +67,11 @@ class InitialImage {
6067
if (offset < 0 || offset + bytes > data_.size()) {
6168
return OutOfRange;
6269
} else {
63-
auto elements{TotalElementCount(x.shape())};
70+
auto optElements{TotalElementCount(x.shape())};
71+
if (!optElements) {
72+
return TooManyElems;
73+
}
74+
auto elements{*optElements};
6475
auto elementBytes{bytes > 0 ? bytes / elements : 0};
6576
if (elements * elementBytes != bytes) {
6677
return SizeMismatch;

flang/lib/Evaluate/constant.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,18 @@ ConstantSubscript ConstantBounds::SubscriptsToOffset(
8080
return offset;
8181
}
8282

83-
std::size_t TotalElementCount(const ConstantSubscripts &shape) {
84-
return static_cast<std::size_t>(GetSize(shape));
83+
std::optional<uint64_t> TotalElementCount(const ConstantSubscripts &shape) {
84+
uint64_t size{1};
85+
for (auto dim : shape) {
86+
CHECK(dim >= 0);
87+
uint64_t osize{size};
88+
size = osize * dim;
89+
if (size > std::numeric_limits<decltype(dim)>::max() ||
90+
(dim != 0 && size / dim != osize)) {
91+
return std::nullopt;
92+
}
93+
}
94+
return static_cast<uint64_t>(GetSize(shape));
8595
}
8696

8797
bool ConstantBounds::IncrementSubscripts(
@@ -135,7 +145,7 @@ template <typename RESULT, typename ELEMENT>
135145
ConstantBase<RESULT, ELEMENT>::ConstantBase(
136146
std::vector<Element> &&x, ConstantSubscripts &&sh, Result res)
137147
: ConstantBounds(std::move(sh)), result_{res}, values_(std::move(x)) {
138-
CHECK(size() == TotalElementCount(shape()));
148+
CHECK(TotalElementCount(shape()) && size() == *TotalElementCount(shape()));
139149
}
140150

141151
template <typename RESULT, typename ELEMENT>
@@ -149,7 +159,9 @@ bool ConstantBase<RESULT, ELEMENT>::operator==(const ConstantBase &that) const {
149159
template <typename RESULT, typename ELEMENT>
150160
auto ConstantBase<RESULT, ELEMENT>::Reshape(
151161
const ConstantSubscripts &dims) const -> std::vector<Element> {
152-
std::size_t n{TotalElementCount(dims)};
162+
std::optional<uint64_t> optN{TotalElementCount(dims)};
163+
CHECK(optN);
164+
uint64_t n{*optN};
153165
CHECK(!empty() || n == 0);
154166
std::vector<Element> elements;
155167
auto iter{values().cbegin()};
@@ -209,7 +221,8 @@ template <int KIND>
209221
Constant<Type<TypeCategory::Character, KIND>>::Constant(ConstantSubscript len,
210222
std::vector<Scalar<Result>> &&strings, ConstantSubscripts &&sh)
211223
: ConstantBounds(std::move(sh)), length_{len} {
212-
CHECK(strings.size() == TotalElementCount(shape()));
224+
CHECK(TotalElementCount(shape()) &&
225+
strings.size() == *TotalElementCount(shape()));
213226
values_.assign(strings.size() * length_,
214227
static_cast<typename Scalar<Result>::value_type>(' '));
215228
ConstantSubscript at{0};
@@ -236,7 +249,9 @@ bool Constant<Type<TypeCategory::Character, KIND>>::empty() const {
236249
template <int KIND>
237250
std::size_t Constant<Type<TypeCategory::Character, KIND>>::size() const {
238251
if (length_ == 0) {
239-
return TotalElementCount(shape());
252+
std::optional<uint64_t> n{TotalElementCount(shape())};
253+
CHECK(n);
254+
return *n;
240255
} else {
241256
return static_cast<ConstantSubscript>(values_.size()) / length_;
242257
}
@@ -274,7 +289,9 @@ auto Constant<Type<TypeCategory::Character, KIND>>::Substring(
274289
template <int KIND>
275290
auto Constant<Type<TypeCategory::Character, KIND>>::Reshape(
276291
ConstantSubscripts &&dims) const -> Constant<Result> {
277-
std::size_t n{TotalElementCount(dims)};
292+
std::optional<uint64_t> optN{TotalElementCount(dims)};
293+
CHECK(optN);
294+
uint64_t n{*optN};
278295
CHECK(!empty() || n == 0);
279296
std::vector<Element> elements;
280297
ConstantSubscript at{0},

flang/lib/Evaluate/fold-designator.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,9 @@ ConstantObjectPointer ConstantObjectPointer::From(
373373
FoldingContext &context, const Expr<SomeType> &expr) {
374374
auto extents{GetConstantExtents(context, expr)};
375375
CHECK(extents);
376-
std::size_t elements{TotalElementCount(*extents)};
376+
std::optional<uint64_t> optElements{TotalElementCount(*extents)};
377+
CHECK(optElements);
378+
uint64_t elements{*optElements};
377379
CHECK(elements > 0);
378380
int rank{GetRank(*extents)};
379381
ConstantSubscripts at(rank, 1);

flang/lib/Evaluate/fold-implementation.h

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,13 @@ Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
492492
CHECK(rank == GetRank(shape));
493493
// Compute all the scalar values of the results
494494
std::vector<Scalar<TR>> results;
495-
if (TotalElementCount(shape) > 0) {
495+
std::optional<uint64_t> n{TotalElementCount(shape)};
496+
if (!n) {
497+
context.messages().Say(
498+
"Too many elements in elemental intrinsic function result"_err_en_US);
499+
return Expr<TR>{std::move(funcRef)};
500+
}
501+
if (*n > 0) {
496502
ConstantBounds bounds{shape};
497503
ConstantSubscripts resultIndex(rank, 1);
498504
ConstantSubscripts argIndex[]{std::get<I>(*args)->lbounds()...};
@@ -879,33 +885,40 @@ template <typename T> Expr<T> Folder<T>::RESHAPE(FunctionRef<T> &&funcRef) {
879885
context_.messages().Say(
880886
"'shape=' argument must not have a negative extent"_err_en_US);
881887
} else {
882-
int rank{GetRank(shape.value())};
883-
std::size_t resultElements{TotalElementCount(shape.value())};
884-
std::optional<std::vector<int>> dimOrder;
885-
if (order) {
886-
dimOrder = ValidateDimensionOrder(rank, *order);
887-
}
888-
std::vector<int> *dimOrderPtr{dimOrder ? &dimOrder.value() : nullptr};
889-
if (order && !dimOrder) {
890-
context_.messages().Say("Invalid 'order=' argument in RESHAPE"_err_en_US);
891-
} else if (resultElements > source->size() && (!pad || pad->empty())) {
888+
std::optional<uint64_t> optResultElement{TotalElementCount(shape.value())};
889+
if (!optResultElement) {
892890
context_.messages().Say(
893-
"Too few elements in 'source=' argument and 'pad=' "
894-
"argument is not present or has null size"_err_en_US);
891+
"'shape=' argument has too many elements"_err_en_US);
895892
} else {
896-
Constant<T> result{!source->empty() || !pad
897-
? source->Reshape(std::move(shape.value()))
898-
: pad->Reshape(std::move(shape.value()))};
899-
ConstantSubscripts subscripts{result.lbounds()};
900-
auto copied{result.CopyFrom(*source,
901-
std::min(source->size(), resultElements), subscripts, dimOrderPtr)};
902-
if (copied < resultElements) {
903-
CHECK(pad);
904-
copied += result.CopyFrom(
905-
*pad, resultElements - copied, subscripts, dimOrderPtr);
893+
int rank{GetRank(shape.value())};
894+
uint64_t resultElements{*optResultElement};
895+
std::optional<std::vector<int>> dimOrder;
896+
if (order) {
897+
dimOrder = ValidateDimensionOrder(rank, *order);
898+
}
899+
std::vector<int> *dimOrderPtr{dimOrder ? &dimOrder.value() : nullptr};
900+
if (order && !dimOrder) {
901+
context_.messages().Say(
902+
"Invalid 'order=' argument in RESHAPE"_err_en_US);
903+
} else if (resultElements > source->size() && (!pad || pad->empty())) {
904+
context_.messages().Say(
905+
"Too few elements in 'source=' argument and 'pad=' "
906+
"argument is not present or has null size"_err_en_US);
907+
} else {
908+
Constant<T> result{!source->empty() || !pad
909+
? source->Reshape(std::move(shape.value()))
910+
: pad->Reshape(std::move(shape.value()))};
911+
ConstantSubscripts subscripts{result.lbounds()};
912+
auto copied{result.CopyFrom(*source,
913+
std::min(source->size(), resultElements), subscripts, dimOrderPtr)};
914+
if (copied < resultElements) {
915+
CHECK(pad);
916+
copied += result.CopyFrom(
917+
*pad, resultElements - copied, subscripts, dimOrderPtr);
918+
}
919+
CHECK(copied == resultElements);
920+
return Expr<T>{std::move(result)};
906921
}
907-
CHECK(copied == resultElements);
908-
return Expr<T>{std::move(result)};
909922
}
910923
}
911924
// Invalid, prevent re-folding
@@ -944,14 +957,19 @@ template <typename T> Expr<T> Folder<T>::SPREAD(FunctionRef<T> &&funcRef) {
944957
ConstantSubscripts shape{source->shape()};
945958
shape.insert(shape.begin() + *dim - 1, *ncopies);
946959
Constant<T> spread{source->Reshape(std::move(shape))};
947-
std::vector<int> dimOrder;
948-
for (int j{0}; j < sourceRank; ++j) {
949-
dimOrder.push_back(j < *dim - 1 ? j : j + 1);
950-
}
951-
dimOrder.push_back(*dim - 1);
952-
ConstantSubscripts at{spread.lbounds()}; // all 1
953-
spread.CopyFrom(*source, TotalElementCount(spread.shape()), at, &dimOrder);
954-
return Expr<T>{std::move(spread)};
960+
std::optional<uint64_t> n{TotalElementCount(spread.shape())};
961+
if (!n) {
962+
context_.messages().Say("Too many elements in SPREAD result"_err_en_US);
963+
} else {
964+
std::vector<int> dimOrder;
965+
for (int j{0}; j < sourceRank; ++j) {
966+
dimOrder.push_back(j < *dim - 1 ? j : j + 1);
967+
}
968+
dimOrder.push_back(*dim - 1);
969+
ConstantSubscripts at{spread.lbounds()}; // all 1
970+
spread.CopyFrom(*source, *n, at, &dimOrder);
971+
return Expr<T>{std::move(spread)};
972+
}
955973
}
956974
// Invalid, prevent re-folding
957975
return MakeInvalidIntrinsic(std::move(funcRef));

flang/lib/Evaluate/initial-image.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ auto InitialImage::Add(ConstantSubscript offset, std::size_t bytes,
1818
if (offset < 0 || offset + bytes > data_.size()) {
1919
return OutOfRange;
2020
} else {
21-
auto elements{TotalElementCount(x.shape())};
21+
auto optElements{TotalElementCount(x.shape())};
22+
if (!optElements) {
23+
return TooManyElems;
24+
}
25+
auto elements{*optElements};
2226
auto elementBytes{bytes > 0 ? bytes / elements : 0};
2327
if (elements * elementBytes != bytes) {
2428
return SizeMismatch;
@@ -89,7 +93,9 @@ class AsConstantHelper {
8993
}
9094
using Const = Constant<T>;
9195
using Scalar = typename Const::Element;
92-
std::size_t elements{TotalElementCount(extents_)};
96+
std::optional<uint64_t> optElements{TotalElementCount(extents_)};
97+
CHECK(optElements);
98+
uint64_t elements{*optElements};
9399
std::vector<Scalar> typedValue(elements);
94100
auto elemBytes{ToInt64(type_.MeasureSizeInBytes(
95101
context_, GetRank(extents_) > 0, charLength_))};

flang/lib/Semantics/data-to-inits.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,8 @@ bool DataInitializationCompiler<DSV>::InitElement(
462462
"DATA statement value '%s' for '%s' has the wrong length"_warn_en_US,
463463
folded.AsFortran(), DescribeElement());
464464
return true;
465+
} else if (status == evaluate::InitialImage::TooManyElems) {
466+
exprAnalyzer_.Say("DATA statement has too many elements"_err_en_US);
465467
} else {
466468
CHECK(exprAnalyzer_.context().AnyFatalError());
467469
}

flang/test/Semantics/reshape.f90

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ program reshaper
4949
integer, parameter :: array21(I64_MAX - 2 : I64_MAX) = [1, 2, 3]
5050
integer, parameter :: array22(2) = RESHAPE(array21, [2])
5151

52+
integer(8), parameter :: huge_shape(2) = [I64_MAX, I64_MAX]
53+
!ERROR: 'shape=' argument has too many elements
54+
integer :: array23(I64_MAX, I64_MAX) = RESHAPE([1, 2, 3], huge_shape)
55+
5256
!ERROR: Size of 'shape=' argument must not be greater than 15
5357
CALL ext_sub(RESHAPE([(n, n=1,20)], &
5458
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]))

0 commit comments

Comments
 (0)