Skip to content

[Runtime] Fix double-frees in ConcurrentReadableArray. #16794

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 51 additions & 21 deletions include/swift/Runtime/Concurrent.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,14 +451,57 @@ template <class ElemTy> struct ConcurrentReadableArray {
Mutex WriterLock;
std::vector<Storage *> FreeList;

void incrementReaders() {
ReaderCount.fetch_add(1, std::memory_order_acquire);
}

void decrementReaders() {
ReaderCount.fetch_sub(1, std::memory_order_release);
}

void deallocateFreeList() {
for (Storage *storage : FreeList)
storage->deallocate();
FreeList.clear();
FreeList.shrink_to_fit();
}

public:
struct Snapshot {
ConcurrentReadableArray *Array;
const ElemTy *Start;
size_t Count;

Snapshot(ConcurrentReadableArray *array, const ElemTy *start, size_t count)
: Array(array), Start(start), Count(count) {}

Snapshot(const Snapshot &other)
: Array(other.Array), Start(other.Start), Count(other.Count) {
Array->incrementReaders();
}

~Snapshot() {
Array->decrementReaders();
}

const ElemTy *begin() { return Start; }
const ElemTy *end() { return Start + Count; }
size_t count() { return Count; }
};

// This type cannot be safely copied, moved, or deleted.
ConcurrentReadableArray(const ConcurrentReadableArray &) = delete;
ConcurrentReadableArray(ConcurrentReadableArray &&) = delete;
ConcurrentReadableArray &operator=(const ConcurrentReadableArray &) = delete;

ConcurrentReadableArray() : Capacity(0), ReaderCount(0), Elements(nullptr) {}

~ConcurrentReadableArray() {
assert(ReaderCount.load(std::memory_order_acquire) == 0 &&
"deallocating ConcurrentReadableArray with outstanding snapshots");
deallocateFreeList();
}

void push_back(const ElemTy &elem) {
ScopedLock guard(WriterLock);

Expand All @@ -482,32 +525,19 @@ template <class ElemTy> struct ConcurrentReadableArray {
storage->Count.store(count + 1, std::memory_order_release);

if (ReaderCount.load(std::memory_order_acquire) == 0)
for (Storage *storage : FreeList)
storage->deallocate();
deallocateFreeList();
}

/// Read the contents of the array. The parameter `f` is called with
/// two parameters: a pointer to the elements in the array, and the
/// count. This represents a snapshot of the contents at the time
/// `read` was called. The pointer becomes invalid after `f` returns.
template <class F> auto read(F f) -> decltype(f(nullptr, 0)) {
ReaderCount.fetch_add(1, std::memory_order_acquire);
Snapshot snapshot() {
incrementReaders();
auto *storage = Elements.load(SWIFT_MEMORY_ORDER_CONSUME);
if (storage == nullptr) {
return Snapshot(this, nullptr, 0);
}

auto count = storage->Count.load(std::memory_order_acquire);
const auto *ptr = storage->data();

decltype(f(nullptr, 0)) result = f(ptr, count);

ReaderCount.fetch_sub(1, std::memory_order_release);

return result;
}

/// Get the current count. It's just a snapshot and may be obsolete immediately.
size_t count() {
return read([](const ElemTy *ptr, size_t count) -> size_t {
return count;
});
return Snapshot(this, ptr, count);
}
};

Expand Down
203 changes: 95 additions & 108 deletions stdlib/public/runtime/ProtocolConformance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,14 +326,11 @@ void ConformanceState::verify() const {
// Iterate over all of the sections and verify all of the protocol
// descriptors.
auto &Self = const_cast<ConformanceState &>(*this);
Self.SectionsToScan.read([](const ConformanceSection *ptr, size_t count) -> char {
for (size_t i = 0; i < count; i++) {
for (const auto &Record : ptr[i]) {
Record.get()->verify();
}
for (const auto &Section : Self.SectionsToScan.snapshot()) {
for (const auto &Record : Section) {
Record.get()->verify();
}
return 0;
});
}
}
#endif

Expand Down Expand Up @@ -445,7 +442,7 @@ searchInConformanceCache(const Metadata *type,
}

// Check if the negative cache entry is up-to-date.
if (Value->getFailureGeneration() == C.SectionsToScan.count()) {
if (Value->getFailureGeneration() == C.SectionsToScan.snapshot().count()) {
// Negative cache entry is up-to-date. Return failure along with
// the original query type's own cache entry, if we found one.
// (That entry may be out of date but the caller still has use for it.)
Expand Down Expand Up @@ -546,100 +543,94 @@ swift_conformsToProtocolImpl(const Metadata * const type,
auto failureEntry = FoundConformance.failureEntry;

// Prepare to scan conformance records.
size_t scannedCount;
auto returnNull = C.SectionsToScan
.read([&](const ConformanceSection *ptr, size_t count) -> bool {
scannedCount = count;
// Scan only sections that were not scanned yet.
// If we found an out-of-date negative cache entry,
// we need not to re-scan the sections that it covers.
auto startIndex = failureEntry ? failureEntry->getFailureGeneration() : 0;
auto endIndex = count;
auto snapshot = C.SectionsToScan.snapshot();

// If there are no unscanned sections outstanding
// then we can cache failure and give up now.
if (startIndex == endIndex) {
C.cacheFailure(type, protocol, count);
return true;
// Scan only sections that were not scanned yet.
// If we found an out-of-date negative cache entry,
// we need not to re-scan the sections that it covers.
auto startIndex = failureEntry ? failureEntry->getFailureGeneration() : 0;
auto endIndex = snapshot.count();

// If there are no unscanned sections outstanding
// then we can cache failure and give up now.
if (startIndex == endIndex) {
C.cacheFailure(type, protocol, snapshot.count());
return nullptr;
}

/// Local function to retrieve the witness table and record the result.
auto recordWitnessTable = [&](const ProtocolConformanceDescriptor &descriptor,
const Metadata *type) {
switch (descriptor.getConformanceKind()) {
case ConformanceFlags::ConformanceKind::WitnessTable:
// If the record provides a nondependent witness table for all
// instances of a generic type, cache it for the generic pattern.
C.cacheSuccess(type, protocol, descriptor.getStaticWitnessTable());
return;

case ConformanceFlags::ConformanceKind::WitnessTableAccessor:
// If the record provides a dependent witness table accessor,
// cache the result for the instantiated type metadata.
C.cacheSuccess(type, protocol, descriptor.getWitnessTable(type));
return;

case ConformanceFlags::ConformanceKind::ConditionalWitnessTableAccessor: {
auto witnessTable = descriptor.getWitnessTable(type);
if (witnessTable)
C.cacheSuccess(type, protocol, witnessTable);
else
C.cacheFailure(type, protocol, snapshot.count());
return;
}
}

/// Local function to retrieve the witness table and record the result.
auto recordWitnessTable = [&](const ProtocolConformanceDescriptor &descriptor,
const Metadata *type) {
switch (descriptor.getConformanceKind()) {
case ConformanceFlags::ConformanceKind::WitnessTable:
// If the record provides a nondependent witness table for all
// instances of a generic type, cache it for the generic pattern.
C.cacheSuccess(type, protocol, descriptor.getStaticWitnessTable());
return;

case ConformanceFlags::ConformanceKind::WitnessTableAccessor:
// If the record provides a dependent witness table accessor,
// cache the result for the instantiated type metadata.
C.cacheSuccess(type, protocol, descriptor.getWitnessTable(type));
return;

case ConformanceFlags::ConformanceKind::ConditionalWitnessTableAccessor: {
auto witnessTable = descriptor.getWitnessTable(type);
if (witnessTable)
C.cacheSuccess(type, protocol, witnessTable);
else
C.cacheFailure(type, protocol, count);
return;
}
}
// Always fail, because we cannot interpret a future conformance
// kind.
C.cacheFailure(type, protocol, snapshot.count());
};

// Always fail, because we cannot interpret a future conformance
// kind.
C.cacheFailure(type, protocol, count);
};

// Really scan conformance records.
for (size_t i = startIndex; i < endIndex; i++) {
auto &section = ptr[i];
// Eagerly pull records for nondependent witnesses into our cache.
for (const auto &record : section) {
auto &descriptor = *record.get();

// If the record applies to a specific type, cache it.
if (auto metadata = descriptor.getCanonicalTypeMetadata()) {
auto P = descriptor.getProtocol();

// Look for an exact match.
if (protocol != P)
continue;

if (!isRelatedType(type, metadata, /*candidateIsMetadata=*/true))
continue;

// Record the witness table.
recordWitnessTable(descriptor, metadata);

// TODO: "Nondependent witness table" probably deserves its own flag.
// An accessor function might still be necessary even if the witness table
// can be shared.
} else if (descriptor.getTypeKind()
== TypeMetadataRecordKind::DirectNominalTypeDescriptor ||
descriptor.getTypeKind()
== TypeMetadataRecordKind::IndirectNominalTypeDescriptor) {
auto R = descriptor.getTypeContextDescriptor();
auto P = descriptor.getProtocol();

// Look for an exact match.
if (protocol != P)
continue;

if (!isRelatedType(type, R, /*candidateIsMetadata=*/false))
continue;

recordWitnessTable(descriptor, type);
}
// Really scan conformance records.
for (size_t i = startIndex; i < endIndex; i++) {
auto &section = snapshot.Start[i];
// Eagerly pull records for nondependent witnesses into our cache.
for (const auto &record : section) {
auto &descriptor = *record.get();

// If the record applies to a specific type, cache it.
if (auto metadata = descriptor.getCanonicalTypeMetadata()) {
auto P = descriptor.getProtocol();

// Look for an exact match.
if (protocol != P)
continue;

if (!isRelatedType(type, metadata, /*candidateIsMetadata=*/true))
continue;

// Record the witness table.
recordWitnessTable(descriptor, metadata);

// TODO: "Nondependent witness table" probably deserves its own flag.
// An accessor function might still be necessary even if the witness table
// can be shared.
} else if (descriptor.getTypeKind()
== TypeMetadataRecordKind::DirectNominalTypeDescriptor ||
descriptor.getTypeKind()
== TypeMetadataRecordKind::IndirectNominalTypeDescriptor) {
auto R = descriptor.getTypeContextDescriptor();
auto P = descriptor.getProtocol();

// Look for an exact match.
if (protocol != P)
continue;

if (!isRelatedType(type, R, /*candidateIsMetadata=*/false))
continue;

recordWitnessTable(descriptor, type);
}
}
return false;
});

if (returnNull) return nullptr;
}

// Conformance scan is complete.
// Search the cache once more, and this time update the cache if necessary.
Expand All @@ -648,7 +639,7 @@ swift_conformsToProtocolImpl(const Metadata * const type,
if (FoundConformance.isAuthoritative) {
return FoundConformance.witnessTable;
} else {
C.cacheFailure(type, protocol, scannedCount);
C.cacheFailure(type, protocol, snapshot.count());
return nullptr;
}
}
Expand All @@ -657,19 +648,15 @@ const TypeContextDescriptor *
swift::_searchConformancesByMangledTypeName(Demangle::NodePointer node) {
auto &C = Conformances.get();

return C.SectionsToScan
.read([&](const ConformanceSection *ptr, size_t count) -> const TypeContextDescriptor * {
for (size_t i = 0; i < count; i++) {
auto &section = ptr[i];
for (const auto &record : section) {
if (auto ntd = record->getTypeContextDescriptor()) {
if (_contextDescriptorMatchesMangling(ntd, node))
return ntd;
}
for (auto &section : C.SectionsToScan.snapshot()) {
for (const auto &record : section) {
if (auto ntd = record->getTypeContextDescriptor()) {
if (_contextDescriptorMatchesMangling(ntd, node))
return ntd;
}
}
return nullptr;
});
}
return nullptr;
}

/// Resolve a reference to a generic parameter to type metadata.
Expand Down
1 change: 1 addition & 0 deletions unittests/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ if(("${SWIFT_HOST_VARIANT_SDK}" STREQUAL "${SWIFT_PRIMARY_VARIANT_SDK}") AND
add_swift_unittest(SwiftRuntimeTests
Array.cpp
CompatibilityOverride.cpp
Concurrent.cpp
Exclusivity.cpp
Metadata.cpp
Mutex.cpp
Expand Down
44 changes: 44 additions & 0 deletions unittests/runtime/Concurrent.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//===--- Concurrent.cpp - Concurrent data structure tests -----------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2018 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

#include "swift/Runtime/Concurrent.h"
#include "gtest/gtest.h"

using namespace swift;

TEST(ConcurrentReadableArrayTest, SingleThreaded) {
ConcurrentReadableArray<size_t> array;

auto add = [&](size_t limit) {
for (size_t i = array.snapshot().count(); i < limit; i++)
array.push_back(i);
};
auto check = [&]{
size_t i = 0;
for (auto element : array.snapshot()) {
ASSERT_EQ(element, i);
i++;
}
};

check();
add(1);
check();
add(16);
check();
add(100);
check();
add(1000);
check();
add(1000000);
check();
}