Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add @_Freezable property wrapper. #250

Merged
merged 1 commit into from
Nov 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions Sources/TensorFlow/Freezable.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/// A wrapper around a differentiable value with "freezable" derivatives.
///
/// When `isFrozen` is true, accesses to `wrappedValue` have a derivative of zero.
@propertyWrapper
public struct _Freezable<Value: Differentiable> {
@noDerivative public var isFrozen: Bool = false
private var _value: Value

public init(wrappedValue: Value) {
_value = wrappedValue
}

public var projectedValue: Self {
get { return self }
set { self = newValue }
}

/// The wrapped differentiable value.
@differentiable(vjp: _vjpValue)
public var wrappedValue: Value {
get { _value }
set { _value = newValue }
}

@usableFromInline
func _vjpValue() -> (value: Value, pullback: (Value.TangentVector) -> TangentVector) {
return (_value, { [isFrozen = self.isFrozen] v in
isFrozen ? .zero : v
})
}
}

extension _Freezable {
/// Freeze derivatives for `wrappedValue`. Accesses to `wrappedValue` will always have a
/// derivative of zero.
public mutating func freeze() {
isFrozen = true
}

/// Unfreeze derivatives for `wrappedValue`.
public mutating func unfreeze() {
isFrozen = false
}
}

extension _Freezable: Differentiable {
public typealias TangentVector = Value.TangentVector
public mutating func move(along direction: TangentVector) {
_value.move(along: direction)
}
}

extension _Freezable: EuclideanDifferentiable where Value: EuclideanDifferentiable {
public var differentiableVectorView: TangentVector {
return _value.differentiableVectorView
}
}
76 changes: 76 additions & 0 deletions Tests/TensorFlowTests/FreezableTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import XCTest
@testable import TensorFlow

final class FreezableTests: XCTestCase {
func testFreezableParameters() {
// A dense layer with freezable properties.
struct FreezableDense : Layer {
@_Freezable var weight: Tensor<Float>
@_Freezable var bias: Tensor<Float>

init(weight: Tensor<Float>, bias: Tensor<Float>) {
// Require scalar weight and bias for simplicity.
precondition(weight.isScalar)
precondition(bias.isScalar)
self.weight = weight
self.bias = bias
}

@differentiable
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
return input * weight + bias
}
}

var dense = FreezableDense(weight: Tensor(2), bias: Tensor(3))
let x = Tensor<Float>(4)
do {
let (value, gradient) = valueWithGradient(at: dense, x) { dense, x in dense(x) }
XCTAssertEqual(Tensor(11), value)
// The gradient of `dense.weight` should be non-zero.
XCTAssertEqual(FreezableDense.TangentVector(_weight: Tensor(4), _bias: Tensor(1)),
gradient.0)
XCTAssertEqual(Tensor(2), gradient.1)
}

// Freeze derivatives for `dense.weight`.
dense.$weight.freeze()
do {
let (value, gradient) = valueWithGradient(at: dense, x) { dense, x in dense(x) }
// The gradient of `dense.weight` should now be zero.
XCTAssertEqual(Tensor(11), value)
XCTAssertEqual(FreezableDense.TangentVector(_weight: Tensor(0), _bias: Tensor(1)),
gradient.0)
XCTAssertEqual(Tensor(2), gradient.1)
}

// Unfreeze derivatives for `dense.weight`.
dense.$weight.unfreeze()
do {
let (value, gradient) = valueWithGradient(at: dense, x) { dense, x in dense(x) }
XCTAssertEqual(Tensor(11), value)
// The gradient of `dense.weight` should now be non-zero.
XCTAssertEqual(FreezableDense.TangentVector(_weight: Tensor(4), _bias: Tensor(1)),
gradient.0)
XCTAssertEqual(Tensor(2), gradient.1)
}
}

static var allTests = [
("testFreezableParameters", testFreezableParameters),
]
}