@@ -892,6 +892,7 @@ class PrintAST : public ASTVisitor<PrintAST> {
892
892
bool openBracket = true , bool closeBracket = true );
893
893
void printGenericDeclGenericParams (GenericContext *decl);
894
894
void printDeclGenericRequirements (GenericContext *decl);
895
+ void printPrimaryAssociatedTypes (ProtocolDecl *decl);
895
896
void printBodyIfNecessary (const AbstractFunctionDecl *decl);
896
897
897
898
void printEnumElement (EnumElementDecl *elt);
@@ -1380,7 +1381,8 @@ struct RequirementPrintLocation {
1380
1381
// / function does: asking "where should this requirement be printed?" and then
1381
1382
// / callers check if the location is the ATD.
1382
1383
static RequirementPrintLocation
1383
- bestRequirementPrintLocation (ProtocolDecl *proto, const Requirement &req) {
1384
+ bestRequirementPrintLocation (ProtocolDecl *proto, const Requirement &req,
1385
+ PrintOptions opts, bool inheritanceClause) {
1384
1386
auto protoSelf = proto->getProtocolSelfType ();
1385
1387
// Returns the most relevant decl within proto connected to outerType (or null
1386
1388
// if one doesn't exist), and whether the type is an "direct use",
@@ -1397,6 +1399,7 @@ bestRequirementPrintLocation(ProtocolDecl *proto, const Requirement &req) {
1397
1399
return true ;
1398
1400
} else if (auto DMT = t->getAs <DependentMemberType>()) {
1399
1401
auto assocType = DMT->getAssocType ();
1402
+
1400
1403
if (assocType && assocType->getProtocol () == proto) {
1401
1404
relevantDecl = assocType;
1402
1405
foundType = t;
@@ -1411,6 +1414,17 @@ bestRequirementPrintLocation(ProtocolDecl *proto, const Requirement &req) {
1411
1414
// If we didn't find anything, relevantDecl and foundType will be null, as
1412
1415
// desired.
1413
1416
auto directUse = foundType && outerType->isEqual (foundType);
1417
+
1418
+ // Prefer to attach requirements to associated type declarations,
1419
+ // unless the associated type is a primary associated type and
1420
+ // we're printing primary associated types using the new syntax.
1421
+ if (!directUse &&
1422
+ relevantDecl &&
1423
+ opts.PrintPrimaryAssociatedTypes &&
1424
+ isa<AssociatedTypeDecl>(relevantDecl) &&
1425
+ cast<AssociatedTypeDecl>(relevantDecl)->isPrimary ())
1426
+ relevantDecl = proto;
1427
+
1414
1428
return std::make_pair (relevantDecl, directUse);
1415
1429
};
1416
1430
@@ -1481,7 +1495,8 @@ void PrintAST::printInheritedFromRequirementSignature(ProtocolDecl *proto,
1481
1495
return false ;
1482
1496
}
1483
1497
1484
- auto location = bestRequirementPrintLocation (proto, req);
1498
+ auto location = bestRequirementPrintLocation (proto, req, Options,
1499
+ /* inheritanceClause=*/ true );
1485
1500
return location.AttachedTo == attachingTo && !location.InWhereClause ;
1486
1501
});
1487
1502
}
@@ -1496,7 +1511,8 @@ void PrintAST::printWhereClauseFromRequirementSignature(ProtocolDecl *proto,
1496
1511
proto->getRequirementSignature ().getRequirements ()),
1497
1512
flags,
1498
1513
[&](const Requirement &req) {
1499
- auto location = bestRequirementPrintLocation (proto, req);
1514
+ auto location = bestRequirementPrintLocation (proto, req, Options,
1515
+ /* inheritanceClause=*/ false );
1500
1516
return location.AttachedTo == attachingTo && location.InWhereClause ;
1501
1517
});
1502
1518
}
@@ -2969,6 +2985,22 @@ static void suppressingFeatureUnsafeInheritExecutor(PrintOptions &options,
2969
2985
options.ExcludeAttrList .resize (originalExcludeAttrCount);
2970
2986
}
2971
2987
2988
+ static bool usesFeaturePrimaryAssociatedTypes (Decl *decl) {
2989
+ if (auto *protoDecl = dyn_cast<ProtocolDecl>(decl)) {
2990
+ if (protoDecl->getPrimaryAssociatedTypes ().size () > 0 )
2991
+ return true ;
2992
+ }
2993
+
2994
+ return false ;
2995
+ }
2996
+
2997
+ static void suppressingFeaturePrimaryAssociatedTypes (PrintOptions &options,
2998
+ llvm::function_ref<void ()> action) {
2999
+ bool originalPrintPrimaryAssociatedTypes = options.PrintPrimaryAssociatedTypes ;
3000
+ options.PrintPrimaryAssociatedTypes = false ;
3001
+ action ();
3002
+ options.PrintPrimaryAssociatedTypes = originalPrintPrimaryAssociatedTypes;
3003
+ }
2972
3004
2973
3005
// / Suppress the printing of a particular feature.
2974
3006
static void suppressingFeature (PrintOptions &options, Feature feature,
@@ -3485,6 +3517,38 @@ void PrintAST::visitClassDecl(ClassDecl *decl) {
3485
3517
}
3486
3518
}
3487
3519
3520
+ void PrintAST::printPrimaryAssociatedTypes (ProtocolDecl *decl) {
3521
+ auto primaryAssocTypes = decl->getPrimaryAssociatedTypes ();
3522
+ if (primaryAssocTypes.empty ())
3523
+ return ;
3524
+
3525
+ Printer.printStructurePre (PrintStructureKind::DeclGenericParameterClause);
3526
+
3527
+ Printer << " <" ;
3528
+ llvm::interleave (
3529
+ primaryAssocTypes,
3530
+ [&](AssociatedTypeDecl *assocType) {
3531
+ Printer.callPrintStructurePre (PrintStructureKind::GenericParameter,
3532
+ assocType);
3533
+ Printer.printName (assocType->getName (),
3534
+ PrintNameContext::GenericParameter);
3535
+
3536
+ printInheritedFromRequirementSignature (decl, assocType);
3537
+
3538
+ if (assocType->hasDefaultDefinitionType ()) {
3539
+ Printer << " = " ;
3540
+ assocType->getDefaultDefinitionType ().print (Printer, Options);
3541
+ }
3542
+
3543
+ Printer.printStructurePost (PrintStructureKind::GenericParameter,
3544
+ assocType);
3545
+ },
3546
+ [&] { Printer << " , " ; });
3547
+ Printer << " >" ;
3548
+
3549
+ Printer.printStructurePost (PrintStructureKind::DeclGenericParameterClause);
3550
+ }
3551
+
3488
3552
void PrintAST::visitProtocolDecl (ProtocolDecl *decl) {
3489
3553
printDocumentationComment (decl);
3490
3554
printAttributes (decl);
@@ -3502,6 +3566,10 @@ void PrintAST::visitProtocolDecl(ProtocolDecl *decl) {
3502
3566
Printer.printName (decl->getName ());
3503
3567
});
3504
3568
3569
+ if (Options.PrintPrimaryAssociatedTypes ) {
3570
+ printPrimaryAssociatedTypes (decl);
3571
+ }
3572
+
3505
3573
printInheritedFromRequirementSignature (decl, decl);
3506
3574
3507
3575
// The trailing where clause is a syntactic thing, which isn't serialized
@@ -4997,6 +5065,14 @@ bool Decl::shouldPrintInContext(const PrintOptions &PO) const {
4997
5065
return PO.PrintIfConfig ;
4998
5066
}
4999
5067
5068
+ if (auto *ATD = dyn_cast<AssociatedTypeDecl>(this )) {
5069
+ // If PO.PrintPrimaryAssociatedTypes is on, primary associated
5070
+ // types are printed as part of the protocol declaration itself,
5071
+ // so skip them here.
5072
+ if (ATD->isPrimary () && PO.PrintPrimaryAssociatedTypes )
5073
+ return false ;
5074
+ }
5075
+
5000
5076
// Print everything else.
5001
5077
return true ;
5002
5078
}
0 commit comments