Skip to content

New syntax for declaring primary associated types #41640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions include/swift/AST/Attr.def
Original file line number Diff line number Diff line change
Expand Up @@ -661,10 +661,7 @@ CONTEXTUAL_SIMPLE_DECL_ATTR(distributed, DistributedActor,
APIBreakingToAdd | APIBreakingToRemove,
118)

SIMPLE_DECL_ATTR(_primaryAssociatedType,
PrimaryAssociatedType, OnAssociatedType | UserInaccessible |
APIStableToAdd | ABIStableToAdd | APIBreakingToRemove | ABIStableToRemove,
119)
// 119 is unused

SIMPLE_DECL_ATTR(_assemblyVision, EmitAssemblyVisionRemarks,
OnFunc | UserInaccessible | NotSerialized | OnNominalType |
Expand Down
13 changes: 13 additions & 0 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,11 @@ class alignas(1 << DeclAlignInBits) Decl : public ASTAllocated<Decl> {
IsOpaqueType : 1
);

SWIFT_INLINE_BITFIELD_FULL(AssociatedTypeDecl, AbstractTypeParamDecl, 1,
/// Whether this is a primary associated type.
IsPrimary : 1
);

SWIFT_INLINE_BITFIELD_EMPTY(GenericTypeDecl, TypeDecl);

SWIFT_INLINE_BITFIELD(TypeAliasDecl, GenericTypeDecl, 1+1,
Expand Down Expand Up @@ -3208,6 +3213,14 @@ class AssociatedTypeDecl : public AbstractTypeParamDecl {
LazyMemberLoader *definitionResolver,
uint64_t resolverData);

bool isPrimary() const {
return Bits.AssociatedTypeDecl.IsPrimary;
}

void setPrimary() {
Bits.AssociatedTypeDecl.IsPrimary = true;
}

/// Get the protocol in which this associated type is declared.
ProtocolDecl *getProtocol() const {
return cast<ProtocolDecl>(getDeclContext());
Expand Down
10 changes: 8 additions & 2 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1797,8 +1797,14 @@ ERROR(redundant_class_requirement,none,
ERROR(late_class_requirement,none,
"'class' must come first in the requirement list", ())
ERROR(where_inside_brackets,none,
"'where' clause next to generic parameters is obsolete, "
"must be written following the declaration's type", ())
"'where' clause next to generic parameters is obsolete, "
"must be written following the declaration's type", ())


ERROR(expected_rangle_primary_associated_type_list,PointsToFirstBadToken,
"expected '>' to complete primary associated type list", ())
ERROR(expected_primary_associated_type_name,PointsToFirstBadToken,
"expected an identifier to name primary associated type", ())

//------------------------------------------------------------------------------
// MARK: Conditional compilation parsing diagnostics
Expand Down
6 changes: 6 additions & 0 deletions include/swift/AST/PrintOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,12 @@ struct PrintOptions {
QualifyNestedDeclarations ShouldQualifyNestedDeclarations =
QualifyNestedDeclarations::Never;

/// If true, we print a protocol's primary associated types using the
/// primary associated type syntax: protocol Foo<Type1, ...>.
///
/// If false, we print them as ordinary associated types.
bool PrintPrimaryAssociatedTypes = true;

/// If this is not \c nullptr then function bodies (including accessors
/// and constructors) will be printed by this function.
std::function<void(const ValueDecl *, ASTPrinter &)> FunctionBody;
Expand Down
2 changes: 1 addition & 1 deletion include/swift/Basic/Features.def
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,6 @@ LANGUAGE_FEATURE(BuiltinStackAlloc, 0, "Builtin.stackAlloc", true)
SUPPRESSIBLE_LANGUAGE_FEATURE(SpecializeAttributeWithAvailability, 0, "@_specialize attribute with availability", true)
LANGUAGE_FEATURE(BuiltinAssumeAlignment, 0, "Builtin.assumeAlignment", true)
SUPPRESSIBLE_LANGUAGE_FEATURE(UnsafeInheritExecutor, 0, "@_unsafeInheritExecutor", true)

SUPPRESSIBLE_LANGUAGE_FEATURE(PrimaryAssociatedTypes, 0, "Primary associated types", true)
#undef SUPPRESSIBLE_LANGUAGE_FEATURE
#undef LANGUAGE_FEATURE
3 changes: 3 additions & 0 deletions include/swift/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,9 @@ class Parser {
void parseAbstractFunctionBody(AbstractFunctionDecl *AFD);
BodyAndFingerprint
parseAbstractFunctionBodyDelayed(AbstractFunctionDecl *AFD);

ParserStatus parsePrimaryAssociatedTypes(
SmallVectorImpl<AssociatedTypeDecl *> &AssocTypes);
ParserResult<ProtocolDecl> parseDeclProtocol(ParseDeclOptions Flags,
DeclAttributes &Attributes);

Expand Down
82 changes: 79 additions & 3 deletions lib/AST/ASTPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,7 @@ class PrintAST : public ASTVisitor<PrintAST> {
bool openBracket = true, bool closeBracket = true);
void printGenericDeclGenericParams(GenericContext *decl);
void printDeclGenericRequirements(GenericContext *decl);
void printPrimaryAssociatedTypes(ProtocolDecl *decl);
void printBodyIfNecessary(const AbstractFunctionDecl *decl);

void printEnumElement(EnumElementDecl *elt);
Expand Down Expand Up @@ -1380,7 +1381,8 @@ struct RequirementPrintLocation {
/// function does: asking "where should this requirement be printed?" and then
/// callers check if the location is the ATD.
static RequirementPrintLocation
bestRequirementPrintLocation(ProtocolDecl *proto, const Requirement &req) {
bestRequirementPrintLocation(ProtocolDecl *proto, const Requirement &req,
PrintOptions opts, bool inheritanceClause) {
auto protoSelf = proto->getProtocolSelfType();
// Returns the most relevant decl within proto connected to outerType (or null
// if one doesn't exist), and whether the type is an "direct use",
Expand All @@ -1397,6 +1399,7 @@ bestRequirementPrintLocation(ProtocolDecl *proto, const Requirement &req) {
return true;
} else if (auto DMT = t->getAs<DependentMemberType>()) {
auto assocType = DMT->getAssocType();

if (assocType && assocType->getProtocol() == proto) {
relevantDecl = assocType;
foundType = t;
Expand All @@ -1411,6 +1414,17 @@ bestRequirementPrintLocation(ProtocolDecl *proto, const Requirement &req) {
// If we didn't find anything, relevantDecl and foundType will be null, as
// desired.
auto directUse = foundType && outerType->isEqual(foundType);

// Prefer to attach requirements to associated type declarations,
// unless the associated type is a primary associated type and
// we're printing primary associated types using the new syntax.
if (!directUse &&
relevantDecl &&
opts.PrintPrimaryAssociatedTypes &&
isa<AssociatedTypeDecl>(relevantDecl) &&
cast<AssociatedTypeDecl>(relevantDecl)->isPrimary())
relevantDecl = proto;

return std::make_pair(relevantDecl, directUse);
};

Expand Down Expand Up @@ -1481,7 +1495,8 @@ void PrintAST::printInheritedFromRequirementSignature(ProtocolDecl *proto,
return false;
}

auto location = bestRequirementPrintLocation(proto, req);
auto location = bestRequirementPrintLocation(proto, req, Options,
/*inheritanceClause=*/true);
return location.AttachedTo == attachingTo && !location.InWhereClause;
});
}
Expand All @@ -1496,7 +1511,8 @@ void PrintAST::printWhereClauseFromRequirementSignature(ProtocolDecl *proto,
proto->getRequirementSignature().getRequirements()),
flags,
[&](const Requirement &req) {
auto location = bestRequirementPrintLocation(proto, req);
auto location = bestRequirementPrintLocation(proto, req, Options,
/*inheritanceClause=*/false);
return location.AttachedTo == attachingTo && location.InWhereClause;
});
}
Expand Down Expand Up @@ -2969,6 +2985,22 @@ static void suppressingFeatureUnsafeInheritExecutor(PrintOptions &options,
options.ExcludeAttrList.resize(originalExcludeAttrCount);
}

static bool usesFeaturePrimaryAssociatedTypes(Decl *decl) {
if (auto *protoDecl = dyn_cast<ProtocolDecl>(decl)) {
if (protoDecl->getPrimaryAssociatedTypes().size() > 0)
return true;
}

return false;
}

static void suppressingFeaturePrimaryAssociatedTypes(PrintOptions &options,
llvm::function_ref<void()> action) {
bool originalPrintPrimaryAssociatedTypes = options.PrintPrimaryAssociatedTypes;
options.PrintPrimaryAssociatedTypes = false;
action();
options.PrintPrimaryAssociatedTypes = originalPrintPrimaryAssociatedTypes;
}

/// Suppress the printing of a particular feature.
static void suppressingFeature(PrintOptions &options, Feature feature,
Expand Down Expand Up @@ -3485,6 +3517,38 @@ void PrintAST::visitClassDecl(ClassDecl *decl) {
}
}

void PrintAST::printPrimaryAssociatedTypes(ProtocolDecl *decl) {
auto primaryAssocTypes = decl->getPrimaryAssociatedTypes();
if (primaryAssocTypes.empty())
return;

Printer.printStructurePre(PrintStructureKind::DeclGenericParameterClause);

Printer << "<";
llvm::interleave(
primaryAssocTypes,
[&](AssociatedTypeDecl *assocType) {
Printer.callPrintStructurePre(PrintStructureKind::GenericParameter,
assocType);
Printer.printName(assocType->getName(),
PrintNameContext::GenericParameter);

printInheritedFromRequirementSignature(decl, assocType);

if (assocType->hasDefaultDefinitionType()) {
Printer << " = ";
assocType->getDefaultDefinitionType().print(Printer, Options);
}

Printer.printStructurePost(PrintStructureKind::GenericParameter,
assocType);
},
[&] { Printer << ", "; });
Printer << ">";

Printer.printStructurePost(PrintStructureKind::DeclGenericParameterClause);
}

void PrintAST::visitProtocolDecl(ProtocolDecl *decl) {
printDocumentationComment(decl);
printAttributes(decl);
Expand All @@ -3502,6 +3566,10 @@ void PrintAST::visitProtocolDecl(ProtocolDecl *decl) {
Printer.printName(decl->getName());
});

if (Options.PrintPrimaryAssociatedTypes) {
printPrimaryAssociatedTypes(decl);
}

printInheritedFromRequirementSignature(decl, decl);

// The trailing where clause is a syntactic thing, which isn't serialized
Expand Down Expand Up @@ -4997,6 +5065,14 @@ bool Decl::shouldPrintInContext(const PrintOptions &PO) const {
return PO.PrintIfConfig;
}

if (auto *ATD = dyn_cast<AssociatedTypeDecl>(this)) {
// If PO.PrintPrimaryAssociatedTypes is on, primary associated
// types are printed as part of the protocol declaration itself,
// so skip them here.
if (ATD->isPrimary() && PO.PrintPrimaryAssociatedTypes)
return false;
}

// Print everything else.
return true;
}
Expand Down
1 change: 1 addition & 0 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4539,6 +4539,7 @@ AssociatedTypeDecl::AssociatedTypeDecl(DeclContext *dc, SourceLoc keywordLoc,
: AbstractTypeParamDecl(DeclKind::AssociatedType, dc, name, nameLoc),
KeywordLoc(keywordLoc), DefaultDefinition(defaultDefinition),
TrailingWhere(trailingWhere) {
Bits.AssociatedTypeDecl.IsPrimary = 0;
}

AssociatedTypeDecl::AssociatedTypeDecl(DeclContext *dc, SourceLoc keywordLoc,
Expand Down
Loading