Skip to content

Commit a66ab83

Browse files
authored
Merge pull request #5574 from jckarter/nested-function-lowered-captures
SILGen: Base "currying" of functions on their lowered capture set, instead of their formal capture set.
2 parents b5965f2 + f81e55c commit a66ab83

File tree

8 files changed

+171
-23
lines changed

8 files changed

+171
-23
lines changed

include/swift/SIL/TypeLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,7 @@ class TypeConverter {
762762
/// Get the capture list from a closure, with transitive function captures
763763
/// flattened.
764764
CaptureInfo getLoweredLocalCaptures(AnyFunctionRef fn);
765+
bool hasLoweredLocalCaptures(AnyFunctionRef fn);
765766

766767
enum class ABIDifference : uint8_t {
767768
// No ABI differences, function can be trivially bitcast to result type.

lib/SIL/SILDeclRef.cpp

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,127 @@ bool swift::requiresForeignEntryPoint(ValueDecl *vd) {
112112
return vd->getAttrs().hasAttribute<DynamicAttr>();
113113
}
114114

115+
/// TODO: We should consult the cached LoweredLocalCaptures the SIL
116+
/// TypeConverter calculates, but that would require plumbing SILModule&
117+
/// through every SILDeclRef constructor. Since this is only used to determine
118+
/// "natural uncurry level", and "uncurry level" is a concept we'd like to
119+
/// phase out, it's not worth it.
120+
static bool hasLoweredLocalCaptures(AnyFunctionRef AFR,
121+
llvm::DenseSet<AnyFunctionRef> &visited) {
122+
if (!AFR.getCaptureInfo().hasLocalCaptures())
123+
return false;
124+
125+
// Scan for local, non-function captures.
126+
bool functionCapturesToRecursivelyCheck = false;
127+
auto addFunctionCapture = [&](AnyFunctionRef capture) {
128+
if (visited.find(capture) == visited.end())
129+
functionCapturesToRecursivelyCheck = true;
130+
};
131+
for (auto &capture : AFR.getCaptureInfo().getCaptures()) {
132+
if (!capture.getDecl()->getDeclContext()->isLocalContext())
133+
continue;
134+
// We transitively capture a local function's captures.
135+
if (auto func = dyn_cast<AbstractFunctionDecl>(capture.getDecl())) {
136+
addFunctionCapture(func);
137+
continue;
138+
}
139+
// We may either directly capture properties, or capture through their
140+
// accessors.
141+
if (auto var = dyn_cast<VarDecl>(capture.getDecl())) {
142+
switch (var->getStorageKind()) {
143+
case VarDecl::StoredWithTrivialAccessors:
144+
llvm_unreachable("stored local variable with trivial accessors?");
145+
146+
case VarDecl::InheritedWithObservers:
147+
llvm_unreachable("inherited local variable?");
148+
149+
case VarDecl::StoredWithObservers:
150+
case VarDecl::Addressed:
151+
case VarDecl::AddressedWithTrivialAccessors:
152+
case VarDecl::AddressedWithObservers:
153+
case VarDecl::ComputedWithMutableAddress:
154+
// Directly capture storage if we're supposed to.
155+
if (capture.isDirect())
156+
return true;
157+
158+
// Otherwise, transitively capture the accessors.
159+
SWIFT_FALLTHROUGH;
160+
161+
case VarDecl::Computed:
162+
addFunctionCapture(var->getGetter());
163+
if (auto setter = var->getSetter())
164+
addFunctionCapture(setter);
165+
continue;
166+
167+
case VarDecl::Stored:
168+
return true;
169+
}
170+
}
171+
// Anything else is directly captured.
172+
return true;
173+
}
174+
175+
// Recursively consider function captures, since we didn't have any direct
176+
// captures.
177+
auto captureHasLocalCaptures = [&](AnyFunctionRef capture) -> bool {
178+
if (visited.insert(capture).second)
179+
return hasLoweredLocalCaptures(capture, visited);
180+
return false;
181+
};
182+
183+
if (functionCapturesToRecursivelyCheck) {
184+
for (auto &capture : AFR.getCaptureInfo().getCaptures()) {
185+
if (!capture.getDecl()->getDeclContext()->isLocalContext())
186+
continue;
187+
if (auto func = dyn_cast<AbstractFunctionDecl>(capture.getDecl())) {
188+
if (captureHasLocalCaptures(func))
189+
return true;
190+
continue;
191+
}
192+
if (auto var = dyn_cast<VarDecl>(capture.getDecl())) {
193+
switch (var->getStorageKind()) {
194+
case VarDecl::StoredWithTrivialAccessors:
195+
llvm_unreachable("stored local variable with trivial accessors?");
196+
197+
case VarDecl::InheritedWithObservers:
198+
llvm_unreachable("inherited local variable?");
199+
200+
case VarDecl::StoredWithObservers:
201+
case VarDecl::Addressed:
202+
case VarDecl::AddressedWithTrivialAccessors:
203+
case VarDecl::AddressedWithObservers:
204+
case VarDecl::ComputedWithMutableAddress:
205+
assert(!capture.isDirect() && "should have short circuited out");
206+
// Otherwise, transitively capture the accessors.
207+
SWIFT_FALLTHROUGH;
208+
209+
case VarDecl::Computed:
210+
if (captureHasLocalCaptures(var->getGetter()))
211+
return true;
212+
if (auto setter = var->getSetter())
213+
if (captureHasLocalCaptures(setter))
214+
return true;
215+
continue;
216+
217+
case VarDecl::Stored:
218+
llvm_unreachable("should have short circuited out");
219+
}
220+
}
221+
llvm_unreachable("should have short circuited out");
222+
}
223+
}
224+
225+
return false;
226+
}
227+
115228
static unsigned getFuncNaturalUncurryLevel(AnyFunctionRef AFR) {
116229
assert(AFR.getParameterLists().size() >= 1 && "no arguments for func?!");
117230
unsigned Level = AFR.getParameterLists().size() - 1;
118231
// Functions with captures have an extra uncurry level for the capture
119232
// context.
120-
if (AFR.getCaptureInfo().hasLocalCaptures())
233+
llvm::DenseSet<AnyFunctionRef> visited;
234+
visited.insert(AFR);
235+
if (hasLoweredLocalCaptures(AFR, visited))
121236
Level += 1;
122237
return Level;
123238
}

lib/SIL/TypeLowering.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1696,7 +1696,7 @@ TypeConverter::getFunctionInterfaceTypeWithCaptures(CanAnyFunctionType funcType,
16961696

16971697
// If we don't have any local captures (including function captures),
16981698
// there's no context to apply.
1699-
if (!theClosure.getCaptureInfo().hasLocalCaptures()) {
1699+
if (!hasLoweredLocalCaptures(theClosure)) {
17001700
if (!genericSig)
17011701
return CanFunctionType::get(funcType.getInput(),
17021702
funcType.getResult(),
@@ -2021,6 +2021,11 @@ getAnyFunctionRefFromCapture(CapturedValue capture) {
20212021
return None;
20222022
}
20232023

2024+
bool
2025+
TypeConverter::hasLoweredLocalCaptures(AnyFunctionRef fn) {
2026+
return !getLoweredLocalCaptures(fn).getCaptures().empty();
2027+
}
2028+
20242029
CaptureInfo
20252030
TypeConverter::getLoweredLocalCaptures(AnyFunctionRef fn) {
20262031
// First, bail out if there are no local captures at all.

lib/SILGen/SILGenApply.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,7 +1131,7 @@ class SILGenApply : public Lowering::ExprVisitor<SILGenApply> {
11311131
// If there are captures, put the placeholder curry level in the formal
11321132
// type.
11331133
// TODO: Eliminate the need for this.
1134-
if (afd->getCaptureInfo().hasLocalCaptures())
1134+
if (SGF.SGM.M.Types.hasLoweredLocalCaptures(afd))
11351135
substFnType = CanFunctionType::get(
11361136
SGF.getASTContext().TheEmptyTupleType, substFnType);
11371137
}
@@ -1152,7 +1152,7 @@ class SILGenApply : public Lowering::ExprVisitor<SILGenApply> {
11521152
// captures in the constant info too, to generate more efficient
11531153
// code for mutually recursive local functions which otherwise
11541154
// capture no state.
1155-
if (afd->getCaptureInfo().hasLocalCaptures()) {
1155+
if (SGF.SGM.M.Types.hasLoweredLocalCaptures(afd)) {
11561156
SmallVector<ManagedValue, 4> captures;
11571157
SGF.emitCaptures(e, afd, CaptureEmission::ImmediateApplication,
11581158
captures);
@@ -1197,14 +1197,15 @@ class SILGenApply : public Lowering::ExprVisitor<SILGenApply> {
11971197
// If there are captures, put the placeholder curry level in the formal
11981198
// type.
11991199
// TODO: Eliminate the need for this.
1200-
if (e->getCaptureInfo().hasLocalCaptures())
1200+
bool hasCaptures = SGF.SGM.M.Types.hasLoweredLocalCaptures(e);
1201+
if (hasCaptures)
12011202
substFnType = CanFunctionType::get(
12021203
SGF.getASTContext().TheEmptyTupleType, substFnType);
12031204

12041205
setCallee(Callee::forDirect(SGF, constant, substFnType, e));
12051206

12061207
// If the closure requires captures, emit them.
1207-
if (e->getCaptureInfo().hasLocalCaptures()) {
1208+
if (hasCaptures) {
12081209
SmallVector<ManagedValue, 4> captures;
12091210
SGF.emitCaptures(e, e, CaptureEmission::ImmediateApplication,
12101211
captures);
@@ -4974,7 +4975,7 @@ emitSpecializedAccessorFunctionRef(SILGenFunction &gen,
49744975

49754976
// Collect captures if the accessor has them.
49764977
auto accessorFn = cast<AbstractFunctionDecl>(constant.getDecl());
4977-
if (accessorFn->getCaptureInfo().hasLocalCaptures()) {
4978+
if (gen.SGM.M.Types.hasLoweredLocalCaptures(accessorFn)) {
49784979
assert(!selfValue && "local property has self param?!");
49794980
SmallVector<ManagedValue, 4> captures;
49804981
gen.emitCaptures(loc, accessorFn, CaptureEmission::ImmediateApplication,

lib/SILGen/SILGenExpr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ emitRValueForDecl(SILLocation loc, ConcreteDeclRef declRef, Type ncRefType,
443443
bool hasLocalCaptures = false;
444444
unsigned uncurryLevel = 0;
445445
if (auto *fd = dyn_cast<FuncDecl>(decl)) {
446-
hasLocalCaptures = fd->getCaptureInfo().hasLocalCaptures();
446+
hasLocalCaptures = SGM.M.Types.hasLoweredLocalCaptures(fd);
447447
if (hasLocalCaptures)
448448
++uncurryLevel;
449449
}

lib/SILGen/SILGenFunction.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -362,11 +362,9 @@ SILGenFunction::emitClosureValue(SILLocation loc, SILDeclRef constant,
362362
auto closure = *constant.getAnyFunctionRef();
363363
auto captureInfo = closure.getCaptureInfo();
364364
auto loweredCaptureInfo = SGM.Types.getLoweredLocalCaptures(closure);
365-
366-
assert(((constant.uncurryLevel == 1 &&
367-
captureInfo.hasLocalCaptures()) ||
368-
(constant.uncurryLevel == 0 &&
369-
!captureInfo.hasLocalCaptures())) &&
365+
auto hasCaptures = SGM.Types.hasLoweredLocalCaptures(closure);
366+
assert(((constant.uncurryLevel == 1 && hasCaptures) ||
367+
(constant.uncurryLevel == 0 && !hasCaptures)) &&
370368
"curried local functions not yet supported");
371369

372370
auto constantInfo = getConstantInfo(constant);
@@ -406,7 +404,7 @@ SILGenFunction::emitClosureValue(SILLocation loc, SILDeclRef constant,
406404
loc, captureInfo);
407405
}
408406

409-
if (!captureInfo.hasLocalCaptures() && !wasSpecialized) {
407+
if (!hasCaptures && !wasSpecialized) {
410408
auto result = ManagedValue::forUnmanaged(functionRef);
411409
return emitOrigToSubstValue(loc, result,
412410
AbstractionPattern(expectedType),
@@ -436,7 +434,7 @@ SILGenFunction::emitClosureValue(SILLocation loc, SILDeclRef constant,
436434
// - the original type
437435
auto origLoweredFormalType =
438436
AbstractionPattern(constantInfo.LoweredInterfaceType);
439-
if (captureInfo.hasLocalCaptures()) {
437+
if (hasCaptures) {
440438
// Get the unlowered formal type of the constant, stripping off
441439
// the first level of function application, which applies captures.
442440
origLoweredFormalType =
@@ -794,7 +792,7 @@ void SILGenFunction::emitCurryThunk(ValueDecl *vd,
794792

795793
} else if (auto fd = dyn_cast<AbstractFunctionDecl>(vd)) {
796794
// Forward implicit closure context arguments.
797-
bool hasCaptures = fd->getCaptureInfo().hasLocalCaptures();
795+
bool hasCaptures = SGM.M.Types.hasLoweredLocalCaptures(fd);
798796
if (hasCaptures)
799797
--paramCount;
800798

test/SILGen/local_captures.swift

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,12 @@ func globalfunc() -> () -> () {
1010
func localFunc() {
1111
}
1212

13-
// CHECK-LABEL: sil shared @_TFF14local_captures10globalfuncFT_FT_T_L_6callitfT_T_ : $@convention(thin) () -> ()
14-
// CHECK: function_ref @_TFF14local_captures10globalfuncFT_FT_T_L_9localFuncFT_T_ : $@convention(thin) () -> ()
15-
// CHECK-NEXT: apply
13+
// CHECK-LABEL: sil shared @_TFF14local_captures10globalfuncFT_FT_T_L_6callitFT_T_ : $@convention(thin) () -> ()
1614
func callit() {
1715
localFunc()
1816
}
1917

20-
// CHECK-LABEL: sil shared @_TFF14local_captures10globalfuncFT_FT_T_L_5getitfT_FT_T_ : $@convention(thin) () -> @owned @callee_owned () -> ()
21-
// CHECK: function_ref @_TFF14local_captures10globalfuncFT_FT_T_L_9localFuncFT_T_ : $@convention(thin) () -> ()
22-
// CHECK-NEXT: thin_to_thick_function
23-
// CHECK-NEXT: return
18+
// CHECK-LABEL: sil shared @_TFF14local_captures10globalfuncFT_FT_T_L_5getitFT_FT_T_ : $@convention(thin) () -> @owned @callee_owned () -> ()
2419
func getit() -> () -> () {
2520
return localFunc
2621
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s
2+
3+
do {
4+
func foo() { bar(2) }
5+
func bar<T>(_: T) { foo() }
6+
7+
class Foo {
8+
// CHECK-LABEL: sil shared @_TFC41nested_types_referencing_nested_functionsL_3FoocfT_S0_ : $@convention(method) (@owned Foo) -> @owned Foo {
9+
init() {
10+
foo()
11+
}
12+
// CHECK-LABEL: sil shared @_TFC41nested_types_referencing_nested_functionsL_3Foo3zimfT_T_ : $@convention(method) (@guaranteed Foo) -> ()
13+
func zim() {
14+
foo()
15+
}
16+
// CHECK-LABEL: sil shared @_TFC41nested_types_referencing_nested_functionsL_3Foo4zangurfxT_ : $@convention(method) <T> (@in T, @guaranteed Foo) -> ()
17+
func zang<T>(_ x: T) {
18+
bar(x)
19+
}
20+
// CHECK-LABEL: sil shared @_TFC41nested_types_referencing_nested_functionsL_3FooD : $@convention(method) (@owned Foo) -> ()
21+
deinit {
22+
foo()
23+
}
24+
}
25+
26+
let x = Foo()
27+
x.zim()
28+
x.zang(1)
29+
_ = Foo.zim
30+
_ = Foo.zang as (Foo) -> (Int) -> ()
31+
_ = x.zim
32+
_ = x.zang as (Int) -> ()
33+
}

0 commit comments

Comments
 (0)