Skip to content

Commit 8c2c838

Browse files
dan-zhengrxwei
authored andcommitted
---
yaml --- r: 262135 b: refs/heads/tensorflow c: e0cdd32 h: refs/heads/master i: 262133: 4bc227f 262131: 0eddc0b 262127: a095223
1 parent d2ae8b4 commit 8c2c838

15 files changed

+608
-4
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-04-25-a: 22f738a831d43aff2b9c9773bcb65
818818
refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-05-08-a: 7d98cc16689baba5c8a3b90a9329bdcc1a12b4e9
819819
refs/heads/cherr42: a566ad54b073c2c56ac0a705d0a5bed9743135a5
820820
"refs/heads/codable_test_comment_fix": fc8f6824f7f347e1e8db55bff62db385c5728b5a
821-
refs/heads/tensorflow: 31290d09842654ead00dbd0b4852ece9ceb84639
821+
refs/heads/tensorflow: e0cdd321493fa70fafa3256fd614db405d3616bc
822822
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-11-a: 8126fd7a652e2f70ad6d76505239e34fb2ef3e1a
823823
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-12-a: b3fd3dd84df6717f2e2e9df58c6d7e99fed57086
824824
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-13-a: 71135119579039dc321c5f65d870050fe36efda2

branches/tensorflow/include/swift/AST/DiagnosticsSema.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2421,6 +2421,8 @@ ERROR(broken_additive_arithmetic_requirement,none,
24212421
"AdditiveArithmetic protocol is broken: unexpected requirement", ())
24222422
ERROR(broken_vector_numeric_requirement,none,
24232423
"VectorNumeric protocol is broken: unexpected requirement", ())
2424+
ERROR(broken_differentiable_requirement,none,
2425+
"Differentiable protocol is broken: unexpected requirement", ())
24242426

24252427
NOTE(codable_extraneous_codingkey_case_here,none,
24262428
"CodingKey case %0 does not match any stored properties", (Identifier))

branches/tensorflow/include/swift/AST/KnownIdentifiers.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ IDENTIFIER(Parameters)
134134
// Differentiable
135135
IDENTIFIER(CotangentVector)
136136
IDENTIFIER(TangentVector)
137+
IDENTIFIER(moved)
138+
IDENTIFIER(tangentVector)
137139

138140
// Kinds of layout constraints
139141
IDENTIFIER_WITH_NAME(UnknownLayout, "_UnknownLayout")

branches/tensorflow/lib/Sema/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ add_swift_host_library(swiftSema STATIC
2828
DerivedConformanceError.cpp
2929
# SWIFT_ENABLE_TENSORFLOW
3030
DerivedConformanceAdditiveArithmeticVectorNumeric.cpp
31+
DerivedConformanceDifferentiable.cpp
3132
DerivedConformanceKeyPathIterable.cpp
3233
DerivedConformanceParameterGroup.cpp
3334
DerivedConformanceParameterized.cpp

branches/tensorflow/lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 436 additions & 0 deletions
Large diffs are not rendered by default.

branches/tensorflow/lib/Sema/DerivedConformances.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ bool DerivedConformance::derivesProtocolConformance(DeclContext *DC,
7474
if (*knownProtocol == KnownProtocolKind::VectorNumeric)
7575
return canDeriveVectorNumeric(Nominal);
7676

77+
// SWIFT_ENABLE_TENSORFLOW
78+
if (*knownProtocol == KnownProtocolKind::Differentiable)
79+
return canDeriveDifferentiable(Nominal);
80+
7781
// SWIFT_ENABLE_TENSORFLOW
7882
// The only requirement for deriving Parameterized is that there exist some
7983
// stored properties marked with @TFParameter. The `Parameters` struct can
@@ -274,6 +278,28 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
274278
}
275279
}
276280

281+
// SWIFT_ENABLE_TENSORFLOW
282+
// Differentiable.moved(along:)
283+
if (name.isCompoundName() &&
284+
name.getBaseName() == ctx.Id_moved) {
285+
auto argumentNames = name.getArgumentNames();
286+
if (argumentNames.size() == 1 &&
287+
argumentNames[0] == ctx.getIdentifier("along")) {
288+
return getRequirement(KnownProtocolKind::Differentiable);
289+
}
290+
}
291+
292+
// SWIFT_ENABLE_TENSORFLOW
293+
// Differentiable.tangentVector(from:)
294+
if (name.isCompoundName() &&
295+
name.getBaseName() == ctx.Id_tangentVector) {
296+
auto argumentNames = name.getArgumentNames();
297+
if (argumentNames.size() == 1 &&
298+
argumentNames[0] == ctx.getIdentifier("from")) {
299+
return getRequirement(KnownProtocolKind::Differentiable);
300+
}
301+
}
302+
277303
// SWIFT_ENABLE_TENSORFLOW
278304
// ParameterGroup.update(withGradients:_:)
279305
if (name.isCompoundName() &&
@@ -325,6 +351,13 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
325351
if (name.isSimpleName(ctx.Id_AllKeyPaths))
326352
return getRequirement(KnownProtocolKind::KeyPathIterable);
327353

354+
// SWIFT_ENABLE_TENSORFLOW
355+
// Differentiable.TangentVector
356+
// Differentiable.CotangentVector
357+
if (name.isSimpleName(ctx.Id_TangentVector) ||
358+
name.isSimpleName(ctx.Id_CotangentVector))
359+
return getRequirement(KnownProtocolKind::Differentiable);
360+
328361
// SWIFT_ENABLE_TENSORFLOW
329362
// Parameterized.Parameters
330363
if (name.isSimpleName(ctx.Id_Parameters))

branches/tensorflow/lib/Sema/DerivedConformances.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,22 @@ class DerivedConformance {
233233
/// \returns the derived member, which will also be added to the type.
234234
Type deriveVectorNumeric(AssociatedTypeDecl *assocType);
235235

236+
/// Determine if a Differentiable requirement can be derived for a type.
237+
///
238+
/// \returns True if the requirement can be derived.
239+
static bool canDeriveDifferentiable(NominalTypeDecl *type);
240+
241+
/// Derive a Differentiable requirement for a nominal type.
242+
///
243+
/// \returns the derived member, which will also be added to the type.
244+
ValueDecl *deriveDifferentiable(ValueDecl *requirement);
245+
246+
/// Derive a Differentiable type witness for a nominal type, if it has
247+
/// parameters (stored properties marked with @TFParameter).
248+
///
249+
/// \returns the derived member, which will also be added to the type.
250+
Type deriveDifferentiable(AssociatedTypeDecl *assocType);
251+
236252
/// Declare a read-only property.
237253
std::pair<VarDecl *, PatternBindingDecl *>
238254
declareDerivedProperty(Identifier name, Type propertyInterfaceType,

branches/tensorflow/lib/Sema/TypeCheckAttr.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,10 +2385,10 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
23852385

23862386
// Predicate checking if a type has associated tangent and cotangent spaces.
23872387
auto hasAssociatedSpaces = [&](Type type) -> bool {
2388-
// No need to check for cotangent space because every type with a tangent
2389-
// space also has a cotangent space.
23902388
return (bool)type->getAutoDiffAssociatedVectorSpace(
2391-
AutoDiffAssociatedVectorSpaceKind::Tangent, lookupConformance);
2389+
AutoDiffAssociatedVectorSpaceKind::Tangent, lookupConformance) &&
2390+
(bool)type->getAutoDiffAssociatedVectorSpace(
2391+
AutoDiffAssociatedVectorSpaceKind::Cotangent, lookupConformance);
23922392
};
23932393

23942394
// Check that the user has only selected wrt params with allowed types.

branches/tensorflow/lib/Sema/TypeCheckProtocol.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5296,6 +5296,10 @@ ValueDecl *TypeChecker::deriveProtocolRequirement(DeclContext *DC,
52965296
case KnownProtocolKind::VectorNumeric:
52975297
return derived.deriveVectorNumeric(Requirement);
52985298

5299+
// SWIFT_ENABLE_TENSORFLOW
5300+
case KnownProtocolKind::Differentiable:
5301+
return derived.deriveDifferentiable(Requirement);
5302+
52995303
default:
53005304
return nullptr;
53015305
}
@@ -5328,6 +5332,8 @@ Type TypeChecker::deriveTypeWitness(DeclContext *DC,
53285332
return derived.deriveParameterGroup(AssocType);
53295333
case KnownProtocolKind::VectorNumeric:
53305334
return derived.deriveVectorNumeric(AssocType);
5335+
case KnownProtocolKind::Differentiable:
5336+
return derived.deriveDifferentiable(AssocType);
53315337
default:
53325338
return nullptr;
53335339
}

branches/tensorflow/stdlib/public/TensorFlow/Ops.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ extension Tensor : ShapedVectorNumeric where Scalar : Numeric {}
104104
extension Tensor : Differentiable where Scalar : FloatingPoint {
105105
public typealias TangentVector = Tensor
106106
public typealias CotangentVector = Tensor
107+
public func tangentVector(from cotangent: CotangentVector) -> TangentVector {
108+
return cotangent
109+
}
107110
}
108111

109112
//===----------------------------------------------------------------------===//

branches/tensorflow/stdlib/public/core/AutoDiff.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,16 @@ public extension Differentiable
8787
}
8888
}
8989

90+
// FIXME: This is currently commented because the where clause leads to
91+
// associated type inference which conflicts with `Differentiable` derived
92+
// conformances.
93+
/*
9094
public extension Differentiable where TangentVector == CotangentVector {
9195
func tangentVector(from cotangent: CotangentVector) -> TangentVector {
9296
return cotangent
9397
}
9498
}
99+
*/
95100

96101
//===----------------------------------------------------------------------===//
97102
// Differential Operators

branches/tensorflow/stdlib/public/core/FloatingPointTypes.swift.gyb

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,6 +1852,9 @@ extension ${Self} : VectorNumeric {
18521852
extension ${Self} : Differentiable {
18531853
public typealias TangentVector = ${Self}
18541854
public typealias CotangentVector = ${Self}
1855+
public func tangentVector(from cotangent: CotangentVector) -> TangentVector {
1856+
return cotangent
1857+
}
18551858
}
18561859

18571860
//===----------------------------------------------------------------------===//

branches/tensorflow/test/AutoDiff/e2e_differentiable_property.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ extension Space : Differentiable {
5252
func moved(along: TangentSpace) -> Space {
5353
return Space(x: x + along.dx, y: y + along.dy)
5454
}
55+
func tangentVector(from cotangent: CotangentVector) -> TangentVector {
56+
return cotangent
57+
}
5558
}
5659

5760
E2EDifferentiablePropertyTests.test("computed property") {
@@ -110,6 +113,9 @@ extension ProductSpaceOtherTangent : Differentiable {
110113
func moved(along: ProductSpaceOtherTangentTangentSpace) -> ProductSpaceOtherTangent {
111114
return ProductSpaceOtherTangent(x: x + along.x, y: y + along.y)
112115
}
116+
func tangentVector(from cotangent: CotangentVector) -> TangentVector {
117+
return cotangent
118+
}
113119
}
114120

115121
E2EDifferentiablePropertyTests.test("fieldwise product space, other tangent") {

branches/tensorflow/test/AutoDiff/refcounting.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ public struct Vector : AdditiveArithmetic, VectorNumeric, Differentiable, Equata
1212
public var y: Float
1313
public var nonTrivialStuff = NonTrivialStuff()
1414
public typealias TangentVector = Vector
15+
public typealias CotangentVector = Vector
16+
public func tangentVector(from cotangent: CotangentVector) -> TangentVector { return cotangent }
1517
public typealias Scalar = Float
1618
public static var zero: Vector { return Vector(0) }
1719
public init(_ scalar: Float) { self.x = scalar; self.y = scalar }
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// SWIFT_ENABLE_TENSORFLOW
2+
// RUN: %target-swift-frontend -typecheck -verify %s -verify-ignore-unknown
3+
4+
struct Simple : VectorNumeric, Differentiable {
5+
var w: Float
6+
var b: Float
7+
}
8+
var simple = Simple(w: 1, b: 1)
9+
assert(simple.moved(along: simple) == simple + simple)
10+
assert(simple.tangentVector(from: simple) == simple)
11+
12+
// Test type with mixed members.
13+
struct Mixed : VectorNumeric, Differentiable {
14+
var simple: Simple
15+
var float: Float
16+
}
17+
var mixed = Mixed(simple: simple, float: 1)
18+
assert(mixed.moved(along: mixed) == mixed + mixed)
19+
assert(mixed.tangentVector(from: mixed) == mixed)
20+
21+
// Test type with generic members that conform to `Differentiable`.
22+
// Since `T == T.TangentVector == T.CotangentVector`,
23+
// it's only necessary to synthesis typealiases:
24+
// `typealias TangentVector = Generic`
25+
// `typealias CotangentVector = Generic`
26+
struct Generic<T> : VectorNumeric, Differentiable
27+
where T : Differentiable, T == T.TangentVector, T == T.CotangentVector
28+
{
29+
var w: T
30+
var b: T
31+
}
32+
var generic = Generic<Double>(w: 1, b: 1)
33+
assert(generic.moved(along: generic) == generic + generic)
34+
assert(generic.tangentVector(from: generic) == generic)
35+
36+
// Test type with manual definition of vector space types to `Self`.
37+
struct VectorSpacesEqualSelf : VectorNumeric, Differentiable {
38+
var w: Float
39+
var b: Float
40+
typealias TangentVector = VectorSpacesEqualSelf
41+
typealias CotangentVector = VectorSpacesEqualSelf
42+
}
43+
44+
// TODO: Support the cases below after `Differentiable` derived conformances
45+
// limitations are lifted.
46+
47+
/*
48+
// Test type with generic members that conform to `Differentiable`.
49+
// Since it's not the case that
50+
// `T == T.TangentVector == T.CotangentVector`,
51+
// it's necessary to synthesize new vector space struct types.
52+
struct GenericNeedsVectorSpaceStructs<T> : VectorNumeric, Differentiable
53+
where T : VectorNumeric, T : Differentiable
54+
{
55+
var w: T
56+
var b: T
57+
}
58+
59+
// Test type that doesn't conform to `VectorNumeric`.
60+
// Thus, `Self` cannot be used as `TangentVector` or `CotangentVector`.
61+
// Vector space structs types must be synthesized.
62+
// Note: it would be nice to emit a warning if conforming `Self` to
63+
// `VectorNumeric` is possible.
64+
struct NotVectorNumeric : Differentiable {
65+
var w: Float
66+
var b: Float
67+
}
68+
*/
69+
70+
// Test errors.
71+
72+
// Test manually customizing vector space types.
73+
// Thees should fail. Synthesis is semantically unsupported if vector space
74+
// types are customized.
75+
struct VectorSpaceTypeAlias : VectorNumeric, Differentiable { // expected-error {{type 'VectorSpaceTypeAlias' does not conform to protocol 'Differentiable'}}
76+
var w: Float
77+
var b: Float
78+
typealias TangentVector = Simple
79+
}
80+
struct VectorSpaceCustomStruct : VectorNumeric, Differentiable { // expected-error {{type 'VectorSpaceCustomStruct' does not conform to protocol 'Differentiable'}}
81+
var w: Float
82+
var b: Float
83+
struct CotangentVector : VectorNumeric, Differentiable {
84+
var w: Float.CotangentVector
85+
var b: Float.CotangentVector
86+
typealias TangentVector = VectorSpaceCustomStruct.CotangentVector
87+
typealias CotangentVector = VectorSpaceCustomStruct.CotangentVector
88+
}
89+
}

0 commit comments

Comments
 (0)