Skip to content

Commit bc8d3c7

Browse files
committed
[AutoDiff] Change Differentiable.moved(along:) to move(along:).
Change `Differentiable.moved(along:)` to `mutating func move(along:)`. This is important for upcoming `Differentiable` class support. Update `Differentiable` derived conformances (logic and diagnostics). Update tests.
1 parent 9319fce commit bc8d3c7

15 files changed

+172
-253
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2520,12 +2520,15 @@ ERROR(broken_differentiable_requirement,none,
25202520
"Differentiable protocol is broken: unexpected requirement", ())
25212521
WARNING(differentiable_nondiff_type_implicit_noderivative_fixit,none,
25222522
"stored property %0 has no derivative because it does not conform to "
2523-
"'Differentiable'; add '@noDerivative' to make it explicit",
2524-
(Identifier))
2525-
WARNING(differentiable_constant_property_implicit_noderivative_fixit,none,
2526-
"'let' properties with a default value do not have a derivative; add "
2527-
"'@noDerivative' to make it explicit, or change it to 'var' to allow "
2528-
"derivatives", ())
2523+
"'Differentiable'; add an explicit '@noDerivative' attribute"
2524+
"%select{|, or conform %1 to 'AdditiveArithmetic'}2",
2525+
(Identifier, Identifier, bool))
2526+
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
2527+
"synthesis of the 'Differentiable.move(along:)' requirement for %1 "
2528+
"requires all stored properties to be mutable; use 'var' instead, or add "
2529+
"an explicit '@noDerivative' attribute"
2530+
"%select{|, or conform %1 to 'AdditiveArithmetic'}2",
2531+
(Identifier, Identifier, bool))
25292532

25302533
NOTE(codable_extraneous_codingkey_case_here,none,
25312534
"CodingKey case %0 does not match any stored properties", (Identifier))

include/swift/AST/KnownIdentifiers.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ IDENTIFIER(scaled)
139139
IDENTIFIER(AllDifferentiableVariables)
140140
IDENTIFIER(TangentVector)
141141
IDENTIFIER(allDifferentiableVariables)
142-
IDENTIFIER(moved)
142+
IDENTIFIER(move)
143143

144144
// Kinds of layout constraints
145145
IDENTIFIER_WITH_NAME(UnknownLayout, "_UnknownLayout")

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 73 additions & 119 deletions
Large diffs are not rendered by default.

lib/Sema/DerivedConformances.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,9 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
299299
}
300300

301301
// SWIFT_ENABLE_TENSORFLOW
302-
// Differentiable.moved(along:)
302+
// Differentiable.move(along:)
303303
if (name.isCompoundName() &&
304-
name.getBaseName() == ctx.Id_moved) {
304+
name.getBaseName() == ctx.Id_move) {
305305
auto argumentNames = name.getArgumentNames();
306306
if (argumentNames.size() == 1 &&
307307
argumentNames[0] == ctx.getIdentifier("along")) {

stdlib/public/core/Array.swift

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1973,13 +1973,14 @@ extension Array where Element : Differentiable {
19731973
}
19741974
}
19751975

1976-
public func moved(along direction: TangentVector) -> DifferentiableView {
1976+
public mutating func move(along direction: TangentVector) {
19771977
precondition(
19781978
base.count == direction.base.count,
19791979
"cannot move Array.DifferentiableView with count \(base.count) along " +
19801980
"direction with different count \(direction.base.count)")
1981-
return DifferentiableView(
1982-
zip(base, direction.base).map { $0.moved(along: $1) })
1981+
for i in base.indices {
1982+
base[i].move(along: direction.base[i])
1983+
}
19831984
}
19841985
}
19851986
}
@@ -2072,8 +2073,10 @@ extension Array : Differentiable where Element : Differentiable {
20722073
}
20732074
}
20742075

2075-
public func moved(along direction: TangentVector) -> Array {
2076-
return DifferentiableView(self).moved(along: direction).base
2076+
public mutating func move(along direction: TangentVector) {
2077+
var view = DifferentiableView(self)
2078+
view.move(along: direction)
2079+
self = view.base
20772080
}
20782081
}
20792082

stdlib/public/core/AutoDiff.swift

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,9 @@ public protocol Differentiable {
6565
/// All differentiable variables of this value.
6666
var allDifferentiableVariables: AllDifferentiableVariables { get set }
6767

68-
/// Returns `self` moved along the value space towards the given tangent
69-
/// vector. In Riemannian geometry (mathematics), this represents an
70-
/// exponential map.
71-
func moved(along direction: TangentVector) -> Self
68+
/// Moves `self` along the value space towards the given tangent vector. In
69+
/// Riemannian geometry (mathematics), this represents an exponential map.
70+
mutating func move(along direction: TangentVector)
7271

7372
@available(*, deprecated,
7473
message: "'CotangentVector' is now equal to 'TangentVector' and will be removed")
@@ -82,13 +81,9 @@ public extension Differentiable where AllDifferentiableVariables == Self {
8281
}
8382
}
8483

85-
// FIXME: The `Self : AdditiveArithmetic` constraint should be implied by
86-
// `TangentVector == Self`, but the type checker errors out when it does not
87-
// exist.
88-
public extension Differentiable
89-
where TangentVector == Self, Self : AdditiveArithmetic {
90-
func moved(along direction: TangentVector) -> Self {
91-
return self + direction
84+
public extension Differentiable where TangentVector == Self {
85+
mutating func move(along direction: TangentVector) {
86+
self += direction
9287
}
9388
}
9489

@@ -451,7 +446,7 @@ internal protocol _AnyDerivativeBox {
451446

452447
// `Differentiable` requirements.
453448
var _allDifferentiableVariables: _AnyDerivativeBox { get }
454-
func _moved(along direction: _AnyDerivativeBox) -> _AnyDerivativeBox
449+
mutating func _move(along direction: _AnyDerivativeBox)
455450

456451
/// The underlying base value, type-erased to `Any`.
457452
var _typeErasedBase: Any { get }
@@ -555,18 +550,17 @@ internal struct _ConcreteDerivativeBox<T> : _AnyDerivativeBox
555550
return _ConcreteDerivativeBox(_base.allDifferentiableVariables)
556551
}
557552

558-
func _moved(along direction: _AnyDerivativeBox) -> _AnyDerivativeBox {
559-
if _isOpaqueZero() {
560-
return direction
561-
}
553+
mutating func _move(along direction: _AnyDerivativeBox) {
562554
if direction._isOpaqueZero() {
563-
return self
555+
return
564556
}
557+
// The case where `self._isOpaqueZero()` returns true is handled in
558+
// `AnyDerivative.move(along:)`.
565559
guard let directionBase =
566560
direction._unboxed(to: T.TangentVector.self) else {
567561
_derivativeTypeMismatch(T.self, type(of: direction._typeErasedBase))
568562
}
569-
return _ConcreteDerivativeBox<T>(_base.moved(along: directionBase))
563+
_base.move(along: directionBase)
570564
}
571565
}
572566

@@ -661,7 +655,11 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
661655
get { return AnyDerivative(_box: _box._allDifferentiableVariables) }
662656
// set { _box._allDifferentiableVariables = newValue._box }
663657
}
664-
public func moved(along direction: TangentVector) -> AnyDerivative {
665-
return AnyDerivative(_box: _box._moved(along: direction._box))
658+
public mutating func move(along direction: TangentVector) {
659+
if _box._isOpaqueZero() {
660+
_box = direction._box
661+
return
662+
}
663+
_box._move(along: direction._box)
666664
}
667665
}

stdlib/public/core/FloatingPointTypes.swift.gyb

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,6 +1887,10 @@ extension ${Self} : VectorProtocol {
18871887
extension ${Self} : Differentiable {
18881888
public typealias TangentVector = ${Self}
18891889
public typealias AllDifferentiableVariables = ${Self}
1890+
1891+
public mutating func move(along direction: TangentVector) {
1892+
self += direction
1893+
}
18901894
}
18911895

18921896
//===----------------------------------------------------------------------===//

test/AutoDiff/anyderivative.swift

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,19 @@ import StdlibUnittest
66
var AnyDerivativeTests = TestSuite("AnyDerivative")
77

88
struct Vector : Differentiable {
9-
let x, y: Float
9+
var x, y: Float
1010
}
1111
struct Generic<T: Differentiable> : Differentiable {
12-
let x: T
12+
var x: T
13+
}
14+
15+
extension AnyDerivative {
16+
// This exists only to faciliate testing.
17+
func moved(along direction: TangentVector) -> Self {
18+
var result = self
19+
result.move(along: direction)
20+
return result
21+
}
1322
}
1423

1524
AnyDerivativeTests.test("Vector") {

test/AutoDiff/array.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ ArrayAutodiffTests.test("ArraySubscript") {
2929

3030
ArrayAutodiffTests.test("ArrayConcat") {
3131
struct TwoArrays : Differentiable {
32-
let a: [Float]
33-
let b: [Float]
32+
var a: [Float]
33+
var b: [Float]
3434
}
3535

3636
func sumFirstThreeConcatted(_ arrs: TwoArrays) -> Float {

test/AutoDiff/autodiff_diagnostics.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ _ = gradient(at: 0, in: one_to_one_0) // okay!
2525
//===----------------------------------------------------------------------===//
2626

2727
struct S {
28-
let p: Float
28+
var p: Float
2929
}
30-
3130
extension S : Differentiable, VectorProtocol {
31+
// Test custom `TangentVector` type with non-matching stored property name.
3232
struct TangentVector: Differentiable, VectorProtocol {
3333
var dp: Float
3434
}
@@ -39,8 +39,8 @@ extension S : Differentiable, VectorProtocol {
3939
static func - (lhs: S, rhs: S) -> S { return S(p: lhs.p - rhs.p) }
4040
static func * (lhs: Float, rhs: S) -> S { return S(p: lhs * rhs.p) }
4141

42-
func moved(along direction: TangentVector) -> S {
43-
return S(p: p + direction.dp)
42+
mutating func move(along direction: TangentVector) {
43+
p.move(along: direction.dp)
4444
}
4545
}
4646

test/AutoDiff/derived_differentiable_properties.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,12 @@ struct GenericTanMember<T : Differentiable> : Differentiable, AdditiveArithmetic
7777
// CHECK-AST: @_implements(Equatable, ==(_:_:)) internal static func __derived_struct_equals(_ a: GenericTanMember<T>, _ b: GenericTanMember<T>) -> Bool
7878

7979
public struct ConditionallyDifferentiable<T> {
80-
public let x: T
80+
public var x: T
8181
}
8282
extension ConditionallyDifferentiable : Differentiable where T : Differentiable {}
8383

8484
// CHECK-AST-LABEL: public struct ConditionallyDifferentiable<T> {
8585
// CHECK-AST: @differentiable(wrt: self where T : Differentiable)
86-
// CHECK-AST: public let x: T
86+
// CHECK-AST: public var x: T
8787
// CHECK-AST: internal init(x: T)
8888
// CHECK-AST: }

test/AutoDiff/differentiable_attr_type_checking.swift

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -548,11 +548,7 @@ struct ResultLabelTest {
548548
}
549549

550550
struct Tensor<Scalar> : AdditiveArithmetic {}
551-
extension Tensor : Differentiable where Scalar : Differentiable {
552-
typealias TangentVector = Tensor
553-
typealias AllDifferentiableVariables = Tensor
554-
func moved(along direction: Tensor) -> Tensor { return self }
555-
}
551+
extension Tensor : Differentiable where Scalar : Differentiable {}
556552
@differentiable(where Scalar : Differentiable)
557553
func where2<Scalar : Numeric>(x: Tensor<Scalar>) -> Tensor<Scalar> {
558554
return x

test/AutoDiff/e2e_differentiable_property.swift

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,15 @@ struct Space {
2020
/// `x` is a computed property with a custom vjp.
2121
var x: Float {
2222
@differentiable(vjp: vjpX)
23-
get {
24-
return storedX
25-
}
23+
get { storedX }
24+
set { storedX = newValue }
2625
}
2726

2827
func vjpX() -> (Float, (Float) -> TangentSpace) {
2928
return (x, { v in TangentSpace(x: v, y: 0) } )
3029
}
3130

32-
private let storedX: Float
31+
private var storedX: Float
3332

3433
@differentiable
3534
var y: Float
@@ -42,8 +41,9 @@ struct Space {
4241

4342
extension Space : Differentiable {
4443
typealias TangentVector = TangentSpace
45-
func moved(along: TangentSpace) -> Space {
46-
return Space(x: x + along.x, y: y + along.y)
44+
mutating func move(along direction: TangentSpace) {
45+
x.move(along: direction.x)
46+
y.move(along: direction.y)
4747
}
4848
}
4949

@@ -106,13 +106,14 @@ extension ProductSpaceOtherTangentTangentSpace : Differentiable {
106106
}
107107

108108
struct ProductSpaceOtherTangent {
109-
let x, y: Float
109+
var x, y: Float
110110
}
111111

112112
extension ProductSpaceOtherTangent : Differentiable {
113113
typealias TangentVector = ProductSpaceOtherTangentTangentSpace
114-
func moved(along: ProductSpaceOtherTangentTangentSpace) -> ProductSpaceOtherTangent {
115-
return ProductSpaceOtherTangent(x: x + along.x, y: y + along.y)
114+
mutating func move(along direction: ProductSpaceOtherTangentTangentSpace) {
115+
x.move(along: direction.x)
116+
y.move(along: direction.y)
116117
}
117118
}
118119

test/AutoDiff/separate_cotangent_type.swift

Lines changed: 0 additions & 48 deletions
This file was deleted.

0 commit comments

Comments
 (0)