Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 28c5991

Browse files
committed
Add @Freezable property wrapper.
`@Freezable` wraps differentiable values and provides toggleable trainability via the `isFrozen` property. When `isFrozen` is true, accesses to `value` have a derivative of zero.
1 parent 140fc41 commit 28c5991

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

Sources/TensorFlow/Layer.swift

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,34 @@ public final class Parameter<Scalar: TensorFlowScalar> {
191191
self.value = value
192192
}
193193
}
194+
195+
/// A wrapper around a differentiable value with "freezable" derivatives.
196+
///
197+
/// When `isFrozen` is true, accesses to `value` have a derivative of zero.
198+
@_propertyWrapper
199+
public struct Freezable<Value: Differentiable> : Differentiable {
200+
@noDerivative public var isFrozen: Bool = false
201+
private var _value: Value
202+
203+
public init(initialValue: Value) {
204+
_value = initialValue
205+
}
206+
207+
@differentiable(vjp: _vjpValue)
208+
public var value: Value {
209+
get { _value }
210+
set { _value = newValue }
211+
}
212+
213+
@usableFromInline
214+
func _vjpValue() -> (value: Value, pullback: (Value.TangentVector) -> TangentVector) {
215+
return (_value, { [isFrozen = self.isFrozen] v in
216+
isFrozen ? .zero : v
217+
})
218+
}
219+
220+
public typealias TangentVector = Value.TangentVector
221+
public mutating func move(along direction: TangentVector) {
222+
_value.move(along: direction)
223+
}
224+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
@testable import TensorFlow
17+
18+
final class FreezableTests: XCTestCase {
19+
func testFreezableParameters() {
20+
struct FreezableDense : Layer {
21+
@Freezable var weight: Float
22+
// Workaround for underlying storage private access.
23+
var freezableWeight: Freezable<Float> {
24+
get { $weight }
25+
set { $weight = newValue }
26+
}
27+
28+
@Freezable var bias: Float
29+
// Workaround for underlying storage private access.
30+
var freezableBias: Freezable<Float> {
31+
get { $bias }
32+
set { $bias = newValue }
33+
}
34+
35+
@differentiable
36+
func callAsFunction(_ input: Float) -> Float {
37+
return input * weight + bias
38+
}
39+
}
40+
41+
var dense = FreezableDense(weight: 2, bias: 3)
42+
let x: Float = 4
43+
do {
44+
let (value, gradient) = valueWithGradient(at: dense, x) { dense, x in dense(x) }
45+
XCTAssertEqual(11, value)
46+
// FIXME: '$' is not a valid identifier:
47+
// cannot declare entity named '$weight'; the '$' prefix is reserved for
48+
// implicitly-synthesized declarations.
49+
//
50+
// Tentative solution: change `Differentiable` derived conformances to use original
51+
// property names instead of underlying storage property names.
52+
//
53+
// Now: `FreezableDense.TangentVector($weight: ..., $bias: ...)
54+
// Goal: `FreezableDense.TangentVector(weight: ..., bias: ...)
55+
//
56+
// XCTAssertEqual(FreezableDense.TangentVector($weight: 4, $bias: 1), gradient.0)
57+
XCTAssertEqual(4, gradient.0.$weight)
58+
XCTAssertEqual(1, gradient.0.$bias)
59+
XCTAssertEqual(2, gradient.1)
60+
}
61+
dense.freezableWeight.isFrozen = true
62+
do {
63+
let (value, gradient) = valueWithGradient(at: dense, x) { dense, x in dense(x) }
64+
XCTAssertEqual(11, value)
65+
XCTAssertEqual(0, gradient.0.$weight)
66+
XCTAssertEqual(1, gradient.0.$bias)
67+
XCTAssertEqual(2, gradient.1)
68+
}
69+
}
70+
71+
static var allTests = [
72+
("testFreezableParameters", testFreezableParameters),
73+
]
74+
}

0 commit comments

Comments
 (0)