Skip to content

Commit f810eb0

Browse files
authored
Merge pull request #61316 from tshortli/accept-has-symbol-in-closures
Sema: Accept `if #_hasSymbol()` conditions in closure contexts
2 parents a694172 + 8729801 commit f810eb0

File tree

8 files changed

+155
-42
lines changed

8 files changed

+155
-42
lines changed

include/swift/AST/Stmt.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,14 +398,15 @@ class alignas(8) PoundAvailableInfo final :
398398
class PoundHasSymbolInfo final : public ASTAllocated<PoundHasSymbolInfo> {
399399
Expr *SymbolExpr;
400400
ConcreteDeclRef ReferencedDecl;
401+
bool Invalid;
401402

402403
SourceLoc PoundLoc;
403404
SourceLoc LParenLoc;
404405
SourceLoc RParenLoc;
405406

406407
PoundHasSymbolInfo(SourceLoc PoundLoc, SourceLoc LParenLoc, Expr *SymbolExpr,
407408
SourceLoc RParenLoc)
408-
: SymbolExpr(SymbolExpr), ReferencedDecl(), PoundLoc(PoundLoc),
409+
: SymbolExpr(SymbolExpr), ReferencedDecl(), Invalid(), PoundLoc(PoundLoc),
409410
LParenLoc(LParenLoc), RParenLoc(RParenLoc){};
410411

411412
public:
@@ -419,6 +420,10 @@ class PoundHasSymbolInfo final : public ASTAllocated<PoundHasSymbolInfo> {
419420
ConcreteDeclRef getReferencedDecl() { return ReferencedDecl; }
420421
void setReferencedDecl(ConcreteDeclRef CDR) { ReferencedDecl = CDR; }
421422

423+
/// Returns true if the referenced decl has been diagnosed as invalid.
424+
bool isInvalid() const { return Invalid; }
425+
void setInvalid() { Invalid = true; }
426+
422427
SourceLoc getLParenLoc() const { return LParenLoc; }
423428
SourceLoc getRParenLoc() const { return RParenLoc; }
424429
SourceLoc getStartLoc() const { return PoundLoc; }

lib/AST/ASTWalker.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1394,8 +1394,17 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
13941394
for (auto &elt : C) {
13951395
switch (elt.getKind()) {
13961396
case StmtConditionElement::CK_Availability:
1397-
case StmtConditionElement::CK_HasSymbol:
13981397
break;
1398+
case StmtConditionElement::CK_HasSymbol: {
1399+
auto E = elt.getHasSymbolInfo()->getSymbolExpr();
1400+
if (!E)
1401+
return true;
1402+
E = doIt(E);
1403+
if (!E)
1404+
return true;
1405+
elt.getHasSymbolInfo()->setSymbolExpr(E);
1406+
break;
1407+
}
13991408
case StmtConditionElement::CK_Boolean: {
14001409
auto E = elt.getBoolean();
14011410
// Walk an expression condition normally.

lib/Sema/CSApply.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8923,9 +8923,25 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
89238923
for (auto &condElement : *stmtCondition) {
89248924
switch (condElement.getKind()) {
89258925
case StmtConditionElement::CK_Availability:
8926-
case StmtConditionElement::CK_HasSymbol:
89278926
continue;
89288927

8928+
case StmtConditionElement::CK_HasSymbol: {
8929+
ConstraintSystem &cs = solution.getConstraintSystem();
8930+
auto info = condElement.getHasSymbolInfo();
8931+
auto target = *cs.getSolutionApplicationTarget(&condElement);
8932+
auto resolvedTarget = rewriteTarget(target);
8933+
if (!resolvedTarget) {
8934+
info->setInvalid();
8935+
return None;
8936+
}
8937+
8938+
auto rewrittenExpr = resolvedTarget->getAsExpr();
8939+
info->setSymbolExpr(rewrittenExpr);
8940+
info->setReferencedDecl(
8941+
TypeChecker::getReferencedDeclForHasSymbolCondition(rewrittenExpr));
8942+
continue;
8943+
}
8944+
89298945
case StmtConditionElement::CK_Boolean: {
89308946
auto condExpr = condElement.getBoolean();
89318947
auto finalCondExpr = condExpr->walk(*this);

lib/Sema/CSGen.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4369,10 +4369,15 @@ bool ConstraintSystem::generateConstraints(StmtCondition condition,
43694369
continue;
43704370

43714371
case StmtConditionElement::CK_HasSymbol: {
4372-
ASTContext &ctx = getASTContext();
4373-
ctx.Diags.diagnose(condElement.getStartLoc(),
4374-
diag::has_symbol_unsupported_in_closures);
4375-
return true;
4372+
Expr *symbolExpr = condElement.getHasSymbolInfo()->getSymbolExpr();
4373+
auto target = SolutionApplicationTarget(symbolExpr, dc, CTP_Unused,
4374+
Type(), /*isDiscarded=*/false);
4375+
4376+
if (generateConstraints(target, FreeTypeVariableBinding::Disallow))
4377+
return true;
4378+
4379+
setSolutionApplicationTarget(&condElement, target);
4380+
continue;
43764381
}
43774382

43784383
case StmtConditionElement::CK_Boolean: {

lib/Sema/MiscDiagnostics.cpp

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4302,6 +4302,55 @@ checkImplicitPromotionsInCondition(const StmtConditionElement &cond,
43024302
}
43034303
}
43044304

4305+
/// Perform MiscDiagnostics for the conditions belonging to a \c
4306+
/// LabeledConditionalStmt.
4307+
static void checkLabeledStmtConditions(ASTContext &ctx,
4308+
const LabeledConditionalStmt *stmt,
4309+
DeclContext *DC) {
4310+
for (auto elt : stmt->getCond()) {
4311+
// Check for implicit optional promotions in stmt-condition patterns.
4312+
checkImplicitPromotionsInCondition(elt, ctx);
4313+
4314+
switch (elt.getKind()) {
4315+
case StmtConditionElement::CK_Boolean:
4316+
case StmtConditionElement::CK_PatternBinding:
4317+
case StmtConditionElement::CK_Availability:
4318+
break;
4319+
4320+
case StmtConditionElement::CK_HasSymbol: {
4321+
auto info = elt.getHasSymbolInfo();
4322+
if (info->isInvalid())
4323+
break;
4324+
4325+
auto symbolExpr = info->getSymbolExpr();
4326+
if (!symbolExpr)
4327+
break;
4328+
4329+
if (!symbolExpr->getType())
4330+
break;
4331+
4332+
if (auto decl = info->getReferencedDecl().getDecl()) {
4333+
// `if #_hasSymbol(someStronglyLinkedSymbol)` is functionally a no-op
4334+
// and may indicate the developer has mis-identified the declaration
4335+
// they want to check (or forgot to import the module weakly).
4336+
if (!decl->isWeakImported(DC->getParentModule())) {
4337+
ctx.Diags.diagnose(symbolExpr->getLoc(),
4338+
diag::has_symbol_decl_must_be_weak,
4339+
decl->getDescriptiveKind(), decl->getName());
4340+
info->setInvalid();
4341+
}
4342+
} else {
4343+
// Diagnose because we weren't able to interpret the expression as one
4344+
// that uniquely identifies a single declaration.
4345+
ctx.Diags.diagnose(symbolExpr->getLoc(), diag::has_symbol_invalid_expr);
4346+
info->setInvalid();
4347+
}
4348+
break;
4349+
}
4350+
}
4351+
}
4352+
}
4353+
43054354
static void diagnoseUnintendedOptionalBehavior(const Expr *E,
43064355
const DeclContext *DC) {
43074356
if (!E || isa<ErrorExpr>(E) || !E->getType())
@@ -5290,11 +5339,9 @@ void swift::performStmtDiagnostics(const Stmt *S, DeclContext *DC) {
52905339
checkSwitch(ctx, switchStmt, DC);
52915340

52925341
checkStmtConditionTrailingClosure(ctx, S);
5293-
5294-
// Check for implicit optional promotions in stmt-condition patterns.
5342+
52955343
if (auto *lcs = dyn_cast<LabeledConditionalStmt>(S))
5296-
for (const auto &elt : lcs->getCond())
5297-
checkImplicitPromotionsInCondition(elt, ctx);
5344+
checkLabeledStmtConditions(ctx, lcs, DC);
52985345

52995346
if (!ctx.LangOpts.DisableAvailabilityChecking)
53005347
diagnoseStmtAvailability(S, const_cast<DeclContext*>(DC));

lib/Sema/TypeCheckStmt.cpp

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -448,8 +448,7 @@ static Expr *getDeclRefProvidingExpressionForHasSymbol(Expr *E) {
448448
return E;
449449
}
450450

451-
static ConcreteDeclRef
452-
getReferencedDeclForHasSymbolCondition(ASTContext &Context, Expr *E) {
451+
ConcreteDeclRef TypeChecker::getReferencedDeclForHasSymbolCondition(Expr *E) {
453452
// Match DotSelfExprs (e.g. `SomeStruct.self`) when the type is static.
454453
if (auto DSE = dyn_cast<DotSelfExpr>(E)) {
455454
if (DSE->isStaticallyDerivedMetatype())
@@ -461,7 +460,6 @@ getReferencedDeclForHasSymbolCondition(ASTContext &Context, Expr *E) {
461460
return CDR;
462461
}
463462

464-
Context.Diags.diagnose(E->getLoc(), diag::has_symbol_invalid_expr);
465463
return ConcreteDeclRef();
466464
}
467465

@@ -506,19 +504,12 @@ bool TypeChecker::typeCheckStmtConditionElement(StmtConditionElement &elt,
506504
auto exprTy = TypeChecker::typeCheckExpression(E, dc);
507505
Info->setSymbolExpr(E);
508506

509-
if (!exprTy)
510-
return true;
511-
512-
auto CDR = getReferencedDeclForHasSymbolCondition(Context, E);
513-
if (!CDR)
507+
if (!exprTy) {
508+
Info->setInvalid();
514509
return true;
515-
516-
auto decl = CDR.getDecl();
517-
if (!decl->isWeakImported(dc->getParentModule())) {
518-
Context.Diags.diagnose(E->getLoc(), diag::has_symbol_decl_must_be_weak,
519-
decl->getDescriptiveKind(), decl->getName());
520510
}
521-
Info->setReferencedDecl(CDR);
511+
512+
Info->setReferencedDecl(getReferencedDeclForHasSymbolCondition(E));
522513
return false;
523514
}
524515

lib/Sema/TypeChecker.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,10 @@ Expr *substituteInputSugarTypeForResult(ApplyExpr *E);
441441
bool typeCheckStmtConditionElement(StmtConditionElement &elt, bool &isFalsable,
442442
DeclContext *dc);
443443

444+
/// Returns the unique decl ref identified by the expr according to the
445+
/// requirements of the \c #_hasSymbol() condition type.
446+
ConcreteDeclRef getReferencedDeclForHasSymbolCondition(Expr *E);
447+
444448
void typeCheckASTNode(ASTNode &node, DeclContext *DC,
445449
bool LeaveBodyUnchecked = false);
446450

test/Sema/has_symbol.swift

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: %empty-directory(%t)
22
// RUN: %target-swift-frontend -emit-module -emit-module-path %t/has_symbol_helper.swiftmodule -parse-as-library %S/Inputs/has_symbol_helper.swift -enable-library-evolution
33
// RUN: %target-typecheck-verify-swift -disable-availability-checking -I %t
4+
// RUN: %target-typecheck-verify-swift -disable-availability-checking -I %t -enable-experimental-feature ResultBuilderASTTransform
45

56
// UNSUPPORTED: OS=windows-msvc
67

@@ -93,6 +94,7 @@ func testNotWeakDeclDiagnostics(_ s: LocalStruct) {
9394
}
9495

9596
func testInvalidExpressionsDiagnostics() {
97+
if #_hasSymbol(unknownDecl) {} // expected-error {{cannot find 'unknownDecl' in scope}}
9698
if #_hasSymbol(noArgFunc()) {} // expected-error {{#_hasSymbol condition must refer to a declaration}}
9799
if #_hasSymbol(global - 1) {} // expected-error {{#_hasSymbol condition must refer to a declaration}}
98100
if #_hasSymbol(S.staticFunc()) {} // expected-error {{#_hasSymbol condition must refer to a declaration}}
@@ -101,15 +103,36 @@ func testInvalidExpressionsDiagnostics() {
101103
if #_hasSymbol(1 as S) {} // expected-error {{cannot convert value of type 'Int' to type 'S' in coercion}}
102104
}
103105

104-
func testMultiStatementClosure() {
105-
let _: () -> Void = { // expected-error {{unable to infer closure type in the current context}}
106-
if #_hasSymbol(global) {} // expected-error 2 {{#_hasSymbol is not supported in closures}}
107-
}
108-
109-
let _: () -> Void = { // expected-error {{unable to infer closure type in the current context}}
110-
if #_hasSymbol(global) {} // expected-error 2 {{#_hasSymbol is not supported in closures}}
111-
localFunc()
112-
}
106+
func testGuard() {
107+
guard #_hasSymbol(global) else { return }
108+
guard #_hasSymbol(unknownDecl) else { return } // expected-error {{cannot find 'unknownDecl' in scope}}
109+
guard #_hasSymbol(localFunc) else { return } // expected-warning {{global function 'localFunc()' is not a weakly linked declaration}}
110+
}
111+
112+
func testWhile() {
113+
while #_hasSymbol(global) { break }
114+
while #_hasSymbol(unknownDecl) { break } // expected-error {{cannot find 'unknownDecl' in scope}}
115+
while #_hasSymbol(localFunc) { break } // expected-warning {{global function 'localFunc()' is not a weakly linked declaration}}
116+
}
117+
118+
func doIt(_ closure: () -> ()) {
119+
closure()
120+
}
121+
122+
func testClosure() {
123+
doIt { if #_hasSymbol(global) {} }
124+
doIt { if #_hasSymbol(noArgFunc) {} }
125+
doIt { if #_hasSymbol(ambiguousFunc as () -> Int) {} }
126+
doIt { if #_hasSymbol(S.self) {} }
127+
doIt { if #_hasSymbol(ambiguousFunc) {} } // expected-error {{ambiguous use of 'ambiguousFunc()'}}
128+
doIt { if #_hasSymbol(localFunc) {} } // expected-warning {{global function 'localFunc()' is not a weakly linked declaration}}
129+
doIt { if #_hasSymbol(unknownDecl) {} } // expected-error {{cannot find 'unknownDecl' in scope}}
130+
doIt { if #_hasSymbol(noArgFunc()) {} } // expected-error {{#_hasSymbol condition must refer to a declaration}}
131+
doIt { if #_hasSymbol(global - 1) {} } // expected-error {{#_hasSymbol condition must refer to a declaration}}
132+
doIt { if #_hasSymbol(S.staticFunc()) {} } // expected-error {{#_hasSymbol condition must refer to a declaration}}
133+
doIt { if #_hasSymbol(C.classFunc()) {} } // expected-error {{#_hasSymbol condition must refer to a declaration}}
134+
doIt { if #_hasSymbol(1 as Int) {} } // expected-error {{#_hasSymbol condition must refer to a declaration}}
135+
doIt { if #_hasSymbol(1 as S) {} } // expected-error {{cannot convert value of type 'Int' to type 'S' in coercion}}
113136
}
114137

115138
protocol View {}
@@ -120,15 +143,28 @@ protocol View {}
120143
static func buildEither<Content>(second content: Content) -> Content where Content : View { fatalError() }
121144
}
122145

123-
struct Image : View {
124-
}
146+
struct Image : View {}
125147

126148
struct MyView {
127-
@ViewBuilder var body: some View {
128-
if #_hasSymbol(global) { // expected-error {{#_hasSymbol is not supported in closures}}
129-
Image()
130-
} else {
131-
Image()
132-
}
149+
let image = Image()
150+
151+
@ViewBuilder var globalView: some View {
152+
if #_hasSymbol(global) { image }
153+
else { image }
154+
}
155+
156+
@ViewBuilder var ambiguousFuncView: some View {
157+
if #_hasSymbol(ambiguousFunc) { image } // expected-error {{ambiguous use of 'ambiguousFunc()'}}
158+
else { image }
159+
}
160+
161+
@ViewBuilder var localFuncView: some View {
162+
if #_hasSymbol(localFunc) { image } // expected-warning {{global function 'localFunc()' is not a weakly linked declaration}}
163+
else { image }
164+
}
165+
166+
@ViewBuilder var noArgFuncView: some View {
167+
if #_hasSymbol(noArgFunc()) { image } // expected-error {{#_hasSymbol condition must refer to a declaration}}
168+
else { image }
133169
}
134170
}

0 commit comments

Comments
 (0)