Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 16d87eb

Browse files
authored
[Layer] Remove the 'context' argument from 'Layer.applied(to:in:)'. (#87)
* [Layer] Remove the 'context' argument to `Layer.applied(to:in:)`. * Add tests.
1 parent 861d1f5 commit 16d87eb

File tree

5 files changed

+298
-182
lines changed

5 files changed

+298
-182
lines changed

Sources/DeepLearning/Context.swift

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#if !COMPILING_TENSORFLOW_MODULE
16+
import TensorFlow
17+
#endif
18+
19+
#if os(macOS) || os(iOS) || os(tvOS) || os(watchOS)
20+
import Darwin
21+
#else
22+
import Glibc
23+
#endif
24+
25+
/// A value that indicates the phase of using a machine learning model.
26+
public enum LearningPhase {
27+
case training
28+
case inference
29+
}
30+
31+
/// A context that stores thread-local contextual information used by deep learning APIs such as
32+
/// layers.
33+
///
34+
/// Use `Context.local` to retrieve the current thread-local context.
35+
///
36+
/// Examples:
37+
///
38+
/// * Set the current learning phase to training so that layers like `BatchNorm` will
39+
/// compute mean and variance when applied to inputs.
40+
///
41+
/// ```swift
42+
/// Context.local.learningPhase = .training
43+
/// ```
44+
/// * Set the current learning phase to inference so that layers like `Dropout` will not drop out
45+
/// units when applied to inputs.
46+
///
47+
/// ```swift
48+
/// Context.local.learningPhase = .inference
49+
/// ```
50+
public struct Context {
51+
/// The learning phase.
52+
public var learningPhase: LearningPhase = .inference
53+
54+
/// Creates a context with default properties.
55+
public init() {}
56+
57+
/// The current thread-local context.
58+
///
59+
/// - Note: Accessing this property is thread-safe.
60+
public static var local: Context {
61+
_read { yield ContextManager.local.currentContext }
62+
_modify { yield &ContextManager.local.currentContext }
63+
}
64+
}
65+
66+
/// Calls the given closure within a context that has everything identical to the current context
67+
/// except for the given learning phase.
68+
///
69+
/// - Parameters:
70+
/// - context: A context that will be set before the closure gets called and restored after the
71+
/// closure returns.
72+
/// - body: A nullary closure. If the closure has a return value, that value is also used as the
73+
/// return value of the `withContext(_:_:)` function.
74+
/// - Returns: The return value, if any, of the `body` closure.
75+
public func withContext<R>(_ context: Context, _ body: () throws -> R) rethrows -> R {
76+
ContextManager.local.push(context)
77+
defer { ContextManager.local.popContext() }
78+
return try body()
79+
}
80+
81+
/// Calls the given closure within a context that has everything identical to the current context
82+
/// except for the given learning phase.
83+
///
84+
/// - Parameters:
85+
/// - learningPhase: A learning phase that will be set before the closure gets called and restored
86+
/// after the closure returns.
87+
/// - body: A nullary closure. If the closure has a return value, that value is also used as the
88+
/// return value of the `withLearningPhase(_:_:)` function.
89+
/// - Returns: The return value, if any, of the `body` closure.
90+
public func withLearningPhase<R>(_ learningPhase: LearningPhase,
91+
_ body: () throws -> R) rethrows -> R {
92+
var context = ContextManager.local.currentContext
93+
context.learningPhase = learningPhase
94+
return try withContext(context, body)
95+
}
96+
97+
/// A manager that maintains and provides safe access to thread-local `Context` values.
98+
private final class ContextManager {
99+
var contextStack: [Context] = [Context()]
100+
101+
/// The data key for the singleton `Context` in the current thread.
102+
static let key: pthread_key_t = {
103+
var key = pthread_key_t()
104+
pthread_key_create(&key) { obj in
105+
#if !(os(macOS) || os(iOS) || os(watchOS) || os(tvOS))
106+
let obj = obj!
107+
#endif
108+
Unmanaged<ContextManager>.fromOpaque(obj).release()
109+
}
110+
return key
111+
}()
112+
113+
/// The thread-local singleton.
114+
static var local: ContextManager {
115+
if let address = pthread_getspecific(key) {
116+
return Unmanaged<ContextManager>.fromOpaque(address).takeUnretainedValue()
117+
}
118+
let context = ContextManager()
119+
pthread_setspecific(key, Unmanaged.passRetained(context).toOpaque())
120+
return context
121+
}
122+
123+
/// Pushes the given context to the context stack.
124+
func push(_ context: Context) {
125+
contextStack.append(context)
126+
}
127+
128+
/// Pops a context out of a stack.
129+
///
130+
/// - Precondition: The context stack must contain more than `1` contexts.
131+
func popContext() {
132+
assert(contextStack.count > 1,
133+
"Internal error: Only 1 context is available. Popping is not allowed.")
134+
contextStack.removeLast()
135+
}
136+
137+
/// The most recent context.
138+
var currentContext: Context {
139+
_read {
140+
assert(!contextStack.isEmpty, "Internal error: No contexts exist.")
141+
yield contextStack[contextStack.endIndex - 1]
142+
}
143+
_modify {
144+
assert(!contextStack.isEmpty, "Internal error: No contexts exist.")
145+
yield &contextStack[contextStack.endIndex - 1]
146+
}
147+
}
148+
}

0 commit comments

Comments
 (0)