Skip to content

Commit 3e65a7c

Browse files
authored
Merge pull request #60065 from xedin/result-builder-ast-transform-under-flag
[TypeChecker] Implement result builder transform via AST modification under a flag
2 parents 1a611e9 + 0044c8f commit 3e65a7c

19 files changed

+1809
-267
lines changed

include/swift/AST/AnyFunctionRef.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,23 @@ class AnyFunctionRef {
150150
return cast<AutoClosureExpr>(ACE)->getBody();
151151
}
152152

153+
void setParsedBody(BraceStmt *stmt, bool isSingleExpression) {
154+
if (auto *AFD = TheFunction.dyn_cast<AbstractFunctionDecl *>()) {
155+
AFD->setBody(stmt, AbstractFunctionDecl::BodyKind::Parsed);
156+
AFD->setHasSingleExpressionBody(isSingleExpression);
157+
return;
158+
}
159+
160+
auto *ACE = TheFunction.get<AbstractClosureExpr *>();
161+
if (auto *CE = dyn_cast<ClosureExpr>(ACE)) {
162+
CE->setBody(stmt, isSingleExpression);
163+
CE->setBodyState(ClosureExpr::BodyState::ReadyForTypeChecking);
164+
return;
165+
}
166+
167+
llvm_unreachable("autoclosures don't have statement bodies");
168+
}
169+
153170
void setTypecheckedBody(BraceStmt *stmt, bool isSingleExpression) {
154171
if (auto *AFD = TheFunction.dyn_cast<AbstractFunctionDecl *>()) {
155172
AFD->setBody(stmt, AbstractFunctionDecl::BodyKind::TypeChecked);
@@ -159,7 +176,9 @@ class AnyFunctionRef {
159176

160177
auto *ACE = TheFunction.get<AbstractClosureExpr *>();
161178
if (auto *CE = dyn_cast<ClosureExpr>(ACE)) {
162-
return CE->setBody(stmt, isSingleExpression);
179+
CE->setBody(stmt, isSingleExpression);
180+
CE->setBodyState(ClosureExpr::BodyState::TypeCheckedWithSignature);
181+
return;
163182
}
164183

165184
llvm_unreachable("autoclosures don't have statement bodies");

include/swift/AST/Expr.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,11 @@ class alignas(8) Expr : public ASTAllocated<Expr> {
363363
: NumPadBits,
364364
NumElements : 32
365365
);
366+
367+
SWIFT_INLINE_BITFIELD_FULL(TypeJoinExpr, Expr, 32,
368+
: NumPadBits,
369+
NumElements : 32
370+
);
366371
} Bits;
367372

368373
private:
@@ -5940,6 +5945,55 @@ class PackExpr final : public Expr,
59405945
static bool classof(const Expr *E) { return E->getKind() == ExprKind::Pack; }
59415946
};
59425947

5948+
class TypeJoinExpr final : public Expr,
5949+
private llvm::TrailingObjects<TypeJoinExpr, Expr *> {
5950+
friend TrailingObjects;
5951+
5952+
DeclRefExpr *Var;
5953+
5954+
size_t numTrailingObjects() const {
5955+
return getNumElements();
5956+
}
5957+
5958+
MutableArrayRef<Expr *> getMutableElements() {
5959+
return { getTrailingObjects<Expr *>(), getNumElements() };
5960+
}
5961+
5962+
TypeJoinExpr(DeclRefExpr *var, ArrayRef<Expr *> elements);
5963+
5964+
public:
5965+
static TypeJoinExpr *create(ASTContext &ctx, DeclRefExpr *var,
5966+
ArrayRef<Expr *> exprs);
5967+
5968+
SourceLoc getLoc() const { return SourceLoc(); }
5969+
SourceRange getSourceRange() const { return SourceRange(); }
5970+
5971+
DeclRefExpr *getVar() const { return Var; }
5972+
5973+
void setVar(DeclRefExpr *var) {
5974+
assert(var && "cannot set variable reference to null");
5975+
Var = var;
5976+
}
5977+
5978+
ArrayRef<Expr *> getElements() const {
5979+
return { getTrailingObjects<Expr *>(), getNumElements() };
5980+
}
5981+
5982+
Expr *getElement(unsigned i) const {
5983+
return getElements()[i];
5984+
}
5985+
5986+
void setElement(unsigned i, Expr *E) {
5987+
getMutableElements()[i] = E;
5988+
}
5989+
5990+
unsigned getNumElements() const { return Bits.TypeJoinExpr.NumElements; }
5991+
5992+
static bool classof(const Expr *E) {
5993+
return E->getKind() == ExprKind::TypeJoin;
5994+
}
5995+
};
5996+
59435997
inline bool Expr::isInfixOperator() const {
59445998
return isa<BinaryExpr>(this) || isa<IfExpr>(this) ||
59455999
isa<AssignExpr>(this) || isa<ExplicitCastExpr>(this);

include/swift/AST/ExprNodes.def

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ UNCHECKED_EXPR(KeyPathDot, Expr)
205205
UNCHECKED_EXPR(OneWay, Expr)
206206
EXPR(Tap, Expr)
207207
EXPR(Pack, Expr)
208-
LAST_EXPR(Pack)
208+
UNCHECKED_EXPR(TypeJoin, Expr)
209+
LAST_EXPR(TypeJoin)
209210

210211
#undef EXPR_RANGE
211212
#undef LITERAL_EXPR

include/swift/AST/Stmt.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ class BraceStmt final : public Stmt,
159159
SourceLoc rbloc,
160160
Optional<bool> implicit = None);
161161

162+
static BraceStmt *createImplicit(ASTContext &ctx,
163+
ArrayRef<ASTNode> elements) {
164+
return create(ctx, /*lbloc=*/SourceLoc(), elements, /*rbloc=*/SourceLoc(),
165+
/*implicit=*/true);
166+
}
167+
162168
SourceLoc getLBraceLoc() const { return LBLoc; }
163169
SourceLoc getRBraceLoc() const { return RBLoc; }
164170

@@ -552,6 +558,9 @@ class DoStmt : public LabeledStmt {
552558
labelInfo),
553559
DoLoc(doLoc), Body(body) {}
554560

561+
static DoStmt *createImplicit(ASTContext &C, LabeledStmtInfo labelInfo,
562+
ArrayRef<ASTNode> body);
563+
555564
SourceLoc getDoLoc() const { return DoLoc; }
556565

557566
SourceLoc getStartLoc() const { return getLabelLocOrKeywordLoc(DoLoc); }
@@ -1001,6 +1010,11 @@ class CaseStmt final
10011010

10021011
unsigned getNumCaseLabelItems() const { return Bits.CaseStmt.NumPatterns; }
10031012

1013+
FallthroughStmt *getFallthroughStmt() const {
1014+
return hasFallthroughDest() ? *getTrailingObjects<FallthroughStmt *>()
1015+
: nullptr;
1016+
}
1017+
10041018
NullablePtr<CaseStmt> getFallthroughDest() const {
10051019
return const_cast<CaseStmt &>(*this).getFallthroughDest();
10061020
}
@@ -1030,6 +1044,8 @@ class CaseStmt final
10301044
}
10311045
SourceLoc getEndLoc() const { return getBody()->getEndLoc(); }
10321046

1047+
SourceLoc getItemTerminatorLoc() const { return ItemTerminatorLoc; }
1048+
10331049
SourceRange getLabelItemsRange() const {
10341050
switch (ParentKind) {
10351051
case CaseParentKind::Switch:

include/swift/Basic/Features.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ EXPERIMENTAL_FEATURE(FlowSensitiveConcurrencyCaptures)
101101
EXPERIMENTAL_FEATURE(MoveOnly)
102102
EXPERIMENTAL_FEATURE(OneWayClosureParameters)
103103
EXPERIMENTAL_FEATURE(TypeWitnessSystemInference)
104+
EXPERIMENTAL_FEATURE(ResultBuilderASTTransform)
104105

105106
/// Whether to enable experimental differentiable programming features:
106107
/// `@differentiable` declaration attribute, etc.

include/swift/Sema/ConstraintSystem.h

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ struct ResultBuilder {
134134
return BuilderType->getAnyNominal();
135135
}
136136

137+
VarDecl *getBuilderSelf() const { return BuilderSelf; }
138+
137139
Identifier getBuildOptionalId() const { return BuildOptionalId; }
138140

139141
bool supports(Identifier fnBaseName, ArrayRef<Identifier> argLabels = {},
@@ -647,6 +649,10 @@ T *getAsPattern(ASTNode node) {
647649
return nullptr;
648650
}
649651

652+
template <typename T = Stmt> T *castToStmt(ASTNode node) {
653+
return cast<T>(node.get<Stmt *>());
654+
}
655+
650656
SourceLoc getLoc(ASTNode node);
651657
SourceRange getSourceRange(ASTNode node);
652658

@@ -910,6 +916,10 @@ struct AppliedBuilderTransform {
910916
/// converted. Opaque types should be unopened.
911917
Type bodyResultType;
912918

919+
/// The version of the original body with result builder applied
920+
/// as AST transformation.
921+
NullablePtr<BraceStmt> transformedBody;
922+
913923
/// An expression whose value has been recorded for later use.
914924
struct RecordedExpr {
915925
/// The temporary value that captures the value of the expression, if
@@ -1059,6 +1069,7 @@ class SolutionApplicationTargetsKey {
10591069
pattern,
10601070
patternBindingEntry,
10611071
varDecl,
1072+
functionRef,
10621073
};
10631074

10641075
private:
@@ -1079,6 +1090,8 @@ class SolutionApplicationTargetsKey {
10791090
} patternBindingEntry;
10801091

10811092
const VarDecl *varDecl;
1093+
1094+
const DeclContext *functionRef;
10821095
} storage;
10831096

10841097
public:
@@ -1124,6 +1137,11 @@ class SolutionApplicationTargetsKey {
11241137
storage.varDecl = varDecl;
11251138
}
11261139

1140+
SolutionApplicationTargetsKey(const AnyFunctionRef functionRef) {
1141+
kind = Kind::functionRef;
1142+
storage.functionRef = functionRef.getAsDeclContext();
1143+
}
1144+
11271145
friend bool operator==(
11281146
SolutionApplicationTargetsKey lhs, SolutionApplicationTargetsKey rhs) {
11291147
if (lhs.kind != rhs.kind)
@@ -1155,6 +1173,9 @@ class SolutionApplicationTargetsKey {
11551173

11561174
case Kind::varDecl:
11571175
return lhs.storage.varDecl == rhs.storage.varDecl;
1176+
1177+
case Kind::functionRef:
1178+
return lhs.storage.functionRef == rhs.storage.functionRef;
11581179
}
11591180
llvm_unreachable("invalid SolutionApplicationTargetsKey kind");
11601181
}
@@ -1206,6 +1227,11 @@ class SolutionApplicationTargetsKey {
12061227
return hash_combine(
12071228
DenseMapInfo<unsigned>::getHashValue(static_cast<unsigned>(kind)),
12081229
DenseMapInfo<void *>::getHashValue(storage.varDecl));
1230+
1231+
case Kind::functionRef:
1232+
return hash_combine(
1233+
DenseMapInfo<unsigned>::getHashValue(static_cast<unsigned>(kind)),
1234+
DenseMapInfo<void *>::getHashValue(storage.functionRef));
12091235
}
12101236
llvm_unreachable("invalid statement kind");
12111237
}
@@ -2710,6 +2736,14 @@ class ConstraintSystem {
27102736
/// diagnostics when result builder has multiple overloads.
27112737
llvm::SmallDenseSet<AnyFunctionRef> InvalidResultBuilderBodies;
27122738

2739+
/// The *global* set of all functions that have a particular result builder
2740+
/// applied.
2741+
///
2742+
/// The value here is `$__builderSelf` variable and a transformed body.
2743+
llvm::DenseMap<std::pair<AnyFunctionRef, NominalTypeDecl *>,
2744+
std::pair<VarDecl *, BraceStmt *>>
2745+
BuilderTransformedBodies;
2746+
27132747
/// Arguments after the code completion token that were thus ignored (i.e.
27142748
/// assigned fresh type variables) for type checking.
27152749
llvm::SetVector<Expr *> IgnoredArguments;
@@ -3714,6 +3748,35 @@ class ConstraintSystem {
37143748
return known->second;
37153749
}
37163750

3751+
Optional<AppliedBuilderTransform>
3752+
getAppliedResultBuilderTransform(AnyFunctionRef fn) const {
3753+
auto transformed = resultBuilderTransformed.find(fn);
3754+
if (transformed != resultBuilderTransformed.end())
3755+
return transformed->second;
3756+
return None;
3757+
}
3758+
3759+
void setBuilderTransformedBody(AnyFunctionRef fn, NominalTypeDecl *builder,
3760+
NullablePtr<VarDecl> builderSelf,
3761+
NullablePtr<BraceStmt> body) {
3762+
assert(builder->getAttrs().hasAttribute<ResultBuilderAttr>());
3763+
assert(body);
3764+
assert(builderSelf);
3765+
3766+
auto existing = BuilderTransformedBodies.insert(
3767+
{{fn, builder}, {builderSelf.get(), body.get()}});
3768+
assert(existing.second && "Duplicate result builder transform");
3769+
(void)existing;
3770+
}
3771+
3772+
Optional<std::pair<VarDecl *, BraceStmt *>>
3773+
getBuilderTransformedBody(AnyFunctionRef fn, NominalTypeDecl *builder) const {
3774+
auto result = BuilderTransformedBodies.find({fn, builder});
3775+
if (result == BuilderTransformedBodies.end())
3776+
return None;
3777+
return result->second;
3778+
}
3779+
37173780
void setCaseLabelItemInfo(const CaseLabelItem *item, CaseLabelItemInfo info) {
37183781
assert(item != nullptr);
37193782
assert(caseLabelItems.count(item) == 0);
@@ -4792,6 +4855,16 @@ class ConstraintSystem {
47924855
LLVM_NODISCARD
47934856
bool generateConstraints(ClosureExpr *closure);
47944857

4858+
/// Generate constraints for the body of the given function.
4859+
///
4860+
/// \param fn The function or closure expression
4861+
/// \param body The body of the given function that should be
4862+
/// used for constraint generation.
4863+
///
4864+
/// \returns \c true if constraint generation failed, \c false otherwise
4865+
LLVM_NODISCARD
4866+
bool generateConstraints(AnyFunctionRef fn, BraceStmt *body);
4867+
47954868
/// Generate constraints for the given (unchecked) expression.
47964869
///
47974870
/// \returns a possibly-sanitized expression, or null if an error occurred.
@@ -5640,14 +5713,14 @@ class ConstraintSystem {
56405713
///
56415714
///
56425715
/// \param solution The solution to apply.
5643-
/// \param closure The closure to which the solution is being applied.
5716+
/// \param fn The function or closure to which the solution is being applied.
56445717
/// \param currentDC The declaration context in which transformations
56455718
/// will be applied.
56465719
/// \param rewriteTarget Function that performs a rewrite of any
56475720
/// solution application target within the context.
56485721
///
56495722
/// \returns true if solution cannot be applied.
5650-
bool applySolutionToBody(Solution &solution, ClosureExpr *closure,
5723+
bool applySolutionToBody(Solution &solution, AnyFunctionRef fn,
56515724
DeclContext *&currentDC,
56525725
std::function<Optional<SolutionApplicationTarget>(
56535726
SolutionApplicationTarget)>
@@ -6485,6 +6558,13 @@ bool isOperatorDisjunction(Constraint *disjunction);
64856558
/// or nested declarations).
64866559
ASTNode findAsyncNode(ClosureExpr *closure);
64876560

6561+
/// Check whether the given binding represents a placeholder variable that
6562+
/// has to get its type inferred at a first use site.
6563+
///
6564+
/// \returns The currently assigned type if it's a placeholder,
6565+
/// empty type otherwise.
6566+
Type isPlaceholderVar(PatternBindingDecl *PB);
6567+
64886568
} // end namespace constraints
64896569

64906570
template<typename ...Args>

lib/AST/ASTDumper.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2905,6 +2905,21 @@ class PrintExpr : public ExprVisitor<PrintExpr> {
29052905
printRec(E->getBody(), E->getVar()->getDeclContext()->getASTContext());
29062906
PrintWithColorRAII(OS, ParenthesisColor) << ')';
29072907
}
2908+
2909+
void visitTypeJoinExpr(TypeJoinExpr *E) {
2910+
printCommon(E, "type_join_expr");
2911+
2912+
PrintWithColorRAII(OS, DeclColor) << " var=";
2913+
printRec(E->getVar());
2914+
OS << '\n';
2915+
2916+
for (auto *member : E->getElements()) {
2917+
printRec(member);
2918+
OS << '\n';
2919+
}
2920+
2921+
PrintWithColorRAII(OS, ParenthesisColor) << ')';
2922+
}
29082923
};
29092924

29102925
} // end anonymous namespace

lib/AST/ASTPrinter.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3037,6 +3037,10 @@ static bool usesFeatureOneWayClosureParameters(Decl *decl) {
30373037
return false;
30383038
}
30393039

3040+
static bool usesFeatureResultBuilderASTTransform(Decl *decl) {
3041+
return false;
3042+
}
3043+
30403044
static bool usesFeatureTypeWitnessSystemInference(Decl *decl) {
30413045
return false;
30423046
}
@@ -4551,6 +4555,9 @@ void PrintAST::visitPackExpr(PackExpr *expr) {
45514555
void PrintAST::visitReifyPackExpr(ReifyPackExpr *expr) {
45524556
}
45534557

4558+
void PrintAST::visitTypeJoinExpr(TypeJoinExpr *expr) {
4559+
}
4560+
45544561
void PrintAST::visitAssignExpr(AssignExpr *expr) {
45554562
visit(expr->getDest());
45564563
Printer << " = ";

0 commit comments

Comments
 (0)