Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 2f75a2d

Browse files
authored
Make useLazyTensor a thread-local state. (#387)
1 parent f5222cd commit 2f75a2d

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

Sources/TensorFlow/Core/LazyTensorTrace.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class LazyTensorTraceBuilder {
9191

9292
/// Returns a trace obtained by tracing the given function.
9393
static func trace<In: TensorGroup, Out: TensorGroup>(_ fn: (In) -> Out) -> LazyTensorTrace {
94-
precondition(_RuntimeConfig.useLazyTensor, "Lazy tensor is not enabled for tracing.")
94+
precondition(_ThreadLocalState.useLazyTensor, "Lazy tensor is not enabled for tracing.")
9595

9696
// Set up inputs for running `fn`.
9797
let inputOps = In._typeList.map { Self.makePlaceholder(dataType: $0) }

Sources/TensorFlow/Core/Runtime.swift

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ extension _ExecutionContext {
693693
static func makeOp(
694694
_ name: String, _ outputCount: Int
695695
) -> TFTensorOperation {
696-
return _RuntimeConfig.useLazyTensor
696+
return _ThreadLocalState.useLazyTensor
697697
? LazyTensorOperation(name, outputCount)
698698
: TFE_Op(name, outputCount)
699699
}
@@ -1201,6 +1201,20 @@ class _ThreadLocalState {
12011201

12021202
var lazyTensorContext = LazyTensorContext()
12031203

1204+
static var useLazyTensor: Bool {
1205+
get {
1206+
_ThreadLocalState.local.lazyTensorEnabled ?? _RuntimeConfig.useLazyTensor
1207+
}
1208+
set {
1209+
_ThreadLocalState.local.lazyTensorEnabled = newValue
1210+
}
1211+
}
1212+
1213+
/// When true, use lazy evaluation. If this is not set, we should use the
1214+
/// value of `_RuntimeConfig.useLazyTensor` to determine if lazy evaluation
1215+
/// is enabled.
1216+
private var lazyTensorEnabled: Bool? = nil
1217+
12041218
private static let key: pthread_key_t = {
12051219
var key = pthread_key_t()
12061220
pthread_key_create(&key) {

0 commit comments

Comments
 (0)