Skip to content

Commit b0be562

Browse files
Merge pull request #34829 from aschwaighofer/irgen_get_await_async_continuation
IRGen: get/await_async_continuation support
2 parents c189398 + 505a6ee commit b0be562

File tree

11 files changed

+546
-35
lines changed

11 files changed

+546
-35
lines changed

lib/IRGen/GenCall.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4581,3 +4581,17 @@ void irgen::emitAsyncReturn(IRGenFunction &IGF, AsyncContextLayout &asyncLayout,
45814581
auto call = IGF.Builder.CreateCall(fnPtr, Args);
45824582
call->setTailCall();
45834583
}
4584+
4585+
FunctionPointer
4586+
IRGenFunction::getFunctionPointerForResumeIntrinsic(llvm::Value *resume) {
4587+
auto *fnTy = llvm::FunctionType::get(
4588+
IGM.VoidTy, {IGM.Int8PtrTy, IGM.Int8PtrTy, IGM.Int8PtrTy},
4589+
false /*vaargs*/);
4590+
auto signature =
4591+
Signature(fnTy, IGM.constructInitialAttributes(), IGM.SwiftCC);
4592+
auto fnPtr = FunctionPointer(
4593+
FunctionPointer::KindTy::Function,
4594+
Builder.CreateBitOrPointerCast(resume, fnTy->getPointerTo()),
4595+
PointerAuthInfo(), signature);
4596+
return fnPtr;
4597+
}

lib/IRGen/GenFunc.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2395,21 +2395,31 @@ llvm::Function *IRGenFunction::getOrCreateResumePrjFn() {
23952395
},
23962396
false /*isNoInline*/));
23972397
}
2398-
23992398
llvm::Function *
24002399
IRGenFunction::createAsyncDispatchFn(const FunctionPointer &fnPtr,
24012400
ArrayRef<llvm::Value *> args) {
24022401
SmallVector<llvm::Type*, 8> argTys;
2403-
argTys.push_back(IGM.Int8PtrTy); // Function pointer to be called.
24042402
for (auto arg : args) {
24052403
auto *ty = arg->getType();
24062404
argTys.push_back(ty);
24072405
}
2406+
return createAsyncDispatchFn(fnPtr, argTys);
2407+
}
2408+
2409+
llvm::Function *
2410+
IRGenFunction::createAsyncDispatchFn(const FunctionPointer &fnPtr,
2411+
ArrayRef<llvm::Type *> argTypes) {
2412+
SmallVector<llvm::Type*, 8> argTys;
2413+
argTys.push_back(IGM.Int8PtrTy); // Function pointer to be called.
2414+
for (auto ty : argTypes) {
2415+
argTys.push_back(ty);
2416+
}
24082417
auto calleeFnPtrType = fnPtr.getRawPointer()->getType();
24092418
auto *dispatchFnTy =
24102419
llvm::FunctionType::get(IGM.VoidTy, argTys, false /*vaargs*/);
24112420
llvm::SmallString<40> name;
2412-
llvm::raw_svector_ostream(name) << "__swift_suspend_dispatch_" << args.size();
2421+
llvm::raw_svector_ostream(name)
2422+
<< "__swift_suspend_dispatch_" << argTypes.size();
24132423
llvm::Function *dispatch =
24142424
llvm::Function::Create(dispatchFnTy, llvm::Function::InternalLinkage,
24152425
llvm::StringRef(name), &IGM.Module);

lib/IRGen/IRGenFunction.cpp

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "llvm/Support/CommandLine.h"
2424
#include "llvm/Support/raw_ostream.h"
2525

26+
#include "Callee.h"
2627
#include "Explosion.h"
2728
#include "IRGenDebugInfo.h"
2829
#include "IRGenFunction.h"
@@ -511,3 +512,263 @@ llvm::Value *IRGenFunction::alignUpToMaximumAlignment(llvm::Type *sizeTy, llvm::
511512
auto *invertedMask = Builder.CreateNot(alignMask);
512513
return Builder.CreateAnd(Builder.CreateAdd(val, alignMask), invertedMask);
513514
}
515+
516+
/// Returns the current task \p currTask as an UnsafeContinuation at +1.
517+
static llvm::Value *unsafeContinuationFromTask(IRGenFunction &IGF,
518+
SILType unsafeContinuationTy,
519+
llvm::Value *currTask) {
520+
auto &IGM = IGF.IGM;
521+
auto &Builder = IGF.Builder;
522+
523+
auto &rawPonterTI = IGM.getRawPointerTypeInfo();
524+
auto object =
525+
Builder.CreateBitOrPointerCast(currTask, rawPonterTI.getStorageType());
526+
527+
// Wrap the native object in the UnsafeContinuation struct.
528+
// struct UnsafeContinuation<T> {
529+
// let _continuation : Builtin.RawPointer
530+
// }
531+
auto &unsafeContinuationTI =
532+
cast<LoadableTypeInfo>(IGF.getTypeInfo(unsafeContinuationTy));
533+
auto unsafeContinuationStructTy =
534+
cast<llvm::StructType>(unsafeContinuationTI.getStorageType());
535+
auto fieldTy =
536+
cast<llvm::StructType>(unsafeContinuationStructTy->getElementType(0));
537+
auto reference =
538+
Builder.CreateBitOrPointerCast(object, fieldTy->getElementType(0));
539+
auto field =
540+
Builder.CreateInsertValue(llvm::UndefValue::get(fieldTy), reference, 0);
541+
auto unsafeContinuation = Builder.CreateInsertValue(
542+
llvm::UndefValue::get(unsafeContinuationStructTy), field, 0);
543+
544+
return unsafeContinuation;
545+
}
546+
547+
void IRGenFunction::emitGetAsyncContinuation(SILType unsafeContinuationTy,
548+
StackAddress resultAddr,
549+
Explosion &out) {
550+
// Create the continuation.
551+
// void current_sil_function(AsyncTask *currTask, Executor *currExecutor,
552+
// AsyncContext *currCtxt) {
553+
//
554+
// A continuation is the current AsyncTask 'currTask' with:
555+
// currTask->ResumeTask = @llvm.coro.async.resume();
556+
// currTask->ResumeContext = &continuation_context;
557+
//
558+
// Where:
559+
//
560+
// struct {
561+
// AsyncContext *resumeCtxt;
562+
// void *awaitSynchronization;
563+
// SwiftError *errResult;
564+
// union {
565+
// IndirectResult *result;
566+
// DirectResult *result;
567+
// };
568+
// ExecutorRef *resumeExecutor;
569+
// } continuation_context; // local variable of current_sil_function
570+
//
571+
// continuation_context.resumeCtxt = currCtxt;
572+
// continuation_context.errResult = nulllptr;
573+
// continuation_context.result = ... // local alloca.
574+
// continuation_context.resumeExecutor = .. // current executor
575+
576+
auto currTask = getAsyncTask();
577+
auto unsafeContinuation =
578+
unsafeContinuationFromTask(*this, unsafeContinuationTy, currTask);
579+
580+
// Create and setup the continuation context for UnsafeContinuation<T>.
581+
// continuation_context.resumeCtxt = currCtxt;
582+
// continuation_context.errResult = nulllptr;
583+
// continuation_context.result = ... // local alloca T
584+
auto pointerAlignment = IGM.getPointerAlignment();
585+
auto continuationContext =
586+
createAlloca(IGM.AsyncContinuationContextTy, pointerAlignment);
587+
AsyncCoroutineCurrentContinuationContext = continuationContext.getAddress();
588+
// TODO: add lifetime with matching lifetime in await_async_continuation
589+
auto contResumeAddr =
590+
Builder.CreateStructGEP(continuationContext.getAddress(), 0);
591+
Builder.CreateStore(getAsyncContext(),
592+
Address(contResumeAddr, pointerAlignment));
593+
auto contErrResultAddr =
594+
Builder.CreateStructGEP(continuationContext.getAddress(), 2);
595+
Builder.CreateStore(
596+
llvm::Constant::getNullValue(
597+
contErrResultAddr->getType()->getPointerElementType()),
598+
Address(contErrResultAddr, pointerAlignment));
599+
auto contResultAddr =
600+
Builder.CreateStructGEP(continuationContext.getAddress(), 3);
601+
if (!resultAddr.getAddress().isValid()) {
602+
assert(unsafeContinuationTy.getASTType()
603+
->castTo<BoundGenericType>()
604+
->getGenericArgs()
605+
.size() == 1 &&
606+
"expect UnsafeContinuation<T> to have one generic arg");
607+
auto resultTy = IGM.getLoweredType(unsafeContinuationTy.getASTType()
608+
->castTo<BoundGenericType>()
609+
->getGenericArgs()[0]
610+
->getCanonicalType());
611+
auto &resultTI = getTypeInfo(resultTy);
612+
auto resultAddr =
613+
resultTI.allocateStack(*this, resultTy, "async.continuation.result");
614+
Builder.CreateStore(Builder.CreateBitOrPointerCast(
615+
resultAddr.getAddress().getAddress(),
616+
contResultAddr->getType()->getPointerElementType()),
617+
Address(contResultAddr, pointerAlignment));
618+
} else {
619+
Builder.CreateStore(Builder.CreateBitOrPointerCast(
620+
resultAddr.getAddress().getAddress(),
621+
contResultAddr->getType()->getPointerElementType()),
622+
Address(contResultAddr, pointerAlignment));
623+
}
624+
// continuation_context.resumeExecutor = // current executor
625+
auto contExecutorRefAddr =
626+
Builder.CreateStructGEP(continuationContext.getAddress(), 4);
627+
Builder.CreateStore(
628+
Builder.CreateBitOrPointerCast(
629+
getAsyncExecutor(),
630+
contExecutorRefAddr->getType()->getPointerElementType()),
631+
Address(contExecutorRefAddr, pointerAlignment));
632+
633+
// Fill the current task (i.e the continuation) with the continuation
634+
// information.
635+
// currTask->ResumeTask = @llvm.coro.async.resume();
636+
assert(currTask->getType() == IGM.SwiftTaskPtrTy);
637+
auto currTaskResumeTaskAddr = Builder.CreateStructGEP(currTask,3);
638+
auto coroResume =
639+
Builder.CreateIntrinsicCall(llvm::Intrinsic::coro_async_resume, {});
640+
641+
assert(AsyncCoroutineCurrentResume == nullptr &&
642+
"Don't support nested get_async_continuation");
643+
AsyncCoroutineCurrentResume = coroResume;
644+
Builder.CreateStore(
645+
Builder.CreateBitOrPointerCast(coroResume, IGM.FunctionPtrTy),
646+
Address(currTaskResumeTaskAddr, pointerAlignment));
647+
// currTask->ResumeContext = &continuation_context;
648+
auto currTaskResumeCtxtAddr = Builder.CreateStructGEP(currTask, 4);
649+
Builder.CreateStore(
650+
Builder.CreateBitOrPointerCast(continuationContext.getAddress(),
651+
IGM.SwiftContextPtrTy),
652+
Address(currTaskResumeCtxtAddr, pointerAlignment));
653+
654+
// Publish all the writes.
655+
// continuation_context.awaitSynchronization =(atomic release) nullptr;
656+
auto contAwaitSyncAddr =
657+
Builder.CreateStructGEP(continuationContext.getAddress(), 1);
658+
auto null = llvm::ConstantInt::get(
659+
contAwaitSyncAddr->getType()->getPointerElementType(), 0);
660+
auto atomicStore =
661+
Builder.CreateStore(null, Address(contAwaitSyncAddr, pointerAlignment));
662+
atomicStore->setAtomic(llvm::AtomicOrdering::Release,
663+
llvm::SyncScope::System);
664+
out.add(unsafeContinuation);
665+
}
666+
667+
void IRGenFunction::emitAwaitAsyncContinuation(
668+
SILType unsafeContinuationTy, bool isIndirectResult,
669+
Explosion &outDirectResult, llvm::BasicBlock *&normalBB,
670+
llvm::PHINode *&optionalErrorResult, llvm::BasicBlock *&optionalErrorBB) {
671+
assert(AsyncCoroutineCurrentContinuationContext && "no active continuation");
672+
auto pointerAlignment = IGM.getPointerAlignment();
673+
674+
// First check whether the await reached this point first. Meaning we still
675+
// have to wait for the continuation result. If the await reaches first we
676+
// abort the control flow here (resuming the continuation will execute the
677+
// remaining control flow).
678+
auto contAwaitSyncAddr =
679+
Builder.CreateStructGEP(AsyncCoroutineCurrentContinuationContext, 1);
680+
auto null = llvm::ConstantInt::get(
681+
contAwaitSyncAddr->getType()->getPointerElementType(), 0);
682+
auto one = llvm::ConstantInt::get(
683+
contAwaitSyncAddr->getType()->getPointerElementType(), 1);
684+
auto results = Builder.CreateAtomicCmpXchg(
685+
contAwaitSyncAddr, null, one,
686+
llvm::AtomicOrdering::Release /*success ordering*/,
687+
llvm::AtomicOrdering::Acquire /* failure ordering */,
688+
llvm::SyncScope::System);
689+
auto firstAtAwait = Builder.CreateExtractValue(results, 1);
690+
auto contBB = createBasicBlock("await.async.maybe.resume");
691+
auto abortBB = createBasicBlock("await.async.abort");
692+
Builder.CreateCondBr(firstAtAwait, abortBB, contBB);
693+
Builder.emitBlock(abortBB);
694+
{
695+
// We are first to the sync point. Abort. The continuation's result is not
696+
// available yet.
697+
emitCoroutineOrAsyncExit();
698+
}
699+
700+
auto contBB2 = createBasicBlock("await.async.resume");
701+
Builder.emitBlock(contBB);
702+
{
703+
// Setup the suspend point.
704+
SmallVector<llvm::Value *, 8> arguments;
705+
arguments.push_back(AsyncCoroutineCurrentResume);
706+
auto resumeProjFn = getOrCreateResumePrjFn();
707+
arguments.push_back(
708+
Builder.CreateBitOrPointerCast(resumeProjFn, IGM.Int8PtrTy));
709+
// The dispatch function just calls the resume point.
710+
auto resumeFnPtr =
711+
getFunctionPointerForResumeIntrinsic(AsyncCoroutineCurrentResume);
712+
arguments.push_back(Builder.CreateBitOrPointerCast(
713+
createAsyncDispatchFn(resumeFnPtr,
714+
{IGM.Int8PtrTy, IGM.Int8PtrTy, IGM.Int8PtrTy}),
715+
IGM.Int8PtrTy));
716+
arguments.push_back(AsyncCoroutineCurrentResume);
717+
arguments.push_back(
718+
Builder.CreateBitOrPointerCast(getAsyncTask(), IGM.Int8PtrTy));
719+
arguments.push_back(
720+
Builder.CreateBitOrPointerCast(getAsyncExecutor(), IGM.Int8PtrTy));
721+
arguments.push_back(Builder.CreateBitOrPointerCast(
722+
AsyncCoroutineCurrentContinuationContext, IGM.Int8PtrTy));
723+
Builder.CreateIntrinsicCall(llvm::Intrinsic::coro_suspend_async, arguments);
724+
725+
auto results = Builder.CreateAtomicCmpXchg(
726+
contAwaitSyncAddr, null, one,
727+
llvm::AtomicOrdering::Release /*success ordering*/,
728+
llvm::AtomicOrdering::Acquire /* failure ordering */,
729+
llvm::SyncScope::System);
730+
// Again, are we first at the wait (can only reach that state after
731+
// continuation.resume/abort is called)? If so abort to wait for the end of
732+
// the await point to be reached.
733+
auto firstAtAwait = Builder.CreateExtractValue(results, 1);
734+
Builder.CreateCondBr(firstAtAwait, abortBB, contBB2);
735+
}
736+
737+
Builder.emitBlock(contBB2);
738+
auto contBB3 = createBasicBlock("await.async.normal");
739+
if (optionalErrorBB) {
740+
auto contErrResultAddr = Address(
741+
Builder.CreateStructGEP(AsyncCoroutineCurrentContinuationContext, 2),
742+
pointerAlignment);
743+
auto errorRes = Builder.CreateLoad(contErrResultAddr);
744+
auto nullError = llvm::Constant::getNullValue(errorRes->getType());
745+
auto hasError = Builder.CreateICmpNE(errorRes, nullError);
746+
optionalErrorResult->addIncoming(errorRes, Builder.GetInsertBlock());
747+
Builder.CreateCondBr(hasError, optionalErrorBB, contBB3);
748+
} else {
749+
Builder.CreateBr(contBB3);
750+
}
751+
752+
Builder.emitBlock(contBB3);
753+
if (!isIndirectResult) {
754+
auto contResultAddrAddr =
755+
Builder.CreateStructGEP(AsyncCoroutineCurrentContinuationContext, 3);
756+
auto resultAddrVal =
757+
Builder.CreateLoad(Address(contResultAddrAddr, pointerAlignment));
758+
// Take the result.
759+
auto resultTy = IGM.getLoweredType(unsafeContinuationTy.getASTType()
760+
->castTo<BoundGenericType>()
761+
->getGenericArgs()[0]
762+
->getCanonicalType());
763+
auto &resultTI = cast<LoadableTypeInfo>(getTypeInfo(resultTy));
764+
auto resultStorageTy = resultTI.getStorageType();
765+
auto resultAddr =
766+
Address(Builder.CreateBitOrPointerCast(resultAddrVal,
767+
resultStorageTy->getPointerTo()),
768+
resultTI.getFixedAlignment());
769+
resultTI.loadAsTake(*this, resultAddr, outDirectResult);
770+
}
771+
Builder.CreateBr(normalBB);
772+
AsyncCoroutineCurrentResume = nullptr;
773+
AsyncCoroutineCurrentContinuationContext = nullptr;
774+
}

lib/IRGen/IRGenFunction.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,21 @@ class IRGenFunction {
135135
llvm::Function *getOrCreateResumePrjFn();
136136
llvm::Function *createAsyncDispatchFn(const FunctionPointer &fnPtr,
137137
ArrayRef<llvm::Value *> args);
138+
llvm::Function *createAsyncDispatchFn(const FunctionPointer &fnPtr,
139+
ArrayRef<llvm::Type *> argTypes);
140+
141+
void emitGetAsyncContinuation(SILType silTy, StackAddress optionalResultAddr,
142+
Explosion &out);
143+
144+
void emitAwaitAsyncContinuation(SILType unsafeContinuationTy,
145+
bool isIndirectResult,
146+
Explosion &outDirectResult,
147+
llvm::BasicBlock *&normalBB,
148+
llvm::PHINode *&optionalErrorPhi,
149+
llvm::BasicBlock *&optionalErrorBB);
150+
151+
FunctionPointer
152+
getFunctionPointerForResumeIntrinsic(llvm::Value *resumeIntrinsic);
138153

139154
private:
140155
void emitPrologue();
@@ -145,8 +160,16 @@ class IRGenFunction {
145160
llvm::Value *CalleeErrorResultSlot = nullptr;
146161
llvm::Value *CallerErrorResultSlot = nullptr;
147162
llvm::Value *CoroutineHandle = nullptr;
163+
llvm::Value *AsyncCoroutineCurrentResume = nullptr;
164+
llvm::Value *AsyncCoroutineCurrentContinuationContext = nullptr;
148165
bool IsAsync = false;
149166

167+
/// The unique block that calls @llvm.coro.end.
168+
llvm::BasicBlock *CoroutineExitBlock = nullptr;
169+
170+
public:
171+
void emitCoroutineOrAsyncExit();
172+
150173
//--- Helper methods -----------------------------------------------------------
151174
public:
152175

lib/IRGen/IRGenModule.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,14 @@ IRGenModule::IRGenModule(IRGenerator &irgen,
593593
AsyncFunctionPointerTy = createStructType(*this, "swift.async_func_pointer",
594594
{RelativeAddressTy, Int32Ty}, true);
595595
SwiftContextTy = createStructType(*this, "swift.context", {});
596-
SwiftTaskTy = createStructType(*this, "swift.task", {});
596+
auto *ContextPtrTy = llvm::PointerType::getUnqual(SwiftContextTy);
597+
SwiftTaskTy = createStructType(*this, "swift.task", {
598+
Int8PtrTy, Int8PtrTy, // Job.SchedulerPrivate
599+
Int64Ty, // Job.Flags
600+
FunctionPtrTy, // Job.RunJob/Job.ResumeTask
601+
ContextPtrTy, // Task.ResumeContext
602+
Int64Ty // Task.Status
603+
});
597604
SwiftExecutorTy = createStructType(*this, "swift.executor", {});
598605
AsyncFunctionPointerPtrTy = AsyncFunctionPointerTy->getPointerTo(DefaultAS);
599606
SwiftContextPtrTy = SwiftContextTy->getPointerTo(DefaultAS);
@@ -612,6 +619,11 @@ IRGenModule::IRGenModule(IRGenerator &irgen,
612619
*this, "swift.async_task_and_context",
613620
{ SwiftTaskPtrTy, SwiftContextPtrTy });
614621

622+
AsyncContinuationContextTy = createStructType(
623+
*this, "swift.async_continuation_context",
624+
{SwiftContextPtrTy, SizeTy, ErrorPtrTy, OpaquePtrTy, SwiftExecutorPtrTy});
625+
AsyncContinuationContextPtrTy = AsyncContinuationContextTy->getPointerTo();
626+
615627
DifferentiabilityWitnessTy = createStructType(
616628
*this, "swift.differentiability_witness", {Int8PtrTy, Int8PtrTy});
617629
}

0 commit comments

Comments
 (0)