Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit d2578b3

Browse files
committed
Make useLazyTensor a thread-local state.
1 parent ef48ae9 commit d2578b3

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

Sources/TensorFlow/Core/LazyTensorTrace.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class LazyTensorTraceBuilder {
9191

9292
/// Returns a trace obtained by tracing the given function.
9393
static func trace<In: TensorGroup, Out: TensorGroup>(_ fn: (In) -> Out) -> LazyTensorTrace {
94-
precondition(_RuntimeConfig.useLazyTensor, "Lazy tensor is not enabled for tracing.")
94+
precondition(_ThreadLocalState.useLazyTensor, "Lazy tensor is not enabled for tracing.")
9595

9696
// Set up inputs for running `fn`.
9797
let inputOps = In._typeList.map { Self.makePlaceholder(dataType: $0) }

Sources/TensorFlow/Core/Runtime.swift

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ extension _ExecutionContext {
693693
static func makeOp(
694694
_ name: String, _ outputCount: Int
695695
) -> TFTensorOperation {
696-
return _RuntimeConfig.useLazyTensor
696+
return _ThreadLocalState.useLazyTensor
697697
? LazyTensorOperation(name, outputCount)
698698
: TFE_Op(name, outputCount)
699699
}
@@ -1201,6 +1201,17 @@ class _ThreadLocalState {
12011201

12021202
var lazyTensorContext = LazyTensorContext()
12031203

1204+
static var useLazyTensor: Bool {
1205+
get {
1206+
_ThreadLocalState.local.lazyTensorEnabled ?? _RuntimeConfig.useLazyTensor
1207+
}
1208+
set(newValue) {
1209+
_ThreadLocalState.local.lazyTensorEnabled = newValue
1210+
}
1211+
}
1212+
1213+
private var lazyTensorEnabled: Bool? = nil
1214+
12041215
private static let key: pthread_key_t = {
12051216
var key = pthread_key_t()
12061217
pthread_key_create(&key) {

0 commit comments

Comments
 (0)