Skip to content

[CS] Allow bidirectional inference for ExprPatterns #64387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ ERROR(cannot_match_expr_tuple_pattern_with_nontuple_value,none,
ERROR(cannot_match_unresolved_expr_pattern_with_value,none,
"pattern cannot match values of type %0",
(Type))
ERROR(cannot_match_value_with_pattern,none,
"pattern of type %1 cannot match %0",
(Type, Type))

ERROR(cannot_reference_compare_types,none,
"cannot check reference equality of functions; operands here have types "
Expand Down
36 changes: 25 additions & 11 deletions include/swift/AST/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,11 @@ class alignas(8) Pattern : public ASTAllocated<Pattern> {
/// equivalent to matching this pattern.
///
/// Looks through ParenPattern, BindingPattern, and TypedPattern.
Pattern *getSemanticsProvidingPattern();
const Pattern *getSemanticsProvidingPattern() const {
return const_cast<Pattern*>(this)->getSemanticsProvidingPattern();
Pattern *getSemanticsProvidingPattern(bool allowTypedPattern = true);
const Pattern *
getSemanticsProvidingPattern(bool allowTypedPattern = true) const {
return const_cast<Pattern *>(this)->getSemanticsProvidingPattern(
allowTypedPattern);
}

/// Returns whether this pattern has been type-checked yet.
Expand Down Expand Up @@ -799,14 +801,26 @@ class BindingPattern : public Pattern {
}
};

inline Pattern *Pattern::getSemanticsProvidingPattern() {
if (auto *pp = dyn_cast<ParenPattern>(this))
return pp->getSubPattern()->getSemanticsProvidingPattern();
if (auto *tp = dyn_cast<TypedPattern>(this))
return tp->getSubPattern()->getSemanticsProvidingPattern();
if (auto *vp = dyn_cast<BindingPattern>(this))
return vp->getSubPattern()->getSemanticsProvidingPattern();
return this;
inline Pattern *Pattern::getSemanticsProvidingPattern(bool allowTypedPattern) {
auto *P = this;
while (true) {
if (auto *PP = dyn_cast<ParenPattern>(P)) {
P = PP->getSubPattern();
continue;
}
if (auto *BP = dyn_cast<BindingPattern>(P)) {
P = BP->getSubPattern();
continue;
}
if (allowTypedPattern) {
if (auto *TP = dyn_cast<TypedPattern>(P)) {
P = TP->getSubPattern();
continue;
}
}
break;
}
return P;
}

/// Describes a pattern and the context in which it occurs.
Expand Down
13 changes: 7 additions & 6 deletions include/swift/Sema/CompletionContextFinder.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

namespace swift {

namespace constraints {
class SyntacticElementTarget;
}

class CompletionContextFinder : public ASTWalker {
enum class ContextKind {
FallbackExpression,
Expand Down Expand Up @@ -53,12 +57,9 @@ class CompletionContextFinder : public ASTWalker {
return MacroWalking::Arguments;
}

/// Finder for completion contexts within the provided initial expression.
CompletionContextFinder(ASTNode initialNode, DeclContext *DC)
: InitialExpr(initialNode.dyn_cast<Expr *>()), InitialDC(DC) {
assert(DC);
initialNode.walk(*this);
};
/// Finder for completion contexts within the provided SyntacticElementTarget.
CompletionContextFinder(constraints::SyntacticElementTarget target,
DeclContext *DC);

/// Finder for completion contexts within the outermost non-closure context of
/// the code completion expression's direct context.
Expand Down
3 changes: 3 additions & 0 deletions include/swift/Sema/ConstraintLocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ class ConstraintLocator : public llvm::FoldingSetNode {
/// otherwise \c nullptr.
NullablePtr<Pattern> getPatternMatch() const;

/// Whether the locator in question is for a pattern match.
bool isForPatternMatch() const;

/// Returns true if \p locator is ending with either of the following
/// - Member
/// - Member -> KeyPathDynamicMember
Expand Down
4 changes: 4 additions & 0 deletions include/swift/Sema/ConstraintLocatorPathElts.def
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ CUSTOM_LOCATOR_PATH_ELT(TernaryBranch)
/// Performing a pattern patch.
CUSTOM_LOCATOR_PATH_ELT(PatternMatch)

/// The constraint that models the allowed implicit casts for
/// an EnumElementPattern.
SIMPLE_LOCATOR_PATH_ELT(EnumPatternImplicitCastMatch)

/// Points to a particular attribute associated with one of
/// the arguments e.g. `inout` or its type e.g. `@escaping`.
///
Expand Down
44 changes: 37 additions & 7 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,10 @@ class Solution {
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
caseLabelItems;

/// A map of expressions to the ExprPatterns that they are being solved as
/// a part of.
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;

/// The set of parameters that have been inferred to be 'isolated'.
llvm::SmallVector<ParamDecl *, 2> isolatedParams;

Expand Down Expand Up @@ -1685,6 +1689,16 @@ class Solution {
: nullptr;
}

/// Retrieve the solved ExprPattern that corresponds to provided
/// sub-expression.
NullablePtr<ExprPattern> getExprPatternFor(Expr *E) const {
auto result = exprPatterns.find(E);
if (result == exprPatterns.end())
return nullptr;

return result->second;
}

/// This method implements functionality of `Expr::isTypeReference`
/// with data provided by a given solution.
bool isTypeReference(Expr *E) const;
Expand Down Expand Up @@ -2148,6 +2162,10 @@ class ConstraintSystem {
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
caseLabelItems;

/// A map of expressions to the ExprPatterns that they are being solved as
/// a part of.
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;

/// The set of parameters that have been inferred to be 'isolated'.
llvm::SmallSetVector<ParamDecl *, 2> isolatedParams;

Expand Down Expand Up @@ -2745,6 +2763,9 @@ class ConstraintSystem {
/// The length of \c caseLabelItems.
unsigned numCaseLabelItems;

/// The length of \c exprPatterns.
unsigned numExprPatterns;

/// The length of \c isolatedParams.
unsigned numIsolatedParams;

Expand Down Expand Up @@ -3166,6 +3187,15 @@ class ConstraintSystem {
caseLabelItems[item] = info;
}

/// Record a given ExprPattern as the parent of its sub-expression.
void setExprPatternFor(Expr *E, ExprPattern *EP) {
assert(E);
assert(EP);
auto inserted = exprPatterns.insert({E, EP}).second;
assert(inserted && "Mapping already defined?");
(void)inserted;
}

Optional<CaseLabelItemInfo> getCaseLabelItemInfo(
const CaseLabelItem *item) const {
auto known = caseLabelItems.find(item);
Expand Down Expand Up @@ -4322,15 +4352,15 @@ class ConstraintSystem {
Expr *generateConstraints(Expr *E, DeclContext *dc,
bool isInputExpression = true);

/// Generate constraints for binding the given pattern to the
/// value of the given expression.
/// Generate constraints for a given pattern.
///
/// \returns a possibly-sanitized initializer, or null if an error occurred.
/// \returns The type of the pattern, or \c None if a failure occured.
[[nodiscard]]
Type generateConstraints(Pattern *P, ConstraintLocatorBuilder locator,
bool bindPatternVarsOneWay,
PatternBindingDecl *patternBinding,
unsigned patternIndex);
Optional<Type> generateConstraints(Pattern *P,
ConstraintLocatorBuilder locator,
bool bindPatternVarsOneWay,
PatternBindingDecl *patternBinding,
unsigned patternIndex);

/// Generate constraints for a statement condition.
///
Expand Down
37 changes: 22 additions & 15 deletions lib/IDE/TypeCheckCompletionCallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,13 @@ Type swift::ide::getTypeForCompletion(const constraints::Solution &S,
/// \endcode
/// If the code completion expression occurs in such an AST, return the
/// declaration of the \c $match variable, otherwise return \c nullptr.
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, const Solution &S) {
if (auto EP = S.getExprPatternFor(E))
return EP.get()->getMatchVar();

// TODO: Once ExprPattern type-checking is fully moved into the solver,
// the below can be deleted.
auto &CS = S.getConstraintSystem();
auto &Context = CS.getASTContext();

auto *Binary = dyn_cast_or_null<BinaryExpr>(CS.getParentExpr(E));
Expand Down Expand Up @@ -109,20 +115,21 @@ static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
}

Type swift::ide::getPatternMatchType(const constraints::Solution &S, Expr *E) {
if (auto MatchVar = getMatchVarIfInPatternMatch(E, S.getConstraintSystem())) {
Type MatchVarType;
// If the MatchVar has an explicit type, it's not part of the solution. But
// we can look it up in the constraint system directly.
if (auto T = S.getConstraintSystem().getVarType(MatchVar)) {
MatchVarType = T;
} else {
MatchVarType = getTypeForCompletion(S, MatchVar);
}
if (MatchVarType) {
return MatchVarType;
}
}
return nullptr;
auto MatchVar = getMatchVarIfInPatternMatch(E, S);
if (!MatchVar)
return nullptr;

if (S.hasType(MatchVar))
return S.getResolvedType(MatchVar);

// If the ExprPattern wasn't solved as part of the constraint system, it's
// not part of the solution.
// TODO: This can be removed once ExprPattern type-checking is fully part
// of the constraint system.
if (auto T = S.getConstraintSystem().getVarType(MatchVar))
return T;

return getTypeForCompletion(S, MatchVar);
}

void swift::ide::getSolutionSpecificVarTypes(
Expand Down
3 changes: 2 additions & 1 deletion lib/Sema/BuilderTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,8 @@ Optional<BraceStmt *> TypeChecker::applyResultBuilderBodyTransform(
SmallVector<Solution, 4> solutions;
cs.solveForCodeCompletion(solutions);

CompletionContextFinder analyzer(func, func->getDeclContext());
SyntacticElementTarget funcTarget(func);
CompletionContextFinder analyzer(funcTarget, func->getDeclContext());
if (analyzer.hasCompletion()) {
filterSolutionsForCodeCompletion(solutions, analyzer);
for (const auto &solution : solutions) {
Expand Down
Loading