Skip to content

SILGen: Base "currying" of functions on their lowered capture set, instead of their formal capture set. #5574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/SIL/TypeLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ class TypeConverter {
/// Get the capture list from a closure, with transitive function captures
/// flattened.
CaptureInfo getLoweredLocalCaptures(AnyFunctionRef fn);
bool hasLoweredLocalCaptures(AnyFunctionRef fn);

enum class ABIDifference : uint8_t {
// No ABI differences, function can be trivially bitcast to result type.
Expand Down
117 changes: 116 additions & 1 deletion lib/SIL/SILDeclRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,127 @@ bool swift::requiresForeignEntryPoint(ValueDecl *vd) {
return vd->getAttrs().hasAttribute<DynamicAttr>();
}

/// TODO: We should consult the cached LoweredLocalCaptures the SIL
/// TypeConverter calculates, but that would require plumbing SILModule&
/// through every SILDeclRef constructor. Since this is only used to determine
/// "natural uncurry level", and "uncurry level" is a concept we'd like to
/// phase out, it's not worth it.
static bool hasLoweredLocalCaptures(AnyFunctionRef AFR,
llvm::DenseSet<AnyFunctionRef> &visited) {
if (!AFR.getCaptureInfo().hasLocalCaptures())
return false;

// Scan for local, non-function captures.
bool functionCapturesToRecursivelyCheck = false;
auto addFunctionCapture = [&](AnyFunctionRef capture) {
if (visited.find(capture) == visited.end())
functionCapturesToRecursivelyCheck = true;
};
for (auto &capture : AFR.getCaptureInfo().getCaptures()) {
if (!capture.getDecl()->getDeclContext()->isLocalContext())
continue;
// We transitively capture a local function's captures.
if (auto func = dyn_cast<AbstractFunctionDecl>(capture.getDecl())) {
addFunctionCapture(func);
continue;
}
// We may either directly capture properties, or capture through their
// accessors.
if (auto var = dyn_cast<VarDecl>(capture.getDecl())) {
switch (var->getStorageKind()) {
case VarDecl::StoredWithTrivialAccessors:
llvm_unreachable("stored local variable with trivial accessors?");

case VarDecl::InheritedWithObservers:
llvm_unreachable("inherited local variable?");

case VarDecl::StoredWithObservers:
case VarDecl::Addressed:
case VarDecl::AddressedWithTrivialAccessors:
case VarDecl::AddressedWithObservers:
case VarDecl::ComputedWithMutableAddress:
// Directly capture storage if we're supposed to.
if (capture.isDirect())
return true;

// Otherwise, transitively capture the accessors.
SWIFT_FALLTHROUGH;

case VarDecl::Computed:
addFunctionCapture(var->getGetter());
if (auto setter = var->getSetter())
addFunctionCapture(setter);
continue;

case VarDecl::Stored:
return true;
}
}
// Anything else is directly captured.
return true;
}

// Recursively consider function captures, since we didn't have any direct
// captures.
auto captureHasLocalCaptures = [&](AnyFunctionRef capture) -> bool {
if (visited.insert(capture).second)
return hasLoweredLocalCaptures(capture, visited);
return false;
};

if (functionCapturesToRecursivelyCheck) {
for (auto &capture : AFR.getCaptureInfo().getCaptures()) {
if (!capture.getDecl()->getDeclContext()->isLocalContext())
continue;
if (auto func = dyn_cast<AbstractFunctionDecl>(capture.getDecl())) {
if (captureHasLocalCaptures(func))
return true;
continue;
}
if (auto var = dyn_cast<VarDecl>(capture.getDecl())) {
switch (var->getStorageKind()) {
case VarDecl::StoredWithTrivialAccessors:
llvm_unreachable("stored local variable with trivial accessors?");

case VarDecl::InheritedWithObservers:
llvm_unreachable("inherited local variable?");

case VarDecl::StoredWithObservers:
case VarDecl::Addressed:
case VarDecl::AddressedWithTrivialAccessors:
case VarDecl::AddressedWithObservers:
case VarDecl::ComputedWithMutableAddress:
assert(!capture.isDirect() && "should have short circuited out");
// Otherwise, transitively capture the accessors.
SWIFT_FALLTHROUGH;

case VarDecl::Computed:
if (captureHasLocalCaptures(var->getGetter()))
return true;
if (auto setter = var->getSetter())
if (captureHasLocalCaptures(setter))
return true;
continue;

case VarDecl::Stored:
llvm_unreachable("should have short circuited out");
}
}
llvm_unreachable("should have short circuited out");
}
}

return false;
}

static unsigned getFuncNaturalUncurryLevel(AnyFunctionRef AFR) {
assert(AFR.getParameterLists().size() >= 1 && "no arguments for func?!");
unsigned Level = AFR.getParameterLists().size() - 1;
// Functions with captures have an extra uncurry level for the capture
// context.
if (AFR.getCaptureInfo().hasLocalCaptures())
llvm::DenseSet<AnyFunctionRef> visited;
visited.insert(AFR);
if (hasLoweredLocalCaptures(AFR, visited))
Level += 1;
return Level;
}
Expand Down
7 changes: 6 additions & 1 deletion lib/SIL/TypeLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1666,7 +1666,7 @@ TypeConverter::getFunctionInterfaceTypeWithCaptures(CanAnyFunctionType funcType,

// If we don't have any local captures (including function captures),
// there's no context to apply.
if (!theClosure.getCaptureInfo().hasLocalCaptures()) {
if (!hasLoweredLocalCaptures(theClosure)) {
if (!genericSig)
return CanFunctionType::get(funcType.getInput(),
funcType.getResult(),
Expand Down Expand Up @@ -1991,6 +1991,11 @@ getAnyFunctionRefFromCapture(CapturedValue capture) {
return None;
}

bool
TypeConverter::hasLoweredLocalCaptures(AnyFunctionRef fn) {
return !getLoweredLocalCaptures(fn).getCaptures().empty();
}

CaptureInfo
TypeConverter::getLoweredLocalCaptures(AnyFunctionRef fn) {
// First, bail out if there are no local captures at all.
Expand Down
11 changes: 6 additions & 5 deletions lib/SILGen/SILGenApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ class SILGenApply : public Lowering::ExprVisitor<SILGenApply> {
// If there are captures, put the placeholder curry level in the formal
// type.
// TODO: Eliminate the need for this.
if (afd->getCaptureInfo().hasLocalCaptures())
if (SGF.SGM.M.Types.hasLoweredLocalCaptures(afd))
substFnType = CanFunctionType::get(
SGF.getASTContext().TheEmptyTupleType, substFnType);
}
Expand All @@ -1152,7 +1152,7 @@ class SILGenApply : public Lowering::ExprVisitor<SILGenApply> {
// captures in the constant info too, to generate more efficient
// code for mutually recursive local functions which otherwise
// capture no state.
if (afd->getCaptureInfo().hasLocalCaptures()) {
if (SGF.SGM.M.Types.hasLoweredLocalCaptures(afd)) {
SmallVector<ManagedValue, 4> captures;
SGF.emitCaptures(e, afd, CaptureEmission::ImmediateApplication,
captures);
Expand Down Expand Up @@ -1197,14 +1197,15 @@ class SILGenApply : public Lowering::ExprVisitor<SILGenApply> {
// If there are captures, put the placeholder curry level in the formal
// type.
// TODO: Eliminate the need for this.
if (e->getCaptureInfo().hasLocalCaptures())
bool hasCaptures = SGF.SGM.M.Types.hasLoweredLocalCaptures(e);
if (hasCaptures)
substFnType = CanFunctionType::get(
SGF.getASTContext().TheEmptyTupleType, substFnType);

setCallee(Callee::forDirect(SGF, constant, substFnType, e));

// If the closure requires captures, emit them.
if (e->getCaptureInfo().hasLocalCaptures()) {
if (hasCaptures) {
SmallVector<ManagedValue, 4> captures;
SGF.emitCaptures(e, e, CaptureEmission::ImmediateApplication,
captures);
Expand Down Expand Up @@ -4974,7 +4975,7 @@ emitSpecializedAccessorFunctionRef(SILGenFunction &gen,

// Collect captures if the accessor has them.
auto accessorFn = cast<AbstractFunctionDecl>(constant.getDecl());
if (accessorFn->getCaptureInfo().hasLocalCaptures()) {
if (gen.SGM.M.Types.hasLoweredLocalCaptures(accessorFn)) {
assert(!selfValue && "local property has self param?!");
SmallVector<ManagedValue, 4> captures;
gen.emitCaptures(loc, accessorFn, CaptureEmission::ImmediateApplication,
Expand Down
2 changes: 1 addition & 1 deletion lib/SILGen/SILGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ emitRValueForDecl(SILLocation loc, ConcreteDeclRef declRef, Type ncRefType,
bool hasLocalCaptures = false;
unsigned uncurryLevel = 0;
if (auto *fd = dyn_cast<FuncDecl>(decl)) {
hasLocalCaptures = fd->getCaptureInfo().hasLocalCaptures();
hasLocalCaptures = SGM.M.Types.hasLoweredLocalCaptures(fd);
if (hasLocalCaptures)
++uncurryLevel;
}
Expand Down
14 changes: 6 additions & 8 deletions lib/SILGen/SILGenFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,9 @@ SILGenFunction::emitClosureValue(SILLocation loc, SILDeclRef constant,
auto closure = *constant.getAnyFunctionRef();
auto captureInfo = closure.getCaptureInfo();
auto loweredCaptureInfo = SGM.Types.getLoweredLocalCaptures(closure);

assert(((constant.uncurryLevel == 1 &&
captureInfo.hasLocalCaptures()) ||
(constant.uncurryLevel == 0 &&
!captureInfo.hasLocalCaptures())) &&
auto hasCaptures = SGM.Types.hasLoweredLocalCaptures(closure);
assert(((constant.uncurryLevel == 1 && hasCaptures) ||
(constant.uncurryLevel == 0 && !hasCaptures)) &&
"curried local functions not yet supported");

auto constantInfo = getConstantInfo(constant);
Expand Down Expand Up @@ -406,7 +404,7 @@ SILGenFunction::emitClosureValue(SILLocation loc, SILDeclRef constant,
loc, captureInfo);
}

if (!captureInfo.hasLocalCaptures() && !wasSpecialized) {
if (!hasCaptures && !wasSpecialized) {
auto result = ManagedValue::forUnmanaged(functionRef);
return emitOrigToSubstValue(loc, result,
AbstractionPattern(expectedType),
Expand Down Expand Up @@ -436,7 +434,7 @@ SILGenFunction::emitClosureValue(SILLocation loc, SILDeclRef constant,
// - the original type
auto origLoweredFormalType =
AbstractionPattern(constantInfo.LoweredInterfaceType);
if (captureInfo.hasLocalCaptures()) {
if (hasCaptures) {
// Get the unlowered formal type of the constant, stripping off
// the first level of function application, which applies captures.
origLoweredFormalType =
Expand Down Expand Up @@ -794,7 +792,7 @@ void SILGenFunction::emitCurryThunk(ValueDecl *vd,

} else if (auto fd = dyn_cast<AbstractFunctionDecl>(vd)) {
// Forward implicit closure context arguments.
bool hasCaptures = fd->getCaptureInfo().hasLocalCaptures();
bool hasCaptures = SGM.M.Types.hasLoweredLocalCaptures(fd);
if (hasCaptures)
--paramCount;

Expand Down
9 changes: 2 additions & 7 deletions test/SILGen/local_captures.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,12 @@ func globalfunc() -> () -> () {
func localFunc() {
}

// CHECK-LABEL: sil shared @_TFF14local_captures10globalfuncFT_FT_T_L_6callitfT_T_ : $@convention(thin) () -> ()
// CHECK: function_ref @_TFF14local_captures10globalfuncFT_FT_T_L_9localFuncFT_T_ : $@convention(thin) () -> ()
// CHECK-NEXT: apply
// CHECK-LABEL: sil shared @_TFF14local_captures10globalfuncFT_FT_T_L_6callitFT_T_ : $@convention(thin) () -> ()
func callit() {
localFunc()
}

// CHECK-LABEL: sil shared @_TFF14local_captures10globalfuncFT_FT_T_L_5getitfT_FT_T_ : $@convention(thin) () -> @owned @callee_owned () -> ()
// CHECK: function_ref @_TFF14local_captures10globalfuncFT_FT_T_L_9localFuncFT_T_ : $@convention(thin) () -> ()
// CHECK-NEXT: thin_to_thick_function
// CHECK-NEXT: return
// CHECK-LABEL: sil shared @_TFF14local_captures10globalfuncFT_FT_T_L_5getitFT_FT_T_ : $@convention(thin) () -> @owned @callee_owned () -> ()
func getit() -> () -> () {
return localFunc
}
Expand Down
33 changes: 33 additions & 0 deletions test/SILGen/nested_types_referencing_nested_functions.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s

do {
func foo() { bar(2) }
func bar<T>(_: T) { foo() }

class Foo {
// CHECK-LABEL: sil shared @_TFC41nested_types_referencing_nested_functionsL_3FoocfT_S0_ : $@convention(method) (@owned Foo) -> @owned Foo {
init() {
foo()
}
// CHECK-LABEL: sil shared @_TFC41nested_types_referencing_nested_functionsL_3Foo3zimfT_T_ : $@convention(method) (@guaranteed Foo) -> ()
func zim() {
foo()
}
// CHECK-LABEL: sil shared @_TFC41nested_types_referencing_nested_functionsL_3Foo4zangurfxT_ : $@convention(method) <T> (@in T, @guaranteed Foo) -> ()
func zang<T>(_ x: T) {
bar(x)
}
// CHECK-LABEL: sil shared @_TFC41nested_types_referencing_nested_functionsL_3FooD : $@convention(method) (@owned Foo) -> ()
deinit {
foo()
}
}

let x = Foo()
x.zim()
x.zang(1)
_ = Foo.zim
_ = Foo.zang as (Foo) -> (Int) -> ()
_ = x.zim
_ = x.zang as (Int) -> ()
}