Skip to content

[flang][runtime] Accept 128-bit integer SHIFT values in CSHIFT/EOSHIFT #75246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions flang/runtime/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,31 @@ static inline RT_API_ATTRS std::int64_t GetInt64(
}
}

static inline RT_API_ATTRS std::optional<std::int64_t> GetInt64Safe(
const char *p, std::size_t bytes, Terminator &terminator) {
switch (bytes) {
case 1:
return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 1> *>(p);
case 2:
return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 2> *>(p);
case 4:
return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 4> *>(p);
case 8:
return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 8> *>(p);
case 16: {
using Int128 = CppTypeFor<TypeCategory::Integer, 16>;
auto n{*reinterpret_cast<const Int128 *>(p)};
std::int64_t result = n;
if (result == n) {
return result;
}
return std::nullopt;
}
default:
terminator.Crash("GetInt64Safe: no case for %zd bytes", bytes);
}
}

template <typename INT>
inline RT_API_ATTRS bool SetInteger(INT &x, int kind, std::int64_t value) {
switch (kind) {
Expand Down
37 changes: 23 additions & 14 deletions flang/runtime/transformational.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ class ShiftControl {
}
}
}
} else if (auto count{GetInt64Safe(
shift_.OffsetElement<char>(), shiftElemLen_, terminator_)}) {
shiftCount_ = *count;
} else {
shiftCount_ =
GetInt64(shift_.OffsetElement<char>(), shiftElemLen_, terminator_);
terminator_.Crash("%s: SHIFT= value exceeds 64 bits", which);
}
}
RT_API_ATTRS SubscriptValue GetShift(const SubscriptValue resultAt[]) const {
Expand All @@ -67,8 +69,10 @@ class ShiftControl {
++k;
}
}
return GetInt64(
shift_.Element<char>(shiftAt), shiftElemLen_, terminator_);
auto count{GetInt64Safe(
shift_.Element<char>(shiftAt), shiftElemLen_, terminator_)};
RUNTIME_CHECK(terminator_, count.has_value());
return *count;
} else {
return shiftCount_; // invariant count extracted in Init()
}
Expand Down Expand Up @@ -719,12 +723,15 @@ void RTDEF(Reshape)(Descriptor &result, const Descriptor &source,
std::size_t resultElements{1};
SubscriptValue shapeSubscript{shape.GetDimension(0).LowerBound()};
for (int j{0}; j < resultRank; ++j, ++shapeSubscript) {
resultExtent[j] = GetInt64(
shape.Element<char>(&shapeSubscript), shapeElementBytes, terminator);
if (resultExtent[j] < 0) {
auto extent{GetInt64Safe(
shape.Element<char>(&shapeSubscript), shapeElementBytes, terminator)};
if (!extent) {
terminator.Crash("RESHAPE: value of SHAPE(%d) exceeds 64 bits", j + 1);
} else if (*extent < 0) {
terminator.Crash("RESHAPE: bad value for SHAPE(%d)=%jd", j + 1,
static_cast<std::intmax_t>(resultExtent[j]));
static_cast<std::intmax_t>(*extent));
}
resultExtent[j] = *extent;
resultElements *= resultExtent[j];
}

Expand Down Expand Up @@ -762,14 +769,16 @@ void RTDEF(Reshape)(Descriptor &result, const Descriptor &source,
SubscriptValue orderSubscript{order->GetDimension(0).LowerBound()};
std::size_t orderElementBytes{order->ElementBytes()};
for (SubscriptValue j{0}; j < resultRank; ++j, ++orderSubscript) {
auto k{GetInt64(order->Element<char>(&orderSubscript), orderElementBytes,
terminator)};
if (k < 1 || k > resultRank || ((values >> k) & 1)) {
auto k{GetInt64Safe(order->Element<char>(&orderSubscript),
orderElementBytes, terminator)};
if (!k) {
terminator.Crash("RESHAPE: ORDER element value exceeds 64 bits");
} else if (*k < 1 || *k > resultRank || ((values >> *k) & 1)) {
terminator.Crash("RESHAPE: bad value for ORDER element (%jd)",
static_cast<std::intmax_t>(k));
static_cast<std::intmax_t>(*k));
}
values |= std::uint64_t{1} << k;
dimOrder[j] = k - 1;
values |= std::uint64_t{1} << *k;
dimOrder[j] = *k - 1;
}
} else {
for (int j{0}; j < resultRank; ++j) {
Expand Down