Skip to content

Fix source compatibility problems with conforming to multiple Collection axes at once. #7136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 226 additions & 4 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "swift/AST/ASTPrinter.h"
#include "swift/AST/Decl.h"
#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/GenericSignature.h"
#include "swift/AST/NameLookup.h"
#include "swift/AST/ReferencedNameTracker.h"
#include "swift/AST/TypeMatcher.h"
Expand Down Expand Up @@ -2976,7 +2977,56 @@ InferredAssociatedTypesByWitnesses
ConformanceChecker::inferTypeWitnessesViaValueWitnesses(ValueDecl *req) {
InferredAssociatedTypesByWitnesses result;

auto isExtensionUsableForInference = [&](ExtensionDecl *extension) -> bool {
// Assume unconstrained concrete extensions we found witnesses in are
// always viable.
if (!extension->getExtendedType()->isAnyExistentialType()) {
// TODO: When constrained extensions are a thing, we'll need an "is
// as specialized as" kind of check here.
return !extension->isConstrainedExtension();
}

// The extension may not have a generic signature set up yet, as a
// recursion breaker, in which case we can't yet confidently reject its
// witnesses.
if (!extension->getGenericSignature())
return true;

// The condition here is a bit more fickle than
// `isProtocolExtensionUsable`. That check would prematurely reject
// extensions like `P where AssocType == T` if we're relying on a
// default implementation inside the extension to infer `AssocType == T`
// in the first place. Only check conformances on the `Self` type,
// because those have to be explicitly declared on the type somewhere
// so won't be affected by whatever answer inference comes up with.
auto selfTy = GenericTypeParamType::get(0, 0, TC.Context);
for (const Requirement &reqt
: extension->getGenericSignature()->getRequirements()) {
switch (reqt.getKind()) {
case RequirementKind::Conformance:
case RequirementKind::Superclass:
if (selfTy->isEqual(reqt.getFirstType())
&& !TC.isSubtypeOf(Conformance->getType(),reqt.getSecondType(), DC))
return false;
break;

case RequirementKind::Layout:
case RequirementKind::SameType:
break;
}
}

return true;
};

for (auto witness : lookupValueWitnesses(req, /*ignoringNames=*/nullptr)) {
// If the potential witness came from an extension, and our `Self`
// type can't use it regardless of what associated types we end up
// inferring, skip the witness.
if (auto extension = dyn_cast<ExtensionDecl>(witness->getDeclContext()))
if (!isExtensionUsableForInference(extension))
continue;

// Try to resolve the type witness via this value witness.
auto witnessResult = inferTypeWitnessesViaValueWitness(req, witness);

Expand Down Expand Up @@ -3283,6 +3333,28 @@ namespace {
/// The number of value witnesses that occur in protocol
/// extensions.
unsigned NumValueWitnessesInProtocolExtensions;

#ifndef NDEBUG
LLVM_ATTRIBUTE_USED
#endif
void dump() {
llvm::errs() << "Type Witnesses:\n";
for (auto &typeWitness : TypeWitnesses) {
llvm::errs() << " " << typeWitness.first->getName() << " := ";
typeWitness.second.first->print(llvm::errs());
llvm::errs() << " value " << typeWitness.second.second << '\n';
}
llvm::errs() << "Value Witnesses:\n";
for (unsigned i : indices(ValueWitnesses)) {
auto &valueWitness = ValueWitnesses[i];
llvm::errs() << i << ": " << (Decl*)valueWitness.first
<< ' ' << valueWitness.first->getName() << '\n';
valueWitness.first->getDeclContext()->dumpContext();
llvm::errs() << " for " << (Decl*)valueWitness.second
<< ' ' << valueWitness.second->getName() << '\n';
valueWitness.second->getDeclContext()->dumpContext();
}
}
};

/// A failed type witness binding.
Expand Down Expand Up @@ -3327,6 +3399,156 @@ namespace {
};
} // end anonymous namespace

static Comparison
compareDeclsForInference(TypeChecker &TC, DeclContext *DC,
ValueDecl *decl1, ValueDecl *decl2) {
// TC.compareDeclarations assumes that it's comparing two decls that
// apply equally well to a call site. We haven't yet inferred the
// associated types for a type, so the ranking algorithm used by
// compareDeclarations to score protocol extensions is inappropriate,
// since we may have potential witnesses from extensions with mutually
// exclusive associated type constraints, and compareDeclarations will
// consider these unordered since neither extension's generic signature
// is a superset of the other.

// If the witnesses come from the same decl context, score normally.
auto dc1 = decl1->getDeclContext();
auto dc2 = decl2->getDeclContext();

if (dc1 == dc2)
return TC.compareDeclarations(DC, decl1, decl2);

auto isProtocolExt1 =
(bool)dc1->getAsProtocolExtensionContext();
auto isProtocolExt2 =
(bool)dc2->getAsProtocolExtensionContext();

// If one witness comes from a protocol extension, favor the one
// from a concrete context.
if (isProtocolExt1 != isProtocolExt2) {
return isProtocolExt1 ? Comparison::Worse : Comparison::Better;
}

// If both witnesses came from concrete contexts, score normally.
// Associated type inference shouldn't impact the result.
// FIXME: It could, if someone constrained to ConcreteType.AssocType...
if (!isProtocolExt1)
return TC.compareDeclarations(DC, decl1, decl2);

// Compare protocol extensions by which protocols they require Self to
// conform to. If one extension requires a superset of the other's
// constraints, it wins.
auto sig1 = dc1->getGenericSignatureOfContext();
auto sig2 = dc2->getGenericSignatureOfContext();

// FIXME: Extensions sometimes have null generic signatures while
// checking the standard library...
if (!sig1 || !sig2)
return TC.compareDeclarations(DC, decl1, decl2);

auto selfParam = GenericTypeParamType::get(0, 0, TC.Context);

// Collect the protocols required by extension 1.
Type class1;
SmallPtrSet<ProtocolDecl*, 4> protos1;

std::function<void (ProtocolDecl*)> insertProtocol;
insertProtocol = [&](ProtocolDecl *p) {
if (!protos1.insert(p).second)
return;

for (auto parent : p->getInheritedProtocols(&TC))
insertProtocol(parent);
};

for (auto &reqt : sig1->getRequirements()) {
if (!reqt.getFirstType()->isEqual(selfParam))
continue;
switch (reqt.getKind()) {
case RequirementKind::Conformance: {
SmallVector<ProtocolDecl*, 4> protos;
reqt.getSecondType()->getAnyExistentialTypeProtocols(protos);

for (auto proto : protos) {
insertProtocol(proto);
}
break;
}
case RequirementKind::Superclass:
class1 = reqt.getSecondType();
break;

case RequirementKind::SameType:
case RequirementKind::Layout:
break;
}
}

// Compare with the protocols required by extension 2.
Type class2;
SmallPtrSet<ProtocolDecl*, 4> protos2;
bool protos2AreSubsetOf1 = true;
std::function<void (ProtocolDecl*)> removeProtocol;
removeProtocol = [&](ProtocolDecl *p) {
if (!protos2.insert(p).second)
return;

protos2AreSubsetOf1 &= protos1.erase(p);
for (auto parent : p->getInheritedProtocols(&TC))
removeProtocol(parent);
};

for (auto &reqt : sig2->getRequirements()) {
if (!reqt.getFirstType()->isEqual(selfParam))
continue;
switch (reqt.getKind()) {
case RequirementKind::Conformance: {
SmallVector<ProtocolDecl*, 4> protos;
reqt.getSecondType()->getAnyExistentialTypeProtocols(protos);

for (auto proto : protos) {
removeProtocol(proto);
}
break;
}
case RequirementKind::Superclass:
class2 = reqt.getSecondType();
break;

case RequirementKind::SameType:
case RequirementKind::Layout:
break;
}
}

auto isClassConstraintAsStrict = [&](Type t1, Type t2) -> bool {
if (!t1)
return !t2;

if (!t2)
return true;

return TC.isSubtypeOf(t1, t2, DC);
};

bool protos1AreSubsetOf2 = protos1.empty();
// If the second extension requires strictly more protocols than the
// first, it's better.
if (protos1AreSubsetOf2 > protos2AreSubsetOf1
&& isClassConstraintAsStrict(class2, class1)) {
return Comparison::Worse;
// If the first extension requires strictly more protocols than the
// second, it's better.
} else if (protos2AreSubsetOf1 > protos1AreSubsetOf2
&& isClassConstraintAsStrict(class1, class2)) {
return Comparison::Better;
}

// If they require the same set of protocols, or non-overlapping
// sets, judge them normally.
return TC.compareDeclarations(DC, decl1, decl2);
}

void ConformanceChecker::resolveTypeWitnesses() {
llvm::SetVector<AssociatedTypeDecl *> unresolvedAssocTypes;

Expand Down Expand Up @@ -3793,7 +4015,7 @@ void ConformanceChecker::resolveTypeWitnesses() {
if (firstWitness == secondWitness)
continue;

switch (TC.compareDeclarations(DC, firstWitness, secondWitness)) {
switch (compareDeclsForInference(TC, DC, firstWitness, secondWitness)) {
case Comparison::Better:
if (secondBetter)
return false;
Expand Down Expand Up @@ -3822,16 +4044,16 @@ void ConformanceChecker::resolveTypeWitnesses() {
if (compareSolutions(solutions[i], solutions[bestIdx]))
bestIdx = i;
}

// Make sure that solution is better than any of the other solutions
// Make sure that solution is better than any of the other solutions.
bool ambiguous = false;
for (unsigned i = 1, n = solutions.size(); i != n; ++i) {
if (i != bestIdx && !compareSolutions(solutions[bestIdx], solutions[i])) {
ambiguous = true;
break;
}
}

// If we had a best solution, keep just that solution.
if (!ambiguous) {
if (bestIdx != 0)
Expand Down
24 changes: 24 additions & 0 deletions stdlib/public/core/MutableCollection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,30 @@ extension MutableCollection {
}
}

extension MutableCollection where Self: BidirectionalCollection {
public subscript(bounds: Range<Index>) -> MutableBidirectionalSlice<Self> {
get {
_failEarlyRangeCheck(bounds, bounds: startIndex..<endIndex)
return MutableBidirectionalSlice(base: self, bounds: bounds)
}
set {
_writeBackMutableSlice(&self, bounds: bounds, slice: newValue)
}
}
}

extension MutableCollection where Self: RandomAccessCollection {
public subscript(bounds: Range<Index>) -> MutableRandomAccessSlice<Self> {
get {
_failEarlyRangeCheck(bounds, bounds: startIndex..<endIndex)
return MutableRandomAccessSlice(base: self, bounds: bounds)
}
set {
_writeBackMutableSlice(&self, bounds: bounds, slice: newValue)
}
}
}

@available(*, unavailable, renamed: "MutableCollection")
public typealias MutableCollectionType = MutableCollection

Expand Down
37 changes: 37 additions & 0 deletions stdlib/public/core/RangeReplaceableCollection.swift.gyb
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,43 @@ extension RangeReplaceableCollection {
public mutating func reserveCapacity(_ n: IndexDistance) {}
}

// Offer the most specific slice type available for each possible combination of
// RangeReplaceable * (1 + Bidirectional + RandomAccess) * (1 + Mutable)
// collections.

% for capability in ['', 'Bidirectional', 'RandomAccess']:
% if capability:
extension RangeReplaceableCollection where
Self: ${capability}Collection,
Self.SubSequence == RangeReplaceable${capability}Slice<Self> {
public subscript(bounds: Range<Index>)
-> RangeReplaceable${capability}Slice<Self> {
return RangeReplaceable${capability}Slice(base: self, bounds: bounds)
}
}
% end

extension RangeReplaceableCollection where
Self: MutableCollection,
% if capability:
Self: ${capability}Collection,
% end
Self.SubSequence == MutableRangeReplaceable${capability}Slice<Self>
{
public subscript(bounds: Range<Index>)
-> MutableRangeReplaceable${capability}Slice<Self> {
get {
_failEarlyRangeCheck(bounds, bounds: startIndex..<endIndex)
return MutableRangeReplaceable${capability}Slice(base: self,
bounds: bounds)
}
set {
_writeBackMutableSlice(&self, bounds: bounds, slice: newValue)
}
}
}
% end

extension RangeReplaceableCollection where SubSequence == Self {
/// Removes and returns the first element of the collection.
///
Expand Down
Loading