Skip to content

Commit 6172679

Browse files
authored
Merge pull request #19677 from rudkx/join-protocol-compositions
[TypeJoin] Implement Type::join for protocols and protocol compositions.
2 parents da463e0 + e3f7531 commit 6172679

File tree

2 files changed

+148
-14
lines changed

2 files changed

+148
-14
lines changed

lib/AST/TypeJoinMeet.cpp

Lines changed: 102 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ struct TypeJoin : CanTypeVisitor<TypeJoin, CanType> {
5151
}
5252

5353
static CanType getSuperclassJoin(CanType first, CanType second);
54+
CanType computeProtocolCompositionJoin(ArrayRef<Type> firstMembers,
55+
ArrayRef<Type> secondMembers);
56+
5457

5558
CanType visitErrorType(CanType second);
5659
CanType visitTupleType(CanType second);
@@ -105,10 +108,10 @@ struct TypeJoin : CanTypeVisitor<TypeJoin, CanType> {
105108

106109
// Likewise, rather than making every visitor deal with Any,
107110
// always dispatch to the protocol composition side of the join.
108-
if (first->isAny())
111+
if (first->is<ProtocolCompositionType>())
109112
return TypeJoin(second).visit(first);
110113

111-
if (second->isAny())
114+
if (second->is<ProtocolCompositionType>())
112115
return TypeJoin(first).visit(second);
113116

114117
// Otherwise the first type might be an optional (or not), so
@@ -184,16 +187,6 @@ CanType TypeJoin::visitClassType(CanType second) {
184187
return getSuperclassJoin(First, second);
185188
}
186189

187-
CanType TypeJoin::visitProtocolType(CanType second) {
188-
assert(First != second);
189-
190-
// FIXME: We should compute a tighter bound and/or return nullptr if
191-
// we cannot. We do this now because existing tests rely on
192-
// producing Any for the join of protocols that have a common
193-
// supertype.
194-
return TheAnyType;
195-
}
196-
197190
CanType TypeJoin::visitBoundGenericClassType(CanType second) {
198191
return getSuperclassJoin(First, second);
199192
}
@@ -352,16 +345,111 @@ CanType TypeJoin::visitGenericFunctionType(CanType second) {
352345
return Unimplemented;
353346
}
354347

348+
// Use the distributive law to compute the join of the protocol
349+
// compositions.
350+
//
351+
// (A ^ B) v (C ^ D)
352+
// = (A v C) ^ (A v D) ^ (B v C) ^ (B v D)
353+
//
354+
// In general this law only applies to distributive lattices.
355+
//
356+
// In our case, this should be safe because our meet operation only
357+
// produces an existing nominal type when it is one of the operands of
358+
// the operation. So we can never arbitrarily climb down the lattice
359+
// in ways that would break distributivity.
360+
//
361+
CanType TypeJoin::computeProtocolCompositionJoin(ArrayRef<Type> firstMembers,
362+
ArrayRef<Type> secondMembers) {
363+
SmallVector<Type, 8> result;
364+
for (auto first : firstMembers) {
365+
for (auto second : secondMembers) {
366+
auto joined = Type::join(first, second);
367+
if (!joined)
368+
return Unimplemented;
369+
370+
if ((*joined)->isAny())
371+
continue;
372+
373+
result.push_back(*joined);
374+
}
375+
}
376+
377+
if (result.empty())
378+
return TheAnyType;
379+
380+
auto &ctx = result[0]->getASTContext();
381+
return ProtocolCompositionType::get(ctx, result, false)->getCanonicalType();
382+
}
383+
355384
CanType TypeJoin::visitProtocolCompositionType(CanType second) {
385+
// The join of Any and a no-escape function doesn't exist; it isn't
386+
// Any. If it were Any, it would mean we would allow these functions
387+
// to escape through Any.
356388
if (second->isAny()) {
357389
auto *fnTy = First->getAs<AnyFunctionType>();
358390
if (fnTy && fnTy->getExtInfo().isNoEscape())
359391
return Nonexistent;
360392

361-
return second;
393+
return TheAnyType;
362394
}
363395

364-
return Unimplemented;
396+
assert(First != second);
397+
398+
// FIXME: Handle other types here.
399+
if (!First->isExistentialType())
400+
return TheAnyType;
401+
402+
SmallVector<Type, 1> protocolType;
403+
ArrayRef<Type> firstMembers;
404+
if (First->is<ProtocolType>()) {
405+
protocolType.push_back(First);
406+
firstMembers = protocolType;
407+
} else {
408+
firstMembers = cast<ProtocolCompositionType>(First)->getMembers();
409+
}
410+
auto secondMembers = cast<ProtocolCompositionType>(second)->getMembers();
411+
412+
return computeProtocolCompositionJoin(firstMembers, secondMembers);
413+
}
414+
415+
CanType TypeJoin::visitProtocolType(CanType second) {
416+
assert(First != second);
417+
418+
assert(!First->is<ProtocolCompositionType>() &&
419+
!second->is<ProtocolCompositionType>());
420+
421+
// FIXME: Handle other types here.
422+
if (First->getKind() != second->getKind())
423+
return TheAnyType;
424+
425+
auto *firstDecl =
426+
cast<ProtocolDecl>(First->getNominalOrBoundGenericNominal());
427+
428+
auto *secondDecl =
429+
cast<ProtocolDecl>(second->getNominalOrBoundGenericNominal());
430+
431+
if (firstDecl->getInheritedProtocols().empty() &&
432+
secondDecl->getInheritedProtocols().empty())
433+
return TheAnyType;
434+
435+
if (firstDecl->inheritsFrom(secondDecl))
436+
return second;
437+
438+
if (secondDecl->inheritsFrom(firstDecl))
439+
return First;
440+
441+
// One isn't the supertype of the other, so instead, treat each as
442+
// if it's a protocol composition of its inherited members, and join
443+
// those.
444+
SmallVector<Type, 4> firstMembers;
445+
for (auto *decl : firstDecl->getInheritedProtocols())
446+
firstMembers.push_back(decl->getDeclaredInterfaceType());
447+
448+
SmallVector<Type, 4> secondMembers;
449+
for (auto *decl : secondDecl->getInheritedProtocols())
450+
secondMembers.push_back(decl->getDeclaredInterfaceType());
451+
452+
return computeProtocolCompositionJoin(firstMembers, secondMembers);
365453
}
366454

367455
CanType TypeJoin::visitLValueType(CanType second) { return Unimplemented; }

test/Sema/type_join.swift

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,31 @@ import Swift
55
class C {}
66
class D : C {}
77

8+
protocol L {}
9+
protocol M : L {}
10+
protocol N : L {}
11+
protocol P : M {}
12+
protocol Q : M {}
13+
protocol R : L {}
14+
protocol Y {}
15+
16+
protocol FakeEquatable {}
17+
protocol FakeHashable : FakeEquatable {}
18+
protocol FakeExpressibleByIntegerLiteral {}
19+
protocol FakeNumeric : FakeEquatable, FakeExpressibleByIntegerLiteral {}
20+
protocol FakeSignedNumeric : FakeNumeric {}
21+
protocol FakeComparable : FakeEquatable {}
22+
protocol FakeStrideable : FakeComparable {}
23+
protocol FakeCustomStringConvertible {}
24+
protocol FakeBinaryInteger : FakeHashable, FakeNumeric, FakeCustomStringConvertible, FakeStrideable {}
25+
protocol FakeLosslessStringConvertible {}
26+
protocol FakeFixedWidthInteger : FakeBinaryInteger, FakeLosslessStringConvertible {}
27+
protocol FakeUnsignedInteger : FakeBinaryInteger {}
28+
protocol FakeSignedInteger : FakeBinaryInteger, FakeSignedNumeric {}
29+
protocol FakeFloatingPoint : FakeSignedNumeric, FakeStrideable, FakeHashable {}
30+
protocol FakeExpressibleByFloatLiteral {}
31+
protocol FakeBinaryFloatingPoint : FakeFloatingPoint, FakeExpressibleByFloatLiteral {}
32+
833
func expectEqualType<T>(_: T.Type, _: T.Type) {}
934
func commonSupertype<T>(_: T, _: T) -> T {}
1035

@@ -38,6 +63,27 @@ expectEqualType(Builtin.type_join(Builtin.Int1.self, Builtin.Int1.self), Builtin
3863
expectEqualType(Builtin.type_join(Builtin.Int32.self, Builtin.Int1.self), Any.self)
3964
expectEqualType(Builtin.type_join(Builtin.Int1.self, Builtin.Int32.self), Any.self)
4065

66+
expectEqualType(Builtin.type_join(L.self, L.self), L.self)
67+
expectEqualType(Builtin.type_join(L.self, M.self), L.self)
68+
expectEqualType(Builtin.type_join(L.self, P.self), L.self)
69+
expectEqualType(Builtin.type_join(L.self, Y.self), Any.self)
70+
expectEqualType(Builtin.type_join(N.self, P.self), L.self)
71+
expectEqualType(Builtin.type_join(Q.self, P.self), M.self)
72+
expectEqualType(Builtin.type_join((N & P).self, (Q & R).self), M.self)
73+
expectEqualType(Builtin.type_join((Q & P).self, (Y & R).self), L.self)
74+
expectEqualType(Builtin.type_join(FakeEquatable.self, FakeEquatable.self), FakeEquatable.self)
75+
expectEqualType(Builtin.type_join(FakeHashable.self, FakeEquatable.self), FakeEquatable.self)
76+
expectEqualType(Builtin.type_join(FakeEquatable.self, FakeHashable.self), FakeEquatable.self)
77+
expectEqualType(Builtin.type_join(FakeNumeric.self, FakeHashable.self), FakeEquatable.self)
78+
expectEqualType(Builtin.type_join((FakeHashable & FakeStrideable).self, (FakeHashable & FakeNumeric).self),
79+
FakeHashable.self)
80+
expectEqualType(Builtin.type_join((FakeNumeric & FakeStrideable).self,
81+
(FakeHashable & FakeNumeric).self), FakeNumeric.self)
82+
expectEqualType(Builtin.type_join(FakeBinaryInteger.self, FakeFloatingPoint.self),
83+
(FakeHashable & FakeNumeric & FakeStrideable).self)
84+
expectEqualType(Builtin.type_join(FakeFloatingPoint.self, FakeBinaryInteger.self),
85+
(FakeHashable & FakeNumeric & FakeStrideable).self)
86+
4187
func joinFunctions(
4288
_ escaping: @escaping () -> (),
4389
_ nonescaping: () -> ()

0 commit comments

Comments
 (0)