Skip to content

Commit 254be25

Browse files
authored
Merge pull request #18564 from rintaro/ide-completion-contextanalysis
[CodeCompletion] Improve context type analysis
2 parents 31f0f28 + 137ca65 commit 254be25

12 files changed

+212
-76
lines changed

include/swift/AST/Expr.h

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -687,20 +687,13 @@ class ErrorExpr : public Expr {
687687
/// can help us preserve the context of the code completion position.
688688
class CodeCompletionExpr : public Expr {
689689
SourceRange Range;
690-
bool Activated;
691690

692691
public:
693-
CodeCompletionExpr(SourceRange Range, Type Ty = Type()) :
694-
Expr(ExprKind::CodeCompletion, /*Implicit=*/true, Ty),
695-
Range(Range) {
696-
Activated = false;
697-
}
692+
CodeCompletionExpr(SourceRange Range, Type Ty = Type())
693+
: Expr(ExprKind::CodeCompletion, /*Implicit=*/true, Ty), Range(Range) {}
698694

699695
SourceRange getSourceRange() const { return Range; }
700696

701-
bool isActivated() const { return Activated; }
702-
void setActivated() { Activated = true; }
703-
704697
static bool classof(const Expr *E) {
705698
return E->getKind() == ExprKind::CodeCompletion;
706699
}

lib/IDE/CodeCompletion.cpp

Lines changed: 83 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,13 +1509,20 @@ static bool isTopLevelContext(const DeclContext *DC) {
15091509
static Type getReturnTypeFromContext(const DeclContext *DC) {
15101510
if (auto FD = dyn_cast<AbstractFunctionDecl>(DC)) {
15111511
if (FD->hasInterfaceType()) {
1512-
if (auto FT = FD->getInterfaceType()->getAs<FunctionType>()) {
1512+
auto Ty = FD->getInterfaceType();
1513+
if (FD->getDeclContext()->isTypeContext())
1514+
Ty = FD->getMethodInterfaceType();
1515+
if (auto FT = Ty->getAs<AnyFunctionType>())
15131516
return FT->getResult();
1514-
}
15151517
}
1516-
} else if (auto CE = dyn_cast<AbstractClosureExpr>(DC)) {
1517-
if (CE->getType()) {
1518-
return CE->getResultType();
1518+
} else if (auto ACE = dyn_cast<AbstractClosureExpr>(DC)) {
1519+
if (ACE->getType())
1520+
return ACE->getResultType();
1521+
if (auto CE = dyn_cast<ClosureExpr>(ACE)) {
1522+
if (CE->hasExplicitResultType())
1523+
return const_cast<ClosureExpr *>(CE)
1524+
->getExplicitResultTypeLoc()
1525+
.getType();
15191526
}
15201527
}
15211528
return Type();
@@ -3893,7 +3900,7 @@ class CompletionLookup final : public swift::VisibleDeclConsumer {
38933900
using FunctionParams = ArrayRef<AnyFunctionType::Param>;
38943901

38953902
static bool
3896-
collectPossibleParamLists(DeclContext &DC, CallExpr *callExpr,
3903+
collectPossibleParamLists(DeclContext &DC, ApplyExpr *callExpr,
38973904
SmallVectorImpl<FunctionParams> &candidates) {
38983905
auto *fnExpr = callExpr->getFn();
38993906

@@ -3940,27 +3947,20 @@ class CompletionLookup final : public swift::VisibleDeclConsumer {
39403947
if (!tuple)
39413948
return false;
39423949

3943-
for (unsigned i = 0, n = tuple->getNumElements(); i != n; ++i) {
3944-
if (isa<CodeCompletionExpr>(tuple->getElement(i))) {
3945-
HasName = !tuple->getElementName(i).empty();
3946-
Position = i;
3947-
return true;
3948-
}
3949-
}
39503950
auto &SM = DC.getASTContext().SourceMgr;
39513951
for (unsigned i = 0, n = tuple->getNumElements(); i != n; ++i) {
39523952
if (SM.isBeforeInBuffer(tuple->getElement(i)->getEndLoc(),
39533953
CCExpr->getStartLoc()))
39543954
continue;
3955-
HasName = !tuple->getElementName(i).empty();
3955+
HasName = tuple->getElementNameLoc(i).isValid();
39563956
Position = i;
39573957
return true;
39583958
}
39593959
return false;
39603960
}
39613961

39623962
static bool
3963-
collectArgumentExpectation(DeclContext &DC, CallExpr *CallE, Expr *CCExpr,
3963+
collectArgumentExpectation(DeclContext &DC, ApplyExpr *CallE, Expr *CCExpr,
39643964
std::vector<Type> &ExpectedTypes,
39653965
std::vector<StringRef> &ExpectedNames) {
39663966
// Collect parameter lists for possible func decls.
@@ -3974,15 +3974,17 @@ class CompletionLookup final : public swift::VisibleDeclConsumer {
39743974
if (!getPositionInArgs(DC, CallE->getArg(), CCExpr, Position, HasName))
39753975
return false;
39763976

3977-
// Collect possible types at the position.
3977+
// Collect possible types (or labels) at the position.
39783978
{
3979+
bool MayNeedName =
3980+
!HasName && isa<CallExpr>(CallE) && !CallE->isImplicit();
39793981
SmallPtrSet<TypeBase *, 4> seenTypes;
39803982
SmallPtrSet<Identifier, 4> seenNames;
39813983
for (auto Params : Candidates) {
39823984
if (Position >= Params.size())
39833985
continue;
39843986
const auto &Param = Params[Position];
3985-
if (Param.hasLabel() && !HasName) {
3987+
if (Param.hasLabel() && MayNeedName) {
39863988
if (seenNames.insert(Param.getLabel()).second)
39873989
ExpectedNames.push_back(Param.getLabel().str());
39883990
} else {
@@ -4909,19 +4911,19 @@ namespace {
49094911
class ExprParentFinder : public ASTWalker {
49104912
friend class CodeCompletionTypeContextAnalyzer;
49114913
Expr *ChildExpr;
4912-
llvm::function_ref<bool(ASTNode)> Predicate;
4914+
llvm::function_ref<bool(ParentTy)> Predicate;
49134915

49144916
bool arePositionsSame(Expr *E1, Expr *E2) {
49154917
return E1->getSourceRange().Start == E2->getSourceRange().Start &&
49164918
E1->getSourceRange().End == E2->getSourceRange().End;
49174919
}
49184920

49194921
public:
4920-
llvm::SmallVector<ASTNode, 5> Ancestors;
4921-
ASTNode ParentClosest;
4922-
ASTNode ParentFarthest;
4922+
llvm::SmallVector<ParentTy, 5> Ancestors;
4923+
ParentTy ParentClosest;
4924+
ParentTy ParentFarthest;
49234925
ExprParentFinder(Expr* ChildExpr,
4924-
llvm::function_ref<bool(ASTNode)> Predicate) :
4926+
llvm::function_ref<bool(ParentTy)> Predicate) :
49254927
ChildExpr(ChildExpr), Predicate(Predicate) {}
49264928

49274929
std::pair<bool, Expr *> walkToExprPre(Expr *E) override {
@@ -4966,6 +4968,18 @@ namespace {
49664968
Ancestors.pop_back();
49674969
return true;
49684970
}
4971+
4972+
std::pair<bool, Pattern *> walkToPatternPre(Pattern *P) override {
4973+
if (Predicate(P))
4974+
Ancestors.push_back(P);
4975+
return { true, P };
4976+
}
4977+
4978+
Pattern *walkToPatternPost(Pattern *P) override {
4979+
if (Predicate(P))
4980+
Ancestors.pop_back();
4981+
return P;
4982+
}
49694983
};
49704984
} // end anonymous namespace
49714985

@@ -4981,16 +4995,19 @@ class CodeCompletionTypeContextAnalyzer {
49814995
public:
49824996
CodeCompletionTypeContextAnalyzer(DeclContext *DC, Expr *ParsedExpr) : DC(DC),
49834997
ParsedExpr(ParsedExpr), SM(DC->getASTContext().SourceMgr),
4984-
Context(DC->getASTContext()), Finder(ParsedExpr, [](ASTNode Node) {
4985-
if (auto E = Node.dyn_cast<Expr *>()) {
4998+
Context(DC->getASTContext()),
4999+
Finder(ParsedExpr, [](ASTWalker::ParentTy Node) {
5000+
if (auto E = Node.getAsExpr()) {
49865001
switch(E->getKind()) {
49875002
case ExprKind::Call:
5003+
case ExprKind::Binary:
5004+
case ExprKind::PrefixUnary:
49885005
case ExprKind::Assign:
49895006
return true;
49905007
default:
49915008
return false;
4992-
}
4993-
} else if (auto S = Node.dyn_cast<Stmt *>()) {
5009+
}
5010+
} else if (auto S = Node.getAsStmt()) {
49945011
switch (S->getKind()) {
49955012
case StmtKind::Return:
49965013
case StmtKind::ForEach:
@@ -5002,25 +5019,34 @@ class CodeCompletionTypeContextAnalyzer {
50025019
default:
50035020
return false;
50045021
}
5005-
} else if (auto D = Node.dyn_cast<Decl *>()) {
5022+
} else if (auto D = Node.getAsDecl()) {
50065023
switch (D->getKind()) {
50075024
case DeclKind::PatternBinding:
50085025
return true;
50095026
default:
50105027
return false;
50115028
}
5029+
} else if (auto P = Node.getAsPattern()) {
5030+
switch (P->getKind()) {
5031+
case PatternKind::Expr:
5032+
return true;
5033+
default:
5034+
return false;
5035+
}
50125036
} else
50135037
return false;
5014-
}) {}
5038+
}) {}
50155039

50165040
void analyzeExpr(Expr *Parent, llvm::function_ref<void(Type)> Callback,
50175041
SmallVectorImpl<StringRef> &PossibleNames) {
50185042
switch (Parent->getKind()) {
5019-
case ExprKind::Call: {
5043+
case ExprKind::Call:
5044+
case ExprKind::Binary:
5045+
case ExprKind::PrefixUnary: {
50205046
std::vector<Type> PotentialTypes;
50215047
std::vector<StringRef> ExpectedNames;
50225048
CompletionLookup::collectArgumentExpectation(
5023-
*DC, cast<CallExpr>(Parent), ParsedExpr, PotentialTypes,
5049+
*DC, cast<ApplyExpr>(Parent), ParsedExpr, PotentialTypes,
50245050
ExpectedNames);
50255051
for (Type Ty : PotentialTypes)
50265052
Callback(Ty);
@@ -5047,7 +5073,7 @@ class CodeCompletionTypeContextAnalyzer {
50475073
break;
50485074
}
50495075
default:
5050-
llvm_unreachable("Unhandled expression kinds.");
5076+
llvm_unreachable("Unhandled expression kind.");
50515077
}
50525078
}
50535079

@@ -5072,7 +5098,7 @@ class CodeCompletionTypeContextAnalyzer {
50725098
}
50735099
break;
50745100
default:
5075-
llvm_unreachable("Unhandled statement kinds.");
5101+
llvm_unreachable("Unhandled statement kind.");
50765102
}
50775103
}
50785104

@@ -5114,7 +5140,22 @@ class CodeCompletionTypeContextAnalyzer {
51145140
break;
51155141
}
51165142
default:
5117-
llvm_unreachable("Unhandled decl kinds.");
5143+
llvm_unreachable("Unhandled decl kind.");
5144+
}
5145+
}
5146+
5147+
void analyzePattern(Pattern *P, llvm::function_ref<void(Type)> Callback) {
5148+
switch (P->getKind()) {
5149+
case PatternKind::Expr: {
5150+
auto ExprPat = cast<ExprPattern>(P);
5151+
if (auto D = ExprPat->getMatchVar()) {
5152+
if (D->hasInterfaceType())
5153+
Callback(D->getInterfaceType());
5154+
}
5155+
break;
5156+
}
5157+
default:
5158+
llvm_unreachable("Unhandled pattern kind.");
51185159
}
51195160
}
51205161

@@ -5136,12 +5177,14 @@ class CodeCompletionTypeContextAnalyzer {
51365177

51375178
for (auto It = Finder.Ancestors.rbegin(); It != Finder.Ancestors.rend();
51385179
++ It) {
5139-
if (auto Parent = It->dyn_cast<Expr *>()) {
5180+
if (auto Parent = It->getAsExpr()) {
51405181
analyzeExpr(Parent, Callback, PossibleNames);
5141-
} else if (auto Parent = It->dyn_cast<Stmt *>()) {
5182+
} else if (auto Parent = It->getAsStmt()) {
51425183
analyzeStmt(Parent, Callback);
5143-
} else if (auto Parent = It->dyn_cast<Decl *>()) {
5184+
} else if (auto Parent = It->getAsDecl()) {
51445185
analyzeDecl(Parent, Callback);
5186+
} else if (auto Parent = It->getAsPattern()) {
5187+
analyzePattern(Parent, Callback);
51455188
}
51465189
if (!PossibleTypes.empty() || !PossibleNames.empty())
51475190
return true;
@@ -5417,12 +5460,12 @@ void CodeCompletionCallbacksImpl::doneParsing() {
54175460
case CompletionKind::UnresolvedMember : {
54185461
Lookup.setHaveDot(SourceLoc());
54195462
SmallVector<Type, 1> PossibleTypes;
5420-
ExprParentFinder Walker(UnresolvedExpr, [&](ASTNode Node) {
5421-
return Node.is<Expr *>();
5463+
ExprParentFinder Walker(UnresolvedExpr, [&](ASTWalker::ParentTy Node) {
5464+
return Node.getAsExpr();
54225465
});
54235466
CurDeclContext->walkContext(Walker);
54245467
bool Success = false;
5425-
if (auto PE = Walker.ParentFarthest.get<Expr *>()) {
5468+
if (auto PE = Walker.ParentFarthest.getAsExpr()) {
54265469
prepareForRetypechecking(PE);
54275470
Success = typeCheckUnresolvedExpr(*CurDeclContext, UnresolvedExpr, PE,
54285471
PossibleTypes);
@@ -5452,11 +5495,7 @@ void CodeCompletionCallbacksImpl::doneParsing() {
54525495

54535496
case CompletionKind::ReturnStmtExpr : {
54545497
SourceLoc Loc = P.Context.SourceMgr.getCodeCompletionLoc();
5455-
if (auto FD = dyn_cast<AbstractFunctionDecl>(CurDeclContext)) {
5456-
if (auto FT = FD->getInterfaceType()->getAs<FunctionType>()) {
5457-
Lookup.setExpectedTypes(FT->getResult());
5458-
}
5459-
}
5498+
Lookup.setExpectedTypes(getReturnTypeFromContext(CurDeclContext));
54605499
Lookup.getValueCompletionsInDeclContext(Loc);
54615500
break;
54625501
}

lib/Parse/ParseExpr.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -532,10 +532,9 @@ ParserResult<Expr> Parser::parseExprUnary(Diag<> Message, bool isExprBasic) {
532532
}
533533

534534
ParserResult<Expr> SubExpr = parseExprUnary(Message, isExprBasic);
535-
if (SubExpr.hasCodeCompletion())
536-
return makeParserCodeCompletionResult<Expr>();
535+
ParserStatus Status = SubExpr;
537536
if (SubExpr.isNull())
538-
return nullptr;
537+
return Status;
539538

540539
// We are sure we can create a prefix prefix operator expr now.
541540
UnaryContext.setCreateSyntax(SyntaxKind::PrefixOperatorExpr);
@@ -545,12 +544,13 @@ ParserResult<Expr> Parser::parseExprUnary(Diag<> Message, bool isExprBasic) {
545544
if (auto *LE = dyn_cast<NumberLiteralExpr>(SubExpr.get())) {
546545
if (Operator->hasName() && Operator->getName().getBaseName() == "-") {
547546
LE->setNegative(Operator->getLoc());
548-
return makeParserResult(LE);
547+
return makeParserResult(Status, LE);
549548
}
550549
}
551550

552-
return makeParserResult(new (Context) PrefixUnaryExpr(
553-
Operator, formUnaryArgument(Context, SubExpr.get())));
551+
return makeParserResult(
552+
Status, new (Context) PrefixUnaryExpr(
553+
Operator, formUnaryArgument(Context, SubExpr.get())));
554554
}
555555

556556
/// expr-keypath-swift:
@@ -1667,9 +1667,6 @@ ParserResult<Expr> Parser::parseExprPrimary(Diag<> ID, bool isExprBasic) {
16671667
rParenLoc,
16681668
trailingClosure,
16691669
SyntaxKind::FunctionCallArgumentList);
1670-
if (status.isError())
1671-
return nullptr;
1672-
16731670
SyntaxContext->createNodeInPlace(SyntaxKind::FunctionCallExpr);
16741671
return makeParserResult(
16751672
status,

lib/Parse/ParsePattern.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,10 +1116,9 @@ ParserResult<Pattern> Parser::parseMatchingPattern(bool isExprBasic) {
11161116
// disambiguate.
11171117
ParserResult<Expr> subExpr =
11181118
parseExprImpl(diag::expected_pattern, isExprBasic);
1119-
if (subExpr.hasCodeCompletion())
1120-
return makeParserCodeCompletionStatus();
1119+
ParserStatus status = subExpr;
11211120
if (subExpr.isNull())
1122-
return nullptr;
1121+
return status;
11231122

11241123
if (SyntaxContext->isEnabled()) {
11251124
if (auto UPES = PatternCtx.popIf<UnresolvedPatternExprSyntax>()) {
@@ -1132,9 +1131,9 @@ ParserResult<Pattern> Parser::parseMatchingPattern(bool isExprBasic) {
11321131
// obvious pattern, which will come back wrapped in an immediate
11331132
// UnresolvedPatternExpr. Transform this now to simplify later code.
11341133
if (auto *UPE = dyn_cast<UnresolvedPatternExpr>(subExpr.get()))
1135-
return makeParserResult(UPE->getSubPattern());
1134+
return makeParserResult(status, UPE->getSubPattern());
11361135

1137-
return makeParserResult(new (Context) ExprPattern(subExpr.get()));
1136+
return makeParserResult(status, new (Context) ExprPattern(subExpr.get()));
11381137
}
11391138

11401139
ParserResult<Pattern> Parser::parseMatchingPatternAsLetOrVar(bool isLet,

lib/Sema/CSGen.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,9 +1190,6 @@ namespace {
11901190
}
11911191

11921192
virtual Type visitCodeCompletionExpr(CodeCompletionExpr *E) {
1193-
if (!E->isActivated())
1194-
return Type();
1195-
11961193
CS.Options |= ConstraintSystemFlags::SuppressDiagnostics;
11971194
return CS.createTypeVariable(CS.getConstraintLocator(E),
11981195
TVO_CanBindToLValue);

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2231,9 +2231,6 @@ bool TypeChecker::typeCheckCompletionSequence(Expr *&expr, DeclContext *DC) {
22312231
// Ensure the output expression is up to date.
22322232
assert(exprAsBinOp == expr && isa<BinaryExpr>(expr) && "found wrong expr?");
22332233

2234-
// Add type variable for the code-completion expression.
2235-
CCE->setActivated();
2236-
22372234
if (auto generated = CS.generateConstraints(expr)) {
22382235
expr = generated;
22392236
} else {

test/IDE/complete_dynamic_lookup.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ func testAnyObject11_(_ dl: AnyObject) {
463463
dl.returnsObjcClass!(#^DL_FUNC_NAME_PAREN_1^#
464464
}
465465
// DL_FUNC_NAME_PAREN_1: Begin completions
466-
// DL_FUNC_NAME_PAREN_1-DAG: Pattern/CurrModule: ['(']{#Int#}[')'][#TopLevelObjcClass#]{{; name=.+$}}
466+
// DL_FUNC_NAME_PAREN_1-DAG: Pattern/CurrModule: ['(']{#(i): Int#}[')'][#TopLevelObjcClass#]{{; name=.+$}}
467467
// DL_FUNC_NAME_PAREN_1: End completions
468468

469469
func testAnyObject12(_ dl: AnyObject) {

0 commit comments

Comments
 (0)