Skip to content

Commit 17fd8af

Browse files
authored
Merge pull request #15806 from slavapestov/reorder-associated-types
Resilient re-ordering of associated types
2 parents 3415a7e + ab6f20a commit 17fd8af

14 files changed

+192
-25
lines changed

include/swift/AST/Decl.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2510,6 +2510,14 @@ class TypeDecl : public ValueDecl {
25102510

25112511
/// Compute an ordering between two type declarations that is ABI-stable.
25122512
static int compare(const TypeDecl *type1, const TypeDecl *type2);
2513+
2514+
/// Compute an ordering between two type declarations that is ABI-stable.
2515+
/// This version takes a pointer-to-a-pointer for use with
2516+
/// llvm::array_pod_sort() and similar.
2517+
template<typename T>
2518+
static int compare(T * const* type1, T * const* type2) {
2519+
return compare(*type1, *type2);
2520+
}
25132521
};
25142522

25152523
/// A type declaration that can have generic parameters attached to it. Because

include/swift/AST/Types.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4288,11 +4288,6 @@ class ProtocolType : public NominalType, public llvm::FoldingSetNode {
42884288
static bool visitAllProtocols(ArrayRef<ProtocolDecl *> protocols,
42894289
llvm::function_ref<bool(ProtocolDecl *)> fn);
42904290

4291-
/// Compare two protocols to provide them with a stable ordering for
4292-
/// use in sorting.
4293-
static int compareProtocols(ProtocolDecl * const* PP1,
4294-
ProtocolDecl * const* PP2);
4295-
42964291
void Profile(llvm::FoldingSetNodeID &ID) {
42974292
Profile(ID, getDecl(), getParent());
42984293
}

include/swift/SIL/SILWitnessVisitor.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,21 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {
5858
if (haveAddedAssociatedTypes) return;
5959
haveAddedAssociatedTypes = true;
6060

61+
SmallVector<AssociatedTypeDecl *, 2> associatedTypes;
6162
for (Decl *member : protocol->getMembers()) {
6263
if (auto associatedType = dyn_cast<AssociatedTypeDecl>(member)) {
63-
// TODO: only add associated types when they're new?
64-
asDerived().addAssociatedType(AssociatedType(associatedType));
64+
associatedTypes.push_back(associatedType);
6565
}
6666
}
67+
68+
// Sort associated types by name, for resilience.
69+
llvm::array_pod_sort(associatedTypes.begin(), associatedTypes.end(),
70+
TypeDecl::compare);
71+
72+
for (auto *associatedType : associatedTypes) {
73+
// TODO: only add associated types when they're new?
74+
asDerived().addAssociatedType(AssociatedType(associatedType));
75+
}
6776
};
6877

6978
for (const auto &reqt : protocol->getRequirementSignature()) {

lib/AST/ASTContext.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1613,7 +1613,7 @@ static int compareSimilarAssociatedTypes(AssociatedTypeDecl *const *lhs,
16131613
AssociatedTypeDecl *const *rhs) {
16141614
auto lhsProto = (*lhs)->getProtocol();
16151615
auto rhsProto = (*rhs)->getProtocol();
1616-
return ProtocolType::compareProtocols(&lhsProto, &rhsProto);
1616+
return TypeDecl::compare(lhsProto, rhsProto);
16171617
}
16181618

16191619
ArrayRef<AssociatedTypeDecl *> AssociatedTypeDecl::getOverriddenDecls() const {

lib/AST/ConformanceLookupTable.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1110,7 +1110,7 @@ int ConformanceLookupTable::compareProtocolConformances(
11101110
// Otherwise, sort by protocol.
11111111
ProtocolDecl *lhsProto = lhs->getProtocol();
11121112
ProtocolDecl *rhsProto = rhs->getProtocol();
1113-
return ProtocolType::compareProtocols(&lhsProto, &rhsProto);
1113+
return TypeDecl::compare(lhsProto, rhsProto);
11141114
}
11151115

11161116
void ConformanceLookupTable::getAllConformances(

lib/AST/GenericSignature.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ GenericSignature::getCanonical(TypeArrayView<GenericTypeParamType> params,
271271
auto prevProto =
272272
prevReqt.getSecondType()->castTo<ProtocolType>()->getDecl();
273273
auto proto = reqt.getSecondType()->castTo<ProtocolType>()->getDecl();
274-
assert(ProtocolType::compareProtocols(&prevProto, &proto) < 0 &&
274+
assert(TypeDecl::compare(prevProto, proto) < 0 &&
275275
"Out-of-order conformance requirements");
276276
}
277277
#endif

lib/AST/GenericSignatureBuilder.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2073,7 +2073,7 @@ static int compareAssociatedTypes(AssociatedTypeDecl *assocType1,
20732073
// - by protocol, so t_n_m.`P.T` < t_n_m.`Q.T` (given P < Q)
20742074
auto proto1 = assocType1->getProtocol();
20752075
auto proto2 = assocType2->getProtocol();
2076-
if (int compareProtocols = ProtocolType::compareProtocols(&proto1, &proto2))
2076+
if (int compareProtocols = TypeDecl::compare(proto1, proto2))
20772077
return compareProtocols;
20782078

20792079
// Error case: if we have two associated types with the same name in the
@@ -7286,7 +7286,7 @@ void GenericSignatureBuilder::enumerateRequirements(
72867286

72877287
// Sort the protocols in canonical order.
72887288
llvm::array_pod_sort(protocols.begin(), protocols.end(),
7289-
ProtocolType::compareProtocols);
7289+
TypeDecl::compare);
72907290

72917291
// Enumerate the conformance requirements.
72927292
for (auto proto : protocols) {

lib/AST/ProtocolConformance.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,8 +1188,7 @@ DeclContext::getLocalProtocols(
11881188

11891189
// Sort if required.
11901190
if (sorted) {
1191-
llvm::array_pod_sort(result.begin(), result.end(),
1192-
&ProtocolType::compareProtocols);
1191+
llvm::array_pod_sort(result.begin(), result.end(), TypeDecl::compare);
11931192
}
11941193

11951194
return result;

lib/AST/Type.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -971,12 +971,6 @@ static void addMinimumProtocols(Type T,
971971
}
972972
}
973973

974-
/// \brief Compare two protocols to establish an ordering between them.
975-
int ProtocolType::compareProtocols(ProtocolDecl * const* PP1,
976-
ProtocolDecl * const* PP2) {
977-
return TypeDecl::compare(*PP1, *PP2);
978-
}
979-
980974
bool ProtocolType::visitAllProtocols(
981975
ArrayRef<ProtocolDecl *> protocols,
982976
llvm::function_ref<bool(ProtocolDecl *)> fn) {
@@ -1052,7 +1046,7 @@ void ProtocolType::canonicalizeProtocols(
10521046

10531047
// Sort the set of protocols by module + name, to give a stable
10541048
// ordering.
1055-
llvm::array_pod_sort(protocols.begin(), protocols.end(), compareProtocols);
1049+
llvm::array_pod_sort(protocols.begin(), protocols.end(), TypeDecl::compare);
10561050
}
10571051

10581052
static Type

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4890,8 +4890,7 @@ TypeChecker::findWitnessedObjCRequirements(const ValueDecl *witness,
48904890
= cast<ProtocolDecl>(lhs->getDeclContext());
48914891
ProtocolDecl *rhsProto
48924892
= cast<ProtocolDecl>(rhs->getDeclContext());
4893-
return ProtocolType::compareProtocols(&lhsProto,
4894-
&rhsProto) < 0;
4893+
return TypeDecl::compare(lhsProto, rhsProto) < 0;
48954894
});
48964895
}
48974896
return result;

test/IRGen/associated_type_witness.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ protocol Assocked {
1212

1313
struct Universal : P, Q {}
1414

15-
// CHECK: [[ASSOC_TYPE_NAMES:@.*]] = private constant [29 x i8] c"OneAssoc TwoAssoc ThreeAssoc\00"
15+
// CHECK: [[ASSOC_TYPE_NAMES:@.*]] = private constant [29 x i8] c"OneAssoc ThreeAssoc TwoAssoc\00"
1616
// CHECK: @"$S23associated_type_witness18HasThreeAssocTypesMp" =
1717
// CHECK-SAME: [[ASSOC_TYPE_NAMES]] to i64
1818

test/IRGen/associated_types.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func testFastRuncible<T: Runcible, U: FastRuncible>(_ t: T, u: U)
7575
// 1. Get the type metadata for U.RuncerType.Runcee.
7676
// 1a. Get the type metadata for U.RuncerType.
7777
// Note that we actually look things up in T, which is going to prove unfortunate.
78-
// CHECK: [[T0_GEP:%.*]] = getelementptr inbounds i8*, i8** %T.Runcible, i32 1
78+
// CHECK: [[T0_GEP:%.*]] = getelementptr inbounds i8*, i8** %T.Runcible, i32 2
7979
// CHECK: [[T0:%.*]] = load i8*, i8** [[T0_GEP]]
8080
// CHECK-NEXT: [[T1:%.*]] = bitcast i8* [[T0]] to %swift.metadata_response ([[INT]], %swift.type*, i8**)*
8181
// CHECK-NEXT: [[T2:%.*]] = call swiftcc %swift.metadata_response [[T1]]([[INT]] 0, %swift.type* %T, i8** %T.Runcible)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
public protocol Bed {
2+
func squiggle()
3+
}
4+
5+
public protocol Outfit {
6+
var size: Int { get }
7+
}
8+
9+
public protocol Eater {
10+
#if BEFORE
11+
func eat()
12+
func poop()
13+
#else
14+
func poop()
15+
func eat()
16+
#endif
17+
}
18+
19+
public protocol Wiggler {
20+
#if BEFORE
21+
func wiggle()
22+
func cry()
23+
#else
24+
func cry()
25+
func wiggle()
26+
#endif
27+
}
28+
29+
#if BEFORE
30+
31+
public protocol Baby : Eater, Wiggler {
32+
associatedtype Bassinet : Bed
33+
associatedtype Onesie : Outfit
34+
35+
var outfitSize: Int { get }
36+
37+
func sleep(in: Bassinet)
38+
func wear(outfit: Onesie)
39+
}
40+
41+
#else
42+
43+
public protocol Baby : Wiggler, Eater {
44+
associatedtype Onesie : Outfit
45+
associatedtype Bassinet : Bed
46+
47+
var outfitSize: Int { get }
48+
49+
func wear(outfit: Onesie)
50+
func sleep(in: Bassinet)
51+
}
52+
53+
#endif
54+
55+
public func goodDay<B : Baby>(for baby: B,
56+
sleepingIn bed: B.Bassinet,
57+
wearing outfit: B.Onesie) {
58+
if baby.outfitSize != outfit.size {
59+
fatalError("I grew too much!")
60+
}
61+
62+
baby.wear(outfit: outfit)
63+
baby.sleep(in: bed)
64+
baby.poop()
65+
baby.sleep(in: bed)
66+
baby.eat()
67+
baby.sleep(in: bed)
68+
baby.wiggle()
69+
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// RUN: %target-resilience-test --no-backward-deployment
2+
// REQUIRES: executable_test
3+
4+
import StdlibUnittest
5+
import protocol_reorder_requirements
6+
7+
8+
var ProtocolReorderRequirementsTest = TestSuite("ProtocolReorderRequirements")
9+
10+
var log = [String]()
11+
12+
struct MyBassinet : Bed {
13+
func squiggle() {
14+
log.append("nap time")
15+
}
16+
}
17+
18+
struct MyOnesie : Outfit {
19+
let size = 3
20+
}
21+
22+
struct SillyBaby : Baby {
23+
func eat() {
24+
log.append("hangry!")
25+
}
26+
27+
func sleep(in bassinet: MyBassinet) {
28+
bassinet.squiggle()
29+
}
30+
31+
func wear(outfit: MyOnesie) {
32+
log.append("wearing outfit size \(outfit.size)")
33+
}
34+
35+
func poop() {
36+
log.append("change the diaper")
37+
}
38+
39+
func cry() {
40+
log.append("waaaaah!")
41+
}
42+
43+
func wiggle() {
44+
log.append("time to wiggle!")
45+
}
46+
47+
let outfitSize = 3
48+
}
49+
50+
func typicalDay<B : Baby>(for baby: B,
51+
sleepingIn bed: B.Bassinet,
52+
wearing outfit: B.Onesie) {
53+
baby.wear(outfit: outfit)
54+
baby.sleep(in: bed)
55+
baby.cry()
56+
baby.poop()
57+
baby.cry()
58+
baby.sleep(in: bed)
59+
baby.eat()
60+
baby.cry()
61+
}
62+
63+
ProtocolReorderRequirementsTest.test("ReorderProtocolRequirements") {
64+
let baby = SillyBaby()
65+
let bed = MyBassinet()
66+
let outfit = MyOnesie()
67+
68+
typicalDay(for: baby, sleepingIn: bed, wearing: outfit)
69+
expectEqual(log, [
70+
"wearing outfit size 3",
71+
"nap time",
72+
"waaaaah!",
73+
"change the diaper",
74+
"waaaaah!",
75+
"nap time",
76+
"hangry!",
77+
"waaaaah!"
78+
])
79+
log = []
80+
81+
goodDay(for: baby, sleepingIn: bed, wearing: outfit)
82+
expectEqual(log, [
83+
"wearing outfit size 3",
84+
"nap time",
85+
"change the diaper",
86+
"nap time",
87+
"hangry!",
88+
"nap time",
89+
"time to wiggle!"
90+
])
91+
}
92+
93+
runAllTests()
94+

0 commit comments

Comments
 (0)