Skip to content

[Concurrency] DiscardingTaskGroup (rev 3) #62914

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/swift/ABI/MetadataValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -2383,6 +2383,24 @@ enum class TaskOptionRecordKind : uint8_t {
RunInline = UINT8_MAX,
};

/// Flags for TaskGroup.
class TaskGroupFlags : public FlagSet<uint32_t> {
public:
enum {
// 8 bits are reserved for future use
/// Request the TaskGroup to immediately release completed tasks,
/// and not store their results. This also effectively disables `next()`.
TaskGroup_DiscardResults = 8,
};

explicit TaskGroupFlags(uint32_t bits) : FlagSet(bits) {}
constexpr TaskGroupFlags() {}

FLAGSET_DEFINE_FLAG_ACCESSORS(TaskGroup_DiscardResults,
isDiscardResults,
setIsDiscardResults)
};

/// Flags for cancellation records.
class TaskStatusRecordFlags : public FlagSet<size_t> {
public:
Expand Down
8 changes: 8 additions & 0 deletions include/swift/ABI/TaskGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ class alignas(Alignment_TaskGroup) TaskGroup {

// Provide accessor for task group's status record
TaskGroupTaskStatusRecord *getTaskRecord();

/// The group is a `TaskGroup` that accumulates results.
bool isAccumulatingResults() {
return !isDiscardingResults();
}

/// The group is a `DiscardingTaskGroup` that discards results.
bool isDiscardingResults();
};

} // end namespace swift
Expand Down
13 changes: 8 additions & 5 deletions include/swift/ABI/TaskStatus.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,16 @@ class TaskStatusRecord {
TaskStatusRecord(TaskStatusRecordKind kind,
TaskStatusRecord *parent = nullptr)
: Flags(kind) {
getKind();
resetParent(parent);
}

TaskStatusRecord(const TaskStatusRecord &) = delete;
TaskStatusRecord &operator=(const TaskStatusRecord &) = delete;

TaskStatusRecordKind getKind() const { return Flags.getKind(); }
TaskStatusRecordKind getKind() const {
return Flags.getKind();
}

TaskStatusRecord *getParent() const { return Parent; }

Expand Down Expand Up @@ -172,15 +175,14 @@ class ChildTaskStatusRecord : public TaskStatusRecord {
/// Group child tasks DO NOT have their own `ChildTaskStatusRecord` entries,
/// and are only tracked by their respective `TaskGroupTaskStatusRecord`.
class TaskGroupTaskStatusRecord : public TaskStatusRecord {
public:
AsyncTask *FirstChild;
AsyncTask *LastChild;

public:
TaskGroupTaskStatusRecord()
: TaskStatusRecord(TaskStatusRecordKind::TaskGroup),
FirstChild(nullptr),
LastChild(nullptr) {
}
LastChild(nullptr) {}

TaskGroupTaskStatusRecord(AsyncTask *child)
: TaskStatusRecord(TaskStatusRecordKind::TaskGroup),
Expand All @@ -189,7 +191,8 @@ class TaskGroupTaskStatusRecord : public TaskStatusRecord {
assert(!LastChild || !LastChild->childFragment()->getNextChild());
}

TaskGroup *getGroup() { return reinterpret_cast<TaskGroup *>(this); }
/// Get the task group this record is associated with.
TaskGroup *getGroup();

/// Return the first child linked by this record. This may be null;
/// if not, it (and all of its successors) are guaranteed to satisfy
Expand Down
4 changes: 4 additions & 0 deletions include/swift/AST/Builtins.def
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,10 @@ BUILTIN_MISC_OPERATION(ResumeThrowingContinuationThrowing,
BUILTIN_MISC_OPERATION(CreateTaskGroup,
"createTaskGroup", "", Special)

/// Create a task group, with options.
BUILTIN_MISC_OPERATION(CreateTaskGroupWithFlags,
"createTaskGroupWithFlags", "", Special)

/// Destroy a task group.
BUILTIN_MISC_OPERATION(DestroyTaskGroup,
"destroyTaskGroup", "", Special)
Expand Down
41 changes: 38 additions & 3 deletions include/swift/Runtime/Concurrency.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,15 @@ void swift_task_future_wait_throwing(
/// func swift_taskGroup_wait_next_throwing(
/// waitingTask: Builtin.NativeObject, // current task
/// group: Builtin.RawPointer
/// ) async -> T
/// ) async throws -> T
/// \endcode
SWIFT_EXPORT_FROM(swift_Concurrency)
SWIFT_CC(swiftasync)
void swift_taskGroup_wait_next_throwing(
OpaqueValue *resultPointer, SWIFT_ASYNC_CONTEXT AsyncContext *callerContext,
TaskGroup *group, ThrowingTaskFutureWaitContinuationFunction *resumeFn,
OpaqueValue *resultPointer,
SWIFT_ASYNC_CONTEXT AsyncContext *callerContext,
TaskGroup *group,
ThrowingTaskFutureWaitContinuationFunction *resumeFn,
AsyncContext *callContext);

/// Initialize a `TaskGroup` in the passed `group` memory location.
Expand All @@ -205,6 +207,17 @@ void swift_taskGroup_wait_next_throwing(
SWIFT_EXPORT_FROM(swift_Concurrency) SWIFT_CC(swift)
void swift_taskGroup_initialize(TaskGroup *group, const Metadata *T);

/// Initialize a `TaskGroup` in the passed `group` memory location.
/// The caller is responsible for retaining and managing the group's lifecycle.
///
/// Its Swift signature is
///
/// \code
/// func swift_taskGroup_initialize(flags: Int, group: Builtin.RawPointer)
/// \endcode
SWIFT_EXPORT_FROM(swift_Concurrency) SWIFT_CC(swift)
void swift_taskGroup_initializeWithFlags(size_t flags, TaskGroup *group, const Metadata *T);

/// Attach a child task to the parent task's task group record.
///
/// This function MUST be called from the AsyncTask running the task group.
Expand Down Expand Up @@ -276,6 +289,28 @@ void swift_taskGroup_cancelAll(TaskGroup *group);
SWIFT_EXPORT_FROM(swift_Concurrency) SWIFT_CC(swift)
bool swift_taskGroup_isCancelled(TaskGroup *group);

/// Wait until all pending tasks from the task group have completed.
/// If this task group is accumulating results, this also discards all those results.
///
/// This can be called from any thread. Its Swift signature is
///
/// \code
/// func swift_taskGroup_waitAll(
/// waitingTask: Builtin.NativeObject, // current task
/// group: Builtin.RawPointer,
/// bodyError: Swift.Error?
/// ) async throws
/// \endcode
SWIFT_EXPORT_FROM(swift_Concurrency)
SWIFT_CC(swiftasync)
void swift_taskGroup_waitAll(
OpaqueValue *resultPointer,
SWIFT_ASYNC_CONTEXT AsyncContext *callerContext,
TaskGroup *group,
SwiftError *bodyError,
ThrowingTaskFutureWaitContinuationFunction *resumeFn,
AsyncContext *callContext);

/// Check the readyQueue of a task group, return true if it has no pending tasks.
///
/// This can be called from any thread. Its Swift signature is
Expand Down
12 changes: 12 additions & 0 deletions include/swift/Runtime/RuntimeFunctions.def
Original file line number Diff line number Diff line change
Expand Up @@ -2060,6 +2060,18 @@ FUNCTION(TaskGroupInitialize,
ATTRS(NoUnwind),
EFFECT(Concurrency))

// void swift_taskGroup_initializeWithFlags(size_t flags, TaskGroup *group);
FUNCTION(TaskGroupInitializeWithFlags,
swift_taskGroup_initializeWithFlags, SwiftCC,
ConcurrencyAvailability,
RETURNS(VoidTy),
ARGS(SizeTy, // flags
Int8PtrTy, // group
TypeMetadataPtrTy // T.Type
),
ATTRS(NoUnwind),
EFFECT(Concurrency))

// void swift_taskGroup_destroy(TaskGroup *group);
FUNCTION(TaskGroupDestroy,
swift_taskGroup_destroy, SwiftCC,
Expand Down
20 changes: 20 additions & 0 deletions lib/AST/Builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1539,6 +1539,24 @@ static ValueDecl *getCreateTaskGroup(ASTContext &ctx, Identifier id) {
_rawPointer);
}

static ValueDecl *getCreateTaskGroupWithFlags(ASTContext &ctx, Identifier id) {
ModuleDecl *M = ctx.TheBuiltinModule;
DeclContext *DC = &M->getMainFile(FileUnitKind::Builtin);
SynthesisContext SC(ctx, DC);

BuiltinFunctionBuilder builder(ctx);

// int
builder.addParameter(makeConcrete(ctx.getIntType())); // 0 flags

// T.self
builder.addParameter(makeMetatype(makeGenericParam(0))); // 1 ChildTaskResult.Type

// -> Builtin.RawPointer
builder.setResult(makeConcrete(synthesizeType(SC, _rawPointer)));
return builder.build(id);
}

static ValueDecl *getDestroyTaskGroup(ASTContext &ctx, Identifier id) {
return getBuiltinFunction(ctx, id, _thin,
_parameters(_rawPointer),
Expand Down Expand Up @@ -2908,6 +2926,8 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {

case BuiltinValueKind::CreateTaskGroup:
return getCreateTaskGroup(Context, Id);
case BuiltinValueKind::CreateTaskGroupWithFlags:
return getCreateTaskGroupWithFlags(Context, Id);

case BuiltinValueKind::DestroyTaskGroup:
return getDestroyTaskGroup(Context, Id);
Expand Down
3 changes: 3 additions & 0 deletions lib/IRGen/Callee.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ namespace irgen {
AsyncLetGetThrowing,
AsyncLetFinish,
TaskGroupWaitNext,
TaskGroupWaitAll,
DistributedExecuteTarget,
};

Expand Down Expand Up @@ -247,6 +248,7 @@ namespace irgen {
case SpecialKind::AsyncLetGetThrowing:
case SpecialKind::AsyncLetFinish:
case SpecialKind::TaskGroupWaitNext:
case SpecialKind::TaskGroupWaitAll:
return true;
case SpecialKind::DistributedExecuteTarget:
return false;
Expand Down Expand Up @@ -277,6 +279,7 @@ namespace irgen {
case SpecialKind::AsyncLetGetThrowing:
case SpecialKind::AsyncLetFinish:
case SpecialKind::TaskGroupWaitNext:
case SpecialKind::TaskGroupWaitAll:
return true;
case SpecialKind::DistributedExecuteTarget:
return false;
Expand Down
15 changes: 14 additions & 1 deletion lib/IRGen/GenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,22 @@ void irgen::emitBuiltinCall(IRGenFunction &IGF, const BuiltinInfo &Builtin,
}

if (Builtin.ID == BuiltinValueKind::CreateTaskGroup) {
llvm::Value *groupFlags = nullptr;
// Claim metadata pointer.
(void)args.claimAll();
out.add(emitCreateTaskGroup(IGF, substitutions));
out.add(emitCreateTaskGroup(IGF, substitutions, groupFlags));
return;
}

if (Builtin.ID == BuiltinValueKind::CreateTaskGroupWithFlags) {
auto groupFlags = args.claimNext();
// Claim the remaining metadata pointer.
if (args.size() == 1) {
(void)args.claimNext();
} else if (args.size() > 1) {
llvm_unreachable("createTaskGroupWithFlags expects 1 or 2 arguments");
}
out.add(emitCreateTaskGroup(IGF, substitutions, groupFlags));
return;
}

Expand Down
1 change: 1 addition & 0 deletions lib/IRGen/GenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ FunctionPointerKind::getStaticAsyncContextSize(IRGenModule &IGM) const {
case SpecialKind::AsyncLetGetThrowing:
case SpecialKind::AsyncLetFinish:
case SpecialKind::TaskGroupWaitNext:
case SpecialKind::TaskGroupWaitAll:
case SpecialKind::DistributedExecuteTarget:
// The current guarantee for all of these functions is the same.
// See TaskFutureWaitAsyncContext.
Expand Down
14 changes: 10 additions & 4 deletions lib/IRGen/GenConcurrency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ void irgen::emitEndAsyncLet(IRGenFunction &IGF, llvm::Value *alet) {
}

llvm::Value *irgen::emitCreateTaskGroup(IRGenFunction &IGF,
SubstitutionMap subs) {
SubstitutionMap subs,
llvm::Value *groupFlags) {
auto ty = llvm::ArrayType::get(IGF.IGM.Int8PtrTy, NumWords_TaskGroup);
auto address = IGF.createAlloca(ty, Alignment(Alignment_TaskGroup));
auto group = IGF.Builder.CreateBitCast(address.getAddress(),
Expand All @@ -282,9 +283,14 @@ llvm::Value *irgen::emitCreateTaskGroup(IRGenFunction &IGF,
auto resultType = subs.getReplacementTypes()[0]->getCanonicalType();
auto resultTypeMetadata = IGF.emitAbstractTypeMetadataRef(resultType);

auto *call =
IGF.Builder.CreateCall(IGF.IGM.getTaskGroupInitializeFunctionPointer(),
{group, resultTypeMetadata});
llvm::CallInst *call;
if (groupFlags) {
call = IGF.Builder.CreateCall(IGF.IGM.getTaskGroupInitializeWithFlagsFunctionPointer(),
{groupFlags, group, resultTypeMetadata});
} else {
call = IGF.Builder.CreateCall(IGF.IGM.getTaskGroupInitializeFunctionPointer(),
{group, resultTypeMetadata});
}
call->setDoesNotThrow();
call->setCallingConv(IGF.IGM.SwiftCC);

Expand Down
3 changes: 2 additions & 1 deletion lib/IRGen/GenConcurrency.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ llvm::Value *emitBuiltinStartAsyncLet(IRGenFunction &IGF,
void emitEndAsyncLet(IRGenFunction &IGF, llvm::Value *alet);

/// Emit the createTaskGroup builtin.
llvm::Value *emitCreateTaskGroup(IRGenFunction &IGF, SubstitutionMap subs);
llvm::Value *emitCreateTaskGroup(IRGenFunction &IGF, SubstitutionMap subs,
llvm::Value *groupFlags);

/// Emit the destroyTaskGroup builtin.
void emitDestroyTaskGroup(IRGenFunction &IGF, llvm::Value *group);
Expand Down
3 changes: 3 additions & 0 deletions lib/IRGen/IRGenSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2658,6 +2658,9 @@ FunctionPointer::Kind irgen::classifyFunctionPointerKind(SILFunction *fn) {
if (name.equals("swift_taskGroup_wait_next_throwing"))
return SpecialKind::TaskGroupWaitNext;

if (name.equals("swift_taskGroup_waitAll"))
return SpecialKind::TaskGroupWaitAll;

if (name.equals("swift_distributed_execute_target"))
return SpecialKind::DistributedExecuteTarget;
}
Expand Down
1 change: 1 addition & 0 deletions lib/SIL/IR/OperandOwnership.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,7 @@ BUILTIN_OPERAND_OWNERSHIP(DestroyingConsume, EndAsyncLet)
BUILTIN_OPERAND_OWNERSHIP(DestroyingConsume, StartAsyncLetWithLocalBuffer)
BUILTIN_OPERAND_OWNERSHIP(DestroyingConsume, EndAsyncLetLifetime)
BUILTIN_OPERAND_OWNERSHIP(InstantaneousUse, CreateTaskGroup)
BUILTIN_OPERAND_OWNERSHIP(InstantaneousUse, CreateTaskGroupWithFlags)
BUILTIN_OPERAND_OWNERSHIP(InstantaneousUse, DestroyTaskGroup)

BUILTIN_OPERAND_OWNERSHIP(ForwardingConsume, COWBufferForReading)
Expand Down
1 change: 1 addition & 0 deletions lib/SIL/IR/ValueOwnership.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ CONSTANT_OWNERSHIP_BUILTIN(None, EndAsyncLet)
CONSTANT_OWNERSHIP_BUILTIN(None, StartAsyncLetWithLocalBuffer)
CONSTANT_OWNERSHIP_BUILTIN(None, EndAsyncLetLifetime)
CONSTANT_OWNERSHIP_BUILTIN(None, CreateTaskGroup)
CONSTANT_OWNERSHIP_BUILTIN(None, CreateTaskGroupWithFlags)
CONSTANT_OWNERSHIP_BUILTIN(None, DestroyTaskGroup)
CONSTANT_OWNERSHIP_BUILTIN(None, TaskRunInline)
CONSTANT_OWNERSHIP_BUILTIN(None, Copy)
Expand Down
1 change: 1 addition & 0 deletions lib/SIL/Utils/MemAccessUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2505,6 +2505,7 @@ static void visitBuiltinAddress(BuiltinInst *builtin,
case BuiltinValueKind::EndAsyncLet:
case BuiltinValueKind::EndAsyncLetLifetime:
case BuiltinValueKind::CreateTaskGroup:
case BuiltinValueKind::CreateTaskGroupWithFlags:
case BuiltinValueKind::DestroyTaskGroup:
return;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ static bool isBarrier(SILInstruction *inst) {
case BuiltinValueKind::EndAsyncLet:
case BuiltinValueKind::EndAsyncLetLifetime:
case BuiltinValueKind::CreateTaskGroup:
case BuiltinValueKind::CreateTaskGroupWithFlags:
case BuiltinValueKind::DestroyTaskGroup:
case BuiltinValueKind::StackAlloc:
case BuiltinValueKind::StackDealloc:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ OVERRIDE_TASK_GROUP(taskGroup_initialize, void,
SWIFT_EXPORT_FROM(swift_Concurrency), SWIFT_CC(swift),
swift::, (TaskGroup *group, const Metadata *T), (group, T))

OVERRIDE_TASK_GROUP(taskGroup_initializeWithFlags, void,
SWIFT_EXPORT_FROM(swift_Concurrency), SWIFT_CC(swift),
swift::, (size_t flags, TaskGroup *group, const Metadata *T), (flags, group, T))

OVERRIDE_TASK_STATUS(taskGroup_attachChild, void,
SWIFT_EXPORT_FROM(swift_Concurrency), SWIFT_CC(swift),
swift::, (TaskGroup *group, AsyncTask *child),
Expand Down
Loading