Skip to content

Commit c11b301

Browse files
authored
Merge pull request #75221 from drexin/wip-129359370
[IRGen] Add direct error return support for async functions
2 parents 1921b60 + 9b1a82d commit c11b301

File tree

8 files changed

+410
-144
lines changed

8 files changed

+410
-144
lines changed

lib/IRGen/GenCall.cpp

Lines changed: 199 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2114,10 +2114,40 @@ void SignatureExpansion::expandAsyncReturnType() {
21142114
}
21152115
};
21162116

2117-
auto resultType = getSILFuncConventions().getSILResultType(
2118-
IGM.getMaximalTypeExpansionContext());
2117+
auto fnConv = getSILFuncConventions();
2118+
2119+
auto resultType =
2120+
fnConv.getSILResultType(IGM.getMaximalTypeExpansionContext());
21192121
auto &ti = IGM.getTypeInfo(resultType);
21202122
auto &native = ti.nativeReturnValueSchema(IGM);
2123+
2124+
if (!fnConv.hasIndirectSILResults() && !fnConv.hasIndirectSILErrorResults() &&
2125+
!native.requiresIndirect() && fnConv.funcTy->hasErrorResult() &&
2126+
fnConv.isTypedError()) {
2127+
auto errorType = getSILFuncConventions().getSILErrorType(
2128+
IGM.getMaximalTypeExpansionContext());
2129+
auto &errorTi = IGM.getTypeInfo(errorType);
2130+
auto &nativeError = errorTi.nativeReturnValueSchema(IGM);
2131+
if (!nativeError.shouldReturnTypedErrorIndirectly()) {
2132+
auto combined = combineResultAndTypedErrorType(IGM, native, nativeError);
2133+
2134+
if (combined.combinedTy->isVoidTy()) {
2135+
addErrorResult();
2136+
return;
2137+
}
2138+
2139+
if (auto *structTy = dyn_cast<llvm::StructType>(combined.combinedTy)) {
2140+
for (auto *elem : structTy->elements()) {
2141+
ParamIRTypes.push_back(elem);
2142+
}
2143+
} else {
2144+
ParamIRTypes.push_back(combined.combinedTy);
2145+
}
2146+
}
2147+
addErrorResult();
2148+
return;
2149+
}
2150+
21212151
if (native.requiresIndirect() || native.empty()) {
21222152
addErrorResult();
21232153
return;
@@ -2135,11 +2165,23 @@ void SignatureExpansion::expandAsyncReturnType() {
21352165
void SignatureExpansion::addIndirectThrowingResult() {
21362166
if (getSILFuncConventions().funcTy->hasErrorResult() &&
21372167
getSILFuncConventions().isTypedError()) {
2138-
auto resultType = getSILFuncConventions().getSILErrorType(
2139-
IGM.getMaximalTypeExpansionContext());
2140-
const TypeInfo &resultTI = IGM.getTypeInfo(resultType);
2141-
auto storageTy = resultTI.getStorageType();
2142-
ParamIRTypes.push_back(storageTy->getPointerTo());
2168+
auto resultType = getSILFuncConventions().getSILResultType(
2169+
IGM.getMaximalTypeExpansionContext());
2170+
auto &ti = IGM.getTypeInfo(resultType);
2171+
auto &native = ti.nativeReturnValueSchema(IGM);
2172+
2173+
auto errorType = getSILFuncConventions().getSILErrorType(
2174+
IGM.getMaximalTypeExpansionContext());
2175+
const TypeInfo &errorTI = IGM.getTypeInfo(errorType);
2176+
auto &nativeError = errorTI.nativeReturnValueSchema(IGM);
2177+
2178+
if (getSILFuncConventions().hasIndirectSILResults() ||
2179+
getSILFuncConventions().hasIndirectSILErrorResults() ||
2180+
native.requiresIndirect() ||
2181+
nativeError.shouldReturnTypedErrorIndirectly()) {
2182+
auto errorStorageTy = errorTI.getStorageType();
2183+
ParamIRTypes.push_back(errorStorageTy->getPointerTo());
2184+
}
21432185
}
21442186

21452187
}
@@ -2265,6 +2307,36 @@ void SignatureExpansion::expandAsyncAwaitType() {
22652307
IGM.getMaximalTypeExpansionContext());
22662308
auto &ti = IGM.getTypeInfo(resultType);
22672309
auto &native = ti.nativeReturnValueSchema(IGM);
2310+
2311+
if (!getSILFuncConventions().hasIndirectSILResults() &&
2312+
!getSILFuncConventions().hasIndirectSILErrorResults() &&
2313+
getSILFuncConventions().funcTy->hasErrorResult() &&
2314+
!native.requiresIndirect() && getSILFuncConventions().isTypedError()) {
2315+
auto errorType = getSILFuncConventions().getSILErrorType(
2316+
IGM.getMaximalTypeExpansionContext());
2317+
auto &errorTi = IGM.getTypeInfo(errorType);
2318+
auto &nativeError = errorTi.nativeReturnValueSchema(IGM);
2319+
if (!nativeError.shouldReturnTypedErrorIndirectly()) {
2320+
auto combined = combineResultAndTypedErrorType(IGM, native, nativeError);
2321+
2322+
if (combined.combinedTy->isVoidTy()) {
2323+
addErrorResult();
2324+
return;
2325+
}
2326+
2327+
if (auto *structTy = dyn_cast<llvm::StructType>(combined.combinedTy)) {
2328+
for (auto *elem : structTy->elements()) {
2329+
components.push_back(elem);
2330+
}
2331+
} else {
2332+
components.push_back(combined.combinedTy);
2333+
}
2334+
addErrorResult();
2335+
ResultIRType = llvm::StructType::get(IGM.getLLVMContext(), components);
2336+
return;
2337+
}
2338+
}
2339+
22682340
if (native.requiresIndirect() || native.empty()) {
22692341
addErrorResult();
22702342
ResultIRType = llvm::StructType::get(IGM.getLLVMContext(), components);
@@ -2278,7 +2350,6 @@ void SignatureExpansion::expandAsyncAwaitType() {
22782350
});
22792351

22802352
addErrorResult();
2281-
22822353
ResultIRType = llvm::StructType::get(IGM.getLLVMContext(), components);
22832354
}
22842355

@@ -2950,9 +3021,22 @@ class AsyncCallEmission final : public CallEmission {
29503021
setIndirectTypedErrorResultSlotArgsIndex(--LastArgWritten);
29513022
Args[LastArgWritten] = nullptr;
29523023
} else {
2953-
auto buf = IGF.getCalleeTypedErrorResultSlot(
2954-
fnConv.getSILErrorType(IGF.IGM.getMaximalTypeExpansionContext()));
2955-
Args[--LastArgWritten] = buf.getAddress();
3024+
auto silResultTy =
3025+
fnConv.getSILResultType(IGF.IGM.getMaximalTypeExpansionContext());
3026+
auto silErrorTy =
3027+
fnConv.getSILErrorType(IGF.IGM.getMaximalTypeExpansionContext());
3028+
3029+
auto &nativeSchema =
3030+
IGF.IGM.getTypeInfo(silResultTy).nativeReturnValueSchema(IGF.IGM);
3031+
auto &errorSchema =
3032+
IGF.IGM.getTypeInfo(silErrorTy).nativeReturnValueSchema(IGF.IGM);
3033+
3034+
if (nativeSchema.requiresIndirect() ||
3035+
errorSchema.shouldReturnTypedErrorIndirectly()) {
3036+
// Return the error indirectly.
3037+
auto buf = IGF.getCalleeTypedErrorResultSlot(silErrorTy);
3038+
Args[--LastArgWritten] = buf.getAddress();
3039+
}
29563040
}
29573041
}
29583042

@@ -3134,7 +3218,22 @@ class AsyncCallEmission final : public CallEmission {
31343218
errorType =
31353219
substConv.getSILErrorType(IGM.getMaximalTypeExpansionContext());
31363220

3137-
if (resultTys.size() == 1) {
3221+
SILFunctionConventions fnConv(getCallee().getOrigFunctionType(),
3222+
IGF.getSILModule());
3223+
3224+
// Get the natural IR type in the body of the function that makes
3225+
// the call. This may be different than the IR type returned by the
3226+
// call itself due to ABI type coercion.
3227+
auto resultType =
3228+
fnConv.getSILResultType(IGF.IGM.getMaximalTypeExpansionContext());
3229+
auto &nativeSchema =
3230+
IGF.IGM.getTypeInfo(resultType).nativeReturnValueSchema(IGF.IGM);
3231+
3232+
bool mayReturnErrorDirectly = mayReturnTypedErrorDirectly();
3233+
if (mayReturnErrorDirectly && !nativeSchema.requiresIndirect()) {
3234+
return emitToUnmappedExplosionWithDirectTypedError(resultType, result,
3235+
out);
3236+
} else if (resultTys.size() == 1) {
31383237
result = Builder.CreateExtractValue(result, numAsyncContextParams);
31393238
if (hasError) {
31403239
Address errorAddr = IGF.getCalleeErrorResultSlot(errorType,
@@ -3166,17 +3265,6 @@ class AsyncCallEmission final : public CallEmission {
31663265
result = resultAgg;
31673266
}
31683267

3169-
SILFunctionConventions fnConv(getCallee().getOrigFunctionType(),
3170-
IGF.getSILModule());
3171-
3172-
// Get the natural IR type in the body of the function that makes
3173-
// the call. This may be different than the IR type returned by the
3174-
// call itself due to ABI type coercion.
3175-
auto resultType =
3176-
fnConv.getSILResultType(IGF.IGM.getMaximalTypeExpansionContext());
3177-
auto &nativeSchema =
3178-
IGF.IGM.getTypeInfo(resultType).nativeReturnValueSchema(IGF.IGM);
3179-
31803268
// For ABI reasons the result type of the call might not actually match the
31813269
// expected result type.
31823270
//
@@ -3315,7 +3403,7 @@ void CallEmission::emitToUnmappedMemory(Address result) {
33153403
#ifndef NDEBUG
33163404
LastArgWritten = 0; // appease an assert
33173405
#endif
3318-
3406+
33193407
auto call = emitCallSite();
33203408

33213409
// Async calls need to store the error result that is passed as a parameter.
@@ -4403,32 +4491,21 @@ void CallEmission::emitToUnmappedExplosionWithDirectTypedError(
44034491
extractScalarResults(IGF, result->getType(), result, nativeExplosion);
44044492
auto values = nativeExplosion.claimAll();
44054493

4406-
auto convertIfNecessary = [&](llvm::Type *nativeTy,
4407-
llvm::Value *elt) -> llvm::Value * {
4408-
auto *eltTy = elt->getType();
4409-
if (nativeTy->isIntOrPtrTy() && eltTy->isIntOrPtrTy() &&
4410-
nativeTy->getPrimitiveSizeInBits() != eltTy->getPrimitiveSizeInBits()) {
4411-
if (nativeTy->isPointerTy() && eltTy == IGF.IGM.IntPtrTy) {
4412-
return IGF.Builder.CreateIntToPtr(elt, nativeTy);
4413-
}
4414-
return IGF.Builder.CreateTruncOrBitCast(elt, nativeTy);
4415-
}
4416-
return elt;
4417-
};
4418-
44194494
Explosion errorExplosion;
44204495
if (!errorSchema.empty()) {
44214496
if (auto *structTy =
44224497
dyn_cast<llvm::StructType>(errorSchema.getExpandedType(IGF.IGM))) {
44234498
for (unsigned i = 0, e = structTy->getNumElements(); i < e; ++i) {
44244499
llvm::Value *elt = values[combined.errorValueMapping[i]];
44254500
auto *nativeTy = structTy->getElementType(i);
4426-
elt = convertIfNecessary(nativeTy, elt);
4501+
elt = convertForAsyncDirect(IGF, elt, nativeTy, /*forExtraction*/ true);
44274502
errorExplosion.add(elt);
44284503
}
44294504
} else {
4430-
errorExplosion.add(convertIfNecessary(
4431-
combined.combinedTy, values[combined.errorValueMapping[0]]));
4505+
auto *converted =
4506+
convertForAsyncDirect(IGF, values[combined.errorValueMapping[0]],
4507+
combined.combinedTy, /*forExtraction*/ true);
4508+
errorExplosion.add(converted);
44324509
}
44334510

44344511
typedErrorExplosion =
@@ -4444,10 +4521,14 @@ void CallEmission::emitToUnmappedExplosionWithDirectTypedError(
44444521
dyn_cast<llvm::StructType>(nativeSchema.getExpandedType(IGF.IGM))) {
44454522
for (unsigned i = 0, e = structTy->getNumElements(); i < e; ++i) {
44464523
auto *nativeTy = structTy->getElementType(i);
4447-
resultExplosion.add(convertIfNecessary(nativeTy, values[i]));
4524+
auto *converted = convertForAsyncDirect(IGF, values[i], nativeTy,
4525+
/*forExtraction*/ true);
4526+
resultExplosion.add(converted);
44484527
}
44494528
} else {
4450-
resultExplosion.add(convertIfNecessary(combined.combinedTy, values[0]));
4529+
auto *converted = convertForAsyncDirect(
4530+
IGF, values[0], combined.combinedTy, /*forExtraction*/ true);
4531+
resultExplosion.add(converted);
44514532
}
44524533
out = nativeSchema.mapFromNative(IGF.IGM, IGF, resultExplosion, resultType);
44534534
}
@@ -5313,6 +5394,33 @@ llvm::Value* IRGenFunction::coerceValue(llvm::Value *value, llvm::Type *toTy,
53135394
return loaded;
53145395
}
53155396

5397+
llvm::Value *irgen::convertForAsyncDirect(IRGenFunction &IGF,
5398+
llvm::Value *value, llvm::Type *toTy,
5399+
bool forExtraction) {
5400+
auto &Builder = IGF.Builder;
5401+
auto *fromTy = value->getType();
5402+
if (toTy->isIntOrPtrTy() && fromTy->isIntOrPtrTy() && toTy != fromTy) {
5403+
5404+
if (toTy->isPointerTy()) {
5405+
if (fromTy->isPointerTy())
5406+
return Builder.CreateBitCast(value, toTy);
5407+
if (fromTy == IGF.IGM.IntPtrTy)
5408+
return Builder.CreateIntToPtr(value, toTy);
5409+
} else if (fromTy->isPointerTy()) {
5410+
if (toTy == IGF.IGM.IntPtrTy) {
5411+
return Builder.CreatePtrToInt(value, toTy);
5412+
}
5413+
}
5414+
5415+
if (forExtraction) {
5416+
return Builder.CreateTruncOrBitCast(value, toTy);
5417+
} else {
5418+
return Builder.CreateZExtOrBitCast(value, toTy);
5419+
}
5420+
}
5421+
return value;
5422+
}
5423+
53165424
void IRGenFunction::emitScalarReturn(llvm::Type *resultType,
53175425
Explosion &result) {
53185426
if (result.empty()) {
@@ -5754,32 +5862,18 @@ void IRGenFunction::emitScalarReturn(SILType returnResultType,
57545862
return;
57555863
}
57565864

5757-
auto convertIfNecessary = [&](llvm::Type *nativeTy,
5758-
llvm::Value *elt) -> llvm::Value * {
5759-
auto *eltTy = elt->getType();
5760-
if (nativeTy->isIntOrPtrTy() && eltTy->isIntOrPtrTy() &&
5761-
nativeTy->getPrimitiveSizeInBits() !=
5762-
eltTy->getPrimitiveSizeInBits()) {
5763-
assert(nativeTy->getPrimitiveSizeInBits() >
5764-
eltTy->getPrimitiveSizeInBits());
5765-
if (eltTy->isPointerTy()) {
5766-
return Builder.CreatePtrToInt(elt, nativeTy);
5767-
}
5768-
return Builder.CreateZExt(elt, nativeTy);
5769-
}
5770-
return elt;
5771-
};
5772-
57735865
if (auto *structTy = dyn_cast<llvm::StructType>(combinedTy)) {
57745866
nativeAgg = llvm::UndefValue::get(combinedTy);
57755867
for (unsigned i = 0, e = native.size(); i != e; ++i) {
57765868
llvm::Value *elt = native.claimNext();
57775869
auto *nativeTy = structTy->getElementType(i);
5778-
elt = convertIfNecessary(nativeTy, elt);
5870+
elt = convertForAsyncDirect(*this, elt, nativeTy,
5871+
/*forExtraction*/ false);
57795872
nativeAgg = Builder.CreateInsertValue(nativeAgg, elt, i);
57805873
}
57815874
} else {
5782-
nativeAgg = convertIfNecessary(combinedTy, native.claimNext());
5875+
nativeAgg = convertForAsyncDirect(*this, native.claimNext(), combinedTy,
5876+
/*forExtraction*/ false);
57835877
}
57845878
}
57855879

@@ -6089,6 +6183,51 @@ void irgen::emitAsyncReturn(IRGenFunction &IGF, AsyncContextLayout &asyncLayout,
60896183
SILFunctionConventions conv(fnType, IGF.getSILModule());
60906184
auto &nativeSchema =
60916185
IGM.getTypeInfo(funcResultTypeInContext).nativeReturnValueSchema(IGM);
6186+
6187+
if (fnType->hasErrorResult() && !conv.hasIndirectSILResults() &&
6188+
!conv.hasIndirectSILErrorResults() && !nativeSchema.requiresIndirect() &&
6189+
conv.isTypedError()) {
6190+
auto errorType = conv.getSILErrorType(IGM.getMaximalTypeExpansionContext());
6191+
auto &errorTI = IGM.getTypeInfo(errorType);
6192+
auto &nativeError = errorTI.nativeReturnValueSchema(IGM);
6193+
if (!nativeError.shouldReturnTypedErrorIndirectly()) {
6194+
assert(!error.empty() && "Direct error return must have error value");
6195+
auto *combinedTy =
6196+
combineResultAndTypedErrorType(IGM, nativeSchema, nativeError)
6197+
.combinedTy;
6198+
6199+
if (combinedTy->isVoidTy()) {
6200+
assert(result.empty() && "Unexpected result values");
6201+
} else {
6202+
if (auto *structTy = dyn_cast<llvm::StructType>(combinedTy)) {
6203+
llvm::Value *nativeAgg = llvm::UndefValue::get(structTy);
6204+
for (unsigned i = 0, e = result.size(); i != e; ++i) {
6205+
llvm::Value *elt = result.claimNext();
6206+
auto *nativeTy = structTy->getElementType(i);
6207+
elt = convertForAsyncDirect(IGF, elt, nativeTy,
6208+
/*forExtraction*/ false);
6209+
nativeAgg = IGF.Builder.CreateInsertValue(nativeAgg, elt, i);
6210+
}
6211+
Explosion out;
6212+
IGF.emitAllExtractValues(nativeAgg, structTy, out);
6213+
while (!out.empty()) {
6214+
nativeResultsStorage.push_back(out.claimNext());
6215+
}
6216+
} else {
6217+
auto *converted = convertForAsyncDirect(
6218+
IGF, result.claimNext(), combinedTy, /*forExtraction*/ false);
6219+
nativeResultsStorage.push_back(converted);
6220+
}
6221+
}
6222+
6223+
nativeResultsStorage.push_back(error.claimNext());
6224+
nativeResults = nativeResultsStorage;
6225+
6226+
emitAsyncReturn(IGF, asyncLayout, fnType, nativeResults);
6227+
return;
6228+
}
6229+
}
6230+
60926231
if (result.empty() && !nativeSchema.empty()) {
60936232
if (!nativeSchema.requiresIndirect())
60946233
// When we throw, we set the return values to undef.

lib/IRGen/GenCall.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ namespace irgen {
274274
void forwardAsyncCallResult(IRGenFunction &IGF, CanSILFunctionType fnType,
275275
AsyncContextLayout &layout, llvm::CallInst *call);
276276

277+
/// Converts a value for async direct errors.
278+
llvm::Value *convertForAsyncDirect(IRGenFunction &IGF, llvm::Value *value,
279+
llvm::Type *toTy, bool forExtraction);
280+
277281
} // end namespace irgen
278282
} // end namespace swift
279283

0 commit comments

Comments
 (0)