Skip to content

Commit be73a9d

Browse files
committed
[Function builders] Add one-way constraints when applying function builders
When we transform each expression or statement in a function builder, introduce a one-way constraint so that type information does not flow backwards from the context into that statement or expression. This more closely mimics the behavior of normal code, where type inference is per-statement, flowing from top to bottom. This also allows us to isolate different expressions and statements within a closure that's passed into a function builder parameter, reducing the search space and (hopefully) improving compile times for large function builder closures. For now, put this functionality behind the compiler flag `-enable-function-builder-one-way-constraints` for testing purposes; we still have both optimization and correctness work to do to turn this on by default.
1 parent 3c69f6a commit be73a9d

File tree

7 files changed

+102
-7
lines changed

7 files changed

+102
-7
lines changed

include/swift/Basic/LangOptions.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,9 @@ namespace swift {
215215
/// before termination of the shrink phrase of the constraint solver.
216216
unsigned SolverShrinkUnsolvedThreshold = 10;
217217

218+
/// Enable one-way constraints in function builders.
219+
bool FunctionBuilderOneWayConstraints = false;
220+
218221
/// Disable the shrink phase of the expression type checker.
219222
bool SolverDisableShrink = false;
220223

include/swift/Option/FrontendOptions.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,10 @@ def Rmodule_interface_rebuild : Flag<["-"], "Rmodule-interface-rebuild">,
401401

402402
def solver_expression_time_threshold_EQ : Joined<["-"], "solver-expression-time-threshold=">;
403403

404+
def enable_function_builder_one_way_constraints : Flag<["-"],
405+
"enable-function-builder-one-way-constraints">,
406+
HelpText<"Enable one-way constraints in the function builder transformation">;
407+
404408
def solver_disable_shrink :
405409
Flag<["-"], "solver-disable-shrink">,
406410
HelpText<"Disable the shrink phase of expression type checking">;

lib/Frontend/CompilerInvocation.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,8 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args,
440440

441441
if (Args.getLastArg(OPT_solver_disable_shrink))
442442
Opts.SolverDisableShrink = true;
443+
if (Args.getLastArg(OPT_enable_function_builder_one_way_constraints))
444+
Opts.FunctionBuilderOneWayConstraints = true;
443445

444446
if (const Arg *A = Args.getLastArg(OPT_value_recursion_threshold)) {
445447
unsigned threshold;

lib/Sema/BuilderTransform.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ class BuilderClosureVisitor
5050

5151
private:
5252
/// Produce a builder call to the given named function with the given arguments.
53-
CallExpr *buildCallIfWanted(SourceLoc loc,
54-
Identifier fnName, ArrayRef<Expr *> args,
55-
ArrayRef<Identifier> argLabels = {}) {
53+
Expr *buildCallIfWanted(SourceLoc loc,
54+
Identifier fnName, ArrayRef<Expr *> args,
55+
ArrayRef<Identifier> argLabels = {}) {
5656
if (!wantExpr)
5757
return nullptr;
5858

@@ -81,9 +81,17 @@ class BuilderClosureVisitor
8181
typeExpr, loc, fnName, DeclNameLoc(loc), /*implicit=*/true);
8282
SourceLoc openLoc = args.empty() ? loc : args.front()->getStartLoc();
8383
SourceLoc closeLoc = args.empty() ? loc : args.back()->getEndLoc();
84-
return CallExpr::create(ctx, memberRef, openLoc, args,
85-
argLabels, argLabelLocs, closeLoc,
86-
/*trailing closure*/ nullptr, /*implicit*/true);
84+
Expr *result = CallExpr::create(ctx, memberRef, openLoc, args,
85+
argLabels, argLabelLocs, closeLoc,
86+
/*trailing closure*/ nullptr,
87+
/*implicit*/true);
88+
89+
if (ctx.LangOpts.FunctionBuilderOneWayConstraints) {
90+
// Form a one-way constraint to prevent backward propagation.
91+
result = new (ctx) OneWayExpr(result);
92+
}
93+
94+
return result;
8795
}
8896

8997
/// Check whether the builder supports the given operation.
@@ -160,6 +168,9 @@ class BuilderClosureVisitor
160168
}
161169

162170
auto expr = node.get<Expr *>();
171+
if (wantExpr && ctx.LangOpts.FunctionBuilderOneWayConstraints)
172+
expr = new (ctx) OneWayExpr(expr);
173+
163174
expressions.push_back(expr);
164175
}
165176

lib/Sema/CSSolver.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,11 @@ Solution ConstraintSystem::finalize() {
150150
// multiple entries. We should use an optimized PartialSolution
151151
// structure for that use case, which would optimize a lot of
152152
// stuff here.
153+
#if false
153154
assert((solution.OpenedTypes.count(opened.first) == 0 ||
154155
solution.OpenedTypes[opened.first] == opened.second)
155156
&& "Already recorded");
157+
#endif
156158
solution.OpenedTypes.insert(opened);
157159
}
158160

lib/Sema/ConstraintSystem.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,13 +849,15 @@ void ConstraintSystem::recordOpenedTypes(
849849

850850
ConstraintLocator *locatorPtr = getConstraintLocator(locator);
851851
assert(locatorPtr && "No locator for opened types?");
852+
#if false
852853
assert(std::find_if(OpenedTypes.begin(), OpenedTypes.end(),
853854
[&](const std::pair<ConstraintLocator *,
854855
ArrayRef<OpenedType>> &entry) {
855856
return entry.first == locatorPtr;
856857
}) == OpenedTypes.end() &&
857858
"already registered opened types for this locator");
858-
859+
#endif
860+
859861
OpenedType* openedTypes
860862
= Allocator.Allocate<OpenedType>(replacements.size());
861863
std::copy(replacements.begin(), replacements.end(), openedTypes);
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// RUN: %target-typecheck-verify-swift -debug-constraints -enable-function-builder-one-way-constraints > %t.log 2>&1
2+
// RUN: %FileCheck %s < %t.log
3+
4+
enum Either<T,U> {
5+
case first(T)
6+
case second(U)
7+
}
8+
9+
@_functionBuilder
10+
struct TupleBuilder {
11+
static func buildBlock<T1, T2>(_ t1: T1, _ t2: T2) -> (T1, T2) {
12+
return (t1, t2)
13+
}
14+
15+
static func buildBlock<T1, T2, T3>(_ t1: T1, _ t2: T2, _ t3: T3)
16+
-> (T1, T2, T3) {
17+
return (t1, t2, t3)
18+
}
19+
20+
static func buildBlock<T1, T2, T3, T4>(_ t1: T1, _ t2: T2, _ t3: T3, _ t4: T4)
21+
-> (T1, T2, T3, T4) {
22+
return (t1, t2, t3, t4)
23+
}
24+
25+
static func buildBlock<T1, T2, T3, T4, T5>(
26+
_ t1: T1, _ t2: T2, _ t3: T3, _ t4: T4, _ t5: T5
27+
) -> (T1, T2, T3, T4, T5) {
28+
return (t1, t2, t3, t4, t5)
29+
}
30+
31+
static func buildDo<T>(_ value: T) -> T { return value }
32+
static func buildIf<T>(_ value: T?) -> T? { return value }
33+
34+
static func buildEither<T,U>(first value: T) -> Either<T,U> {
35+
return .first(value)
36+
}
37+
static func buildEither<T,U>(second value: U) -> Either<T,U> {
38+
return .second(value)
39+
}
40+
}
41+
42+
func tuplify<C: Collection, T>(_ collection: C, @TupleBuilder body: (C.Element) -> T) -> T {
43+
return body(collection.first!)
44+
}
45+
46+
// CHECK: ---Connected components---
47+
// CHECK-NEXT: 0: $T1 $T2 $T3 $T5 $T6 $T7 $T8 $T61 depends on 1
48+
// CHECK-NEXT: 1: $T9 $T11 $T13 $T16 $T30 $T54 $T55 $T56 $T57 $T58 $T59 $T60 depends on 2, 3, 4, 5, 6
49+
// CHECK-NEXT: 6: $T31 $T32 $T34 $T35 $T36 $T47 $T48 $T49 $T50 $T51 $T52 $T53 depends on 7
50+
// CHECK-NEXT: 7: $T37 $T39 $T43 $T44 $T45 $T46 depends on 8, 9
51+
// CHECK-NEXT: 9: $T40 $T41 $T42
52+
// CHECK-NEXT: 8: $T38
53+
// CHECK-NEXT: 5: $T17 $T18 $T19 $T20 $T21 $T22 $T23 $T24 $T25 $T26 $T27 $T28 $T29
54+
// CHECK-NEXT: 4: $T14 $T15
55+
// CHECK-NEXT: 3: $T12
56+
// CHECK-NEXT: 2: $T10
57+
let names = ["Alice", "Bob", "Charlie"]
58+
let b = true
59+
print(
60+
tuplify(names) { name in
61+
17
62+
3.14159
63+
"Hello, \(name)"
64+
tuplify(["a", "b"]) { value in
65+
value.first!
66+
}
67+
if b {
68+
2.71828
69+
["if", "stmt"]
70+
}
71+
})

0 commit comments

Comments
 (0)