Skip to content

Commit 6d63cf4

Browse files
committed
fix cast type in getting task record
1 parent bf5ef5a commit 6d63cf4

File tree

5 files changed

+38
-62
lines changed

5 files changed

+38
-62
lines changed

include/swift/ABI/TaskStatus.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,16 @@ class TaskStatusRecord {
4545
TaskStatusRecord(TaskStatusRecordKind kind,
4646
TaskStatusRecord *parent = nullptr)
4747
: Flags(kind) {
48+
getKind();
4849
resetParent(parent);
4950
}
5051

5152
TaskStatusRecord(const TaskStatusRecord &) = delete;
5253
TaskStatusRecord &operator=(const TaskStatusRecord &) = delete;
5354

54-
TaskStatusRecordKind getKind() const { return Flags.getKind(); }
55+
TaskStatusRecordKind getKind() const {
56+
return Flags.getKind();
57+
}
5558

5659
TaskStatusRecord *getParent() const { return Parent; }
5760

@@ -172,15 +175,14 @@ class ChildTaskStatusRecord : public TaskStatusRecord {
172175
/// Group child tasks DO NOT have their own `ChildTaskStatusRecord` entries,
173176
/// and are only tracked by their respective `TaskGroupTaskStatusRecord`.
174177
class TaskGroupTaskStatusRecord : public TaskStatusRecord {
178+
public:
175179
AsyncTask *FirstChild;
176180
AsyncTask *LastChild;
177181

178-
public:
179182
TaskGroupTaskStatusRecord()
180183
: TaskStatusRecord(TaskStatusRecordKind::TaskGroup),
181184
FirstChild(nullptr),
182-
LastChild(nullptr) {
183-
}
185+
LastChild(nullptr) {}
184186

185187
TaskGroupTaskStatusRecord(AsyncTask *child)
186188
: TaskStatusRecord(TaskStatusRecordKind::TaskGroup),
@@ -189,7 +191,8 @@ class TaskGroupTaskStatusRecord : public TaskStatusRecord {
189191
assert(!LastChild || !LastChild->childFragment()->getNextChild());
190192
}
191193

192-
TaskGroup *getGroup() { return reinterpret_cast<TaskGroup *>(this); }
194+
/// Get the task group this record is associated with.
195+
TaskGroup *getGroup();
193196

194197
/// Return the first child linked by this record. This may be null;
195198
/// if not, it (and all of its successors) are guaranteed to satisfy

stdlib/public/BackDeployConcurrency/TaskStatus.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,7 @@ class TaskGroupTaskStatusRecord : public TaskStatusRecord {
170170
TaskGroupTaskStatusRecord()
171171
: TaskStatusRecord(TaskStatusRecordKind::TaskGroup),
172172
FirstChild(nullptr),
173-
LastChild(nullptr) {
174-
}
173+
LastChild(nullptr) {}
175174

176175
TaskGroupTaskStatusRecord(AsyncTask *child)
177176
: TaskStatusRecord(TaskStatusRecordKind::TaskGroup),
@@ -180,7 +179,9 @@ class TaskGroupTaskStatusRecord : public TaskStatusRecord {
180179
assert(!LastChild || !LastChild->childFragment()->getNextChild());
181180
}
182181

183-
TaskGroup *getGroup() { return reinterpret_cast<TaskGroup *>(this); }
182+
TaskGroup *getGroup() {
183+
return reinterpret_cast<TaskGroup *>(this);
184+
}
184185

185186
/// Return the first child linked by this record. This may be null;
186187
/// if not, it (and all of its successors) are guaranteed to satisfy

stdlib/public/Concurrency/TaskGroup.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050

5151
using namespace swift;
5252

53-
#if 1
53+
#if 0
5454
#define SWIFT_TASK_GROUP_DEBUG_LOG(group, fmt, ...) \
5555
fprintf(stderr, "[%#lx] [%s:%d](%s) group(%p%s) " fmt "\n", \
5656
(unsigned long)Thread::current().platformThreadId(), \
@@ -67,8 +67,8 @@ namespace {
6767
class TaskStatusRecord;
6868
struct TaskGroupStatus;
6969

70-
struct AccumulatingTaskGroup;
71-
struct DiscardingTaskGroup;
70+
class AccumulatingTaskGroup;
71+
class DiscardingTaskGroup;
7272

7373
/******************************************************************************/
7474
/*************************** TASK GROUP BASE **********************************/
@@ -112,9 +112,15 @@ class TaskGroupBase : public TaskGroupTaskStatusRecord {
112112
waitQueue(nullptr),
113113
successType(T) {}
114114

115+
TaskGroupBase(const TaskGroupBase &) = delete;
116+
115117
public:
116118
virtual ~TaskGroupBase() {}
117119

120+
TaskStatusRecordKind getKind() const {
121+
return Flags.getKind();
122+
}
123+
118124
/// Describes the status of the group.
119125
enum class ReadyStatus : uintptr_t {
120126
/// The task group is empty, no tasks are pending.
@@ -259,7 +265,7 @@ class TaskGroupBase : public TaskGroupTaskStatusRecord {
259265
/// Any TaskGroup always IS its own TaskRecord.
260266
/// This allows us to easily get the group while cancellation is propagated throughout the task tree.
261267
TaskGroupTaskStatusRecord *getTaskRecord() {
262-
return reinterpret_cast<TaskGroupTaskStatusRecord *>(this);
268+
return static_cast<TaskGroupTaskStatusRecord *>(this);
263269
}
264270

265271
// ==== Queue operations ----------------------------------------------------
@@ -824,11 +830,11 @@ static TaskGroupBase *asBaseImpl(TaskGroup *group) {
824830
}
825831
static AccumulatingTaskGroup *asAccumulatingImpl(TaskGroup *group) {
826832
assert(group->isAccumulatingResults());
827-
return reinterpret_cast<AccumulatingTaskGroup*>(group);
833+
return static_cast<AccumulatingTaskGroup*>(reinterpret_cast<TaskGroupBase*>(group));
828834
}
829835
static DiscardingTaskGroup *asDiscardingImpl(TaskGroup *group) {
830836
assert(group->isDiscardingResults());
831-
return reinterpret_cast<DiscardingTaskGroup*>(group);
837+
return static_cast<DiscardingTaskGroup*>(reinterpret_cast<TaskGroupBase*>(group));
832838
}
833839

834840
static TaskGroup *asAbstract(TaskGroupBase *group) {
@@ -849,6 +855,10 @@ bool TaskGroup::isDiscardingResults() {
849855
return asBaseImpl(this)->isDiscardingResults();
850856
}
851857

858+
TaskGroup* TaskGroupTaskStatusRecord::getGroup() {
859+
return reinterpret_cast<TaskGroup *>(static_cast<TaskGroupBase*>(this));
860+
}
861+
852862
// =============================================================================
853863
// ==== initialize -------------------------------------------------------------
854864

@@ -864,7 +874,7 @@ static void swift_taskGroup_initializeWithFlagsImpl(size_t rawGroupFlags,
864874
TaskGroup *group, const Metadata *T) {
865875

866876
TaskGroupFlags groupFlags(rawGroupFlags);
867-
SWIFT_TASK_DEBUG_LOG("(group(%p) create; flags: isDiscardingResults=%d",
877+
SWIFT_TASK_DEBUG_LOG("group(%p) create; flags: isDiscardingResults=%d",
868878
group, groupFlags.isDiscardResults());
869879

870880
TaskGroupBase *impl;
@@ -875,6 +885,8 @@ static void swift_taskGroup_initializeWithFlagsImpl(size_t rawGroupFlags,
875885
}
876886

877887
TaskGroupTaskStatusRecord *record = impl->getTaskRecord();
888+
assert(record->getKind() == swift::TaskStatusRecordKind::TaskGroup);
889+
878890
// ok, now that the group actually is initialized: attach it to the task
879891
addStatusRecord(record, [&](ActiveTaskStatus parentStatus) {
880892
// If the task has already been cancelled, reflect that immediately in
@@ -902,7 +914,8 @@ void TaskGroup::addChildTask(AsyncTask *child) {
902914
// prevents us from racing with cancellation or escalation. We don't
903915
// need to acquire the task group lock because the child list is only
904916
// accessed under the task status record lock.
905-
auto record = asBaseImpl(this)->getTaskRecord();
917+
auto base = asBaseImpl(this);
918+
auto record = base->getTaskRecord();
906919
record->attachChild(child);
907920
}
908921

test/Concurrency/Runtime/async_task_locals_prevent_illegal_use_discarding_taskgroup.swift

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,7 @@ enum TL {
2222
func bindAroundGroupAddTask() async {
2323
await TL.$number.withValue(1111) { // ok
2424
await withTaskGroup(of: Int.self) { group in
25-
// CHECK: error: task-local: detected illegal task-local value binding at {{.*}}illegal_use.swift:[[# @LINE + 1]]
26-
TL.$number.withValue(2222) { // bad!
27-
print("Survived, inside withValue!") // CHECK-NOT: Survived, inside withValue!
28-
group.addTask {
29-
0 // don't actually perform the read, it would be unsafe.
30-
}
31-
}
32-
33-
print("Survived the illegal call!") // CHECK-NOT: Survived the illegal call!
34-
}
35-
}
36-
}
37-
38-
func bindAroundDiscardingGroupAddTask() async {
39-
await TL.$number.withValue(1111) { // ok
40-
await withDiscardingTaskGroup(of: Int.self) { group in
41-
// CHECK: error: task-local: detected illegal task-local value binding at {{.*}}illegal_use.swift:[[# @LINE + 1]]
25+
// CHECK: error: task-local: detected illegal task-local value binding at {{.*}}illegal_use_discarding_taskgroup.swift:[[# @LINE + 1]]
4226
TL.$number.withValue(2222) { // bad!
4327
print("Survived, inside withValue!") // CHECK-NOT: Survived, inside withValue!
4428
group.addTask {
@@ -53,10 +37,6 @@ func bindAroundDiscardingGroupAddTask() async {
5337

5438
@main struct Main {
5539
static func main() async {
56-
if CommandLine.arguments.contains("discarding-task-group") {
57-
await bindAroundGroupAddTask()
58-
} else {
59-
await bindAroundDiscardingGroupAddTask()
60-
}
40+
await bindAroundGroupAddTask()
6141
}
6242
}

test/Concurrency/Runtime/async_task_locals_prevent_illegal_use_taskgroup.swift

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,34 +11,17 @@
1111
// REQUIRES: concurrency_runtime
1212
// UNSUPPORTED: back_deployment_runtime
1313

14-
@available(SwiftStdlib 5.1, *)
1514
enum TL {
1615
@TaskLocal
1716
static var number: Int = 2
1817
}
1918

2019
// ==== ------------------------------------------------------------------------
2120

22-
func bindAroundGroupAddTask() async {
23-
await TL.$number.withValue(1111) { // ok
24-
await withTaskGroup(of: Int.self) { group in
25-
// CHECK: error: task-local: detected illegal task-local value binding at {{.*}}illegal_use.swift:[[# @LINE + 1]]
26-
TL.$number.withValue(2222) { // bad!
27-
print("Survived, inside withValue!") // CHECK-NOT: Survived, inside withValue!
28-
group.addTask {
29-
0 // don't actually perform the read, it would be unsafe.
30-
}
31-
}
32-
33-
print("Survived the illegal call!") // CHECK-NOT: Survived the illegal call!
34-
}
35-
}
36-
}
37-
3821
func bindAroundDiscardingGroupAddTask() async {
3922
await TL.$number.withValue(1111) { // ok
40-
await withDiscardingTaskGroup(of: Int.self) { group in
41-
// CHECK: error: task-local: detected illegal task-local value binding at {{.*}}illegal_use.swift:[[# @LINE + 1]]
23+
await withTaskGroup(of: Int.self) { group in
24+
// CHECK: error: task-local: detected illegal task-local value binding at {{.*}}illegal_use_taskgroup.swift:[[# @LINE + 1]]
4225
TL.$number.withValue(2222) { // bad!
4326
print("Survived, inside withValue!") // CHECK-NOT: Survived, inside withValue!
4427
group.addTask {
@@ -53,10 +36,6 @@ func bindAroundDiscardingGroupAddTask() async {
5336

5437
@main struct Main {
5538
static func main() async {
56-
if CommandLine.arguments.contains("discarding-task-group") {
57-
await bindAroundGroupAddTask()
58-
} else {
59-
await bindAroundDiscardingGroupAddTask()
60-
}
39+
await bindAroundDiscardingGroupAddTask()
6140
}
6241
}

0 commit comments

Comments
 (0)