Skip to content

Commit 2177a17

Browse files
authored
[flang][runtime] Avoid call recursion in CopyElement runtime. (#101421)
Device compilers may fail to identify maximum stack size required by a kernel that calls CopyElement due to potential recursive calls. To avoid this, we can use dynamically allocated Stack. To avoid dynamic allocations on the host for simple cases, the Stack implementation has a reserved space (that ends up being allocated on the program stack). I tested both pre-allocated and 0-reserve implementations on the host, and all passed. The actual reserve values might be tuned as needed.
1 parent 98e4413 commit 2177a17

File tree

3 files changed

+291
-64
lines changed

3 files changed

+291
-64
lines changed

flang/runtime/copy.cpp

Lines changed: 155 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -7,83 +7,178 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "copy.h"
10+
#include "stack.h"
1011
#include "terminator.h"
1112
#include "type-info.h"
1213
#include "flang/Runtime/allocatable.h"
1314
#include "flang/Runtime/descriptor.h"
1415
#include <cstring>
1516

1617
namespace Fortran::runtime {
18+
namespace {
19+
using StaticDescTy = StaticDescriptor<maxRank, true, 0>;
20+
21+
// A structure describing the data copy that needs to be done
22+
// from one descriptor to another. It is a helper structure
23+
// for CopyElement.
24+
struct CopyDescriptor {
25+
// A constructor specifying all members explicitly.
26+
RT_API_ATTRS CopyDescriptor(const Descriptor &to, const SubscriptValue toAt[],
27+
const Descriptor &from, const SubscriptValue fromAt[],
28+
std::size_t elements, bool usesStaticDescriptors = false)
29+
: to_(to), from_(from), elements_(elements),
30+
usesStaticDescriptors_(usesStaticDescriptors) {
31+
for (int dim{0}; dim < to.rank(); ++dim) {
32+
toAt_[dim] = toAt[dim];
33+
}
34+
for (int dim{0}; dim < from.rank(); ++dim) {
35+
fromAt_[dim] = fromAt[dim];
36+
}
37+
}
38+
// The number of elements to copy is initialized from the to descriptor.
39+
// The current element subscripts are initialized from the lower bounds
40+
// of the to and from descriptors.
41+
RT_API_ATTRS CopyDescriptor(const Descriptor &to, const Descriptor &from,
42+
bool usesStaticDescriptors = false)
43+
: to_(to), from_(from), elements_(to.Elements()),
44+
usesStaticDescriptors_(usesStaticDescriptors) {
45+
to.GetLowerBounds(toAt_);
46+
from.GetLowerBounds(fromAt_);
47+
}
48+
49+
// Descriptor of the destination.
50+
const Descriptor &to_;
51+
// A subscript specifying the current element position to copy to.
52+
SubscriptValue toAt_[maxRank];
53+
// Descriptor of the source.
54+
const Descriptor &from_;
55+
// A subscript specifying the current element position to copy from.
56+
SubscriptValue fromAt_[maxRank];
57+
// Number of elements left to copy.
58+
std::size_t elements_;
59+
// Must be true, if the to and from descriptors are allocated
60+
// by the CopyElement runtime. The allocated memory belongs
61+
// to a separate stack that needs to be popped in correspondence
62+
// with popping such a CopyDescriptor node.
63+
bool usesStaticDescriptors_;
64+
};
65+
66+
// A pair of StaticDescTy elements.
67+
struct StaticDescriptorsPair {
68+
StaticDescTy to;
69+
StaticDescTy from;
70+
};
71+
} // namespace
72+
1773
RT_OFFLOAD_API_GROUP_BEGIN
1874

1975
RT_API_ATTRS void CopyElement(const Descriptor &to, const SubscriptValue toAt[],
2076
const Descriptor &from, const SubscriptValue fromAt[],
2177
Terminator &terminator) {
22-
char *toPtr{to.Element<char>(toAt)};
23-
char *fromPtr{from.Element<char>(fromAt)};
24-
RUNTIME_CHECK(terminator, to.ElementBytes() == from.ElementBytes());
25-
std::memcpy(toPtr, fromPtr, to.ElementBytes());
26-
// Deep copy allocatable and automatic components if any.
27-
if (const auto *addendum{to.Addendum()}) {
28-
if (const auto *derived{addendum->derivedType()};
29-
derived && !derived->noDestructionNeeded()) {
30-
RUNTIME_CHECK(terminator,
31-
from.Addendum() && derived == from.Addendum()->derivedType());
32-
const Descriptor &componentDesc{derived->component()};
33-
const typeInfo::Component *component{
34-
componentDesc.OffsetElement<typeInfo::Component>()};
35-
std::size_t nComponents{componentDesc.Elements()};
36-
for (std::size_t j{0}; j < nComponents; ++j, ++component) {
37-
if (component->genre() == typeInfo::Component::Genre::Allocatable ||
38-
component->genre() == typeInfo::Component::Genre::Automatic) {
39-
Descriptor &toDesc{
40-
*reinterpret_cast<Descriptor *>(toPtr + component->offset())};
41-
if (toDesc.raw().base_addr != nullptr) {
42-
toDesc.set_base_addr(nullptr);
43-
RUNTIME_CHECK(terminator, toDesc.Allocate() == CFI_SUCCESS);
44-
const Descriptor &fromDesc{*reinterpret_cast<const Descriptor *>(
45-
fromPtr + component->offset())};
46-
CopyArray(toDesc, fromDesc, terminator);
47-
}
48-
} else if (component->genre() == typeInfo::Component::Genre::Data &&
49-
component->derivedType() &&
50-
!component->derivedType()->noDestructionNeeded()) {
51-
SubscriptValue extents[maxRank];
52-
const typeInfo::Value *bounds{component->bounds()};
53-
for (int dim{0}; dim < component->rank(); ++dim) {
54-
SubscriptValue lb{bounds[2 * dim].GetValue(&to).value_or(0)};
55-
SubscriptValue ub{bounds[2 * dim + 1].GetValue(&to).value_or(0)};
56-
extents[dim] = ub >= lb ? ub - lb + 1 : 0;
78+
#if !defined(RT_DEVICE_COMPILATION)
79+
constexpr unsigned copyStackReserve{16};
80+
constexpr unsigned descriptorStackReserve{6};
81+
#else
82+
// Always use dynamic allocation on the device to avoid
83+
// big stack sizes. This may be tuned as needed.
84+
constexpr unsigned copyStackReserve{0};
85+
constexpr unsigned descriptorStackReserve{0};
86+
#endif
87+
// Keep a stack of CopyDescriptor's to avoid recursive calls.
88+
Stack<CopyDescriptor, copyStackReserve> copyStack{terminator};
89+
// Keep a separate stack of StaticDescTy pairs. These descriptors
90+
// may be used for representing copies of Component::Genre::Data
91+
// components (since they do not have their descriptors allocated
92+
// in memory).
93+
Stack<StaticDescriptorsPair, descriptorStackReserve> descriptorsStack{
94+
terminator};
95+
copyStack.emplace(to, toAt, from, fromAt, /*elements=*/std::size_t{1});
96+
97+
while (!copyStack.empty()) {
98+
CopyDescriptor &currentCopy{copyStack.top()};
99+
std::size_t &elements{currentCopy.elements_};
100+
if (elements == 0) {
101+
// This copy has been exhausted.
102+
if (currentCopy.usesStaticDescriptors_) {
103+
// Pop the static descriptors, if they were used
104+
// for the current copy.
105+
descriptorsStack.pop();
106+
}
107+
copyStack.pop();
108+
continue;
109+
}
110+
const Descriptor &curTo{currentCopy.to_};
111+
SubscriptValue *curToAt{currentCopy.toAt_};
112+
const Descriptor &curFrom{currentCopy.from_};
113+
SubscriptValue *curFromAt{currentCopy.fromAt_};
114+
char *toPtr{curTo.Element<char>(curToAt)};
115+
char *fromPtr{curFrom.Element<char>(curFromAt)};
116+
RUNTIME_CHECK(terminator, curTo.ElementBytes() == curFrom.ElementBytes());
117+
// TODO: the memcpy can be optimized when both to and from are contiguous.
118+
// Moreover, if we came here from an Component::Genre::Data component,
119+
// all the per-element copies are redundant, because the parent
120+
// has already been copied as a whole.
121+
std::memcpy(toPtr, fromPtr, curTo.ElementBytes());
122+
--elements;
123+
if (elements != 0) {
124+
curTo.IncrementSubscripts(curToAt);
125+
curFrom.IncrementSubscripts(curFromAt);
126+
}
127+
128+
// Deep copy allocatable and automatic components if any.
129+
if (const auto *addendum{curTo.Addendum()}) {
130+
if (const auto *derived{addendum->derivedType()};
131+
derived && !derived->noDestructionNeeded()) {
132+
RUNTIME_CHECK(terminator,
133+
curFrom.Addendum() && derived == curFrom.Addendum()->derivedType());
134+
const Descriptor &componentDesc{derived->component()};
135+
const typeInfo::Component *component{
136+
componentDesc.OffsetElement<typeInfo::Component>()};
137+
std::size_t nComponents{componentDesc.Elements()};
138+
for (std::size_t j{0}; j < nComponents; ++j, ++component) {
139+
if (component->genre() == typeInfo::Component::Genre::Allocatable ||
140+
component->genre() == typeInfo::Component::Genre::Automatic) {
141+
Descriptor &toDesc{
142+
*reinterpret_cast<Descriptor *>(toPtr + component->offset())};
143+
if (toDesc.raw().base_addr != nullptr) {
144+
toDesc.set_base_addr(nullptr);
145+
RUNTIME_CHECK(terminator, toDesc.Allocate() == CFI_SUCCESS);
146+
const Descriptor &fromDesc{*reinterpret_cast<const Descriptor *>(
147+
fromPtr + component->offset())};
148+
copyStack.emplace(toDesc, fromDesc);
149+
}
150+
} else if (component->genre() == typeInfo::Component::Genre::Data &&
151+
component->derivedType() &&
152+
!component->derivedType()->noDestructionNeeded()) {
153+
SubscriptValue extents[maxRank];
154+
const typeInfo::Value *bounds{component->bounds()};
155+
std::size_t elements{1};
156+
for (int dim{0}; dim < component->rank(); ++dim) {
157+
SubscriptValue lb{bounds[2 * dim].GetValue(&curTo).value_or(0)};
158+
SubscriptValue ub{
159+
bounds[2 * dim + 1].GetValue(&curTo).value_or(0)};
160+
extents[dim] = ub >= lb ? ub - lb + 1 : 0;
161+
elements *= extents[dim];
162+
}
163+
if (elements != 0) {
164+
const typeInfo::DerivedType &compType{*component->derivedType()};
165+
// Place a pair of static descriptors onto the descriptors stack.
166+
descriptorsStack.emplace();
167+
StaticDescriptorsPair &descs{descriptorsStack.top()};
168+
Descriptor &toCompDesc{descs.to.descriptor()};
169+
toCompDesc.Establish(compType, toPtr + component->offset(),
170+
component->rank(), extents);
171+
Descriptor &fromCompDesc{descs.from.descriptor()};
172+
fromCompDesc.Establish(compType, fromPtr + component->offset(),
173+
component->rank(), extents);
174+
copyStack.emplace(toCompDesc, fromCompDesc,
175+
/*usesStaticDescriptors=*/true);
176+
}
57177
}
58-
const typeInfo::DerivedType &compType{*component->derivedType()};
59-
StaticDescriptor<maxRank, true, 0> toStaticDescriptor;
60-
Descriptor &toCompDesc{toStaticDescriptor.descriptor()};
61-
toCompDesc.Establish(compType, toPtr + component->offset(),
62-
component->rank(), extents);
63-
StaticDescriptor<maxRank, true, 0> fromStaticDescriptor;
64-
Descriptor &fromCompDesc{fromStaticDescriptor.descriptor()};
65-
fromCompDesc.Establish(compType, fromPtr + component->offset(),
66-
component->rank(), extents);
67-
CopyArray(toCompDesc, fromCompDesc, terminator);
68178
}
69179
}
70180
}
71181
}
72182
}
73-
74-
RT_API_ATTRS void CopyArray(
75-
const Descriptor &to, const Descriptor &from, Terminator &terminator) {
76-
std::size_t elements{to.Elements()};
77-
RUNTIME_CHECK(terminator, elements == from.Elements());
78-
SubscriptValue toAt[maxRank], fromAt[maxRank];
79-
to.GetLowerBounds(toAt);
80-
from.GetLowerBounds(fromAt);
81-
while (elements-- > 0) {
82-
CopyElement(to, toAt, from, fromAt, terminator);
83-
to.IncrementSubscripts(toAt);
84-
from.IncrementSubscripts(fromAt);
85-
}
86-
}
87-
88183
RT_OFFLOAD_API_GROUP_END
89184
} // namespace Fortran::runtime

flang/runtime/copy.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,5 @@ namespace Fortran::runtime {
2121
RT_API_ATTRS void CopyElement(const Descriptor &to, const SubscriptValue toAt[],
2222
const Descriptor &from, const SubscriptValue fromAt[], Terminator &);
2323

24-
// Copies data from one allocated descriptor's array to another.
25-
RT_API_ATTRS void CopyArray(
26-
const Descriptor &to, const Descriptor &from, Terminator &);
27-
2824
} // namespace Fortran::runtime
2925
#endif // FORTRAN_RUNTIME_COPY_H_

flang/runtime/stack.h

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
//===-- runtime/stack.h -----------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// Trivial implementation of stack that can be used on all targets.
10+
// It is a list based stack with dynamic allocation/deallocation
11+
// of the list nodes.
12+
13+
#ifndef FORTRAN_RUNTIME_STACK_H
14+
#define FORTRAN_RUNTIME_STACK_H
15+
16+
#include "terminator.h"
17+
#include "flang/Runtime/memory.h"
18+
19+
namespace Fortran::runtime {
20+
// Storage for the Stack elements of type T.
21+
template <typename T, unsigned N> struct StackStorage {
22+
void *getElement(unsigned i) {
23+
if (i < N) {
24+
return storage[i];
25+
} else {
26+
return nullptr;
27+
}
28+
}
29+
const void *getElement(unsigned i) const {
30+
if (i < N) {
31+
return storage[i];
32+
} else {
33+
return nullptr;
34+
}
35+
}
36+
37+
private:
38+
// Storage to hold N elements of type T.
39+
// It is declared as an array of bytes to avoid
40+
// default construction (if any is implied by type T).
41+
alignas(T) char storage[N][sizeof(T)];
42+
};
43+
44+
// 0-size specialization that provides no storage.
45+
template <typename T> struct alignas(T) StackStorage<T, 0> {
46+
void *getElement(unsigned) { return nullptr; }
47+
const void *getElement(unsigned) const { return nullptr; }
48+
};
49+
50+
template <typename T, unsigned N = 0> class Stack : public StackStorage<T, N> {
51+
public:
52+
Stack() = delete;
53+
Stack(const Stack &) = delete;
54+
Stack(Stack &&) = delete;
55+
RT_API_ATTRS Stack(Terminator &terminator) : terminator_{terminator} {}
56+
RT_API_ATTRS ~Stack() {
57+
while (!empty()) {
58+
pop();
59+
}
60+
}
61+
RT_API_ATTRS void push(const T &object) {
62+
if (void *ptr{this->getElement(size_)}) {
63+
new (ptr) T{object};
64+
} else {
65+
top_ = New<List>{terminator_}(top_, object).release();
66+
}
67+
++size_;
68+
}
69+
RT_API_ATTRS void push(T &&object) {
70+
if (void *ptr{this->getElement(size_)}) {
71+
new (ptr) T{std::move(object)};
72+
} else {
73+
top_ = New<List>{terminator_}(top_, std::move(object)).release();
74+
}
75+
++size_;
76+
}
77+
template <typename... Args> RT_API_ATTRS void emplace(Args &&...args) {
78+
if (void *ptr{this->getElement(size_)}) {
79+
new (ptr) T{std::forward<Args>(args)...};
80+
} else {
81+
top_ =
82+
New<List>{terminator_}(top_, std::forward<Args>(args)...).release();
83+
}
84+
++size_;
85+
}
86+
RT_API_ATTRS T &top() {
87+
RUNTIME_CHECK(terminator_, size_ > 0);
88+
if (void *ptr{this->getElement(size_ - 1)}) {
89+
return *reinterpret_cast<T *>(ptr);
90+
} else {
91+
RUNTIME_CHECK(terminator_, top_);
92+
return top_->object_;
93+
}
94+
}
95+
RT_API_ATTRS const T &top() const {
96+
RUNTIME_CHECK(terminator_, size_ > 0);
97+
if (void *ptr{this->getElement(size_ - 1)}) {
98+
return *reinterpret_cast<const T *>(ptr);
99+
} else {
100+
RUNTIME_CHECK(terminator_, top_);
101+
return top_->object_;
102+
}
103+
}
104+
RT_API_ATTRS void pop() {
105+
RUNTIME_CHECK(terminator_, size_ > 0);
106+
if (void *ptr{this->getElement(size_ - 1)}) {
107+
reinterpret_cast<T *>(ptr)->~T();
108+
} else {
109+
RUNTIME_CHECK(terminator_, top_);
110+
List *next{top_->next_};
111+
top_->~List();
112+
FreeMemory(top_);
113+
top_ = next;
114+
}
115+
--size_;
116+
}
117+
RT_API_ATTRS bool empty() const { return size_ == 0; }
118+
119+
private:
120+
struct List {
121+
template <typename... Args>
122+
RT_API_ATTRS List(List *next, Args &&...args)
123+
: next_(next), object_(std::forward<Args>(args)...) {}
124+
RT_API_ATTRS List(List *next, const T &object)
125+
: next_(next), object_(object) {}
126+
RT_API_ATTRS List(List *next, T &&object)
127+
: next_(next), object_(std::move(object)) {}
128+
List *next_{nullptr};
129+
T object_;
130+
};
131+
List *top_{nullptr};
132+
std::size_t size_{0};
133+
Terminator &terminator_;
134+
};
135+
} // namespace Fortran::runtime
136+
#endif // FORTRAN_RUNTIME_STACK_H

0 commit comments

Comments
 (0)