Skip to content

Add KeyPathIterable protocol and synthesis. #21557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2405,6 +2405,8 @@ ERROR(broken_encodable_requirement,none,
ERROR(broken_decodable_requirement,none,
"Decodable protocol is broken: unexpected requirement", ())
// SWIFT_ENABLE_TENSORFLOW
ERROR(broken_key_path_iterable_requirement,none,
"KeyPathIterable protocol is broken: unexpected requirement", ())
ERROR(broken_parameter_group_requirement,none,
"ParameterGroup protocol is broken: unexpected requirement", ())
ERROR(broken_parameterized_requirement,none,
Expand Down
6 changes: 5 additions & 1 deletion include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ IDENTIFIER(decode)
IDENTIFIER(decodeIfPresent)
IDENTIFIER(Decoder)
IDENTIFIER(decoder)
// SWIFT_ENABLE_TENSORFLOW
IDENTIFIER(dynamicCallable)
IDENTIFIER(dynamicMember)
IDENTIFIER(Element)
Expand Down Expand Up @@ -118,6 +117,11 @@ IDENTIFIER_WITH_NAME(value_, "_value")
IDENTIFIER(with)

// SWIFT_ENABLE_TENSORFLOW
IDENTIFIER(AllKeyPaths)
IDENTIFIER(allKeyPaths)
IDENTIFIER(recursivelyAllKeyPaths)
IDENTIFIER(allWritableKeyPaths)
IDENTIFIER(recursivelyAllWritableKeyPaths)
IDENTIFIER(allParameters)
IDENTIFIER(Parameter)
IDENTIFIER(Parameters)
Expand Down
1 change: 1 addition & 0 deletions include/swift/AST/KnownProtocols.def
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ PROTOCOL(Decodable)
PROTOCOL(AdditiveArithmetic)
PROTOCOL(Numeric)
PROTOCOL(FloatingPoint)
PROTOCOL(KeyPathIterable)
PROTOCOL(ParameterGroup)
PROTOCOL(Parameterized)
PROTOCOL(TensorArrayProtocol)
Expand Down
1 change: 1 addition & 0 deletions lib/IRGen/GenMeta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4061,6 +4061,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::FloatingPoint:
case KnownProtocolKind::AdditiveArithmetic:
case KnownProtocolKind::Numeric:
case KnownProtocolKind::KeyPathIterable:
case KnownProtocolKind::ParameterGroup:
case KnownProtocolKind::Parameterized:
case KnownProtocolKind::TensorArrayProtocol:
Expand Down
1 change: 1 addition & 0 deletions lib/Sema/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ add_swift_host_library(swiftSema STATIC
DerivedConformanceEquatableHashable.cpp
DerivedConformanceError.cpp
# SWIFT_ENABLE_TENSORFLOW
DerivedConformanceKeyPathIterable.cpp
DerivedConformanceParameterGroup.cpp
DerivedConformanceParameterized.cpp
DerivedConformanceRawRepresentable.cpp
Expand Down
140 changes: 140 additions & 0 deletions lib/Sema/DerivedConformanceKeyPathIterable.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
//===--- DerivedConformanceKeyPathIterable.cpp ----------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file implements explicit derivation of the KeyPathIterable protocol for
// a nominal type.
//
//===----------------------------------------------------------------------===//

#include "CodeSynthesis.h"
#include "TypeChecker.h"
#include "swift/AST/Decl.h"
#include "swift/AST/Expr.h"
#include "swift/AST/GenericSignature.h"
#include "swift/AST/Module.h"
#include "swift/AST/ParameterList.h"
#include "swift/AST/Pattern.h"
#include "swift/AST/ProtocolConformance.h"
#include "swift/AST/Stmt.h"
#include "swift/AST/Types.h"
#include "DerivedConformances.h"

using namespace swift;

bool DerivedConformance::canDeriveKeyPathIterable(NominalTypeDecl *nominal) {
// Note: we could extend synthesis to support classes.
// Subclasses need to append `allKeyPaths` to `super.allKeyPaths`.
return isa<StructDecl>(nominal);
}

// Compute `PartialKeyPathType<Nominal>`, bound to the given nominal
// declaration's type.
static Type computePartialKeyPathType(NominalTypeDecl *nominal) {
auto &C = nominal->getASTContext();
auto nominalType = nominal->getDeclaredInterfaceType();
if (!nominalType || nominalType->hasError())
return nullptr;
auto *partialKeyPathDecl = cast<ClassDecl>(C.getPartialKeyPathDecl());
return BoundGenericClassType::get(partialKeyPathDecl, /*parent*/ Type(),
{nominal->getDeclaredInterfaceType()});
}

// Compute `AllKeyPaths` associated type for the given nominal declaration.
// It should be `[PartialKeyPath<Nominal>]`.
static ArraySliceType *computeAllKeyPathsType(NominalTypeDecl *nominal) {
auto partialKeyPathType = computePartialKeyPathType(nominal);
return ArraySliceType::get(partialKeyPathType);
}

// Synthesize body for the `allKeyPaths` computed property getter.
static void
deriveBodyKeyPathIterable_allKeyPaths(AbstractFunctionDecl *funcDecl) {
auto *nominal = funcDecl->getDeclContext()->getSelfNominalTypeDecl();
auto &C = nominal->getASTContext();

auto *nominalTypeExpr = TypeExpr::createForDecl(SourceLoc(), nominal,
funcDecl, /*Implicit*/ true);

// Create array of key path expressions to stored properties.
llvm::SmallVector<Expr *, 2> keyPathExprs;
for (auto member : nominal->getStoredProperties()) {
auto *dotExpr = new (C)
UnresolvedDotExpr(nominalTypeExpr, SourceLoc(), member->getFullName(),
DeclNameLoc(), /*Implicit*/ true);
auto *keyPathExpr =
new (C) KeyPathExpr(SourceLoc(), dotExpr, nullptr, /*Implicit*/ true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woo! This is so simple.

keyPathExprs.push_back(keyPathExpr);
}
// Return array of all key path expressions.
auto keyPathsArrayExpr =
ArrayExpr::create(C, SourceLoc(), keyPathExprs, {}, SourceLoc());
auto *returnStmt = new (C) ReturnStmt(SourceLoc(), keyPathsArrayExpr);
auto *body = BraceStmt::create(C, SourceLoc(), {returnStmt}, SourceLoc(),
/*Implicit*/ true);
funcDecl->setBody(BraceStmt::create(C, SourceLoc(), {body}, SourceLoc(),
/*Implicit*/ true));
}

// Synthesize the `allKeyPaths` computed property declaration.
static ValueDecl *
deriveKeyPathIterable_allKeyPaths(DerivedConformance &derived) {
auto nominal = derived.Nominal;
auto &C = derived.TC.Context;

auto returnInterfaceTy = computeAllKeyPathsType(nominal);
auto returnTy =
derived.getConformanceContext()->mapTypeIntoContext(returnInterfaceTy);

// Create `allKeyPaths` property declaration.
VarDecl *allKeyPathsDecl;
PatternBindingDecl *pbDecl;
std::tie(allKeyPathsDecl, pbDecl) = derived.declareDerivedProperty(
C.Id_allKeyPaths, returnInterfaceTy, returnTy, /*isStatic*/ false,
/*isFinal*/ true);

// Add `@inlinable` to the `allKeyPaths` declaration.
allKeyPathsDecl->getAttrs().add(new (C) InlinableAttr(/*implicit*/ true));

// Create `allKeyPaths` getter.
auto *getterDecl = derived.declareDerivedPropertyGetter(
derived.TC, allKeyPathsDecl, returnTy);
getterDecl->setBodySynthesizer(deriveBodyKeyPathIterable_allKeyPaths);
allKeyPathsDecl->setAccessors(StorageImplInfo::getImmutableComputed(),
SourceLoc(), {getterDecl}, SourceLoc());
derived.addMembersToConformanceContext({getterDecl, allKeyPathsDecl, pbDecl});

return allKeyPathsDecl;
}

static Type deriveKeyPathIterable_AllKeyPaths(DerivedConformance &derived) {
auto *rawInterfaceType = computeAllKeyPathsType(derived.Nominal);
return derived.getConformanceContext()->mapTypeIntoContext(rawInterfaceType);
}

ValueDecl *DerivedConformance::deriveKeyPathIterable(ValueDecl *requirement) {
if (requirement->getBaseName() == TC.Context.Id_allKeyPaths) {
return deriveKeyPathIterable_allKeyPaths(*this);
}
TC.diagnose(requirement->getLoc(),
diag::broken_key_path_iterable_requirement);
return nullptr;
}

Type DerivedConformance::deriveKeyPathIterable(
AssociatedTypeDecl *requirement) {
if (requirement->getBaseName() == TC.Context.Id_AllKeyPaths) {
return deriveKeyPathIterable_AllKeyPaths(*this);
}
TC.diagnose(requirement->getLoc(),
diag::broken_key_path_iterable_requirement);
return nullptr;
}
14 changes: 14 additions & 0 deletions lib/Sema/DerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ bool DerivedConformance::derivesProtocolConformance(DeclContext *DC,
return canDeriveHashable(Nominal);
}

// SWIFT_ENABLE_TENSORFLOW
if (*knownProtocol == KnownProtocolKind::KeyPathIterable)
return canDeriveKeyPathIterable(Nominal);

// SWIFT_ENABLE_TENSORFLOW
// The only requirement for deriving Parameterized is that there exist some
// stored properties marked with @TFParameter. The `Parameters` struct can
Expand Down Expand Up @@ -206,6 +210,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
if (name.isSimpleName(ctx.Id_intValue))
return getRequirement(KnownProtocolKind::CodingKey);

// SWIFT_ENABLE_TENSORFLOW
// KeyPathIterable.allKeyPaths
if (name.isSimpleName(ctx.Id_allKeyPaths))
return getRequirement(KnownProtocolKind::KeyPathIterable);

// SWIFT_ENABLE_TENSORFLOW
// Parameterized.allParameters
if (name.isSimpleName(ctx.Id_allParameters))
Expand Down Expand Up @@ -279,6 +288,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
if (name.isSimpleName(ctx.Id_AllCases))
return getRequirement(KnownProtocolKind::CaseIterable);

// SWIFT_ENABLE_TENSORFLOW
// KeyPathIterable.AllKeyPaths
if (name.isSimpleName(ctx.Id_AllKeyPaths))
return getRequirement(KnownProtocolKind::KeyPathIterable);

// SWIFT_ENABLE_TENSORFLOW
// Parameterized.Parameters
if (name.isSimpleName(ctx.Id_Parameters))
Expand Down
16 changes: 16 additions & 0 deletions lib/Sema/DerivedConformances.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,22 @@ class DerivedConformance {
/// \returns the derived member, which will also be added to the type.
ValueDecl *deriveDecodable(ValueDecl *requirement);

// SWIFT_ENABLE_TENSORFLOW
/// Determine if a KeyPathIterable requirement can be derived for a type.
///
/// \returns True if the requirement can be derived.
static bool canDeriveKeyPathIterable(NominalTypeDecl *type);

/// Derive a KeyPathIterable requirement for a nominal type.
///
/// \returns the derived member, which will also be added to the type.
ValueDecl *deriveKeyPathIterable(ValueDecl *requirement);

/// Derive a KeyPathIterable type witness for a nominal type.
///
/// \returns the derived member, which will also be added to the type.
Type deriveKeyPathIterable(AssociatedTypeDecl *assocType);

// SWIFT_ENABLE_TENSORFLOW
/// Derive a Parameterized requirement for a nominal type.
///
Expand Down
6 changes: 6 additions & 0 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5276,6 +5276,10 @@ ValueDecl *TypeChecker::deriveProtocolRequirement(DeclContext *DC,
case KnownProtocolKind::Decodable:
return derived.deriveDecodable(Requirement);

// SWIFT_ENABLE_TENSORFLOW
case KnownProtocolKind::KeyPathIterable:
return derived.deriveKeyPathIterable(Requirement);

// SWIFT_ENABLE_TENSORFLOW
case KnownProtocolKind::Parameterized:
return derived.deriveParameterized(Requirement);
Expand Down Expand Up @@ -5308,6 +5312,8 @@ Type TypeChecker::deriveTypeWitness(DeclContext *DC,
case KnownProtocolKind::CaseIterable:
return derived.deriveCaseIterable(AssocType);
// SWIFT_ENABLE_TENSORFLOW
case KnownProtocolKind::KeyPathIterable:
return derived.deriveKeyPathIterable(AssocType);
case KnownProtocolKind::Parameterized:
return derived.deriveParameterized(AssocType);
case KnownProtocolKind::ParameterGroup:
Expand Down
2 changes: 2 additions & 0 deletions stdlib/public/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ set(SWIFTLIB_ESSENTIAL
IntegerTypes.swift.gyb
Join.swift
KeyPath.swift
# SWIFT_ENABLE_TENSORFLOW
KeyPathIterable.swift
KeyValuePairs.swift
LazyCollection.swift
LazySequence.swift
Expand Down
3 changes: 2 additions & 1 deletion stdlib/public/core/GroupInfo.json
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@
"MemoryLayout.swift"
],
"KeyPaths": [
"KeyPath.swift"
"KeyPath.swift",
"KeyPathIterable.swift"
],
"Reflection": [
"Dump.swift",
Expand Down
Loading