Skip to content

[CS] Allow ExprPatterns to be type-checked in the solver #64280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ ERROR(cannot_match_expr_tuple_pattern_with_nontuple_value,none,
ERROR(cannot_match_unresolved_expr_pattern_with_value,none,
"pattern cannot match values of type %0",
(Type))
ERROR(cannot_match_value_with_pattern,none,
"pattern of type %1 cannot match %0",
(Type, Type))

ERROR(cannot_reference_compare_types,none,
"cannot check reference equality of functions; operands here have types "
Expand Down
13 changes: 7 additions & 6 deletions include/swift/Sema/CompletionContextFinder.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

namespace swift {

namespace constraints {
class SyntacticElementTarget;
}

class CompletionContextFinder : public ASTWalker {
enum class ContextKind {
FallbackExpression,
Expand Down Expand Up @@ -53,12 +57,9 @@ class CompletionContextFinder : public ASTWalker {
return MacroWalking::Arguments;
}

/// Finder for completion contexts within the provided initial expression.
CompletionContextFinder(ASTNode initialNode, DeclContext *DC)
: InitialExpr(initialNode.dyn_cast<Expr *>()), InitialDC(DC) {
assert(DC);
initialNode.walk(*this);
};
/// Finder for completion contexts within the provided SyntacticElementTarget.
CompletionContextFinder(constraints::SyntacticElementTarget target,
DeclContext *DC);

/// Finder for completion contexts within the outermost non-closure context of
/// the code completion expression's direct context.
Expand Down
3 changes: 3 additions & 0 deletions include/swift/Sema/ConstraintLocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ class ConstraintLocator : public llvm::FoldingSetNode {
/// otherwise \c nullptr.
NullablePtr<Pattern> getPatternMatch() const;

/// Whether the locator in question is for a pattern match.
bool isForPatternMatch() const;

/// Returns true if \p locator is ending with either of the following
/// - Member
/// - Member -> KeyPathDynamicMember
Expand Down
4 changes: 4 additions & 0 deletions include/swift/Sema/ConstraintLocatorPathElts.def
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ CUSTOM_LOCATOR_PATH_ELT(TernaryBranch)
/// Performing a pattern patch.
CUSTOM_LOCATOR_PATH_ELT(PatternMatch)

/// The constraint that models the allowed implicit casts for
/// an EnumElementPattern.
SIMPLE_LOCATOR_PATH_ELT(EnumPatternImplicitCastMatch)

/// Points to a particular attribute associated with one of
/// the arguments e.g. `inout` or its type e.g. `@escaping`.
///
Expand Down
35 changes: 35 additions & 0 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,10 @@ class Solution {
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
caseLabelItems;

/// A map of expressions to the ExprPatterns that they are being solved as
/// a part of.
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;

/// The set of parameters that have been inferred to be 'isolated'.
llvm::SmallVector<ParamDecl *, 2> isolatedParams;

Expand Down Expand Up @@ -1685,6 +1689,16 @@ class Solution {
: nullptr;
}

/// Retrieve the solved ExprPattern that corresponds to provided
/// sub-expression.
NullablePtr<ExprPattern> getExprPatternFor(Expr *E) const {
auto result = exprPatterns.find(E);
if (result == exprPatterns.end())
return nullptr;

return result->second;
}

/// This method implements functionality of `Expr::isTypeReference`
/// with data provided by a given solution.
bool isTypeReference(Expr *E) const;
Expand Down Expand Up @@ -2148,6 +2162,10 @@ class ConstraintSystem {
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
caseLabelItems;

/// A map of expressions to the ExprPatterns that they are being solved as
/// a part of.
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;

/// The set of parameters that have been inferred to be 'isolated'.
llvm::SmallSetVector<ParamDecl *, 2> isolatedParams;

Expand Down Expand Up @@ -2745,6 +2763,9 @@ class ConstraintSystem {
/// The length of \c caseLabelItems.
unsigned numCaseLabelItems;

/// The length of \c exprPatterns.
unsigned numExprPatterns;

/// The length of \c isolatedParams.
unsigned numIsolatedParams;

Expand Down Expand Up @@ -3166,6 +3187,15 @@ class ConstraintSystem {
caseLabelItems[item] = info;
}

/// Record a given ExprPattern as the parent of its sub-expression.
void setExprPatternFor(Expr *E, ExprPattern *EP) {
assert(E);
assert(EP);
auto inserted = exprPatterns.insert({E, EP}).second;
assert(inserted && "Mapping already defined?");
(void)inserted;
}

Optional<CaseLabelItemInfo> getCaseLabelItemInfo(
const CaseLabelItem *item) const {
auto known = caseLabelItems.find(item);
Expand Down Expand Up @@ -4315,6 +4345,11 @@ class ConstraintSystem {
/// \returns \c true if constraint generation failed, \c false otherwise
bool generateConstraints(SingleValueStmtExpr *E);

/// Generate constraints for an array of ExprPatterns, forming a conjunction
/// that solves each expression in turn.
void generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
ConstraintLocatorBuilder locator);

/// Generate constraints for the given (unchecked) expression.
///
/// \returns a possibly-sanitized expression, or null if an error occurred.
Expand Down
37 changes: 22 additions & 15 deletions lib/IDE/TypeCheckCompletionCallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,13 @@ Type swift::ide::getTypeForCompletion(const constraints::Solution &S,
/// \endcode
/// If the code completion expression occurs in such an AST, return the
/// declaration of the \c $match variable, otherwise return \c nullptr.
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, const Solution &S) {
if (auto EP = S.getExprPatternFor(E))
return EP.get()->getMatchVar();

// TODO: Once ExprPattern type-checking is fully moved into the solver,
// the below can be deleted.
auto &CS = S.getConstraintSystem();
auto &Context = CS.getASTContext();

auto *Binary = dyn_cast_or_null<BinaryExpr>(CS.getParentExpr(E));
Expand Down Expand Up @@ -109,20 +115,21 @@ static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
}

Type swift::ide::getPatternMatchType(const constraints::Solution &S, Expr *E) {
if (auto MatchVar = getMatchVarIfInPatternMatch(E, S.getConstraintSystem())) {
Type MatchVarType;
// If the MatchVar has an explicit type, it's not part of the solution. But
// we can look it up in the constraint system directly.
if (auto T = S.getConstraintSystem().getVarType(MatchVar)) {
MatchVarType = T;
} else {
MatchVarType = getTypeForCompletion(S, MatchVar);
}
if (MatchVarType) {
return MatchVarType;
}
}
return nullptr;
auto MatchVar = getMatchVarIfInPatternMatch(E, S);
if (!MatchVar)
return nullptr;

if (S.hasType(MatchVar))
return S.getResolvedType(MatchVar);

// If the ExprPattern wasn't solved as part of the constraint system, it's
// not part of the solution.
// TODO: This can be removed once ExprPattern type-checking is fully part
// of the constraint system.
if (auto T = S.getConstraintSystem().getVarType(MatchVar))
return T;

return getTypeForCompletion(S, MatchVar);
}

void swift::ide::getSolutionSpecificVarTypes(
Expand Down
3 changes: 2 additions & 1 deletion lib/Sema/BuilderTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,8 @@ Optional<BraceStmt *> TypeChecker::applyResultBuilderBodyTransform(
SmallVector<Solution, 4> solutions;
cs.solveForCodeCompletion(solutions);

CompletionContextFinder analyzer(func, func->getDeclContext());
SyntacticElementTarget funcTarget(func);
CompletionContextFinder analyzer(funcTarget, func->getDeclContext());
if (analyzer.hasCompletion()) {
filterSolutionsForCodeCompletion(solutions, analyzer);
for (const auto &solution : solutions) {
Expand Down
128 changes: 83 additions & 45 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8705,6 +8705,9 @@ namespace {
return Action::SkipChildren();
}

NullablePtr<Pattern>
rewritePattern(Pattern *pattern, DeclContext *DC);

/// Rewrite the target, producing a new target.
Optional<SyntacticElementTarget>
rewriteTarget(SyntacticElementTarget target);
Expand Down Expand Up @@ -8951,12 +8954,68 @@ static Expr *wrapAsyncLetInitializer(
return resultInit;
}

static Pattern *rewriteExprPattern(const SyntacticElementTarget &matchTarget,
Type patternTy,
RewriteTargetFn rewriteTarget) {
auto *EP = matchTarget.getExprPattern();

// See if we can simplify to another kind of pattern.
if (auto simplified = TypeChecker::trySimplifyExprPattern(EP, patternTy))
return simplified.get();

auto resultTarget = rewriteTarget(matchTarget);
if (!resultTarget)
return nullptr;

EP->setMatchExpr(resultTarget->getAsExpr());
EP->getMatchVar()->setInterfaceType(patternTy->mapTypeOutOfContext());
EP->setType(patternTy);
return EP;
}

/// Attempt to rewrite either an ExprPattern, or a pattern that was solved as
/// an ExprPattern, e.g an EnumElementPattern that could not refer to an enum
/// case.
static Optional<Pattern *>
tryRewriteExprPattern(Pattern *P, Solution &solution, Type patternTy,
RewriteTargetFn rewriteTarget) {
// See if we have a match expression target.
auto matchTarget = solution.getTargetFor(P);
if (!matchTarget)
return None;

return rewriteExprPattern(*matchTarget, patternTy, rewriteTarget);
}

NullablePtr<Pattern> ExprWalker::rewritePattern(Pattern *pattern,
DeclContext *DC) {
auto &solution = Rewriter.solution;

// Figure out the pattern type.
auto patternTy = solution.getResolvedType(pattern);
patternTy = patternTy->reconstituteSugar(/*recursive=*/false);

// Coerce the pattern to its appropriate type.
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
patternOptions |= TypeResolutionFlags::OverrideType;

auto tryRewritePattern = [&](Pattern *EP, Type ty) {
return ::tryRewriteExprPattern(
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
};

auto contextualPattern = ContextualPattern::forRawPattern(pattern, DC);
return TypeChecker::coercePatternToType(contextualPattern, patternTy,
patternOptions, tryRewritePattern);
}

/// Apply the given solution to the initialization target.
///
/// \returns the resulting initialization expression.
static Optional<SyntacticElementTarget>
applySolutionToInitialization(Solution &solution, SyntacticElementTarget target,
Expr *initializer) {
Expr *initializer,
RewriteTargetFn rewriteTarget) {
auto wrappedVar = target.getInitializationWrappedVar();
Type initType;
if (wrappedVar) {
Expand Down Expand Up @@ -9021,10 +9080,14 @@ applySolutionToInitialization(Solution &solution, SyntacticElementTarget target,

finalPatternType = finalPatternType->reconstituteSugar(/*recursive =*/false);

auto tryRewritePattern = [&](Pattern *EP, Type ty) {
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
};

// Apply the solution to the pattern as well.
auto contextualPattern = target.getContextualPattern();
if (auto coercedPattern = TypeChecker::coercePatternToType(
contextualPattern, finalPatternType, options)) {
contextualPattern, finalPatternType, options, tryRewritePattern)) {
resultTarget.setPattern(coercedPattern);
} else {
return None;
Expand Down Expand Up @@ -9171,10 +9234,15 @@ static Optional<SyntacticElementTarget> applySolutionToForEachStmt(
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
options |= TypeResolutionFlags::OverrideType;

auto tryRewritePattern = [&](Pattern *EP, Type ty) {
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
};

// Apply the solution to the pattern as well.
auto contextualPattern = target.getContextualPattern();
auto coercedPattern = TypeChecker::coercePatternToType(
contextualPattern, forEachStmtInfo.initType, options);
contextualPattern, forEachStmtInfo.initType, options,
tryRewritePattern);
if (!coercedPattern)
return None;

Expand Down Expand Up @@ -9262,7 +9330,8 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
switch (target.getExprContextualTypePurpose()) {
case CTP_Initialization: {
auto initResultTarget = applySolutionToInitialization(
solution, target, rewrittenExpr);
solution, target, rewrittenExpr,
[&](auto target) { return rewriteTarget(target); });
if (!initResultTarget)
return None;

Expand Down Expand Up @@ -9353,47 +9422,11 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
ConstraintSystem &cs = solution.getConstraintSystem();
auto info = *cs.getCaseLabelItemInfo(*caseLabelItem);

// Figure out the pattern type.
Type patternType = solution.simplifyType(solution.getType(info.pattern));
patternType = patternType->reconstituteSugar(/*recursive=*/false);

// Check whether this enum element is resolved via ~= application.
if (auto *enumElement = dyn_cast<EnumElementPattern>(info.pattern)) {
if (auto target = cs.getTargetFor(enumElement)) {
auto *EP = target->getExprPattern();
auto enumType = solution.getResolvedType(EP);

auto *matchCall = target->getAsExpr();

auto *result = matchCall->walk(*this);
if (!result)
return None;

{
auto *matchVar = EP->getMatchVar();
matchVar->setInterfaceType(enumType->mapTypeOutOfContext());
}

EP->setMatchExpr(result);
EP->setType(enumType);

(*caseLabelItem)->setPattern(EP, /*resolved=*/true);
return target;
}
}

// Coerce the pattern to its appropriate type.
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
patternOptions |= TypeResolutionFlags::OverrideType;
auto contextualPattern =
ContextualPattern::forRawPattern(info.pattern,
target.getDeclContext());
if (auto coercedPattern = TypeChecker::coercePatternToType(
contextualPattern, patternType, patternOptions)) {
(*caseLabelItem)->setPattern(coercedPattern, /*resolved=*/true);
} else {
auto pattern = rewritePattern(info.pattern, target.getDeclContext());
if (!pattern)
return None;
}

(*caseLabelItem)->setPattern(pattern.get(), /*resolved=*/true);

// If there is a guard expression, coerce that.
if (auto *guardExpr = info.guardExpr) {
Expand Down Expand Up @@ -9461,8 +9494,13 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
options |= TypeResolutionFlags::OverrideType;
}

auto tryRewritePattern = [&](Pattern *EP, Type ty) {
return ::tryRewriteExprPattern(
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
};

if (auto coercedPattern = TypeChecker::coercePatternToType(
contextualPattern, patternType, options)) {
contextualPattern, patternType, options, tryRewritePattern)) {
auto resultTarget = target;
resultTarget.setPattern(coercedPattern);
return resultTarget;
Expand Down
Loading