Skip to content

Commit f81e55c

Browse files
committed
SILGen: Base "currying" of functions on their lowered capture set, instead of their formal capture set.
This allows for slightly better codegen for nested functions that refer to other nested functions that don't transitively capture any local state, but more importantly, allows methods of local types to work while still referring to nested functions that don't capture local state, fixing rdar://problem/28015090.
1 parent cbc9d89 commit f81e55c

File tree

8 files changed

+171
-23
lines changed

8 files changed

+171
-23
lines changed

include/swift/SIL/TypeLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,7 @@ class TypeConverter {
746746
/// Get the capture list from a closure, with transitive function captures
747747
/// flattened.
748748
CaptureInfo getLoweredLocalCaptures(AnyFunctionRef fn);
749+
bool hasLoweredLocalCaptures(AnyFunctionRef fn);
749750

750751
enum class ABIDifference : uint8_t {
751752
// No ABI differences, function can be trivially bitcast to result type.

lib/SIL/SILDeclRef.cpp

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,127 @@ bool swift::requiresForeignEntryPoint(ValueDecl *vd) {
112112
return vd->getAttrs().hasAttribute<DynamicAttr>();
113113
}
114114

115+
/// TODO: We should consult the cached LoweredLocalCaptures the SIL
116+
/// TypeConverter calculates, but that would require plumbing SILModule&
117+
/// through every SILDeclRef constructor. Since this is only used to determine
118+
/// "natural uncurry level", and "uncurry level" is a concept we'd like to
119+
/// phase out, it's not worth it.
120+
static bool hasLoweredLocalCaptures(AnyFunctionRef AFR,
121+
llvm::DenseSet<AnyFunctionRef> &visited) {
122+
if (!AFR.getCaptureInfo().hasLocalCaptures())
123+
return false;
124+
125+
// Scan for local, non-function captures.
126+
bool functionCapturesToRecursivelyCheck = false;
127+
auto addFunctionCapture = [&](AnyFunctionRef capture) {
128+
if (visited.find(capture) == visited.end())
129+
functionCapturesToRecursivelyCheck = true;
130+
};
131+
for (auto &capture : AFR.getCaptureInfo().getCaptures()) {
132+
if (!capture.getDecl()->getDeclContext()->isLocalContext())
133+
continue;
134+
// We transitively capture a local function's captures.
135+
if (auto func = dyn_cast<AbstractFunctionDecl>(capture.getDecl())) {
136+
addFunctionCapture(func);
137+
continue;
138+
}
139+
// We may either directly capture properties, or capture through their
140+
// accessors.
141+
if (auto var = dyn_cast<VarDecl>(capture.getDecl())) {
142+
switch (var->getStorageKind()) {
143+
case VarDecl::StoredWithTrivialAccessors:
144+
llvm_unreachable("stored local variable with trivial accessors?");
145+
146+
case VarDecl::InheritedWithObservers:
147+
llvm_unreachable("inherited local variable?");
148+
149+
case VarDecl::StoredWithObservers:
150+
case VarDecl::Addressed:
151+
case VarDecl::AddressedWithTrivialAccessors:
152+
case VarDecl::AddressedWithObservers:
153+
case VarDecl::ComputedWithMutableAddress:
154+
// Directly capture storage if we're supposed to.
155+
if (capture.isDirect())
156+
return true;
157+
158+
// Otherwise, transitively capture the accessors.
159+
SWIFT_FALLTHROUGH;
160+
161+
case VarDecl::Computed:
162+
addFunctionCapture(var->getGetter());
163+
if (auto setter = var->getSetter())
164+
addFunctionCapture(setter);
165+
continue;
166+
167+
case VarDecl::Stored:
168+
return true;
169+
}
170+
}
171+
// Anything else is directly captured.
172+
return true;
173+
}
174+
175+
// Recursively consider function captures, since we didn't have any direct
176+
// captures.
177+
auto captureHasLocalCaptures = [&](AnyFunctionRef capture) -> bool {
178+
if (visited.insert(capture).second)
179+
return hasLoweredLocalCaptures(capture, visited);
180+
return false;
181+
};
182+
183+
if (functionCapturesToRecursivelyCheck) {
184+
for (auto &capture : AFR.getCaptureInfo().getCaptures()) {
185+
if (!capture.getDecl()->getDeclContext()->isLocalContext())
186+
continue;
187+
if (auto func = dyn_cast<AbstractFunctionDecl>(capture.getDecl())) {
188+
if (captureHasLocalCaptures(func))
189+
return true;
190+
continue;
191+
}
192+
if (auto var = dyn_cast<VarDecl>(capture.getDecl())) {
193+
switch (var->getStorageKind()) {
194+
case VarDecl::StoredWithTrivialAccessors:
195+
llvm_unreachable("stored local variable with trivial accessors?");
196+
197+
case VarDecl::InheritedWithObservers:
198+
llvm_unreachable("inherited local variable?");
199+
200+
case VarDecl::StoredWithObservers:
201+
case VarDecl::Addressed:
202+
case VarDecl::AddressedWithTrivialAccessors:
203+
case VarDecl::AddressedWithObservers:
204+
case VarDecl::ComputedWithMutableAddress:
205+
assert(!capture.isDirect() && "should have short circuited out");
206+
// Otherwise, transitively capture the accessors.
207+
SWIFT_FALLTHROUGH;
208+
209+
case VarDecl::Computed:
210+
if (captureHasLocalCaptures(var->getGetter()))
211+
return true;
212+
if (auto setter = var->getSetter())
213+
if (captureHasLocalCaptures(setter))
214+
return true;
215+
continue;
216+
217+
case VarDecl::Stored:
218+
llvm_unreachable("should have short circuited out");
219+
}
220+
}
221+
llvm_unreachable("should have short circuited out");
222+
}
223+
}
224+
225+
return false;
226+
}
227+
115228
static unsigned getFuncNaturalUncurryLevel(AnyFunctionRef AFR) {
116229
assert(AFR.getParameterLists().size() >= 1 && "no arguments for func?!");
117230
unsigned Level = AFR.getParameterLists().size() - 1;
118231
// Functions with captures have an extra uncurry level for the capture
119232
// context.
120-
if (AFR.getCaptureInfo().hasLocalCaptures())
233+
llvm::DenseSet<AnyFunctionRef> visited;
234+
visited.insert(AFR);
235+
if (hasLoweredLocalCaptures(AFR, visited))
121236
Level += 1;
122237
return Level;
123238
}

lib/SIL/TypeLowering.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1666,7 +1666,7 @@ TypeConverter::getFunctionInterfaceTypeWithCaptures(CanAnyFunctionType funcType,
16661666

16671667
// If we don't have any local captures (including function captures),
16681668
// there's no context to apply.
1669-
if (!theClosure.getCaptureInfo().hasLocalCaptures()) {
1669+
if (!hasLoweredLocalCaptures(theClosure)) {
16701670
if (!genericSig)
16711671
return CanFunctionType::get(funcType.getInput(),
16721672
funcType.getResult(),
@@ -1991,6 +1991,11 @@ getAnyFunctionRefFromCapture(CapturedValue capture) {
19911991
return None;
19921992
}
19931993

1994+
bool
1995+
TypeConverter::hasLoweredLocalCaptures(AnyFunctionRef fn) {
1996+
return !getLoweredLocalCaptures(fn).getCaptures().empty();
1997+
}
1998+
19941999
CaptureInfo
19952000
TypeConverter::getLoweredLocalCaptures(AnyFunctionRef fn) {
19962001
// First, bail out if there are no local captures at all.

lib/SILGen/SILGenApply.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,7 +1131,7 @@ class SILGenApply : public Lowering::ExprVisitor<SILGenApply> {
11311131
// If there are captures, put the placeholder curry level in the formal
11321132
// type.
11331133
// TODO: Eliminate the need for this.
1134-
if (afd->getCaptureInfo().hasLocalCaptures())
1134+
if (SGF.SGM.M.Types.hasLoweredLocalCaptures(afd))
11351135
substFnType = CanFunctionType::get(
11361136
SGF.getASTContext().TheEmptyTupleType, substFnType);
11371137
}
@@ -1152,7 +1152,7 @@ class SILGenApply : public Lowering::ExprVisitor<SILGenApply> {
11521152
// captures in the constant info too, to generate more efficient
11531153
// code for mutually recursive local functions which otherwise
11541154
// capture no state.
1155-
if (afd->getCaptureInfo().hasLocalCaptures()) {
1155+
if (SGF.SGM.M.Types.hasLoweredLocalCaptures(afd)) {
11561156
SmallVector<ManagedValue, 4> captures;
11571157
SGF.emitCaptures(e, afd, CaptureEmission::ImmediateApplication,
11581158
captures);
@@ -1197,14 +1197,15 @@ class SILGenApply : public Lowering::ExprVisitor<SILGenApply> {
11971197
// If there are captures, put the placeholder curry level in the formal
11981198
// type.
11991199
// TODO: Eliminate the need for this.
1200-
if (e->getCaptureInfo().hasLocalCaptures())
1200+
bool hasCaptures = SGF.SGM.M.Types.hasLoweredLocalCaptures(e);
1201+
if (hasCaptures)
12011202
substFnType = CanFunctionType::get(
12021203
SGF.getASTContext().TheEmptyTupleType, substFnType);
12031204

12041205
setCallee(Callee::forDirect(SGF, constant, substFnType, e));
12051206

12061207
// If the closure requires captures, emit them.
1207-
if (e->getCaptureInfo().hasLocalCaptures()) {
1208+
if (hasCaptures) {
12081209
SmallVector<ManagedValue, 4> captures;
12091210
SGF.emitCaptures(e, e, CaptureEmission::ImmediateApplication,
12101211
captures);
@@ -4974,7 +4975,7 @@ emitSpecializedAccessorFunctionRef(SILGenFunction &gen,
49744975

49754976
// Collect captures if the accessor has them.
49764977
auto accessorFn = cast<AbstractFunctionDecl>(constant.getDecl());
4977-
if (accessorFn->getCaptureInfo().hasLocalCaptures()) {
4978+
if (gen.SGM.M.Types.hasLoweredLocalCaptures(accessorFn)) {
49784979
assert(!selfValue && "local property has self param?!");
49794980
SmallVector<ManagedValue, 4> captures;
49804981
gen.emitCaptures(loc, accessorFn, CaptureEmission::ImmediateApplication,

lib/SILGen/SILGenExpr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ emitRValueForDecl(SILLocation loc, ConcreteDeclRef declRef, Type ncRefType,
443443
bool hasLocalCaptures = false;
444444
unsigned uncurryLevel = 0;
445445
if (auto *fd = dyn_cast<FuncDecl>(decl)) {
446-
hasLocalCaptures = fd->getCaptureInfo().hasLocalCaptures();
446+
hasLocalCaptures = SGM.M.Types.hasLoweredLocalCaptures(fd);
447447
if (hasLocalCaptures)
448448
++uncurryLevel;
449449
}

lib/SILGen/SILGenFunction.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -362,11 +362,9 @@ SILGenFunction::emitClosureValue(SILLocation loc, SILDeclRef constant,
362362
auto closure = *constant.getAnyFunctionRef();
363363
auto captureInfo = closure.getCaptureInfo();
364364
auto loweredCaptureInfo = SGM.Types.getLoweredLocalCaptures(closure);
365-
366-
assert(((constant.uncurryLevel == 1 &&
367-
captureInfo.hasLocalCaptures()) ||
368-
(constant.uncurryLevel == 0 &&
369-
!captureInfo.hasLocalCaptures())) &&
365+
auto hasCaptures = SGM.Types.hasLoweredLocalCaptures(closure);
366+
assert(((constant.uncurryLevel == 1 && hasCaptures) ||
367+
(constant.uncurryLevel == 0 && !hasCaptures)) &&
370368
"curried local functions not yet supported");
371369

372370
auto constantInfo = getConstantInfo(constant);
@@ -406,7 +404,7 @@ SILGenFunction::emitClosureValue(SILLocation loc, SILDeclRef constant,
406404
loc, captureInfo);
407405
}
408406

409-
if (!captureInfo.hasLocalCaptures() && !wasSpecialized) {
407+
if (!hasCaptures && !wasSpecialized) {
410408
auto result = ManagedValue::forUnmanaged(functionRef);
411409
return emitOrigToSubstValue(loc, result,
412410
AbstractionPattern(expectedType),
@@ -436,7 +434,7 @@ SILGenFunction::emitClosureValue(SILLocation loc, SILDeclRef constant,
436434
// - the original type
437435
auto origLoweredFormalType =
438436
AbstractionPattern(constantInfo.LoweredInterfaceType);
439-
if (captureInfo.hasLocalCaptures()) {
437+
if (hasCaptures) {
440438
// Get the unlowered formal type of the constant, stripping off
441439
// the first level of function application, which applies captures.
442440
origLoweredFormalType =
@@ -794,7 +792,7 @@ void SILGenFunction::emitCurryThunk(ValueDecl *vd,
794792

795793
} else if (auto fd = dyn_cast<AbstractFunctionDecl>(vd)) {
796794
// Forward implicit closure context arguments.
797-
bool hasCaptures = fd->getCaptureInfo().hasLocalCaptures();
795+
bool hasCaptures = SGM.M.Types.hasLoweredLocalCaptures(fd);
798796
if (hasCaptures)
799797
--paramCount;
800798

test/SILGen/local_captures.swift

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,12 @@ func globalfunc() -> () -> () {
1010
func localFunc() {
1111
}
1212

13-
// CHECK-LABEL: sil shared @_TFF14local_captures10globalfuncFT_FT_T_L_6callitfT_T_ : $@convention(thin) () -> ()
14-
// CHECK: function_ref @_TFF14local_captures10globalfuncFT_FT_T_L_9localFuncFT_T_ : $@convention(thin) () -> ()
15-
// CHECK-NEXT: apply
13+
// CHECK-LABEL: sil shared @_TFF14local_captures10globalfuncFT_FT_T_L_6callitFT_T_ : $@convention(thin) () -> ()
1614
func callit() {
1715
localFunc()
1816
}
1917

20-
// CHECK-LABEL: sil shared @_TFF14local_captures10globalfuncFT_FT_T_L_5getitfT_FT_T_ : $@convention(thin) () -> @owned @callee_owned () -> ()
21-
// CHECK: function_ref @_TFF14local_captures10globalfuncFT_FT_T_L_9localFuncFT_T_ : $@convention(thin) () -> ()
22-
// CHECK-NEXT: thin_to_thick_function
23-
// CHECK-NEXT: return
18+
// CHECK-LABEL: sil shared @_TFF14local_captures10globalfuncFT_FT_T_L_5getitFT_FT_T_ : $@convention(thin) () -> @owned @callee_owned () -> ()
2419
func getit() -> () -> () {
2520
return localFunc
2621
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s
2+
3+
do {
4+
func foo() { bar(2) }
5+
func bar<T>(_: T) { foo() }
6+
7+
class Foo {
8+
// CHECK-LABEL: sil shared @_TFC41nested_types_referencing_nested_functionsL_3FoocfT_S0_ : $@convention(method) (@owned Foo) -> @owned Foo {
9+
init() {
10+
foo()
11+
}
12+
// CHECK-LABEL: sil shared @_TFC41nested_types_referencing_nested_functionsL_3Foo3zimfT_T_ : $@convention(method) (@guaranteed Foo) -> ()
13+
func zim() {
14+
foo()
15+
}
16+
// CHECK-LABEL: sil shared @_TFC41nested_types_referencing_nested_functionsL_3Foo4zangurfxT_ : $@convention(method) <T> (@in T, @guaranteed Foo) -> ()
17+
func zang<T>(_ x: T) {
18+
bar(x)
19+
}
20+
// CHECK-LABEL: sil shared @_TFC41nested_types_referencing_nested_functionsL_3FooD : $@convention(method) (@owned Foo) -> ()
21+
deinit {
22+
foo()
23+
}
24+
}
25+
26+
let x = Foo()
27+
x.zim()
28+
x.zang(1)
29+
_ = Foo.zim
30+
_ = Foo.zang as (Foo) -> (Int) -> ()
31+
_ = x.zim
32+
_ = x.zang as (Int) -> ()
33+
}

0 commit comments

Comments
 (0)