Skip to content

Commit 5dea77e

Browse files
authored
Merge pull request #38627 from ahoppen/pr/enum-matching-completion
[CodeCompletion] Explicitly support enum pattern matching
2 parents 6e35487 + 12ff361 commit 5dea77e

10 files changed

+280
-111
lines changed

include/swift/AST/KnownIdentifiers.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ IDENTIFIER(decodeIfPresent)
6666
IDENTIFIER(Decoder)
6767
IDENTIFIER(decoder)
6868
IDENTIFIER_(Differentiation)
69+
IDENTIFIER_WITH_NAME(PatternMatchVar, "$match")
6970
IDENTIFIER(dynamicallyCall)
7071
IDENTIFIER(dynamicMember)
7172
IDENTIFIER(Element)

include/swift/Sema/CodeCompletionTypeChecking.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,21 +89,26 @@ namespace swift {
8989
/// formed during expression type-checking.
9090
class UnresolvedMemberTypeCheckCompletionCallback: public TypeCheckCompletionCallback {
9191
public:
92-
struct Result {
92+
struct ExprResult {
9393
Type ExpectedTy;
9494
bool IsImplicitSingleExpressionReturn;
9595
};
9696

9797
private:
9898
CodeCompletionExpr *CompletionExpr;
99-
SmallVector<Result, 4> Results;
99+
SmallVector<ExprResult, 4> ExprResults;
100+
SmallVector<Type, 1> EnumPatternTypes;
100101
bool GotCallback = false;
101102

102103
public:
103104
UnresolvedMemberTypeCheckCompletionCallback(CodeCompletionExpr *CompletionExpr)
104105
: CompletionExpr(CompletionExpr) {}
105106

106-
ArrayRef<Result> getResults() const { return Results; }
107+
ArrayRef<ExprResult> getExprResults() const { return ExprResults; }
108+
109+
/// If we are completing in a pattern matching position, the types of all
110+
/// enums for whose cases are valid as an \c EnumElementPattern.
111+
ArrayRef<Type> getEnumPatternTypes() const { return EnumPatternTypes; }
107112

108113
/// True if at least one solution was passed via the \c sawSolution
109114
/// callback.

lib/IDE/CodeCompletion.cpp

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3312,8 +3312,6 @@ class CompletionLookup final : public swift::VisibleDeclConsumer {
33123312
getSemanticContext(EED, Reason, dynamicLookupInfo),
33133313
expectedTypeContext);
33143314
Builder.setAssociatedDecl(EED);
3315-
if (HasTypeContext)
3316-
Builder.addFlair(CodeCompletionFlairBit::ExpressionSpecific);
33173315

33183316
addLeadingDot(Builder);
33193317
addValueBaseName(Builder, EED->getBaseIdentifier());
@@ -4372,6 +4370,23 @@ class CompletionLookup final : public swift::VisibleDeclConsumer {
43724370
addObjCPoundKeywordCompletions(/*needPound=*/true);
43734371
}
43744372

4373+
/// Returns \c true if \p VD is an initializer on the \c Optional or \c
4374+
/// Id_OptionalNilComparisonType type from the Swift stdlib.
4375+
static bool isInitializerOnOptional(Type T, ValueDecl *VD) {
4376+
bool IsOptionalType = false;
4377+
IsOptionalType |= static_cast<bool>(T->getOptionalObjectType());
4378+
if (auto *NTD = T->getAnyNominal()) {
4379+
IsOptionalType |= NTD->getBaseIdentifier() ==
4380+
VD->getASTContext().Id_OptionalNilComparisonType;
4381+
}
4382+
if (IsOptionalType && VD->getModuleContext()->isStdlibModule() &&
4383+
isa<ConstructorDecl>(VD)) {
4384+
return true;
4385+
} else {
4386+
return false;
4387+
}
4388+
}
4389+
43754390
void getUnresolvedMemberCompletions(Type T) {
43764391
if (!T->mayHaveMembers())
43774392
return;
@@ -4389,16 +4404,11 @@ class CompletionLookup final : public swift::VisibleDeclConsumer {
43894404
// We can only say .foo where foo is a static member of the contextual
43904405
// type and has the same type (or if the member is a function, then the
43914406
// same result type) as the contextual type.
4392-
FilteredDeclConsumer consumer(*this, [=](ValueDecl *VD,
4393-
DeclVisibilityKind Reason) {
4394-
if (T->getOptionalObjectType() &&
4395-
VD->getModuleContext()->isStdlibModule()) {
4396-
// In optional context, ignore '.init(<some>)', 'init(nilLiteral:)',
4397-
if (isa<ConstructorDecl>(VD))
4398-
return false;
4399-
}
4400-
return true;
4401-
});
4407+
FilteredDeclConsumer consumer(
4408+
*this, [=](ValueDecl *VD, DeclVisibilityKind Reason) {
4409+
// In optional context, ignore '.init(<some>)', 'init(nilLiteral:)',
4410+
return !isInitializerOnOptional(T, VD);
4411+
});
44024412

44034413
auto baseType = MetatypeType::get(T);
44044414
llvm::SaveAndRestore<LookupKind> SaveLook(Kind, LookupKind::ValueExpr);
@@ -4410,6 +4420,21 @@ class CompletionLookup final : public swift::VisibleDeclConsumer {
44104420
/*includeProtocolExtensionMembers*/true);
44114421
}
44124422

4423+
/// Complete all enum members declared on \p T.
4424+
void getEnumElementPatternCompletions(Type T) {
4425+
if (!isa_and_nonnull<EnumDecl>(T->getAnyNominal()))
4426+
return;
4427+
4428+
auto baseType = MetatypeType::get(T);
4429+
llvm::SaveAndRestore<LookupKind> SaveLook(Kind, LookupKind::EnumElement);
4430+
llvm::SaveAndRestore<Type> SaveType(ExprType, baseType);
4431+
llvm::SaveAndRestore<bool> SaveUnresolved(IsUnresolvedMember, true);
4432+
lookupVisibleMemberDecls(*this, baseType, CurrDeclContext,
4433+
/*includeInstanceMembers=*/false,
4434+
/*includeDerivedRequirements=*/false,
4435+
/*includeProtocolExtensionMembers=*/true);
4436+
}
4437+
44134438
void getUnresolvedMemberCompletions(ArrayRef<Type> Types) {
44144439
NeedLeadingDot = !HaveDot;
44154440

@@ -6461,8 +6486,8 @@ static void deliverCompletionResults(CodeCompletionContext &CompletionContext,
64616486
}
64626487

64636488
void deliverUnresolvedMemberResults(
6464-
ArrayRef<UnresolvedMemberTypeCheckCompletionCallback::Result> Results,
6465-
DeclContext *DC, SourceLoc DotLoc,
6489+
ArrayRef<UnresolvedMemberTypeCheckCompletionCallback::ExprResult> Results,
6490+
ArrayRef<Type> EnumPatternTypes, DeclContext *DC, SourceLoc DotLoc,
64666491
ide::CodeCompletionContext &CompletionCtx,
64676492
CodeCompletionConsumer &Consumer) {
64686493
ASTContext &Ctx = DC->getASTContext();
@@ -6471,7 +6496,7 @@ void deliverUnresolvedMemberResults(
64716496

64726497
assert(DotLoc.isValid());
64736498
Lookup.setHaveDot(DotLoc);
6474-
Lookup.shouldCheckForDuplicates(Results.size() > 1);
6499+
Lookup.shouldCheckForDuplicates(Results.size() + EnumPatternTypes.size() > 1);
64756500

64766501
// Get the canonical versions of the top-level types
64776502
SmallPtrSet<CanType, 4> originalTypes;
@@ -6496,6 +6521,22 @@ void deliverUnresolvedMemberResults(
64966521
Lookup.getUnresolvedMemberCompletions(Result.ExpectedTy);
64976522
}
64986523

6524+
// Offer completions when interpreting the pattern match as an
6525+
// EnumElementPattern.
6526+
for (auto &Ty : EnumPatternTypes) {
6527+
Lookup.setExpectedTypes({Ty}, /*IsImplicitSingleExpressionReturn=*/false,
6528+
/*expectsNonVoid=*/true);
6529+
Lookup.setIdealExpectedType(Ty);
6530+
6531+
// We can pattern match MyEnum against Optional<MyEnum>
6532+
if (Ty->getOptionalObjectType()) {
6533+
Type Unwrapped = Ty->lookThroughAllOptionalTypes();
6534+
Lookup.getEnumElementPatternCompletions(Unwrapped);
6535+
}
6536+
6537+
Lookup.getEnumElementPatternCompletions(Ty);
6538+
}
6539+
64996540
deliverCompletionResults(CompletionCtx, Lookup, DC, Consumer);
65006541
}
65016542

@@ -6608,8 +6649,9 @@ bool CodeCompletionCallbacksImpl::trySolverCompletion(bool MaybeFuncBody) {
66086649
Lookup.fallbackTypeCheck(CurDeclContext);
66096650

66106651
addKeywords(CompletionContext.getResultSink(), MaybeFuncBody);
6611-
deliverUnresolvedMemberResults(Lookup.getResults(), CurDeclContext, DotLoc,
6612-
CompletionContext, Consumer);
6652+
deliverUnresolvedMemberResults(Lookup.getExprResults(),
6653+
Lookup.getEnumPatternTypes(), CurDeclContext,
6654+
DotLoc, CompletionContext, Consumer);
66136655
return true;
66146656
}
66156657
case CompletionKind::KeyPathExprSwift: {

lib/Sema/TypeCheckCodeCompletion.cpp

Lines changed: 91 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -811,59 +811,11 @@ class CompletionContextFinder : public ASTWalker {
811811

812812
} // end namespace
813813

814-
// Determine if the target expression is the implicit BinaryExpr generated for
815-
// pattern-matching in a switch/if/guard case (<completion> ~= matchValue).
816-
static bool isForPatternMatch(SolutionApplicationTarget &target) {
817-
if (target.getExprContextualTypePurpose() != CTP_Condition)
818-
return false;
819-
Expr *condition = target.getAsExpr();
820-
if (!condition->isImplicit())
821-
return false;
822-
if (auto *BE = dyn_cast<BinaryExpr>(condition)) {
823-
Identifier id;
824-
if (auto *ODRE = dyn_cast<OverloadedDeclRefExpr>(BE->getFn())) {
825-
id = ODRE->getDecls().front()->getBaseIdentifier();
826-
} else if (auto *DRE = dyn_cast<DeclRefExpr>(BE->getFn())) {
827-
id = DRE->getDecl()->getBaseIdentifier();
828-
}
829-
if (id != target.getDeclContext()->getASTContext().Id_MatchOperator)
830-
return false;
831-
return isa<CodeCompletionExpr>(BE->getLHS());
832-
}
833-
return false;
834-
}
835-
836-
/// Remove any solutions from the provided vector that both require fixes and have a
837-
/// score worse than the best.
814+
/// Remove any solutions from the provided vector that both require fixes and
815+
/// have a score worse than the best.
838816
static void filterSolutions(SolutionApplicationTarget &target,
839817
SmallVectorImpl<Solution> &solutions,
840818
CodeCompletionExpr *completionExpr) {
841-
// FIXME: this is only needed because in pattern matching position, the
842-
// code completion expression always becomes an expression pattern, which
843-
// requires the ~= operator to be defined on the type being matched against.
844-
// Pattern matching against an enum doesn't require that however, so valid
845-
// solutions always end up having fixes. This is a problem because there will
846-
// always be a valid solution as well. Optional defines ~= between Optional
847-
// and _OptionalNilComparisonType (which defines a nilLiteral initializer),
848-
// and the matched-against value can implicitly be made Optional if it isn't
849-
// already, so _OptionalNilComparisonType is always a valid solution for the
850-
// completion. That only generates the 'nil' completion, which is rarely what
851-
// the user intends to write in this position and shouldn't be preferred over
852-
// the other formed solutions (which require fixes). We should generate enum
853-
// pattern completions separately, but for now ignore the
854-
// _OptionalNilComparisonType solution.
855-
if (isForPatternMatch(target) && completionExpr) {
856-
solutions.erase(llvm::remove_if(solutions, [&](const Solution &S) {
857-
ASTContext &ctx = S.getConstraintSystem().getASTContext();
858-
if (!S.hasType(completionExpr))
859-
return false;
860-
if (auto ty = S.getResolvedType(completionExpr))
861-
if (auto *NTD = ty->getAnyNominal())
862-
return NTD->getBaseIdentifier() == ctx.Id_OptionalNilComparisonType;
863-
return false;
864-
}), solutions.end());
865-
}
866-
867819
if (solutions.size() <= 1)
868820
return;
869821

@@ -1286,6 +1238,69 @@ sawSolution(const constraints::Solution &S) {
12861238
}
12871239
}
12881240

1241+
/// If the code completion variable occurs in a pattern matching position, we
1242+
/// have an AST that looks like this.
1243+
/// \code
1244+
/// (binary_expr implicit type='$T3'
1245+
/// (overloaded_decl_ref_expr function_ref=compound decls=[
1246+
/// Swift.(file).~=,
1247+
/// Swift.(file).Optional extension.~=])
1248+
/// (tuple_expr implicit type='($T1, (OtherEnum))'
1249+
/// (code_completion_expr implicit type='$T1')
1250+
/// (declref_expr implicit decl=swift_ide_test.(file).foo(x:).$match)))
1251+
/// \endcode
1252+
/// If the code completion expression occurs in such an AST, return the
1253+
/// declaration of the \c $match variable, otherwise return \c nullptr.
1254+
VarDecl *getMatchVarIfInPatternMatch(CodeCompletionExpr *CompletionExpr,
1255+
ConstraintSystem &CS) {
1256+
auto &Context = CS.getASTContext();
1257+
1258+
TupleExpr *ArgTuple =
1259+
dyn_cast_or_null<TupleExpr>(CS.getParentExpr(CompletionExpr));
1260+
if (!ArgTuple || !ArgTuple->isImplicit() || ArgTuple->getNumElements() != 2) {
1261+
return nullptr;
1262+
}
1263+
1264+
auto Binary = dyn_cast_or_null<BinaryExpr>(CS.getParentExpr(ArgTuple));
1265+
if (!Binary || !Binary->isImplicit()) {
1266+
return nullptr;
1267+
}
1268+
1269+
auto CalledOperator = Binary->getFn();
1270+
if (!CalledOperator || !CalledOperator->isImplicit()) {
1271+
return nullptr;
1272+
}
1273+
// The reference to the ~= operator might be an OverloadedDeclRefExpr or a
1274+
// DeclRefExpr, depending on how many ~= operators are viable.
1275+
if (auto Overloaded =
1276+
dyn_cast_or_null<OverloadedDeclRefExpr>(CalledOperator)) {
1277+
if (!llvm::all_of(Overloaded->getDecls(), [&Context](ValueDecl *D) {
1278+
return D->getBaseName() == Context.Id_MatchOperator;
1279+
})) {
1280+
return nullptr;
1281+
}
1282+
} else if (auto Ref = dyn_cast_or_null<DeclRefExpr>(CalledOperator)) {
1283+
if (Ref->getDecl()->getBaseName() != Context.Id_MatchOperator) {
1284+
return nullptr;
1285+
}
1286+
} else {
1287+
return nullptr;
1288+
}
1289+
1290+
auto MatchArg = dyn_cast_or_null<DeclRefExpr>(ArgTuple->getElement(1));
1291+
if (!MatchArg || !MatchArg->isImplicit()) {
1292+
return nullptr;
1293+
}
1294+
1295+
auto MatchVar = MatchArg->getDecl();
1296+
if (MatchVar && MatchVar->isImplicit() &&
1297+
MatchVar->getBaseName() == Context.Id_PatternMatchVar) {
1298+
return dyn_cast<VarDecl>(MatchVar);
1299+
} else {
1300+
return nullptr;
1301+
}
1302+
}
1303+
12891304
void UnresolvedMemberTypeCheckCompletionCallback::
12901305
sawSolution(const constraints::Solution &S) {
12911306
GotCallback = true;
@@ -1295,18 +1310,34 @@ sawSolution(const constraints::Solution &S) {
12951310
// If the type couldn't be determined (e.g. because there isn't any context
12961311
// to derive it from), let's not attempt to do a lookup since it wouldn't
12971312
// produce any useful results anyway.
1298-
if (!ExpectedTy || ExpectedTy->is<UnresolvedType>())
1299-
return;
1300-
1301-
// If ExpectedTy is a duplicate of any other result, ignore this solution.
1302-
if (llvm::any_of(Results, [&](const Result &R) {
1303-
return R.ExpectedTy->isEqual(ExpectedTy);
1304-
})) {
1305-
return;
1313+
if (ExpectedTy && !ExpectedTy->is<UnresolvedType>()) {
1314+
// If ExpectedTy is a duplicate of any other result, ignore this solution.
1315+
if (!llvm::any_of(ExprResults, [&](const ExprResult &R) {
1316+
return R.ExpectedTy->isEqual(ExpectedTy);
1317+
})) {
1318+
bool SingleExprBody =
1319+
isImplicitSingleExpressionReturn(CS, CompletionExpr);
1320+
ExprResults.push_back({ExpectedTy, SingleExprBody});
1321+
}
13061322
}
13071323

1308-
bool SingleExprBody = isImplicitSingleExpressionReturn(CS, CompletionExpr);
1309-
Results.push_back({ExpectedTy, SingleExprBody});
1324+
if (auto MatchVar = getMatchVarIfInPatternMatch(CompletionExpr, CS)) {
1325+
Type MatchVarType;
1326+
// If the MatchVar has an explicit type, it's not part of the solution. But
1327+
// we can look it up in the constraint system directly.
1328+
if (auto T = S.getConstraintSystem().getVarType(MatchVar)) {
1329+
MatchVarType = T;
1330+
} else {
1331+
MatchVarType = S.getResolvedType(MatchVar);
1332+
}
1333+
if (MatchVarType && !MatchVarType->is<UnresolvedType>()) {
1334+
if (!llvm::any_of(EnumPatternTypes, [&](const Type &R) {
1335+
return R->isEqual(MatchVarType);
1336+
})) {
1337+
EnumPatternTypes.push_back(MatchVarType);
1338+
}
1339+
}
1340+
}
13101341
}
13111342

13121343
void KeyPathTypeCheckCompletionCallback::sawSolution(

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -633,11 +633,9 @@ bool TypeChecker::typeCheckExprPattern(ExprPattern *EP, DeclContext *DC,
633633
PrettyStackTracePattern stackTrace(Context, "type-checking", EP);
634634

635635
// Create a 'let' binding to stand in for the RHS value.
636-
auto *matchVar = new (Context) VarDecl(/*IsStatic*/false,
637-
VarDecl::Introducer::Let,
638-
EP->getLoc(),
639-
Context.getIdentifier("$match"),
640-
DC);
636+
auto *matchVar =
637+
new (Context) VarDecl(/*IsStatic*/ false, VarDecl::Introducer::Let,
638+
EP->getLoc(), Context.Id_PatternMatchVar, DC);
641639
matchVar->setInterfaceType(rhsType->mapTypeOutOfContext());
642640

643641
matchVar->setImplicit();

0 commit comments

Comments
 (0)