Skip to content

Commit 6774c0a

Browse files
committed
[SILGen] Implement async completion bridging via checked continuations
Because `CheckedContinuation` is not a @Frozen struct we have to use `Any` to store it in @block_storage indirectly. If the flag is enabled, we'd emit a block storage with `Any` and initialize the existential with stack allocated `CheckedContinuation` formed from `UnsafeContinuation`. Inside of the completion handler `Any` is going to be projected and cast back to `CheckedContinuation`. (cherry picked from commit 4b22620)
1 parent d0d153b commit 6774c0a

File tree

6 files changed

+491
-35
lines changed

6 files changed

+491
-35
lines changed

include/swift/AST/KnownSDKTypes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ KNOWN_SDK_TYPE_DECL(ObjectiveC, ObjCBool, StructDecl, 0)
3737

3838
// TODO(async): These might move to the stdlib module when concurrency is
3939
// standardized
40+
KNOWN_SDK_TYPE_DECL(Concurrency, CheckedContinuation, NominalTypeDecl, 2)
4041
KNOWN_SDK_TYPE_DECL(Concurrency, UnsafeContinuation, NominalTypeDecl, 2)
4142
KNOWN_SDK_TYPE_DECL(Concurrency, MainActor, NominalTypeDecl, 0)
4243
KNOWN_SDK_TYPE_DECL(Concurrency, Job, StructDecl, 0) // TODO: remove in favor of ExecutorJob

lib/SILGen/ResultPlan.cpp

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,75 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
738738
// A foreign async function shouldn't have any indirect results.
739739
}
740740

741+
std::tuple</*blockStorage=*/SILValue, /*blockStorageType=*/CanType,
742+
/*continuationType=*/CanType>
743+
emitBlockStorage(SILGenFunction &SGF, SILLocation loc, bool throws) {
744+
auto &ctx = SGF.getASTContext();
745+
746+
// Wrap the Builtin.RawUnsafeContinuation in an
747+
// UnsafeContinuation<T, E>.
748+
auto *unsafeContinuationDecl = ctx.getUnsafeContinuationDecl();
749+
auto errorTy = throws ? ctx.getErrorExistentialType() : ctx.getNeverType();
750+
auto continuationTy =
751+
BoundGenericType::get(unsafeContinuationDecl, /*parent=*/Type(),
752+
{calleeTypeInfo.substResultType, errorTy})
753+
->getCanonicalType();
754+
755+
auto wrappedContinuation = SGF.B.createStruct(
756+
loc, SILType::getPrimitiveObjectType(continuationTy), {continuation});
757+
758+
const bool checkedBridging = ctx.LangOpts.UseCheckedAsyncObjCBridging;
759+
760+
// If checked bridging is enabled, wrap that continuation again in a
761+
// CheckedContinuation<T, E>
762+
if (checkedBridging) {
763+
auto *checkedContinuationDecl = ctx.getCheckedContinuationDecl();
764+
continuationTy =
765+
BoundGenericType::get(checkedContinuationDecl, /*parent=*/Type(),
766+
{calleeTypeInfo.substResultType, errorTy})
767+
->getCanonicalType();
768+
}
769+
770+
auto blockStorageTy = SILBlockStorageType::get(
771+
checkedBridging ? ctx.TheAnyType : continuationTy);
772+
auto blockStorage = SGF.emitTemporaryAllocation(
773+
loc, SILType::getPrimitiveAddressType(blockStorageTy));
774+
775+
auto continuationAddr = SGF.B.createProjectBlockStorage(loc, blockStorage);
776+
777+
// Stash continuation in a buffer for a block object.
778+
779+
if (checkedBridging) {
780+
auto createIntrinsic =
781+
throws ? SGF.SGM.getCreateCheckedThrowingContinuation()
782+
: SGF.SGM.getCreateCheckedContinuation();
783+
784+
// In this case block storage captures `Any` which would be initialized
785+
// with an checked continuation.
786+
auto underlyingContinuationAddr =
787+
SGF.B.createInitExistentialAddr(loc, continuationAddr, continuationTy,
788+
SGF.getLoweredType(continuationTy),
789+
/*conformances=*/{});
790+
791+
auto subs = SubstitutionMap::get(createIntrinsic->getGenericSignature(),
792+
{calleeTypeInfo.substResultType},
793+
ArrayRef<ProtocolConformanceRef>{});
794+
795+
InitializationPtr underlyingInit(
796+
new KnownAddressInitialization(underlyingContinuationAddr));
797+
auto continuationMV =
798+
ManagedValue::forRValueWithoutOwnership(wrappedContinuation);
799+
SGF.emitApplyOfLibraryIntrinsic(loc, createIntrinsic, subs,
800+
{continuationMV}, SGFContext())
801+
.forwardInto(SGF, loc, underlyingInit.get());
802+
} else {
803+
SGF.B.createStore(loc, wrappedContinuation, continuationAddr,
804+
StoreOwnershipQualifier::Trivial);
805+
}
806+
807+
return std::make_tuple(blockStorage, blockStorageTy, continuationTy);
808+
}
809+
741810
ManagedValue
742811
emitForeignAsyncCompletionHandler(SILGenFunction &SGF,
743812
AbstractionPattern origFormalType,
@@ -751,28 +820,11 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
751820
continuation = SGF.B.createGetAsyncContinuationAddr(loc, resumeBuf,
752821
calleeTypeInfo.substResultType, throws);
753822

754-
// Wrap the Builtin.RawUnsafeContinuation in an
755-
// UnsafeContinuation<T, E>.
756-
auto continuationDecl = SGF.getASTContext().getUnsafeContinuationDecl();
757-
758-
auto errorTy = throws
759-
? SGF.getASTContext().getErrorExistentialType()
760-
: SGF.getASTContext().getNeverType();
761-
auto continuationTy = BoundGenericType::get(continuationDecl, Type(),
762-
{ calleeTypeInfo.substResultType, errorTy })
763-
->getCanonicalType();
764-
auto wrappedContinuation =
765-
SGF.B.createStruct(loc,
766-
SILType::getPrimitiveObjectType(continuationTy),
767-
{continuation});
768-
769-
// Stash it in a buffer for a block object.
770-
auto blockStorageTy = SILBlockStorageType::get(continuationTy);
771-
auto blockStorage = SGF.emitTemporaryAllocation(
772-
loc, SILType::getPrimitiveAddressType(blockStorageTy));
773-
auto continuationAddr = SGF.B.createProjectBlockStorage(loc, blockStorage);
774-
SGF.B.createStore(loc, wrappedContinuation, continuationAddr,
775-
StoreOwnershipQualifier::Trivial);
823+
SILValue blockStorage;
824+
CanType blockStorageTy;
825+
CanType continuationTy;
826+
std::tie(blockStorage, blockStorageTy, continuationTy) =
827+
emitBlockStorage(SGF, loc, throws);
776828

777829
// Get the block invocation function for the given completion block type.
778830
auto completionHandlerIndex = calleeTypeInfo.foreign.async
@@ -797,6 +849,7 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
797849
cast<SILFunctionType>(
798850
impFnTy->mapTypeOutOfContext()->getReducedType(sig)),
799851
blockStorageTy->mapTypeOutOfContext()->getReducedType(sig),
852+
continuationTy->mapTypeOutOfContext()->getReducedType(sig),
800853
origFormalType, sig, calleeTypeInfo);
801854
auto impRef = SGF.B.createFunctionRef(loc, impl);
802855

lib/SILGen/SILGen.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,8 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
190190
/// as `async` in Swift.
191191
SILFunction *getOrCreateForeignAsyncCompletionHandlerImplFunction(
192192
CanSILFunctionType blockType, CanType blockStorageType,
193-
AbstractionPattern origFormalType, CanGenericSignature sig,
194-
CalleeTypeInfo &calleeInfo);
193+
CanType continuationType, AbstractionPattern origFormalType,
194+
CanGenericSignature sig, CalleeTypeInfo &calleeInfo);
195195

196196
/// Determine whether the given class has any instance variables that
197197
/// need to be destroyed.

lib/SILGen/SILGenThunk.cpp

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,8 @@ static const clang::Type *prependParameterType(
200200

201201
SILFunction *SILGenModule::getOrCreateForeignAsyncCompletionHandlerImplFunction(
202202
CanSILFunctionType blockType, CanType blockStorageType,
203-
AbstractionPattern origFormalType, CanGenericSignature sig,
204-
CalleeTypeInfo &calleeInfo) {
203+
CanType continuationType, AbstractionPattern origFormalType,
204+
CanGenericSignature sig, CalleeTypeInfo &calleeInfo) {
205205
auto convention = *calleeInfo.foreign.async;
206206
auto resumeType =
207207
calleeInfo.substResultType->mapTypeOutOfContext()->getReducedType(sig);
@@ -293,11 +293,33 @@ SILFunction *SILGenModule::getOrCreateForeignAsyncCompletionHandlerImplFunction(
293293

294294
// Get the continuation out of the block object.
295295
auto blockStorage = params[0].getValue();
296-
auto continuationAddr = SGF.B.createProjectBlockStorage(loc, blockStorage);
297-
auto continuationVal = SGF.B.createLoad(loc, continuationAddr,
298-
LoadOwnershipQualifier::Trivial);
299-
auto continuation =
300-
ManagedValue::forObjectRValueWithoutOwnership(continuationVal);
296+
SILValue continuationAddr =
297+
SGF.B.createProjectBlockStorage(loc, blockStorage);
298+
299+
auto &ctx = SGF.getASTContext();
300+
301+
bool checkedBridging = ctx.LangOpts.UseCheckedAsyncObjCBridging;
302+
303+
ManagedValue continuation;
304+
if (checkedBridging) {
305+
FormalEvaluationScope scope(SGF);
306+
307+
auto underlyingValueTy = OpenedArchetypeType::get(ctx.TheAnyType, sig);
308+
309+
auto underlyingValueAddr = SGF.emitOpenExistential(
310+
loc, ManagedValue::forTrivialAddressRValue(continuationAddr),
311+
SGF.getLoweredType(underlyingValueTy), AccessKind::Read);
312+
313+
continuation = SGF.B.createUncheckedAddrCast(
314+
loc, underlyingValueAddr,
315+
SILType::getPrimitiveAddressType(
316+
F->mapTypeIntoContext(continuationType)->getCanonicalType()));
317+
} else {
318+
auto continuationVal = SGF.B.createLoad(
319+
loc, continuationAddr, LoadOwnershipQualifier::Trivial);
320+
continuation =
321+
ManagedValue::forObjectRValueWithoutOwnership(continuationVal);
322+
}
301323

302324
// Check for an error if the convention includes one.
303325
// Increment the error and flag indices if present. They do not account
@@ -311,8 +333,12 @@ SILFunction *SILGenModule::getOrCreateForeignAsyncCompletionHandlerImplFunction(
311333

312334
SILBasicBlock *returnBB = nullptr;
313335
if (errorIndex) {
314-
resumeIntrinsic = getResumeUnsafeThrowingContinuation();
315-
auto errorIntrinsic = getResumeUnsafeThrowingContinuationWithError();
336+
resumeIntrinsic = checkedBridging
337+
? getResumeCheckedThrowingContinuation()
338+
: getResumeUnsafeThrowingContinuation();
339+
auto errorIntrinsic =
340+
checkedBridging ? getResumeCheckedThrowingContinuationWithError()
341+
: getResumeUnsafeThrowingContinuationWithError();
316342

317343
auto errorArgument = params[*errorIndex];
318344
auto someErrorBB = SGF.createBasicBlock(FunctionSection::Postmatter);
@@ -385,9 +411,12 @@ SILFunction *SILGenModule::getOrCreateForeignAsyncCompletionHandlerImplFunction(
385411
SGF.B.createBranch(loc, returnBB);
386412
SGF.B.emitBlock(noneErrorBB);
387413
} else if (auto foreignError = calleeInfo.foreign.error) {
388-
resumeIntrinsic = getResumeUnsafeThrowingContinuation();
414+
resumeIntrinsic = checkedBridging
415+
? getResumeCheckedThrowingContinuation()
416+
: getResumeUnsafeThrowingContinuation();
389417
} else {
390-
resumeIntrinsic = getResumeUnsafeContinuation();
418+
resumeIntrinsic = checkedBridging ? getResumeCheckedContinuation()
419+
: getResumeUnsafeContinuation();
391420
}
392421

393422
auto loweredResumeTy = SGF.getLoweredType(AbstractionPattern::getOpaque(),

0 commit comments

Comments
 (0)