Skip to content

[5.10] Provide ability to use CheckedContinuation when suspending for an async ObjC call #69139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/AST/KnownSDKTypes.def
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ KNOWN_SDK_TYPE_DECL(ObjectiveC, ObjCBool, StructDecl, 0)

// TODO(async): These might move to the stdlib module when concurrency is
// standardized
KNOWN_SDK_TYPE_DECL(Concurrency, CheckedContinuation, NominalTypeDecl, 2)
KNOWN_SDK_TYPE_DECL(Concurrency, UnsafeContinuation, NominalTypeDecl, 2)
KNOWN_SDK_TYPE_DECL(Concurrency, MainActor, NominalTypeDecl, 0)
KNOWN_SDK_TYPE_DECL(Concurrency, Job, StructDecl, 0) // TODO: remove in favor of ExecutorJob
Expand Down
4 changes: 4 additions & 0 deletions include/swift/Basic/LangOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ namespace swift {
bool DisableImplicitBacktracingModuleImport =
!SWIFT_IMPLICIT_BACKTRACING_IMPORT;

// Whether to use checked continuations when making an async call from
// Swift into ObjC. If false, will use unchecked continuations instead.
bool UseCheckedAsyncObjCBridging = false;

/// Should we check the target OSs of serialized modules to see that they're
/// new enough?
bool EnableTargetOSChecking = true;
Expand Down
4 changes: 4 additions & 0 deletions include/swift/Option/FrontendOptions.td
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ def enable_experimental_async_top_level :
// HIDDEN FLAGS
let Flags = [FrontendOption, NoDriverOption, HelpHidden] in {

def checked_async_objc_bridging : Joined<["-"], "checked-async-objc-bridging=">,
HelpText<"Control whether checked continuations are used when bridging "
"async calls from Swift to ObjC: 'on', 'off' ">;

def debug_constraints : Flag<["-"], "debug-constraints">,
HelpText<"Debug the constraint-based type checker">;

Expand Down
15 changes: 15 additions & 0 deletions lib/Frontend/CompilerInvocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,21 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args,
HadError = true;
}

if (auto A = Args.getLastArg(OPT_checked_async_objc_bridging)) {
auto value = llvm::StringSwitch<llvm::Optional<bool>>(A->getValue())
.Case("off", false)
.Case("on", true)
.Default(llvm::None);

if (value) {
Opts.UseCheckedAsyncObjCBridging = *value;
} else {
Diags.diagnose(SourceLoc(), diag::error_invalid_arg_value,
A->getAsString(Args), A->getValue());
HadError = true;
}
}

return HadError || UnsupportedOS || UnsupportedArch;
}

Expand Down
174 changes: 124 additions & 50 deletions lib/SILGen/ResultPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,11 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
SILValue resumeBuf;
SILValue continuation;
ExecutorBreadcrumb breadcrumb;


SILValue blockStorage;
CanType blockStorageTy;
CanType continuationTy;

public:
ForeignAsyncInitializationPlan(SILGenFunction &SGF, SILLocation loc,
const CalleeTypeInfo &calleeTypeInfo)
Expand All @@ -738,6 +742,75 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
// A foreign async function shouldn't have any indirect results.
}

std::tuple</*blockStorage=*/SILValue, /*blockStorageType=*/CanType,
/*continuationType=*/CanType>
emitBlockStorage(SILGenFunction &SGF, SILLocation loc, bool throws) {
auto &ctx = SGF.getASTContext();

// Wrap the Builtin.RawUnsafeContinuation in an
// UnsafeContinuation<T, E>.
auto *unsafeContinuationDecl = ctx.getUnsafeContinuationDecl();
auto errorTy = throws ? ctx.getErrorExistentialType() : ctx.getNeverType();
auto continuationTy =
BoundGenericType::get(unsafeContinuationDecl, /*parent=*/Type(),
{calleeTypeInfo.substResultType, errorTy})
->getCanonicalType();

auto wrappedContinuation = SGF.B.createStruct(
loc, SILType::getPrimitiveObjectType(continuationTy), {continuation});

const bool checkedBridging = ctx.LangOpts.UseCheckedAsyncObjCBridging;

// If checked bridging is enabled, wrap that continuation again in a
// CheckedContinuation<T, E>
if (checkedBridging) {
auto *checkedContinuationDecl = ctx.getCheckedContinuationDecl();
continuationTy =
BoundGenericType::get(checkedContinuationDecl, /*parent=*/Type(),
{calleeTypeInfo.substResultType, errorTy})
->getCanonicalType();
}

auto blockStorageTy = SILBlockStorageType::get(
checkedBridging ? ctx.TheAnyType : continuationTy);
auto blockStorage = SGF.emitTemporaryAllocation(
loc, SILType::getPrimitiveAddressType(blockStorageTy));

auto continuationAddr = SGF.B.createProjectBlockStorage(loc, blockStorage);

// Stash continuation in a buffer for a block object.

if (checkedBridging) {
auto createIntrinsic =
throws ? SGF.SGM.getCreateCheckedThrowingContinuation()
: SGF.SGM.getCreateCheckedContinuation();

// In this case block storage captures `Any` which would be initialized
// with an checked continuation.
auto underlyingContinuationAddr =
SGF.B.createInitExistentialAddr(loc, continuationAddr, continuationTy,
SGF.getLoweredType(continuationTy),
/*conformances=*/{});

auto subs = SubstitutionMap::get(createIntrinsic->getGenericSignature(),
{calleeTypeInfo.substResultType},
ArrayRef<ProtocolConformanceRef>{});

InitializationPtr underlyingInit(
new KnownAddressInitialization(underlyingContinuationAddr));
auto continuationMV =
ManagedValue::forRValueWithoutOwnership(wrappedContinuation);
SGF.emitApplyOfLibraryIntrinsic(loc, createIntrinsic, subs,
{continuationMV}, SGFContext())
.forwardInto(SGF, loc, underlyingInit.get());
} else {
SGF.B.createStore(loc, wrappedContinuation, continuationAddr,
StoreOwnershipQualifier::Trivial);
}

return std::make_tuple(blockStorage, blockStorageTy, continuationTy);
}

ManagedValue
emitForeignAsyncCompletionHandler(SILGenFunction &SGF,
AbstractionPattern origFormalType,
Expand All @@ -751,28 +824,8 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
continuation = SGF.B.createGetAsyncContinuationAddr(loc, resumeBuf,
calleeTypeInfo.substResultType, throws);

// Wrap the Builtin.RawUnsafeContinuation in an
// UnsafeContinuation<T, E>.
auto continuationDecl = SGF.getASTContext().getUnsafeContinuationDecl();

auto errorTy = throws
? SGF.getASTContext().getErrorExistentialType()
: SGF.getASTContext().getNeverType();
auto continuationTy = BoundGenericType::get(continuationDecl, Type(),
{ calleeTypeInfo.substResultType, errorTy })
->getCanonicalType();
auto wrappedContinuation =
SGF.B.createStruct(loc,
SILType::getPrimitiveObjectType(continuationTy),
{continuation});

// Stash it in a buffer for a block object.
auto blockStorageTy = SILType::getPrimitiveAddressType(
SILBlockStorageType::get(continuationTy));
auto blockStorage = SGF.emitTemporaryAllocation(loc, blockStorageTy);
auto continuationAddr = SGF.B.createProjectBlockStorage(loc, blockStorage);
SGF.B.createStore(loc, wrappedContinuation, continuationAddr,
StoreOwnershipQualifier::Trivial);
std::tie(blockStorage, blockStorageTy, continuationTy) =
emitBlockStorage(SGF, loc, throws);

// Get the block invocation function for the given completion block type.
auto completionHandlerIndex = calleeTypeInfo.foreign.async
Expand All @@ -796,11 +849,11 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
SGF.SGM.getOrCreateForeignAsyncCompletionHandlerImplFunction(
cast<SILFunctionType>(
impFnTy->mapTypeOutOfContext()->getReducedType(sig)),
blockStorageTy->mapTypeOutOfContext()->getReducedType(sig),
continuationTy->mapTypeOutOfContext()->getReducedType(sig),
origFormalType, sig, *calleeTypeInfo.foreign.async,
calleeTypeInfo.foreign.error);
origFormalType, sig, calleeTypeInfo);
auto impRef = SGF.B.createFunctionRef(loc, impl);

// Initialize the block object for the completion handler.
SILValue block = SGF.B.createInitBlockStorageHeader(loc, blockStorage,
impRef, SILType::getPrimitiveObjectType(impFnTy),
Expand Down Expand Up @@ -829,7 +882,8 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
SILValue bridgedForeignError) override {
// There should be no direct results from the call.
assert(directResults.empty());

auto &ctx = SGF.getASTContext();

// Await the continuation we handed off to the completion handler.
SILBasicBlock *resumeBlock = SGF.createBasicBlock();
SILBasicBlock *errorBlock = nullptr;
Expand All @@ -854,54 +908,74 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {
// (1) fulfill the unsafe continuation with the foreign error
// (2) branch to the await block
{
// First, fulfill the unsafe continuation with the foreign error.
// First, fulfill the continuation with the foreign error.
// Currently, that block's code looks something like
// %foreignError = ... : $*Optional<NSError>
// %converter = function_ref _convertNSErrorToError(_:)
// %error = apply %converter(%foreignError)
// [... insert here ...]
// destroy_value %error
// destroy_value %foreignError
// Insert code to fulfill it after the native %error is defined. That
// code should structure the RawUnsafeContinuation (continuation) into
// an appropriately typed UnsafeContinuation and then pass that together
// with (a copy of) the error to
// _resumeUnsafeThrowingContinuationWithError.
// Insert code to fulfill it after the native %error is defined. That
// code should load UnsafeContinuation (or CheckedContinuation
// depending on mode) and then pass that together with (a copy of) the
// error to _resume{Unsafe, Checked}ThrowingContinuationWithError.
// [foreign_error_block_with_foreign_async_convention]
SGF.B.setInsertionPoint(
++bridgedForeignError->getDefiningInstruction()->getIterator());

auto continuationDecl = SGF.getASTContext().getUnsafeContinuationDecl();
bool checkedBridging = ctx.LangOpts.UseCheckedAsyncObjCBridging;

auto errorTy = SGF.getASTContext().getErrorExistentialType();
auto continuationBGT =
BoundGenericType::get(continuationDecl, Type(),
{calleeTypeInfo.substResultType, errorTy});
auto env = SGF.F.getGenericEnvironment();
auto sig = env ? env->getGenericSignature().getCanonicalSignature()
: CanGenericSignature();
auto mappedContinuationTy =
continuationBGT->mapTypeOutOfContext()->getReducedType(sig);

// Load unsafe or checked continuation from the block storage
// and call _resume{Unsafe, Checked}ThrowingContinuationWithError.

SILValue continuationAddr =
SGF.B.createProjectBlockStorage(loc, blockStorage);

ManagedValue continuation;
if (checkedBridging) {
FormalEvaluationScope scope(SGF);

auto underlyingValueTy =
OpenedArchetypeType::get(ctx.TheAnyType, sig);

auto underlyingValueAddr = SGF.emitOpenExistential(
loc, ManagedValue::forTrivialAddressRValue(continuationAddr),
SGF.getLoweredType(underlyingValueTy), AccessKind::Read);

continuation = SGF.B.createUncheckedAddrCast(
loc, underlyingValueAddr,
SILType::getPrimitiveAddressType(continuationTy));
} else {
auto continuationVal = SGF.B.createLoad(
loc, continuationAddr, LoadOwnershipQualifier::Trivial);
continuation =
ManagedValue::forObjectRValueWithoutOwnership(continuationVal);
}

auto mappedOutContinuationTy =
continuationTy->mapTypeOutOfContext()->getReducedType(sig);
auto resumeType =
cast<BoundGenericType>(mappedContinuationTy).getGenericArgs()[0];
auto continuationTy = continuationBGT->getCanonicalType();
cast<BoundGenericType>(mappedOutContinuationTy).getGenericArgs()[0];

auto errorIntrinsic =
SGF.SGM.getResumeUnsafeThrowingContinuationWithError();
checkedBridging
? SGF.SGM.getResumeCheckedThrowingContinuationWithError()
: SGF.SGM.getResumeUnsafeThrowingContinuationWithError();

Type replacementTypes[] = {
SGF.F.mapTypeIntoContext(resumeType)->getCanonicalType()};
auto subs = SubstitutionMap::get(errorIntrinsic->getGenericSignature(),
replacementTypes,
ArrayRef<ProtocolConformanceRef>{});
auto wrappedContinuation = SGF.B.createStruct(
loc, SILType::getPrimitiveObjectType(continuationTy),
{continuation});

auto continuationMV = ManagedValue::forObjectRValueWithoutOwnership(
SILValue(wrappedContinuation));
SGF.emitApplyOfLibraryIntrinsic(
loc, errorIntrinsic, subs,
{continuationMV,
{continuation,
SGF.B.copyOwnedObjectRValue(loc, bridgedForeignError,
ManagedValue::ScopeKind::Lexical)},
SGFContext());
Expand All @@ -924,7 +998,7 @@ class ForeignAsyncInitializationPlan final : public ResultPlan {

Scope errorScope(SGF, loc);

auto errorTy = SGF.getASTContext().getErrorExistentialType();
auto errorTy = ctx.getErrorExistentialType();
auto errorVal = SGF.B.createTermResult(
SILType::getPrimitiveObjectType(errorTy), OwnershipKind::Owned);

Expand Down
31 changes: 31 additions & 0 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,37 @@ SILGenModule::getCheckExpectedExecutor() {
"_checkExpectedExecutor");
}

FuncDecl *
SILGenModule::getCreateCheckedContinuation() {
return lookupConcurrencyIntrinsic(getASTContext(),
CreateCheckedContinuation,
"_createCheckedContinuation");
}
FuncDecl *
SILGenModule::getCreateCheckedThrowingContinuation() {
return lookupConcurrencyIntrinsic(getASTContext(),
CreateCheckedThrowingContinuation,
"_createCheckedThrowingContinuation");
}
FuncDecl *
SILGenModule::getResumeCheckedContinuation() {
return lookupConcurrencyIntrinsic(getASTContext(),
ResumeCheckedContinuation,
"_resumeCheckedContinuation");
}
FuncDecl *
SILGenModule::getResumeCheckedThrowingContinuation() {
return lookupConcurrencyIntrinsic(getASTContext(),
ResumeCheckedThrowingContinuation,
"_resumeCheckedThrowingContinuation");
}
FuncDecl *
SILGenModule::getResumeCheckedThrowingContinuationWithError() {
return lookupConcurrencyIntrinsic(
getASTContext(), ResumeCheckedThrowingContinuationWithError,
"_resumeCheckedThrowingContinuationWithError");
}

FuncDecl *SILGenModule::getAsyncMainDrainQueue() {
return lookupConcurrencyIntrinsic(getASTContext(), AsyncMainDrainQueue,
"_asyncMainDrainQueue");
Expand Down
25 changes: 21 additions & 4 deletions lib/SILGen/SILGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace swift {
namespace Lowering {
class TypeConverter;
class SILGenFunction;
class CalleeTypeInfo;

/// An enum to indicate whether a protocol method requirement is satisfied by
/// a free function, as for an operator requirement.
Expand Down Expand Up @@ -127,6 +128,12 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
llvm::Optional<FuncDecl *> ResumeUnsafeThrowingContinuationWithError;
llvm::Optional<FuncDecl *> CheckExpectedExecutor;

llvm::Optional<FuncDecl*> CreateCheckedContinuation;
llvm::Optional<FuncDecl*> CreateCheckedThrowingContinuation;
llvm::Optional<FuncDecl*> ResumeCheckedContinuation;
llvm::Optional<FuncDecl*> ResumeCheckedThrowingContinuation;
llvm::Optional<FuncDecl*> ResumeCheckedThrowingContinuationWithError;

llvm::Optional<FuncDecl *> AsyncMainDrainQueue;
llvm::Optional<FuncDecl *> GetMainExecutor;
llvm::Optional<FuncDecl *> SwiftJobRun;
Expand Down Expand Up @@ -182,10 +189,9 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
/// implementation function for an ObjC API that was imported
/// as `async` in Swift.
SILFunction *getOrCreateForeignAsyncCompletionHandlerImplFunction(
CanSILFunctionType blockType, CanType continuationTy,
AbstractionPattern origFormalType, CanGenericSignature sig,
ForeignAsyncConvention convention,
llvm::Optional<ForeignErrorConvention> foreignError);
CanSILFunctionType blockType, CanType blockStorageType,
CanType continuationType, AbstractionPattern origFormalType,
CanGenericSignature sig, CalleeTypeInfo &calleeInfo);

/// Determine whether the given class has any instance variables that
/// need to be destroyed.
Expand Down Expand Up @@ -551,6 +557,17 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
/// Retrieve the _Concurrency._checkExpectedExecutor intrinsic.
FuncDecl *getCheckExpectedExecutor();

/// Retrieve the _Concurrency._createCheckedContinuation intrinsic.
FuncDecl *getCreateCheckedContinuation();
/// Retrieve the _Concurrency._createCheckedThrowingContinuation intrinsic.
FuncDecl *getCreateCheckedThrowingContinuation();
/// Retrieve the _Concurrency._resumeCheckedContinuation intrinsic.
FuncDecl *getResumeCheckedContinuation();
/// Retrieve the _Concurrency._resumeCheckedThrowingContinuation intrinsic.
FuncDecl *getResumeCheckedThrowingContinuation();
/// Retrieve the _Concurrency._resumeCheckedThrowingContinuationWithError intrinsic.
FuncDecl *getResumeCheckedThrowingContinuationWithError();

/// Retrieve the _Concurrency._asyncMainDrainQueue intrinsic.
FuncDecl *getAsyncMainDrainQueue();
/// Retrieve the _Concurrency._getMainExecutor intrinsic.
Expand Down
Loading