Skip to content

Commit 580213b

Browse files
authored
Merge pull request #69139 from xedin/checked-continuations-for-foreign-callbacks-5.10
[5.10] Provide ability to use CheckedContinuation when suspending for an async ObjC call
2 parents 51e3b64 + d94dd41 commit 580213b

14 files changed

+756
-99
lines changed

include/swift/AST/KnownSDKTypes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ KNOWN_SDK_TYPE_DECL(ObjectiveC, ObjCBool, StructDecl, 0)
3737

3838
// TODO(async): These might move to the stdlib module when concurrency is
3939
// standardized
40+
KNOWN_SDK_TYPE_DECL(Concurrency, CheckedContinuation, NominalTypeDecl, 2)
4041
KNOWN_SDK_TYPE_DECL(Concurrency, UnsafeContinuation, NominalTypeDecl, 2)
4142
KNOWN_SDK_TYPE_DECL(Concurrency, MainActor, NominalTypeDecl, 0)
4243
KNOWN_SDK_TYPE_DECL(Concurrency, Job, StructDecl, 0) // TODO: remove in favor of ExecutorJob

include/swift/Basic/LangOptions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,10 @@ namespace swift {
389389
bool DisableImplicitBacktracingModuleImport =
390390
!SWIFT_IMPLICIT_BACKTRACING_IMPORT;
391391

392+
// Whether to use checked continuations when making an async call from
393+
// Swift into ObjC. If false, will use unchecked continuations instead.
394+
bool UseCheckedAsyncObjCBridging = false;
395+
392396
/// Should we check the target OSs of serialized modules to see that they're
393397
/// new enough?
394398
bool EnableTargetOSChecking = true;

include/swift/Option/FrontendOptions.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,10 @@ def enable_experimental_async_top_level :
364364
// HIDDEN FLAGS
365365
let Flags = [FrontendOption, NoDriverOption, HelpHidden] in {
366366

367+
def checked_async_objc_bridging : Joined<["-"], "checked-async-objc-bridging=">,
368+
HelpText<"Control whether checked continuations are used when bridging "
369+
"async calls from Swift to ObjC: 'on', 'off' ">;
370+
367371
def debug_constraints : Flag<["-"], "debug-constraints">,
368372
HelpText<"Debug the constraint-based type checker">;
369373

lib/Frontend/CompilerInvocation.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,21 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args,
13261326
HadError = true;
13271327
}
13281328

1329+
if (auto A = Args.getLastArg(OPT_checked_async_objc_bridging)) {
1330+
auto value = llvm::StringSwitch<llvm::Optional<bool>>(A->getValue())
1331+
.Case("off", false)
1332+
.Case("on", true)
1333+
.Default(llvm::None);
1334+
1335+
if (value) {
1336+
Opts.UseCheckedAsyncObjCBridging = *value;
1337+
} else {
1338+
Diags.diagnose(SourceLoc(), diag::error_invalid_arg_value,
1339+
A->getAsString(Args), A->getValue());
1340+
HadError = true;
1341+
}
1342+
}
1343+
13291344
return HadError || UnsupportedOS || UnsupportedArch;
13301345
}
13311346

lib/SILGen/ResultPlan.cpp

Lines changed: 124 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,11 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
719719
SILValue resumeBuf;
720720
SILValue continuation;
721721
ExecutorBreadcrumb breadcrumb;
722-
722+
723+
SILValue blockStorage;
724+
CanType blockStorageTy;
725+
CanType continuationTy;
726+
723727
public:
724728
ForeignAsyncInitializationPlan(SILGenFunction &SGF, SILLocation loc,
725729
const CalleeTypeInfo &calleeTypeInfo)
@@ -738,6 +742,75 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
738742
// A foreign async function shouldn't have any indirect results.
739743
}
740744

745+
std::tuple</*blockStorage=*/SILValue, /*blockStorageType=*/CanType,
746+
/*continuationType=*/CanType>
747+
emitBlockStorage(SILGenFunction &SGF, SILLocation loc, bool throws) {
748+
auto &ctx = SGF.getASTContext();
749+
750+
// Wrap the Builtin.RawUnsafeContinuation in an
751+
// UnsafeContinuation<T, E>.
752+
auto *unsafeContinuationDecl = ctx.getUnsafeContinuationDecl();
753+
auto errorTy = throws ? ctx.getErrorExistentialType() : ctx.getNeverType();
754+
auto continuationTy =
755+
BoundGenericType::get(unsafeContinuationDecl, /*parent=*/Type(),
756+
{calleeTypeInfo.substResultType, errorTy})
757+
->getCanonicalType();
758+
759+
auto wrappedContinuation = SGF.B.createStruct(
760+
loc, SILType::getPrimitiveObjectType(continuationTy), {continuation});
761+
762+
const bool checkedBridging = ctx.LangOpts.UseCheckedAsyncObjCBridging;
763+
764+
// If checked bridging is enabled, wrap that continuation again in a
765+
// CheckedContinuation<T, E>
766+
if (checkedBridging) {
767+
auto *checkedContinuationDecl = ctx.getCheckedContinuationDecl();
768+
continuationTy =
769+
BoundGenericType::get(checkedContinuationDecl, /*parent=*/Type(),
770+
{calleeTypeInfo.substResultType, errorTy})
771+
->getCanonicalType();
772+
}
773+
774+
auto blockStorageTy = SILBlockStorageType::get(
775+
checkedBridging ? ctx.TheAnyType : continuationTy);
776+
auto blockStorage = SGF.emitTemporaryAllocation(
777+
loc, SILType::getPrimitiveAddressType(blockStorageTy));
778+
779+
auto continuationAddr = SGF.B.createProjectBlockStorage(loc, blockStorage);
780+
781+
// Stash continuation in a buffer for a block object.
782+
783+
if (checkedBridging) {
784+
auto createIntrinsic =
785+
throws ? SGF.SGM.getCreateCheckedThrowingContinuation()
786+
: SGF.SGM.getCreateCheckedContinuation();
787+
788+
// In this case block storage captures `Any` which would be initialized
789+
// with an checked continuation.
790+
auto underlyingContinuationAddr =
791+
SGF.B.createInitExistentialAddr(loc, continuationAddr, continuationTy,
792+
SGF.getLoweredType(continuationTy),
793+
/*conformances=*/{});
794+
795+
auto subs = SubstitutionMap::get(createIntrinsic->getGenericSignature(),
796+
{calleeTypeInfo.substResultType},
797+
ArrayRef<ProtocolConformanceRef>{});
798+
799+
InitializationPtr underlyingInit(
800+
new KnownAddressInitialization(underlyingContinuationAddr));
801+
auto continuationMV =
802+
ManagedValue::forRValueWithoutOwnership(wrappedContinuation);
803+
SGF.emitApplyOfLibraryIntrinsic(loc, createIntrinsic, subs,
804+
{continuationMV}, SGFContext())
805+
.forwardInto(SGF, loc, underlyingInit.get());
806+
} else {
807+
SGF.B.createStore(loc, wrappedContinuation, continuationAddr,
808+
StoreOwnershipQualifier::Trivial);
809+
}
810+
811+
return std::make_tuple(blockStorage, blockStorageTy, continuationTy);
812+
}
813+
741814
ManagedValue
742815
emitForeignAsyncCompletionHandler(SILGenFunction &SGF,
743816
AbstractionPattern origFormalType,
@@ -751,28 +824,8 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
751824
continuation = SGF.B.createGetAsyncContinuationAddr(loc, resumeBuf,
752825
calleeTypeInfo.substResultType, throws);
753826

754-
// Wrap the Builtin.RawUnsafeContinuation in an
755-
// UnsafeContinuation<T, E>.
756-
auto continuationDecl = SGF.getASTContext().getUnsafeContinuationDecl();
757-
758-
auto errorTy = throws
759-
? SGF.getASTContext().getErrorExistentialType()
760-
: SGF.getASTContext().getNeverType();
761-
auto continuationTy = BoundGenericType::get(continuationDecl, Type(),
762-
{ calleeTypeInfo.substResultType, errorTy })
763-
->getCanonicalType();
764-
auto wrappedContinuation =
765-
SGF.B.createStruct(loc,
766-
SILType::getPrimitiveObjectType(continuationTy),
767-
{continuation});
768-
769-
// Stash it in a buffer for a block object.
770-
auto blockStorageTy = SILType::getPrimitiveAddressType(
771-
SILBlockStorageType::get(continuationTy));
772-
auto blockStorage = SGF.emitTemporaryAllocation(loc, blockStorageTy);
773-
auto continuationAddr = SGF.B.createProjectBlockStorage(loc, blockStorage);
774-
SGF.B.createStore(loc, wrappedContinuation, continuationAddr,
775-
StoreOwnershipQualifier::Trivial);
827+
std::tie(blockStorage, blockStorageTy, continuationTy) =
828+
emitBlockStorage(SGF, loc, throws);
776829

777830
// Get the block invocation function for the given completion block type.
778831
auto completionHandlerIndex = calleeTypeInfo.foreign.async
@@ -796,11 +849,11 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
796849
SGF.SGM.getOrCreateForeignAsyncCompletionHandlerImplFunction(
797850
cast<SILFunctionType>(
798851
impFnTy->mapTypeOutOfContext()->getReducedType(sig)),
852+
blockStorageTy->mapTypeOutOfContext()->getReducedType(sig),
799853
continuationTy->mapTypeOutOfContext()->getReducedType(sig),
800-
origFormalType, sig, *calleeTypeInfo.foreign.async,
801-
calleeTypeInfo.foreign.error);
854+
origFormalType, sig, calleeTypeInfo);
802855
auto impRef = SGF.B.createFunctionRef(loc, impl);
803-
856+
804857
// Initialize the block object for the completion handler.
805858
SILValue block = SGF.B.createInitBlockStorageHeader(loc, blockStorage,
806859
impRef, SILType::getPrimitiveObjectType(impFnTy),
@@ -829,7 +882,8 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
829882
SILValue bridgedForeignError) override {
830883
// There should be no direct results from the call.
831884
assert(directResults.empty());
832-
885+
auto &ctx = SGF.getASTContext();
886+
833887
// Await the continuation we handed off to the completion handler.
834888
SILBasicBlock *resumeBlock = SGF.createBasicBlock();
835889
SILBasicBlock *errorBlock = nullptr;
@@ -854,54 +908,74 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
854908
// (1) fulfill the unsafe continuation with the foreign error
855909
// (2) branch to the await block
856910
{
857-
// First, fulfill the unsafe continuation with the foreign error.
911+
// First, fulfill the continuation with the foreign error.
858912
// Currently, that block's code looks something like
859913
// %foreignError = ... : $*Optional<NSError>
860914
// %converter = function_ref _convertNSErrorToError(_:)
861915
// %error = apply %converter(%foreignError)
862916
// [... insert here ...]
863917
// destroy_value %error
864918
// destroy_value %foreignError
865-
// Insert code to fulfill it after the native %error is defined. That
866-
// code should structure the RawUnsafeContinuation (continuation) into
867-
// an appropriately typed UnsafeContinuation and then pass that together
868-
// with (a copy of) the error to
869-
// _resumeUnsafeThrowingContinuationWithError.
919+
// Insert code to fulfill it after the native %error is defined. That
920+
// code should load UnsafeContinuation (or CheckedContinuation
921+
// depending on mode) and then pass that together with (a copy of) the
922+
// error to _resume{Unsafe, Checked}ThrowingContinuationWithError.
870923
// [foreign_error_block_with_foreign_async_convention]
871924
SGF.B.setInsertionPoint(
872925
++bridgedForeignError->getDefiningInstruction()->getIterator());
873926

874-
auto continuationDecl = SGF.getASTContext().getUnsafeContinuationDecl();
927+
bool checkedBridging = ctx.LangOpts.UseCheckedAsyncObjCBridging;
875928

876-
auto errorTy = SGF.getASTContext().getErrorExistentialType();
877-
auto continuationBGT =
878-
BoundGenericType::get(continuationDecl, Type(),
879-
{calleeTypeInfo.substResultType, errorTy});
880929
auto env = SGF.F.getGenericEnvironment();
881930
auto sig = env ? env->getGenericSignature().getCanonicalSignature()
882931
: CanGenericSignature();
883-
auto mappedContinuationTy =
884-
continuationBGT->mapTypeOutOfContext()->getReducedType(sig);
932+
933+
// Load unsafe or checked continuation from the block storage
934+
// and call _resume{Unsafe, Checked}ThrowingContinuationWithError.
935+
936+
SILValue continuationAddr =
937+
SGF.B.createProjectBlockStorage(loc, blockStorage);
938+
939+
ManagedValue continuation;
940+
if (checkedBridging) {
941+
FormalEvaluationScope scope(SGF);
942+
943+
auto underlyingValueTy =
944+
OpenedArchetypeType::get(ctx.TheAnyType, sig);
945+
946+
auto underlyingValueAddr = SGF.emitOpenExistential(
947+
loc, ManagedValue::forTrivialAddressRValue(continuationAddr),
948+
SGF.getLoweredType(underlyingValueTy), AccessKind::Read);
949+
950+
continuation = SGF.B.createUncheckedAddrCast(
951+
loc, underlyingValueAddr,
952+
SILType::getPrimitiveAddressType(continuationTy));
953+
} else {
954+
auto continuationVal = SGF.B.createLoad(
955+
loc, continuationAddr, LoadOwnershipQualifier::Trivial);
956+
continuation =
957+
ManagedValue::forObjectRValueWithoutOwnership(continuationVal);
958+
}
959+
960+
auto mappedOutContinuationTy =
961+
continuationTy->mapTypeOutOfContext()->getReducedType(sig);
885962
auto resumeType =
886-
cast<BoundGenericType>(mappedContinuationTy).getGenericArgs()[0];
887-
auto continuationTy = continuationBGT->getCanonicalType();
963+
cast<BoundGenericType>(mappedOutContinuationTy).getGenericArgs()[0];
888964

889965
auto errorIntrinsic =
890-
SGF.SGM.getResumeUnsafeThrowingContinuationWithError();
966+
checkedBridging
967+
? SGF.SGM.getResumeCheckedThrowingContinuationWithError()
968+
: SGF.SGM.getResumeUnsafeThrowingContinuationWithError();
969+
891970
Type replacementTypes[] = {
892971
SGF.F.mapTypeIntoContext(resumeType)->getCanonicalType()};
893972
auto subs = SubstitutionMap::get(errorIntrinsic->getGenericSignature(),
894973
replacementTypes,
895974
ArrayRef<ProtocolConformanceRef>{});
896-
auto wrappedContinuation = SGF.B.createStruct(
897-
loc, SILType::getPrimitiveObjectType(continuationTy),
898-
{continuation});
899975

900-
auto continuationMV = ManagedValue::forObjectRValueWithoutOwnership(
901-
SILValue(wrappedContinuation));
902976
SGF.emitApplyOfLibraryIntrinsic(
903977
loc, errorIntrinsic, subs,
904-
{continuationMV,
978+
{continuation,
905979
SGF.B.copyOwnedObjectRValue(loc, bridgedForeignError,
906980
ManagedValue::ScopeKind::Lexical)},
907981
SGFContext());
@@ -924,7 +998,7 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
924998

925999
Scope errorScope(SGF, loc);
9261000

927-
auto errorTy = SGF.getASTContext().getErrorExistentialType();
1001+
auto errorTy = ctx.getErrorExistentialType();
9281002
auto errorVal = SGF.B.createTermResult(
9291003
SILType::getPrimitiveObjectType(errorTy), OwnershipKind::Owned);
9301004

lib/SILGen/SILGen.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,37 @@ SILGenModule::getCheckExpectedExecutor() {
467467
"_checkExpectedExecutor");
468468
}
469469

470+
FuncDecl *
471+
SILGenModule::getCreateCheckedContinuation() {
472+
return lookupConcurrencyIntrinsic(getASTContext(),
473+
CreateCheckedContinuation,
474+
"_createCheckedContinuation");
475+
}
476+
FuncDecl *
477+
SILGenModule::getCreateCheckedThrowingContinuation() {
478+
return lookupConcurrencyIntrinsic(getASTContext(),
479+
CreateCheckedThrowingContinuation,
480+
"_createCheckedThrowingContinuation");
481+
}
482+
FuncDecl *
483+
SILGenModule::getResumeCheckedContinuation() {
484+
return lookupConcurrencyIntrinsic(getASTContext(),
485+
ResumeCheckedContinuation,
486+
"_resumeCheckedContinuation");
487+
}
488+
FuncDecl *
489+
SILGenModule::getResumeCheckedThrowingContinuation() {
490+
return lookupConcurrencyIntrinsic(getASTContext(),
491+
ResumeCheckedThrowingContinuation,
492+
"_resumeCheckedThrowingContinuation");
493+
}
494+
FuncDecl *
495+
SILGenModule::getResumeCheckedThrowingContinuationWithError() {
496+
return lookupConcurrencyIntrinsic(
497+
getASTContext(), ResumeCheckedThrowingContinuationWithError,
498+
"_resumeCheckedThrowingContinuationWithError");
499+
}
500+
470501
FuncDecl *SILGenModule::getAsyncMainDrainQueue() {
471502
return lookupConcurrencyIntrinsic(getASTContext(), AsyncMainDrainQueue,
472503
"_asyncMainDrainQueue");

lib/SILGen/SILGen.h

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ namespace swift {
3232
namespace Lowering {
3333
class TypeConverter;
3434
class SILGenFunction;
35+
class CalleeTypeInfo;
3536

3637
/// An enum to indicate whether a protocol method requirement is satisfied by
3738
/// a free function, as for an operator requirement.
@@ -127,6 +128,12 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
127128
llvm::Optional<FuncDecl *> ResumeUnsafeThrowingContinuationWithError;
128129
llvm::Optional<FuncDecl *> CheckExpectedExecutor;
129130

131+
llvm::Optional<FuncDecl*> CreateCheckedContinuation;
132+
llvm::Optional<FuncDecl*> CreateCheckedThrowingContinuation;
133+
llvm::Optional<FuncDecl*> ResumeCheckedContinuation;
134+
llvm::Optional<FuncDecl*> ResumeCheckedThrowingContinuation;
135+
llvm::Optional<FuncDecl*> ResumeCheckedThrowingContinuationWithError;
136+
130137
llvm::Optional<FuncDecl *> AsyncMainDrainQueue;
131138
llvm::Optional<FuncDecl *> GetMainExecutor;
132139
llvm::Optional<FuncDecl *> SwiftJobRun;
@@ -182,10 +189,9 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
182189
/// implementation function for an ObjC API that was imported
183190
/// as `async` in Swift.
184191
SILFunction *getOrCreateForeignAsyncCompletionHandlerImplFunction(
185-
CanSILFunctionType blockType, CanType continuationTy,
186-
AbstractionPattern origFormalType, CanGenericSignature sig,
187-
ForeignAsyncConvention convention,
188-
llvm::Optional<ForeignErrorConvention> foreignError);
192+
CanSILFunctionType blockType, CanType blockStorageType,
193+
CanType continuationType, AbstractionPattern origFormalType,
194+
CanGenericSignature sig, CalleeTypeInfo &calleeInfo);
189195

190196
/// Determine whether the given class has any instance variables that
191197
/// need to be destroyed.
@@ -551,6 +557,17 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
551557
/// Retrieve the _Concurrency._checkExpectedExecutor intrinsic.
552558
FuncDecl *getCheckExpectedExecutor();
553559

560+
/// Retrieve the _Concurrency._createCheckedContinuation intrinsic.
561+
FuncDecl *getCreateCheckedContinuation();
562+
/// Retrieve the _Concurrency._createCheckedThrowingContinuation intrinsic.
563+
FuncDecl *getCreateCheckedThrowingContinuation();
564+
/// Retrieve the _Concurrency._resumeCheckedContinuation intrinsic.
565+
FuncDecl *getResumeCheckedContinuation();
566+
/// Retrieve the _Concurrency._resumeCheckedThrowingContinuation intrinsic.
567+
FuncDecl *getResumeCheckedThrowingContinuation();
568+
/// Retrieve the _Concurrency._resumeCheckedThrowingContinuationWithError intrinsic.
569+
FuncDecl *getResumeCheckedThrowingContinuationWithError();
570+
554571
/// Retrieve the _Concurrency._asyncMainDrainQueue intrinsic.
555572
FuncDecl *getAsyncMainDrainQueue();
556573
/// Retrieve the _Concurrency._getMainExecutor intrinsic.

0 commit comments

Comments
 (0)