Skip to content

Sema: Fancier handling of associated type defaults #70241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 33 additions & 23 deletions lib/AST/ASTPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6055,7 +6055,7 @@ class TypePrinter : public TypeVisitor<TypePrinter> {
switch (Options.OpaqueReturnTypePrinting) {
case PrintOptions::OpaqueReturnTypePrintingMode::StableReference:
case PrintOptions::OpaqueReturnTypePrintingMode::Description:
return true;
return false;
case PrintOptions::OpaqueReturnTypePrintingMode::WithOpaqueKeyword:
return opaque->getDecl()->hasExplicitGenericParams();
case PrintOptions::OpaqueReturnTypePrintingMode::WithoutOpaqueKeyword:
Expand Down Expand Up @@ -7413,39 +7413,51 @@ class TypePrinter : public TypeVisitor<TypePrinter> {
Printer << "each ";
}

void printArchetypeCommon(ArchetypeType *T) {
if (Options.AlternativeTypeNames) {
auto found = Options.AlternativeTypeNames->find(T->getCanonicalType());
if (found != Options.AlternativeTypeNames->end()) {
if (T->isParameterPack()) printEach();
Printer << found->second.str();
return;
void printArchetypeCommon(Type interfaceTy, ArchetypeType *archetypeTy) {
if (auto *paramTy = interfaceTy->getAs<GenericTypeParamType>()) {
assert(archetypeTy->isRoot());

if (Options.AlternativeTypeNames) {
auto found = Options.AlternativeTypeNames->find(CanType(archetypeTy));
if (found != Options.AlternativeTypeNames->end()) {
if (paramTy->isParameterPack()) printEach();
Printer << found->second.str();
return;
}
}

visit(paramTy);
return;
}

auto interfaceType = T->getInterfaceType();
if (auto *dependentMember = interfaceType->getAs<DependentMemberType>()) {
visitParentType(T->getParent());
printDependentMember(dependentMember);
} else {
visit(interfaceType);
auto *memberTy = interfaceTy->castTo<DependentMemberType>();
if (memberTy->getBase()->is<GenericTypeParamType>())
visitParentType(archetypeTy->getRoot());
else {
printArchetypeCommon(memberTy->getBase(), archetypeTy->getRoot());
Printer << ".";
}

printDependentMember(memberTy);
}

void visitPrimaryArchetypeType(PrimaryArchetypeType *T) {
printArchetypeCommon(T);
printArchetypeCommon(T->getInterfaceType(), T);
}

void visitOpaqueTypeArchetypeType(OpaqueTypeArchetypeType *T) {
if (auto parent = T->getParent()) {
printArchetypeCommon(T);
auto interfaceTy = T->getInterfaceType();
auto *paramTy = interfaceTy->getAs<GenericTypeParamType>();

if (!paramTy) {
assert(interfaceTy->is<DependentMemberType>());
printArchetypeCommon(interfaceTy, T);
return;
}

// Try to print a named opaque type.
auto printNamedOpaque = [&] {
unsigned ordinal =
T->getInterfaceType()->castTo<GenericTypeParamType>()->getIndex();
unsigned ordinal = paramTy->getIndex();
if (auto genericParam = T->getDecl()->getExplicitGenericParam(ordinal)) {
visit(genericParam->getDeclaredInterfaceType());
return true;
Expand Down Expand Up @@ -7492,9 +7504,7 @@ class TypePrinter : public TypeVisitor<TypePrinter> {
Printer.printEscapedStringLiteral(
decl->getOpaqueReturnTypeIdentifier().str());

Printer << ", " << T->getInterfaceType()
->castTo<GenericTypeParamType>()
->getIndex();
Printer << ", " << paramTy->getIndex();

// The identifier after the closing parenthesis is irrelevant and can be
// anything. It just needs to be there for the @_opaqueReturnTypeOf
Expand Down Expand Up @@ -7526,7 +7536,7 @@ class TypePrinter : public TypeVisitor<TypePrinter> {
}

void visitPackArchetypeType(PackArchetypeType *T) {
printArchetypeCommon(T);
printArchetypeCommon(T->getInterfaceType(), T);
}

void visitGenericTypeParamType(GenericTypeParamType *T) {
Expand Down
4 changes: 2 additions & 2 deletions lib/AST/ASTVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,8 +701,8 @@ class Verifier : public ASTWalker {
auto interfaceType = archetype->getInterfaceType();
auto contextType = archetypeEnv->mapTypeIntoContext(interfaceType);

if (contextType.getPointer() != archetype) {
Out << "Archetype " << archetype->getString() << "does not appear"
if (!contextType->isEqual(archetype)) {
Out << "Archetype " << archetype->getString() << " does not appear"
<< " inside its own generic environment\n";
Out << "Interface type: " << interfaceType.getString() << "\n";
Out << "Contextual type: " << contextType.getString() << "\n";
Expand Down
5 changes: 3 additions & 2 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7253,10 +7253,11 @@ void TypeChecker::inferDefaultWitnesses(ProtocolDecl *proto) {
DefaultWitnessChecker checker(proto);

// Find the default for the given associated type.
auto findAssociatedTypeDefault = [](AssociatedTypeDecl *assocType)
auto findAssociatedTypeDefault = [proto](AssociatedTypeDecl *assocType)
-> std::pair<Type, AssociatedTypeDecl *> {
auto defaultedAssocType =
AssociatedTypeInference::findDefaultedAssociatedType(assocType);
AssociatedTypeInference::findDefaultedAssociatedType(
proto, proto, assocType);
if (!defaultedAssocType)
return {Type(), nullptr};

Expand Down
3 changes: 2 additions & 1 deletion lib/Sema/TypeCheckProtocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -1271,7 +1271,8 @@ class AssociatedTypeInference {

/// Find an associated type declaration that provides a default definition.
static AssociatedTypeDecl *findDefaultedAssociatedType(
AssociatedTypeDecl *assocType);
DeclContext *dc, NominalTypeDecl *adoptee,
AssociatedTypeDecl *assocType);
};

/// Match the given witness to the given requirement.
Expand Down
30 changes: 18 additions & 12 deletions lib/Sema/TypeCheckProtocolInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -919,26 +919,31 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitness(ValueDecl *req,
}

AssociatedTypeDecl *AssociatedTypeInference::findDefaultedAssociatedType(
DeclContext *dc,
NominalTypeDecl *adoptee,
AssociatedTypeDecl *assocType) {
// If this associated type has a default, we're done.
if (assocType->hasDefaultDefinitionType())
return assocType;

// Look at overridden associated types.
// Otherwise, look for all associated types with the same name along all the
// protocols that the adoptee conforms to.
SmallVector<ValueDecl *, 4> decls;
auto options = NL_ProtocolMembers | NL_OnlyTypes;
dc->lookupQualified(adoptee, DeclNameRef(assocType->getName()),
SourceLoc(), options, decls);

SmallPtrSet<CanType, 4> canonicalTypes;
SmallVector<AssociatedTypeDecl *, 2> results;
for (auto overridden : assocType->getOverriddenDecls()) {
auto overriddenDefault = findDefaultedAssociatedType(overridden);
if (!overriddenDefault) continue;

Type overriddenType =
overriddenDefault->getDefaultDefinitionType();
assert(overriddenType);
if (!overriddenType) continue;
for (auto *decl : decls) {
if (auto *assocDecl = dyn_cast<AssociatedTypeDecl>(decl)) {
auto defaultType = assocDecl->getDefaultDefinitionType();
if (!defaultType) continue;

CanType key = overriddenType->getCanonicalType();
CanType key = defaultType->getCanonicalType();
if (canonicalTypes.insert(key).second)
results.push_back(overriddenDefault);
results.push_back(assocDecl);
}
}

// If there was a single result, return it.
Expand Down Expand Up @@ -999,7 +1004,8 @@ llvm::Optional<AbstractTypeWitness>
AssociatedTypeInference::computeDefaultTypeWitness(
AssociatedTypeDecl *assocType) const {
// Go find a default definition.
auto *const defaultedAssocType = findDefaultedAssociatedType(assocType);
auto *const defaultedAssocType = findDefaultedAssociatedType(
dc, dc->getSelfNominalTypeDecl(), assocType);
if (!defaultedAssocType)
return llvm::None;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ protocol HasRecursiveP {

extension HasRecursiveP where T == DefinesRecursiveP.T {}
// expected-error@-1 {{cannot build rewrite system for generic signature; rule length limit exceeded}}
// expected-note@-2 {{failed rewrite rule is τ_0_0.[HasRecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[concrete: ((((((((((((@_opaqueReturnTypeOf("$s56opaque_archetype_concrete_requirement_recursive_rejected17DefinesRecursivePV1tQrvp", 0) __.T).T).T).T).T).T).T).T).T).T).T).T).T] => τ_0_0.[HasRecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T]}}
// expected-note@-2 {{failed rewrite rule is τ_0_0.[HasRecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[concrete: (@_opaqueReturnTypeOf("$s56opaque_archetype_concrete_requirement_recursive_rejected17DefinesRecursivePV1tQrvp", 0) __).T.T.T.T.T.T.T.T.T.T.T.T.T] => τ_0_0.[HasRecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T].[RecursiveP:T]}}

2 changes: 1 addition & 1 deletion test/SILGen/opaque_result_type_private_assoc_type.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import opaque_result_type_private_assoc_type_other

// CHECK-LABEL: sil hidden [ossa] @$s033opaque_result_type_private_assoc_C0028usesAssocTypeOfPrivateResultH0yyF : $@convention(thin) () -> () {
func usesAssocTypeOfPrivateResultType() {
// CHECK: [[BOX:%.*]] = alloc_stack $Optional<@_opaqueReturnTypeOf("$s033opaque_result_type_private_assoc_C6_other11doSomethingQryF", 0) __.Element>
// CHECK: [[BOX:%.*]] = alloc_stack $Optional<(@_opaqueReturnTypeOf("$s033opaque_result_type_private_assoc_C6_other11doSomethingQryF", 0) __).Element>
// CHECK: [[METHOD:%.*]] = witness_method $@_opaqueReturnTypeOf("$s033opaque_result_type_private_assoc_C6_other11doSomethingQryF", 0) __, #IteratorProtocol.next : <Self where Self : opaque_result_type_private_assoc_type_other.IteratorProtocol> (Self) -> () -> Self.Element? : $@convention(witness_method: IteratorProtocol) <τ_0_0 where τ_0_0 : IteratorProtocol> (@in_guaranteed τ_0_0) -> @out Optional<τ_0_0.Element>
// CHECK: [[RESULT:%.*]] = apply [[METHOD]]<@_opaqueReturnTypeOf("$s033opaque_result_type_private_assoc_C6_other11doSomethingQryF", 0) __>([[BOX]], {{%.*}}) : $@convention(witness_method: IteratorProtocol) <τ_0_0 where τ_0_0 : IteratorProtocol> (@in_guaranteed τ_0_0) -> @out Optional<τ_0_0.Element>
let iterator = doSomething()
Expand Down
19 changes: 19 additions & 0 deletions test/decl/protocol/req/associated_type_default_lookup.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: %target-typecheck-verify-swift

protocol P1 {
associatedtype A

func f(_: A)
}

protocol P2: P1 {
associatedtype A = Int
}

func foo<T: P1>(_: T.Type) -> T.A.Type {}

_ = foo(S.self)

struct S: P2 {
func f(_: A) {}
}
8 changes: 4 additions & 4 deletions test/decl/protocol/req/associated_type_tuple.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ protocol P1 {
extension Tuple: P1 where repeat each T: P1 {} // expected-error {{type '(repeat each T)' does not conform to protocol 'P1'}}

protocol P2 {
associatedtype A = Int // expected-note {{default type 'Int' for associated type 'A' (from protocol 'P2') is unsuitable for tuple conformance; the associated type requirement must be fulfilled by a type alias with underlying type '(repeat (each T).A)'}}
associatedtype B = Int // expected-note {{default type 'Int' for associated type 'B' (from protocol 'P2') is unsuitable for tuple conformance; the associated type requirement must be fulfilled by a type alias with underlying type '(repeat (each T).B)'}}
}

extension Tuple: P2 where repeat each T: P2 {} // expected-error {{type '(repeat each T)' does not conform to protocol 'P2'}}

protocol P3 {
associatedtype A // expected-note {{unable to infer associated type 'A' for protocol 'P3'}}
func f() -> A
associatedtype C // expected-note {{unable to infer associated type 'C' for protocol 'P3'}}
func f() -> C
}

extension Tuple: P3 where repeat each T: P3 { // expected-error {{type '(repeat each T)' does not conform to protocol 'P3'}}
func f() -> Int {} // expected-note {{cannot infer 'A' = 'Int' in tuple conformance because the associated type requirement must be fulfilled by a type alias with underlying type '(repeat (each T).A)'}}
func f() -> Int {} // expected-note {{cannot infer 'C' = 'Int' in tuple conformance because the associated type requirement must be fulfilled by a type alias with underlying type '(repeat (each T).C)'}}
}