Skip to content

[Constraint solver] Migrate for-each statement checking into SolutionApplicationTarget #30924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 172 additions & 5 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7886,7 +7886,7 @@ bool ConstraintSystem::applySolutionFixes(const Solution &solution) {

/// Apply the given solution to the initialization target.
///
/// \returns the resulting initialiation expression.
/// \returns the resulting initialization expression.
static Optional<SolutionApplicationTarget> applySolutionToInitialization(
Solution &solution, SolutionApplicationTarget target,
Expr *initializer) {
Expand Down Expand Up @@ -7950,7 +7950,7 @@ static Optional<SolutionApplicationTarget> applySolutionToInitialization(
finalPatternType = finalPatternType->reconstituteSugar(/*recursive =*/false);

// Apply the solution to the pattern as well.
auto contextualPattern = target.getInitializationContextualPattern();
auto contextualPattern = target.getContextualPattern();
if (auto coercedPattern = TypeChecker::coercePatternToType(
contextualPattern, finalPatternType, options)) {
resultTarget.setPattern(coercedPattern);
Expand All @@ -7961,6 +7961,139 @@ static Optional<SolutionApplicationTarget> applySolutionToInitialization(
return resultTarget;
}

/// Apply the given solution to the for-each statement target.
///
/// \returns the resulting initialization expression.
static Optional<SolutionApplicationTarget> applySolutionToForEachStmt(
Solution &solution, SolutionApplicationTarget target, Expr *sequence) {
auto resultTarget = target;
auto &forEachStmtInfo = resultTarget.getForEachStmtInfo();

// Simplify the various types.
forEachStmtInfo.elementType =
solution.simplifyType(forEachStmtInfo.elementType);
forEachStmtInfo.iteratorType =
solution.simplifyType(forEachStmtInfo.iteratorType);
forEachStmtInfo.initType =
solution.simplifyType(forEachStmtInfo.initType);
forEachStmtInfo.sequenceType =
solution.simplifyType(forEachStmtInfo.sequenceType);

// Coerce the sequence to the sequence type.
auto &cs = solution.getConstraintSystem();
auto locator = cs.getConstraintLocator(target.getAsExpr());
sequence = solution.coerceToType(
sequence, forEachStmtInfo.sequenceType, locator);
if (!sequence)
return None;

resultTarget.setExpr(sequence);

// Get the conformance of the sequence type to the Sequence protocol.
auto stmt = forEachStmtInfo.stmt;
auto sequenceProto = TypeChecker::getProtocol(
cs.getASTContext(), stmt->getForLoc(), KnownProtocolKind::Sequence);
auto contextualLocator = cs.getConstraintLocator(
target.getAsExpr(), LocatorPathElt::ContextualType());
auto sequenceConformance = solution.resolveConformance(
contextualLocator, sequenceProto);
assert(!sequenceConformance.isInvalid() &&
"Couldn't find sequence conformance");

// Coerce the pattern to the element type.
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
options |= TypeResolutionFlags::OverrideType;

// Apply the solution to the pattern as well.
auto contextualPattern = target.getContextualPattern();
if (auto coercedPattern = TypeChecker::coercePatternToType(
contextualPattern, forEachStmtInfo.initType, options)) {
resultTarget.setPattern(coercedPattern);
} else {
return None;
}

// Apply the solution to the filtering condition, if there is one.
auto dc = target.getDeclContext();
if (forEachStmtInfo.whereExpr) {
auto *boolDecl = dc->getASTContext().getBoolDecl();
assert(boolDecl);
Type boolType = boolDecl->getDeclaredType();
assert(boolType);

SolutionApplicationTarget whereTarget(
forEachStmtInfo.whereExpr, dc, CTP_Condition, boolType,
/*isDiscarded=*/false);
auto newWhereTarget = cs.applySolution(solution, whereTarget);
if (!newWhereTarget)
return None;

forEachStmtInfo.whereExpr = newWhereTarget->getAsExpr();
}

// Invoke iterator() to get an iterator from the sequence.
ASTContext &ctx = cs.getASTContext();
VarDecl *iterator;
Type nextResultType = OptionalType::get(forEachStmtInfo.elementType);
{
// Create a local variable to capture the iterator.
std::string name;
if (auto np = dyn_cast_or_null<NamedPattern>(stmt->getPattern()))
name = "$"+np->getBoundName().str().str();
name += "$generator";

iterator = new (ctx) VarDecl(
/*IsStatic*/ false, VarDecl::Introducer::Var,
/*IsCaptureList*/ false, stmt->getInLoc(),
ctx.getIdentifier(name), dc);
iterator->setInterfaceType(
forEachStmtInfo.iteratorType->mapTypeOutOfContext());
iterator->setImplicit();

auto genPat = new (ctx) NamedPattern(iterator);
genPat->setImplicit();

// TODO: test/DebugInfo/iteration.swift requires this extra info to
// be around.
PatternBindingDecl::createImplicit(
ctx, StaticSpellingKind::None, genPat,
new (ctx) OpaqueValueExpr(stmt->getInLoc(), nextResultType),
dc, /*VarLoc*/ stmt->getForLoc());
}

// Create the iterator variable.
auto *varRef = TypeChecker::buildCheckedRefExpr(
iterator, dc, DeclNameLoc(stmt->getInLoc()), /*implicit*/ true);

// Convert that Optional<Element> value to the type of the pattern.
auto optPatternType = OptionalType::get(forEachStmtInfo.initType);
if (!optPatternType->isEqual(nextResultType)) {
OpaqueValueExpr *elementExpr =
new (ctx) OpaqueValueExpr(stmt->getInLoc(), nextResultType,
/*isPlaceholder=*/true);
Expr *convertElementExpr = elementExpr;
if (TypeChecker::typeCheckExpression(
convertElementExpr, dc,
TypeLoc::withoutLoc(optPatternType),
CTP_CoerceOperand).isNull()) {
return None;
}
elementExpr->setIsPlaceholder(false);
stmt->setElementExpr(elementExpr);
stmt->setConvertElementExpr(convertElementExpr);
}

// Write the result back into the AST.
stmt->setSequence(resultTarget.getAsExpr());
stmt->setPattern(resultTarget.getContextualPattern().getPattern());
stmt->setSequenceConformance(sequenceConformance);
stmt->setWhere(forEachStmtInfo.whereExpr);
stmt->setIteratorVar(iterator);
stmt->setIteratorVarRef(varRef);

return resultTarget;
}

Optional<SolutionApplicationTarget>
ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
auto &solution = Rewriter.solution;
Expand All @@ -7972,16 +8105,50 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
if (!rewrittenExpr)
return None;

/// Handle application for initializations.
if (target.getExprContextualTypePurpose() == CTP_Initialization) {
/// Handle special cases for expressions.
switch (target.getExprContextualTypePurpose()) {
case CTP_Initialization: {
auto initResultTarget = applySolutionToInitialization(
solution, target, rewrittenExpr);
if (!initResultTarget)
return None;

result = *initResultTarget;
} else {
break;
}

case CTP_ForEachStmt: {
auto forEachResultTarget = applySolutionToForEachStmt(
solution, target, rewrittenExpr);
if (!forEachResultTarget)
return None;

result = *forEachResultTarget;
break;
}

case CTP_Unused:
case CTP_ReturnStmt:
case swift::CTP_ReturnSingleExpr:
case swift::CTP_YieldByValue:
case swift::CTP_YieldByReference:
case swift::CTP_ThrowStmt:
case swift::CTP_EnumCaseRawValue:
case swift::CTP_DefaultParameter:
case swift::CTP_AutoclosureDefaultParameter:
case swift::CTP_CalleeResult:
case swift::CTP_CallArgument:
case swift::CTP_ClosureResult:
case swift::CTP_ArrayElement:
case swift::CTP_DictionaryKey:
case swift::CTP_DictionaryValue:
case swift::CTP_CoerceOperand:
case swift::CTP_AssignSource:
case swift::CTP_SubscriptAssignSource:
case swift::CTP_Condition:
case swift::CTP_CannotFail:
result.setExpr(rewrittenExpr);
break;
}
} else if (auto stmtCondition = target.getAsStmtCondition()) {
for (auto &condElement : *stmtCondition) {
Expand Down
116 changes: 116 additions & 0 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4128,6 +4128,112 @@ static bool generateInitPatternConstraints(
return false;
}

/// Generate constraints for a for-each statement.
static Optional<SolutionApplicationTarget>
generateForEachStmtConstraints(
ConstraintSystem &cs, SolutionApplicationTarget target, Expr *sequence) {
auto forEachStmtInfo = target.getForEachStmtInfo();
ForEachStmt *stmt = forEachStmtInfo.stmt;

auto locator = cs.getConstraintLocator(sequence);
auto contextualLocator =
cs.getConstraintLocator(sequence, LocatorPathElt::ContextualType());

// The expression type must conform to the Sequence protocol.
auto sequenceProto = TypeChecker::getProtocol(
cs.getASTContext(), stmt->getForLoc(), KnownProtocolKind::Sequence);
if (!sequenceProto) {
return None;
}

Type sequenceType = cs.createTypeVariable(locator, TVO_CanBindToNoEscape);
cs.addConstraint(ConstraintKind::Conversion, cs.getType(sequence),
sequenceType, locator);
cs.addConstraint(ConstraintKind::ConformsTo, sequenceType,
sequenceProto->getDeclaredType(), contextualLocator);

// Check the element pattern.
ASTContext &ctx = cs.getASTContext();
auto dc = target.getDeclContext();
Pattern *pattern = TypeChecker::resolvePattern(stmt->getPattern(), dc,
/*isStmtCondition*/false);
if (!pattern)
return None;

auto contextualPattern =
ContextualPattern::forRawPattern(pattern, dc);
Type patternType = TypeChecker::typeCheckPattern(contextualPattern);
if (patternType->hasError()) {
return None;
}

// Collect constraints from the element pattern.
auto elementLocator = cs.getConstraintLocator(
contextualLocator, ConstraintLocator::SequenceElementType);
Type initType = cs.generateConstraints(
pattern, contextualLocator, target.shouldBindPatternVarsOneWay(),
nullptr, 0);
if (!initType)
return None;

// Add a conversion constraint between the element type of the sequence
// and the type of the element pattern.
auto elementAssocType =
sequenceProto->getAssociatedType(cs.getASTContext().Id_Element);
Type elementType = DependentMemberType::get(sequenceType, elementAssocType);
cs.addConstraint(ConstraintKind::Conversion, elementType, initType,
elementLocator);

// Determine the iterator type.
auto iteratorAssocType =
sequenceProto->getAssociatedType(cs.getASTContext().Id_Iterator);
Type iteratorType = DependentMemberType::get(sequenceType, iteratorAssocType);

// The iterator type must conform to IteratorProtocol.
ProtocolDecl *iteratorProto = TypeChecker::getProtocol(
cs.getASTContext(), stmt->getForLoc(),
KnownProtocolKind::IteratorProtocol);
if (!iteratorProto)
return None;

// Reference the makeIterator witness.
FuncDecl *makeIterator = ctx.getSequenceMakeIterator();
Type makeIteratorType =
cs.createTypeVariable(locator, TVO_CanBindToNoEscape);
cs.addValueWitnessConstraint(
LValueType::get(sequenceType), makeIterator,
makeIteratorType, dc, FunctionRefKind::Compound,
contextualLocator);

// Generate constraints for the "where" expression, if there is one.
if (forEachStmtInfo.whereExpr) {
auto *boolDecl = dc->getASTContext().getBoolDecl();
if (!boolDecl)
return None;

Type boolType = boolDecl->getDeclaredType();
if (!boolType)
return None;

SolutionApplicationTarget whereTarget(
forEachStmtInfo.whereExpr, dc, CTP_Condition, boolType,
/*isDiscarded=*/false);
if (cs.generateConstraints(whereTarget, FreeTypeVariableBinding::Disallow))
return None;

forEachStmtInfo.whereExpr = whereTarget.getAsExpr();
}

// Populate all of the information for a for-each loop.
forEachStmtInfo.elementType = elementType;
forEachStmtInfo.iteratorType = iteratorType;
forEachStmtInfo.initType = initType;
forEachStmtInfo.sequenceType = sequenceType;
target.setPattern(pattern);
target.getForEachStmtInfo() = forEachStmtInfo;
return target;
}

bool ConstraintSystem::generateConstraints(
SolutionApplicationTarget &target,
FreeTypeVariableBinding allowFreeTypeVariables) {
Expand Down Expand Up @@ -4186,6 +4292,16 @@ bool ConstraintSystem::generateConstraints(
return true;
}

// For a for-each statement, generate constraints for the pattern, where
// clause, and sequence traversal.
if (target.getExprContextualTypePurpose() == CTP_ForEachStmt) {
auto resultTarget = generateForEachStmtConstraints(*this, target, expr);
if (!resultTarget)
return true;

target = *resultTarget;
}

if (getASTContext().TypeCheckerOpts.DebugConstraintSolver) {
auto &log = getASTContext().TypeCheckerDebug->getStream();
log << "---Initial constraints for the given expression---\n";
Expand Down
11 changes: 1 addition & 10 deletions lib/Sema/CSSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,6 @@ static bool debugConstraintSolverForTarget(

Optional<std::vector<Solution>> ConstraintSystem::solve(
SolutionApplicationTarget &target,
ExprTypeCheckListener *listener,
FreeTypeVariableBinding allowFreeTypeVariables
) {
llvm::SaveAndRestore<bool> debugForExpr(
Expand Down Expand Up @@ -1171,7 +1170,7 @@ Optional<std::vector<Solution>> ConstraintSystem::solve(
// when there is an error and attempts to salvage an ill-formed program.
for (unsigned stage = 0; stage != 2; ++stage) {
auto solution = (stage == 0)
? solveImpl(target, listener, allowFreeTypeVariables)
? solveImpl(target, allowFreeTypeVariables)
: salvage();

switch (solution.getKind()) {
Expand Down Expand Up @@ -1237,7 +1236,6 @@ Optional<std::vector<Solution>> ConstraintSystem::solve(

SolutionResult
ConstraintSystem::solveImpl(SolutionApplicationTarget &target,
ExprTypeCheckListener *listener,
FreeTypeVariableBinding allowFreeTypeVariables) {
if (getASTContext().TypeCheckerOpts.DebugConstraintSolver) {
auto &log = getASTContext().TypeCheckerDebug->getStream();
Expand All @@ -1260,13 +1258,6 @@ ConstraintSystem::solveImpl(SolutionApplicationTarget &target,
if (generateConstraints(target, allowFreeTypeVariables))
return SolutionResult::forError();;

// Notify the listener that we've built the constraint system.
if (Expr *expr = target.getAsExpr()) {
if (listener && listener->builtConstraints(*this, expr)) {
return SolutionResult::forError();
}
}

// Try to solve the constraint system using computed suggestions.
SmallVector<Solution, 4> solutions;
solve(solutions, allowFreeTypeVariables);
Expand Down
Loading