This repository was archived by the owner on Jul 1, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Expand file tree Collapse file tree 2 files changed +13
-2
lines changed Original file line number Diff line number Diff line change @@ -91,7 +91,7 @@ class LazyTensorTraceBuilder {
91
91
92
92
/// Returns a trace obtained by tracing the given function.
93
93
static func trace< In: TensorGroup , Out: TensorGroup > ( _ fn: ( In ) -> Out ) -> LazyTensorTrace {
94
- precondition ( _RuntimeConfig . useLazyTensor, " Lazy tensor is not enabled for tracing. " )
94
+ precondition ( _ThreadLocalState . useLazyTensor, " Lazy tensor is not enabled for tracing. " )
95
95
96
96
// Set up inputs for running `fn`.
97
97
let inputOps = In . _typeList. map { Self . makePlaceholder ( dataType: $0) }
Original file line number Diff line number Diff line change @@ -693,7 +693,7 @@ extension _ExecutionContext {
693
693
static func makeOp(
694
694
_ name: String , _ outputCount: Int
695
695
) -> TFTensorOperation {
696
- return _RuntimeConfig . useLazyTensor
696
+ return _ThreadLocalState . useLazyTensor
697
697
? LazyTensorOperation ( name, outputCount)
698
698
: TFE_Op ( name, outputCount)
699
699
}
@@ -1201,6 +1201,17 @@ class _ThreadLocalState {
1201
1201
1202
1202
var lazyTensorContext = LazyTensorContext ( )
1203
1203
1204
+ static var useLazyTensor : Bool {
1205
+ get {
1206
+ _ThreadLocalState. local. lazyTensorEnabled ?? _RuntimeConfig. useLazyTensor
1207
+ }
1208
+ set ( newValue) {
1209
+ _ThreadLocalState. local. lazyTensorEnabled = newValue
1210
+ }
1211
+ }
1212
+
1213
+ private var lazyTensorEnabled : Bool ? = nil
1214
+
1204
1215
private static let key : pthread_key_t = {
1205
1216
var key = pthread_key_t ( )
1206
1217
pthread_key_create ( & key) {
You can’t perform that action at this time.
0 commit comments