Skip to content

Commit e50d5c3

Browse files
Merge pull request swiftlang#37039 from aschwaighofer/fix_async_dyn_repl_compiler
IRGen: Fix async dynamic replacements
2 parents 5e91dea + 1f890dc commit e50d5c3

File tree

10 files changed

+383
-165
lines changed

10 files changed

+383
-165
lines changed

lib/IRGen/Callee.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,13 @@ namespace irgen {
135135
assert(isSigned());
136136
return Discriminator;
137137
}
138+
PointerAuthInfo getCorrespondingCodeAuthInfo() const {
139+
if (auto authInfo = *this) {
140+
return PointerAuthInfo(authInfo.getCorrespondingCodeKey(),
141+
authInfo.getDiscriminator());
142+
}
143+
return *this;
144+
}
138145

139146
/// Are the auth infos obviously the same?
140147
friend bool operator==(const PointerAuthInfo &lhs,

lib/IRGen/GenCall.cpp

Lines changed: 78 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -239,19 +239,26 @@ llvm::Value *IRGenFunction::getAsyncContext() {
239239
return Builder.CreateLoad(asyncContextLocation);
240240
}
241241

242-
llvm::CallInst *
243-
IRGenFunction::emitSuspendAsyncCall(unsigned asyncContextIndex,
244-
llvm::StructType *resultTy,
245-
ArrayRef<llvm::Value *> args) {
246-
auto *id = Builder.CreateIntrinsicCall(llvm::Intrinsic::coro_suspend_async,
247-
{resultTy}, args);
248-
llvm::Value *calleeContext =
249-
Builder.CreateExtractValue(id, asyncContextIndex);
250-
calleeContext = Builder.CreateBitOrPointerCast(calleeContext, IGM.Int8PtrTy);
251-
llvm::Constant *projectFn = cast<llvm::Constant>(args[2])->stripPointerCasts();
252-
llvm::Value *context = Builder.CreateCall(projectFn, {calleeContext});
242+
void IRGenFunction::storeCurrentAsyncContext(llvm::Value *context) {
253243
context = Builder.CreateBitCast(context, IGM.SwiftContextPtrTy);
254244
Builder.CreateStore(context, asyncContextLocation);
245+
}
246+
247+
llvm::CallInst *IRGenFunction::emitSuspendAsyncCall(
248+
unsigned asyncContextIndex, llvm::StructType *resultTy,
249+
ArrayRef<llvm::Value *> args, bool restoreCurrentContext) {
250+
auto *id = Builder.CreateIntrinsicCall(llvm::Intrinsic::coro_suspend_async,
251+
{resultTy}, args);
252+
if (restoreCurrentContext) {
253+
llvm::Value *calleeContext =
254+
Builder.CreateExtractValue(id, asyncContextIndex);
255+
calleeContext =
256+
Builder.CreateBitOrPointerCast(calleeContext, IGM.Int8PtrTy);
257+
llvm::Constant *projectFn =
258+
cast<llvm::Constant>(args[2])->stripPointerCasts();
259+
llvm::Value *context = Builder.CreateCall(projectFn, {calleeContext});
260+
storeCurrentAsyncContext(context);
261+
}
255262

256263
return id;
257264
}
@@ -2043,10 +2050,9 @@ std::pair<llvm::Value *, llvm::Value *> irgen::getAsyncFunctionAndSize(
20432050
Address(addrPtr, IGF.IGM.getPointerAlignment()), /*isFar*/ false,
20442051
/*expectedType*/ functionPointer.getFunctionType()->getPointerTo());
20452052
}
2046-
if (auto authInfo = functionPointer.getAuthInfo()) {
2047-
auto newAuthInfo = PointerAuthInfo(authInfo.getCorrespondingCodeKey(),
2048-
authInfo.getDiscriminator());
2049-
fn = emitPointerAuthSign(IGF, fn, newAuthInfo);
2053+
if (auto authInfo =
2054+
functionPointer.getAuthInfo().getCorrespondingCodeAuthInfo()) {
2055+
fn = emitPointerAuthSign(IGF, fn, authInfo);
20502056
}
20512057
}
20522058
llvm::Value *size = nullptr;
@@ -2374,13 +2380,11 @@ class AsyncCallEmission final : public CallEmission {
23742380
}
23752381

23762382
FunctionPointer getCalleeFunctionPointer() override {
2377-
PointerAuthInfo newAuthInfo;
2378-
if (auto authInfo = CurCallee.getFunctionPointer().getAuthInfo()) {
2379-
newAuthInfo = PointerAuthInfo(authInfo.getCorrespondingCodeKey(),
2380-
authInfo.getDiscriminator());
2381-
}
2383+
PointerAuthInfo codeAuthInfo = CurCallee.getFunctionPointer()
2384+
.getAuthInfo()
2385+
.getCorrespondingCodeAuthInfo();
23822386
return FunctionPointer(
2383-
FunctionPointer::Kind::Function, calleeFunction, newAuthInfo,
2387+
FunctionPointer::Kind::Function, calleeFunction, codeAuthInfo,
23842388
Signature::forAsyncAwait(IGF.IGM, getCallee().getOrigFunctionType()));
23852389
}
23862390

@@ -2591,20 +2595,6 @@ class AsyncCallEmission final : public CallEmission {
25912595
return IGF.getCalleeErrorResultSlot(errorType);
25922596
}
25932597

2594-
FunctionPointer getFunctionPointerForDispatchCall(const FunctionPointer &fn) {
2595-
auto &IGM = IGF.IGM;
2596-
// Strip off the return type. The original function pointer signature
2597-
// captured both the entry point type and the resume function type.
2598-
auto *fnTy = llvm::FunctionType::get(
2599-
IGM.VoidTy, fn.getSignature().getType()->params(), false /*vaargs*/);
2600-
auto signature =
2601-
Signature(fnTy, fn.getSignature().getAttributes(), IGM.SwiftAsyncCC);
2602-
auto fnPtr =
2603-
FunctionPointer(FunctionPointer::Kind::Function, fn.getRawPointer(),
2604-
fn.getAuthInfo(), signature);
2605-
return fnPtr;
2606-
}
2607-
26082598
llvm::CallInst *createCall(const FunctionPointer &fn,
26092599
ArrayRef<llvm::Value *> args) override {
26102600
auto &IGM = IGF.IGM;
@@ -2624,8 +2614,8 @@ class AsyncCallEmission final : public CallEmission {
26242614
auto resumeProjFn = IGF.getOrCreateResumePrjFn();
26252615
arguments.push_back(
26262616
Builder.CreateBitOrPointerCast(resumeProjFn, IGM.Int8PtrTy));
2627-
auto dispatchFn =
2628-
IGF.createAsyncDispatchFn(getFunctionPointerForDispatchCall(fn), args);
2617+
auto dispatchFn = IGF.createAsyncDispatchFn(
2618+
getFunctionPointerForDispatchCall(IGM, fn), args);
26292619
arguments.push_back(
26302620
Builder.CreateBitOrPointerCast(dispatchFn, IGM.Int8PtrTy));
26312621
arguments.push_back(
@@ -4751,10 +4741,8 @@ llvm::Value *FunctionPointer::getPointer(IRGenFunction &IGF) const {
47514741
auto *result = IGF.emitLoadOfRelativePointer(
47524742
Address(addrPtr, IGF.IGM.getPointerAlignment()), /*isFar*/ false,
47534743
/*expectedType*/ getFunctionType()->getPointerTo());
4754-
if (auto authInfo = AuthInfo) {
4755-
auto newAuthInfo = PointerAuthInfo(authInfo.getCorrespondingCodeKey(),
4756-
authInfo.getDiscriminator());
4757-
result = emitPointerAuthSign(IGF, result, newAuthInfo);
4744+
if (auto codeAuthInfo = AuthInfo.getCorrespondingCodeAuthInfo()) {
4745+
result = emitPointerAuthSign(IGF, result, codeAuthInfo);
47584746
}
47594747
return result;
47604748
}
@@ -4795,11 +4783,7 @@ FunctionPointer FunctionPointer::getAsFunction(IRGenFunction &IGF) const {
47954783
case FunctionPointer::BasicKind::Function:
47964784
return *this;
47974785
case FunctionPointer::BasicKind::AsyncFunctionPointer: {
4798-
auto authInfo = AuthInfo;
4799-
if (authInfo) {
4800-
authInfo = PointerAuthInfo(AuthInfo.getCorrespondingCodeKey(),
4801-
AuthInfo.getDiscriminator());
4802-
}
4786+
auto authInfo = AuthInfo.getCorrespondingCodeAuthInfo();
48034787
return FunctionPointer(Kind::Function, getPointer(IGF), authInfo, Sig);
48044788
}
48054789
}
@@ -4944,3 +4928,51 @@ Address irgen::emitAutoDiffAllocateSubcontext(
49444928
call->setCallingConv(IGF.IGM.SwiftCC);
49454929
return Address(call, IGF.IGM.getPointerAlignment());
49464930
}
4931+
4932+
FunctionPointer
4933+
irgen::getFunctionPointerForDispatchCall(IRGenModule &IGM,
4934+
const FunctionPointer &fn) {
4935+
// Strip off the return type. The original function pointer signature
4936+
// captured both the entry point type and the resume function type.
4937+
auto *fnTy = llvm::FunctionType::get(
4938+
IGM.VoidTy, fn.getSignature().getType()->params(), false /*vaargs*/);
4939+
auto signature =
4940+
Signature(fnTy, fn.getSignature().getAttributes(), IGM.SwiftAsyncCC);
4941+
auto fnPtr = FunctionPointer(FunctionPointer::Kind::Function,
4942+
fn.getRawPointer(), fn.getAuthInfo(), signature);
4943+
return fnPtr;
4944+
}
4945+
4946+
void irgen::forwardAsyncCallResult(IRGenFunction &IGF,
4947+
CanSILFunctionType fnType,
4948+
AsyncContextLayout &layout,
4949+
llvm::CallInst *call) {
4950+
auto &IGM = IGF.IGM;
4951+
auto numAsyncContextParams =
4952+
Signature::forAsyncReturn(IGM, fnType).getAsyncContextIndex() + 1;
4953+
llvm::Value *result = call;
4954+
auto *suspendResultTy = cast<llvm::StructType>(result->getType());
4955+
Explosion resultExplosion;
4956+
Explosion errorExplosion;
4957+
auto hasError = fnType->hasErrorResult();
4958+
Optional<ArrayRef<llvm::Value *>> nativeResults = llvm::None;
4959+
SmallVector<llvm::Value *, 16> nativeResultsStorage;
4960+
4961+
if (suspendResultTy->getNumElements() == numAsyncContextParams) {
4962+
// no result to forward.
4963+
assert(!hasError);
4964+
} else {
4965+
auto &Builder = IGF.Builder;
4966+
auto resultTys =
4967+
makeArrayRef(suspendResultTy->element_begin() + numAsyncContextParams,
4968+
suspendResultTy->element_end());
4969+
4970+
for (unsigned i = 0, e = resultTys.size(); i != e; ++i) {
4971+
llvm::Value *elt =
4972+
Builder.CreateExtractValue(result, numAsyncContextParams + i);
4973+
nativeResultsStorage.push_back(elt);
4974+
}
4975+
nativeResults = nativeResultsStorage;
4976+
}
4977+
emitAsyncReturn(IGF, layout, fnType, nativeResults);
4978+
}

lib/IRGen/GenCall.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,12 @@ namespace irgen {
272272
IRGenFunction &IGF, Address context);
273273
Address emitAutoDiffAllocateSubcontext(
274274
IRGenFunction &IGF, Address context, llvm::Value *size);
275+
276+
FunctionPointer getFunctionPointerForDispatchCall(IRGenModule &IGM,
277+
const FunctionPointer &fn);
278+
void forwardAsyncCallResult(IRGenFunction &IGF, CanSILFunctionType fnType,
279+
AsyncContextLayout &layout, llvm::CallInst *call);
280+
275281
} // end namespace irgen
276282
} // end namespace swift
277283

0 commit comments

Comments
 (0)