Skip to content

Commit 7db1728

Browse files
dan-zhengrxwei
authored andcommitted
---
yaml --- r: 262110 b: refs/heads/tensorflow c: 4d4f4b2 h: refs/heads/master
1 parent 0007dd2 commit 7db1728

File tree

6 files changed

+59
-33
lines changed

6 files changed

+59
-33
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-04-25-a: 22f738a831d43aff2b9c9773bcb65
818818
refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-05-08-a: 7d98cc16689baba5c8a3b90a9329bdcc1a12b4e9
819819
refs/heads/cherr42: a566ad54b073c2c56ac0a705d0a5bed9743135a5
820820
"refs/heads/codable_test_comment_fix": fc8f6824f7f347e1e8db55bff62db385c5728b5a
821-
refs/heads/tensorflow: 714f8d6dc4f20647679259a11a99bea287b0c975
821+
refs/heads/tensorflow: 4d4f4b253ca9deb4fd12428b110c244e17d18164
822822
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-11-a: 8126fd7a652e2f70ad6d76505239e34fb2ef3e1a
823823
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-12-a: b3fd3dd84df6717f2e2e9df58c6d7e99fed57086
824824
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-13-a: 71135119579039dc321c5f65d870050fe36efda2

branches/tensorflow/stdlib/public/core/AutoDiff.swift

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,24 +63,26 @@ public protocol ShapedVectorNumeric : VectorNumeric {
6363
/// elements are of `Tangent` type.
6464
public protocol Differentiable {
6565
/// The tangent vector space of this differentiable manifold.
66-
associatedtype TangentVector : Differentiable
67-
where TangentVector.TangentVector == TangentVector
66+
associatedtype TangentVector : Differentiable, VectorNumeric
67+
where TangentVector.TangentVector == TangentVector,
68+
TangentVector.Scalar : FloatingPoint
6869
/// The cotangent space of this differentiable manifold.
69-
associatedtype CotangentVector : Differentiable
70-
where CotangentVector.CotangentVector == CotangentVector
70+
associatedtype CotangentVector : Differentiable, VectorNumeric
71+
where CotangentVector.CotangentVector == CotangentVector,
72+
CotangentVector.Scalar : FloatingPoint
7173

7274
/// Returns `self` moved along the value space towards the given tangent
7375
/// vector. In Riemannian geometry (mathematics), this represents an
7476
/// exponential map.
75-
func moved(toward direction: TangentVector) -> Self
77+
func moved(along direction: TangentVector) -> Self
7678

7779
/// Convert a cotangent vector to its corresponding tangent vector.
7880
func tangentVector(from cotangent: CotangentVector) -> TangentVector
7981
}
8082

8183
public extension Differentiable
8284
where Self : VectorNumeric, TangentVector == Self {
83-
func moved(toward direction: TangentVector) -> Self {
85+
func moved(along direction: TangentVector) -> Self {
8486
return self + direction
8587
}
8688
}

branches/tensorflow/stdlib/public/core/FloatingPointTypes.swift.gyb

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,12 +1832,7 @@ extension ${Self} : Strideable {
18321832
//===----------------------------------------------------------------------===//
18331833

18341834
extension ${Self} : VectorNumeric {
1835-
public typealias Shape = ()
18361835
public typealias Scalar = ${Self}
1837-
1838-
public init(repeating repeatedValue: ${Self}, shape: ()) {
1839-
self = repeatedValue
1840-
}
18411836
}
18421837

18431838
extension ${Self} : Differentiable {

branches/tensorflow/test/AutoDiff/differentiable_attr_type_checking.swift

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -540,15 +540,23 @@ struct JVPStruct {
540540
let p: Float
541541
}
542542

543+
extension JVPStruct : VectorNumeric {
544+
static var zero: JVPStruct { return JVPStruct(p: 0) }
545+
static func + (lhs: JVPStruct, rhs: JVPStruct) -> JVPStruct {
546+
return JVPStruct(p: lhs.p + rhs.p)
547+
}
548+
static func - (lhs: JVPStruct, rhs: JVPStruct) -> JVPStruct {
549+
return JVPStruct(p: lhs.p - rhs.p)
550+
}
551+
typealias Scalar = Float
552+
static func * (lhs: Float, rhs: JVPStruct) -> JVPStruct {
553+
return JVPStruct(p: lhs * rhs.p)
554+
}
555+
}
556+
543557
extension JVPStruct : Differentiable {
544558
typealias TangentVector = JVPStruct
545559
typealias CotangentVector = JVPStruct
546-
func moved(toward direction: JVPStruct) -> JVPStruct {
547-
return JVPStruct(p: p + direction.p)
548-
}
549-
func tangentVector(from cotangent: JVPStruct) -> JVPStruct {
550-
return cotangent
551-
}
552560
}
553561

554562
extension JVPStruct {
@@ -634,15 +642,23 @@ struct VJPStruct {
634642
let p: Float
635643
}
636644

645+
extension VJPStruct : VectorNumeric {
646+
static var zero: VJPStruct { return VJPStruct(p: 0) }
647+
static func + (lhs: VJPStruct, rhs: VJPStruct) -> VJPStruct {
648+
return VJPStruct(p: lhs.p + rhs.p)
649+
}
650+
static func - (lhs: VJPStruct, rhs: VJPStruct) -> VJPStruct {
651+
return VJPStruct(p: lhs.p - rhs.p)
652+
}
653+
typealias Scalar = Float
654+
static func * (lhs: Float, rhs: VJPStruct) -> VJPStruct {
655+
return VJPStruct(p: lhs * rhs.p)
656+
}
657+
}
658+
637659
extension VJPStruct : Differentiable {
638660
typealias TangentVector = VJPStruct
639661
typealias CotangentVector = VJPStruct
640-
func moved(toward direction: VJPStruct) -> VJPStruct {
641-
return VJPStruct(p: p + direction.p)
642-
}
643-
func tangentVector(from cotangent: VJPStruct) -> VJPStruct {
644-
return cotangent
645-
}
646662
}
647663

648664
extension VJPStruct {

branches/tensorflow/test/AutoDiff/protocol_requirement_autodiff.swift

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,7 @@ extension DiffReq {
2525

2626
struct Quadratic : DiffReq, Equatable {
2727
typealias TangentVector = Quadratic
28-
typealias CoangentVector = Quadratic
29-
func moved(toward q: Quadratic) -> Quadratic {
30-
return Quadratic(a + q.a, b + q.b, c + q.c)
31-
}
28+
typealias CotangentVector = Quadratic
3229

3330
let a, b, c: Float
3431
init(_ a: Float, _ b: Float, _ c: Float) {
@@ -42,13 +39,26 @@ struct Quadratic : DiffReq, Equatable {
4239
}
4340
}
4441

42+
extension Quadratic : VectorNumeric {
43+
static var zero: Quadratic { return Quadratic(0, 0, 0) }
44+
static func + (lhs: Quadratic, rhs: Quadratic) -> Quadratic {
45+
return Quadratic(lhs.a + rhs.a, lhs.b + rhs.b, lhs.c + rhs.c)
46+
}
47+
static func - (lhs: Quadratic, rhs: Quadratic) -> Quadratic {
48+
return Quadratic(lhs.a + rhs.a, lhs.b + rhs.b, lhs.c + rhs.c)
49+
}
50+
typealias Scalar = Float
51+
static func * (lhs: Float, rhs: Quadratic) -> Quadratic {
52+
return Quadratic(lhs * rhs.a, lhs * rhs.b, lhs * rhs.c)
53+
}
54+
}
55+
4556
ProtocolRequirementAutodiffTests.test("Trivial") {
4657
expectEqual((Quadratic(0, 0, 1), 12), Quadratic(11, 12, 13).gradF(at: 0))
4758
expectEqual((Quadratic(1, 1, 1), 2 * 11 + 12),
4859
Quadratic(11, 12, 13).gradF(at: 1))
4960
expectEqual((Quadratic(4, 2, 1), 2 * 11 * 2 + 12),
5061
Quadratic(11, 12, 13).gradF(at: 2))
51-
5262
}
5363

5464
runAllTests()

branches/tensorflow/test/AutoDiff/witness_table_silgen.swift

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@ protocol Proto : Differentiable {
1111
func function3(_ x: Float, _ y: Float) -> Float
1212
}
1313

14-
struct S : Proto {
14+
struct S : Proto, VectorNumeric {
15+
static var zero: S { return S(p: 0) }
16+
typealias Scalar = Float
17+
static func + (lhs: S, rhs: S) -> S { return S(p: lhs.p + rhs.p) }
18+
static func - (lhs: S, rhs: S) -> S { return S(p: lhs.p - rhs.p) }
19+
static func * (lhs: Float, rhs: S) -> S { return S(p: lhs * rhs.p) }
20+
1521
typealias TangentVector = S
1622
typealias CotangentVector = S
17-
func moved(toward vector: TangentVector) -> S {
18-
fatalError("unimplemented")
19-
}
2023

2124
let p: Float
2225

0 commit comments

Comments
 (0)