@@ -550,7 +550,7 @@ static void deriveBodyEquatable_struct_eq(AbstractFunctionDecl *eqDecl) {
550
550
551
551
// / Derive an '==' operator implementation for an enum or a struct.
552
552
static ValueDecl *
553
- deriveEquatable_eq (DerivedConformance &derived, Identifier generatedIdentifier,
553
+ deriveEquatable_eq (DerivedConformance &derived,
554
554
void (*bodySynthesizer)(AbstractFunctionDecl *)) {
555
555
// enum SomeEnum<T...> {
556
556
// case A, B(Int), C(String, Int)
@@ -590,14 +590,14 @@ deriveEquatable_eq(DerivedConformance &derived, Identifier generatedIdentifier,
590
590
ASTContext &C = derived.TC .Context ;
591
591
592
592
auto parentDC = derived.getConformanceContext ();
593
- auto enumTy = parentDC->getDeclaredTypeInContext ();
594
- auto enumIfaceTy = parentDC->getDeclaredInterfaceType ();
593
+ auto selfTy = parentDC->getDeclaredTypeInContext ();
594
+ auto selfIfaceTy = parentDC->getDeclaredInterfaceType ();
595
595
596
596
auto getParamDecl = [&](StringRef s) -> ParamDecl * {
597
597
auto *param = new (C) ParamDecl (VarDecl::Specifier::Default, SourceLoc (),
598
598
SourceLoc (), Identifier (), SourceLoc (),
599
- C.getIdentifier (s), enumTy , parentDC);
600
- param->setInterfaceType (enumIfaceTy );
599
+ C.getIdentifier (s), selfTy , parentDC);
600
+ param->setInterfaceType (selfIfaceTy );
601
601
return param;
602
602
};
603
603
@@ -611,6 +611,17 @@ deriveEquatable_eq(DerivedConformance &derived, Identifier generatedIdentifier,
611
611
612
612
auto boolTy = C.getBoolDecl ()->getDeclaredType ();
613
613
614
+ Identifier generatedIdentifier;
615
+ if (parentDC->getParentModule ()->getResilienceStrategy () ==
616
+ ResilienceStrategy::Resilient) {
617
+ generatedIdentifier = C.Id_EqualsOperator ;
618
+ } else if (selfTy->getEnumOrBoundGenericEnum ()) {
619
+ generatedIdentifier = C.Id_derived_enum_equals ;
620
+ } else {
621
+ assert (selfTy->getStructOrBoundGenericStruct ());
622
+ generatedIdentifier = C.Id_derived_struct_equals ;
623
+ }
624
+
614
625
DeclName name (C, generatedIdentifier, params);
615
626
auto eqDecl =
616
627
FuncDecl::create (C, /* StaticLoc=*/ SourceLoc (),
@@ -626,17 +637,19 @@ deriveEquatable_eq(DerivedConformance &derived, Identifier generatedIdentifier,
626
637
eqDecl->getAttrs ().add (new (C) InfixAttr (/* implicit*/ false ));
627
638
628
639
// Add the @_implements(Equatable, ==(_:_:)) attribute
629
- auto equatableProto = C.getProtocol (KnownProtocolKind::Equatable);
630
- auto equatableTy = equatableProto->getDeclaredType ();
631
- auto equatableTypeLoc = TypeLoc::withoutLoc (equatableTy);
632
- SmallVector<Identifier, 2 > argumentLabels = { Identifier (), Identifier () };
633
- auto equalsDeclName = DeclName (C, DeclBaseName (C.Id_EqualsOperator ),
634
- argumentLabels);
635
- eqDecl->getAttrs ().add (new (C) ImplementsAttr (SourceLoc (),
636
- SourceRange (),
637
- equatableTypeLoc,
638
- equalsDeclName,
639
- DeclNameLoc ()));
640
+ if (generatedIdentifier != C.Id_EqualsOperator ) {
641
+ auto equatableProto = C.getProtocol (KnownProtocolKind::Equatable);
642
+ auto equatableTy = equatableProto->getDeclaredType ();
643
+ auto equatableTypeLoc = TypeLoc::withoutLoc (equatableTy);
644
+ SmallVector<Identifier, 2 > argumentLabels = { Identifier (), Identifier () };
645
+ auto equalsDeclName = DeclName (C, DeclBaseName (C.Id_EqualsOperator ),
646
+ argumentLabels);
647
+ eqDecl->getAttrs ().add (new (C) ImplementsAttr (SourceLoc (),
648
+ SourceRange (),
649
+ equatableTypeLoc,
650
+ equalsDeclName,
651
+ DeclNameLoc ()));
652
+ }
640
653
641
654
if (!C.getEqualIntDecl ()) {
642
655
derived.TC .diagnose (derived.ConformanceDecl ->getLoc (),
@@ -683,11 +696,9 @@ ValueDecl *DerivedConformance::deriveEquatable(ValueDecl *requirement) {
683
696
: ed->hasOnlyCasesWithoutAssociatedValues ()
684
697
? &deriveBodyEquatable_enum_noAssociatedValues_eq
685
698
: &deriveBodyEquatable_enum_hasAssociatedValues_eq;
686
- return deriveEquatable_eq (*this , TC.Context .Id_derived_enum_equals ,
687
- bodySynthesizer);
699
+ return deriveEquatable_eq (*this , bodySynthesizer);
688
700
} else if (isa<StructDecl>(Nominal))
689
- return deriveEquatable_eq (*this , TC.Context .Id_derived_struct_equals ,
690
- &deriveBodyEquatable_struct_eq);
701
+ return deriveEquatable_eq (*this , &deriveBodyEquatable_struct_eq);
691
702
else
692
703
llvm_unreachable (" todo" );
693
704
}
0 commit comments