Skip to content

[flang] Check for overflows in RESHAPE folding #68342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flang/include/flang/Evaluate/constant.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ inline int GetRank(const ConstantSubscripts &s) {
return static_cast<int>(s.size());
}

std::size_t TotalElementCount(const ConstantSubscripts &);
// Returns the number of elements of shape, if no overflow occurs.
std::optional<uint64_t> TotalElementCount(const ConstantSubscripts &shape);

// Validate dimension re-ordering like ORDER in RESHAPE.
// On success, return a vector that can be used as dimOrder in
Expand Down
15 changes: 13 additions & 2 deletions flang/include/flang/Evaluate/initial-image.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@ namespace Fortran::evaluate {

class InitialImage {
public:
enum Result { Ok, NotAConstant, OutOfRange, SizeMismatch, LengthMismatch };
enum Result {
Ok,
NotAConstant,
OutOfRange,
SizeMismatch,
LengthMismatch,
TooManyElems
};

explicit InitialImage(std::size_t bytes) : data_(bytes) {}
InitialImage(InitialImage &&that) = default;
Expand Down Expand Up @@ -60,7 +67,11 @@ class InitialImage {
if (offset < 0 || offset + bytes > data_.size()) {
return OutOfRange;
} else {
auto elements{TotalElementCount(x.shape())};
auto optElements{TotalElementCount(x.shape())};
if (!optElements) {
return TooManyElems;
}
auto elements{*optElements};
auto elementBytes{bytes > 0 ? bytes / elements : 0};
if (elements * elementBytes != bytes) {
return SizeMismatch;
Expand Down
31 changes: 24 additions & 7 deletions flang/lib/Evaluate/constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,18 @@ ConstantSubscript ConstantBounds::SubscriptsToOffset(
return offset;
}

std::size_t TotalElementCount(const ConstantSubscripts &shape) {
return static_cast<std::size_t>(GetSize(shape));
std::optional<uint64_t> TotalElementCount(const ConstantSubscripts &shape) {
uint64_t size{1};
for (auto dim : shape) {
CHECK(dim >= 0);
uint64_t osize{size};
size = osize * dim;
if (size > std::numeric_limits<decltype(dim)>::max() ||
(dim != 0 && size / dim != osize)) {
return std::nullopt;
}
}
return static_cast<uint64_t>(GetSize(shape));
}

bool ConstantBounds::IncrementSubscripts(
Expand Down Expand Up @@ -135,7 +145,7 @@ template <typename RESULT, typename ELEMENT>
ConstantBase<RESULT, ELEMENT>::ConstantBase(
std::vector<Element> &&x, ConstantSubscripts &&sh, Result res)
: ConstantBounds(std::move(sh)), result_{res}, values_(std::move(x)) {
CHECK(size() == TotalElementCount(shape()));
CHECK(TotalElementCount(shape()) && size() == *TotalElementCount(shape()));
}

template <typename RESULT, typename ELEMENT>
Expand All @@ -149,7 +159,9 @@ bool ConstantBase<RESULT, ELEMENT>::operator==(const ConstantBase &that) const {
template <typename RESULT, typename ELEMENT>
auto ConstantBase<RESULT, ELEMENT>::Reshape(
const ConstantSubscripts &dims) const -> std::vector<Element> {
std::size_t n{TotalElementCount(dims)};
std::optional<uint64_t> optN{TotalElementCount(dims)};
CHECK(optN);
uint64_t n{*optN};
CHECK(!empty() || n == 0);
std::vector<Element> elements;
auto iter{values().cbegin()};
Expand Down Expand Up @@ -209,7 +221,8 @@ template <int KIND>
Constant<Type<TypeCategory::Character, KIND>>::Constant(ConstantSubscript len,
std::vector<Scalar<Result>> &&strings, ConstantSubscripts &&sh)
: ConstantBounds(std::move(sh)), length_{len} {
CHECK(strings.size() == TotalElementCount(shape()));
CHECK(TotalElementCount(shape()) &&
strings.size() == *TotalElementCount(shape()));
values_.assign(strings.size() * length_,
static_cast<typename Scalar<Result>::value_type>(' '));
ConstantSubscript at{0};
Expand All @@ -236,7 +249,9 @@ bool Constant<Type<TypeCategory::Character, KIND>>::empty() const {
template <int KIND>
std::size_t Constant<Type<TypeCategory::Character, KIND>>::size() const {
if (length_ == 0) {
return TotalElementCount(shape());
std::optional<uint64_t> n{TotalElementCount(shape())};
CHECK(n);
return *n;
} else {
return static_cast<ConstantSubscript>(values_.size()) / length_;
}
Expand Down Expand Up @@ -274,7 +289,9 @@ auto Constant<Type<TypeCategory::Character, KIND>>::Substring(
template <int KIND>
auto Constant<Type<TypeCategory::Character, KIND>>::Reshape(
ConstantSubscripts &&dims) const -> Constant<Result> {
std::size_t n{TotalElementCount(dims)};
std::optional<uint64_t> optN{TotalElementCount(dims)};
CHECK(optN);
uint64_t n{*optN};
CHECK(!empty() || n == 0);
std::vector<Element> elements;
ConstantSubscript at{0},
Expand Down
4 changes: 3 additions & 1 deletion flang/lib/Evaluate/fold-designator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,9 @@ ConstantObjectPointer ConstantObjectPointer::From(
FoldingContext &context, const Expr<SomeType> &expr) {
auto extents{GetConstantExtents(context, expr)};
CHECK(extents);
std::size_t elements{TotalElementCount(*extents)};
std::optional<uint64_t> optElements{TotalElementCount(*extents)};
CHECK(optElements);
uint64_t elements{*optElements};
CHECK(elements > 0);
int rank{GetRank(*extents)};
ConstantSubscripts at(rank, 1);
Expand Down
84 changes: 51 additions & 33 deletions flang/lib/Evaluate/fold-implementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,13 @@ Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
CHECK(rank == GetRank(shape));
// Compute all the scalar values of the results
std::vector<Scalar<TR>> results;
if (TotalElementCount(shape) > 0) {
std::optional<uint64_t> n{TotalElementCount(shape)};
if (!n) {
context.messages().Say(
"Too many elements in elemental intrinsic function result"_err_en_US);
return Expr<TR>{std::move(funcRef)};
}
if (*n > 0) {
ConstantBounds bounds{shape};
ConstantSubscripts resultIndex(rank, 1);
ConstantSubscripts argIndex[]{std::get<I>(*args)->lbounds()...};
Expand Down Expand Up @@ -879,33 +885,40 @@ template <typename T> Expr<T> Folder<T>::RESHAPE(FunctionRef<T> &&funcRef) {
context_.messages().Say(
"'shape=' argument must not have a negative extent"_err_en_US);
} else {
int rank{GetRank(shape.value())};
std::size_t resultElements{TotalElementCount(shape.value())};
std::optional<std::vector<int>> dimOrder;
if (order) {
dimOrder = ValidateDimensionOrder(rank, *order);
}
std::vector<int> *dimOrderPtr{dimOrder ? &dimOrder.value() : nullptr};
if (order && !dimOrder) {
context_.messages().Say("Invalid 'order=' argument in RESHAPE"_err_en_US);
} else if (resultElements > source->size() && (!pad || pad->empty())) {
std::optional<uint64_t> optResultElement{TotalElementCount(shape.value())};
if (!optResultElement) {
context_.messages().Say(
"Too few elements in 'source=' argument and 'pad=' "
"argument is not present or has null size"_err_en_US);
"'shape=' argument has too many elements"_err_en_US);
} else {
Constant<T> result{!source->empty() || !pad
? source->Reshape(std::move(shape.value()))
: pad->Reshape(std::move(shape.value()))};
ConstantSubscripts subscripts{result.lbounds()};
auto copied{result.CopyFrom(*source,
std::min(source->size(), resultElements), subscripts, dimOrderPtr)};
if (copied < resultElements) {
CHECK(pad);
copied += result.CopyFrom(
*pad, resultElements - copied, subscripts, dimOrderPtr);
int rank{GetRank(shape.value())};
uint64_t resultElements{*optResultElement};
std::optional<std::vector<int>> dimOrder;
if (order) {
dimOrder = ValidateDimensionOrder(rank, *order);
}
std::vector<int> *dimOrderPtr{dimOrder ? &dimOrder.value() : nullptr};
if (order && !dimOrder) {
context_.messages().Say(
"Invalid 'order=' argument in RESHAPE"_err_en_US);
} else if (resultElements > source->size() && (!pad || pad->empty())) {
context_.messages().Say(
"Too few elements in 'source=' argument and 'pad=' "
"argument is not present or has null size"_err_en_US);
} else {
Constant<T> result{!source->empty() || !pad
? source->Reshape(std::move(shape.value()))
: pad->Reshape(std::move(shape.value()))};
ConstantSubscripts subscripts{result.lbounds()};
auto copied{result.CopyFrom(*source,
std::min(source->size(), resultElements), subscripts, dimOrderPtr)};
if (copied < resultElements) {
CHECK(pad);
copied += result.CopyFrom(
*pad, resultElements - copied, subscripts, dimOrderPtr);
}
CHECK(copied == resultElements);
return Expr<T>{std::move(result)};
}
CHECK(copied == resultElements);
return Expr<T>{std::move(result)};
}
}
// Invalid, prevent re-folding
Expand Down Expand Up @@ -944,14 +957,19 @@ template <typename T> Expr<T> Folder<T>::SPREAD(FunctionRef<T> &&funcRef) {
ConstantSubscripts shape{source->shape()};
shape.insert(shape.begin() + *dim - 1, *ncopies);
Constant<T> spread{source->Reshape(std::move(shape))};
std::vector<int> dimOrder;
for (int j{0}; j < sourceRank; ++j) {
dimOrder.push_back(j < *dim - 1 ? j : j + 1);
}
dimOrder.push_back(*dim - 1);
ConstantSubscripts at{spread.lbounds()}; // all 1
spread.CopyFrom(*source, TotalElementCount(spread.shape()), at, &dimOrder);
return Expr<T>{std::move(spread)};
std::optional<uint64_t> n{TotalElementCount(spread.shape())};
if (!n) {
context_.messages().Say("Too many elements in SPREAD result"_err_en_US);
} else {
std::vector<int> dimOrder;
for (int j{0}; j < sourceRank; ++j) {
dimOrder.push_back(j < *dim - 1 ? j : j + 1);
}
dimOrder.push_back(*dim - 1);
ConstantSubscripts at{spread.lbounds()}; // all 1
spread.CopyFrom(*source, *n, at, &dimOrder);
return Expr<T>{std::move(spread)};
}
}
// Invalid, prevent re-folding
return MakeInvalidIntrinsic(std::move(funcRef));
Expand Down
10 changes: 8 additions & 2 deletions flang/lib/Evaluate/initial-image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ auto InitialImage::Add(ConstantSubscript offset, std::size_t bytes,
if (offset < 0 || offset + bytes > data_.size()) {
return OutOfRange;
} else {
auto elements{TotalElementCount(x.shape())};
auto optElements{TotalElementCount(x.shape())};
if (!optElements) {
return TooManyElems;
}
auto elements{*optElements};
auto elementBytes{bytes > 0 ? bytes / elements : 0};
if (elements * elementBytes != bytes) {
return SizeMismatch;
Expand Down Expand Up @@ -89,7 +93,9 @@ class AsConstantHelper {
}
using Const = Constant<T>;
using Scalar = typename Const::Element;
std::size_t elements{TotalElementCount(extents_)};
std::optional<uint64_t> optElements{TotalElementCount(extents_)};
CHECK(optElements);
uint64_t elements{*optElements};
std::vector<Scalar> typedValue(elements);
auto elemBytes{ToInt64(type_.MeasureSizeInBytes(
context_, GetRank(extents_) > 0, charLength_))};
Expand Down
2 changes: 2 additions & 0 deletions flang/lib/Semantics/data-to-inits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,8 @@ bool DataInitializationCompiler<DSV>::InitElement(
"DATA statement value '%s' for '%s' has the wrong length"_warn_en_US,
folded.AsFortran(), DescribeElement());
return true;
} else if (status == evaluate::InitialImage::TooManyElems) {
exprAnalyzer_.Say("DATA statement has too many elements"_err_en_US);
} else {
CHECK(exprAnalyzer_.context().AnyFatalError());
}
Expand Down
4 changes: 4 additions & 0 deletions flang/test/Semantics/reshape.f90
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ program reshaper
integer, parameter :: array21(I64_MAX - 2 : I64_MAX) = [1, 2, 3]
integer, parameter :: array22(2) = RESHAPE(array21, [2])

integer(8), parameter :: huge_shape(2) = [I64_MAX, I64_MAX]
!ERROR: 'shape=' argument has too many elements
integer :: array23(I64_MAX, I64_MAX) = RESHAPE([1, 2, 3], huge_shape)

!ERROR: Size of 'shape=' argument must not be greater than 15
CALL ext_sub(RESHAPE([(n, n=1,20)], &
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]))
Expand Down