Skip to content

Commit 0e9fe7e

Browse files
committed
[Concurrency] Relax task-local mis-use prevention in task groups
1 parent c64a95b commit 0e9fe7e

File tree

7 files changed

+141
-35
lines changed

7 files changed

+141
-35
lines changed

include/swift/ABI/TaskLocal.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ class TaskLocal {
4444
/// lookups by skipping empty parent tasks during get(), and explained
4545
/// in depth in `createParentLink`.
4646
IsParent = 0b01,
47+
/// The task local binding was created inside the body of a `withTaskGroup`,
48+
/// and therefore must either copy it, or crash when a child task is created
49+
/// using 'group.addTask' and it would refer to this task local.
50+
IsNextCreatedInTaskGroupBody = 0b10,
4751
};
4852

4953
class Item {
@@ -101,10 +105,20 @@ class TaskLocal {
101105
/// the Item linked list into the appropriate parent.
102106
static Item *createParentLink(AsyncTask *task, AsyncTask *parent);
103107

108+
static Item *createLink(AsyncTask *task,
109+
const HeapObject *key,
110+
const Metadata *valueType,
111+
bool inTaskGroupBody);
112+
104113
static Item *createLink(AsyncTask *task,
105114
const HeapObject *key,
106115
const Metadata *valueType);
107116

117+
static Item *createLinkInTaskGroup(
118+
AsyncTask *task,
119+
const HeapObject *key,
120+
const Metadata *valueType);
121+
108122
void destroy(AsyncTask *task);
109123

110124
Item *getNext() {
@@ -115,6 +129,16 @@ class TaskLocal {
115129
return static_cast<NextLinkType>(next & statusMask);
116130
}
117131

132+
bool isNextLinkPointer() const {
133+
return static_cast<NextLinkType>(next & statusMask) ==
134+
NextLinkType::IsNext;
135+
}
136+
137+
bool IsNextCreatedInTaskGroupBody() const {
138+
return static_cast<NextLinkType>(next & statusMask) ==
139+
NextLinkType::IsNextCreatedInTaskGroupBody;
140+
}
141+
118142
/// Item does not contain any actual value, and is only used to point at
119143
/// a specific parent item.
120144
bool isEmpty() const {
@@ -136,9 +160,9 @@ class TaskLocal {
136160
if (valueType) {
137161
size_t alignment = valueType->vw_alignment();
138162
return (offset + alignment - 1) & ~(alignment - 1);
139-
} else {
140-
return offset;
141163
}
164+
165+
return offset;
142166
}
143167

144168
/// Determine the size of the item given a particular value type.
@@ -200,6 +224,8 @@ class TaskLocal {
200224
/// can be safely disposed of.
201225
bool popValue(AsyncTask *task);
202226

227+
NextLinkType peekHeadLinkType() const;
228+
203229
/// Copy all task-local bindings to the target task.
204230
///
205231
/// The new bindings allocate their own items and can out-live the current task.

stdlib/public/Concurrency/Task.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -729,19 +729,31 @@ swift_task_create_commonImpl(size_t rawTaskCreateFlags,
729729
assert(initialContextSize >= sizeof(FutureAsyncContext));
730730
}
731731

732-
// Add to the task group, if requested.
733-
if (taskCreateFlags.addPendingGroupTaskUnconditionally()) {
734-
assert(group && "Missing group");
735-
swift_taskGroup_addPending(group, /*unconditionally=*/true);
736-
}
737-
738732
AsyncTask *parent = nullptr;
739733
AsyncTask *currentTask = swift_task_getCurrent();
740734
if (jobFlags.task_isChildTask()) {
741735
parent = currentTask;
742736
assert(parent != nullptr && "creating a child task with no active task");
743737
}
744738

739+
if (group) {
740+
assert(parent && "a task created in a group must be a child task");
741+
742+
// Prevent task-local misuse;
743+
// We must not allow an addTask {} wrapped immediately with a withValue {}
744+
auto taskLocalHeadLinkType = parent->_private().Local.peekHeadLinkType();
745+
if (taskLocalHeadLinkType == swift::TaskLocal::NextLinkType::IsNextCreatedInTaskGroupBody) {
746+
swift_task_reportIllegalTaskLocalBindingWithinWithTaskGroup(nullptr, 0, true, 0);
747+
abort();
748+
}
749+
750+
// Add to the task group, if requested.
751+
if (taskCreateFlags.addPendingGroupTaskUnconditionally()) {
752+
assert(group && "Missing group");
753+
swift_taskGroup_addPending(group, /*unconditionally=*/true);
754+
}
755+
}
756+
745757
// Start with user specified priority at creation time (if any)
746758
JobPriority basePriority = (taskCreateFlags.getRequestedPriority());
747759

stdlib/public/Concurrency/TaskLocal.cpp

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ TaskLocal::Item::createParentLink(AsyncTask *task, AsyncTask *parent) {
171171
static_cast<uintptr_t>(NextLinkType::IsParent);
172172
break;
173173
case NextLinkType::IsNext:
174+
case NextLinkType::IsNextCreatedInTaskGroupBody:
174175
if (parentHead->getNext()) {
175176
assert(false && "empty taskValue head in parent task, yet parent's 'head' is `IsNext`, "
176177
"this should not happen, as it implies the parent must have stored some value.");
@@ -194,7 +195,8 @@ TaskLocal::Item::createParentLink(AsyncTask *task, AsyncTask *parent) {
194195
TaskLocal::Item*
195196
TaskLocal::Item::createLink(AsyncTask *task,
196197
const HeapObject *key,
197-
const Metadata *valueType) {
198+
const Metadata *valueType,
199+
bool inTaskGroupBody) {
198200
size_t amountToAllocate = Item::itemSize(valueType);
199201
void *allocation = task ? _swift_task_alloc_specific(task, amountToAllocate)
200202
: malloc(amountToAllocate);
@@ -203,11 +205,28 @@ TaskLocal::Item::createLink(AsyncTask *task,
203205
auto next = task ? task->_private().Local.head
204206
: FallbackTaskLocalStorage::get()->head;
205207
item->next = reinterpret_cast<uintptr_t>(next) |
206-
static_cast<uintptr_t>(NextLinkType::IsNext);
208+
static_cast<uintptr_t>(
209+
inTaskGroupBody ? NextLinkType::IsNextCreatedInTaskGroupBody
210+
: NextLinkType::IsNext);
207211

208212
return item;
209213
}
210214

215+
TaskLocal::Item*
216+
TaskLocal::Item::createLink(AsyncTask *task,
217+
const HeapObject *key,
218+
const Metadata *valueType) {
219+
return createLink(task, key, valueType, /*=inTaskGroupBody=*/false);
220+
}
221+
222+
223+
TaskLocal::Item*
224+
TaskLocal::Item::createLinkInTaskGroup(AsyncTask *task,
225+
const HeapObject *key,
226+
const Metadata *valueType) {
227+
return createLink(task, key, valueType, /*=inTaskGroupBody=*/true);
228+
}
229+
211230

212231
void TaskLocal::Item::copyTo(AsyncTask *target) {
213232
assert(target && "TaskLocal item attempt to copy to null target task!");
@@ -232,13 +251,11 @@ void TaskLocal::Item::copyTo(AsyncTask *target) {
232251

233252
SWIFT_CC(swift)
234253
static void swift_task_reportIllegalTaskLocalBindingWithinWithTaskGroupImpl(
235-
const unsigned char *file, uintptr_t fileLength,
236-
bool fileIsASCII, uintptr_t line) {
254+
const unsigned char *_unused_file, uintptr_t _unused_fileLength,
255+
bool _unused_fileIsASCII, uintptr_t _unused_line) {
237256

238-
char *message;
239-
swift_asprintf(
240-
&message,
241-
"error: task-local: detected illegal task-local value binding at %.*s:%d.\n"
257+
char *message =
258+
"error: task-local: detected illegal task-local value binding.\n"
242259
"Task-local values must only be set in a structured-context, such as: "
243260
"around any (synchronous or asynchronous function invocation), "
244261
"around an 'async let' declaration, or around a 'with(Throwing)TaskGroup(...){ ... }' "
@@ -272,9 +289,7 @@ static void swift_task_reportIllegalTaskLocalBindingWithinWithTaskGroupImpl(
272289
" }\n"
273290
"\n"
274291
" group.addTask { ... }\n"
275-
" }\n",
276-
(int)fileLength, file,
277-
(int)line);
292+
" }\n";
278293

279294
if (_swift_shouldReportFatalErrorsToDebugger()) {
280295
RuntimeErrorDetails details = {
@@ -329,6 +344,7 @@ void TaskLocal::Storage::destroy(AsyncTask *task) {
329344
auto linkType = item->getNextLinkType();
330345
switch (linkType) {
331346
case TaskLocal::NextLinkType::IsNext:
347+
case TaskLocal::NextLinkType::IsNextCreatedInTaskGroupBody:
332348
next = item->getNext();
333349
item->destroy(task);
334350
item = next;
@@ -351,8 +367,25 @@ void TaskLocal::Storage::pushValue(AsyncTask *task,
351367
/* +1 */ OpaqueValue *value,
352368
const Metadata *valueType) {
353369
assert(value && "Task local value must not be nil");
370+
assert(swift_task_getCurrent() == task &&
371+
"must only be pushing task locals onto current task");
372+
373+
// We're in a task group body.
374+
// We specifically need to prevent this pattern:
375+
//
376+
// $number.withValue(0xBAADF00D) { // push
377+
// group.addTask { ... }
378+
// } // pop! BOOM!
379+
//
380+
// because the end of the withValue scope would pop the value,
381+
// and thus if the child task didn't copy the value, it'd refer to a bad
382+
// memory location at this point.
383+
384+
TaskLocal::Item* item = Item::createLink(
385+
task, key, valueType,
386+
/*inTaskGroupBody=*/swift_task_hasTaskGroupStatusRecord());
387+
354388

355-
auto item = Item::createLink(task, key, valueType);
356389
valueType->vw_initializeWithTake(item->getStoragePtr(), value);
357390
head = item;
358391
}
@@ -367,6 +400,11 @@ bool TaskLocal::Storage::popValue(AsyncTask *task) {
367400
return head != nullptr;
368401
}
369402

403+
TaskLocal::NextLinkType
404+
TaskLocal::Storage::peekHeadLinkType() const {
405+
return head->getNextLinkType();
406+
}
407+
370408
OpaqueValue* TaskLocal::Storage::getValue(AsyncTask *task,
371409
const HeapObject *key) {
372410
assert(key && "TaskLocal key must not be null.");

stdlib/public/Concurrency/TaskLocal.swift

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,6 @@ public final class TaskLocal<Value: Sendable>: Sendable, CustomStringConvertible
252252
operation: () async throws -> R,
253253
isolation: isolated (any Actor)?,
254254
file: String = #fileID, line: UInt = #line) async rethrows -> R {
255-
// check if we're not trying to bind a value from an illegal context; this may crash
256-
_checkIllegalTaskLocalBindingWithinWithTaskGroup(file: file, line: line)
257-
258255
_taskLocalValuePush(key: key, value: consume valueDuringOperation)
259256
defer { _taskLocalValuePop() }
260257

@@ -269,9 +266,6 @@ public final class TaskLocal<Value: Sendable>: Sendable, CustomStringConvertible
269266
internal func withValueImpl<R>(_ valueDuringOperation: __owned Value,
270267
operation: () async throws -> R,
271268
file: String = #fileID, line: UInt = #line) async rethrows -> R {
272-
// check if we're not trying to bind a value from an illegal context; this may crash
273-
_checkIllegalTaskLocalBindingWithinWithTaskGroup(file: file, line: line)
274-
275269
_taskLocalValuePush(key: key, value: consume valueDuringOperation)
276270
defer { _taskLocalValuePop() }
277271

@@ -296,9 +290,6 @@ public final class TaskLocal<Value: Sendable>: Sendable, CustomStringConvertible
296290
@discardableResult
297291
public func withValue<R>(_ valueDuringOperation: Value, operation: () throws -> R,
298292
file: String = #fileID, line: UInt = #line) rethrows -> R {
299-
// check if we're not trying to bind a value from an illegal context; this may crash
300-
_checkIllegalTaskLocalBindingWithinWithTaskGroup(file: file, line: line)
301-
302293
_taskLocalValuePush(key: key, value: valueDuringOperation)
303294
defer { _taskLocalValuePop() }
304295

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: %target-run-simple-swift( -plugin-path %swift-plugin-dir -Xfrontend -disable-availability-checking -parse-as-library %import-libdispatch) 2>&1 | %FileCheck %s --dump-input=always
2+
3+
// REQUIRES: executable_test
4+
// REQUIRES: concurrency
5+
// REQUIRES: libdispatch
6+
// REQUIRES: concurrency_runtime
7+
// UNSUPPORTED: back_deployment_runtime
8+
9+
@available(SwiftStdlib 5.1, *)
10+
enum TL {
11+
@TaskLocal
12+
static var number: Int = 2
13+
}
14+
15+
// ==== ------------------------------------------------------------------------
16+
17+
func bindAroundGroupAddTask() async {
18+
await TL.$number.withValue(1111) { // ok
19+
await withTaskGroup(of: Int.self) { group in
20+
// CHECK-NOT: error: task-local: detected illegal
21+
22+
TL.$number.withValue(2222) { // this is OK, there's no addTask being wrapped
23+
print("Survived, inside withValue, value: \(TL.number)") // CHECK: Survived, inside withValue, value: 2222
24+
}
25+
26+
group.addTask {
27+
print("Survived, inside addTask, value: \(TL.number)") // CHECK: Survived, inside addTask, value: 1111
28+
return TL.number
29+
}
30+
}
31+
print("Survived, done") // CHECK: Survived, done
32+
}
33+
}
34+
35+
@main struct Main {
36+
static func main() async {
37+
await bindAroundGroupAddTask()
38+
}
39+
}

test/Concurrency/Runtime/async_task_locals_prevent_illegal_use_discarding_taskgroup.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ enum TL {
2222
func bindAroundGroupAddTask() async {
2323
await TL.$number.withValue(1111) { // ok
2424
await withTaskGroup(of: Int.self) { group in
25-
// CHECK: error: task-local: detected illegal task-local value binding at {{.*}}illegal_use_discarding_taskgroup.swift:[[# @LINE + 1]]
25+
// CHECK: error: task-local: detected illegal task-local value binding
2626
TL.$number.withValue(2222) { // bad!
27-
print("Survived, inside withValue!") // CHECK-NOT: Survived, inside withValue!
2827
group.addTask {
29-
0 // don't actually perform the read, it would be unsafe.
28+
print("Survived, inside withValue!") // CHECK-NOT: Survived, inside withValue!
29+
return 0 // don't actually perform the read, it would be unsafe.
3030
}
3131
}
3232

test/Concurrency/Runtime/async_task_locals_prevent_illegal_use_taskgroup.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ enum TL {
2121
func bindAroundDiscardingGroupAddTask() async {
2222
await TL.$number.withValue(1111) { // ok
2323
await withTaskGroup(of: Int.self) { group in
24-
// CHECK: error: task-local: detected illegal task-local value binding at {{.*}}illegal_use_taskgroup.swift:[[# @LINE + 1]]
24+
// CHECK: error: task-local: detected illegal task-local value binding
2525
TL.$number.withValue(2222) { // bad!
26-
print("Survived, inside withValue!") // CHECK-NOT: Survived, inside withValue!
2726
group.addTask {
28-
0 // don't actually perform the read, it would be unsafe.
27+
print("Survived, inside withValue!") // CHECK-NOT: Survived, inside withValue!
28+
return 0 // don't actually perform the read, it would be unsafe.
2929
}
3030
}
3131

0 commit comments

Comments
 (0)