Skip to content

Commit 933c388

Browse files
authored
Merge pull request #69791 from DougGregor/typed-throws-silgen-thunks
[Typed throws] Handle error conversions in other SILGen thunks
2 parents e45afd3 + 76a8950 commit 933c388

File tree

4 files changed

+137
-28
lines changed

4 files changed

+137
-28
lines changed

lib/SILGen/SILGenBackDeploy.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,14 +221,18 @@ void SILGenFunction::emitBackDeploymentThunk(SILDeclRef thunk) {
221221
// Generate the thunk prolog by collecting parameters.
222222
SmallVector<ManagedValue, 4> params;
223223
SmallVector<ManagedValue, 4> indirectParams;
224-
collectThunkParams(loc, params, &indirectParams);
224+
SmallVector<ManagedValue, 4> indirectErrorResults;
225+
collectThunkParams(loc, params, &indirectParams, &indirectErrorResults);
225226

226227
// Build up the list of arguments that we're going to invoke the the real
227228
// function with.
228229
SmallVector<SILValue, 8> paramsForForwarding;
229230
for (auto indirectParam : indirectParams) {
230231
paramsForForwarding.emplace_back(indirectParam.getLValueAddress());
231232
}
233+
for (auto indirectErrorResult : indirectErrorResults) {
234+
paramsForForwarding.emplace_back(indirectErrorResult.getLValueAddress());
235+
}
232236

233237
for (auto param : params) {
234238
// We're going to directly call either the original function or the fallback

lib/SILGen/SILGenPoly.cpp

Lines changed: 68 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,37 @@ void SILGenFunction::collectThunkParams(
891891
}
892892
}
893893

894+
/// If the inner function we are calling (with type \c fnType) from the thunk
895+
/// created by \c SGF requires an indirect error argument, returns that
896+
/// argument.
897+
static llvm::Optional<SILValue>
898+
emitThunkIndirectErrorArgument(SILGenFunction &SGF, SILLocation loc,
899+
CanSILFunctionType fnType) {
900+
// If the function we're calling has as indirect error result, create an
901+
// argument for it.
902+
auto innerError = fnType->getOptionalErrorResult();
903+
if (!innerError || innerError->getConvention() != ResultConvention::Indirect)
904+
return llvm::None;
905+
906+
// If the type of the indirect error is the same for both the inner
907+
// function and the thunk, so we can re-use the indirect error slot.
908+
auto loweredErrorResultType = SGF.getSILType(*innerError, fnType);
909+
if (SGF.IndirectErrorResult &&
910+
SGF.IndirectErrorResult->getType().getObjectType()
911+
== loweredErrorResultType) {
912+
return SGF.IndirectErrorResult;
913+
}
914+
915+
// The type of the indirect error in the inner function differs from
916+
// that of the thunk, or the thunk has a direct error, so allocate a
917+
// stack location for the inner indirect error.
918+
SILValue innerIndirectErrorAddr =
919+
SGF.B.createAllocStack(loc, loweredErrorResultType);
920+
SGF.enterDeallocStackCleanup(innerIndirectErrorAddr);
921+
922+
return innerIndirectErrorAddr;
923+
}
924+
894925
namespace {
895926

896927
class TranslateIndirect : public Cleanup {
@@ -4847,27 +4878,9 @@ static void buildThunkBody(SILGenFunction &SGF, SILLocation loc,
48474878

48484879
// If the function we're calling has as indirect error result, create an
48494880
// argument for it.
4850-
SILValue innerIndirectErrorAddr;
4851-
if (auto innerError = fnType->getOptionalErrorResult()) {
4852-
if (innerError->getConvention() == ResultConvention::Indirect) {
4853-
auto loweredErrorResultType = SGF.getSILType(*innerError, fnType);
4854-
if (SGF.IndirectErrorResult &&
4855-
SGF.IndirectErrorResult->getType().getObjectType()
4856-
== loweredErrorResultType) {
4857-
// The type of the indirect error is the same for both the inner
4858-
// function and the thunk, so we can re-use the indirect error slot.
4859-
innerIndirectErrorAddr = SGF.IndirectErrorResult;
4860-
} else {
4861-
// The type of the indirect error in the inner function differs from
4862-
// that of the thunk, or the thunk has a direct error, so allocate a
4863-
// stack location for the inner indirect error.
4864-
innerIndirectErrorAddr =
4865-
SGF.B.createAllocStack(loc, loweredErrorResultType);
4866-
SGF.enterDeallocStackCleanup(innerIndirectErrorAddr);
4867-
}
4868-
4869-
argValues.push_back(innerIndirectErrorAddr);
4870-
}
4881+
if (auto innerIndirectErrorAddr =
4882+
emitThunkIndirectErrorArgument(SGF, loc, fnType)) {
4883+
argValues.push_back(*innerIndirectErrorAddr);
48714884
}
48724885

48734886
// Add the rest of the arguments.
@@ -5175,7 +5188,8 @@ static void buildWithoutActuallyEscapingThunkBody(SILGenFunction &SGF,
51755188

51765189
SmallVector<ManagedValue, 8> params;
51775190
SmallVector<ManagedValue, 8> indirectResults;
5178-
SGF.collectThunkParams(loc, params, &indirectResults);
5191+
SmallVector<ManagedValue, 1> indirectErrorResults;
5192+
SGF.collectThunkParams(loc, params, &indirectResults, &indirectErrorResults);
51795193

51805194
// Ignore the self parameter at the SIL level. IRGen will use it to
51815195
// recover type metadata.
@@ -5185,13 +5199,16 @@ static void buildWithoutActuallyEscapingThunkBody(SILGenFunction &SGF,
51855199
ManagedValue fnValue = params.pop_back_val();
51865200
auto fnType = fnValue.getType().castTo<SILFunctionType>();
51875201

5202+
// Forward indirect result arguments.
51885203
SmallVector<SILValue, 8> argValues;
51895204
if (!indirectResults.empty()) {
51905205
for (auto result : indirectResults)
51915206
argValues.push_back(result.getLValueAddress());
51925207
}
51935208

5194-
// Forward indirect result arguments.
5209+
// Forward indirect error arguments.
5210+
for (auto indirectError : indirectErrorResults)
5211+
argValues.push_back(indirectError.getLValueAddress());
51955212

51965213
// Add the rest of the arguments.
51975214
forwardFunctionArguments(SGF, loc, fnType, params, argValues);
@@ -5376,14 +5393,18 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
53765393
SILGenFunction thunkSGF(SGM, *thunk, FunctionDC);
53775394
SmallVector<ManagedValue, 4> params;
53785395
SmallVector<ManagedValue, 4> thunkIndirectResults;
5379-
thunkSGF.collectThunkParams(loc, params, &thunkIndirectResults);
5396+
SmallVector<ManagedValue, 4> thunkIndirectErrorResults;
5397+
thunkSGF.collectThunkParams(
5398+
loc, params, &thunkIndirectResults, &thunkIndirectErrorResults);
53805399

53815400
SILFunctionConventions fromConv(fromType, getModule());
53825401
SILFunctionConventions toConv(toType, getModule());
53835402
if (!toConv.useLoweredAddresses()) {
53845403
SmallVector<ManagedValue, 4> thunkArguments;
53855404
for (auto indRes : thunkIndirectResults)
53865405
thunkArguments.push_back(indRes);
5406+
for (auto indErrRes : thunkIndirectErrorResults)
5407+
thunkArguments.push_back(indErrRes);
53875408
thunkArguments.append(params.begin(), params.end());
53885409
SmallVector<SILParameterInfo, 4> toParameters(
53895410
toConv.getParameters().begin(), toConv.getParameters().end());
@@ -5392,7 +5413,8 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
53925413
// Handle self reordering.
53935414
// - For pullbacks: reorder result infos.
53945415
// - For differentials: reorder parameter infos and arguments.
5395-
auto numIndirectResults = thunkIndirectResults.size();
5416+
auto numIndirectResults =
5417+
thunkIndirectResults.size() + thunkIndirectErrorResults.size();
53965418
if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Pullback &&
53975419
toResults.size() > 1) {
53985420
std::rotate(toResults.begin(), toResults.end() - 1, toResults.end());
@@ -5464,6 +5486,8 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
54645486
SmallVector<ManagedValue, 4> thunkArguments;
54655487
thunkArguments.append(thunkIndirectResults.begin(),
54665488
thunkIndirectResults.end());
5489+
thunkArguments.append(thunkIndirectErrorResults.begin(),
5490+
thunkIndirectErrorResults.end());
54675491
thunkArguments.append(params.begin(), params.end());
54685492
SmallVector<SILParameterInfo, 4> toParameters(toConv.getParameters().begin(),
54695493
toConv.getParameters().end());
@@ -5710,7 +5734,9 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
57105734
SILGenFunction thunkSGF(*this, *thunk, customDerivativeFn->getDeclContext());
57115735
SmallVector<ManagedValue, 4> params;
57125736
SmallVector<ManagedValue, 4> indirectResults;
5713-
thunkSGF.collectThunkParams(loc, params, &indirectResults);
5737+
SmallVector<ManagedValue, 1> indirectErrorResults;
5738+
thunkSGF.collectThunkParams(
5739+
loc, params, &indirectResults, &indirectErrorResults);
57145740

57155741
auto *fnRef = thunkSGF.B.createFunctionRef(loc, customDerivativeFn);
57165742
auto fnRefType =
@@ -5754,6 +5780,8 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
57545780
SmallVector<SILValue, 8> arguments;
57555781
for (auto indRes : indirectResults)
57565782
arguments.push_back(indRes.getLValueAddress());
5783+
for (auto indErrorRes : indirectErrorResults)
5784+
arguments.push_back(indErrorRes.getLValueAddress());
57575785
forwardFunctionArguments(thunkSGF, loc, fnRefType, params, arguments);
57585786

57595787
// Apply function argument.
@@ -6187,6 +6215,13 @@ SILGenFunction::emitVTableThunk(SILDeclRef base,
61876215
inputOrigType.getFunctionResultType(),
61886216
inputSubstType.getResult(),
61896217
derivedFTy, thunkTy);
6218+
6219+
// If the function we're calling has as indirect error result, create an
6220+
// argument for it.
6221+
if (auto innerIndirectErrorAddr =
6222+
emitThunkIndirectErrorArgument(*this, loc, derivedFTy)) {
6223+
args.push_back(*innerIndirectErrorAddr);
6224+
}
61906225
}
61916226

61926227
// Then, the arguments.
@@ -6571,6 +6606,13 @@ void SILGenFunction::emitProtocolWitness(
65716606
reqtOrigTy.getFunctionResultType(),
65726607
reqtSubstTy.getResult(),
65736608
witnessFTy, thunkTy);
6609+
6610+
// If the function we're calling has as indirect error result, create an
6611+
// argument for it.
6612+
if (auto innerIndirectErrorAddr =
6613+
emitThunkIndirectErrorArgument(*this, loc, witnessFTy)) {
6614+
args.push_back(*innerIndirectErrorAddr);
6615+
}
65746616
}
65756617

65766618
// - the rest of the arguments

lib/SILGen/SILGenType.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,7 @@ SILFunction *SILGenModule::emitProtocolWitness(
758758
CanAnyFunctionType::get(genericSig,
759759
reqtSubstTy->getParams(),
760760
reqtSubstTy.getResult(),
761-
reqtOrigTy->getExtInfo());
761+
reqtSubstTy->getExtInfo());
762762

763763
// Coroutine lowering requires us to provide these substitutions
764764
// in order to recreate the appropriate yield types for the accessor

test/SILGen/typed_throws_generic.swift

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,66 @@ func forcedMap<T, U>(_ source: [T]) -> [U] {
254254
// CHECK: bb0(%0 : $*U, %1 : $*Never, %2 : $*T)
255255
return source.typedMap { $0 as! U }
256256
}
257+
258+
// Witness thunks
259+
protocol P {
260+
associatedtype E: Error
261+
func f() throws(E)
262+
}
263+
264+
struct Res<Success, Failure: Error>: P {
265+
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @$s20typed_throws_generic3ResVyxq_GAA1PA2aEP1fyy1EQzYKFTW : $@convention(witness_method: P) <τ_0_0, τ_0_1 where τ_0_1 : Error> (@in_guaranteed Res<τ_0_0, τ_0_1>) -> @error_indirect τ_0_1
266+
// CHECK: bb0(%0 : $*τ_0_1, %1 : $*Res<τ_0_0, τ_0_1>):
267+
// CHECK: [[SELF:%.*]] = load [trivial] %1 : $*Res<τ_0_0, τ_0_1>
268+
// CHECK: [[WITNESS:%.*]] = function_ref @$s20typed_throws_generic3ResV1fyyq_YKF : $@convention(method) <τ_0_0, τ_0_1 where τ_0_1 : Error> (Res<τ_0_0, τ_0_1>) -> @error_indirect τ_0_1
269+
// CHECK-NEXT: [[INNER_ERROR_BOX:%.*]] = alloc_stack $τ_0_1
270+
// CHECK-NEXT: try_apply [[WITNESS]]<τ_0_0, τ_0_1>([[INNER_ERROR_BOX]], [[SELF]]) : $@convention(method) <τ_0_0, τ_0_1 where τ_0_1 : Error> (Res<τ_0_0, τ_0_1>) -> @error_indirect τ_0_1, normal [[NORMAL_BB:bb[0-9]+]], error [[ERROR_BB:bb[0-9]+]]
271+
272+
// CHECK: [[NORMAL_BB]]
273+
// CHECK: dealloc_stack [[INNER_ERROR_BOX]] : $*τ_0_1
274+
275+
// CHECK: [[ERROR_BB]]:
276+
// CHECK: throw_addr
277+
func f() throws(Failure) { }
278+
}
279+
280+
struct TypedRes<Success>: P {
281+
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @$s20typed_throws_generic8TypedResVyxGAA1PA2aEP1fyy1EQzYKFTW : $@convention(witness_method: P) <τ_0_0> (@in_guaranteed TypedRes<τ_0_0>) -> @error_indirect MyError
282+
// CHECK: bb0(%0 : $*MyError, %1 : $*TypedRes<τ_0_0>)
283+
// CHECK: [[SELF:%.*]] = load [trivial] %1 : $*TypedRes<τ_0_0>
284+
// CHECK: [[WITNESS:%.*]] = function_ref @$s20typed_throws_generic8TypedResV1fyyAA7MyErrorOYKF : $@convention(method) <τ_0_0> (TypedRes<τ_0_0>) -> @error MyError
285+
// CHECK: try_apply [[WITNESS]]<τ_0_0>([[SELF]]) : $@convention(method) <τ_0_0> (TypedRes<τ_0_0>) -> @error MyError, normal [[NORMAL_BB:bb[0-9]+]], error [[ERROR_BB:bb[0-9]+]]
286+
287+
// CHECK: [[NORMAL_BB]]
288+
// CHECK: return
289+
290+
// CHECK: [[ERROR_BB]]([[ERROR:%.*]] : $MyError):
291+
// CHECK-NEXT: store [[ERROR]] to [trivial] %0 : $*MyError
292+
// CHECK-NEXT: throw_addr
293+
func f() throws(MyError) { }
294+
}
295+
296+
struct UntypedRes<Success>: P {
297+
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @$s20typed_throws_generic10UntypedResVyxGAA1PA2aEP1fyy1EQzYKFTW : $@convention(witness_method: P) <τ_0_0> (@in_guaranteed UntypedRes<τ_0_0>) -> @error_indirect any Error
298+
// CHECK: bb0(%0 : $*any Error, %1 : $*UntypedRes<τ_0_0>):
299+
// CHECK: [[SELF:%.*]] = load [trivial] %1 : $*UntypedRes<τ_0_0>
300+
// CHECK: [[WITNESS:%.*]] = function_ref @$s20typed_throws_generic10UntypedResV1fyyKF : $@convention(method) <τ_0_0> (UntypedRes<τ_0_0>) -> @error any Error
301+
// CHECK: try_apply [[WITNESS]]<τ_0_0>([[SELF]]) : $@convention(method) <τ_0_0> (UntypedRes<τ_0_0>) -> @error any Error, normal [[NORMAL_BB:bb[0-9]+]], error [[ERROR_BB:bb[0-9]+]]
302+
303+
// CHECK: [[NORMAL_BB]]
304+
// CHECK: return
305+
306+
// CHECK: [[ERROR_BB]]([[ERROR:%.*]] : @owned $any Error):
307+
// CHECK-NEXT: store [[ERROR]] to [init] %0 : $*any Error
308+
// CHECK-NEXT: throw_addr
309+
func f() throws { }
310+
}
311+
312+
struct InfallibleRes<Success>: P {
313+
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @$s20typed_throws_generic13InfallibleResVyxGAA1PA2aEP1fyy1EQzYKFTW : $@convention(witness_method: P) <τ_0_0> (@in_guaranteed InfallibleRes<τ_0_0>) -> @error_indirect any Error
314+
// CHECK: bb0(%0 : $*any Error, %1 : $*InfallibleRes<τ_0_0>):
315+
// CHECK: [[SELF:%.*]] = load [trivial] %1 : $*InfallibleRes<τ_0_0>
316+
// CHECK: [[WITNESS:%.*]] = function_ref @$s20typed_throws_generic13InfallibleResV1fyyF : $@convention(method) <τ_0_0> (InfallibleRes<τ_0_0>) -> ()
317+
// CHECK: = apply [[WITNESS]]<τ_0_0>([[SELF]]) : $@convention(method) <τ_0_0> (InfallibleRes<τ_0_0>)
318+
func f() { }
319+
}

0 commit comments

Comments
 (0)