Skip to content

Commit e270da9

Browse files
committed
[Concurrency] Correct handling of mixed nested values
1 parent d43e61d commit e270da9

File tree

3 files changed

+48
-15
lines changed

3 files changed

+48
-15
lines changed

include/swift/ABI/TaskLocal.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ class TaskLocal {
4747
/// The task local binding was created inside the body of a `withTaskGroup`,
4848
/// and therefore must either copy it, or crash when a child task is created
4949
/// using 'group.addTask' and it would refer to this task local.
50+
///
51+
/// Items of this kind must be copied by a group child task for access
52+
/// safety reasons, as otherwise the pop would happen before the child task
53+
/// has completed.
5054
IsNextCreatedInTaskGroupBody = 0b10,
5155
};
5256

@@ -126,6 +130,8 @@ class TaskLocal {
126130
}
127131

128132
void relinkNext(Item* nextOverride) {
133+
fprintf(stderr, "[%s:%d](%s) try relink item:%p\n", __FILE_NAME__, __LINE__, __FUNCTION__, this);
134+
fprintf(stderr, "[%s:%d](%s) try relink to target:%p\n", __FILE_NAME__, __LINE__, __FUNCTION__, nextOverride);
129135
assert(!getNext() &&
130136
"Can only relink task local item that was not pointing at anything yet");
131137
assert(nextOverride->isNextLinkPointer() ||
@@ -149,7 +155,7 @@ class TaskLocal {
149155
NextLinkType::IsNext;
150156
}
151157

152-
bool isNextCreatedInTaskGroupBody() const {
158+
bool isNextLinkPointerCreatedInTaskGroupBody() const {
153159
return static_cast<NextLinkType>(next & statusMask) ==
154160
NextLinkType::IsNextCreatedInTaskGroupBody;
155161
}

stdlib/public/Concurrency/TaskLocal.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,8 @@ void TaskLocal::Storage::pushValue(AsyncTask *task,
436436

437437
valueType->vw_initializeWithTake(item->getStoragePtr(), value);
438438
head = item;
439-
SWIFT_TASK_LOCAL_DEBUG_LOG(item->key, "Created link item:%p", item);
439+
SWIFT_TASK_LOCAL_DEBUG_LOG(item->key, "Created link item:%p, in group body:%d",
440+
item, inTaskGroupBody);
440441
}
441442

442443
bool TaskLocal::Storage::popValue(AsyncTask *task) {
@@ -513,29 +514,35 @@ void TaskLocal::Storage::copyToOnlyOnlyFromCurrent(AsyncTask *target) {
513514
std::set<const HeapObject*> copied = {};
514515

515516
auto item = head;
516-
TaskLocal::Item *lastCopiedItem = nullptr;
517+
TaskLocal::Item *copiedHead = nullptr;
517518
while (item) {
518519
// we only have to copy an item if it is the most recent binding of a key.
519520
// i.e. if we've already seen an item for this key, we can skip it.
520521
if (copied.emplace(item->key).second) {
521522

522-
if (!item->isNextCreatedInTaskGroupBody() && lastCopiedItem) {
523-
SWIFT_TASK_LOCAL_DEBUG_LOG(item->key, "break out, next item is not within body, item value [%p]", item->getStoragePtr());
523+
if (!item->isNextLinkPointerCreatedInTaskGroupBody() && copiedHead) {
524+
SWIFT_TASK_LOCAL_DEBUG_LOG(item->key, "break out, next item is not within body, item:%p", item);
524525
// The next item is not the "risky one" so we can directly link to it,
525526
// as we would have within normal child task relationships. E.g. this is
526527
// a parent or next pointer to a "safe" (withValue { withTaskGroup { ... } })
527528
// binding, so we re-link our current head to point at this item.
528-
lastCopiedItem->relinkNext(item);
529+
copiedHead->relinkNext(item);
529530
break;
530531
}
531532

532-
lastCopiedItem = item->copyTo(target);
533+
auto copy = item->copyTo(target);
534+
if (!copiedHead) {
535+
SWIFT_TASK_LOCAL_DEBUG_LOG(item->key, "store copied head item:%p",
536+
copiedHead);
537+
copiedHead = copy;
538+
}
539+
533540
SWIFT_TASK_LOCAL_DEBUG_LOG(item->key, "copy value [%p] to target:%p, item:%p, copied:%p",
534-
item->getStoragePtr(), target, item, lastCopiedItem);
541+
item->getStoragePtr(), target, item, copiedHead);
535542

536543
// If we didn't copy an item, e.g. because it was a pointer to parent,
537544
// break out of the loop and keep pointing at parent still.
538-
if (lastCopiedItem == nullptr) {
545+
if (!copy) {
539546
SWIFT_TASK_LOCAL_DEBUG_LOG(item->key, "break out, next is %p", 0);
540547
break;
541548
}

test/Concurrency/Runtime/async_task_locals_in_task_group_may_need_to_copy.swift

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
// RUN: %target-run-simple-swift( -plugin-path %swift-plugin-dir -Xfrontend -disable-availability-checking -parse-as-library %import-libdispatch) | %FileCheck %s --dump-input=always
1+
// RUN: %target-run-simple-swift( -plugin-path %swift-plugin-dir -Xfrontend -strict-concurrency=complete -Xfrontend -disable-availability-checking -parse-as-library %import-libdispatch) | %FileCheck %s --dump-input=always
22

33
// REQUIRES: executable_test
44
// REQUIRES: concurrency
55
// REQUIRES: libdispatch
66
// REQUIRES: concurrency_runtime
77
// UNSUPPORTED: back_deployment_runtime
88

9-
@available(SwiftStdlib 5.1, *)
109
enum TL {
1110
@TaskLocal
1211
static var one: Int = 1
@@ -22,22 +21,43 @@ func test() async {
2221
await TL.$one.withValue(11) {
2322
await TL.$one.withValue(1111) {
2423
await withTaskGroup(of: Void.self) { group in
24+
2525
TL.$two.withValue(2222) {
2626
group.addTask { // will have to copy the `2222`
2727
print("Survived, one: \(TL.one) @ \(#fileID):\(#line)") // CHECK: Survived, one: 1111
2828
print("Survived, two: \(TL.two) @ \(#fileID):\(#line)") // CHECK: Survived, two: 2222
2929
}
3030
}
31+
await group.next()
32+
print("--")
3133

34+
TL.$two.withValue(2) {
35+
TL.$two.withValue(22) {
36+
TL.$two.withValue(2222) {
37+
group.addTask { // will have to copy the `2222`
38+
print("Survived, one: \(TL.one) @ \(#fileID):\(#line)") // CHECK: Survived, one: 1111
39+
print("Survived, two: \(TL.two) @ \(#fileID):\(#line)") // CHECK: Survived, two: 2222
40+
}
41+
}
42+
}
43+
}
3244
await group.next()
3345
print("--")
3446

3547
TL.$two.withValue(2) {
3648
TL.$two.withValue(22) {
37-
TL.$two.withValue(2222) {
38-
group.addTask { // will have to copy the `2222`
39-
print("Survived, one: \(TL.one) @ \(#fileID):\(#line)") // CHECK: Survived, one: 1111
40-
print("Survived, two: \(TL.two) @ \(#fileID):\(#line)") // CHECK: Survived, two: 2222
49+
TL.$three.withValue(33) {
50+
TL.$two.withValue(2222) {
51+
group.addTask { // will have to copy the `2222`
52+
print("Survived, one: \(TL.one) @ \(#fileID):\(#line)") // CHECK: Survived, one: 1111
53+
print("Survived, two: \(TL.two) @ \(#fileID):\(#line)") // CHECK: Survived, two: 2222
54+
55+
TL.$three.withValue(3333) {
56+
print("Survived, one: \(TL.one) @ \(#fileID):\(#line)") // CHECK: Survived, one: 1111
57+
print("Survived, two: \(TL.two) @ \(#fileID):\(#line)") // CHECK: Survived, two: 2222
58+
print("Survived, three: \(TL.three) @ \(#fileID):\(#line)") // CHECK: Survived, three: 3333
59+
}
60+
}
4161
}
4262
}
4363
}

0 commit comments

Comments
 (0)