@@ -239,19 +239,26 @@ llvm::Value *IRGenFunction::getAsyncContext() {
239
239
return Builder.CreateLoad (asyncContextLocation);
240
240
}
241
241
242
- llvm::CallInst *
243
- IRGenFunction::emitSuspendAsyncCall (unsigned asyncContextIndex,
244
- llvm::StructType *resultTy,
245
- ArrayRef<llvm::Value *> args) {
246
- auto *id = Builder.CreateIntrinsicCall (llvm::Intrinsic::coro_suspend_async,
247
- {resultTy}, args);
248
- llvm::Value *calleeContext =
249
- Builder.CreateExtractValue (id, asyncContextIndex);
250
- calleeContext = Builder.CreateBitOrPointerCast (calleeContext, IGM.Int8PtrTy );
251
- llvm::Constant *projectFn = cast<llvm::Constant>(args[2 ])->stripPointerCasts ();
252
- llvm::Value *context = Builder.CreateCall (projectFn, {calleeContext});
242
+ void IRGenFunction::storeCurrentAsyncContext (llvm::Value *context) {
253
243
context = Builder.CreateBitCast (context, IGM.SwiftContextPtrTy );
254
244
Builder.CreateStore (context, asyncContextLocation);
245
+ }
246
+
247
+ llvm::CallInst *IRGenFunction::emitSuspendAsyncCall (
248
+ unsigned asyncContextIndex, llvm::StructType *resultTy,
249
+ ArrayRef<llvm::Value *> args, bool restoreCurrentContext) {
250
+ auto *id = Builder.CreateIntrinsicCall (llvm::Intrinsic::coro_suspend_async,
251
+ {resultTy}, args);
252
+ if (restoreCurrentContext) {
253
+ llvm::Value *calleeContext =
254
+ Builder.CreateExtractValue (id, asyncContextIndex);
255
+ calleeContext =
256
+ Builder.CreateBitOrPointerCast (calleeContext, IGM.Int8PtrTy );
257
+ llvm::Constant *projectFn =
258
+ cast<llvm::Constant>(args[2 ])->stripPointerCasts ();
259
+ llvm::Value *context = Builder.CreateCall (projectFn, {calleeContext});
260
+ storeCurrentAsyncContext (context);
261
+ }
255
262
256
263
return id;
257
264
}
@@ -2043,10 +2050,9 @@ std::pair<llvm::Value *, llvm::Value *> irgen::getAsyncFunctionAndSize(
2043
2050
Address (addrPtr, IGF.IGM .getPointerAlignment ()), /* isFar*/ false ,
2044
2051
/* expectedType*/ functionPointer.getFunctionType ()->getPointerTo ());
2045
2052
}
2046
- if (auto authInfo = functionPointer.getAuthInfo ()) {
2047
- auto newAuthInfo = PointerAuthInfo (authInfo.getCorrespondingCodeKey (),
2048
- authInfo.getDiscriminator ());
2049
- fn = emitPointerAuthSign (IGF, fn, newAuthInfo);
2053
+ if (auto authInfo =
2054
+ functionPointer.getAuthInfo ().getCorrespondingCodeAuthInfo ()) {
2055
+ fn = emitPointerAuthSign (IGF, fn, authInfo);
2050
2056
}
2051
2057
}
2052
2058
llvm::Value *size = nullptr ;
@@ -2374,13 +2380,11 @@ class AsyncCallEmission final : public CallEmission {
2374
2380
}
2375
2381
2376
2382
FunctionPointer getCalleeFunctionPointer () override {
2377
- PointerAuthInfo newAuthInfo;
2378
- if (auto authInfo = CurCallee.getFunctionPointer ().getAuthInfo ()) {
2379
- newAuthInfo = PointerAuthInfo (authInfo.getCorrespondingCodeKey (),
2380
- authInfo.getDiscriminator ());
2381
- }
2383
+ PointerAuthInfo codeAuthInfo = CurCallee.getFunctionPointer ()
2384
+ .getAuthInfo ()
2385
+ .getCorrespondingCodeAuthInfo ();
2382
2386
return FunctionPointer (
2383
- FunctionPointer::Kind::Function, calleeFunction, newAuthInfo ,
2387
+ FunctionPointer::Kind::Function, calleeFunction, codeAuthInfo ,
2384
2388
Signature::forAsyncAwait (IGF.IGM , getCallee ().getOrigFunctionType ()));
2385
2389
}
2386
2390
@@ -2591,20 +2595,6 @@ class AsyncCallEmission final : public CallEmission {
2591
2595
return IGF.getCalleeErrorResultSlot (errorType);
2592
2596
}
2593
2597
2594
- FunctionPointer getFunctionPointerForDispatchCall (const FunctionPointer &fn) {
2595
- auto &IGM = IGF.IGM ;
2596
- // Strip off the return type. The original function pointer signature
2597
- // captured both the entry point type and the resume function type.
2598
- auto *fnTy = llvm::FunctionType::get (
2599
- IGM.VoidTy , fn.getSignature ().getType ()->params (), false /* vaargs*/ );
2600
- auto signature =
2601
- Signature (fnTy, fn.getSignature ().getAttributes (), IGM.SwiftAsyncCC );
2602
- auto fnPtr =
2603
- FunctionPointer (FunctionPointer::Kind::Function, fn.getRawPointer (),
2604
- fn.getAuthInfo (), signature);
2605
- return fnPtr;
2606
- }
2607
-
2608
2598
llvm::CallInst *createCall (const FunctionPointer &fn,
2609
2599
ArrayRef<llvm::Value *> args) override {
2610
2600
auto &IGM = IGF.IGM ;
@@ -2624,8 +2614,8 @@ class AsyncCallEmission final : public CallEmission {
2624
2614
auto resumeProjFn = IGF.getOrCreateResumePrjFn ();
2625
2615
arguments.push_back (
2626
2616
Builder.CreateBitOrPointerCast (resumeProjFn, IGM.Int8PtrTy ));
2627
- auto dispatchFn =
2628
- IGF. createAsyncDispatchFn ( getFunctionPointerForDispatchCall (fn), args);
2617
+ auto dispatchFn = IGF. createAsyncDispatchFn (
2618
+ getFunctionPointerForDispatchCall (IGM, fn), args);
2629
2619
arguments.push_back (
2630
2620
Builder.CreateBitOrPointerCast (dispatchFn, IGM.Int8PtrTy ));
2631
2621
arguments.push_back (
@@ -4751,10 +4741,8 @@ llvm::Value *FunctionPointer::getPointer(IRGenFunction &IGF) const {
4751
4741
auto *result = IGF.emitLoadOfRelativePointer (
4752
4742
Address (addrPtr, IGF.IGM .getPointerAlignment ()), /* isFar*/ false ,
4753
4743
/* expectedType*/ getFunctionType ()->getPointerTo ());
4754
- if (auto authInfo = AuthInfo) {
4755
- auto newAuthInfo = PointerAuthInfo (authInfo.getCorrespondingCodeKey (),
4756
- authInfo.getDiscriminator ());
4757
- result = emitPointerAuthSign (IGF, result, newAuthInfo);
4744
+ if (auto codeAuthInfo = AuthInfo.getCorrespondingCodeAuthInfo ()) {
4745
+ result = emitPointerAuthSign (IGF, result, codeAuthInfo);
4758
4746
}
4759
4747
return result;
4760
4748
}
@@ -4795,11 +4783,7 @@ FunctionPointer FunctionPointer::getAsFunction(IRGenFunction &IGF) const {
4795
4783
case FunctionPointer::BasicKind::Function:
4796
4784
return *this ;
4797
4785
case FunctionPointer::BasicKind::AsyncFunctionPointer: {
4798
- auto authInfo = AuthInfo;
4799
- if (authInfo) {
4800
- authInfo = PointerAuthInfo (AuthInfo.getCorrespondingCodeKey (),
4801
- AuthInfo.getDiscriminator ());
4802
- }
4786
+ auto authInfo = AuthInfo.getCorrespondingCodeAuthInfo ();
4803
4787
return FunctionPointer (Kind::Function, getPointer (IGF), authInfo, Sig);
4804
4788
}
4805
4789
}
@@ -4944,3 +4928,51 @@ Address irgen::emitAutoDiffAllocateSubcontext(
4944
4928
call->setCallingConv (IGF.IGM .SwiftCC );
4945
4929
return Address (call, IGF.IGM .getPointerAlignment ());
4946
4930
}
4931
+
4932
+ FunctionPointer
4933
+ irgen::getFunctionPointerForDispatchCall (IRGenModule &IGM,
4934
+ const FunctionPointer &fn) {
4935
+ // Strip off the return type. The original function pointer signature
4936
+ // captured both the entry point type and the resume function type.
4937
+ auto *fnTy = llvm::FunctionType::get (
4938
+ IGM.VoidTy , fn.getSignature ().getType ()->params (), false /* vaargs*/ );
4939
+ auto signature =
4940
+ Signature (fnTy, fn.getSignature ().getAttributes (), IGM.SwiftAsyncCC );
4941
+ auto fnPtr = FunctionPointer (FunctionPointer::Kind::Function,
4942
+ fn.getRawPointer (), fn.getAuthInfo (), signature);
4943
+ return fnPtr;
4944
+ }
4945
+
4946
+ void irgen::forwardAsyncCallResult (IRGenFunction &IGF,
4947
+ CanSILFunctionType fnType,
4948
+ AsyncContextLayout &layout,
4949
+ llvm::CallInst *call) {
4950
+ auto &IGM = IGF.IGM ;
4951
+ auto numAsyncContextParams =
4952
+ Signature::forAsyncReturn (IGM, fnType).getAsyncContextIndex () + 1 ;
4953
+ llvm::Value *result = call;
4954
+ auto *suspendResultTy = cast<llvm::StructType>(result->getType ());
4955
+ Explosion resultExplosion;
4956
+ Explosion errorExplosion;
4957
+ auto hasError = fnType->hasErrorResult ();
4958
+ Optional<ArrayRef<llvm::Value *>> nativeResults = llvm::None;
4959
+ SmallVector<llvm::Value *, 16 > nativeResultsStorage;
4960
+
4961
+ if (suspendResultTy->getNumElements () == numAsyncContextParams) {
4962
+ // no result to forward.
4963
+ assert (!hasError);
4964
+ } else {
4965
+ auto &Builder = IGF.Builder ;
4966
+ auto resultTys =
4967
+ makeArrayRef (suspendResultTy->element_begin () + numAsyncContextParams,
4968
+ suspendResultTy->element_end ());
4969
+
4970
+ for (unsigned i = 0 , e = resultTys.size (); i != e; ++i) {
4971
+ llvm::Value *elt =
4972
+ Builder.CreateExtractValue (result, numAsyncContextParams + i);
4973
+ nativeResultsStorage.push_back (elt);
4974
+ }
4975
+ nativeResults = nativeResultsStorage;
4976
+ }
4977
+ emitAsyncReturn (IGF, layout, fnType, nativeResults);
4978
+ }
0 commit comments