Skip to content

Provide ability to use CheckedContinuation when suspending for an async ObjC call #68390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/AST/KnownSDKTypes.def
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ KNOWN_SDK_TYPE_DECL(ObjectiveC, ObjCBool, StructDecl, 0)

// TODO(async): These might move to the stdlib module when concurrency is
// standardized
KNOWN_SDK_TYPE_DECL(Concurrency, CheckedContinuation, NominalTypeDecl, 2)
KNOWN_SDK_TYPE_DECL(Concurrency, UnsafeContinuation, NominalTypeDecl, 2)
KNOWN_SDK_TYPE_DECL(Concurrency, MainActor, NominalTypeDecl, 0)
KNOWN_SDK_TYPE_DECL(Concurrency, Job, StructDecl, 0) // TODO: remove in favor of ExecutorJob
Expand Down
4 changes: 4 additions & 0 deletions include/swift/Basic/LangOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ namespace swift {
bool DisableImplicitBacktracingModuleImport =
!SWIFT_IMPLICIT_BACKTRACING_IMPORT;

// Whether to use checked continuations when making an async call from
// Swift into ObjC. If false, will use unchecked continuations instead.
bool UseCheckedAsyncObjCBridging = false;

/// Should we check the target OSs of serialized modules to see that they're
/// new enough?
bool EnableTargetOSChecking = true;
Expand Down
4 changes: 4 additions & 0 deletions include/swift/Option/FrontendOptions.td
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,10 @@ def enable_experimental_async_top_level :
// HIDDEN FLAGS
let Flags = [FrontendOption, NoDriverOption, HelpHidden] in {

def checked_async_objc_bridging : Joined<["-"], "checked-async-objc-bridging=">,
HelpText<"Control whether checked continuations are used when bridging "
"async calls from Swift to ObjC: 'on', 'off' ">;

def debug_constraints : Flag<["-"], "debug-constraints">,
HelpText<"Debug the constraint-based type checker">;

Expand Down
15 changes: 15 additions & 0 deletions lib/Frontend/CompilerInvocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,21 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args,
HadError = true;
}

if (auto A = Args.getLastArg(OPT_checked_async_objc_bridging)) {
auto value = llvm::StringSwitch<llvm::Optional<bool>>(A->getValue())
.Case("off", false)
.Case("on", true)
.Default(llvm::None);

if (value) {
Opts.UseCheckedAsyncObjCBridging = *value;
} else {
Diags.diagnose(SourceLoc(), diag::error_invalid_arg_value,
A->getAsString(Args), A->getValue());
HadError = true;
}
}

return HadError || UnsupportedOS || UnsupportedArch;
}

Expand Down
124 changes: 94 additions & 30 deletions lib/SILGen/ResultPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,75 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
// A foreign async function shouldn't have any indirect results.
}

std::tuple</*blockStorage=*/SILValue, /*blockStorageType=*/CanType,
/*continuationType=*/CanType>
emitBlockStorage(SILGenFunction &SGF, SILLocation loc, bool throws) {
auto &ctx = SGF.getASTContext();

// Wrap the Builtin.RawUnsafeContinuation in an
// UnsafeContinuation<T, E>.
auto *unsafeContinuationDecl = ctx.getUnsafeContinuationDecl();
auto errorTy = throws ? ctx.getErrorExistentialType() : ctx.getNeverType();
auto continuationTy =
BoundGenericType::get(unsafeContinuationDecl, /*parent=*/Type(),
{calleeTypeInfo.substResultType, errorTy})
->getCanonicalType();

auto wrappedContinuation = SGF.B.createStruct(
loc, SILType::getPrimitiveObjectType(continuationTy), {continuation});

const bool checkedBridging = ctx.LangOpts.UseCheckedAsyncObjCBridging;

// If checked bridging is enabled, wrap that continuation again in a
// CheckedContinuation<T, E>
if (checkedBridging) {
auto *checkedContinuationDecl = ctx.getCheckedContinuationDecl();
continuationTy =
BoundGenericType::get(checkedContinuationDecl, /*parent=*/Type(),
{calleeTypeInfo.substResultType, errorTy})
->getCanonicalType();
}

auto blockStorageTy = SILBlockStorageType::get(
checkedBridging ? ctx.TheAnyType : continuationTy);
auto blockStorage = SGF.emitTemporaryAllocation(
loc, SILType::getPrimitiveAddressType(blockStorageTy));

auto continuationAddr = SGF.B.createProjectBlockStorage(loc, blockStorage);

// Stash continuation in a buffer for a block object.

if (checkedBridging) {
auto createIntrinsic =
throws ? SGF.SGM.getCreateCheckedThrowingContinuation()
: SGF.SGM.getCreateCheckedContinuation();

// In this case block storage captures `Any` which would be initialized
// with an checked continuation.
auto underlyingContinuationAddr =
SGF.B.createInitExistentialAddr(loc, continuationAddr, continuationTy,
SGF.getLoweredType(continuationTy),
/*conformances=*/{});

auto subs = SubstitutionMap::get(createIntrinsic->getGenericSignature(),
{calleeTypeInfo.substResultType},
ArrayRef<ProtocolConformanceRef>{});

InitializationPtr underlyingInit(
new KnownAddressInitialization(underlyingContinuationAddr));
auto continuationMV =
ManagedValue::forRValueWithoutOwnership(wrappedContinuation);
SGF.emitApplyOfLibraryIntrinsic(loc, createIntrinsic, subs,
{continuationMV}, SGFContext())
.forwardInto(SGF, loc, underlyingInit.get());
} else {
SGF.B.createStore(loc, wrappedContinuation, continuationAddr,
StoreOwnershipQualifier::Trivial);
}

return std::make_tuple(blockStorage, blockStorageTy, continuationTy);
}

ManagedValue
emitForeignAsyncCompletionHandler(SILGenFunction &SGF,
AbstractionPattern origFormalType,
Expand All @@ -751,28 +820,11 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
continuation = SGF.B.createGetAsyncContinuationAddr(loc, resumeBuf,
calleeTypeInfo.substResultType, throws);

// Wrap the Builtin.RawUnsafeContinuation in an
// UnsafeContinuation<T, E>.
auto continuationDecl = SGF.getASTContext().getUnsafeContinuationDecl();

auto errorTy = throws
? SGF.getASTContext().getErrorExistentialType()
: SGF.getASTContext().getNeverType();
auto continuationTy = BoundGenericType::get(continuationDecl, Type(),
{ calleeTypeInfo.substResultType, errorTy })
->getCanonicalType();
auto wrappedContinuation =
SGF.B.createStruct(loc,
SILType::getPrimitiveObjectType(continuationTy),
{continuation});

// Stash it in a buffer for a block object.
auto blockStorageTy = SILType::getPrimitiveAddressType(
SILBlockStorageType::get(continuationTy));
auto blockStorage = SGF.emitTemporaryAllocation(loc, blockStorageTy);
auto continuationAddr = SGF.B.createProjectBlockStorage(loc, blockStorage);
SGF.B.createStore(loc, wrappedContinuation, continuationAddr,
StoreOwnershipQualifier::Trivial);
SILValue blockStorage;
CanType blockStorageTy;
CanType continuationTy;
std::tie(blockStorage, blockStorageTy, continuationTy) =
emitBlockStorage(SGF, loc, throws);

// Get the block invocation function for the given completion block type.
auto completionHandlerIndex = calleeTypeInfo.foreign.async
Expand All @@ -796,11 +848,11 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
SGF.SGM.getOrCreateForeignAsyncCompletionHandlerImplFunction(
cast<SILFunctionType>(
impFnTy->mapTypeOutOfContext()->getReducedType(sig)),
blockStorageTy->mapTypeOutOfContext()->getReducedType(sig),
continuationTy->mapTypeOutOfContext()->getReducedType(sig),
origFormalType, sig, *calleeTypeInfo.foreign.async,
calleeTypeInfo.foreign.error);
origFormalType, sig, calleeTypeInfo);
auto impRef = SGF.B.createFunctionRef(loc, impl);

// Initialize the block object for the completion handler.
SILValue block = SGF.B.createInitBlockStorageHeader(loc, blockStorage,
impRef, SILType::getPrimitiveObjectType(impFnTy),
Expand Down Expand Up @@ -829,7 +881,8 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
SILValue bridgedForeignError) override {
// There should be no direct results from the call.
assert(directResults.empty());

auto &ctx = SGF.getASTContext();

// Await the continuation we handed off to the completion handler.
SILBasicBlock *resumeBlock = SGF.createBasicBlock();
SILBasicBlock *errorBlock = nullptr;
Expand Down Expand Up @@ -871,9 +924,20 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
SGF.B.setInsertionPoint(
++bridgedForeignError->getDefiningInstruction()->getIterator());

auto continuationDecl = SGF.getASTContext().getUnsafeContinuationDecl();

auto errorTy = SGF.getASTContext().getErrorExistentialType();
// FIXME: this case is not respecting checked bridging, and it's a
// great candidate for that. This situation comes up when bridging
// to an ObjC completion-handler method that returns a bool. It seems
// that bool indicates whether the handler was invoked. If it was not
// then it writes out an error. Here for the unsafe bridging, we're
// invoking the continuation by re-wrapping it in an
// UnsafeContinuation<_, Error> and then immediately calling its
// resume(throwing: error) method. For a checked bridging scenario, we
// would need to use a copy of the original CheckedContinuation that
// was passed to the callee. Whether that's by invoking the block
// ourselves, or just invoking the CheckedContinuation.

auto continuationDecl = ctx.getUnsafeContinuationDecl();
auto errorTy = ctx.getErrorExistentialType();
auto continuationBGT =
BoundGenericType::get(continuationDecl, Type(),
{calleeTypeInfo.substResultType, errorTy});
Expand Down Expand Up @@ -924,7 +988,7 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {

Scope errorScope(SGF, loc);

auto errorTy = SGF.getASTContext().getErrorExistentialType();
auto errorTy = ctx.getErrorExistentialType();
auto errorVal = SGF.B.createTermResult(
SILType::getPrimitiveObjectType(errorTy), OwnershipKind::Owned);

Expand Down
31 changes: 31 additions & 0 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,37 @@ SILGenModule::getCheckExpectedExecutor() {
"_checkExpectedExecutor");
}

FuncDecl *
SILGenModule::getCreateCheckedContinuation() {
return lookupConcurrencyIntrinsic(getASTContext(),
CreateCheckedContinuation,
"_createCheckedContinuation");
}
FuncDecl *
SILGenModule::getCreateCheckedThrowingContinuation() {
return lookupConcurrencyIntrinsic(getASTContext(),
CreateCheckedThrowingContinuation,
"_createCheckedThrowingContinuation");
}
FuncDecl *
SILGenModule::getResumeCheckedContinuation() {
return lookupConcurrencyIntrinsic(getASTContext(),
ResumeCheckedContinuation,
"_resumeCheckedContinuation");
}
FuncDecl *
SILGenModule::getResumeCheckedThrowingContinuation() {
return lookupConcurrencyIntrinsic(getASTContext(),
ResumeCheckedThrowingContinuation,
"_resumeCheckedThrowingContinuation");
}
FuncDecl *
SILGenModule::getResumeCheckedThrowingContinuationWithError() {
return lookupConcurrencyIntrinsic(
getASTContext(), ResumeCheckedThrowingContinuationWithError,
"_resumeCheckedThrowingContinuationWithError");
}

FuncDecl *SILGenModule::getAsyncMainDrainQueue() {
return lookupConcurrencyIntrinsic(getASTContext(), AsyncMainDrainQueue,
"_asyncMainDrainQueue");
Expand Down
25 changes: 21 additions & 4 deletions lib/SILGen/SILGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace swift {
namespace Lowering {
class TypeConverter;
class SILGenFunction;
class CalleeTypeInfo;

/// An enum to indicate whether a protocol method requirement is satisfied by
/// a free function, as for an operator requirement.
Expand Down Expand Up @@ -127,6 +128,12 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
llvm::Optional<FuncDecl *> ResumeUnsafeThrowingContinuationWithError;
llvm::Optional<FuncDecl *> CheckExpectedExecutor;

llvm::Optional<FuncDecl*> CreateCheckedContinuation;
llvm::Optional<FuncDecl*> CreateCheckedThrowingContinuation;
llvm::Optional<FuncDecl*> ResumeCheckedContinuation;
llvm::Optional<FuncDecl*> ResumeCheckedThrowingContinuation;
llvm::Optional<FuncDecl*> ResumeCheckedThrowingContinuationWithError;

llvm::Optional<FuncDecl *> AsyncMainDrainQueue;
llvm::Optional<FuncDecl *> GetMainExecutor;
llvm::Optional<FuncDecl *> SwiftJobRun;
Expand Down Expand Up @@ -182,10 +189,9 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
/// implementation function for an ObjC API that was imported
/// as `async` in Swift.
SILFunction *getOrCreateForeignAsyncCompletionHandlerImplFunction(
CanSILFunctionType blockType, CanType continuationTy,
AbstractionPattern origFormalType, CanGenericSignature sig,
ForeignAsyncConvention convention,
llvm::Optional<ForeignErrorConvention> foreignError);
CanSILFunctionType blockType, CanType blockStorageType,
CanType continuationType, AbstractionPattern origFormalType,
CanGenericSignature sig, CalleeTypeInfo &calleeInfo);

/// Determine whether the given class has any instance variables that
/// need to be destroyed.
Expand Down Expand Up @@ -551,6 +557,17 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
/// Retrieve the _Concurrency._checkExpectedExecutor intrinsic.
FuncDecl *getCheckExpectedExecutor();

/// Retrieve the _Concurrency._createCheckedContinuation intrinsic.
FuncDecl *getCreateCheckedContinuation();
/// Retrieve the _Concurrency._createCheckedThrowingContinuation intrinsic.
FuncDecl *getCreateCheckedThrowingContinuation();
/// Retrieve the _Concurrency._resumeCheckedContinuation intrinsic.
FuncDecl *getResumeCheckedContinuation();
/// Retrieve the _Concurrency._resumeCheckedThrowingContinuation intrinsic.
FuncDecl *getResumeCheckedThrowingContinuation();
/// Retrieve the _Concurrency._resumeCheckedThrowingContinuationWithError intrinsic.
FuncDecl *getResumeCheckedThrowingContinuationWithError();

/// Retrieve the _Concurrency._asyncMainDrainQueue intrinsic.
FuncDecl *getAsyncMainDrainQueue();
/// Retrieve the _Concurrency._getMainExecutor intrinsic.
Expand Down
Loading