Skip to content

Sema: Undo changes in chronological order in SolverTrail::undo() #77174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions include/swift/Sema/ConstraintGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,12 @@ class ConstraintGraphNode {
/// gets removed for a constraint graph.
void retractFromInference(Constraint *constraint);

/// Re-evaluate the given constraint. This happens when there are changes
/// in associated type variables e.g. bound/unbound to/from a fixed type,
/// equivalence class changes.
void reintroduceToInference(Constraint *constraint);

/// Similar to \c introduceToInference(Constraint *, ...) this method is going
/// to notify inference that this type variable has been bound to a concrete
/// type.
/// Perform graph updates that must be undone after we bind a fixed type
/// to a type variable.
void retractFromInference(Type fixedType);

/// Perform graph updates that must be undone before we bind a fixed type
/// to a type variable.
///
/// The reason why this can't simplify be a part of \c bindTypeVariable
/// is related to the fact that it's sometimes expensive to re-compute
Expand All @@ -161,12 +159,18 @@ class ConstraintGraphNode {
///
/// This is useful in situations when type variable gets bound and unbound,
/// or equivalence class changes.
void notifyReferencingVars() const;
void notifyReferencingVars(
llvm::function_ref<void(ConstraintGraphNode &,
Constraint *)> notification) const;

/// Notify all of the type variables referenced by this one about a change.
void notifyReferencedVars(
llvm::function_ref<void(ConstraintGraphNode &)> notification);
llvm::function_ref<void(ConstraintGraphNode &)> notification) const;

void updateFixedType(
Type fixedType,
llvm::function_ref<void (ConstraintGraphNode &,
Constraint *)> notification) const;
/// }

/// The constraint graph this node belongs to.
Expand Down Expand Up @@ -261,16 +265,28 @@ class ConstraintGraph {
/// Primitive form for SolverTrail::Change::undo().
void removeConstraint(TypeVariableType *typeVar, Constraint *constraint);

/// Prepare to merge the given node into some other node.
///
/// This records graph changes that must be undone after the merge has
/// been undone.
void mergeNodesPre(TypeVariableType *typeVar2);

/// Merge the two nodes for the two given type variables.
///
/// The type variables must actually have been merged already; this
/// operation merges the two nodes.
/// operation merges the two nodes. This also records graph changes
/// that must be undone before the merge can be undone.
void mergeNodes(TypeVariableType *typeVar1, TypeVariableType *typeVar2);

/// Bind the given type variable to the given fixed type.
void bindTypeVariable(TypeVariableType *typeVar, Type fixedType);

/// Introduce the type variable's fixed type to inference.
/// Perform graph updates that must be undone after we bind a fixed type
/// to a type variable.
void retractFromInference(TypeVariableType *typeVar, Type fixedType);

/// Perform graph updates that must be undone before we bind a fixed type
/// to a type variable.
void introduceToInference(TypeVariableType *typeVar, Type fixedType);

/// Describes which constraints \c gatherConstraints should gather.
Expand Down
7 changes: 1 addition & 6 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -555,12 +555,7 @@ class TypeVariableType::Implementation {
/// \param trail The record of state changes.
void mergeEquivalenceClasses(TypeVariableType *other,
constraints::SolverTrail *trail) {
// Merge the equivalence classes corresponding to these two type
// variables. Always merge 'up' the constraint stack, because it is simpler.
if (getID() > other->getImpl().getID()) {
other->getImpl().mergeEquivalenceClasses(getTypeVariable(), trail);
return;
}
ASSERT(getID() < other->getImpl().getID());

auto otherRep = other->getImpl().getRepresentative(trail);
if (trail)
Expand Down
15 changes: 2 additions & 13 deletions lib/Sema/CSTrail.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,21 +730,10 @@ void SolverTrail::undo(unsigned toIndex) {
ASSERT(!UndoActive);
UndoActive = true;

// FIXME: Undo all changes in the correct order!
for (unsigned i = Changes.size(); i > toIndex; i--) {
auto change = Changes[i - 1];
if (change.Kind == ChangeKind::UpdatedTypeVariable) {
LLVM_DEBUG(llvm::dbgs() << "- "; change.dump(llvm::dbgs(), CS, 0));
change.undo(CS);
}
}

for (unsigned i = Changes.size(); i > toIndex; i--) {
auto change = Changes[i - 1];
if (change.Kind != ChangeKind::UpdatedTypeVariable) {
LLVM_DEBUG(llvm::dbgs() << "- "; change.dump(llvm::dbgs(), CS, 0));
change.undo(CS);
}
LLVM_DEBUG(llvm::dbgs() << "- "; change.dump(llvm::dbgs(), CS, 0));
change.undo(CS);
}

Changes.resize(toIndex);
Expand Down
120 changes: 76 additions & 44 deletions lib/Sema/ConstraintGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ ConstraintGraph::lookupNode(TypeVariableType *typeVar) {
// If this type variable is not the representative of its equivalence class,
// add it to its representative's set of equivalences.
auto typeVarRep = CS.getRepresentative(typeVar);
if (typeVar != typeVarRep)
mergeNodes(typeVar, typeVarRep);
if (typeVar != typeVarRep) {
mergeNodesPre(typeVar);
mergeNodes(typeVarRep, typeVar);
}
else if (auto fixed = CS.getFixedType(typeVarRep)) {
// Bind the type variable.
bindTypeVariable(typeVar, fixed);
Expand Down Expand Up @@ -177,7 +179,9 @@ void ConstraintGraphNode::removeConstraint(Constraint *constraint) {
Constraints.pop_back();
}

void ConstraintGraphNode::notifyReferencingVars() const {
void ConstraintGraphNode::notifyReferencingVars(
llvm::function_ref<void(ConstraintGraphNode &,
Constraint *)> notification) const {
SmallVector<TypeVariableType *, 4> stack;

stack.push_back(TypeVar);
Expand All @@ -199,7 +203,7 @@ void ConstraintGraphNode::notifyReferencingVars() const {
affectedVar->getImpl().getRepresentative(/*record=*/nullptr);

if (!repr->getImpl().getFixedType(/*record=*/nullptr))
CG[repr].reintroduceToInference(constraint);
notification(CG[repr], constraint);
}
}
};
Expand Down Expand Up @@ -236,7 +240,7 @@ void ConstraintGraphNode::notifyReferencingVars() const {
}

void ConstraintGraphNode::notifyReferencedVars(
llvm::function_ref<void(ConstraintGraphNode &)> notification) {
llvm::function_ref<void(ConstraintGraphNode &)> notification) const {
for (auto *fixedBinding : getReferencedVars()) {
notification(CG[fixedBinding]);
}
Expand All @@ -249,25 +253,6 @@ void ConstraintGraphNode::addToEquivalenceClass(
if (EquivalenceClass.empty())
EquivalenceClass.push_back(getTypeVariable());
EquivalenceClass.append(typeVars.begin(), typeVars.end());

{
for (auto *newMember : typeVars) {
auto &node = CG[newMember];

for (auto *constraint : node.getConstraints()) {
introduceToInference(constraint);

if (!isUsefulForReferencedVars(constraint))
continue;

notifyReferencedVars([&](ConstraintGraphNode &referencedVar) {
referencedVar.introduceToInference(constraint);
});
}

node.notifyReferencingVars();
}
}
}

void ConstraintGraphNode::truncateEquivalenceClass(unsigned prevSize) {
Expand Down Expand Up @@ -343,19 +328,17 @@ void ConstraintGraphNode::retractFromInference(Constraint *constraint) {
}
}

void ConstraintGraphNode::reintroduceToInference(Constraint *constraint) {
retractFromInference(constraint);
introduceToInference(constraint);
}

void ConstraintGraphNode::introduceToInference(Type fixedType) {
void ConstraintGraphNode::updateFixedType(
Type fixedType,
llvm::function_ref<void (ConstraintGraphNode &,
Constraint *)> notification) const {
// Notify all of the type variables that reference this one.
//
// Since this type variable has been replaced with a fixed type
// all of the concrete types that reference it are going to change,
// which means that all of the not-yet-attempted bindings should
// change as well.
notifyReferencingVars();
notifyReferencingVars(notification);

if (!fixedType->hasTypeVariable())
return;
Expand All @@ -371,11 +354,27 @@ void ConstraintGraphNode::introduceToInference(Type fixedType) {
// all of the constraints that reference bound type variable.
for (auto *constraint : getConstraints()) {
if (isUsefulForReferencedVars(constraint))
node.reintroduceToInference(constraint);
notification(node, constraint);
}
}
}

void ConstraintGraphNode::retractFromInference(Type fixedType) {
return updateFixedType(
fixedType,
[](ConstraintGraphNode &node, Constraint *constraint) {
node.retractFromInference(constraint);
});
}

void ConstraintGraphNode::introduceToInference(Type fixedType) {
return updateFixedType(
fixedType,
[](ConstraintGraphNode &node, Constraint *constraint) {
node.introduceToInference(constraint);
});
}

#pragma mark Graph mutation

void ConstraintGraph::removeNode(TypeVariableType *typeVar) {
Expand Down Expand Up @@ -486,31 +485,60 @@ void ConstraintGraph::removeConstraint(TypeVariableType *typeVar,
OrphanedConstraints.pop_back();
}

void ConstraintGraph::mergeNodesPre(TypeVariableType *typeVar2) {
// Merge equivalence class from the non-representative type variable.
auto &nonRepNode = (*this)[typeVar2];

for (auto *newMember : nonRepNode.getEquivalenceClassUnsafe()) {
auto &node = (*this)[newMember];

node.notifyReferencingVars(
[&](ConstraintGraphNode &node, Constraint *constraint) {
node.retractFromInference(constraint);
});
}
}

void ConstraintGraph::mergeNodes(TypeVariableType *typeVar1,
TypeVariableType *typeVar2) {
assert(CS.getRepresentative(typeVar1) == CS.getRepresentative(typeVar2) &&
"type representatives don't match");

// Retrieve the node for the representative that we're merging into.
auto typeVarRep = CS.getRepresentative(typeVar1);
auto &repNode = (*this)[typeVarRep];
ASSERT(CS.getRepresentative(typeVar1) == typeVar1);

// Retrieve the node for the non-representative.
assert((typeVar1 == typeVarRep || typeVar2 == typeVarRep) &&
"neither type variable is the new representative?");
auto typeVarNonRep = typeVar1 == typeVarRep? typeVar2 : typeVar1;
auto &repNode = (*this)[typeVar1];

// Record the change, if there are active scopes.
if (CS.isRecordingChanges()) {
CS.recordChange(
SolverTrail::Change::ExtendedEquivalenceClass(
typeVarRep,
typeVar1,
repNode.getEquivalenceClass().size()));
}

// Merge equivalence class from the non-representative type variable.
auto &nonRepNode = (*this)[typeVarNonRep];
repNode.addToEquivalenceClass(nonRepNode.getEquivalenceClassUnsafe());
auto &nonRepNode = (*this)[typeVar2];

auto typeVars = nonRepNode.getEquivalenceClassUnsafe();
repNode.addToEquivalenceClass(typeVars);

for (auto *newMember : typeVars) {
auto &node = (*this)[newMember];

for (auto *constraint : node.getConstraints()) {
repNode.introduceToInference(constraint);

if (!isUsefulForReferencedVars(constraint))
continue;

repNode.notifyReferencedVars([&](ConstraintGraphNode &referencedVar) {
referencedVar.introduceToInference(constraint);
});
}

node.notifyReferencingVars(
[&](ConstraintGraphNode &node, Constraint *constraint) {
node.introduceToInference(constraint);
});
}
}

void ConstraintGraph::bindTypeVariable(TypeVariableType *typeVar, Type fixed) {
Expand All @@ -537,6 +565,10 @@ void ConstraintGraph::bindTypeVariable(TypeVariableType *typeVar, Type fixed) {
}
}

void ConstraintGraph::retractFromInference(TypeVariableType *typeVar, Type fixed) {
(*this)[typeVar].retractFromInference(fixed);
}

void ConstraintGraph::introduceToInference(TypeVariableType *typeVar, Type fixed) {
(*this)[typeVar].introduceToInference(fixed);
}
Expand Down
9 changes: 7 additions & 2 deletions lib/Sema/ConstraintSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,13 @@ void ConstraintSystem::mergeEquivalenceClasses(TypeVariableType *typeVar1,
assert(typeVar2 == getRepresentative(typeVar2) &&
"typeVar2 is not the representative");
assert(typeVar1 != typeVar2 && "cannot merge type with itself");
typeVar1->getImpl().mergeEquivalenceClasses(typeVar2, getTrail());

// Merge nodes in the constraint graph.
// Always merge 'up' the constraint stack, because it is simpler.
if (typeVar1->getImpl().getID() > typeVar2->getImpl().getID())
std::swap(typeVar1, typeVar2);

CG.mergeNodesPre(typeVar2);
typeVar1->getImpl().mergeEquivalenceClasses(typeVar2, getTrail());
CG.mergeNodes(typeVar1, typeVar2);

if (updateWorkList) {
Expand Down Expand Up @@ -205,6 +209,7 @@ void ConstraintSystem::assignFixedType(TypeVariableType *typeVar, Type type,
assert(!type->hasError() &&
"Should not be assigning a type involving ErrorType!");

CG.retractFromInference(typeVar, type);
typeVar->getImpl().assignFixedType(type, getTrail());

if (!updateState)
Expand Down
2 changes: 1 addition & 1 deletion test/Sema/issue-46000.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct Data {}
extension DispatchData {
func asFoundationData<T>(execute: (Data) throws -> T) rethrows -> T {
return try withUnsafeBytes { (ptr: UnsafePointer<Int8>) -> Void in
// expected-error@-1 {{cannot convert return expression of type 'Void' to return type 'T'}}
// expected-error@-1 {{declared closure result 'Void' is incompatible with contextual type 'T'}}
let data = Data()
return try execute(data) // expected-error {{cannot convert value of type 'T' to closure result type 'Void'}}
}
Expand Down
5 changes: 1 addition & 4 deletions test/type/opaque.swift
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,8 @@ func associatedTypeIdentity() {
sameType(cr, dr) // expected-error {{conflicting arguments to generic parameter 'T' ('(some R).S' (result type of 'candace') vs. '(some R).S' (result type of 'doug'))}}
sameType(gary(candace()).r_out(), gary(candace()).r_out())
sameType(gary(doug()).r_out(), gary(doug()).r_out())
// TODO(diagnostics): This is not great but the problem comes from the way solver discovers and attempts bindings, if we could detect that
// `(some R).S` from first reference to `gary()` in inconsistent with the second one based on the parent type of `S` it would be much easier to diagnose.
sameType(gary(doug()).r_out(), gary(candace()).r_out())
// expected-error@-1:12 {{conflicting arguments to generic parameter 'T' ('some R' (result type of 'doug') vs. 'some R' (result type of 'candace'))}}
// expected-error@-2:34 {{conflicting arguments to generic parameter 'T' ('some R' (result type of 'doug') vs. 'some R' (result type of 'candace'))}}
// expected-error@-1:39 {{cannot convert value of type 'some R' (result of 'candace()') to expected argument type 'some R' (result of 'doug()')}}
}

func redeclaration() -> some P { return 0 } // expected-note 2{{previously declared}}
Expand Down