Skip to content

[ConstraintSystem] NFC: Unify type variable and disjunction choice re… #19117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion lib/Sema/CSBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ static SmallVector<Type, 4> enumerateDirectSupertypes(Type type) {
return result;
}

bool TypeVarBindingGenerator::computeNext() {
bool TypeVarBindingProducer::computeNext() {
SmallVector<Binding, 4> newBindings;
auto addNewBinding = [&](Binding binding) {
auto type = binding.BindingType;
Expand Down Expand Up @@ -848,3 +848,34 @@ bool TypeVarBindingGenerator::computeNext() {
Bindings = std::move(newBindings);
return true;
}

void TypeVariableBinding::attempt(ConstraintSystem &cs) const {
auto type = Binding.BindingType;
auto *locator = TypeVar->getImpl().getLocator();

if (Binding.DefaultedProtocol) {
type = cs.openUnboundGenericType(type, locator);
type = type->reconstituteSugar(/*recursive=*/false);
} else if (Binding.BindingSource == ConstraintKind::ArgumentConversion &&
!type->hasTypeVariable() && cs.isCollectionType(type)) {
// If the type binding comes from the argument conversion, let's
// instead of binding collection types directly, try to bind
// using temporary type variables substituted for element
// types, that's going to ensure that subtype relationship is
// always preserved.
auto *BGT = type->castTo<BoundGenericType>();
auto UGT = UnboundGenericType::get(BGT->getDecl(), BGT->getParent(),
BGT->getASTContext());

type = cs.openUnboundGenericType(UGT, locator);
type = type->reconstituteSugar(/*recursive=*/false);
}

// FIXME: We want the locator that indicates where the binding came
// from.
cs.addConstraint(ConstraintKind::Bind, TypeVar, type, locator);

// If this was from a defaultable binding note that.
if (Binding.isDefaultableBinding())
cs.DefaultedConstraints.push_back(Binding.DefaultableBinding);
}
60 changes: 20 additions & 40 deletions lib/Sema/CSSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,38 +527,10 @@ bool ConstraintSystem::tryTypeVariableBindings(
bool anySolved = false;
bool sawFirstLiteralConstraint = false;

auto attemptTypeVarBinding = [&](PotentialBinding &binding) -> bool {
auto type = binding.BindingType;

auto attemptTypeVarBinding = [&](TypeVariableBinding &binding) -> bool {
// Try to solve the system with typeVar := type
ConstraintSystem::SolverScope scope(*this);
if (binding.DefaultedProtocol) {
type = openUnboundGenericType(type, typeVar->getImpl().getLocator());
type = type->reconstituteSugar(/*recursive=*/false);
} else if (binding.BindingSource == ConstraintKind::ArgumentConversion &&
!type->hasTypeVariable() && isCollectionType(type)) {
// If the type binding comes from the argument conversion, let's
// instead of binding collection types directly, try to bind
// using temporary type variables substituted for element
// types, that's going to ensure that subtype relationship is
// always preserved.
auto *BGT = type->castTo<BoundGenericType>();
auto UGT = UnboundGenericType::get(BGT->getDecl(), BGT->getParent(),
BGT->getASTContext());

type = openUnboundGenericType(UGT, typeVar->getImpl().getLocator());
type = type->reconstituteSugar(/*recursive=*/false);
}

// FIXME: We want the locator that indicates where the binding came
// from.
addConstraint(ConstraintKind::Bind, typeVar, type,
typeVar->getImpl().getLocator());

// If this was from a defaultable binding note that.
if (binding.isDefaultableBinding())
DefaultedConstraints.push_back(binding.DefaultableBinding);

binding.attempt(*this);
return !solveRec(solutions);
};

Expand All @@ -576,7 +548,7 @@ bool ConstraintSystem::tryTypeVariableBindings(
}

++solverState->NumTypeVariablesBound;
TypeVarBindingGenerator bindings(*this, typeVar, initialBindings);
TypeVarBindingProducer bindings(*this, typeVar, initialBindings);
while (auto binding = bindings()) {
// Try each of the bindings in turn.
++solverState->NumTypeVariableBindings;
Expand All @@ -589,18 +561,18 @@ bool ConstraintSystem::tryTypeVariableBindings(

// If we were able to solve this without considering
// default literals, don't bother looking at default literals.
if (binding->DefaultedProtocol && !sawFirstLiteralConstraint)
if (binding->hasDefaultedProtocol() && !sawFirstLiteralConstraint)
break;
}

if (TC.getLangOpts().DebugConstraintSolver) {
auto &log = getASTContext().TypeCheckerDebug->getStream();
log.indent(solverState->depth * 2)
<< "(trying " << typeVar->getString()
<< " := " << binding->BindingType->getString() << '\n';
<< " := " << binding->getType()->getString() << '\n';
}

if (binding->DefaultedProtocol)
if (binding->hasDefaultedProtocol())
sawFirstLiteralConstraint = true;

if (attemptTypeVarBinding(*binding))
Expand Down Expand Up @@ -1775,12 +1747,14 @@ Constraint *ConstraintSystem::selectDisjunction() {
}

bool ConstraintSystem::solveForDisjunctionChoices(
Disjunction &disjunction, SmallVectorImpl<Solution> &solutions) {
DisjunctionChoiceProducer &disjunction,
SmallVectorImpl<Solution> &solutions) {
Optional<Score> bestNonGenericScore;
Optional<std::pair<DisjunctionChoice, Score>> lastSolvedChoice;

// Try each of the constraints within the disjunction.
for (auto currentChoice : disjunction) {
while (auto binding = disjunction()) {
auto &currentChoice = *binding;
if (shouldSkipDisjunctionChoice(currentChoice, bestNonGenericScore))
continue;

Expand Down Expand Up @@ -1832,7 +1806,7 @@ bool ConstraintSystem::solveForDisjunctionChoices(
}
}

if (auto score = currentChoice.solve(solutions)) {
if (auto score = binding->attempt(solutions)) {
if (!currentChoice.isGenericOperator() &&
currentChoice.isSymmetricOperator()) {
if (!bestNonGenericScore || score < bestNonGenericScore)
Expand Down Expand Up @@ -1908,8 +1882,9 @@ bool ConstraintSystem::solveForDisjunction(
disjunction->shouldRememberChoice() ? disjunction->getLocator() : nullptr;
assert(!disjunction->shouldRememberChoice() || disjunction->getLocator());

auto choices = Disjunction(*this, disjunction->getNestedConstraints(),
locator, disjunction->isExplicitConversion());
auto choices =
DisjunctionChoiceProducer(*this, disjunction->getNestedConstraints(),
locator, disjunction->isExplicitConversion());

auto noSolutions = solveForDisjunctionChoices(choices, solutions);

Expand Down Expand Up @@ -1980,11 +1955,16 @@ bool ConstraintSystem::solveSimplified(SmallVectorImpl<Solution> &solutions) {
return false;
}

Optional<Score> DisjunctionChoice::solve(SmallVectorImpl<Solution> &solutions) {
void DisjunctionChoice::attempt(ConstraintSystem &cs) const {
CS->simplifyDisjunctionChoice(Choice);

if (ExplicitConversion)
propagateConversionInfo();
}

Optional<Score>
DisjunctionChoice::attempt(SmallVectorImpl<Solution> &solutions) {
attempt(*CS);

if (CS->solveRec(solutions))
return None;
Expand Down
74 changes: 52 additions & 22 deletions lib/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ namespace constraints {
class ConstraintGraph;
class ConstraintGraphNode;
class ConstraintSystem;
class Disjunction;
class TypeVarBindingGenerator;
class DisjunctionChoiceProducer;
class TypeVarBindingProducer;
class TypeVariableBinding;

} // end namespace constraints

Expand Down Expand Up @@ -925,7 +926,8 @@ class ConstraintSystem {
friend class DisjunctionChoice;
friend class Component;
friend class FailureDiagnostic;
friend class TypeVarBindingGenerator;
friend class TypeVarBindingProducer;
friend class TypeVariableBinding;

class SolverScope;

Expand Down Expand Up @@ -2983,7 +2985,7 @@ class ConstraintSystem {
/// the best solution to the constraint system.
///
/// \returns true if we failed to find any solutions, false otherwise.
bool solveForDisjunctionChoices(Disjunction &disjunction,
bool solveForDisjunctionChoices(DisjunctionChoiceProducer &disjunction,
SmallVectorImpl<Solution> &solutions);

/// \brief Solve the system of constraints after it has already been
Expand Down Expand Up @@ -3366,7 +3368,15 @@ void simplifyLocator(Expr *&anchor,
/// null otherwise.
Expr *simplifyLocatorToAnchor(ConstraintSystem &cs, ConstraintLocator *locator);

class DisjunctionChoice {
/// Common interface to encapsulate attempting choices of
/// different entities, such as type variables (types)
/// or disjunctions (their choices).
struct TypeBinding {
virtual ~TypeBinding() {}
virtual void attempt(ConstraintSystem &cs) const = 0;
};

class DisjunctionChoice : public TypeBinding {
ConstraintSystem *CS;
unsigned Index;
Constraint *Choice;
Expand Down Expand Up @@ -3398,8 +3408,10 @@ class DisjunctionChoice {
bool isGenericOperator() const;
bool isSymmetricOperator() const;

void attempt(ConstraintSystem &cs) const override;

/// \brief Apply given choice to the system and try to solve it.
Optional<Score> solve(SmallVectorImpl<Solution> &solutions);
Optional<Score> attempt(SmallVectorImpl<Solution> &solutions);

operator Constraint *() { return Choice; }

Expand Down Expand Up @@ -3428,7 +3440,25 @@ class DisjunctionChoice {
}
};

class TypeVarBindingGenerator {
class TypeVariableBinding : public TypeBinding {
TypeVariableType *TypeVar;
ConstraintSystem::PotentialBinding Binding;

public:
TypeVariableBinding(TypeVariableType *typeVar,
ConstraintSystem::PotentialBinding &binding)
: TypeVar(typeVar), Binding(binding) {}

Type getType() const { return Binding.BindingType; }

void attempt(ConstraintSystem &cs) const override;

bool isDefaultableBinding() const { return Binding.isDefaultableBinding(); }

bool hasDefaultedProtocol() const { return Binding.DefaultedProtocol; }
};

class TypeVarBindingProducer {
using BindingKind = ConstraintSystem::AllowedBindingKind;
using Binding = ConstraintSystem::PotentialBinding;

Expand All @@ -3446,19 +3476,19 @@ class TypeVarBindingGenerator {
llvm::SmallPtrSet<TypeBase *, 4> BoundTypes;

public:
TypeVarBindingGenerator(ConstraintSystem &cs, TypeVariableType *typeVar,
ArrayRef<Binding> initialBindings)
TypeVarBindingProducer(ConstraintSystem &cs, TypeVariableType *typeVar,
ArrayRef<Binding> initialBindings)
: CS(cs), TypeVar(typeVar),
Bindings(initialBindings.begin(), initialBindings.end()) {}

Optional<Binding> operator()() {
Optional<TypeVariableBinding> operator()() {
// Once we reach the end of the current bindings
// let's try to compute new ones, e.g. supertypes,
// literal defaults, if that fails, we are done.
if (needsToComputeNext() && !computeNext())
return None;

return Bindings[Index++];
return TypeVariableBinding(TypeVar, Bindings[Index++]);
}

/// Check whether generator would have to compute next
Expand All @@ -3481,7 +3511,7 @@ class TypeVarBindingGenerator {
/// Iterator over disjunction choices, makes it
/// easy to work with disjunction and encapsulates
/// some other important information such as locator.
class Disjunction {
class DisjunctionChoiceProducer {
ConstraintSystem &CS;
ArrayRef<Constraint *> Choices;
ConstraintLocator *Locator;
Expand All @@ -3490,20 +3520,20 @@ class Disjunction {
unsigned Index = 0;

public:
Disjunction(ConstraintSystem &cs, ArrayRef<Constraint *> choices,
ConstraintLocator *locator, bool explicitConversion)
DisjunctionChoiceProducer(ConstraintSystem &cs,
ArrayRef<Constraint *> choices,
ConstraintLocator *locator, bool explicitConversion)
: CS(cs), Choices(choices), Locator(locator),
IsExplicitConversion(explicitConversion) {}

const Disjunction &begin() const { return *this; }
const Disjunction &end() const { return *this; }

bool operator!=(const Disjunction &) const { return Index < Choices.size(); }

void operator++() { ++Index; }
Optional<DisjunctionChoice> operator()() {
unsigned currIndex = Index;
if (currIndex >= Choices.size())
return None;

DisjunctionChoice operator*() const {
return {&CS, Index, Choices[Index], IsExplicitConversion};
++Index;
return DisjunctionChoice(&CS, currIndex, Choices[currIndex],
IsExplicitConversion);
}

ConstraintLocator *getLocator() const { return Locator; }
Expand Down