Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Make liveness tracking of LazyTensorOperations a thread local state. #372

Merged
merged 5 commits into from
Jul 17, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions Sources/TensorFlow/Core/LazyTensorContext.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/// A class to keep track of runtime information about `LazyTensorOperation`
/// instances created by the program. This will be manaaged as a thread local
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

manaaged ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah. fixed now. :)

/// state.
class LazyTensorOperationsTracker {
struct RefCounts {
let op: LazyTensorOperation
let liveRefCount: Int
let allRefCount: Int
}

private var refCounts: [ObjectIdentifier: RefCounts] = [:]

func incrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
let opID = ObjectIdentifier(op)
if let counts = refCounts[opID] {
refCounts[opID] = RefCounts(
op: op,
liveRefCount: counts.liveRefCount + (isLive ? 1 : 0),
allRefCount: counts.allRefCount + 1)
} else {
refCounts[opID] = RefCounts(
op: op, liveRefCount: isLive ? 1 : 0, allRefCount: 1)
}
}

func decrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
let opID = ObjectIdentifier(op)
if let counts = refCounts[opID] {
if counts.allRefCount > 1 {
refCounts[opID] = RefCounts(
op: op,
liveRefCount: counts.liveRefCount - (isLive ? 1 : 0),
allRefCount: counts.allRefCount - 1)
} else {
refCounts.removeValue(forKey: opID)
}
}
}

func isLive(_ op: LazyTensorOperation) -> Bool {
let opID = ObjectIdentifier(op)
if let counts = refCounts[opID] {
return counts.liveRefCount > 0
}
return false
}

func forEachLiveOperation(
_ perform: (LazyTensorOperation) throws -> Void
) rethrows -> Void {
for (_, counts) in refCounts where counts.liveRefCount > 0 {
try perform(counts.op)
}
}

func forEachOperation(
_ perform: (LazyTensorOperation) throws -> Void
) rethrows -> Void {
for (_, counts) in refCounts { try perform(counts.op) }
}
}

struct LazyTensorContext {
var scopes = [LazyTensorOperationsTracker()]

static private var threadLocalContext: LazyTensorContext {
_ThreadLocalState.local.lazyTensorContext
}

static var operationsTracker: LazyTensorOperationsTracker {
return threadLocalContext.scopes.last!
}
}
56 changes: 7 additions & 49 deletions Sources/TensorFlow/Core/LazyTensorOperation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,80 +50,38 @@ class LazyTensorHandle: _AnyTensorHandle {
precondition(
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
handle = Handle.symbolic(op, index: index, isLive: false)
LazyTensorHandle.incrementRefCount(op, isLive: false)
LazyTensorContext.operationsTracker.incrementRefCount(op, isLive: false)
}

init(_lazyLive op: LazyTensorOperation, index: Int) {
precondition(
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
handle = Handle.symbolic(op, index: index, isLive: true)
LazyTensorHandle.incrementRefCount(op, isLive: true)
LazyTensorContext.operationsTracker.incrementRefCount(op, isLive: true)
}

deinit {
if case let .symbolic(op, _, isLive) = handle {
LazyTensorHandle.decrementRefCount(op, isLive: isLive)
LazyTensorContext.operationsTracker.decrementRefCount(op, isLive: isLive)
}
}

// Liveness tracking for LazyTensorOperations
//
struct LazyTensorOperationRefCounts {
let op: LazyTensorOperation
let liveRefCount: Int
let allRefCount: Int
}

private static var operationRefCounts: [
ObjectIdentifier: LazyTensorOperationRefCounts] = [:]

static func incrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
let opID = ObjectIdentifier(op)
if let counts = operationRefCounts[opID] {
operationRefCounts[opID] = LazyTensorOperationRefCounts(
op: op,
liveRefCount: counts.liveRefCount + (isLive ? 1 : 0),
allRefCount: counts.allRefCount + 1)
} else {
operationRefCounts[opID] = LazyTensorOperationRefCounts(
op: op, liveRefCount: isLive ? 1 : 0, allRefCount: 1)
}
}

static func decrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
let opID = ObjectIdentifier(op)
if let counts = operationRefCounts[opID] {
if counts.allRefCount > 1 {
operationRefCounts[opID] = LazyTensorOperationRefCounts(
op: op,
liveRefCount: counts.liveRefCount - (isLive ? 1 : 0),
allRefCount: counts.allRefCount - 1)
} else {
operationRefCounts.removeValue(forKey: opID)
}
}
}

static func isLive(_ op: LazyTensorOperation) -> Bool {
let opID = ObjectIdentifier(op)
if let counts = operationRefCounts[opID] {
return counts.liveRefCount > 0
}
return false
return LazyTensorContext.operationsTracker.isLive(op)
}

static func forEachLiveOperation(
_ perform: (LazyTensorOperation) throws -> Void
) rethrows -> Void {
for (_, counts) in operationRefCounts where counts.liveRefCount > 0 {
try perform(counts.op)
}
try LazyTensorContext.operationsTracker.forEachLiveOperation(perform)
}

static func forEachOperation(
_ perform: (LazyTensorOperation) throws -> Void
) rethrows -> Void {
for (_, counts) in operationRefCounts { try perform(counts.op) }
try LazyTensorContext.operationsTracker.forEachOperation(perform)
}

@usableFromInline
Expand Down
6 changes: 5 additions & 1 deletion Sources/TensorFlow/Core/Runtime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1193,10 +1193,14 @@ func _TFCOpSetAttrTypeArray(
}
}

/// A class to keep around thread local state.
/// A class to keep around thread local state:
/// - DeviceScopes
/// - LazyTensorContext
class _ThreadLocalState {
var deviceScopes = DeviceScopes()

var lazyTensorContext = LazyTensorContext()

private static let key: pthread_key_t = {
var key = pthread_key_t()
pthread_key_create(&key) {
Expand Down