Skip to content

[TypeJoin] Implement Type::join for protocols and protocol compositions. #19677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 102 additions & 14 deletions lib/AST/TypeJoinMeet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ struct TypeJoin : CanTypeVisitor<TypeJoin, CanType> {
}

static CanType getSuperclassJoin(CanType first, CanType second);
CanType computeProtocolCompositionJoin(ArrayRef<Type> firstMembers,
ArrayRef<Type> secondMembers);


CanType visitErrorType(CanType second);
CanType visitTupleType(CanType second);
Expand Down Expand Up @@ -105,10 +108,10 @@ struct TypeJoin : CanTypeVisitor<TypeJoin, CanType> {

// Likewise, rather than making every visitor deal with Any,
// always dispatch to the protocol composition side of the join.
if (first->isAny())
if (first->is<ProtocolCompositionType>())
return TypeJoin(second).visit(first);

if (second->isAny())
if (second->is<ProtocolCompositionType>())
return TypeJoin(first).visit(second);

// Otherwise the first type might be an optional (or not), so
Expand Down Expand Up @@ -184,16 +187,6 @@ CanType TypeJoin::visitClassType(CanType second) {
return getSuperclassJoin(First, second);
}

CanType TypeJoin::visitProtocolType(CanType second) {
assert(First != second);

// FIXME: We should compute a tighter bound and/or return nullptr if
// we cannot. We do this now because existing tests rely on
// producing Any for the join of protocols that have a common
// supertype.
return TheAnyType;
}

CanType TypeJoin::visitBoundGenericClassType(CanType second) {
return getSuperclassJoin(First, second);
}
Expand Down Expand Up @@ -352,16 +345,111 @@ CanType TypeJoin::visitGenericFunctionType(CanType second) {
return Unimplemented;
}

// Use the distributive law to compute the join of the protocol
// compositions.
//
// (A ^ B) v (C ^ D)
// = (A v C) ^ (A v D) ^ (B v C) ^ (B v D)
//
// In general this law only applies to distributive lattices.
//
// In our case, this should be safe because our meet operation only
// produces an existing nominal type when it is one of the operands of
// the operation. So we can never arbitrarily climb down the lattice
// in ways that would break distributivity.
//
CanType TypeJoin::computeProtocolCompositionJoin(ArrayRef<Type> firstMembers,
ArrayRef<Type> secondMembers) {
SmallVector<Type, 8> result;
for (auto first : firstMembers) {
for (auto second : secondMembers) {
auto joined = Type::join(first, second);
if (!joined)
return Unimplemented;

if ((*joined)->isAny())
continue;

result.push_back(*joined);
}
}

if (result.empty())
return TheAnyType;

auto &ctx = result[0]->getASTContext();
return ProtocolCompositionType::get(ctx, result, false)->getCanonicalType();
}

CanType TypeJoin::visitProtocolCompositionType(CanType second) {
// The join of Any and a no-escape function doesn't exist; it isn't
// Any. If it were Any, it would mean we would allow these functions
// to escape through Any.
if (second->isAny()) {
auto *fnTy = First->getAs<AnyFunctionType>();
if (fnTy && fnTy->getExtInfo().isNoEscape())
return Nonexistent;

return second;
return TheAnyType;
}

return Unimplemented;
assert(First != second);

// FIXME: Handle other types here.
if (!First->isExistentialType())
return TheAnyType;

SmallVector<Type, 1> protocolType;
ArrayRef<Type> firstMembers;
if (First->is<ProtocolType>()) {
protocolType.push_back(First);
firstMembers = protocolType;
} else {
firstMembers = cast<ProtocolCompositionType>(First)->getMembers();
}
auto secondMembers = cast<ProtocolCompositionType>(second)->getMembers();

return computeProtocolCompositionJoin(firstMembers, secondMembers);
}

CanType TypeJoin::visitProtocolType(CanType second) {
assert(First != second);

assert(!First->is<ProtocolCompositionType>() &&
!second->is<ProtocolCompositionType>());

// FIXME: Handle other types here.
if (First->getKind() != second->getKind())
return TheAnyType;

auto *firstDecl =
cast<ProtocolDecl>(First->getNominalOrBoundGenericNominal());

auto *secondDecl =
cast<ProtocolDecl>(second->getNominalOrBoundGenericNominal());

if (firstDecl->getInheritedProtocols().empty() &&
secondDecl->getInheritedProtocols().empty())
return TheAnyType;

if (firstDecl->inheritsFrom(secondDecl))
return second;

if (secondDecl->inheritsFrom(firstDecl))
return First;

// One isn't the supertype of the other, so instead, treat each as
// if it's a protocol composition of its inherited members, and join
// those.
SmallVector<Type, 4> firstMembers;
for (auto *decl : firstDecl->getInheritedProtocols())
firstMembers.push_back(decl->getDeclaredInterfaceType());

SmallVector<Type, 4> secondMembers;
for (auto *decl : secondDecl->getInheritedProtocols())
secondMembers.push_back(decl->getDeclaredInterfaceType());

return computeProtocolCompositionJoin(firstMembers, secondMembers);
}

CanType TypeJoin::visitLValueType(CanType second) { return Unimplemented; }
Expand Down
46 changes: 46 additions & 0 deletions test/Sema/type_join.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,31 @@ import Swift
class C {}
class D : C {}

protocol L {}
protocol M : L {}
protocol N : L {}
protocol P : M {}
protocol Q : M {}
protocol R : L {}
protocol Y {}

protocol FakeEquatable {}
protocol FakeHashable : FakeEquatable {}
protocol FakeExpressibleByIntegerLiteral {}
protocol FakeNumeric : FakeEquatable, FakeExpressibleByIntegerLiteral {}
protocol FakeSignedNumeric : FakeNumeric {}
protocol FakeComparable : FakeEquatable {}
protocol FakeStrideable : FakeComparable {}
protocol FakeCustomStringConvertible {}
protocol FakeBinaryInteger : FakeHashable, FakeNumeric, FakeCustomStringConvertible, FakeStrideable {}
protocol FakeLosslessStringConvertible {}
protocol FakeFixedWidthInteger : FakeBinaryInteger, FakeLosslessStringConvertible {}
protocol FakeUnsignedInteger : FakeBinaryInteger {}
protocol FakeSignedInteger : FakeBinaryInteger, FakeSignedNumeric {}
protocol FakeFloatingPoint : FakeSignedNumeric, FakeStrideable, FakeHashable {}
protocol FakeExpressibleByFloatLiteral {}
protocol FakeBinaryFloatingPoint : FakeFloatingPoint, FakeExpressibleByFloatLiteral {}

func expectEqualType<T>(_: T.Type, _: T.Type) {}
func commonSupertype<T>(_: T, _: T) -> T {}

Expand Down Expand Up @@ -38,6 +63,27 @@ expectEqualType(Builtin.type_join(Builtin.Int1.self, Builtin.Int1.self), Builtin
expectEqualType(Builtin.type_join(Builtin.Int32.self, Builtin.Int1.self), Any.self)
expectEqualType(Builtin.type_join(Builtin.Int1.self, Builtin.Int32.self), Any.self)

expectEqualType(Builtin.type_join(L.self, L.self), L.self)
expectEqualType(Builtin.type_join(L.self, M.self), L.self)
expectEqualType(Builtin.type_join(L.self, P.self), L.self)
expectEqualType(Builtin.type_join(L.self, Y.self), Any.self)
expectEqualType(Builtin.type_join(N.self, P.self), L.self)
expectEqualType(Builtin.type_join(Q.self, P.self), M.self)
expectEqualType(Builtin.type_join((N & P).self, (Q & R).self), M.self)
expectEqualType(Builtin.type_join((Q & P).self, (Y & R).self), L.self)
expectEqualType(Builtin.type_join(FakeEquatable.self, FakeEquatable.self), FakeEquatable.self)
expectEqualType(Builtin.type_join(FakeHashable.self, FakeEquatable.self), FakeEquatable.self)
expectEqualType(Builtin.type_join(FakeEquatable.self, FakeHashable.self), FakeEquatable.self)
expectEqualType(Builtin.type_join(FakeNumeric.self, FakeHashable.self), FakeEquatable.self)
expectEqualType(Builtin.type_join((FakeHashable & FakeStrideable).self, (FakeHashable & FakeNumeric).self),
FakeHashable.self)
expectEqualType(Builtin.type_join((FakeNumeric & FakeStrideable).self,
(FakeHashable & FakeNumeric).self), FakeNumeric.self)
expectEqualType(Builtin.type_join(FakeBinaryInteger.self, FakeFloatingPoint.self),
(FakeHashable & FakeNumeric & FakeStrideable).self)
expectEqualType(Builtin.type_join(FakeFloatingPoint.self, FakeBinaryInteger.self),
(FakeHashable & FakeNumeric & FakeStrideable).self)

func joinFunctions(
_ escaping: @escaping () -> (),
_ nonescaping: () -> ()
Expand Down