Skip to content

[flang][runtime] Interoperable POINTER deallocation validation #96100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions flang/include/flang/Runtime/pointer.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ bool RTDECL(PointerIsAssociated)(const Descriptor &);
bool RTDECL(PointerIsAssociatedWith)(
const Descriptor &, const Descriptor *target);

// Fortran POINTERs are allocated with an extra validation word after their
// payloads in order to detect erroneous deallocations later.
RT_API_ATTRS void *AllocateValidatedPointerPayload(std::size_t);
RT_API_ATTRS bool ValidatePointerPayload(const ISO::CFI_cdesc_t &);

} // extern "C"
} // namespace Fortran::runtime
#endif // FORTRAN_RUNTIME_POINTER_H_
10 changes: 7 additions & 3 deletions flang/runtime/ISO_Fortran_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "terminator.h"
#include "flang/ISO_Fortran_binding_wrapper.h"
#include "flang/Runtime/descriptor.h"
#include "flang/Runtime/pointer.h"
#include "flang/Runtime/type-code.h"
#include <cstdlib>

Expand Down Expand Up @@ -75,7 +76,7 @@ RT_API_ATTRS int CFI_allocate(CFI_cdesc_t *descriptor,
dim->sm = byteSize;
byteSize *= extent;
}
void *p{byteSize ? std::malloc(byteSize) : std::malloc(1)};
void *p{runtime::AllocateValidatedPointerPayload(byteSize)};
if (!p && byteSize) {
return CFI_ERROR_MEM_ALLOCATION;
}
Expand All @@ -91,8 +92,11 @@ RT_API_ATTRS int CFI_deallocate(CFI_cdesc_t *descriptor) {
if (descriptor->version != CFI_VERSION) {
return CFI_INVALID_DESCRIPTOR;
}
if (descriptor->attribute != CFI_attribute_allocatable &&
descriptor->attribute != CFI_attribute_pointer) {
if (descriptor->attribute == CFI_attribute_pointer) {
if (!runtime::ValidatePointerPayload(*descriptor)) {
return CFI_INVALID_DESCRIPTOR;
}
} else if (descriptor->attribute != CFI_attribute_allocatable) {
// Non-interoperable object
return CFI_INVALID_DESCRIPTOR;
}
Expand Down
11 changes: 10 additions & 1 deletion flang/runtime/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,16 @@ RT_API_ATTRS int Descriptor::Destroy(
}
}

RT_API_ATTRS int Descriptor::Deallocate() { return ISO::CFI_deallocate(&raw_); }
RT_API_ATTRS int Descriptor::Deallocate() {
ISO::CFI_cdesc_t &descriptor{raw()};
if (!descriptor.base_addr) {
return CFI_ERROR_BASE_ADDR_NULL;
} else {
std::free(descriptor.base_addr);
descriptor.base_addr = nullptr;
return CFI_SUCCESS;
}
}

RT_API_ATTRS bool Descriptor::DecrementSubscripts(
SubscriptValue *subscript, const int *permutation) const {
Expand Down
67 changes: 42 additions & 25 deletions flang/runtime/pointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,23 @@ void RTDEF(PointerAssociateRemapping)(Descriptor &pointer,
}
}

RT_API_ATTRS void *AllocateValidatedPointerPayload(std::size_t byteSize) {
// Add space for a footer to validate during deallocation.
constexpr std::size_t align{sizeof(std::uintptr_t)};
byteSize = ((byteSize / align) + 1) * align;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
byteSize = ((byteSize / align) + 1) * align;
byteSize = ((byteSize + align - 1) / align) * align;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If byteSize is initially zero, I want the result to be align, not zero.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why? If byteSize is initially zero, the total allocation will end up being 2 * align - isn't align bytes allocation enough?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

((0/align)+1)*align -> (0+1)*align -> 1*align -> align, yes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, byteSize is equal to align, and then total (below) is equal to align + sizeof(std::uintptr_t) - this is how much memory is being malloc'ed. I am asking why we need to malloc more than align bytes of memory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't. Will fix. I was misunderstanding your comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::size_t total{byteSize + sizeof(std::uintptr_t)};
void *p{std::malloc(total)};
if (p) {
// Fill the footer word with the XOR of the ones' complement of
// the base address, which is a value that would be highly unlikely
// to appear accidentally at the right spot.
std::uintptr_t *footer{
reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)};
*footer = ~reinterpret_cast<std::uintptr_t>(p);
}
return p;
}

int RTDEF(PointerAllocate)(Descriptor &pointer, bool hasStat,
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
Terminator terminator{sourceFile, sourceLine};
Expand All @@ -137,22 +154,12 @@ int RTDEF(PointerAllocate)(Descriptor &pointer, bool hasStat,
elementBytes = pointer.raw().elem_len = 0;
}
std::size_t byteSize{pointer.Elements() * elementBytes};
// Add space for a footer to validate during DEALLOCATE.
constexpr std::size_t align{sizeof(std::uintptr_t)};
byteSize = ((byteSize + align - 1) / align) * align;
std::size_t total{byteSize + sizeof(std::uintptr_t)};
void *p{std::malloc(total)};
void *p{AllocateValidatedPointerPayload(byteSize)};
if (!p) {
return ReturnError(terminator, CFI_ERROR_MEM_ALLOCATION, errMsg, hasStat);
}
pointer.set_base_addr(p);
pointer.SetByteStrides();
// Fill the footer word with the XOR of the ones' complement of
// the base address, which is a value that would be highly unlikely
// to appear accidentally at the right spot.
std::uintptr_t *footer{
reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)};
*footer = ~reinterpret_cast<std::uintptr_t>(p);
int stat{StatOk};
if (const DescriptorAddendum * addendum{pointer.Addendum()}) {
if (const auto *derived{addendum->derivedType()}) {
Expand All @@ -176,6 +183,27 @@ int RTDEF(PointerAllocateSource)(Descriptor &pointer, const Descriptor &source,
return stat;
}

static RT_API_ATTRS std::size_t GetByteSize(
const ISO::CFI_cdesc_t &descriptor) {
std::size_t rank{descriptor.rank};
const ISO::CFI_dim_t *dim{descriptor.dim};
std::size_t byteSize{descriptor.elem_len};
for (std::size_t j{0}; j < rank; ++j) {
byteSize *= dim[j].extent;
}
return byteSize;
}

bool RT_API_ATTRS ValidatePointerPayload(const ISO::CFI_cdesc_t &desc) {
std::size_t byteSize{GetByteSize(desc)};
constexpr std::size_t align{sizeof(std::uintptr_t)};
byteSize = ((byteSize / align) + 1) * align;
const void *p{desc.base_addr};
const std::uintptr_t *footer{reinterpret_cast<const std::uintptr_t *>(
static_cast<const char *>(p) + byteSize)};
return *footer == ~reinterpret_cast<std::uintptr_t>(p);
}

int RTDEF(PointerDeallocate)(Descriptor &pointer, bool hasStat,
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
Terminator terminator{sourceFile, sourceLine};
Expand All @@ -185,20 +213,9 @@ int RTDEF(PointerDeallocate)(Descriptor &pointer, bool hasStat,
if (!pointer.IsAllocated()) {
return ReturnError(terminator, StatBaseNull, errMsg, hasStat);
}
if (executionEnvironment.checkPointerDeallocation) {
// Validate the footer. This should fail if the pointer doesn't
// span the entire object, or the object was not allocated as a
// pointer.
std::size_t byteSize{pointer.Elements() * pointer.ElementBytes()};
constexpr std::size_t align{sizeof(std::uintptr_t)};
byteSize = ((byteSize + align - 1) / align) * align;
void *p{pointer.raw().base_addr};
std::uintptr_t *footer{
reinterpret_cast<std::uintptr_t *>(static_cast<char *>(p) + byteSize)};
if (*footer != ~reinterpret_cast<std::uintptr_t>(p)) {
return ReturnError(
terminator, StatBadPointerDeallocation, errMsg, hasStat);
}
if (executionEnvironment.checkPointerDeallocation &&
!ValidatePointerPayload(pointer.raw())) {
return ReturnError(terminator, StatBadPointerDeallocation, errMsg, hasStat);
}
return ReturnError(terminator,
pointer.Destroy(/*finalize=*/true, /*destroyPointers=*/true, &terminator),
Expand Down
Loading