Skip to content

Commit 93470fb

Browse files
authored
Merge pull request #16794 from mikeash/fix-ConcurrentReadableArray-double-free
[Runtime] Fix double-frees in ConcurrentReadableArray.
2 parents b0aeae4 + a4863c4 commit 93470fb

File tree

4 files changed

+191
-129
lines changed

4 files changed

+191
-129
lines changed

include/swift/Runtime/Concurrent.h

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -451,14 +451,57 @@ template <class ElemTy> struct ConcurrentReadableArray {
451451
Mutex WriterLock;
452452
std::vector<Storage *> FreeList;
453453

454+
void incrementReaders() {
455+
ReaderCount.fetch_add(1, std::memory_order_acquire);
456+
}
457+
458+
void decrementReaders() {
459+
ReaderCount.fetch_sub(1, std::memory_order_release);
460+
}
461+
462+
void deallocateFreeList() {
463+
for (Storage *storage : FreeList)
464+
storage->deallocate();
465+
FreeList.clear();
466+
FreeList.shrink_to_fit();
467+
}
468+
454469
public:
470+
struct Snapshot {
471+
ConcurrentReadableArray *Array;
472+
const ElemTy *Start;
473+
size_t Count;
474+
475+
Snapshot(ConcurrentReadableArray *array, const ElemTy *start, size_t count)
476+
: Array(array), Start(start), Count(count) {}
477+
478+
Snapshot(const Snapshot &other)
479+
: Array(other.Array), Start(other.Start), Count(other.Count) {
480+
Array->incrementReaders();
481+
}
482+
483+
~Snapshot() {
484+
Array->decrementReaders();
485+
}
486+
487+
const ElemTy *begin() { return Start; }
488+
const ElemTy *end() { return Start + Count; }
489+
size_t count() { return Count; }
490+
};
491+
455492
// This type cannot be safely copied, moved, or deleted.
456493
ConcurrentReadableArray(const ConcurrentReadableArray &) = delete;
457494
ConcurrentReadableArray(ConcurrentReadableArray &&) = delete;
458495
ConcurrentReadableArray &operator=(const ConcurrentReadableArray &) = delete;
459496

460497
ConcurrentReadableArray() : Capacity(0), ReaderCount(0), Elements(nullptr) {}
461498

499+
~ConcurrentReadableArray() {
500+
assert(ReaderCount.load(std::memory_order_acquire) == 0 &&
501+
"deallocating ConcurrentReadableArray with outstanding snapshots");
502+
deallocateFreeList();
503+
}
504+
462505
void push_back(const ElemTy &elem) {
463506
ScopedLock guard(WriterLock);
464507

@@ -482,32 +525,19 @@ template <class ElemTy> struct ConcurrentReadableArray {
482525
storage->Count.store(count + 1, std::memory_order_release);
483526

484527
if (ReaderCount.load(std::memory_order_acquire) == 0)
485-
for (Storage *storage : FreeList)
486-
storage->deallocate();
528+
deallocateFreeList();
487529
}
488530

489-
/// Read the contents of the array. The parameter `f` is called with
490-
/// two parameters: a pointer to the elements in the array, and the
491-
/// count. This represents a snapshot of the contents at the time
492-
/// `read` was called. The pointer becomes invalid after `f` returns.
493-
template <class F> auto read(F f) -> decltype(f(nullptr, 0)) {
494-
ReaderCount.fetch_add(1, std::memory_order_acquire);
531+
Snapshot snapshot() {
532+
incrementReaders();
495533
auto *storage = Elements.load(SWIFT_MEMORY_ORDER_CONSUME);
534+
if (storage == nullptr) {
535+
return Snapshot(this, nullptr, 0);
536+
}
537+
496538
auto count = storage->Count.load(std::memory_order_acquire);
497539
const auto *ptr = storage->data();
498-
499-
decltype(f(nullptr, 0)) result = f(ptr, count);
500-
501-
ReaderCount.fetch_sub(1, std::memory_order_release);
502-
503-
return result;
504-
}
505-
506-
/// Get the current count. It's just a snapshot and may be obsolete immediately.
507-
size_t count() {
508-
return read([](const ElemTy *ptr, size_t count) -> size_t {
509-
return count;
510-
});
540+
return Snapshot(this, ptr, count);
511541
}
512542
};
513543

stdlib/public/runtime/ProtocolConformance.cpp

Lines changed: 95 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -326,14 +326,11 @@ void ConformanceState::verify() const {
326326
// Iterate over all of the sections and verify all of the protocol
327327
// descriptors.
328328
auto &Self = const_cast<ConformanceState &>(*this);
329-
Self.SectionsToScan.read([](const ConformanceSection *ptr, size_t count) -> char {
330-
for (size_t i = 0; i < count; i++) {
331-
for (const auto &Record : ptr[i]) {
332-
Record.get()->verify();
333-
}
329+
for (const auto &Section : Self.SectionsToScan.snapshot()) {
330+
for (const auto &Record : Section) {
331+
Record.get()->verify();
334332
}
335-
return 0;
336-
});
333+
}
337334
}
338335
#endif
339336

@@ -445,7 +442,7 @@ searchInConformanceCache(const Metadata *type,
445442
}
446443

447444
// Check if the negative cache entry is up-to-date.
448-
if (Value->getFailureGeneration() == C.SectionsToScan.count()) {
445+
if (Value->getFailureGeneration() == C.SectionsToScan.snapshot().count()) {
449446
// Negative cache entry is up-to-date. Return failure along with
450447
// the original query type's own cache entry, if we found one.
451448
// (That entry may be out of date but the caller still has use for it.)
@@ -546,100 +543,94 @@ swift_conformsToProtocolImpl(const Metadata * const type,
546543
auto failureEntry = FoundConformance.failureEntry;
547544

548545
// Prepare to scan conformance records.
549-
size_t scannedCount;
550-
auto returnNull = C.SectionsToScan
551-
.read([&](const ConformanceSection *ptr, size_t count) -> bool {
552-
scannedCount = count;
553-
// Scan only sections that were not scanned yet.
554-
// If we found an out-of-date negative cache entry,
555-
// we need not to re-scan the sections that it covers.
556-
auto startIndex = failureEntry ? failureEntry->getFailureGeneration() : 0;
557-
auto endIndex = count;
546+
auto snapshot = C.SectionsToScan.snapshot();
558547

559-
// If there are no unscanned sections outstanding
560-
// then we can cache failure and give up now.
561-
if (startIndex == endIndex) {
562-
C.cacheFailure(type, protocol, count);
563-
return true;
548+
// Scan only sections that were not scanned yet.
549+
// If we found an out-of-date negative cache entry,
550+
// we need not to re-scan the sections that it covers.
551+
auto startIndex = failureEntry ? failureEntry->getFailureGeneration() : 0;
552+
auto endIndex = snapshot.count();
553+
554+
// If there are no unscanned sections outstanding
555+
// then we can cache failure and give up now.
556+
if (startIndex == endIndex) {
557+
C.cacheFailure(type, protocol, snapshot.count());
558+
return nullptr;
559+
}
560+
561+
/// Local function to retrieve the witness table and record the result.
562+
auto recordWitnessTable = [&](const ProtocolConformanceDescriptor &descriptor,
563+
const Metadata *type) {
564+
switch (descriptor.getConformanceKind()) {
565+
case ConformanceFlags::ConformanceKind::WitnessTable:
566+
// If the record provides a nondependent witness table for all
567+
// instances of a generic type, cache it for the generic pattern.
568+
C.cacheSuccess(type, protocol, descriptor.getStaticWitnessTable());
569+
return;
570+
571+
case ConformanceFlags::ConformanceKind::WitnessTableAccessor:
572+
// If the record provides a dependent witness table accessor,
573+
// cache the result for the instantiated type metadata.
574+
C.cacheSuccess(type, protocol, descriptor.getWitnessTable(type));
575+
return;
576+
577+
case ConformanceFlags::ConformanceKind::ConditionalWitnessTableAccessor: {
578+
auto witnessTable = descriptor.getWitnessTable(type);
579+
if (witnessTable)
580+
C.cacheSuccess(type, protocol, witnessTable);
581+
else
582+
C.cacheFailure(type, protocol, snapshot.count());
583+
return;
584+
}
564585
}
565586

566-
/// Local function to retrieve the witness table and record the result.
567-
auto recordWitnessTable = [&](const ProtocolConformanceDescriptor &descriptor,
568-
const Metadata *type) {
569-
switch (descriptor.getConformanceKind()) {
570-
case ConformanceFlags::ConformanceKind::WitnessTable:
571-
// If the record provides a nondependent witness table for all
572-
// instances of a generic type, cache it for the generic pattern.
573-
C.cacheSuccess(type, protocol, descriptor.getStaticWitnessTable());
574-
return;
575-
576-
case ConformanceFlags::ConformanceKind::WitnessTableAccessor:
577-
// If the record provides a dependent witness table accessor,
578-
// cache the result for the instantiated type metadata.
579-
C.cacheSuccess(type, protocol, descriptor.getWitnessTable(type));
580-
return;
581-
582-
case ConformanceFlags::ConformanceKind::ConditionalWitnessTableAccessor: {
583-
auto witnessTable = descriptor.getWitnessTable(type);
584-
if (witnessTable)
585-
C.cacheSuccess(type, protocol, witnessTable);
586-
else
587-
C.cacheFailure(type, protocol, count);
588-
return;
589-
}
590-
}
587+
// Always fail, because we cannot interpret a future conformance
588+
// kind.
589+
C.cacheFailure(type, protocol, snapshot.count());
590+
};
591591

592-
// Always fail, because we cannot interpret a future conformance
593-
// kind.
594-
C.cacheFailure(type, protocol, count);
595-
};
596-
597-
// Really scan conformance records.
598-
for (size_t i = startIndex; i < endIndex; i++) {
599-
auto &section = ptr[i];
600-
// Eagerly pull records for nondependent witnesses into our cache.
601-
for (const auto &record : section) {
602-
auto &descriptor = *record.get();
603-
604-
// If the record applies to a specific type, cache it.
605-
if (auto metadata = descriptor.getCanonicalTypeMetadata()) {
606-
auto P = descriptor.getProtocol();
607-
608-
// Look for an exact match.
609-
if (protocol != P)
610-
continue;
611-
612-
if (!isRelatedType(type, metadata, /*candidateIsMetadata=*/true))
613-
continue;
614-
615-
// Record the witness table.
616-
recordWitnessTable(descriptor, metadata);
617-
618-
// TODO: "Nondependent witness table" probably deserves its own flag.
619-
// An accessor function might still be necessary even if the witness table
620-
// can be shared.
621-
} else if (descriptor.getTypeKind()
622-
== TypeMetadataRecordKind::DirectNominalTypeDescriptor ||
623-
descriptor.getTypeKind()
624-
== TypeMetadataRecordKind::IndirectNominalTypeDescriptor) {
625-
auto R = descriptor.getTypeContextDescriptor();
626-
auto P = descriptor.getProtocol();
627-
628-
// Look for an exact match.
629-
if (protocol != P)
630-
continue;
631-
632-
if (!isRelatedType(type, R, /*candidateIsMetadata=*/false))
633-
continue;
634-
635-
recordWitnessTable(descriptor, type);
636-
}
592+
// Really scan conformance records.
593+
for (size_t i = startIndex; i < endIndex; i++) {
594+
auto &section = snapshot.Start[i];
595+
// Eagerly pull records for nondependent witnesses into our cache.
596+
for (const auto &record : section) {
597+
auto &descriptor = *record.get();
598+
599+
// If the record applies to a specific type, cache it.
600+
if (auto metadata = descriptor.getCanonicalTypeMetadata()) {
601+
auto P = descriptor.getProtocol();
602+
603+
// Look for an exact match.
604+
if (protocol != P)
605+
continue;
606+
607+
if (!isRelatedType(type, metadata, /*candidateIsMetadata=*/true))
608+
continue;
609+
610+
// Record the witness table.
611+
recordWitnessTable(descriptor, metadata);
612+
613+
// TODO: "Nondependent witness table" probably deserves its own flag.
614+
// An accessor function might still be necessary even if the witness table
615+
// can be shared.
616+
} else if (descriptor.getTypeKind()
617+
== TypeMetadataRecordKind::DirectNominalTypeDescriptor ||
618+
descriptor.getTypeKind()
619+
== TypeMetadataRecordKind::IndirectNominalTypeDescriptor) {
620+
auto R = descriptor.getTypeContextDescriptor();
621+
auto P = descriptor.getProtocol();
622+
623+
// Look for an exact match.
624+
if (protocol != P)
625+
continue;
626+
627+
if (!isRelatedType(type, R, /*candidateIsMetadata=*/false))
628+
continue;
629+
630+
recordWitnessTable(descriptor, type);
637631
}
638632
}
639-
return false;
640-
});
641-
642-
if (returnNull) return nullptr;
633+
}
643634

644635
// Conformance scan is complete.
645636
// Search the cache once more, and this time update the cache if necessary.
@@ -648,7 +639,7 @@ swift_conformsToProtocolImpl(const Metadata * const type,
648639
if (FoundConformance.isAuthoritative) {
649640
return FoundConformance.witnessTable;
650641
} else {
651-
C.cacheFailure(type, protocol, scannedCount);
642+
C.cacheFailure(type, protocol, snapshot.count());
652643
return nullptr;
653644
}
654645
}
@@ -657,19 +648,15 @@ const TypeContextDescriptor *
657648
swift::_searchConformancesByMangledTypeName(Demangle::NodePointer node) {
658649
auto &C = Conformances.get();
659650

660-
return C.SectionsToScan
661-
.read([&](const ConformanceSection *ptr, size_t count) -> const TypeContextDescriptor * {
662-
for (size_t i = 0; i < count; i++) {
663-
auto &section = ptr[i];
664-
for (const auto &record : section) {
665-
if (auto ntd = record->getTypeContextDescriptor()) {
666-
if (_contextDescriptorMatchesMangling(ntd, node))
667-
return ntd;
668-
}
651+
for (auto &section : C.SectionsToScan.snapshot()) {
652+
for (const auto &record : section) {
653+
if (auto ntd = record->getTypeContextDescriptor()) {
654+
if (_contextDescriptorMatchesMangling(ntd, node))
655+
return ntd;
669656
}
670657
}
671-
return nullptr;
672-
});
658+
}
659+
return nullptr;
673660
}
674661

675662
/// Resolve a reference to a generic parameter to type metadata.

unittests/runtime/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ if(("${SWIFT_HOST_VARIANT_SDK}" STREQUAL "${SWIFT_PRIMARY_VARIANT_SDK}") AND
3535
add_swift_unittest(SwiftRuntimeTests
3636
Array.cpp
3737
CompatibilityOverride.cpp
38+
Concurrent.cpp
3839
Exclusivity.cpp
3940
Metadata.cpp
4041
Mutex.cpp

unittests/runtime/Concurrent.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//===--- Concurrent.cpp - Concurrent data structure tests -----------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2018 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "swift/Runtime/Concurrent.h"
14+
#include "gtest/gtest.h"
15+
16+
using namespace swift;
17+
18+
TEST(ConcurrentReadableArrayTest, SingleThreaded) {
19+
ConcurrentReadableArray<size_t> array;
20+
21+
auto add = [&](size_t limit) {
22+
for (size_t i = array.snapshot().count(); i < limit; i++)
23+
array.push_back(i);
24+
};
25+
auto check = [&]{
26+
size_t i = 0;
27+
for (auto element : array.snapshot()) {
28+
ASSERT_EQ(element, i);
29+
i++;
30+
}
31+
};
32+
33+
check();
34+
add(1);
35+
check();
36+
add(16);
37+
check();
38+
add(100);
39+
check();
40+
add(1000);
41+
check();
42+
add(1000000);
43+
check();
44+
}

0 commit comments

Comments
 (0)