Skip to content

[Typed throws] Handle error conversions in other SILGen thunks #69791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion lib/SILGen/SILGenBackDeploy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,18 @@ void SILGenFunction::emitBackDeploymentThunk(SILDeclRef thunk) {
// Generate the thunk prolog by collecting parameters.
SmallVector<ManagedValue, 4> params;
SmallVector<ManagedValue, 4> indirectParams;
collectThunkParams(loc, params, &indirectParams);
SmallVector<ManagedValue, 4> indirectErrorResults;
collectThunkParams(loc, params, &indirectParams, &indirectErrorResults);

// Build up the list of arguments that we're going to invoke the the real
// function with.
SmallVector<SILValue, 8> paramsForForwarding;
for (auto indirectParam : indirectParams) {
paramsForForwarding.emplace_back(indirectParam.getLValueAddress());
}
for (auto indirectErrorResult : indirectErrorResults) {
paramsForForwarding.emplace_back(indirectErrorResult.getLValueAddress());
}

for (auto param : params) {
// We're going to directly call either the original function or the fallback
Expand Down
94 changes: 68 additions & 26 deletions lib/SILGen/SILGenPoly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,37 @@ void SILGenFunction::collectThunkParams(
}
}

/// If the inner function we are calling (with type \c fnType) from the thunk
/// created by \c SGF requires an indirect error argument, returns that
/// argument.
static llvm::Optional<SILValue>
emitThunkIndirectErrorArgument(SILGenFunction &SGF, SILLocation loc,
CanSILFunctionType fnType) {
// If the function we're calling has as indirect error result, create an
// argument for it.
auto innerError = fnType->getOptionalErrorResult();
if (!innerError || innerError->getConvention() != ResultConvention::Indirect)
return llvm::None;

// If the type of the indirect error is the same for both the inner
// function and the thunk, so we can re-use the indirect error slot.
auto loweredErrorResultType = SGF.getSILType(*innerError, fnType);
if (SGF.IndirectErrorResult &&
SGF.IndirectErrorResult->getType().getObjectType()
== loweredErrorResultType) {
return SGF.IndirectErrorResult;
}

// The type of the indirect error in the inner function differs from
// that of the thunk, or the thunk has a direct error, so allocate a
// stack location for the inner indirect error.
SILValue innerIndirectErrorAddr =
SGF.B.createAllocStack(loc, loweredErrorResultType);
SGF.enterDeallocStackCleanup(innerIndirectErrorAddr);

return innerIndirectErrorAddr;
}

namespace {

class TranslateIndirect : public Cleanup {
Expand Down Expand Up @@ -4847,27 +4878,9 @@ static void buildThunkBody(SILGenFunction &SGF, SILLocation loc,

// If the function we're calling has as indirect error result, create an
// argument for it.
SILValue innerIndirectErrorAddr;
if (auto innerError = fnType->getOptionalErrorResult()) {
if (innerError->getConvention() == ResultConvention::Indirect) {
auto loweredErrorResultType = SGF.getSILType(*innerError, fnType);
if (SGF.IndirectErrorResult &&
SGF.IndirectErrorResult->getType().getObjectType()
== loweredErrorResultType) {
// The type of the indirect error is the same for both the inner
// function and the thunk, so we can re-use the indirect error slot.
innerIndirectErrorAddr = SGF.IndirectErrorResult;
} else {
// The type of the indirect error in the inner function differs from
// that of the thunk, or the thunk has a direct error, so allocate a
// stack location for the inner indirect error.
innerIndirectErrorAddr =
SGF.B.createAllocStack(loc, loweredErrorResultType);
SGF.enterDeallocStackCleanup(innerIndirectErrorAddr);
}

argValues.push_back(innerIndirectErrorAddr);
}
if (auto innerIndirectErrorAddr =
emitThunkIndirectErrorArgument(SGF, loc, fnType)) {
argValues.push_back(*innerIndirectErrorAddr);
}

// Add the rest of the arguments.
Expand Down Expand Up @@ -5175,7 +5188,8 @@ static void buildWithoutActuallyEscapingThunkBody(SILGenFunction &SGF,

SmallVector<ManagedValue, 8> params;
SmallVector<ManagedValue, 8> indirectResults;
SGF.collectThunkParams(loc, params, &indirectResults);
SmallVector<ManagedValue, 1> indirectErrorResults;
SGF.collectThunkParams(loc, params, &indirectResults, &indirectErrorResults);

// Ignore the self parameter at the SIL level. IRGen will use it to
// recover type metadata.
Expand All @@ -5185,13 +5199,16 @@ static void buildWithoutActuallyEscapingThunkBody(SILGenFunction &SGF,
ManagedValue fnValue = params.pop_back_val();
auto fnType = fnValue.getType().castTo<SILFunctionType>();

// Forward indirect result arguments.
SmallVector<SILValue, 8> argValues;
if (!indirectResults.empty()) {
for (auto result : indirectResults)
argValues.push_back(result.getLValueAddress());
}

// Forward indirect result arguments.
// Forward indirect error arguments.
for (auto indirectError : indirectErrorResults)
argValues.push_back(indirectError.getLValueAddress());

// Add the rest of the arguments.
forwardFunctionArguments(SGF, loc, fnType, params, argValues);
Expand Down Expand Up @@ -5376,14 +5393,18 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
SILGenFunction thunkSGF(SGM, *thunk, FunctionDC);
SmallVector<ManagedValue, 4> params;
SmallVector<ManagedValue, 4> thunkIndirectResults;
thunkSGF.collectThunkParams(loc, params, &thunkIndirectResults);
SmallVector<ManagedValue, 4> thunkIndirectErrorResults;
thunkSGF.collectThunkParams(
loc, params, &thunkIndirectResults, &thunkIndirectErrorResults);

SILFunctionConventions fromConv(fromType, getModule());
SILFunctionConventions toConv(toType, getModule());
if (!toConv.useLoweredAddresses()) {
SmallVector<ManagedValue, 4> thunkArguments;
for (auto indRes : thunkIndirectResults)
thunkArguments.push_back(indRes);
for (auto indErrRes : thunkIndirectErrorResults)
thunkArguments.push_back(indErrRes);
thunkArguments.append(params.begin(), params.end());
SmallVector<SILParameterInfo, 4> toParameters(
toConv.getParameters().begin(), toConv.getParameters().end());
Expand All @@ -5392,7 +5413,8 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
// Handle self reordering.
// - For pullbacks: reorder result infos.
// - For differentials: reorder parameter infos and arguments.
auto numIndirectResults = thunkIndirectResults.size();
auto numIndirectResults =
thunkIndirectResults.size() + thunkIndirectErrorResults.size();
if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Pullback &&
toResults.size() > 1) {
std::rotate(toResults.begin(), toResults.end() - 1, toResults.end());
Expand Down Expand Up @@ -5464,6 +5486,8 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
SmallVector<ManagedValue, 4> thunkArguments;
thunkArguments.append(thunkIndirectResults.begin(),
thunkIndirectResults.end());
thunkArguments.append(thunkIndirectErrorResults.begin(),
thunkIndirectErrorResults.end());
thunkArguments.append(params.begin(), params.end());
SmallVector<SILParameterInfo, 4> toParameters(toConv.getParameters().begin(),
toConv.getParameters().end());
Expand Down Expand Up @@ -5710,7 +5734,9 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
SILGenFunction thunkSGF(*this, *thunk, customDerivativeFn->getDeclContext());
SmallVector<ManagedValue, 4> params;
SmallVector<ManagedValue, 4> indirectResults;
thunkSGF.collectThunkParams(loc, params, &indirectResults);
SmallVector<ManagedValue, 1> indirectErrorResults;
thunkSGF.collectThunkParams(
loc, params, &indirectResults, &indirectErrorResults);

auto *fnRef = thunkSGF.B.createFunctionRef(loc, customDerivativeFn);
auto fnRefType =
Expand Down Expand Up @@ -5754,6 +5780,8 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
SmallVector<SILValue, 8> arguments;
for (auto indRes : indirectResults)
arguments.push_back(indRes.getLValueAddress());
for (auto indErrorRes : indirectErrorResults)
arguments.push_back(indErrorRes.getLValueAddress());
forwardFunctionArguments(thunkSGF, loc, fnRefType, params, arguments);

// Apply function argument.
Expand Down Expand Up @@ -6187,6 +6215,13 @@ SILGenFunction::emitVTableThunk(SILDeclRef base,
inputOrigType.getFunctionResultType(),
inputSubstType.getResult(),
derivedFTy, thunkTy);

// If the function we're calling has as indirect error result, create an
// argument for it.
if (auto innerIndirectErrorAddr =
emitThunkIndirectErrorArgument(*this, loc, derivedFTy)) {
args.push_back(*innerIndirectErrorAddr);
}
}

// Then, the arguments.
Expand Down Expand Up @@ -6571,6 +6606,13 @@ void SILGenFunction::emitProtocolWitness(
reqtOrigTy.getFunctionResultType(),
reqtSubstTy.getResult(),
witnessFTy, thunkTy);

// If the function we're calling has as indirect error result, create an
// argument for it.
if (auto innerIndirectErrorAddr =
emitThunkIndirectErrorArgument(*this, loc, witnessFTy)) {
args.push_back(*innerIndirectErrorAddr);
}
}

// - the rest of the arguments
Expand Down
2 changes: 1 addition & 1 deletion lib/SILGen/SILGenType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ SILFunction *SILGenModule::emitProtocolWitness(
CanAnyFunctionType::get(genericSig,
reqtSubstTy->getParams(),
reqtSubstTy.getResult(),
reqtOrigTy->getExtInfo());
reqtSubstTy->getExtInfo());

// Coroutine lowering requires us to provide these substitutions
// in order to recreate the appropriate yield types for the accessor
Expand Down
63 changes: 63 additions & 0 deletions test/SILGen/typed_throws_generic.swift
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,66 @@ func forcedMap<T, U>(_ source: [T]) -> [U] {
// CHECK: bb0(%0 : $*U, %1 : $*Never, %2 : $*T)
return source.typedMap { $0 as! U }
}

// Witness thunks
protocol P {
associatedtype E: Error
func f() throws(E)
}

struct Res<Success, Failure: Error>: P {
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @$s20typed_throws_generic3ResVyxq_GAA1PA2aEP1fyy1EQzYKFTW : $@convention(witness_method: P) <τ_0_0, τ_0_1 where τ_0_1 : Error> (@in_guaranteed Res<τ_0_0, τ_0_1>) -> @error_indirect τ_0_1
// CHECK: bb0(%0 : $*τ_0_1, %1 : $*Res<τ_0_0, τ_0_1>):
// CHECK: [[SELF:%.*]] = load [trivial] %1 : $*Res<τ_0_0, τ_0_1>
// CHECK: [[WITNESS:%.*]] = function_ref @$s20typed_throws_generic3ResV1fyyq_YKF : $@convention(method) <τ_0_0, τ_0_1 where τ_0_1 : Error> (Res<τ_0_0, τ_0_1>) -> @error_indirect τ_0_1
// CHECK-NEXT: [[INNER_ERROR_BOX:%.*]] = alloc_stack $τ_0_1
// CHECK-NEXT: try_apply [[WITNESS]]<τ_0_0, τ_0_1>([[INNER_ERROR_BOX]], [[SELF]]) : $@convention(method) <τ_0_0, τ_0_1 where τ_0_1 : Error> (Res<τ_0_0, τ_0_1>) -> @error_indirect τ_0_1, normal [[NORMAL_BB:bb[0-9]+]], error [[ERROR_BB:bb[0-9]+]]

// CHECK: [[NORMAL_BB]]
// CHECK: dealloc_stack [[INNER_ERROR_BOX]] : $*τ_0_1

// CHECK: [[ERROR_BB]]:
// CHECK: throw_addr
func f() throws(Failure) { }
}

struct TypedRes<Success>: P {
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @$s20typed_throws_generic8TypedResVyxGAA1PA2aEP1fyy1EQzYKFTW : $@convention(witness_method: P) <τ_0_0> (@in_guaranteed TypedRes<τ_0_0>) -> @error_indirect MyError
// CHECK: bb0(%0 : $*MyError, %1 : $*TypedRes<τ_0_0>)
// CHECK: [[SELF:%.*]] = load [trivial] %1 : $*TypedRes<τ_0_0>
// CHECK: [[WITNESS:%.*]] = function_ref @$s20typed_throws_generic8TypedResV1fyyAA7MyErrorOYKF : $@convention(method) <τ_0_0> (TypedRes<τ_0_0>) -> @error MyError
// CHECK: try_apply [[WITNESS]]<τ_0_0>([[SELF]]) : $@convention(method) <τ_0_0> (TypedRes<τ_0_0>) -> @error MyError, normal [[NORMAL_BB:bb[0-9]+]], error [[ERROR_BB:bb[0-9]+]]

// CHECK: [[NORMAL_BB]]
// CHECK: return

// CHECK: [[ERROR_BB]]([[ERROR:%.*]] : $MyError):
// CHECK-NEXT: store [[ERROR]] to [trivial] %0 : $*MyError
// CHECK-NEXT: throw_addr
func f() throws(MyError) { }
}

struct UntypedRes<Success>: P {
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @$s20typed_throws_generic10UntypedResVyxGAA1PA2aEP1fyy1EQzYKFTW : $@convention(witness_method: P) <τ_0_0> (@in_guaranteed UntypedRes<τ_0_0>) -> @error_indirect any Error
// CHECK: bb0(%0 : $*any Error, %1 : $*UntypedRes<τ_0_0>):
// CHECK: [[SELF:%.*]] = load [trivial] %1 : $*UntypedRes<τ_0_0>
// CHECK: [[WITNESS:%.*]] = function_ref @$s20typed_throws_generic10UntypedResV1fyyKF : $@convention(method) <τ_0_0> (UntypedRes<τ_0_0>) -> @error any Error
// CHECK: try_apply [[WITNESS]]<τ_0_0>([[SELF]]) : $@convention(method) <τ_0_0> (UntypedRes<τ_0_0>) -> @error any Error, normal [[NORMAL_BB:bb[0-9]+]], error [[ERROR_BB:bb[0-9]+]]

// CHECK: [[NORMAL_BB]]
// CHECK: return

// CHECK: [[ERROR_BB]]([[ERROR:%.*]] : @owned $any Error):
// CHECK-NEXT: store [[ERROR]] to [init] %0 : $*any Error
// CHECK-NEXT: throw_addr
func f() throws { }
}

struct InfallibleRes<Success>: P {
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @$s20typed_throws_generic13InfallibleResVyxGAA1PA2aEP1fyy1EQzYKFTW : $@convention(witness_method: P) <τ_0_0> (@in_guaranteed InfallibleRes<τ_0_0>) -> @error_indirect any Error
// CHECK: bb0(%0 : $*any Error, %1 : $*InfallibleRes<τ_0_0>):
// CHECK: [[SELF:%.*]] = load [trivial] %1 : $*InfallibleRes<τ_0_0>
// CHECK: [[WITNESS:%.*]] = function_ref @$s20typed_throws_generic13InfallibleResV1fyyF : $@convention(method) <τ_0_0> (InfallibleRes<τ_0_0>) -> ()
// CHECK: = apply [[WITNESS]]<τ_0_0>([[SELF]]) : $@convention(method) <τ_0_0> (InfallibleRes<τ_0_0>)
func f() { }
}