Skip to content

Commit 247bb4b

Browse files
dan-zhengrxwei
authored andcommitted
Add KeyPathIterable protocol and synthesis. (#21557)
* Add `KeyPathIterable` protocol and synthesis. The `KeyPathIterable` protocol represents types whose values provide custom key paths to properties or elements. For types that conform to `KeyPathIterable`, the compiler can synthesize a default implementation of `allKeyPaths` based on the type's stored properties. --- `KeyPathIterable` enables generic machine learning optimizers and will be used to replace `ParameterGroup`. The `ParameterGroup` protocol requirements are not sufficiently general for many use cases: it is limited to a single parameter type and does not support joint iteration over parameters for multiple `ParameterGroup` instances. `KeyPathIterable` solves these problems. In addition, it is a generally useful language feature that can model both static stored properties and custom dynamic properties. * Address feedback from @rxwei. - Remove `@inlinable` from `KeyPathIterable` extension methods. - Enable `KeyPathIterable` synthesis for empty structs, add test. - Add comment by @rxwei about supporting synthesis for classes in the future. * Address feedback from @rxwei. Mark synthesized `allKeyPaths` declaration with `@inlinable`. Verified that `@inlinable` is correctly added in SILGen: ``` struct Parameters : KeyPathIterable { @sil_stored var w: Float { get set } @sil_stored var b: Float { get set } init(w: Float, b: Float) typealias AllKeyPaths = [PartialKeyPath<Parameters>] @inlinable var allKeyPaths: [PartialKeyPath<Parameters>] { get } } ```
1 parent a526b0e commit 247bb4b

14 files changed

+622
-2
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2405,6 +2405,8 @@ ERROR(broken_encodable_requirement,none,
24052405
ERROR(broken_decodable_requirement,none,
24062406
"Decodable protocol is broken: unexpected requirement", ())
24072407
// SWIFT_ENABLE_TENSORFLOW
2408+
ERROR(broken_key_path_iterable_requirement,none,
2409+
"KeyPathIterable protocol is broken: unexpected requirement", ())
24082410
ERROR(broken_parameter_group_requirement,none,
24092411
"ParameterGroup protocol is broken: unexpected requirement", ())
24102412
ERROR(broken_parameterized_requirement,none,

include/swift/AST/KnownIdentifiers.def

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ IDENTIFIER(decode)
4747
IDENTIFIER(decodeIfPresent)
4848
IDENTIFIER(Decoder)
4949
IDENTIFIER(decoder)
50-
// SWIFT_ENABLE_TENSORFLOW
5150
IDENTIFIER(dynamicCallable)
5251
IDENTIFIER(dynamicMember)
5352
IDENTIFIER(Element)
@@ -118,6 +117,11 @@ IDENTIFIER_WITH_NAME(value_, "_value")
118117
IDENTIFIER(with)
119118

120119
// SWIFT_ENABLE_TENSORFLOW
120+
IDENTIFIER(AllKeyPaths)
121+
IDENTIFIER(allKeyPaths)
122+
IDENTIFIER(recursivelyAllKeyPaths)
123+
IDENTIFIER(allWritableKeyPaths)
124+
IDENTIFIER(recursivelyAllWritableKeyPaths)
121125
IDENTIFIER(allParameters)
122126
IDENTIFIER(Parameter)
123127
IDENTIFIER(Parameters)

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ PROTOCOL(Decodable)
7070
PROTOCOL(AdditiveArithmetic)
7171
PROTOCOL(Numeric)
7272
PROTOCOL(FloatingPoint)
73+
PROTOCOL(KeyPathIterable)
7374
PROTOCOL(ParameterGroup)
7475
PROTOCOL(Parameterized)
7576
PROTOCOL(TensorArrayProtocol)

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4061,6 +4061,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
40614061
case KnownProtocolKind::FloatingPoint:
40624062
case KnownProtocolKind::AdditiveArithmetic:
40634063
case KnownProtocolKind::Numeric:
4064+
case KnownProtocolKind::KeyPathIterable:
40644065
case KnownProtocolKind::ParameterGroup:
40654066
case KnownProtocolKind::Parameterized:
40664067
case KnownProtocolKind::TensorArrayProtocol:

lib/Sema/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ add_swift_host_library(swiftSema STATIC
2727
DerivedConformanceEquatableHashable.cpp
2828
DerivedConformanceError.cpp
2929
# SWIFT_ENABLE_TENSORFLOW
30+
DerivedConformanceKeyPathIterable.cpp
3031
DerivedConformanceParameterGroup.cpp
3132
DerivedConformanceParameterized.cpp
3233
DerivedConformanceRawRepresentable.cpp
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
//===--- DerivedConformanceKeyPathIterable.cpp ----------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file implements explicit derivation of the KeyPathIterable protocol for
14+
// a nominal type.
15+
//
16+
//===----------------------------------------------------------------------===//
17+
18+
#include "CodeSynthesis.h"
19+
#include "TypeChecker.h"
20+
#include "swift/AST/Decl.h"
21+
#include "swift/AST/Expr.h"
22+
#include "swift/AST/GenericSignature.h"
23+
#include "swift/AST/Module.h"
24+
#include "swift/AST/ParameterList.h"
25+
#include "swift/AST/Pattern.h"
26+
#include "swift/AST/ProtocolConformance.h"
27+
#include "swift/AST/Stmt.h"
28+
#include "swift/AST/Types.h"
29+
#include "DerivedConformances.h"
30+
31+
using namespace swift;
32+
33+
bool DerivedConformance::canDeriveKeyPathIterable(NominalTypeDecl *nominal) {
34+
// Note: we could extend synthesis to support classes.
35+
// Subclasses need to append `allKeyPaths` to `super.allKeyPaths`.
36+
return isa<StructDecl>(nominal);
37+
}
38+
39+
// Compute `PartialKeyPathType<Nominal>`, bound to the given nominal
40+
// declaration's type.
41+
static Type computePartialKeyPathType(NominalTypeDecl *nominal) {
42+
auto &C = nominal->getASTContext();
43+
auto nominalType = nominal->getDeclaredInterfaceType();
44+
if (!nominalType || nominalType->hasError())
45+
return nullptr;
46+
auto *partialKeyPathDecl = cast<ClassDecl>(C.getPartialKeyPathDecl());
47+
return BoundGenericClassType::get(partialKeyPathDecl, /*parent*/ Type(),
48+
{nominal->getDeclaredInterfaceType()});
49+
}
50+
51+
// Compute `AllKeyPaths` associated type for the given nominal declaration.
52+
// It should be `[PartialKeyPath<Nominal>]`.
53+
static ArraySliceType *computeAllKeyPathsType(NominalTypeDecl *nominal) {
54+
auto partialKeyPathType = computePartialKeyPathType(nominal);
55+
return ArraySliceType::get(partialKeyPathType);
56+
}
57+
58+
// Synthesize body for the `allKeyPaths` computed property getter.
59+
static void
60+
deriveBodyKeyPathIterable_allKeyPaths(AbstractFunctionDecl *funcDecl) {
61+
auto *nominal = funcDecl->getDeclContext()->getSelfNominalTypeDecl();
62+
auto &C = nominal->getASTContext();
63+
64+
auto *nominalTypeExpr = TypeExpr::createForDecl(SourceLoc(), nominal,
65+
funcDecl, /*Implicit*/ true);
66+
67+
// Create array of key path expressions to stored properties.
68+
llvm::SmallVector<Expr *, 2> keyPathExprs;
69+
for (auto member : nominal->getStoredProperties()) {
70+
auto *dotExpr = new (C)
71+
UnresolvedDotExpr(nominalTypeExpr, SourceLoc(), member->getFullName(),
72+
DeclNameLoc(), /*Implicit*/ true);
73+
auto *keyPathExpr =
74+
new (C) KeyPathExpr(SourceLoc(), dotExpr, nullptr, /*Implicit*/ true);
75+
keyPathExprs.push_back(keyPathExpr);
76+
}
77+
// Return array of all key path expressions.
78+
auto keyPathsArrayExpr =
79+
ArrayExpr::create(C, SourceLoc(), keyPathExprs, {}, SourceLoc());
80+
auto *returnStmt = new (C) ReturnStmt(SourceLoc(), keyPathsArrayExpr);
81+
auto *body = BraceStmt::create(C, SourceLoc(), {returnStmt}, SourceLoc(),
82+
/*Implicit*/ true);
83+
funcDecl->setBody(BraceStmt::create(C, SourceLoc(), {body}, SourceLoc(),
84+
/*Implicit*/ true));
85+
}
86+
87+
// Synthesize the `allKeyPaths` computed property declaration.
88+
static ValueDecl *
89+
deriveKeyPathIterable_allKeyPaths(DerivedConformance &derived) {
90+
auto nominal = derived.Nominal;
91+
auto &C = derived.TC.Context;
92+
93+
auto returnInterfaceTy = computeAllKeyPathsType(nominal);
94+
auto returnTy =
95+
derived.getConformanceContext()->mapTypeIntoContext(returnInterfaceTy);
96+
97+
// Create `allKeyPaths` property declaration.
98+
VarDecl *allKeyPathsDecl;
99+
PatternBindingDecl *pbDecl;
100+
std::tie(allKeyPathsDecl, pbDecl) = derived.declareDerivedProperty(
101+
C.Id_allKeyPaths, returnInterfaceTy, returnTy, /*isStatic*/ false,
102+
/*isFinal*/ true);
103+
104+
// Add `@inlinable` to the `allKeyPaths` declaration.
105+
allKeyPathsDecl->getAttrs().add(new (C) InlinableAttr(/*implicit*/ true));
106+
107+
// Create `allKeyPaths` getter.
108+
auto *getterDecl = derived.declareDerivedPropertyGetter(
109+
derived.TC, allKeyPathsDecl, returnTy);
110+
getterDecl->setBodySynthesizer(deriveBodyKeyPathIterable_allKeyPaths);
111+
allKeyPathsDecl->setAccessors(StorageImplInfo::getImmutableComputed(),
112+
SourceLoc(), {getterDecl}, SourceLoc());
113+
derived.addMembersToConformanceContext({getterDecl, allKeyPathsDecl, pbDecl});
114+
115+
return allKeyPathsDecl;
116+
}
117+
118+
static Type deriveKeyPathIterable_AllKeyPaths(DerivedConformance &derived) {
119+
auto *rawInterfaceType = computeAllKeyPathsType(derived.Nominal);
120+
return derived.getConformanceContext()->mapTypeIntoContext(rawInterfaceType);
121+
}
122+
123+
ValueDecl *DerivedConformance::deriveKeyPathIterable(ValueDecl *requirement) {
124+
if (requirement->getBaseName() == TC.Context.Id_allKeyPaths) {
125+
return deriveKeyPathIterable_allKeyPaths(*this);
126+
}
127+
TC.diagnose(requirement->getLoc(),
128+
diag::broken_key_path_iterable_requirement);
129+
return nullptr;
130+
}
131+
132+
Type DerivedConformance::deriveKeyPathIterable(
133+
AssociatedTypeDecl *requirement) {
134+
if (requirement->getBaseName() == TC.Context.Id_AllKeyPaths) {
135+
return deriveKeyPathIterable_AllKeyPaths(*this);
136+
}
137+
TC.diagnose(requirement->getLoc(),
138+
diag::broken_key_path_iterable_requirement);
139+
return nullptr;
140+
}

lib/Sema/DerivedConformances.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ bool DerivedConformance::derivesProtocolConformance(DeclContext *DC,
6262
return canDeriveHashable(Nominal);
6363
}
6464

65+
// SWIFT_ENABLE_TENSORFLOW
66+
if (*knownProtocol == KnownProtocolKind::KeyPathIterable)
67+
return canDeriveKeyPathIterable(Nominal);
68+
6569
// SWIFT_ENABLE_TENSORFLOW
6670
// The only requirement for deriving Parameterized is that there exist some
6771
// stored properties marked with @TFParameter. The `Parameters` struct can
@@ -206,6 +210,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
206210
if (name.isSimpleName(ctx.Id_intValue))
207211
return getRequirement(KnownProtocolKind::CodingKey);
208212

213+
// SWIFT_ENABLE_TENSORFLOW
214+
// KeyPathIterable.allKeyPaths
215+
if (name.isSimpleName(ctx.Id_allKeyPaths))
216+
return getRequirement(KnownProtocolKind::KeyPathIterable);
217+
209218
// SWIFT_ENABLE_TENSORFLOW
210219
// Parameterized.allParameters
211220
if (name.isSimpleName(ctx.Id_allParameters))
@@ -279,6 +288,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
279288
if (name.isSimpleName(ctx.Id_AllCases))
280289
return getRequirement(KnownProtocolKind::CaseIterable);
281290

291+
// SWIFT_ENABLE_TENSORFLOW
292+
// KeyPathIterable.AllKeyPaths
293+
if (name.isSimpleName(ctx.Id_AllKeyPaths))
294+
return getRequirement(KnownProtocolKind::KeyPathIterable);
295+
282296
// SWIFT_ENABLE_TENSORFLOW
283297
// Parameterized.Parameters
284298
if (name.isSimpleName(ctx.Id_Parameters))

lib/Sema/DerivedConformances.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,22 @@ class DerivedConformance {
164164
/// \returns the derived member, which will also be added to the type.
165165
ValueDecl *deriveDecodable(ValueDecl *requirement);
166166

167+
// SWIFT_ENABLE_TENSORFLOW
168+
/// Determine if a KeyPathIterable requirement can be derived for a type.
169+
///
170+
/// \returns True if the requirement can be derived.
171+
static bool canDeriveKeyPathIterable(NominalTypeDecl *type);
172+
173+
/// Derive a KeyPathIterable requirement for a nominal type.
174+
///
175+
/// \returns the derived member, which will also be added to the type.
176+
ValueDecl *deriveKeyPathIterable(ValueDecl *requirement);
177+
178+
/// Derive a KeyPathIterable type witness for a nominal type.
179+
///
180+
/// \returns the derived member, which will also be added to the type.
181+
Type deriveKeyPathIterable(AssociatedTypeDecl *assocType);
182+
167183
// SWIFT_ENABLE_TENSORFLOW
168184
/// Derive a Parameterized requirement for a nominal type.
169185
///

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5276,6 +5276,10 @@ ValueDecl *TypeChecker::deriveProtocolRequirement(DeclContext *DC,
52765276
case KnownProtocolKind::Decodable:
52775277
return derived.deriveDecodable(Requirement);
52785278

5279+
// SWIFT_ENABLE_TENSORFLOW
5280+
case KnownProtocolKind::KeyPathIterable:
5281+
return derived.deriveKeyPathIterable(Requirement);
5282+
52795283
// SWIFT_ENABLE_TENSORFLOW
52805284
case KnownProtocolKind::Parameterized:
52815285
return derived.deriveParameterized(Requirement);
@@ -5308,6 +5312,8 @@ Type TypeChecker::deriveTypeWitness(DeclContext *DC,
53085312
case KnownProtocolKind::CaseIterable:
53095313
return derived.deriveCaseIterable(AssocType);
53105314
// SWIFT_ENABLE_TENSORFLOW
5315+
case KnownProtocolKind::KeyPathIterable:
5316+
return derived.deriveKeyPathIterable(AssocType);
53115317
case KnownProtocolKind::Parameterized:
53125318
return derived.deriveParameterized(AssocType);
53135319
case KnownProtocolKind::ParameterGroup:

stdlib/public/core/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ set(SWIFTLIB_ESSENTIAL
8888
IntegerTypes.swift.gyb
8989
Join.swift
9090
KeyPath.swift
91+
# SWIFT_ENABLE_TENSORFLOW
92+
KeyPathIterable.swift
9193
KeyValuePairs.swift
9294
LazyCollection.swift
9395
LazySequence.swift

stdlib/public/core/GroupInfo.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@
140140
"MemoryLayout.swift"
141141
],
142142
"KeyPaths": [
143-
"KeyPath.swift"
143+
"KeyPath.swift",
144+
"KeyPathIterable.swift"
144145
],
145146
"Reflection": [
146147
"Dump.swift",

0 commit comments

Comments
 (0)