Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Cleanup: use .local for accessing thread local state. #416

Merged
merged 1 commit into from
Aug 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions Sources/TensorFlow/Core/LazyTensorContext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,11 @@ class LazyTensorOperationsTracker {
}

struct LazyTensorContext {
private var operationsTracker = LazyTensorOperationsTracker()
private var isShapeTrackingEnabled = true
var operationsTracker = LazyTensorOperationsTracker()
var isShapeTrackingEnabled = true

static private var local: LazyTensorContext {
static var local: LazyTensorContext {
_read { yield _ThreadLocalState.local.lazyTensorContext }
_modify { yield &_ThreadLocalState.local.lazyTensorContext }
}

static var operationsTracker: LazyTensorOperationsTracker {
return local.operationsTracker
}

/// A flag that determines whether we should track shapes. We will need to disable shape
/// tracking within certain contexts. e.g., we won't be able to compute shapes when tracing.
static var isShapeTrackingEnabled: Bool {
get { local.isShapeTrackingEnabled }
set { local.isShapeTrackingEnabled = newValue }
}
}
16 changes: 8 additions & 8 deletions Sources/TensorFlow/Core/LazyTensorOperation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,19 @@ class LazyTensorHandle: _AnyTensorHandle {
precondition(
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
handle = Handle.symbolic(op, index: index, isLive: false)
LazyTensorContext.operationsTracker.incrementRefCount(op, isLive: false)
LazyTensorContext.local.operationsTracker.incrementRefCount(op, isLive: false)
}

init(_lazyLive op: LazyTensorOperation, index: Int) {
precondition(
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
handle = Handle.symbolic(op, index: index, isLive: true)
LazyTensorContext.operationsTracker.incrementRefCount(op, isLive: true)
LazyTensorContext.local.operationsTracker.incrementRefCount(op, isLive: true)
}

deinit {
if case let .symbolic(op, _, isLive) = handle {
LazyTensorContext.operationsTracker.decrementRefCount(op, isLive: isLive)
LazyTensorContext.local.operationsTracker.decrementRefCount(op, isLive: isLive)
}
}

Expand All @@ -80,7 +80,7 @@ class LazyTensorHandle: _AnyTensorHandle {
get {
switch handle {
case .symbolic(let op, let index, _):
precondition(LazyTensorContext.isShapeTrackingEnabled,
precondition(LazyTensorContext.local.isShapeTrackingEnabled,
"Shape tracking is not enabled in this context.")
if let shape = op.outputShapes[index] { return shape }
// Materialize and get the shape from concrete tensor handle.
Expand All @@ -102,19 +102,19 @@ class LazyTensorHandle: _AnyTensorHandle {
// Liveness tracking for LazyTensorOperations
//
static func isLive(_ op: LazyTensorOperation) -> Bool {
return LazyTensorContext.operationsTracker.isLive(op)
return LazyTensorContext.local.operationsTracker.isLive(op)
}

static func forEachLiveOperation(
_ perform: (LazyTensorOperation) throws -> Void
) rethrows -> Void {
try LazyTensorContext.operationsTracker.forEachLiveOperation(perform)
try LazyTensorContext.local.operationsTracker.forEachLiveOperation(perform)
}

static func forEachOperation(
_ perform: (LazyTensorOperation) throws -> Void
) rethrows -> Void {
try LazyTensorContext.operationsTracker.forEachOperation(perform)
try LazyTensorContext.local.operationsTracker.forEachOperation(perform)
}

@usableFromInline
Expand Down Expand Up @@ -263,7 +263,7 @@ class LazyTensorOperation: TensorOperation {
}

func evaluate() -> [LazyTensorHandle] {
if LazyTensorContext.isShapeTrackingEnabled {
if LazyTensorContext.local.isShapeTrackingEnabled {
updateOutputShapes()
}
return (0..<outputCount).map {
Expand Down
6 changes: 3 additions & 3 deletions Sources/TensorFlow/Core/LazyTensorTrace.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ class LazyTensorTraceBuilder {
static func trace<In: TensorGroup, Out: TensorGroup>(_ fn: (In) -> Out) -> LazyTensorTrace {
precondition(_ThreadLocalState.useLazyTensor, "Lazy tensor is not enabled for tracing.")
// Disable shape tracking and reset to original state when done.
let isShapeTrackingEnabled = LazyTensorContext.isShapeTrackingEnabled
defer { LazyTensorContext.isShapeTrackingEnabled = isShapeTrackingEnabled }
LazyTensorContext.isShapeTrackingEnabled = false
let isShapeTrackingEnabled = LazyTensorContext.local.isShapeTrackingEnabled
defer { LazyTensorContext.local.isShapeTrackingEnabled = isShapeTrackingEnabled }
LazyTensorContext.local.isShapeTrackingEnabled = false

// Set up inputs for running `fn`.
let inputOps = In._typeList.map { Self.makePlaceholder(dataType: $0) }
Expand Down