Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add liveness tracking to LazyTensor #190

Merged
merged 6 commits into from
Jun 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions Sources/TensorFlow/LazyTensor/LazyTensorOperation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,72 @@ class LazyTensor: _AnyTensorHandle {
precondition(
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
handle = Handle.symbolic(op, index: index, isLive: false)
LazyTensor.incrementRefCount(op, isLive: false)
}

init(_lazyLive op: LazyTensorOperation, index: Int) {
precondition(
index < op.outputCount, "Symbolic Tensor Index is out-of-bounds")
handle = Handle.symbolic(op, index: index, isLive: true)
LazyTensor.incrementRefCount(op, isLive: true)
}

deinit {
if case let .symbolic(op, _, isLive) = handle {
LazyTensor.decrementRefCount(op, isLive: isLive)
}
}

// Liveness tracking for LazyTensorOperations
//
struct LazyTensorOperationRefCounts {
let op: LazyTensorOperation
let liveRefCount: Int
let allRefCount: Int
}

private static var operationRefCounts: [
ObjectIdentifier: LazyTensorOperationRefCounts] = [:]

static func incrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
let opID = ObjectIdentifier(op)
if let counts = operationRefCounts[opID] {
operationRefCounts[opID] = LazyTensorOperationRefCounts(
op: op,
liveRefCount: counts.liveRefCount + (isLive ? 1 : 0),
allRefCount: counts.allRefCount + 1)
} else {
operationRefCounts[opID] = LazyTensorOperationRefCounts(
op: op, liveRefCount: isLive ? 1 : 0, allRefCount: 1)
}
}

static func decrementRefCount(_ op: LazyTensorOperation, isLive: Bool) {
let opID = ObjectIdentifier(op)
if let counts = operationRefCounts[opID] {
if counts.allRefCount > 1 {
operationRefCounts[opID] = LazyTensorOperationRefCounts(
op: op,
liveRefCount: counts.liveRefCount - (isLive ? 1 : 0),
allRefCount: counts.allRefCount - 1)
} else {
operationRefCounts.removeValue(forKey: opID)
}
}
}

static func isLive(_ op: LazyTensorOperation) -> Bool {
let opID = ObjectIdentifier(op)
if let counts = operationRefCounts[opID] {
return counts.liveRefCount > 0
}
return false
}

static func onLiveOperations(_ perform: (LazyTensorOperation) -> ()) {
for (_, counts) in operationRefCounts where counts.liveRefCount > 0 {
perform(counts.op)
}
}

static var _materializationCallback: (String) -> () = { _ in }
Expand Down Expand Up @@ -85,19 +145,26 @@ class LazyTensorOperation: TensorOperation {
}
}

static var liveOperations: Int = 0

init(_id id: String?, name: String, outputCount: Int) {
self.name = name
self.inputs = []
self.attrs = [:]
self.outputCount = outputCount
self.outputs = nil
self.id = id
LazyTensorOperation.liveOperations += 1
}

required convenience init(_ name: String, _ outputCount: Int) {
self.init(_id: nil, name: name, outputCount: outputCount)
}

deinit {
LazyTensorOperation.liveOperations -= 1
}

func evaluate() -> [LazyTensor] {
return (0..<outputCount).map {
LazyTensor(_lazyLive: self, index: $0)
Expand Down
61 changes: 61 additions & 0 deletions Tests/TensorFlowTests/LazyTensorTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@ import XCTest
@testable import TensorFlow
import CTensorFlow

struct LazyTensorOperationRef: Equatable, Hashable {
let value: LazyTensorOperation
init(_ value: LazyTensorOperation) { self.value = value }
func hash(into hasher: inout Hasher) {
hasher.combine(ObjectIdentifier(value))
}
}

func ==(lhs: LazyTensorOperationRef, rhs: LazyTensorOperationRef) -> Bool {
return lhs.value === rhs.value
}

final class LazyTensorTests: XCTestCase {
func testConstructions() {
let zero = Tensor<Float>(0.0)
Expand All @@ -40,8 +52,57 @@ final class LazyTensorTests: XCTestCase {
XCTAssertEqual("\(liveSymTensor)", "IdentityN_0:0*")
}

func testLivenessTracking() {
func assertLive(_ expectedLive: [LazyTensorOperation]) {
var actualLiveOps: Set<LazyTensorOperationRef> = []
LazyTensor.onLiveOperations {
actualLiveOps.insert(LazyTensorOperationRef($0))
}
let expectedLiveOps = Set<LazyTensorOperationRef>(
expectedLive.map { LazyTensorOperationRef($0) }
)
XCTAssertEqual(expectedLiveOps, actualLiveOps)
}

func isSymbolic(_ t: LazyTensor) -> Bool {
if case let .symbolic(_) = t.handle {
return true
} else {
return false
}
}

let op0 = LazyTensorOperation(
_id: "0", name: "IdentityN", outputCount: 2)
let op1 = LazyTensorOperation(
_id: "1", name: "IdentityN", outputCount: 2)

XCTAssertFalse(LazyTensor.isLive(op0))
XCTAssertFalse(LazyTensor.isLive(op1))

let t0 = LazyTensor(_lazyLive: op0, index: 0)
let t1 = LazyTensor(_lazy: op1, index: 1)
XCTAssertTrue(LazyTensor.isLive(op0))
XCTAssertFalse(LazyTensor.isLive(op1))

do {
let t3 = LazyTensor(_lazyLive: op1, index: 0)
XCTAssertTrue(LazyTensor.isLive(op1))
assertLive([op0, op1])
// The following is here just to ensure t3 is live.
XCTAssertTrue(isSymbolic(t3))
}
XCTAssertFalse(LazyTensor.isLive(op1))
assertLive([op0])

// The following are here just to ensure t0 and t1 are live.
XCTAssertTrue(isSymbolic(t1))
XCTAssertTrue(isSymbolic(t0))
}

static var allTests = [
("testConstructions", testConstructions),
("testLivenessTracking", testLivenessTracking),
]

}