Skip to content

GenericSpecializer: support specializing typed throws #70000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions SwiftCompilerSources/Sources/SIL/Function.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ final public class Function : CustomStringConvertible, HasShortDescription, Hash
blocks.reversed().lazy.flatMap { $0.instructions.reversed() }
}

/// The number of indirect result arguments.
public var numIndirectResultArguments: Int { bridged.getNumIndirectFormalResults() }

public var hasIndirectErrorArgument: Bool { bridged.hasIndirectErrorResult() }

/// The number of arguments which correspond to parameters (and not to indirect results).
public var numParameterArguments: Int { bridged.getNumParameters() }

Expand All @@ -66,7 +66,9 @@ final public class Function : CustomStringConvertible, HasShortDescription, Hash
/// This is the sum of indirect result arguments and parameter arguments.
/// If the function is a definition (i.e. it has at least an entry block), this is the
/// number of arguments of the function's entry block.
public var numArguments: Int { numIndirectResultArguments + numParameterArguments }
public var numArguments: Int {
numIndirectResultArguments + (hasIndirectErrorArgument ? 1 : 0) + numParameterArguments
}

public var hasSelfArgument: Bool {
bridged.getSelfArgumentIndex() >= 0
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -4785,6 +4785,9 @@ class SILFunctionType final
unsigned getNumPackResults() const {
return isCoroutine() ? 0 : NumPackResults;
}
bool hasIndirectErrorResult() const {
return hasErrorResult() && getErrorResult().isFormalIndirect();
}

struct IndirectFormalResultFilter {
bool operator()(SILResultInfo result) const {
Expand Down
1 change: 1 addition & 0 deletions include/swift/SIL/SILBridging.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ struct BridgedFunction {
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE OptionalBridgedBasicBlock getFirstBlock() const;
SWIFT_IMPORT_UNSAFE BRIDGED_INLINE OptionalBridgedBasicBlock getLastBlock() const;
BRIDGED_INLINE SwiftInt getNumIndirectFormalResults() const;
BRIDGED_INLINE bool hasIndirectErrorResult() const;
BRIDGED_INLINE SwiftInt getNumParameters() const;
BRIDGED_INLINE SwiftInt getSelfArgumentIndex() const;
BRIDGED_INLINE SwiftInt getNumSILArguments() const;
Expand Down
4 changes: 4 additions & 0 deletions include/swift/SIL/SILBridgingImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,10 @@ SwiftInt BridgedFunction::getNumIndirectFormalResults() const {
return (SwiftInt)getFunction()->getLoweredFunctionType()->getNumIndirectFormalResults();
}

bool BridgedFunction::hasIndirectErrorResult() const {
return (SwiftInt)getFunction()->getLoweredFunctionType()->hasIndirectErrorResult();
}

SwiftInt BridgedFunction::getNumParameters() const {
return (SwiftInt)getFunction()->getLoweredFunctionType()->getNumParameters();
}
Expand Down
6 changes: 6 additions & 0 deletions include/swift/SIL/SILFunctionConventions.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,12 @@ class SILFunctionConventions {
return 0;
}

bool isArgumentIndexOfIndirectErrorResult(unsigned idx) {
unsigned indirectResults = getNumIndirectSILResults();
return idx >= indirectResults &&
idx < indirectResults + getNumIndirectSILErrorResults();
}

/// Are any SIL results passed as address-typed arguments?
bool hasIndirectSILResults() const { return getNumIndirectSILResults() != 0; }
bool hasIndirectSILErrorResults() const { return getNumIndirectSILErrorResults() != 0; }
Expand Down
7 changes: 1 addition & 6 deletions include/swift/SILOptimizer/Utils/GenericCloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class GenericCloner

llvm::SmallVector<AllocStackInst *, 8> AllocStacks;
llvm::SmallVector<StoreBorrowInst *, 8> StoreBorrowsToCleanup;
llvm::SmallVector<TermInst *, 8> FunctionExits;
AllocStackInst *ReturnValueAddr = nullptr;
AllocStackInst *ErrorValueAddr = nullptr;

public:
friend class SILCloner<GenericCloner>;
Expand Down Expand Up @@ -94,11 +94,6 @@ class GenericCloner
if (Callback)
Callback(Orig, Cloned);

if (auto *termInst = dyn_cast<TermInst>(Cloned)) {
if (termInst->isFunctionExiting()) {
FunctionExits.push_back(termInst);
}
}
SILClonerWithScopes<GenericCloner>::postProcess(Orig, Cloned);
}

Expand Down
15 changes: 14 additions & 1 deletion include/swift/SILOptimizer/Utils/Generics.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class ReabstractionInfo {
/// See `droppedMetatypeArgs`.
bool dropMetatypeArgs = false;

bool hasIndirectErrorResult = false;

/// The first NumResults bits in Conversions refer to formal indirect
/// out-parameters.
unsigned NumFormalIndirectResults = 0;
Expand Down Expand Up @@ -190,6 +192,8 @@ class ReabstractionInfo {
void finishPartialSpecializationPreparation(
FunctionSignaturePartialSpecializer &FSPS);

TypeCategory handleReturnAndError(SILResultInfo RI, unsigned argIdx);

public:
ReabstractionInfo(SILModule &M) : M(&M) {}

Expand Down Expand Up @@ -227,7 +231,12 @@ class ReabstractionInfo {
}

unsigned param2ArgIndex(unsigned ParamIdx) const {
return ParamIdx + NumFormalIndirectResults;
return ParamIdx + NumFormalIndirectResults + (hasIndirectErrorResult ? 1: 0);
}

unsigned indirectErrorIndex() const {
assert(hasIndirectErrorResult);
return NumFormalIndirectResults;
}

/// Returns true if the specialized function needs an alternative mangling.
Expand Down Expand Up @@ -255,6 +264,10 @@ class ReabstractionInfo {
return ConvertIndirectToDirect && Conversions.test(ResultIdx);
}

bool isErrorResultConverted() const {
return ConvertIndirectToDirect && Conversions.test(indirectErrorIndex());
}

/// Gets the total number of original function arguments.
unsigned getNumArguments() const { return Conversions.size(); }

Expand Down
11 changes: 7 additions & 4 deletions lib/SILOptimizer/Transforms/SimplifyCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2550,18 +2550,21 @@ bool SimplifyCFG::simplifyTryApplyBlock(TryApplyInst *TAI) {
auto context = TAI->getFunction()->getTypeExpansionContext();
SmallVector<SILValue, 8> Args;
unsigned numArgs = TAI->getNumArguments();
unsigned calleeArgIdx = 0;
for (unsigned i = 0; i < numArgs; ++i) {
auto Arg = TAI->getArgument(i);
if (origConv.isArgumentIndexOfIndirectErrorResult(i) &&
!targetConv.isArgumentIndexOfIndirectErrorResult(i)) {
continue;
}
// Cast argument if required.
std::tie(Arg, std::ignore) = castValueToABICompatibleType(
&Builder, TAI->getLoc(), Arg, origConv.getSILArgumentType(i, context),
targetConv.getSILArgumentType(i, context), {TAI});
targetConv.getSILArgumentType(calleeArgIdx, context), {TAI});
Args.push_back(Arg);
calleeArgIdx += 1;
}

assert(calleeConv.getNumSILArguments() == Args.size()
&& "The number of arguments should match");

LLVM_DEBUG(llvm::dbgs() << "replace with apply: " << *TAI);

// If the new callee is owned, copy it to extend the lifetime
Expand Down
53 changes: 40 additions & 13 deletions lib/SILOptimizer/Utils/GenericCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ SILFunction *GenericCloner::createDeclaration(
void GenericCloner::populateCloned() {
assert(AllocStacks.empty() && "Stale cloner state.");
assert(!ReturnValueAddr && "Stale cloner state.");
assert(!ErrorValueAddr && "Stale cloner state.");

SILFunction *Cloned = getCloned();
// Create arguments for the entry block.
Expand Down Expand Up @@ -95,17 +96,33 @@ void GenericCloner::populateCloned() {
return false;

if (ArgIdx < origConv.getSILArgIndexOfFirstParam()) {
// Handle result arguments.
unsigned formalIdx =
origConv.getIndirectFormalResultIndexForSILArg(ArgIdx);
if (ReInfo.isFormalResultConverted(formalIdx)) {
// This result is converted from indirect to direct. The return inst
// needs to load the value from the alloc_stack. See below.
createAllocStack();
assert(!ReturnValueAddr);
ReturnValueAddr = ASI;
entryArgs.push_back(ASI);
return true;
if (ArgIdx < origConv.getNumIndirectSILResults()) {
// Handle result arguments.
unsigned formalIdx =
origConv.getIndirectFormalResultIndexForSILArg(ArgIdx);
if (ReInfo.isFormalResultConverted(formalIdx)) {
// This result is converted from indirect to direct. The return inst
// needs to load the value from the alloc_stack. See below.
createAllocStack();
assert(!ReturnValueAddr);
ReturnValueAddr = ASI;
entryArgs.push_back(ASI);
return true;
}
} else {
assert(origConv.getNumIndirectSILErrorResults() == 1 &&
"only a single indirect error result is supported");
assert(ArgIdx == origConv.getNumIndirectSILResults());

if (ReInfo.isErrorResultConverted()) {
// This error result is converted from indirect to direct. The throw
// instruction needs to load the value from the alloc_stack. See below.
createAllocStack();
assert(!ErrorValueAddr);
ErrorValueAddr = ASI;
entryArgs.push_back(ASI);
return true;
}
}
} else if (ReInfo.isDroppedMetatypeArg(ArgIdx)) {
// Replace the metatype argument with an `metatype` instruction in the
Expand Down Expand Up @@ -192,10 +209,20 @@ void GenericCloner::visitTerminator(SILBasicBlock *BB) {
getBuilder().createDeallocStack(ASI->getLoc(), ASI);
}
if (ReturnValue) {
auto *NewReturn = getBuilder().createReturn(RI->getLoc(), ReturnValue);
FunctionExits.push_back(NewReturn);
getBuilder().createReturn(RI->getLoc(), ReturnValue);
return;
}
} else if (isa<ThrowAddrInst>(OrigTermInst) && ErrorValueAddr) {
// The result is converted from indirect to direct. We have to load the
// returned value from the alloc_stack.
SILValue errorValue = getBuilder().emitLoadValueOperation(
ErrorValueAddr->getLoc(), ErrorValueAddr,
LoadOwnershipQualifier::Take);
for (AllocStackInst *ASI : reverse(AllocStacks)) {
getBuilder().createDeallocStack(ASI->getLoc(), ASI);
}
getBuilder().createThrow(OrigTermInst->getLoc(), errorValue);
return;
} else if (OrigTermInst->isFunctionExiting()) {
for (AllocStackInst *ASI : reverse(AllocStacks)) {
getBuilder().createDeallocStack(ASI->getLoc(), ASI);
Expand Down
Loading