Skip to content

Sema: Mostly remove one-way constraints #79914

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions include/swift/Sema/Constraint.h
Original file line number Diff line number Diff line change
Expand Up @@ -857,11 +857,6 @@ class Constraint final : public llvm::ilist_node<Constraint>,
/// from the rest of the constraint system.
bool isIsolated() const { return IsIsolated; }

/// Whether this is a one-way constraint.
bool isOneWayConstraint() const {
return Kind == ConstraintKind::OneWayEqual;
}

/// Retrieve the overload choice for an overload-binding constraint.
OverloadChoice getOverloadChoice() const {
assert(Kind == ConstraintKind::BindOverload);
Expand Down
38 changes: 10 additions & 28 deletions include/swift/Sema/ConstraintGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,25 +295,18 @@ class ConstraintGraph {
/// to a type variable.
void introduceToInference(TypeVariableType *typeVar, Type fixedType);

/// Describes which constraints \c gatherConstraints should gather.
enum class GatheringKind {
/// Gather constraints associated with all of the variables within the
/// same equivalence class as the given type variable, as well as its
/// immediate fixed bindings.
EquivalenceClass,
/// Gather all constraints that mention this type variable or type variables
/// that it is a fixed binding of. Unlike EquivalenceClass, this looks
/// through transitive fixed bindings. This can be used to find all the
/// constraints that may be affected when binding a type variable.
AllMentions,
};
/// Gather constraints associated with all of the variables within the
/// same equivalence class as the given type variable, as well as its
/// immediate fixed bindings.
llvm::TinyPtrVector<Constraint *>
gatherAllConstraints(TypeVariableType *typeVar);

/// Gather the set of constraints that involve the given type variable,
/// i.e., those constraints that will be affected when the type variable
/// gets merged or bound to a fixed type.
/// Gather all constraints that mention this type variable or type variables
/// that it is a fixed binding of. Unlike EquivalenceClass, this looks
/// through transitive fixed bindings. This can be used to find all the
/// constraints that may be affected when binding a type variable.
llvm::TinyPtrVector<Constraint *>
gatherConstraints(TypeVariableType *typeVar,
GatheringKind kind,
gatherNearbyConstraints(TypeVariableType *typeVar,
llvm::function_ref<bool(Constraint *)> acceptConstraint =
[](Constraint *constraint) { return true; });

Expand All @@ -338,12 +331,6 @@ class ConstraintGraph {
/// The constraints in this component.
TinyPtrVector<Constraint *> constraints;

/// The set of components that this component depends on, such that
/// the partial solutions of the those components need to be available
/// before this component can be solved.
///
SmallVector<unsigned, 2> dependencies;

public:
Component(unsigned solutionIndex) : solutionIndex(solutionIndex) { }

Expand All @@ -359,11 +346,6 @@ class ConstraintGraph {
return constraints;
}

/// Records a component which this component depends on.
void recordDependency(const Component &component);

ArrayRef<unsigned> getDependencies() const { return dependencies; }

unsigned getNumDisjunctions() const { return numDisjunctions; }
};

Expand Down
4 changes: 2 additions & 2 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,8 @@ namespace {

llvm::SmallSetVector<ProtocolDecl *, 2> literalProtos;
if (auto argTypeVar = argTy->getAs<TypeVariableType>()) {
auto constraints = CS.getConstraintGraph().gatherConstraints(
argTypeVar, ConstraintGraph::GatheringKind::EquivalenceClass,
auto constraints = CS.getConstraintGraph().gatherNearbyConstraints(
argTypeVar,
[](Constraint *constraint) {
return constraint->getKind() == ConstraintKind::LiteralConformsTo;
});
Expand Down
4 changes: 2 additions & 2 deletions lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15953,8 +15953,8 @@ ConstraintSystem::addKeyPathApplicationRootConstraint(Type root, ConstraintLocat
if (!typeVar)
return;

auto constraints = CG.gatherConstraints(
typeVar, ConstraintGraph::GatheringKind::EquivalenceClass,
auto constraints = CG.gatherNearbyConstraints(
typeVar,
[&keyPathExpr](Constraint *constraint) -> bool {
if (constraint->getKind() != ConstraintKind::KeyPath)
return false;
Expand Down
11 changes: 5 additions & 6 deletions lib/Sema/CSSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1876,8 +1876,8 @@ static Constraint *selectBestBindingDisjunction(
if (!firstBindDisjunction)
firstBindDisjunction = disjunction;

auto constraints = cs.getConstraintGraph().gatherConstraints(
typeVar, ConstraintGraph::GatheringKind::EquivalenceClass,
auto constraints = cs.getConstraintGraph().gatherNearbyConstraints(
typeVar,
[](Constraint *constraint) {
return constraint->getKind() == ConstraintKind::Conversion;
});
Expand Down Expand Up @@ -1906,8 +1906,8 @@ ConstraintSystem::findConstraintThroughOptionals(
while (visitedVars.insert(rep).second) {
// Look for a disjunction that binds this type variable to an overload set.
TypeVariableType *optionalObjectTypeVar = nullptr;
auto constraints = getConstraintGraph().gatherConstraints(
rep, ConstraintGraph::GatheringKind::EquivalenceClass,
auto constraints = getConstraintGraph().gatherNearbyConstraints(
rep,
[&](Constraint *match) {
// If we have an "optional object of" constraint, we may need to
// look through it to find the constraint we're looking for.
Expand Down Expand Up @@ -2549,9 +2549,8 @@ void DisjunctionChoice::propagateConversionInfo(ConstraintSystem &cs) const {
}
}

auto constraints = cs.CG.gatherConstraints(
auto constraints = cs.CG.gatherNearbyConstraints(
typeVar,
ConstraintGraph::GatheringKind::EquivalenceClass,
[](Constraint *constraint) -> bool {
switch (constraint->getKind()) {
case ConstraintKind::Conversion:
Expand Down
104 changes: 5 additions & 99 deletions lib/Sema/CSStep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ void SplitterStep::computeFollowupSteps(
// Take the orphaned constraints, because they'll go into a component now.
OrphanedConstraints = CG.takeOrphanedConstraints();

IncludeInMergedResults.resize(numComponents, true);
Components.resize(numComponents);
PartialSolutions = std::unique_ptr<SmallVector<Solution, 4>[]>(
new SmallVector<Solution, 4>[numComponents]);
Expand All @@ -135,26 +134,9 @@ void SplitterStep::computeFollowupSteps(
for (unsigned i : indices(components)) {
unsigned solutionIndex = components[i].solutionIndex;

// If there are no dependencies, build a normal component step.
if (components[i].getDependencies().empty()) {
steps.push_back(std::make_unique<ComponentStep>(
CS, solutionIndex, &Components[i], std::move(components[i]),
PartialSolutions[solutionIndex]));
continue;
}

// Note that the partial results from any dependencies of this component
// need not be included in the final merged results, because they'll
// already be part of the partial results for this component.
for (auto dependsOn : components[i].getDependencies()) {
IncludeInMergedResults[dependsOn] = false;
}

// Otherwise, build a dependent component "splitter" step, which
// handles all combinations of incoming partial solutions.
steps.push_back(std::make_unique<DependentComponentSplitterStep>(
CS, &Components[i], solutionIndex, std::move(components[i]),
llvm::MutableArrayRef(PartialSolutions.get(), numComponents)));
steps.push_back(std::make_unique<ComponentStep>(
CS, solutionIndex, &Components[i], std::move(components[i]),
PartialSolutions[solutionIndex]));
}

assert(CS.InactiveConstraints.empty() && "Missed a constraint");
Expand Down Expand Up @@ -223,8 +205,7 @@ bool SplitterStep::mergePartialSolutions() const {
SmallVector<unsigned, 2> countsVec;
countsVec.reserve(numComponents);
for (unsigned idx : range(numComponents)) {
countsVec.push_back(
IncludeInMergedResults[idx] ? PartialSolutions[idx].size() : 1);
countsVec.push_back(PartialSolutions[idx].size());
}

// Produce all combinations of partial solutions.
Expand All @@ -237,9 +218,6 @@ bool SplitterStep::mergePartialSolutions() const {
// solutions.
ConstraintSystem::SolverScope scope(CS);
for (unsigned i : range(numComponents)) {
if (!IncludeInMergedResults[i])
continue;

CS.replaySolution(PartialSolutions[i][indices[i]]);
}

Expand Down Expand Up @@ -271,87 +249,15 @@ bool SplitterStep::mergePartialSolutions() const {
return anySolutions;
}

StepResult DependentComponentSplitterStep::take(bool prevFailed) {
// "split" is considered a failure if previous step failed,
// or there is a failure recorded by constraint system, or
// system can't be simplified.
if (prevFailed || CS.getFailedConstraint() || CS.simplify())
return done(/*isSuccess=*/false);

// Figure out the sets of partial solutions that this component depends on.
SmallVector<const SmallVector<Solution, 4> *, 2> dependsOnSets;
for (auto index : Component.getDependencies()) {
dependsOnSets.push_back(&AllPartialSolutions[index]);
}

// Produce all combinations of partial solutions for the inputs.
SmallVector<std::unique_ptr<SolverStep>, 4> followup;
SmallVector<unsigned, 2> indices(Component.getDependencies().size(), 0);
auto dependsOnSetsRef = llvm::ArrayRef(dependsOnSets);
do {
// Form the set of input partial solutions.
SmallVector<const Solution *, 2> dependsOnSolutions;
for (auto index : swift::indices(indices)) {
dependsOnSolutions.push_back(&(*dependsOnSets[index])[indices[index]]);
}
ContextualSolutions.push_back(std::make_unique<SmallVector<Solution, 2>>());

followup.push_back(std::make_unique<ComponentStep>(
CS, Index, Constraints, Component, std::move(dependsOnSolutions),
*ContextualSolutions.back()));
} while (nextCombination(dependsOnSetsRef, indices));

/// Wait until all of the component steps are done.
return suspend(followup);
}

StepResult DependentComponentSplitterStep::resume(bool prevFailed) {
for (auto &ComponentStepSolutions : ContextualSolutions) {
Solutions.append(std::make_move_iterator(ComponentStepSolutions->begin()),
std::make_move_iterator(ComponentStepSolutions->end()));
}
return done(/*isSuccess=*/!Solutions.empty());
}

void DependentComponentSplitterStep::print(llvm::raw_ostream &Out) {
Out << "DependentComponentSplitterStep for dependencies on [";
interleave(
Component.getDependencies(), [&](unsigned index) { Out << index; },
[&] { Out << ", "; });
Out << "]\n";
}

StepResult ComponentStep::take(bool prevFailed) {
// One of the previous components created by "split"
// failed, it means that we can't solve this component.
if ((prevFailed && DependsOnPartialSolutions.empty()) ||
CS.isTooComplex(Solutions) || CS.worseThanBestSolution())
if (prevFailed || CS.isTooComplex(Solutions) || CS.worseThanBestSolution())
return done(/*isSuccess=*/false);

// Setup active scope, only if previous component didn't fail.
setupScope();

// If there are any dependent partial solutions to compose, do so now.
if (!DependsOnPartialSolutions.empty()) {
for (auto partial : DependsOnPartialSolutions) {
CS.replaySolution(*partial);
}

// Activate all of the one-way constraints.
SmallVector<Constraint *, 4> oneWayConstraints;
for (auto &constraint : CS.InactiveConstraints) {
if (constraint.isOneWayConstraint())
oneWayConstraints.push_back(&constraint);
}
for (auto constraint : oneWayConstraints) {
CS.activateConstraint(constraint);
}

// Simplify again.
if (CS.failedConstraint || CS.simplify())
return done(/*isSuccess=*/false);
}

/// Try to figure out what this step is going to be,
/// after the scope has been established.
SmallString<64> potentialBindings;
Expand Down
66 changes: 1 addition & 65 deletions lib/Sema/CSStep.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,6 @@ class SplitterStep final : public SolverStep {

SmallVector<Constraint *, 4> OrphanedConstraints;

/// Whether to include the partial results of this component in the final
/// merged results.
SmallVector<bool, 4> IncludeInMergedResults;

public:
SplitterStep(ConstraintSystem &cs, SmallVectorImpl<Solution> &solutions)
: SolverStep(cs, solutions) {}
Expand All @@ -269,56 +265,6 @@ class SplitterStep final : public SolverStep {
bool mergePartialSolutions() const;
};

/// `DependentComponentSplitterStep` is responsible for composing the partial
/// solutions from other components (on which this component depends) into
/// the inputs based on which we can solve a particular component.
class DependentComponentSplitterStep final : public SolverStep {
/// Constraints "in scope" of this step.
ConstraintList *Constraints;

/// Index into the parent splitter step.
unsigned Index;

/// The component that has dependencies.
ConstraintGraph::Component Component;

/// Array containing all of the partial solutions for the parent split.
MutableArrayRef<SmallVector<Solution, 4>> AllPartialSolutions;

/// The solutions computed the \c ComponentSteps created for each partial
/// solution combinations. Will be merged into the final \c Solutions vector
/// in \c resume.
std::vector<std::unique_ptr<SmallVector<Solution, 2>>> ContextualSolutions;

/// Take all of the constraints in this component and put them into
/// \c Constraints.
void injectConstraints() {
for (auto constraint : Component.getConstraints()) {
Constraints->erase(constraint);
Constraints->push_back(constraint);
}
}

public:
DependentComponentSplitterStep(
ConstraintSystem &cs,
ConstraintList *constraints,
unsigned index,
ConstraintGraph::Component &&component,
MutableArrayRef<SmallVector<Solution, 4>> allPartialSolutions)
: SolverStep(cs, allPartialSolutions[index]), Constraints(constraints),
Index(index), Component(std::move(component)),
AllPartialSolutions(allPartialSolutions) {
assert(!Component.getDependencies().empty() && "Should use ComponentStep");
injectConstraints();
}

StepResult take(bool prevFailed) override;
StepResult resume(bool prevFailed) override;

void print(llvm::raw_ostream &Out) override;
};


/// `ComponentStep` represents a set of type variables and related
/// constraints which could be solved independently. It's further
Expand Down Expand Up @@ -381,10 +327,6 @@ class ComponentStep final : public SolverStep {
/// Constraints "in scope" of this step.
ConstraintList *Constraints;

/// The set of partial solutions that should be composed before evaluating
/// this component.
SmallVector<const Solution *, 2> DependsOnPartialSolutions;

/// Constraint which doesn't have any free type variables associated
/// with it, which makes it disconnected in the graph.
Constraint *OrphanedConstraint = nullptr;
Expand Down Expand Up @@ -419,8 +361,6 @@ class ComponentStep final : public SolverStep {
constraints->erase(constraint);
Constraints->push_back(constraint);
}

assert(component.getDependencies().empty());
}

/// Create a component step that composes existing partial solutions before
Expand All @@ -429,15 +369,11 @@ class ComponentStep final : public SolverStep {
ConstraintSystem &cs, unsigned index,
ConstraintList *constraints,
const ConstraintGraph::Component &component,
llvm::SmallVectorImpl<const Solution *> &&dependsOnPartialSolutions,
SmallVectorImpl<Solution> &solutions)
: SolverStep(cs, solutions), Index(index), IsSingle(false),
OriginalScore(getCurrentScore()), OriginalBestScore(getBestScore()),
Constraints(constraints),
DependsOnPartialSolutions(std::move(dependsOnPartialSolutions)) {
Constraints(constraints) {
TypeVars = component.typeVars;
assert(DependsOnPartialSolutions.size() ==
component.getDependencies().size());

for (auto constraint : component.getConstraints()) {
constraints->erase(constraint);
Expand Down
Loading